[mlpack] 05/53: Fix some bugs with the trivial test.
Barak A. Pearlmutter
barak+git at pearlmutter.net
Mon Nov 14 00:46:46 UTC 2016
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch master
in repository mlpack.
commit 452943f059afbcaabe0c143676f35d91f54a7733
Author: Ryan Curtin <ryan at ratml.org>
Date: Wed Apr 13 13:04:33 2016 -0700
Fix some bugs with the trivial test.
---
qdafn.hpp | 4 +++-
qdafn_impl.hpp | 31 ++++++++++++++++++++-----------
2 files changed, 23 insertions(+), 12 deletions(-)
diff --git a/qdafn.hpp b/qdafn.hpp
index 557b421..860acb8 100644
--- a/qdafn.hpp
+++ b/qdafn.hpp
@@ -23,7 +23,7 @@
namespace qdafn {
-template<typename MatType>
+template<typename MatType = arma::mat>
class QDAFN
{
public:
@@ -60,6 +60,8 @@ class QDAFN
const size_t m;
//! The random lines we are projecting onto. Has l columns.
arma::mat lines;
+ //! Projections of each point onto each random line.
+ arma::mat projections;
//! Indices of the points for each S.
arma::Mat<size_t> sIndices;
diff --git a/qdafn_impl.hpp b/qdafn_impl.hpp
index 1b8cfaa..368b84c 100644
--- a/qdafn_impl.hpp
+++ b/qdafn_impl.hpp
@@ -10,6 +10,9 @@
// In case it hasn't been included yet.
#include "qdafn.hpp"
+#include <queue>
+#include <mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp>
+
namespace qdafn {
// Constructor.
@@ -31,7 +34,7 @@ QDAFN<MatType>::QDAFN(const MatType& referenceSet,
// Now, project each of the reference points onto each line, and collect the
// top m elements.
- arma::mat projections = lines.t() * referenceSet;
+ projections = referenceSet.t() * lines;
// Loop over each projection and find the top m elements.
sIndices.set_size(m, l);
@@ -43,8 +46,8 @@ QDAFN<MatType>::QDAFN(const MatType& referenceSet,
// Grab the top m elements.
for (size_t j = 0; j < m; ++j)
{
- sIndices[j] = sortedIndices[j];
- sValues[j] = projections(sortedIndices[j], i);
+ sIndices(j, i) = sortedIndices[j];
+ sValues(j, i) = projections(sortedIndices[j], i);
}
}
}
@@ -61,6 +64,7 @@ void QDAFN<MatType>::Search(const MatType& querySet,
"value of m!");
neighbors.set_size(k, querySet.n_cols);
+ neighbors.fill(size_t() - 1);
distances.zeros(k, querySet.n_cols);
// Search for each point.
@@ -103,16 +107,21 @@ void QDAFN<MatType>::Search(const MatType& querySet,
// SortDistance() returns (size_t() - 1) if we shouldn't add it.
if (insertPosition != (size_t() - 1))
- InsertNeighbor(distances, neighbors, q, referenceIndex, dist);
+ InsertNeighbor(distances, neighbors, q, insertPosition, referenceIndex,
+ dist);
// Now (line 14) get the next element and insert into the queue. Do this
- // by adjusting the previous value.
- tableLocations[p.second]++;
- const double val = p.first -
- projections(tableLocations[p.second] - 1, p.second) +
- projections(tableLocations[p.second], p.second);
-
- queue.push(std::make_pair(val, p.second));
+ // by adjusting the previous value. Don't insert anything if we are at
+ // the end of the search, though.
+ if (i < m - 1)
+ {
+ tableLocations[p.second]++;
+ const double val = p.first -
+ projections(tableLocations[p.second] - 1, p.second) +
+ projections(tableLocations[p.second], p.second);
+
+ queue.push(std::make_pair(val, p.second));
+ }
}
}
}
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list