[mlpack] 283/324: Added Save, Load tests
Barak A. Pearlmutter
barak+git at cs.nuim.ie
Sun Aug 17 08:22:19 UTC 2014
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch svn-trunk
in repository mlpack.
commit 4d49b10898cf57ccbb68157389ee9a6b14e8e6a0
Author: michaelfox99 <michaelfox99 at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date: Tue Aug 5 13:43:36 2014 +0000
Added Save, Load tests
git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@16967 9d5b8971-822b-0410-80eb-d18c1038ef23
---
src/mlpack/tests/hmm_test.cpp | 302 ++++++++++++++++++++++++++++++++----------
1 file changed, 229 insertions(+), 73 deletions(-)
diff --git a/src/mlpack/tests/hmm_test.cpp b/src/mlpack/tests/hmm_test.cpp
index c223478..2d5ec99 100644
--- a/src/mlpack/tests/hmm_test.cpp
+++ b/src/mlpack/tests/hmm_test.cpp
@@ -777,24 +777,24 @@ BOOST_AUTO_TEST_CASE(GMMHMMPredictTest)
gmms[0].Weights() = arma::vec("0.75 0.25");
// N([2.25 3.10], [1.00 0.20; 0.20 0.89])
- gmms[0].Means()[0] = arma::vec("4.25 3.10");
- gmms[0].Covariances()[0] = arma::mat("1.00 0.20; 0.20 0.89");
+ gmms[0].Component(0) = GaussianDistribution("4.25 3.10",
+ "1.00 0.20; 0.20 0.89");
// N([4.10 1.01], [1.00 0.00; 0.00 1.01])
- gmms[0].Means()[1] = arma::vec("7.10 5.01");
- gmms[0].Covariances()[1] = arma::mat("1.00 0.00; 0.00 1.01");
+ gmms[0].Component(1) = GaussianDistribution("7.10 5.01",
+ "1.00 0.00; 0.00 1.01");
gmms[1] = GMM<>(3, 2);
gmms[1].Weights() = arma::vec("0.4 0.2 0.4");
- gmms[1].Means()[0] = arma::vec("-3.00 -6.12");
- gmms[1].Covariances()[0] = arma::mat("1.00 0.00; 0.00 1.00");
+ gmms[1].Component(0) = GaussianDistribution("-3.00 -6.12",
+ "1.00 0.00; 0.00 1.00");
- gmms[1].Means()[1] = arma::vec("-4.25 -7.12");
- gmms[1].Covariances()[1] = arma::mat("1.50 0.60; 0.60 1.20");
+ gmms[1].Component(1) = GaussianDistribution("-4.25 -7.12",
+ "1.50 0.60; 0.60 1.20");
- gmms[1].Means()[2] = arma::vec("-6.15 -2.00");
- gmms[1].Covariances()[2] = arma::mat("1.00 0.80; 0.80 1.00");
+ gmms[1].Component(2) = GaussianDistribution("-6.15 -2.00",
+ "1.00 0.80; 0.80 1.00");
// Default MATLAB initial probabilities.
arma::vec initial("1 0");
@@ -846,20 +846,20 @@ BOOST_AUTO_TEST_CASE(GMMHMMLabeledTrainingTest)
gmms[0].Weights() = arma::vec("0.3 0.7");
// N([2.25 3.10], [1.00 0.20; 0.20 0.89])
- gmms[0].Means()[0] = arma::vec("4.25 3.10");
- gmms[0].Covariances()[0] = arma::mat("1.00 0.20; 0.20 0.89");
-
+ gmms[0].Component(0) = GaussianDistribution("4.25 3.10",
+ "1.00 0.20; 0.20 0.89");
+
// N([4.10 1.01], [1.00 0.00; 0.00 1.01])
- gmms[0].Means()[1] = arma::vec("7.10 5.01");
- gmms[0].Covariances()[1] = arma::mat("1.00 0.00; 0.00 1.01");
+ gmms[0].Component(1) = GaussianDistribution("7.10 5.01",
+ "1.00 0.00; 0.00 1.01");
gmms[1].Weights() = arma::vec("0.20 0.80");
- gmms[1].Means()[0] = arma::vec("-3.00 -6.12");
- gmms[1].Covariances()[0] = arma::mat("1.00 0.00; 0.00 1.00");
+ gmms[1].Component(0) = GaussianDistribution("-3.00 -6.12",
+ "1.00 0.00; 0.00 1.00");
- gmms[1].Means()[1] = arma::vec("-4.25 -2.12");
- gmms[1].Covariances()[1] = arma::mat("1.50 0.60; 0.60 1.20");
+ gmms[1].Component(1) = GaussianDistribution("-4.25 -2.12",
+ "1.50 0.60; 0.60 1.20");
// Transition matrix.
arma::mat transMat("0.40 0.60;"
@@ -912,33 +912,34 @@ BOOST_AUTO_TEST_CASE(GMMHMMLabeledTrainingTest)
BOOST_REQUIRE_SMALL(hmm.Emission()[0].Weights()[sortedIndices[1]] -
gmms[0].Weights()[1], 0.08);
- BOOST_REQUIRE_SMALL(hmm.Emission()[0].Means()[sortedIndices[0]][0] -
- gmms[0].Means()[0][0], 0.15);
- BOOST_REQUIRE_SMALL(hmm.Emission()[0].Means()[sortedIndices[0]][1] -
- gmms[0].Means()[0][1], 0.15);
-
- BOOST_REQUIRE_SMALL(hmm.Emission()[0].Means()[sortedIndices[1]][0] -
- gmms[0].Means()[1][0], 0.15);
- BOOST_REQUIRE_SMALL(hmm.Emission()[0].Means()[sortedIndices[1]][1] -
- gmms[0].Means()[1][1], 0.15);
-
- BOOST_REQUIRE_SMALL(hmm.Emission()[0].Covariances()[sortedIndices[0]](0, 0) -
- gmms[0].Covariances()[0](0, 0), 0.3);
- BOOST_REQUIRE_SMALL(hmm.Emission()[0].Covariances()[sortedIndices[0]](0, 1) -
- gmms[0].Covariances()[0](0, 1), 0.3);
- BOOST_REQUIRE_SMALL(hmm.Emission()[0].Covariances()[sortedIndices[0]](1, 0) -
- gmms[0].Covariances()[0](1, 0), 0.3);
- BOOST_REQUIRE_SMALL(hmm.Emission()[0].Covariances()[sortedIndices[0]](1, 1) -
- gmms[0].Covariances()[0](1, 1), 0.3);
-
- BOOST_REQUIRE_SMALL(hmm.Emission()[0].Covariances()[sortedIndices[1]](0, 0) -
- gmms[0].Covariances()[1](0, 0), 0.3);
- BOOST_REQUIRE_SMALL(hmm.Emission()[0].Covariances()[sortedIndices[1]](0, 1) -
- gmms[0].Covariances()[1](0, 1), 0.3);
- BOOST_REQUIRE_SMALL(hmm.Emission()[0].Covariances()[sortedIndices[1]](1, 0) -
- gmms[0].Covariances()[1](1, 0), 0.3);
- BOOST_REQUIRE_SMALL(hmm.Emission()[0].Covariances()[sortedIndices[1]](1, 1) -
- gmms[0].Covariances()[1](1, 1), 0.3);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Component(sortedIndices[0]).Mean()[0] -
+ gmms[0].Component(0).Mean()[0], 0.15);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Component(sortedIndices[0]).Mean()[1] -
+ gmms[0].Component(0).Mean()[1], 0.15);
+
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Component(sortedIndices[1]).Mean()[0] -
+ gmms[0].Component(1).Mean()[0], 0.15);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Component(sortedIndices[1]).Mean()[1] -
+ gmms[0].Component(1).Mean()[1], 0.15);
+
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Component(sortedIndices[0]).
+ Covariance()(0, 0) - gmms[0].Component(0).Covariance()(0, 0), 0.3);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Component(sortedIndices[0]).
+ Covariance()(0, 1) - gmms[0].Component(0).Covariance()(0, 1), 0.3);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Component(sortedIndices[0]).
+ Covariance()(1, 0) - gmms[0].Component(0).Covariance()(1, 0), 0.3);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Component(sortedIndices[0]).
+ Covariance()(1, 1) - gmms[0].Component(0).Covariance()(1, 1), 0.3);
+
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Component(sortedIndices[1]).
+ Covariance()(0, 0) - gmms[0].Component(1).Covariance()(0, 0), 0.3);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Component(sortedIndices[1]).
+ Covariance()(0, 1) - gmms[0].Component(1).Covariance()(0, 1), 0.3);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Component(sortedIndices[1]).
+ Covariance()(1, 0) - gmms[0].Component(1).Covariance()(1, 0), 0.3);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Component(sortedIndices[1]).
+ Covariance()(1, 1) - gmms[0].Component(1).Covariance()(1, 1), 0.3);
+
// Sort the GMM.
sortedIndices = sort_index(hmm.Emission()[1].Weights());
@@ -948,33 +949,188 @@ BOOST_AUTO_TEST_CASE(GMMHMMLabeledTrainingTest)
BOOST_REQUIRE_SMALL(hmm.Emission()[1].Weights()[sortedIndices[1]] -
gmms[1].Weights()[1], 0.08);
- BOOST_REQUIRE_SMALL(hmm.Emission()[1].Means()[sortedIndices[0]][0] -
- gmms[1].Means()[0][0], 0.15);
- BOOST_REQUIRE_SMALL(hmm.Emission()[1].Means()[sortedIndices[0]][1] -
- gmms[1].Means()[0][1], 0.15);
-
- BOOST_REQUIRE_SMALL(hmm.Emission()[1].Means()[sortedIndices[1]][0] -
- gmms[1].Means()[1][0], 0.15);
- BOOST_REQUIRE_SMALL(hmm.Emission()[1].Means()[sortedIndices[1]][1] -
- gmms[1].Means()[1][1], 0.15);
-
- BOOST_REQUIRE_SMALL(hmm.Emission()[1].Covariances()[sortedIndices[0]](0, 0) -
- gmms[1].Covariances()[0](0, 0), 0.3);
- BOOST_REQUIRE_SMALL(hmm.Emission()[1].Covariances()[sortedIndices[0]](0, 1) -
- gmms[1].Covariances()[0](0, 1), 0.3);
- BOOST_REQUIRE_SMALL(hmm.Emission()[1].Covariances()[sortedIndices[0]](1, 0) -
- gmms[1].Covariances()[0](1, 0), 0.3);
- BOOST_REQUIRE_SMALL(hmm.Emission()[1].Covariances()[sortedIndices[0]](1, 1) -
- gmms[1].Covariances()[0](1, 1), 0.3);
-
- BOOST_REQUIRE_SMALL(hmm.Emission()[1].Covariances()[sortedIndices[1]](0, 0) -
- gmms[1].Covariances()[1](0, 0), 0.3);
- BOOST_REQUIRE_SMALL(hmm.Emission()[1].Covariances()[sortedIndices[1]](0, 1) -
- gmms[1].Covariances()[1](0, 1), 0.3);
- BOOST_REQUIRE_SMALL(hmm.Emission()[1].Covariances()[sortedIndices[1]](1, 0) -
- gmms[1].Covariances()[1](1, 0), 0.3);
- BOOST_REQUIRE_SMALL(hmm.Emission()[1].Covariances()[sortedIndices[1]](1, 1) -
- gmms[1].Covariances()[1](1, 1), 0.3);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Component(sortedIndices[0]).Mean()[0] -
+ gmms[1].Component(0).Mean()[0], 0.15);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Component(sortedIndices[0]).Mean()[1] -
+ gmms[1].Component(0).Mean()[1], 0.15);
+
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Component(sortedIndices[1]).Mean()[0] -
+ gmms[1].Component(1).Mean()[0], 0.15);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Component(sortedIndices[1]).Mean()[1] -
+ gmms[1].Component(1).Mean()[1], 0.15);
+
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Component(sortedIndices[0]).
+ Covariance()(0, 0) - gmms[1].Component(0).Covariance()(0, 0), 0.3);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Component(sortedIndices[0]).
+ Covariance()(0, 1) - gmms[1].Component(0).Covariance()(0, 1), 0.3);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Component(sortedIndices[0]).
+ Covariance()(1, 0) - gmms[1].Component(0).Covariance()(1, 0), 0.3);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Component(sortedIndices[0]).
+ Covariance()(1, 1) - gmms[1].Component(0).Covariance()(1, 1), 0.3);
+
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Component(sortedIndices[1]).
+ Covariance()(0, 0) - gmms[1].Component(1).Covariance()(0, 0), 0.3);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Component(sortedIndices[1]).
+ Covariance()(0, 1) - gmms[1].Component(1).Covariance()(0, 1), 0.3);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Component(sortedIndices[1]).
+ Covariance()(1, 0) - gmms[1].Component(1).Covariance()(1, 0), 0.3);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Component(sortedIndices[1]).
+ Covariance()(1, 1) - gmms[1].Component(1).Covariance()(1, 1), 0.3);
+}
+
+/**
+ * Test saving and loading of GMM HMMs
+ */
+BOOST_AUTO_TEST_CASE(GMMHMMLoadSaveTest)
+{
+ // Create a GMM HMM, save it, and load it.
+ HMM<GMM<> > hmm(3, GMM<>(4, 3));
+
+ for(size_t j = 0; j < hmm.Emission().size(); ++j)
+ {
+ hmm.Emission()[j].Weights().randu();
+ for (size_t i = 0; i < hmm.Emission()[j].Gaussians(); ++i)
+ {
+ hmm.Emission()[j].Component(i).Mean().randu();
+ hmm.Emission()[j].Component(i).Covariance().randu();
+ }
+ }
+
+ util::SaveRestoreUtility sr;
+ hmm.Save(sr);
+ sr.WriteFile("test-hmm-save.xml");
+
+ util::SaveRestoreUtility sr2;
+ sr2.ReadFile("test-hmm-save.xml");
+ HMM<GMM<> > hmm2(3, GMM<>(4, 3));
+ hmm2.Load(sr2);
+
+ // Remove clutter.
+ remove("test-hmm-save.xml");
+
+ for(size_t j = 0; j < hmm.Emission().size(); ++j)
+ {
+ BOOST_REQUIRE_EQUAL(hmm.Emission()[j].Gaussians(),
+ hmm2.Emission()[j].Gaussians());
+ BOOST_REQUIRE_EQUAL(hmm.Emission()[j].Dimensionality(),
+ hmm2.Emission()[j].Dimensionality());
+
+ for (size_t i = 0; i < hmm.Emission()[j].Dimensionality(); ++i)
+ BOOST_REQUIRE_CLOSE(hmm.Emission()[j].Weights()[i],
+ hmm2.Emission()[j].Weights()[i], 1e-3);
+
+ for (size_t i = 0; i < hmm.Emission()[j].Gaussians(); ++i)
+ {
+ for (size_t l = 0; l < hmm.Emission()[j].Dimensionality(); ++l)
+ {
+ BOOST_REQUIRE_CLOSE(hmm.Emission()[j].Component(i).Mean()[l],
+ hmm2.Emission()[j].Component(i).Mean()[l], 1e-3);
+
+ for (size_t k = 0; k < hmm.Emission()[j].Dimensionality(); ++k)
+ {
+ BOOST_REQUIRE_CLOSE(hmm.Emission()[j].Component(i).Covariance()(l,k),
+ hmm2.Emission()[j].Component(i).Covariance()(l, k), 1e-3);
+ }
+ }
+ }
+
+ }
+}
+
+/**
+ * Test saving and loading of Gaussian HMMs
+ */
+BOOST_AUTO_TEST_CASE(GaussianHMMLoadSaveTest)
+{
+ // Create a Gaussian HMM, save it, and load it.
+ HMM<GaussianDistribution> hmm(3, GaussianDistribution(2));
+
+
+ for(size_t j = 0; j < hmm.Emission().size(); ++j)
+ {
+ hmm.Emission()[j].Mean().randu();
+ hmm.Emission()[j].Covariance().randu();
+ }
+
+ util::SaveRestoreUtility sr;
+ hmm.Save(sr);
+ sr.WriteFile("test-hmm-save.xml");
+
+ util::SaveRestoreUtility sr2;
+ sr2.ReadFile("test-hmm-save.xml");
+ HMM<GaussianDistribution> hmm2(3, GaussianDistribution(2));
+ hmm2.Load(sr2);
+
+ // Remove clutter.
+ remove("test-hmm-save.xml");
+
+ for(size_t j = 0; j < hmm.Emission().size(); ++j)
+ {
+ BOOST_REQUIRE_EQUAL(hmm.Emission()[j].Dimensionality(),
+ hmm2.Emission()[j].Dimensionality());
+
+ for (size_t i = 0; i < hmm.Emission()[j].Dimensionality(); ++i)
+ {
+ BOOST_REQUIRE_CLOSE(hmm.Emission()[j].Mean()[i],
+ hmm2.Emission()[j].Mean()[i], 1e-3);
+ for (size_t k = 0; k < hmm.Emission()[j].Dimensionality(); ++k)
+ {
+ BOOST_REQUIRE_CLOSE(hmm.Emission()[j].Covariance()(i,k),
+ hmm2.Emission()[j].Covariance()(i, k), 1e-3);
+ }
+ }
+
+ }
+}
+
+/**
+ * Test saving and loading of Discrete HMMs
+ */
+BOOST_AUTO_TEST_CASE(DiscreteHMMLoadSaveTest)
+{
+ // Create a Discrete HMM, save it, and load it.
+
+ std::vector<DiscreteDistribution> emission(4);
+ emission[0].Probabilities() = arma::randu<arma::vec>(6);
+ emission[0].Probabilities() /= accu(emission[0].Probabilities());
+ emission[1].Probabilities() = arma::randu<arma::vec>(6);
+ emission[1].Probabilities() /= accu(emission[1].Probabilities());
+ emission[2].Probabilities() = arma::randu<arma::vec>(6);
+ emission[2].Probabilities() /= accu(emission[2].Probabilities());
+ emission[3].Probabilities() = arma::randu<arma::vec>(6);
+ emission[3].Probabilities() /= accu(emission[3].Probabilities());
+
+
+ // Create HMM object.
+ HMM<DiscreteDistribution> hmm(3, DiscreteDistribution(3));
+
+
+ for(size_t j = 0; j < hmm.Emission().size(); ++j)
+ {
+ hmm.Emission()[j].Probabilities() = arma::randu<arma::vec>(3);
+ hmm.Emission()[j].Probabilities() /= accu(emission[j].Probabilities());
+ }
+
+ util::SaveRestoreUtility sr;
+ hmm.Save(sr);
+ sr.WriteFile("test-hmm-save.xml");
+
+ util::SaveRestoreUtility sr2;
+ sr2.ReadFile("test-hmm-save.xml");
+ HMM<DiscreteDistribution> hmm2(3, DiscreteDistribution(3));
+ hmm2.Load(sr2);
+
+ // Remove clutter.
+ remove("test-hmm-save.xml");
+
+ for(size_t j = 0; j < hmm.Emission().size(); ++j)
+ {
+ for (size_t i = 0; i < hmm.Emission()[j].Probabilities().n_elem; ++i)
+ {
+ BOOST_REQUIRE_CLOSE(hmm.Emission()[j].Probabilities()[i],
+ hmm2.Emission()[j].Probabilities()[i], 1e-3);
+ }
+ }
}
BOOST_AUTO_TEST_SUITE_END();
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list