[mlpack] 34/53: Add tests for sparse operation and fix sparse bugs.
Barak A. Pearlmutter
barak+git at pearlmutter.net
Mon Nov 14 00:46:49 UTC 2016
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch master
in repository mlpack.
commit d01b20fe87591e51421308dde5340816193429e3
Author: Ryan Curtin <ryan at ratml.org>
Date: Sun Oct 30 21:30:30 2016 +0900
Add tests for sparse operation and fix sparse bugs.
---
.../methods/approx_kfn/drusilla_select_impl.hpp | 13 ++++++++-----
src/mlpack/tests/drusilla_select_test.cpp | 19 +++++++++++++++++++
2 files changed, 27 insertions(+), 5 deletions(-)
diff --git a/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp b/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp
index 9595374..942063b 100644
--- a/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp
+++ b/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp
@@ -76,12 +76,15 @@ void DrusillaSelect<MatType>::Train(
candidateSet.set_size(referenceSet.n_rows, l * m);
candidateIndices.set_size(l * m);
- arma::vec dataMean = arma::mean(referenceSet, 1);
+ arma::vec dataMean(arma::mean(referenceSet, 1));
arma::vec norms(referenceSet.n_cols);
- arma::mat refCopy = referenceSet.each_col() - dataMean;
+ MatType refCopy(referenceSet.n_rows, referenceSet.n_cols);
for (size_t i = 0; i < refCopy.n_cols; ++i)
- norms[i] = arma::norm(refCopy.col(i) - dataMean);
+ {
+ refCopy.col(i) = referenceSet.col(i) - dataMean;
+ norms[i] = arma::norm(refCopy.col(i));
+ }
// Find the top m points for each of the l projections...
for (size_t i = 0; i < l; ++i)
@@ -90,7 +93,7 @@ void DrusillaSelect<MatType>::Train(
arma::uword maxIndex;
norms.max(maxIndex);
- arma::vec line = refCopy.col(maxIndex) / arma::norm(refCopy.col(maxIndex));
+ arma::vec line(refCopy.col(maxIndex) / arma::norm(refCopy.col(maxIndex)));
const size_t n_nonzero = (size_t) arma::sum(norms > 0);
// Calculate distortion and offset and make scores.
@@ -176,7 +179,7 @@ void DrusillaSelect<MatType>::Search(const MatType& querySet,
// TreeType.
metric::EuclideanDistance metric;
NeighborSearchRules<FurthestNeighborSort, metric::EuclideanDistance,
- tree::KDTree<metric::EuclideanDistance, tree::EmptyStatistic, arma::mat>>
+ tree::KDTree<metric::EuclideanDistance, tree::EmptyStatistic, MatType>>
rules(candidateSet, querySet, k, metric, 0, false);
for (size_t q = 0; q < querySet.n_cols; ++q)
diff --git a/src/mlpack/tests/drusilla_select_test.cpp b/src/mlpack/tests/drusilla_select_test.cpp
index b60a1ad..cce2704 100644
--- a/src/mlpack/tests/drusilla_select_test.cpp
+++ b/src/mlpack/tests/drusilla_select_test.cpp
@@ -143,4 +143,23 @@ BOOST_AUTO_TEST_CASE(SerializationTest)
}
}
+// Make sure we can create the object with a sparse matrix.
+BOOST_AUTO_TEST_CASE(SparseTest)
+{
+ arma::sp_mat dataset;
+ dataset.sprandu(50, 1000, 0.3);
+
+ DrusillaSelect<arma::sp_mat> ds(dataset, 5, 10);
+
+ // Run a search.
+ arma::mat distances;
+ arma::Mat<size_t> neighbors;
+ ds.Search(dataset, 3, neighbors, distances);
+
+ BOOST_REQUIRE_EQUAL(neighbors.n_cols, 1000);
+ BOOST_REQUIRE_EQUAL(neighbors.n_rows, 3);
+ BOOST_REQUIRE_EQUAL(distances.n_cols, 1000);
+ BOOST_REQUIRE_EQUAL(distances.n_rows, 3);
+}
+
BOOST_AUTO_TEST_SUITE_END();
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list