[mlpack] 32/53: Refactor QDAFN to better handle sparse data matrices.

Barak A. Pearlmutter barak+git at pearlmutter.net
Mon Nov 14 00:46:48 UTC 2016


This is an automated email from the git hooks/post-receive script.

bap pushed a commit to branch master
in repository mlpack.

commit 4b35ecc9e8490fcb2fa499c685494b6604789ce4
Author: Ryan Curtin <ryan at ratml.org>
Date:   Sun Oct 30 20:50:40 2016 +0900

    Refactor QDAFN to better handle sparse data matrices.
---
 src/mlpack/methods/approx_kfn/qdafn.hpp      | 8 +++++++-
 src/mlpack/methods/approx_kfn/qdafn_impl.hpp | 7 ++++---
 2 files changed, 11 insertions(+), 4 deletions(-)

diff --git a/src/mlpack/methods/approx_kfn/qdafn.hpp b/src/mlpack/methods/approx_kfn/qdafn.hpp
index ad9e206..f7949db 100644
--- a/src/mlpack/methods/approx_kfn/qdafn.hpp
+++ b/src/mlpack/methods/approx_kfn/qdafn.hpp
@@ -64,6 +64,11 @@ class QDAFN
   template<typename Archive>
   void Serialize(Archive& ar, const unsigned int /* version */);
 
+  //! Get the candidate set for the given projection table.
+  const MatType& CandidateSet(const size_t t) const { return candidateSet[t]; }
+  //! Modify the candidate set for the given projection table.  Careful!
+  MatType& CandidateSet(const size_t t) { return candidateSet[t]; }
+
  private:
   //! The number of projections.
   size_t l;
@@ -79,7 +84,8 @@ class QDAFN
   //! Values of a_i * x for each point in S.
   arma::mat sValues;
 
-  arma::cube candidateSet;
+  // Candidate sets; one element in the vector for each table.
+  std::vector<MatType> candidateSet;
 };
 
 } // namespace neighbor
diff --git a/src/mlpack/methods/approx_kfn/qdafn_impl.hpp b/src/mlpack/methods/approx_kfn/qdafn_impl.hpp
index 85ec99a..de6c882 100644
--- a/src/mlpack/methods/approx_kfn/qdafn_impl.hpp
+++ b/src/mlpack/methods/approx_kfn/qdafn_impl.hpp
@@ -43,9 +43,10 @@ QDAFN<MatType>::QDAFN(const MatType& referenceSet,
   // Loop over each projection and find the top m elements.
   sIndices.set_size(m, l);
   sValues.set_size(m, l);
-  candidateSet.set_size(referenceSet.n_rows, m, l);
+  candidateSet.resize(l);
   for (size_t i = 0; i < l; ++i)
   {
+    candidateSet[i].set_size(referenceSet.n_rows, m);
     arma::uvec sortedIndices = arma::sort_index(projections.col(i), "descend");
 
     // Grab the top m elements.
@@ -53,7 +54,7 @@ QDAFN<MatType>::QDAFN(const MatType& referenceSet,
     {
       sIndices(j, i) = sortedIndices[j];
       sValues(j, i) = projections(sortedIndices[j], i);
-      candidateSet.slice(i).col(j) = referenceSet.col(sortedIndices[j]);
+      candidateSet[i].col(j) = referenceSet.col(sortedIndices[j]);
     }
   }
 }
@@ -106,7 +107,7 @@ void QDAFN<MatType>::Search(const MatType& querySet,
 
       // Calculate distance from query point.
       const double dist = mlpack::metric::EuclideanDistance::Evaluate(
-          querySet.col(q), candidateSet.slice(p.second).col(tableIndex));
+          querySet.col(q), candidateSet[p.second].col(tableIndex));
 
       // Is this neighbor good enough to insert into the results?
       if (dist > resultsQueue.top().first)

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git



More information about the debian-science-commits mailing list