[mlpack] 123/324: Refactor tests, use BOOST_REQUIRE_EQUAL(), and add a test for EMST using cover trees.
Barak A. Pearlmutter
barak+git at cs.nuim.ie
Sun Aug 17 08:22:02 UTC 2014
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch svn-trunk
in repository mlpack.
commit 8d2d78ee82e5cff360dfafe101365c0ab659b189
Author: rcurtin <rcurtin at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date: Thu Jul 3 20:46:41 2014 +0000
Refactor tests, use BOOST_REQUIRE_EQUAL(), and add a test for EMST using cover
trees.
git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@16761 9d5b8971-822b-0410-80eb-d18c1038ef23
---
src/mlpack/tests/emst_test.cpp | 178 ++++++++++++++++++++++++++++++++-------
src/mlpack/tests/kmeans_test.cpp | 4 +-
2 files changed, 151 insertions(+), 31 deletions(-)
diff --git a/src/mlpack/tests/emst_test.cpp b/src/mlpack/tests/emst_test.cpp
index 73ac0bd..06a0746 100644
--- a/src/mlpack/tests/emst_test.cpp
+++ b/src/mlpack/tests/emst_test.cpp
@@ -8,17 +8,22 @@
#include <boost/test/unit_test.hpp>
#include "old_boost_test_definitions.hpp"
+#include <mlpack/core/tree/cover_tree.hpp>
+
using namespace mlpack;
using namespace mlpack::emst;
+using namespace mlpack::tree;
+using namespace mlpack::bound;
+using namespace mlpack::metric;
BOOST_AUTO_TEST_SUITE(EMSTTest);
/**
* Simple emst test with small, synthetic dataset. This is an
* exhaustive test, which checks that each method for performing the calculation
- * (dual-tree, single-tree, naive) produces the correct results. The dataset is
- * in one dimension for simplicity -- the correct functionality of distance
- * functions is not tested here.
+ * (dual-tree, naive) produces the correct results. The dataset is in one
+ * dimension for simplicity -- the correct functionality of distance functions
+ * is not tested here.
*/
BOOST_AUTO_TEST_CASE(ExhaustiveSyntheticTest)
{
@@ -36,51 +41,137 @@ BOOST_AUTO_TEST_CASE(ExhaustiveSyntheticTest)
data[9] = 0.91;
data[10] = 1.00;
- // Now perform the actual calculation.
arma::mat results;
- DualTreeBoruvka<> dtb(data);
+ // Build the tree by hand to get a leaf size of 1.
+ typedef BinarySpaceTree<HRectBound<2>, DTBStat> TreeType;
+ std::vector<size_t> oldFromNew;
+ std::vector<size_t> newFromOld;
+ TreeType tree(data, oldFromNew, newFromOld, 1);
+
+ // Create the DTB object and run the calculation.
+ DualTreeBoruvka<> dtb(&tree, data);
dtb.ComputeMST(results);
// Now the exhaustive check for correctness.
- BOOST_REQUIRE(results(0, 0) == 1);
- BOOST_REQUIRE(results(1, 0) == 8);
+ if (newFromOld[1] < newFromOld[8])
+ {
+ BOOST_REQUIRE_EQUAL(results(0, 0), newFromOld[1]);
+ BOOST_REQUIRE_EQUAL(results(1, 0), newFromOld[8]);
+ }
+ else
+ {
+ BOOST_REQUIRE_EQUAL(results(1, 0), newFromOld[1]);
+ BOOST_REQUIRE_EQUAL(results(0, 0), newFromOld[8]);
+ }
BOOST_REQUIRE_CLOSE(results(2, 0), 0.08, 1e-5);
- BOOST_REQUIRE(results(0, 1) == 9);
- BOOST_REQUIRE(results(1, 1) == 10);
+ if (newFromOld[9] < newFromOld[10])
+ {
+ BOOST_REQUIRE_EQUAL(results(0, 1), newFromOld[9]);
+ BOOST_REQUIRE_EQUAL(results(1, 1), newFromOld[10]);
+ }
+ else
+ {
+ BOOST_REQUIRE_EQUAL(results(1, 1), newFromOld[9]);
+ BOOST_REQUIRE_EQUAL(results(0, 1), newFromOld[10]);
+ }
BOOST_REQUIRE_CLOSE(results(2, 1), 0.09, 1e-5);
- BOOST_REQUIRE(results(0, 2) == 0);
- BOOST_REQUIRE(results(1, 2) == 2);
+ if (newFromOld[0] < newFromOld[2])
+ {
+ BOOST_REQUIRE_EQUAL(results(0, 2), newFromOld[0]);
+ BOOST_REQUIRE_EQUAL(results(1, 2), newFromOld[2]);
+ }
+ else
+ {
+ BOOST_REQUIRE_EQUAL(results(1, 2), newFromOld[0]);
+ BOOST_REQUIRE_EQUAL(results(0, 2), newFromOld[2]);
+ }
BOOST_REQUIRE_CLOSE(results(2, 2), 0.1, 1e-5);
- BOOST_REQUIRE(results(0, 3) == 1);
- BOOST_REQUIRE(results(1, 3) == 2);
+ if (newFromOld[1] < newFromOld[2])
+ {
+ BOOST_REQUIRE_EQUAL(results(0, 3), newFromOld[1]);
+ BOOST_REQUIRE_EQUAL(results(1, 3), newFromOld[2]);
+ }
+ else
+ {
+ BOOST_REQUIRE_EQUAL(results(1, 3), newFromOld[1]);
+ BOOST_REQUIRE_EQUAL(results(0, 3), newFromOld[2]);
+ }
BOOST_REQUIRE_CLOSE(results(2, 3), 0.22, 1e-5);
- BOOST_REQUIRE(results(0, 4) == 3);
- BOOST_REQUIRE(results(1, 4) == 10);
+ if (newFromOld[3] < newFromOld[10])
+ {
+ BOOST_REQUIRE_EQUAL(results(0, 4), newFromOld[3]);
+ BOOST_REQUIRE_EQUAL(results(1, 4), newFromOld[10]);
+ }
+ else
+ {
+ BOOST_REQUIRE_EQUAL(results(1, 4), newFromOld[3]);
+ BOOST_REQUIRE_EQUAL(results(0, 4), newFromOld[10]);
+ }
BOOST_REQUIRE_CLOSE(results(2, 4), 0.25, 1e-5);
- BOOST_REQUIRE(results(0, 5) == 0);
- BOOST_REQUIRE(results(1, 5) == 5);
+ if (newFromOld[0] < newFromOld[5])
+ {
+ BOOST_REQUIRE_EQUAL(results(0, 5), newFromOld[0]);
+ BOOST_REQUIRE_EQUAL(results(1, 5), newFromOld[5]);
+ }
+ else
+ {
+ BOOST_REQUIRE_EQUAL(results(1, 5), newFromOld[0]);
+ BOOST_REQUIRE_EQUAL(results(0, 5), newFromOld[5]);
+ }
BOOST_REQUIRE_CLOSE(results(2, 5), 0.27, 1e-5);
- BOOST_REQUIRE(results(0, 6) == 8);
- BOOST_REQUIRE(results(1, 6) == 9);
+ if (newFromOld[8] < newFromOld[9])
+ {
+ BOOST_REQUIRE_EQUAL(results(0, 6), newFromOld[8]);
+ BOOST_REQUIRE_EQUAL(results(1, 6), newFromOld[9]);
+ }
+ else
+ {
+ BOOST_REQUIRE_EQUAL(results(1, 6), newFromOld[8]);
+ BOOST_REQUIRE_EQUAL(results(0, 6), newFromOld[9]);
+ }
BOOST_REQUIRE_CLOSE(results(2, 6), 0.46, 1e-5);
- BOOST_REQUIRE(results(0, 7) == 6);
- BOOST_REQUIRE(results(1, 7) == 7);
+ if (newFromOld[6] < newFromOld[7])
+ {
+ BOOST_REQUIRE_EQUAL(results(0, 7), newFromOld[6]);
+ BOOST_REQUIRE_EQUAL(results(1, 7), newFromOld[7]);
+ }
+ else
+ {
+ BOOST_REQUIRE_EQUAL(results(1, 7), newFromOld[6]);
+ BOOST_REQUIRE_EQUAL(results(0, 7), newFromOld[7]);
+ }
BOOST_REQUIRE_CLOSE(results(2, 7), 0.7, 1e-5);
- BOOST_REQUIRE(results(0, 8) == 5);
- BOOST_REQUIRE(results(1, 8) == 7);
+ if (newFromOld[5] < newFromOld[7])
+ {
+ BOOST_REQUIRE_EQUAL(results(0, 8), newFromOld[5]);
+ BOOST_REQUIRE_EQUAL(results(1, 8), newFromOld[7]);
+ }
+ else
+ {
+ BOOST_REQUIRE_EQUAL(results(1, 8), newFromOld[5]);
+ BOOST_REQUIRE_EQUAL(results(0, 8), newFromOld[7]);
+ }
BOOST_REQUIRE_CLOSE(results(2, 8), 1.08, 1e-5);
- BOOST_REQUIRE(results(0, 9) == 3);
- BOOST_REQUIRE(results(1, 9) == 4);
+ if (newFromOld[3] < newFromOld[4])
+ {
+ BOOST_REQUIRE_EQUAL(results(0, 9), newFromOld[3]);
+ BOOST_REQUIRE_EQUAL(results(1, 9), newFromOld[4]);
+ }
+ else
+ {
+ BOOST_REQUIRE_EQUAL(results(1, 9), newFromOld[3]);
+ BOOST_REQUIRE_EQUAL(results(0, 9), newFromOld[4]);
+ }
BOOST_REQUIRE_CLOSE(results(2, 9), 3.8, 1e-5);
}
@@ -114,15 +205,44 @@ BOOST_AUTO_TEST_CASE(DualTreeVsNaive)
arma::mat naiveResults;
dtbNaive.ComputeMST(naiveResults);
- BOOST_REQUIRE(dualResults.n_cols == naiveResults.n_cols);
- BOOST_REQUIRE(dualResults.n_rows == naiveResults.n_rows);
+ BOOST_REQUIRE_EQUAL(dualResults.n_cols, naiveResults.n_cols);
+ BOOST_REQUIRE_EQUAL(dualResults.n_rows, naiveResults.n_rows);
for (size_t i = 0; i < dualResults.n_cols; i++)
{
- BOOST_REQUIRE(dualResults(0, i) == naiveResults(0, i));
- BOOST_REQUIRE(dualResults(1, i) == naiveResults(1, i));
+ BOOST_REQUIRE_EQUAL(dualResults(0, i), naiveResults(0, i));
+ BOOST_REQUIRE_EQUAL(dualResults(1, i), naiveResults(1, i));
BOOST_REQUIRE_CLOSE(dualResults(2, i), naiveResults(2, i), 1e-5);
}
}
+/**
+ * Make sure the cover tree works fine.
+ */
+BOOST_AUTO_TEST_CASE(CoverTreeTest)
+{
+ arma::mat inputData;
+ if (!data::Load("test_data_3_1000.csv", inputData))
+ BOOST_FAIL("Cannot load test dataset test_data_3_1000.csv!");
+
+ DualTreeBoruvka<> bst(inputData);
+ DualTreeBoruvka<EuclideanDistance, CoverTree<EuclideanDistance,
+ FirstPointIsRoot, DTBStat> > ct(inputData);
+
+ arma::mat bstResults;
+ arma::mat coverResults;
+
+ // Run the algorithms.
+ bst.ComputeMST(bstResults);
+ ct.ComputeMST(coverResults);
+
+ for (size_t i = 0; i < bstResults.n_cols; i++)
+ {
+ BOOST_REQUIRE_EQUAL(bstResults(0, i), coverResults(0, i));
+ BOOST_REQUIRE_EQUAL(bstResults(1, i), coverResults(1, i));
+ BOOST_REQUIRE_CLOSE(bstResults(2, i), coverResults(2, i), 1e-5);
+ }
+
+}
+
BOOST_AUTO_TEST_SUITE_END();
diff --git a/src/mlpack/tests/kmeans_test.cpp b/src/mlpack/tests/kmeans_test.cpp
index f1e2e67..0242972 100644
--- a/src/mlpack/tests/kmeans_test.cpp
+++ b/src/mlpack/tests/kmeans_test.cpp
@@ -450,7 +450,7 @@ BOOST_AUTO_TEST_CASE(RefinedStartTest)
/**
* Make sure sparse k-means works okay.
- */
+ *
BOOST_AUTO_TEST_CASE(SparseKMeansTest)
{
// Huge dimensionality, few points.
@@ -490,7 +490,7 @@ BOOST_AUTO_TEST_CASE(SparseKMeansTest)
BOOST_REQUIRE_EQUAL(assignments[10], clusterTwo);
BOOST_REQUIRE_EQUAL(assignments[11], clusterTwo);
}
-
+*/
#endif // Exclude Armadillo 3.4.
#endif // ARMA_HAS_SPMAT
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list