[mlpack] 01/207: adds GammaDistribution::Train(observations, probabilities)
Barak A. Pearlmutter
barak+git at pearlmutter.net
Thu Mar 23 17:53:35 UTC 2017
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch master
in repository mlpack.
commit 3c0a56c2080b6ee00878eebadc141ab47e954fd4
Author: yashu-seth <yashuseth2503 at gmail.com>
Date: Sun Dec 18 10:47:54 2016 -0800
adds GammaDistribution::Train(observations, probabilities)
---
src/mlpack/core/dists/gamma_distribution.cpp | 29 ++++++++++++++++++++++++++++
src/mlpack/core/dists/gamma_distribution.hpp | 3 +--
2 files changed, 30 insertions(+), 2 deletions(-)
diff --git a/src/mlpack/core/dists/gamma_distribution.cpp b/src/mlpack/core/dists/gamma_distribution.cpp
index 1802b88..0a46fed 100644
--- a/src/mlpack/core/dists/gamma_distribution.cpp
+++ b/src/mlpack/core/dists/gamma_distribution.cpp
@@ -68,6 +68,35 @@ void GammaDistribution::Train(const arma::mat& rdata, const double tol)
Train(logMeanxVec, meanLogxVec, meanxVec, tol);
}
+//Fits an alpha and beta parameter according to observation probabilities.
+void GammaDistribution::Train(const arma::mat& rdata, const arma::vec& probabilities,
+ const double tol)
+{
+ // If fittingSet is empty, nothing to do.
+ if (arma::size(rdata) == arma::size(arma::mat()))
+ return;
+
+ arma::vec meanLogxVec(rdata.n_rows, arma::fill::zeros);
+ arma::vec meanxVec(rdata.n_rows, arma::fill::zeros);
+ arma::vec logMeanxVec(rdata.n_rows, arma::fill::zeros);
+
+ for(size_t i=0; i<rdata.n_cols; i++)
+ {
+ meanLogxVec += probabilities(i) * arma::log(rdata.col(i));
+ meanxVec += probabilities(i) * rdata.col(i);
+ }
+
+ double tot_probabilty = arma::accu(probabilities);
+
+ meanLogxVec /= tot_probabilty;
+ meanxVec /= tot_probabilty;
+ logMeanxVec = arma::log(meanxVec);
+
+ // Call the statistics-only GammaDistribution::Train() function to fit the
+ // parameters. That function does all the work so we're done.
+ Train(logMeanxVec, meanLogxVec, meanxVec, tol);
+}
+
// Fits an alpha and beta parameter to each dimension of the data.
void GammaDistribution::Train(const arma::vec& logMeanxVec,
const arma::vec& meanLogxVec,
diff --git a/src/mlpack/core/dists/gamma_distribution.hpp b/src/mlpack/core/dists/gamma_distribution.hpp
index 718cd7f..766f666 100644
--- a/src/mlpack/core/dists/gamma_distribution.hpp
+++ b/src/mlpack/core/dists/gamma_distribution.hpp
@@ -101,11 +101,10 @@ class GammaDistribution
* @param tol Convergence tolerance. This is *not* an absolute measure:
* It will stop the approximation once the *change* in the value is
* smaller than tol.
- *
+ */
void Train(const arma::mat& observations,
const arma::vec& probabilities,
const double tol = 1e-8);
- */
/**
* This function trains (fits distribution parameters) to a dataset with
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list