[mlpack] 265/324: Hierarchical GMMs store params in GaussianDistributions. Makes code clearer and simplifies Save/Load.
Barak A. Pearlmutter
barak+git at cs.nuim.ie
Sun Aug 17 08:22:17 UTC 2014
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch svn-trunk
in repository mlpack.
commit 2d24ce4539eb6be9351565dfa143d0599a1f3e8e
Author: michaelfox99 <michaelfox99 at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date: Tue Aug 5 12:58:59 2014 +0000
Hierarchical GMMs store params in GaussianDistributions. Makes code clearer and simplifies Save/Load.
git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@16949 9d5b8971-822b-0410-80eb-d18c1038ef23
---
src/mlpack/methods/gmm/gmm.hpp | 110 +++++++++++++++++++++++++++++------------
1 file changed, 79 insertions(+), 31 deletions(-)
diff --git a/src/mlpack/methods/gmm/gmm.hpp b/src/mlpack/methods/gmm/gmm.hpp
index 15341d1..3de507b 100644
--- a/src/mlpack/methods/gmm/gmm.hpp
+++ b/src/mlpack/methods/gmm/gmm.hpp
@@ -1,5 +1,6 @@
/**
* @author Parikshit Ram (pram at cc.gatech.edu)
+ * @author Michael Fox
* @file gmm.hpp
*
* Defines a Gaussian Mixture model and
@@ -28,14 +29,12 @@ namespace gmm /** Gaussian Mixture Models. */ {
*
* @code
* void Estimate(const arma::mat& observations,
- * std::vector<arma::vec>& means,
- * std::vector<arma::mat>& covariances,
+ * std::vector<distribution::GaussianDistribution>& dists,
* arma::vec& weights);
*
* void Estimate(const arma::mat& observations,
* const arma::vec& probabilities,
- * std::vector<arma::vec>& means,
- * std::vector<arma::mat>& covariances,
+ * std::vector<distribution::GaussianDistribution>& dists,
* arma::vec& weights);
* @endcode
*
@@ -78,10 +77,14 @@ class GMM
size_t gaussians;
//! The dimensionality of the model.
size_t dimensionality;
- //! Vector of means; one for each Gaussian.
+
+ //! Vector of Gaussians
+ std::vector<distribution::GaussianDistribution> dists;
+
+ //! Legacy member data, not used.
std::vector<arma::vec> means;
- //! Vector of covariances; one for each Gaussian.
std::vector<arma::mat> covariances;
+
//! Vector of a priori weights for each Gaussian.
arma::vec weights;
@@ -126,19 +129,16 @@ class GMM
FittingType& fitter);
/**
- * Create a GMM with the given means, covariances, and weights.
+ * Create a GMM with the given dists and weights.
*
- * @param means Means of the model.
- * @param covariances Covariances of the model.
+ * @param dists Distributions of the model.
* @param weights Weights of the model.
*/
- GMM(const std::vector<arma::vec>& means,
- const std::vector<arma::mat>& covariances,
+ GMM(const std::vector<distribution::GaussianDistribution> & dists,
const arma::vec& weights) :
- gaussians(means.size()),
- dimensionality((!means.empty()) ? means[0].n_elem : 0),
- means(means),
- covariances(covariances),
+ gaussians(dists.size()),
+ dimensionality((!dists.empty()) ? dists[0].Mean().n_elem : 0),
+ dists(dists),
weights(weights),
localFitter(FittingType()),
fitter(localFitter) { /* Nothing to do. */ }
@@ -152,14 +152,12 @@ class GMM
* @param covariances Covariances of the model.
* @param weights Weights of the model.
*/
- GMM(const std::vector<arma::vec>& means,
- const std::vector<arma::mat>& covariances,
+ GMM(const std::vector<distribution::GaussianDistribution> & dists,
const arma::vec& weights,
FittingType& fitter) :
- gaussians(means.size()),
- dimensionality((!means.empty()) ? means[0].n_elem : 0),
- means(means),
- covariances(covariances),
+ gaussians(dists.size()),
+ dimensionality((!dists.empty()) ? dists[0].Mean().n_elem : 0),
+ dists(dists),
weights(weights),
fitter(fitter) { /* Nothing to do. */ }
@@ -202,6 +200,21 @@ class GMM
*/
void Save(const std::string& filename) const;
+ /**
+ * Load a GMM from a SaveRestoreUtility. The format should be the same
+ * as is generated by the Save() method.
+ *
+ * @param filename Name of SaveRestoreUtility containing model to be loaded.
+ */
+ void Load(const util::SaveRestoreUtility& sr);
+
+ /**
+ * Save a GMM to a SaveRestoreUtility.
+ *
+ * @param SaveRestoreUtility object to save to.
+ */
+ void Save(util::SaveRestoreUtility& sr) const;
+
//! Return the number of gaussians in the model.
size_t Gaussians() const { return gaussians; }
//! Modify the number of gaussians in the model. Careful! You will have to
@@ -214,15 +227,46 @@ class GMM
//! each mean and covariance matrix yourself.
size_t& Dimensionality() { return dimensionality; }
- //! Return a const reference to the vector of means (mu).
- const std::vector<arma::vec>& Means() const { return means; }
- //! Return a reference to the vector of means (mu).
- std::vector<arma::vec>& Means() { return means; }
+ /**
+ * Return a const reference to a component distribution.
+ *
+ * @param i index of component.
+ */
+ const distribution::GaussianDistribution& Component(size_t i) const {
+ return dists[i]; }
+ /**
+ * Return a reference to a component distribution.
+ *
+ * @param i index of component.
+ */
+ distribution::GaussianDistribution& Component(size_t i) { return dists[i]; }
+
+ //! Functions from earlier releases give errors
+ const std::vector<arma::vec>& Means() const
+ {
+ Log::Fatal << "GMM::Means() no longer supported."
+ << "See GMM::Components().";
+ return means;
+ }
+ std::vector<arma::vec>& Means()
+ {
+ Log::Fatal << "GMM::Means() no longer supported."
+ << "See GMM::Components().";
+ return means;
+ }
+ const std::vector<arma::mat>& Covariances() const
+ {
+ Log::Fatal << "GMM::Covariances() no longer supported."
+ << "See GMM::Components().";
+ return covariances;
+ }
+ std::vector<arma::mat>& Covariances()
+ {
+ Log::Fatal << "GMM::Covariances() no longer supported."
+ << "See GMM::Components().";
+ return covariances;
+ }
- //! Return a const reference to the vector of covariance matrices (sigma).
- const std::vector<arma::mat>& Covariances() const { return covariances; }
- //! Return a reference to the vector of covariance matrices (sigma).
- std::vector<arma::mat>& Covariances() { return covariances; }
//! Return a const reference to the a priori weights of each Gaussian.
const arma::vec& Weights() const { return weights; }
@@ -338,6 +382,11 @@ class GMM
* Returns a string representation of this object.
*/
std::string ToString() const;
+
+ /**
+ * Returns a string indicating the type.
+ */
+ static std::string const Type() { return "GMM"; }
private:
/**
@@ -350,8 +399,7 @@ class GMM
* @param weights Weights of the given mixture model.
*/
double LogLikelihood(const arma::mat& dataPoints,
- const std::vector<arma::vec>& means,
- const std::vector<arma::mat>& covars,
+ const std::vector<distribution::GaussianDistribution>& distsL,
const arma::vec& weights) const;
//! Locally-stored fitting object; in case the user did not pass one.
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list