[mlpack] 17/53: Add DrusillaSelect implementation.
Barak A. Pearlmutter
barak+git at pearlmutter.net
Mon Nov 14 00:46:47 UTC 2016
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch master
in repository mlpack.
commit 5e0db4c90f96fe58ca31a2723a1a4e686043950d
Author: Ryan Curtin <ryan at ratml.org>
Date: Mon Oct 24 16:27:10 2016 +0900
Add DrusillaSelect implementation.
---
src/mlpack/methods/CMakeLists.txt | 1 +
.../methods/approx_kfn/.drusilla_select.hpp.swp | Bin 0 -> 16384 bytes
.../approx_kfn/.drusilla_select_impl.hpp.swo | Bin 0 -> 45056 bytes
.../approx_kfn/.drusilla_select_impl.hpp.swp | Bin 0 -> 20480 bytes
src/mlpack/methods/approx_kfn/CMakeLists.txt | 20 ++
src/mlpack/methods/approx_kfn/drusilla_select.hpp | 125 ++++++++++++
.../methods/approx_kfn/drusilla_select_impl.hpp | 210 +++++++++++++++++++++
.../methods/approx_kfn/drusilla_select_main.cpp | 100 ++++++++++
src/mlpack/tests/CMakeLists.txt | 1 +
src/mlpack/tests/drusilla_select_test.cpp | 145 ++++++++++++++
10 files changed, 602 insertions(+)
diff --git a/src/mlpack/methods/CMakeLists.txt b/src/mlpack/methods/CMakeLists.txt
index dbbd231..f292e97 100644
--- a/src/mlpack/methods/CMakeLists.txt
+++ b/src/mlpack/methods/CMakeLists.txt
@@ -18,6 +18,7 @@ endmacro ()
set(DIRS
preprocess
adaboost
+ approx_kfn
amf
ann
cf
diff --git a/src/mlpack/methods/approx_kfn/.drusilla_select.hpp.swp b/src/mlpack/methods/approx_kfn/.drusilla_select.hpp.swp
new file mode 100644
index 0000000..ae44b28
Binary files /dev/null and b/src/mlpack/methods/approx_kfn/.drusilla_select.hpp.swp differ
diff --git a/src/mlpack/methods/approx_kfn/.drusilla_select_impl.hpp.swo b/src/mlpack/methods/approx_kfn/.drusilla_select_impl.hpp.swo
new file mode 100644
index 0000000..b2bbbba
Binary files /dev/null and b/src/mlpack/methods/approx_kfn/.drusilla_select_impl.hpp.swo differ
diff --git a/src/mlpack/methods/approx_kfn/.drusilla_select_impl.hpp.swp b/src/mlpack/methods/approx_kfn/.drusilla_select_impl.hpp.swp
new file mode 100644
index 0000000..9d5090f
Binary files /dev/null and b/src/mlpack/methods/approx_kfn/.drusilla_select_impl.hpp.swp differ
diff --git a/src/mlpack/methods/approx_kfn/CMakeLists.txt b/src/mlpack/methods/approx_kfn/CMakeLists.txt
new file mode 100644
index 0000000..0e907d6
--- /dev/null
+++ b/src/mlpack/methods/approx_kfn/CMakeLists.txt
@@ -0,0 +1,20 @@
+# Define the files we need to compile.
+# Anything not in this list will not be compiled into mlpack.
+set(SOURCES
+ # DrusillaSelect sources.
+ drusilla_select.hpp
+ drusilla_select_impl.hpp
+)
+
+# Add directory name to sources.
+set(DIR_SRCS)
+foreach(file ${SOURCES})
+ set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file})
+endforeach()
+# Append sources (with directory name) to list of all mlpack sources (used at
+# the parent scope).
+set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)
+
+# The code to compute the approximate neighbor for the given query and reference
+# sets with p-stable LSH.
+add_cli_executable(drusilla_select)
diff --git a/src/mlpack/methods/approx_kfn/drusilla_select.hpp b/src/mlpack/methods/approx_kfn/drusilla_select.hpp
new file mode 100644
index 0000000..38b90ab
--- /dev/null
+++ b/src/mlpack/methods/approx_kfn/drusilla_select.hpp
@@ -0,0 +1,125 @@
+/**
+ * @file drusilla_select.hpp
+ * @author Ryan Curtin
+ *
+ * An implementation of the approximate furthest neighbor algorithm specified in
+ * the following paper:
+ *
+ * @code
+ * @incollection{curtin2016fast,
+ * title={Fast approximate furthest neighbors with data-dependent candidate
+ * selection},
+ * author={Curtin, R.R., and Gardner, A.B.},
+ * booktitle={Similarity Search and Applications},
+ * pages={221--235},
+ * year={2016},
+ * publisher={Springer}
+ * }
+ * @endcode
+ *
+ * This algorithm, called DrusillaSelect, constructs a candidate set of points
+ * to query to find an approximate furthest neighbor. The strange name is a
+ * result of the algorithm being named after a cat. The cat in question may be
+ * viewed at http://www.ratml.org/misc_img/drusilla_fence.png.
+ */
+#ifndef MLPACK_METHODS_APPROX_KFN_DRUSILLA_SELECT_HPP
+#define MLPACK_METHODS_APPROX_KFN_DRUSILLA_SELECT_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace neighbor {
+
+template<typename MatType = arma::mat>
+class DrusillaSelect
+{
+ public:
+ /**
+ * Construct the DrusillaSelect object with the given reference set (this is
+ * the set that will be searched). The resulting set of candidate points that
+ * will be searched at query time will have size l*m.
+ *
+ * @param referenceSet Set of reference data.
+ * @param l Number of projections.
+ * @param m Number of elements to store for each projection.
+ */
+ DrusillaSelect(const MatType& referenceSet,
+ const size_t l,
+ const size_t m);
+
+ /**
+ * Construct the DrusillaSelect object with no given reference set. Be sure
+ * to call Train() before calling Search()!
+ *
+ * @param l Number of projections.
+ * @param m Number of elements to store for each projection.
+ */
+ DrusillaSelect(const size_t l, const size_t m);
+
+ /**
+ * Build the set of candidate points on the given reference set. If l and m
+ * are left unspecified, then the values set in the constructor will be used
+ * instead.
+ *
+ * @param referenceSet Set to extract candidate points from.
+ * @param l Number of projections.
+ * @param m Number of elements to store for each projection.
+ */
+ void Train(const MatType& referenceSet,
+ const size_t l = 0,
+ const size_t m = 0);
+
+ /**
+ * Search for the k furthest neighbors of the given query set. (The query set
+ * can contain just one point: that is okay.) The results will be stored in
+ * the given neighbors and distances matrices, in the same format as the
+ * NeighborSearch and LSHSearch classes. That is, each column in the
+ * neighbors and distances matrices will refer to a single query point, and
+ * the k'th row in that column will refer to the k'th candidate neighbor or
+ * distance for that query point.
+ *
+ * @param querySet Set of query points to search.
+ * @param k Number of furthest neighbors to search for.
+ * @param neighbors Matrix to store resulting neighbors in.
+ * @param distances Matrix to store resulting distances in.
+ */
+ void Search(const MatType& querySet,
+ const size_t k,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances);
+
+ /**
+ * Serialize the model.
+ */
+ template<typename Archive>
+ void Serialize(Archive& ar, const unsigned int /* version */);
+
+ //! Access the candidate set.
+ const MatType& CandidateSet() const { return candidateSet; }
+ //! Modify the candidate set. Be careful!
+ MatType& CandidateSet() { return candidateSet; }
+
+ //! Access the indices of points in the candidate set.
+ const arma::Col<size_t>& CandidateIndices() const { return candidateIndices; }
+ //! Modify the indices of points in the candidate set. Be careful!
+ arma::Col<size_t>& CandidateIndices() { return candidateIndices; }
+
+ private:
+ //! The reference set.
+ MatType candidateSet;
+ //! Indices of each point in the reference set.
+ arma::Col<size_t> candidateIndices;
+
+ //! The number of projections.
+ size_t l;
+ //! The number of points in each projection.
+ size_t m;
+};
+
+} // namespace neighbor
+} // namespace mlpack
+
+// Include implementation.
+#include "drusilla_select_impl.hpp"
+
+#endif
diff --git a/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp b/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp
new file mode 100644
index 0000000..a84b304
--- /dev/null
+++ b/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp
@@ -0,0 +1,210 @@
+/**
+ * @file drusilla_select_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of DrusillaSelect class methods.
+ */
+#ifndef MLPACK_METHODS_APPROX_KFN_DRUSILLA_SELECT_IMPL_HPP
+#define MLPACK_METHODS_APPROX_KFN_DRUSILLA_SELECT_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "drusilla_select.hpp"
+
+#include <queue>
+#include <mlpack/methods/neighbor_search/neighbor_search_rules.hpp>
+#include <mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp>
+#include <mlpack/core/tree/binary_space_tree.hpp>
+#include <algorithm>
+
+namespace mlpack {
+namespace neighbor {
+
+// Constructor.
+template<typename MatType>
+DrusillaSelect<MatType>::DrusillaSelect(const MatType& referenceSet,
+ const size_t l,
+ const size_t m) :
+ candidateSet(referenceSet.n_rows, l * m),
+ l(l),
+ m(m)
+{
+ if (l == 0)
+ throw std::invalid_argument("DrusillaSelect::DrusillaSelect(): invalid "
+ "value of l; must be greater than 0!");
+ else if (m == 0)
+ throw std::invalid_argument("DrusillaSelect::DrusillaSelect(): invalid "
+ "value of m; must be greater than 0!");
+
+ Train(referenceSet, l, m);
+}
+
+// Constructor with no training.
+template<typename MatType>
+DrusillaSelect<MatType>::DrusillaSelect(const size_t l, const size_t m) :
+ l(l),
+ m(m)
+{
+ if (l == 0)
+ throw std::invalid_argument("DrusillaSelect::DrusillaSelect(): invalid "
+ "value of l; must be greater than 0!");
+ else if (m == 0)
+ throw std::invalid_argument("DrusillaSelect::DrusillaSelect(): invalid "
+ "value of m; must be greater than 0!");
+}
+
+// Train the model.
+template<typename MatType>
+void DrusillaSelect<MatType>::Train(
+ const MatType& referenceSet,
+ const size_t lIn,
+ const size_t mIn)
+{
+ // Did the user specify a new size? If so, use it.
+ if (lIn > 0)
+ l = lIn;
+ if (mIn > 0)
+ m = mIn;
+
+ if ((l * m) > referenceSet.n_cols)
+ throw std::invalid_argument("DrusillaSelect::Train(): l and m are too "
+ "large! Choose smaller values. l*m must be smaller than the number "
+ "of points in the dataset.");
+
+ arma::vec dataMean = arma::mean(referenceSet, 1);
+ arma::vec norms(referenceSet.n_cols);
+
+ arma::mat refCopy = referenceSet.each_col() - dataMean;
+ for (size_t i = 0; i < refCopy.n_cols; ++i)
+ norms[i] = arma::norm(refCopy.col(i) - dataMean);
+
+ // Find the top m points for each of the l projections...
+ for (size_t i = 0; i < l; ++i)
+ {
+ // Pick best index.
+ arma::uword maxIndex;
+ norms.max(maxIndex);
+
+ arma::vec line = refCopy.col(maxIndex) / arma::norm(refCopy.col(maxIndex));
+ const size_t n_nonzero = (size_t) arma::sum(norms > 0);
+
+ // Calculate distortion and offset.
+ arma::vec distortions(referenceSet.n_cols);
+ arma::vec offsets(referenceSet.n_cols);
+ for (size_t j = 0; j < referenceSet.n_cols; ++j)
+ {
+ if (norms[j] > 0.0)
+ {
+ offsets[j] = arma::dot(refCopy.col(j), line);
+ distortions[j] = arma::norm(refCopy.col(j) - offsets[j] *
+ line);
+ }
+ else
+ {
+ offsets[j] = 0.0;
+ distortions[j] = 0.0;
+ }
+ }
+ arma::vec sums = arma::abs(offsets) - arma::abs(distortions);
+ arma::uvec sortedSums = arma::sort_index(sums, "descend");
+
+ arma::vec bestSums(m);
+ arma::Col<size_t> bestIndices(m);
+ bestSums.fill(-DBL_MAX);
+
+ // Find the top m elements using a priority queue.
+ typedef std::pair<double, size_t> Candidate;
+ struct CandidateCmp
+ {
+ bool operator()(const Candidate& c1, const Candidate& c2)
+ {
+ return c2.first > c1.first;
+ }
+ };
+
+ std::vector<Candidate> clist(m, std::make_pair(size_t(-1), double(0.0)));
+ std::priority_queue<Candidate, std::vector<Candidate>, CandidateCmp>
+ pq(CandidateCmp(), std::move(clist));
+
+ for (size_t j = 0; j < sums.n_elem; ++j)
+ {
+ Candidate c = std::make_pair(sums[j], j);
+ if (CandidateCmp()(c, pq.top()))
+ {
+ pq.pop();
+ pq.push(c);
+ }
+ }
+
+ // Take the top m elements for this table.
+ for (size_t j = 0; j < m; ++j)
+ {
+ const size_t index = pq.top().second;
+ pq.pop();
+ candidateSet.col(i * m + j) = referenceSet.col(index);
+
+ // Mark the norm as 0 so we don't see this point again.
+ norms[index] = 0.0;
+ }
+
+ // Calculate angles from the current projection. Anything close enough,
+ // mark the norm as 0.
+ arma::vec farPoints = arma::conv_to<arma::vec>::from(
+ arma::atan(distortions / arma::abs(offsets)) >= (M_PI / 8.0));
+ norms %= farPoints;
+ }
+}
+
+// Search.
+template<typename MatType>
+void DrusillaSelect<MatType>::Search(const MatType& querySet,
+ const size_t k,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances)
+{
+ if (candidateSet.n_cols == 0)
+ throw std::runtime_error("DrusillaSelect::Search(): candidate set not "
+ "initialized! Call Train() first.");
+
+ if (k > (l * m))
+ throw std::invalid_argument("DrusillaSelect::Search(): requested k is "
+ "greater than number of points in candidate set! Increase l or m.");
+
+ // We'll use the NeighborSearchRules class to perform our brute-force search.
+ // Note that we aren't using trees for our search, so we can use 'int' as a
+ // TreeType.
+ metric::EuclideanDistance metric;
+ NeighborSearchRules<FurthestNeighborSort, metric::EuclideanDistance,
+ tree::KDTree<metric::EuclideanDistance, tree::EmptyStatistic, arma::mat>>
+ rules(querySet, candidateSet, k, metric, 0, false);
+
+ neighbors.set_size(k, querySet.n_cols);
+ neighbors.fill(size_t() - 1);
+ distances.zeros(k, querySet.n_cols);
+
+ for (size_t q = 0; q < querySet.n_cols; ++q)
+ for (size_t r = 0; r < candidateSet.n_cols; ++r)
+ rules.BaseCase(q, r);
+
+ // Map the neighbors back to their original indices in the reference set.
+ for (size_t i = 0; i < neighbors.n_elem; ++i)
+ neighbors[i] = candidateIndices[neighbors[i]];
+}
+
+//! Serialize the model.
+template<typename MatType>
+template<typename Archive>
+void DrusillaSelect<MatType>::Serialize(Archive& ar,
+ const unsigned int /* version */)
+{
+ using data::CreateNVP;
+
+ ar & CreateNVP(candidateSet, "candidateSet");
+ ar & CreateNVP(candidateIndices, "candidateIndices");
+ ar & CreateNVP(l, "l");
+ ar & CreateNVP(m, "m");
+}
+
+} // namespace neighbor
+} // namespace mlpack
+
+#endif
diff --git a/src/mlpack/methods/approx_kfn/drusilla_select_main.cpp b/src/mlpack/methods/approx_kfn/drusilla_select_main.cpp
new file mode 100644
index 0000000..9e55ec7
--- /dev/null
+++ b/src/mlpack/methods/approx_kfn/drusilla_select_main.cpp
@@ -0,0 +1,100 @@
+/**
+ * @file smarthash_main.cpp
+ * @author Ryan Curtin
+ *
+ * Command-line program for the SmartHash algorithm.
+ */
+#include <mlpack/core.hpp>
+#include "smarthash_fn.hpp"
+#include <mlpack/methods/neighbor_search/neighbor_search.hpp>
+
+using namespace smarthash;
+using namespace mlpack;
+using namespace std;
+
+PROGRAM_INFO("Query-dependent approximate furthest neighbor search",
+ "This program implements the algorithm from the SISAP 2015 paper titled "
+ "'Approximate Furthest Neighbor in High Dimensions' by R. Pagh, F. "
+ "Silvestri, J. Sivertsen, and M. Skala. Specify a reference set (set to "
+ "search in) with --reference_file, specify a query set (set to search for) "
+ "with --query_file, and specify algorithm parameters with --num_tables and "
+ "--num_projections (or don't, and defaults will be used). Also specify "
+ "the number of points to search for with --k. Each of those options has "
+ "short names too; see the detailed parameter documentation below."
+ "\n\n"
+ "Results for each query point are stored in the files specified by "
+ "--neighbors_file and --distances_file. This is in the same format as the "
+ "mlpack KFN and KNN programs: each row holds the k distances or neighbor "
+ "indices for each query point.");
+
+PARAM_STRING_REQ("reference_file", "File containing reference points.", "r");
+PARAM_STRING_REQ("query_file", "File containing query points.", "q");
+
+PARAM_INT_REQ("k", "Number of furthest neighbors to search for.", "k");
+
+PARAM_INT("num_tables", "Number of hash tables to use.", "t", 10);
+PARAM_INT("num_projections", "Number of projections to use in each hash table.",
+ "p", 30);
+
+PARAM_STRING("neighbors_file", "File to save furthest neighbor indices to.",
+ "n", "");
+PARAM_STRING("distances_file", "File to save furthest neighbor distances to.",
+ "d", "");
+
+PARAM_FLAG("calculate_error", "If set, calculate the average distance error.",
+ "e");
+PARAM_STRING("exact_distances_file", "File containing exact distances", "x", "");
+
+int main(int argc, char** argv)
+{
+ CLI::ParseCommandLine(argc, argv);
+
+ const string referenceFile = CLI::GetParam<string>("reference_file");
+ const string queryFile = CLI::GetParam<string>("query_file");
+ const size_t k = (size_t) CLI::GetParam<int>("k");
+ const size_t numTables = (size_t) CLI::GetParam<int>("num_tables");
+ const size_t numProjections = (size_t) CLI::GetParam<int>("num_projections");
+
+ // Load the data.
+ arma::mat referenceData, queryData;
+ data::Load(referenceFile, referenceData, true);
+ data::Load(queryFile, queryData, true);
+
+ // Construct the object.
+ Timer::Start("smarthash_construct");
+ SmartHash<> q(referenceData, numTables, numProjections);
+ Timer::Stop("smarthash_construct");
+
+ // Do the search.
+ arma::Mat<size_t> neighbors;
+ arma::mat distances;
+ Timer::Start("smarthash_search");
+ q.Search(queryData, k, neighbors, distances);
+ Timer::Stop("smarthash_search");
+
+ if (CLI::HasParam("calculate_error"))
+ {
+// neighbor::AllkFN kfn(referenceData);
+
+// arma::Mat<size_t> trueNeighbors;
+ arma::mat trueDistances;
+ data::Load(CLI::GetParam<string>("exact_distances_file"), trueDistances);
+
+// kfn.Search(queryData, 1, trueNeighbors, trueDistances);
+
+ const double averageError = arma::sum(trueDistances / distances.row(0)) /
+ distances.n_cols;
+ const double minError = arma::min(trueDistances / distances.row(0));
+ const double maxError = arma::max(trueDistances / distances.row(0));
+
+ Log::Info << "Average error: " << averageError << "." << endl;
+ Log::Info << "Maximum error: " << maxError << "." << endl;
+ Log::Info << "Minimum error: " << minError << "." << endl;
+ }
+
+ // Save the results.
+ if (CLI::HasParam("neighbors_file"))
+ data::Save(CLI::GetParam<string>("neighbors_file"), neighbors);
+ if (CLI::HasParam("distances_file"))
+ data::Save(CLI::GetParam<string>("distances_file"), distances);
+}
diff --git a/src/mlpack/tests/CMakeLists.txt b/src/mlpack/tests/CMakeLists.txt
index 9ad4092..a93f7bf 100644
--- a/src/mlpack/tests/CMakeLists.txt
+++ b/src/mlpack/tests/CMakeLists.txt
@@ -18,6 +18,7 @@ add_executable(mlpack_test
decision_stump_test.cpp
det_test.cpp
distribution_test.cpp
+ drusilla_select_test.cpp
emst_test.cpp
fastmks_test.cpp
feedforward_network_test.cpp
diff --git a/src/mlpack/tests/drusilla_select_test.cpp b/src/mlpack/tests/drusilla_select_test.cpp
new file mode 100644
index 0000000..504fd62
--- /dev/null
+++ b/src/mlpack/tests/drusilla_select_test.cpp
@@ -0,0 +1,145 @@
+/**
+ * @file drusilla_select_test.cpp
+ * @author Ryan Curtin
+ *
+ * Test for DrusillaSelect.
+ */
+#include <mlpack/methods/approx_kfn/drusilla_select.hpp>
+#include <mlpack/methods/neighbor_search/neighbor_search.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "test_tools.hpp"
+#include "serialization.hpp"
+
+using namespace mlpack;
+using namespace mlpack::neighbor;
+
+BOOST_AUTO_TEST_SUITE(DrusillaSelectTest);
+
+// If we have a dataset with an extreme outlier, then every point (except that
+// one) should end up with that point as the furthest neighbor candidate.
+BOOST_AUTO_TEST_CASE(DrusillaSelectExtremeOutlierTest)
+{
+ arma::mat dataset = arma::randu<arma::mat>(5, 100);
+ dataset.col(100) += 100; // Make last column very large.
+
+ // Construct with some reasonable parameters.
+ DrusillaSelect<> ds(dataset, 5, 5);
+
+ // Query with every point except the extreme point.
+ arma::mat distances;
+ arma::Mat<size_t> neighbors;
+ ds.Search(dataset.cols(0, 99), 1, neighbors, distances);
+
+ BOOST_REQUIRE_EQUAL(neighbors.n_cols, 99);
+ BOOST_REQUIRE_EQUAL(neighbors.n_rows, 1);
+ BOOST_REQUIRE_EQUAL(distances.n_cols, 99);
+ BOOST_REQUIRE_EQUAL(distances.n_rows, 1);
+
+ for (size_t i = 0; i < 99; ++i)
+ BOOST_REQUIRE_EQUAL(neighbors[i], 100);
+}
+
+// If we use only one projection with the number of points equal to what is in
+// the dataset, we should end up with the exact result.
+BOOST_AUTO_TEST_CASE(DrusillaSelectExhaustiveExactTest)
+{
+ arma::mat dataset = arma::randu<arma::mat>(5, 100);
+
+ // Construct with one projection and 100 points in that projection.
+ DrusillaSelect<> ds(dataset, 100, 1);
+
+ arma::mat distances, distancesTrue;
+ arma::Mat<size_t> neighbors, neighborsTrue;
+
+ ds.Search(dataset, 5, neighbors, distances);
+
+ AllkFN kfn(dataset);
+ kfn.Search(dataset, 5, neighborsTrue, distancesTrue);
+
+ BOOST_REQUIRE_EQUAL(neighborsTrue.n_cols, neighbors.n_cols);
+ BOOST_REQUIRE_EQUAL(neighborsTrue.n_rows, neighbors.n_rows);
+ BOOST_REQUIRE_EQUAL(distancesTrue.n_cols, distances.n_cols);
+ BOOST_REQUIRE_EQUAL(distancesTrue.n_rows, distances.n_rows);
+
+ for (size_t i = 0; i < distances.n_elem; ++i)
+ {
+ BOOST_REQUIRE_EQUAL(neighbors[i], neighborsTrue[i]);
+ BOOST_REQUIRE_CLOSE(distances[i], distancesTrue[i], 1e-5);
+ }
+}
+
+// Test that we can call Train() after calling the constructor.
+BOOST_AUTO_TEST_CASE(RetrainTest)
+{
+ arma::mat firstDataset = arma::randu<arma::mat>(3, 10);
+ arma::mat dataset = arma::randu<arma::mat>(3, 200);
+
+ DrusillaSelect<> ds(firstDataset, 3, 3);
+ ds.Train(std::move(dataset), 2, 2);
+
+ arma::mat distances;
+ arma::Mat<size_t> neighbors;
+ ds.Search(dataset, 1, neighbors, distances);
+
+ BOOST_REQUIRE_EQUAL(dataset.n_elem, 0);
+ BOOST_REQUIRE_EQUAL(neighbors.n_cols, 200);
+ BOOST_REQUIRE_EQUAL(neighbors.n_rows, 1);
+ BOOST_REQUIRE_EQUAL(distances.n_cols, 200);
+ BOOST_REQUIRE_EQUAL(distances.n_rows, 1);
+}
+
+// Test serialization.
+BOOST_AUTO_TEST_CASE(SerializationTest)
+{
+ // Create a random dataset.
+ arma::mat dataset = arma::randu<arma::mat>(3, 100);
+
+ DrusillaSelect<> ds(dataset, 3, 3);
+
+ arma::mat fakeDataset1 = arma::randu<arma::mat>(2, 5);
+ arma::mat fakeDataset2 = arma::randu<arma::mat>(10, 8);
+ DrusillaSelect<> dsXml(fakeDataset1, 10, 10);
+ DrusillaSelect<> dsText(2, 2);
+ DrusillaSelect<> dsBinary(5, 6);
+ dsBinary.Train(fakeDataset2);
+
+ // Now do the serialization.
+ SerializeObjectAll(ds, dsXml, dsText, dsBinary);
+
+ // Now do a search and make sure all the results are the same.
+ arma::Mat<size_t> neighbors, neighborsXml, neighborsText, neighborsBinary;
+ arma::mat distances, distancesXml, distancesText, distancesBinary;
+
+ ds.Search(dataset, 3, neighbors, distances);
+ dsXml.Search(dataset, 3, neighborsXml, distancesXml);
+ dsText.Search(dataset, 3, neighborsText, distancesText);
+ dsBinary.Search(dataset, 3, neighborsBinary, distancesBinary);
+
+ BOOST_REQUIRE_EQUAL(neighbors.n_rows, neighborsXml.n_rows);
+ BOOST_REQUIRE_EQUAL(neighbors.n_cols, neighborsXml.n_cols);
+ BOOST_REQUIRE_EQUAL(neighbors.n_rows, neighborsText.n_rows);
+ BOOST_REQUIRE_EQUAL(neighbors.n_cols, neighborsText.n_cols);
+ BOOST_REQUIRE_EQUAL(neighbors.n_rows, neighborsBinary.n_rows);
+ BOOST_REQUIRE_EQUAL(neighbors.n_cols, neighborsBinary.n_cols);
+
+ BOOST_REQUIRE_EQUAL(distances.n_rows, distancesXml.n_rows);
+ BOOST_REQUIRE_EQUAL(distances.n_cols, distancesXml.n_cols);
+ BOOST_REQUIRE_EQUAL(distances.n_rows, distancesText.n_rows);
+ BOOST_REQUIRE_EQUAL(distances.n_cols, distancesText.n_cols);
+ BOOST_REQUIRE_EQUAL(distances.n_rows, distancesBinary.n_rows);
+ BOOST_REQUIRE_EQUAL(distances.n_cols, distancesBinary.n_cols);
+
+ for (size_t i = 0; i < neighbors.n_elem; ++i)
+ {
+ BOOST_REQUIRE_EQUAL(neighbors[i], neighborsXml[i]);
+ BOOST_REQUIRE_EQUAL(neighbors[i], neighborsText[i]);
+ BOOST_REQUIRE_EQUAL(neighbors[i], neighborsBinary[i]);
+
+ BOOST_REQUIRE_CLOSE(distances[i], distancesXml[i], 1e-5);
+ BOOST_REQUIRE_CLOSE(distances[i], distancesText[i], 1e-5);
+ BOOST_REQUIRE_CLOSE(distances[i], distancesBinary[i], 1e-5);
+ }
+}
+
+BOOST_AUTO_TEST_SUITE_END();
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list