[mlpack] 01/11: Refactor RSModel implementation to be inlined into rs_model_impl.hpp.

Barak A. Pearlmutter barak+git at pearlmutter.net
Thu Aug 31 13:19:48 UTC 2017


This is an automated email from the git hooks/post-receive script.

bap pushed a commit to branch master
in repository mlpack.

commit 6ad1f8cc50f4e7c12a0e7396d96d746484b457bc
Author: Ryan Curtin <ryan at ratml.org>
Date:   Sun Aug 6 21:13:13 2017 -0400

    Refactor RSModel implementation to be inlined into rs_model_impl.hpp.
---
 src/mlpack/methods/range_search/CMakeLists.txt    |   1 -
 src/mlpack/methods/range_search/rs_model_impl.hpp | 259 ++++++++++++++++++++++
 2 files changed, 259 insertions(+), 1 deletion(-)

diff --git a/src/mlpack/methods/range_search/CMakeLists.txt b/src/mlpack/methods/range_search/CMakeLists.txt
index 83fa9e4..628ae7e 100644
--- a/src/mlpack/methods/range_search/CMakeLists.txt
+++ b/src/mlpack/methods/range_search/CMakeLists.txt
@@ -8,7 +8,6 @@ set(SOURCES
   range_search_stat.hpp
   rs_model.hpp
   rs_model_impl.hpp
-  rs_model.cpp
 )
 
 # Add directory name to sources.
diff --git a/src/mlpack/methods/range_search/rs_model_impl.hpp b/src/mlpack/methods/range_search/rs_model_impl.hpp
index f244719..37109aa 100644
--- a/src/mlpack/methods/range_search/rs_model_impl.hpp
+++ b/src/mlpack/methods/range_search/rs_model_impl.hpp
@@ -15,9 +15,268 @@
 // In case it hasn't been included yet.
 #include "rs_model.hpp"
 
+#include <mlpack/core/math/random_basis.hpp>
+
 namespace mlpack {
 namespace range {
 
+/**
+ * Initialize the RSModel with the given tree type and whether or not a random
+ * basis should be used.
+ */
+inline RSModel::RSModel(TreeTypes treeType, bool randomBasis) :
+    treeType(treeType),
+    leafSize(0),
+    randomBasis(randomBasis)
+{
+  // Nothing to do.
+}
+
+// Copy constructor.
+inline RSModel::RSModel(const RSModel& other) :
+    treeType(other.treeType),
+    leafSize(other.leafSize),
+    randomBasis(other.randomBasis),
+    rSearch(other.rSearch)
+{
+  // Nothing to do.
+}
+
+// Move constructor.
+inline RSModel::RSModel(RSModel&& other) :
+    treeType(other.treeType),
+    leafSize(other.leafSize),
+    randomBasis(other.randomBasis),
+    rSearch(other.rSearch)
+{
+  // Reset other model.
+  other.treeType = TreeTypes::KD_TREE;
+  other.leafSize = 0;
+  other.randomBasis = false;
+  other.rSearch = decltype(other.rSearch)();
+}
+
+// Copy operator.
+inline RSModel& RSModel::operator=(const RSModel& other)
+{
+  boost::apply_visitor(DeleteVisitor(), rSearch);
+
+  treeType = other.treeType;
+  leafSize = other.leafSize;
+  randomBasis = other.randomBasis;
+  rSearch = other.rSearch;
+
+  return *this;
+}
+
+// Move operator.
+inline RSModel& RSModel::operator=(RSModel&& other)
+{
+  boost::apply_visitor(DeleteVisitor(), rSearch);
+
+  treeType = other.treeType;
+  leafSize = other.leafSize;
+  randomBasis = other.randomBasis;
+  rSearch = other.rSearch;
+
+  // Reset other model.
+  other.treeType = TreeTypes::KD_TREE;
+  other.leafSize = 0;
+  other.randomBasis = false;
+  other.rSearch = decltype(other.rSearch)();
+
+  return *this;
+}
+
+// Clean memory, if necessary.
+inline RSModel::~RSModel()
+{
+  boost::apply_visitor(DeleteVisitor(), rSearch);
+}
+
+inline void RSModel::BuildModel(arma::mat&& referenceSet,
+                                const size_t leafSize,
+                                const bool naive,
+                                const bool singleMode)
+{
+  // Initialize random basis if necessary.
+  if (randomBasis)
+  {
+    Log::Info << "Creating random basis..." << std::endl;
+    math::RandomBasis(q, referenceSet.n_rows);
+  }
+
+  this->leafSize = leafSize;
+
+  // Clean memory, if necessary.
+  boost::apply_visitor(DeleteVisitor(), rSearch);
+
+  // Do we need to modify the reference set?
+  if (randomBasis)
+    referenceSet = q * referenceSet;
+
+  if (!naive)
+  {
+    Timer::Start("tree_building");
+    Log::Info << "Building reference tree..." << std::endl;
+  }
+
+  switch (treeType)
+  {
+    case KD_TREE:
+      rSearch = new RSType<tree::KDTree> (naive, singleMode);
+      break;
+
+    case COVER_TREE:
+      rSearch = new RSType<tree::StandardCoverTree>(naive, singleMode);
+      break;
+
+    case R_TREE:
+      rSearch = new RSType<tree::RTree>(naive, singleMode);
+      break;
+
+    case R_STAR_TREE:
+      rSearch = new RSType<tree::RStarTree>(naive, singleMode);
+      break;
+
+    case BALL_TREE:
+      rSearch = new RSType<tree::BallTree>(naive, singleMode);
+      break;
+
+    case X_TREE:
+      rSearch = new RSType<tree::XTree>(naive, singleMode);
+      break;
+
+    case HILBERT_R_TREE:
+      rSearch = new RSType<tree::HilbertRTree>(naive, singleMode);
+      break;
+
+    case R_PLUS_TREE:
+      rSearch = new RSType<tree::RPlusTree>(naive, singleMode);
+      break;
+
+    case R_PLUS_PLUS_TREE:
+      rSearch = new RSType<tree::RPlusPlusTree>(naive, singleMode);
+      break;
+
+    case VP_TREE:
+      rSearch = new RSType<tree::VPTree>(naive, singleMode);
+      break;
+
+    case RP_TREE:
+      rSearch = new RSType<tree::RPTree>(naive, singleMode);
+      break;
+
+    case MAX_RP_TREE:
+      rSearch = new RSType<tree::MaxRPTree>(naive, singleMode);
+      break;
+
+    case UB_TREE:
+      rSearch = new RSType<tree::UBTree>(naive, singleMode);
+      break;
+
+    case OCTREE:
+      rSearch = new RSType<tree::Octree>(naive, singleMode);
+      break;
+  }
+
+  TrainVisitor tn(std::move(referenceSet), leafSize);
+  boost::apply_visitor(tn, rSearch);
+
+  if (!naive)
+  {
+    Timer::Stop("tree_building");
+    Log::Info << "Tree built." << std::endl;
+  }
+}
+
+// Perform range search.
+inline void RSModel::Search(arma::mat&& querySet,
+                            const math::Range& range,
+                            std::vector<std::vector<size_t>>& neighbors,
+                            std::vector<std::vector<double>>& distances)
+{
+  // We may need to map the query set randomly.
+  if (randomBasis)
+    querySet = q * querySet;
+
+  Log::Info << "Search for points in the range [" << range.Lo() << ", "
+      << range.Hi() << "] with ";
+  if (!Naive() && !SingleMode())
+    Log::Info << "dual-tree " << TreeName() << " search..." << std::endl;
+  else if (!Naive())
+    Log::Info << "single-tree " << TreeName() << " search..." << std::endl;
+  else
+    Log::Info << "brute-force (naive) search..." << std::endl;
+
+
+  BiSearchVisitor search(querySet, range, neighbors, distances,
+      leafSize);
+  boost::apply_visitor(search, rSearch);
+}
+
+// Perform range search (monochromatic case).
+inline void RSModel::Search(const math::Range& range,
+                            std::vector<std::vector<size_t>>& neighbors,
+                            std::vector<std::vector<double>>& distances)
+{
+  Log::Info << "Search for points in the range [" << range.Lo() << ", "
+      << range.Hi() << "] with ";
+  if (!Naive() && !SingleMode())
+    Log::Info << "dual-tree " << TreeName() << " search..." << std::endl;
+  else if (!Naive())
+    Log::Info << "single-tree " << TreeName() << " search..." << std::endl;
+  else
+    Log::Info << "brute-force (naive) search..." << std::endl;
+
+  MonoSearchVisitor search(range, neighbors, distances);
+  boost::apply_visitor(search, rSearch);
+}
+
+// Get the name of the tree type.
+inline std::string RSModel::TreeName() const
+{
+  switch (treeType)
+  {
+    case KD_TREE:
+      return "kd-tree";
+    case COVER_TREE:
+      return "cover tree";
+    case R_TREE:
+      return "R tree";
+    case R_STAR_TREE:
+      return "R* tree";
+    case BALL_TREE:
+      return "ball tree";
+    case X_TREE:
+      return "X tree";
+    case HILBERT_R_TREE:
+      return "Hilbert R tree";
+    case R_PLUS_TREE:
+      return "R+ tree";
+    case R_PLUS_PLUS_TREE:
+      return "R++ tree";
+    case VP_TREE:
+      return "vantage point tree";
+    case RP_TREE:
+      return "random projection tree (mean split)";
+    case MAX_RP_TREE:
+      return "random projection tree (max split)";
+    case UB_TREE:
+      return "UB tree";
+    case OCTREE:
+      return "octree";
+    default:
+      return "unknown tree";
+  }
+}
+
+// Clean memory.
+inline void RSModel::CleanMemory()
+{
+  boost::apply_visitor(DeleteVisitor(), rSearch);
+}
+
 //! Monochromatic range search on the given RSType instance.
 template<typename RSType>
 void MonoSearchVisitor::operator()(RSType* rs) const

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git



More information about the debian-science-commits mailing list