[mlpack] 27/53: Fix potential bug and simplify memory requirements.
Barak A. Pearlmutter
barak+git at pearlmutter.net
Mon Nov 14 00:46:48 UTC 2016
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch master
in repository mlpack.
commit 6d7e0ee10359adff4a3dd15fa1cdeaf9f2f58921
Author: Ryan Curtin <ryan at ratml.org>
Date: Tue Oct 25 16:57:05 2016 +0900
Fix potential bug and simplify memory requirements.
---
src/mlpack/methods/approx_kfn/qdafn.hpp | 15 +++------
src/mlpack/methods/approx_kfn/qdafn_impl.hpp | 50 +++++++++++-----------------
2 files changed, 25 insertions(+), 40 deletions(-)
diff --git a/src/mlpack/methods/approx_kfn/qdafn.hpp b/src/mlpack/methods/approx_kfn/qdafn.hpp
index 694fbf9..7617fc2 100644
--- a/src/mlpack/methods/approx_kfn/qdafn.hpp
+++ b/src/mlpack/methods/approx_kfn/qdafn.hpp
@@ -51,10 +51,11 @@ class QDAFN
arma::Mat<size_t>& neighbors,
arma::mat& distances);
- private:
- //! The reference set.
- const MatType& referenceSet;
+ //! Serialize the model.
+ template<typename Archive>
+ void Serialize(Archive& ar, const unsigned int /* version */);
+ private:
//! The number of projections.
const size_t l;
//! The number of elements to store for each projection.
@@ -69,13 +70,7 @@ class QDAFN
//! Values of a_i * x for each point in S.
arma::mat sValues;
- //! Insert a neighbor into a set of results for a given query point.
- void InsertNeighbor(arma::mat& distances,
- arma::Mat<size_t>& neighbors,
- const size_t queryIndex,
- const size_t pos,
- const size_t neighbor,
- const double distance) const;
+ arma::cube candidateSet;
};
} // namespace neighbor
diff --git a/src/mlpack/methods/approx_kfn/qdafn_impl.hpp b/src/mlpack/methods/approx_kfn/qdafn_impl.hpp
index bf462da..f1d04fa 100644
--- a/src/mlpack/methods/approx_kfn/qdafn_impl.hpp
+++ b/src/mlpack/methods/approx_kfn/qdafn_impl.hpp
@@ -21,7 +21,6 @@ template<typename MatType>
QDAFN<MatType>::QDAFN(const MatType& referenceSet,
const size_t l,
const size_t m) :
- referenceSet(referenceSet),
l(l),
m(m)
{
@@ -40,6 +39,7 @@ QDAFN<MatType>::QDAFN(const MatType& referenceSet,
// Loop over each projection and find the top m elements.
sIndices.set_size(m, l);
sValues.set_size(m, l);
+ candidateSet.set_size(referenceSet.n_rows, m, l);
for (size_t i = 0; i < l; ++i)
{
arma::uvec sortedIndices = arma::sort_index(projections.col(i), "descend");
@@ -49,6 +49,7 @@ QDAFN<MatType>::QDAFN(const MatType& referenceSet,
{
sIndices(j, i) = sortedIndices[j];
sValues(j, i) = projections(sortedIndices[j], i);
+ candidateSet.slice(l).col(j) = referenceSet.col(sortedIndices[j]);
}
}
}
@@ -77,8 +78,8 @@ void QDAFN<MatType>::Search(const MatType& querySet,
std::priority_queue<std::pair<double, size_t>> queue;
for (size_t i = 0; i < l; ++i)
{
- const double val = projections(0, i) - arma::dot(querySet.col(q),
- lines.col(i));
+ const double val = sValues(0, i) - arma::dot(querySet.col(q),
+ lines.col(i));
queue.push(std::make_pair(val, i));
}
@@ -97,17 +98,17 @@ void QDAFN<MatType>::Search(const MatType& querySet,
queue.pop();
// Get index of reference point to look at.
- size_t referenceIndex = sIndices(tableLocations[p.second], p.second);
+ const size_t tableIndex = tableLocations[p.second];
// Calculate distance from query point.
const double dist = mlpack::metric::EuclideanDistance::Evaluate(
- querySet.col(q), referenceSet.col(referenceIndex));
+ querySet.col(q), candidateSet.slice(p.second).col(tableIndex));
// Is this neighbor good enough to insert into the results?
if (dist > resultsQueue.top().first)
{
resultsQueue.pop();
- resultsQueue.push(std::make_pair(dist, referenceIndex));
+ resultsQueue.push(std::make_pair(dist, sIndices(tableIndex, p.second)));
}
// Now (line 14) get the next element and insert into the queue. Do this
@@ -116,9 +117,8 @@ void QDAFN<MatType>::Search(const MatType& querySet,
if (i < m - 1)
{
tableLocations[p.second]++;
- const double val = p.first -
- projections(tableLocations[p.second] - 1, p.second) +
- projections(tableLocations[p.second], p.second);
+ const double val = p.first - sValues(tableIndex, p.second) +
+ sValues(tableIndex + 1, p.second);
queue.push(std::make_pair(val, p.second));
}
@@ -135,28 +135,18 @@ void QDAFN<MatType>::Search(const MatType& querySet,
}
template<typename MatType>
-void QDAFN<MatType>::InsertNeighbor(arma::mat& distances,
- arma::Mat<size_t>& neighbors,
- const size_t queryIndex,
- const size_t pos,
- const size_t neighbor,
- const double distance) const
+template<typename Archive>
+void QDAFN<MatType>::Serialize(Archive& ar, const unsigned int /* version */)
{
- // We only memmove() if there is actually a need to shift something.
- if (pos < (distances.n_rows - 1))
- {
- const size_t len = (distances.n_rows - 1) - pos;
- memmove(distances.colptr(queryIndex) + (pos + 1),
- distances.colptr(queryIndex) + pos,
- sizeof(double) * len);
- memmove(neighbors.colptr(queryIndex) + (pos + 1),
- neighbors.colptr(queryIndex) + pos,
- sizeof(size_t) * len);
- }
-
- // Now put the new information in the right index.
- distances(pos, queryIndex) = distance;
- neighbors(pos, queryIndex) = neighbor;
+ using data::CreateNVP;
+
+ ar & CreateNVP(l, "l");
+ ar & CreateNVP(m, "m");
+ ar & CreateNVP(lines, "lines");
+ ar & CreateNVP(projections, "projections");
+ ar & CreateNVP(sIndices, "sIndices");
+ ar & CreateNVP(sValues, "sValues");
+ ar & CreateNVP(candidateSet, "candidateSet");
}
} // namespace neighbor
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list