[mlpack] 01/207: adds GammaDistribution::Train(observations, probabilities)

Barak A. Pearlmutter barak+git at pearlmutter.net
Thu Mar 23 17:53:35 UTC 2017


This is an automated email from the git hooks/post-receive script.

bap pushed a commit to branch master
in repository mlpack.

commit 3c0a56c2080b6ee00878eebadc141ab47e954fd4
Author: yashu-seth <yashuseth2503 at gmail.com>
Date:   Sun Dec 18 10:47:54 2016 -0800

    adds GammaDistribution::Train(observations, probabilities)
---
 src/mlpack/core/dists/gamma_distribution.cpp | 29 ++++++++++++++++++++++++++++
 src/mlpack/core/dists/gamma_distribution.hpp |  3 +--
 2 files changed, 30 insertions(+), 2 deletions(-)

diff --git a/src/mlpack/core/dists/gamma_distribution.cpp b/src/mlpack/core/dists/gamma_distribution.cpp
index 1802b88..0a46fed 100644
--- a/src/mlpack/core/dists/gamma_distribution.cpp
+++ b/src/mlpack/core/dists/gamma_distribution.cpp
@@ -68,6 +68,35 @@ void GammaDistribution::Train(const arma::mat& rdata, const double tol)
   Train(logMeanxVec, meanLogxVec, meanxVec, tol);
 }
 
+//Fits an alpha and beta parameter according to observation probabilities.
+void GammaDistribution::Train(const arma::mat& rdata, const arma::vec& probabilities,
+                              const double tol)
+{
+  // If fittingSet is empty, nothing to do.
+  if (arma::size(rdata) == arma::size(arma::mat()))
+    return;
+
+  arma::vec meanLogxVec(rdata.n_rows, arma::fill::zeros);
+  arma::vec meanxVec(rdata.n_rows, arma::fill::zeros);
+  arma::vec logMeanxVec(rdata.n_rows, arma::fill::zeros);
+
+  for(size_t i=0; i<rdata.n_cols; i++)
+  {
+    meanLogxVec += probabilities(i) * arma::log(rdata.col(i));
+    meanxVec += probabilities(i) * rdata.col(i);
+  }
+
+  double tot_probabilty = arma::accu(probabilities);
+
+  meanLogxVec /= tot_probabilty;
+  meanxVec /= tot_probabilty;
+  logMeanxVec = arma::log(meanxVec);
+
+  // Call the statistics-only GammaDistribution::Train() function to fit the
+  // parameters. That function does all the work so we're done.
+  Train(logMeanxVec, meanLogxVec, meanxVec, tol);
+}
+
 // Fits an alpha and beta parameter to each dimension of the data.
 void GammaDistribution::Train(const arma::vec& logMeanxVec, 
                               const arma::vec& meanLogxVec,
diff --git a/src/mlpack/core/dists/gamma_distribution.hpp b/src/mlpack/core/dists/gamma_distribution.hpp
index 718cd7f..766f666 100644
--- a/src/mlpack/core/dists/gamma_distribution.hpp
+++ b/src/mlpack/core/dists/gamma_distribution.hpp
@@ -101,11 +101,10 @@ class GammaDistribution
      * @param tol Convergence tolerance. This is *not* an absolute measure:
      *    It will stop the approximation once the *change* in the value is 
      *    smaller than tol.
-     *
+     */
     void Train(const arma::mat& observations,
                const arma::vec& probabilities,
                const double tol = 1e-8);
-     */
 
     /**
      * This function trains (fits distribution parameters) to a dataset with

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git



More information about the debian-science-commits mailing list