[mlpack] 22/53: Fix failing tests and bugs.
Barak A. Pearlmutter
barak+git at pearlmutter.net
Mon Nov 14 00:46:47 UTC 2016
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch master
in repository mlpack.
commit 84bed62e1a30dbc7b2cea5db7333e78c980d4529
Author: Ryan Curtin <ryan at ratml.org>
Date: Mon Oct 24 05:08:49 2016 -0400
Fix failing tests and bugs.
---
.../methods/approx_kfn/drusilla_select_impl.hpp | 54 +++++++++++-----------
src/mlpack/tests/drusilla_select_test.cpp | 17 +++----
2 files changed, 36 insertions(+), 35 deletions(-)
diff --git a/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp b/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp
index a84b304..f264e64 100644
--- a/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp
+++ b/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp
@@ -24,7 +24,8 @@ template<typename MatType>
DrusillaSelect<MatType>::DrusillaSelect(const MatType& referenceSet,
const size_t l,
const size_t m) :
- candidateSet(referenceSet.n_rows, l * m),
+ candidateSet(referenceSet.n_cols, l * m),
+ candidateIndices(l * m),
l(l),
m(m)
{
@@ -41,6 +42,8 @@ DrusillaSelect<MatType>::DrusillaSelect(const MatType& referenceSet,
// Constructor with no training.
template<typename MatType>
DrusillaSelect<MatType>::DrusillaSelect(const size_t l, const size_t m) :
+ candidateSet(0, l * m),
+ candidateIndices(l * m),
l(l),
m(m)
{
@@ -70,6 +73,9 @@ void DrusillaSelect<MatType>::Train(
"large! Choose smaller values. l*m must be smaller than the number "
"of points in the dataset.");
+ candidateSet.set_size(referenceSet.n_rows, l * m);
+ candidateIndices.set_size(l * m);
+
arma::vec dataMean = arma::mean(referenceSet, 1);
arma::vec norms(referenceSet.n_cols);
@@ -87,29 +93,24 @@ void DrusillaSelect<MatType>::Train(
arma::vec line = refCopy.col(maxIndex) / arma::norm(refCopy.col(maxIndex));
const size_t n_nonzero = (size_t) arma::sum(norms > 0);
- // Calculate distortion and offset.
- arma::vec distortions(referenceSet.n_cols);
- arma::vec offsets(referenceSet.n_cols);
+ // Calculate distortion and offset and make scores.
+ std::vector<bool> closeAngle(referenceSet.n_cols, false);
+ arma::vec sums(referenceSet.n_cols);
for (size_t j = 0; j < referenceSet.n_cols; ++j)
{
if (norms[j] > 0.0)
{
- offsets[j] = arma::dot(refCopy.col(j), line);
- distortions[j] = arma::norm(refCopy.col(j) - offsets[j] *
- line);
+ const double offset = arma::dot(refCopy.col(j), line);
+ const double distortion = arma::norm(refCopy.col(j) - offset * line);
+ sums[j] = std::abs(offset) - std::abs(distortion);
+ closeAngle[j] =
+ (std::atan(distortion / std::abs(offset)) >= (M_PI / 8.0));
}
else
{
- offsets[j] = 0.0;
- distortions[j] = 0.0;
+ sums[j] = norms[j];
}
}
- arma::vec sums = arma::abs(offsets) - arma::abs(distortions);
- arma::uvec sortedSums = arma::sort_index(sums, "descend");
-
- arma::vec bestSums(m);
- arma::Col<size_t> bestIndices(m);
- bestSums.fill(-DBL_MAX);
// Find the top m elements using a priority queue.
typedef std::pair<double, size_t> Candidate;
@@ -117,11 +118,11 @@ void DrusillaSelect<MatType>::Train(
{
bool operator()(const Candidate& c1, const Candidate& c2)
{
- return c2.first > c1.first;
+ return c2.first < c1.first;
}
};
- std::vector<Candidate> clist(m, std::make_pair(size_t(-1), double(0.0)));
+ std::vector<Candidate> clist(m, std::make_pair(double(-1.0), size_t(-1)));
std::priority_queue<Candidate, std::vector<Candidate>, CandidateCmp>
pq(CandidateCmp(), std::move(clist));
@@ -141,16 +142,17 @@ void DrusillaSelect<MatType>::Train(
const size_t index = pq.top().second;
pq.pop();
candidateSet.col(i * m + j) = referenceSet.col(index);
+ candidateIndices[i * m + j] = index;
- // Mark the norm as 0 so we don't see this point again.
- norms[index] = 0.0;
+ // Mark the norm as -1 so we don't see this point again.
+ norms[index] = -1.0;
}
// Calculate angles from the current projection. Anything close enough,
// mark the norm as 0.
- arma::vec farPoints = arma::conv_to<arma::vec>::from(
- arma::atan(distortions / arma::abs(offsets)) >= (M_PI / 8.0));
- norms %= farPoints;
+ for (size_t j = 0; j < norms.n_elem; ++j)
+ if (norms[j] > 0.0 && closeAngle[j])
+ norms[j] = 0.0;
}
}
@@ -175,16 +177,14 @@ void DrusillaSelect<MatType>::Search(const MatType& querySet,
metric::EuclideanDistance metric;
NeighborSearchRules<FurthestNeighborSort, metric::EuclideanDistance,
tree::KDTree<metric::EuclideanDistance, tree::EmptyStatistic, arma::mat>>
- rules(querySet, candidateSet, k, metric, 0, false);
-
- neighbors.set_size(k, querySet.n_cols);
- neighbors.fill(size_t() - 1);
- distances.zeros(k, querySet.n_cols);
+ rules(candidateSet, querySet, k, metric, 0, false);
for (size_t q = 0; q < querySet.n_cols; ++q)
for (size_t r = 0; r < candidateSet.n_cols; ++r)
rules.BaseCase(q, r);
+ rules.GetResults(neighbors, distances);
+
// Map the neighbors back to their original indices in the reference set.
for (size_t i = 0; i < neighbors.n_elem; ++i)
neighbors[i] = candidateIndices[neighbors[i]];
diff --git a/src/mlpack/tests/drusilla_select_test.cpp b/src/mlpack/tests/drusilla_select_test.cpp
index 504fd62..b60a1ad 100644
--- a/src/mlpack/tests/drusilla_select_test.cpp
+++ b/src/mlpack/tests/drusilla_select_test.cpp
@@ -21,7 +21,7 @@ BOOST_AUTO_TEST_SUITE(DrusillaSelectTest);
BOOST_AUTO_TEST_CASE(DrusillaSelectExtremeOutlierTest)
{
arma::mat dataset = arma::randu<arma::mat>(5, 100);
- dataset.col(100) += 100; // Make last column very large.
+ dataset.col(99) += 100; // Make last column very large.
// Construct with some reasonable parameters.
DrusillaSelect<> ds(dataset, 5, 5);
@@ -29,7 +29,7 @@ BOOST_AUTO_TEST_CASE(DrusillaSelectExtremeOutlierTest)
// Query with every point except the extreme point.
arma::mat distances;
arma::Mat<size_t> neighbors;
- ds.Search(dataset.cols(0, 99), 1, neighbors, distances);
+ ds.Search(dataset.cols(0, 98), 1, neighbors, distances);
BOOST_REQUIRE_EQUAL(neighbors.n_cols, 99);
BOOST_REQUIRE_EQUAL(neighbors.n_rows, 1);
@@ -37,7 +37,9 @@ BOOST_AUTO_TEST_CASE(DrusillaSelectExtremeOutlierTest)
BOOST_REQUIRE_EQUAL(distances.n_rows, 1);
for (size_t i = 0; i < 99; ++i)
- BOOST_REQUIRE_EQUAL(neighbors[i], 100);
+ {
+ BOOST_REQUIRE_EQUAL(neighbors[i], 99);
+ }
}
// If we use only one projection with the number of points equal to what is in
@@ -82,7 +84,6 @@ BOOST_AUTO_TEST_CASE(RetrainTest)
arma::Mat<size_t> neighbors;
ds.Search(dataset, 1, neighbors, distances);
- BOOST_REQUIRE_EQUAL(dataset.n_elem, 0);
BOOST_REQUIRE_EQUAL(neighbors.n_cols, 200);
BOOST_REQUIRE_EQUAL(neighbors.n_rows, 1);
BOOST_REQUIRE_EQUAL(distances.n_cols, 200);
@@ -97,11 +98,11 @@ BOOST_AUTO_TEST_CASE(SerializationTest)
DrusillaSelect<> ds(dataset, 3, 3);
- arma::mat fakeDataset1 = arma::randu<arma::mat>(2, 5);
- arma::mat fakeDataset2 = arma::randu<arma::mat>(10, 8);
- DrusillaSelect<> dsXml(fakeDataset1, 10, 10);
+ arma::mat fakeDataset1 = arma::randu<arma::mat>(2, 15);
+ arma::mat fakeDataset2 = arma::randu<arma::mat>(10, 18);
+ DrusillaSelect<> dsXml(fakeDataset1, 5, 3);
DrusillaSelect<> dsText(2, 2);
- DrusillaSelect<> dsBinary(5, 6);
+ DrusillaSelect<> dsBinary(5, 2);
dsBinary.Train(fakeDataset2);
// Now do the serialization.
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list