[mlpack] 35/53: Add comprehensive tests for QDAFN.
Barak A. Pearlmutter
barak+git at pearlmutter.net
Mon Nov 14 00:46:49 UTC 2016
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch master
in repository mlpack.
commit 15f4b073adc1410d182cc94d5fded590331eff71
Author: Ryan Curtin <ryan at ratml.org>
Date: Sun Oct 30 21:32:24 2016 +0900
Add comprehensive tests for QDAFN.
There is a bug now, but I have to push this to be able to solve it so I can get
to a system that has working gdb.
---
src/mlpack/methods/approx_kfn/qdafn.hpp | 16 +++++
src/mlpack/methods/approx_kfn/qdafn_impl.hpp | 9 +++
src/mlpack/tests/qdafn_test.cpp | 101 +++++++++++++++++++++++++++
3 files changed, 126 insertions(+)
diff --git a/src/mlpack/methods/approx_kfn/qdafn.hpp b/src/mlpack/methods/approx_kfn/qdafn.hpp
index f7949db..6ba8b81 100644
--- a/src/mlpack/methods/approx_kfn/qdafn.hpp
+++ b/src/mlpack/methods/approx_kfn/qdafn.hpp
@@ -50,6 +50,19 @@ class QDAFN
const size_t m);
/**
+ * Train the QDAFN model on the given reference set, optionally setting new
+ * parameters for the number of projections/tables (l) and the number of
+ * elements stored for each projection/table (m).
+ *
+ * @param referenceSet Reference set to train on.
+ * @param l Number of projections.
+ * @param m Number of elements to store for each projection.
+ */
+ void Train(const MatType& referenceSet,
+ const size_t l = 0,
+ const size_t m = 0);
+
+ /**
* Search for the k furthest neighbors of the given query set. (The query set
* can contain just one point, that is okay.) The results will be stored in
* the given neighbors and distances matrices, in the same format as the
@@ -64,6 +77,9 @@ class QDAFN
template<typename Archive>
void Serialize(Archive& ar, const unsigned int /* version */);
+ //! Get the number of projections.
+ size_t NumProjections() const { return candidateSet.size(); }
+
//! Get the candidate set for the given projection table.
const MatType& CandidateSet(const size_t t) const { return candidateSet[t]; }
//! Modify the candidate set for the given projection table. Careful!
diff --git a/src/mlpack/methods/approx_kfn/qdafn_impl.hpp b/src/mlpack/methods/approx_kfn/qdafn_impl.hpp
index de6c882..475538c 100644
--- a/src/mlpack/methods/approx_kfn/qdafn_impl.hpp
+++ b/src/mlpack/methods/approx_kfn/qdafn_impl.hpp
@@ -28,6 +28,15 @@ QDAFN<MatType>::QDAFN(const MatType& referenceSet,
l(l),
m(m)
{
+ Train(referenceSet);
+}
+
+// Train the object.
+template<typename MatType>
+void QDAFN<MatType>::Train(const MatType& referenceSet,
+ const size_t l,
+ const size_t m)
+{
// Build tables. This is done by drawing random points from a Gaussian
// distribution as the vectors we project onto. The Gaussian should have zero
// mean and unit variance.
diff --git a/src/mlpack/tests/qdafn_test.cpp b/src/mlpack/tests/qdafn_test.cpp
index ea64b52..332b7c7 100644
--- a/src/mlpack/tests/qdafn_test.cpp
+++ b/src/mlpack/tests/qdafn_test.cpp
@@ -102,4 +102,105 @@ BOOST_AUTO_TEST_CASE(QDAFNUniformSet)
BOOST_REQUIRE_GE(successes, 700);
}
+/**
+ * Test re-training method.
+ */
+BOOST_AUTO_TEST_CASE(RetrainTest)
+{
+ arma::mat dataset = arma::randu<arma::mat>(25, 500);
+ arma::mat newDataset = arma::randu<arma::mat>(15, 600);
+
+ QDAFN<> qdafn(dataset, 20, 60);
+
+ qdafn.Train(newDataset, 10, 50);
+
+ BOOST_REQUIRE_EQUAL(qdafn.NumProjections(), 10);
+ for (size_t i = 0; i < 10; ++i)
+ {
+ BOOST_REQUIRE_EQUAL(qdafn.CandidateSet(i).n_rows, 15);
+ BOOST_REQUIRE_EQUAL(qdafn.CandidateSet(i).n_cols, 50);
+ }
+}
+
+/**
+ * Test serialization of QDAFN.
+ */
+BOOST_AUTO_TEST_CASE(SerializationTest)
+{
+ // Use a random dataset.
+ arma::mat dataset = arma::randu<arma::mat>(15, 300);
+
+ QDAFN<> qdafn(dataset, 10, 50);
+
+ arma::mat fakeDataset1 = arma::randu<arma::mat>(10, 200);
+ arma::mat fakeDataset2 = arma::randu<arma::mat>(50, 500);
+ QDAFN<> qdafnXml(fakeDataset1, 5, 10);
+ QDAFN<> qdafnText(6, 50);
+ QDAFN<> qdafnBinary(7, 15);
+ qdafnBinary.Train(fakeDataset2);
+
+ // Serialize the objects.
+ SerializeObjectAll(qdafn, qdafnXml, qdafnText, qdafnBinary);
+
+ // Check that the tables are all the same.
+ BOOST_REQUIRE_EQUAL(qdafnXml.NumProjections(), qdafn.NumProjections());
+ BOOST_REQUIRE_EQUAL(qdafnText.NumProjections(), qdafn.NumProjections());
+ BOOST_REQUIRE_EQUAL(qdafnBinary.NumProjections(), qdafn.NumProjections());
+
+ for (size_t i = 0; i < qdafn.NumProjections(); ++i)
+ {
+ BOOST_REQUIRE_EQUAL(qdafnXml.CandidateSet(i).n_rows,
+ qdafn.CandidateSet(i).n_rows);
+ BOOST_REQUIRE_EQUAL(qdafnText.CandidateSet(i).n_rows,
+ qdafn.CandidateSet(i).n_rows);
+ BOOST_REQUIRE_EQUAL(qdafnBinary.CandidateSet(i).n_rows,
+ qdafn.CandidateSet(i).n_rows);
+
+ BOOST_REQUIRE_EQUAL(qdafnXml.CandidateSet(i).n_cols,
+ qdafn.CandidateSet(i).n_cols);
+ BOOST_REQUIRE_EQUAL(qdafnText.CandidateSet(i).n_cols,
+ qdafn.CandidateSet(i).n_cols);
+ BOOST_REQUIRE_EQUAL(qdafnBinary.CandidateSet(i).n_cols,
+ qdafn.CandidateSet(i).n_cols);
+
+ for (size_t j = 0; j < qdafn.CandidateSet(i).n_elem; ++j)
+ {
+ if (std::abs(qdafn.CandidateSet(i)[j]) < 1e-5)
+ {
+ BOOST_REQUIRE_SMALL(qdafnXml.CandidateSet(i)[j], 1e-5);
+ BOOST_REQUIRE_SMALL(qdafnText.CandidateSet(i)[j], 1e-5);
+ BOOST_REQUIRE_SMALL(qdafnBinary.CandidateSet(i)[j], 1e-5);
+ }
+ else
+ {
+ const double value = qdafn.CandidateSet(i)[j];
+ BOOST_REQUIRE_CLOSE(qdafnXml.CandidateSet(i)[j], value, 1e-5);
+ BOOST_REQUIRE_CLOSE(qdafnText.CandidateSet(i)[j], value, 1e-5);
+ BOOST_REQUIRE_CLOSE(qdafnBinary.CandidateSet(i)[j], value, 1e-5);
+ }
+ }
+ }
+}
+
+// Make sure QDAFN works with sparse data.
+BOOST_AUTO_TEST_CASE(SparseTest)
+{
+ arma::sp_mat dataset;
+ dataset.sprandu(200, 1000, 0.3);
+
+ // Create a sparse version.
+ QDAFN<arma::sp_mat> sparse(dataset, 15, 50);
+
+ // Make sure the results are of the right shape. It's hard to test anything
+ // more than that because we don't have easy-to-check performance guarantees.
+ arma::Mat<size_t> neighbors;
+ arma::mat distances;
+ sparse.Search(dataset, 3, neighbors, distances);
+
+ BOOST_REQUIRE_EQUAL(neighbors.n_rows, 3);
+ BOOST_REQUIRE_EQUAL(neighbors.n_cols, 1000);
+ BOOST_REQUIRE_EQUAL(distances.n_rows, 3);
+ BOOST_REQUIRE_EQUAL(distances.n_cols, 1000);
+}
+
BOOST_AUTO_TEST_SUITE_END();
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list