[mlpack] 29/207: Refactored range search model to include boost variant and visitor paradigm

Barak A. Pearlmutter barak+git at pearlmutter.net
Thu Mar 23 17:53:37 UTC 2017


This is an automated email from the git hooks/post-receive script.

bap pushed a commit to branch master
in repository mlpack.

commit 13e8c1bdbfcd4fb8c1c3031b63ea1b128afb6a81
Author: dinesh Raj <dinu.iota at gmail.com>
Date:   Mon Feb 6 23:11:20 2017 +0530

    Refactored range search model to include boost variant and visitor paradigm
---
 src/mlpack/methods/range_search/range_search.hpp  |   4 +-
 src/mlpack/methods/range_search/rs_model.cpp      | 349 ++----------------
 src/mlpack/methods/range_search/rs_model.hpp      | 261 +++++++++++--
 src/mlpack/methods/range_search/rs_model_impl.hpp | 428 +++++++++++-----------
 4 files changed, 468 insertions(+), 574 deletions(-)

diff --git a/src/mlpack/methods/range_search/range_search.hpp b/src/mlpack/methods/range_search/range_search.hpp
index 2182795..9d4c3e1 100644
--- a/src/mlpack/methods/range_search/range_search.hpp
+++ b/src/mlpack/methods/range_search/range_search.hpp
@@ -22,7 +22,7 @@ namespace mlpack {
 namespace range /** Range-search routines. */ {
 
 //! Forward declaration.
-class RSModel;
+class TrainVisitor;
 
 /**
  * The RangeSearch class is a template class for performing range searches.  It
@@ -323,7 +323,7 @@ class RangeSearch
   size_t scores;
 
   //! For access to mappings when building models.
-  friend RSModel;
+  friend class TrainVisitor;
 };
 
 } // namespace range
diff --git a/src/mlpack/methods/range_search/rs_model.cpp b/src/mlpack/methods/range_search/rs_model.cpp
index 6857f34..8a280a3 100644
--- a/src/mlpack/methods/range_search/rs_model.cpp
+++ b/src/mlpack/methods/range_search/rs_model.cpp
@@ -23,21 +23,7 @@ using namespace mlpack::range;
 RSModel::RSModel(TreeTypes treeType, bool randomBasis) :
     treeType(treeType),
     leafSize(0),
-    randomBasis(randomBasis),
-    kdTreeRS(NULL),
-    coverTreeRS(NULL),
-    rTreeRS(NULL),
-    rStarTreeRS(NULL),
-    ballTreeRS(NULL),
-    xTreeRS(NULL),
-    hilbertRTreeRS(NULL),
-    rPlusTreeRS(NULL),
-    rPlusPlusTreeRS(NULL),
-    vpTreeRS(NULL),
-    rpTreeRS(NULL),
-    maxRPTreeRS(NULL),
-    ubTreeRS(NULL),
-    octreeRS(NULL)
+    randomBasis(randomBasis)
 {
   // Nothing to do.
 }
@@ -45,7 +31,7 @@ RSModel::RSModel(TreeTypes treeType, bool randomBasis) :
 // Clean memory, if necessary.
 RSModel::~RSModel()
 {
-  CleanMemory();
+  boost::apply_visitor(DeleteVisitor(), rSearch);
 }
 
 void RSModel::BuildModel(arma::mat&& referenceSet,
@@ -63,7 +49,7 @@ void RSModel::BuildModel(arma::mat&& referenceSet,
   this->leafSize = leafSize;
 
   // Clean memory, if necessary.
-  CleanMemory();
+  boost::apply_visitor(DeleteVisitor(), rSearch);
 
   // Do we need to modify the reference set?
   if (randomBasis)
@@ -78,126 +64,65 @@ void RSModel::BuildModel(arma::mat&& referenceSet,
   switch (treeType)
   {
     case KD_TREE:
-      // If necessary, build the tree.
-      if (naive)
-      {
-        kdTreeRS = new RSType<tree::KDTree>(move(referenceSet), naive,
-            singleMode);
-      }
-      else
-      {
-        vector<size_t> oldFromNewReferences;
-        RSType<tree::KDTree>::Tree* kdTree = new RSType<tree::KDTree>::Tree(
-            move(referenceSet), oldFromNewReferences, leafSize);
-        kdTreeRS = new RSType<tree::KDTree>(kdTree, singleMode);
-
-        // Give the model ownership of the tree and the mappings.
-        kdTreeRS->treeOwner = true;
-        kdTreeRS->oldFromNewReferences = move(oldFromNewReferences);
-      }
-
+      rSearch = new RSType<tree::KDTree> (naive, singleMode);
       break;
 
     case COVER_TREE:
-      coverTreeRS = new RSType<tree::StandardCoverTree>(move(referenceSet),
-          naive, singleMode);
+      rSearch = new RSType<tree::StandardCoverTree>(naive, singleMode);
       break;
 
     case R_TREE:
-      rTreeRS = new RSType<tree::RTree>(move(referenceSet), naive,
-          singleMode);
+      rSearch = new RSType<tree::RTree>(naive,singleMode);
       break;
 
     case R_STAR_TREE:
-      rStarTreeRS = new RSType<tree::RStarTree>(move(referenceSet), naive,
-          singleMode);
+      rSearch = new RSType<tree::RStarTree>(naive, singleMode);
       break;
 
     case BALL_TREE:
-      // If necessary, build the ball tree.
-      if (naive)
-      {
-        ballTreeRS = new RSType<tree::BallTree>(move(referenceSet), naive,
-            singleMode);
-      }
-      else
-      {
-        vector<size_t> oldFromNewReferences;
-        RSType<tree::BallTree>::Tree* ballTree =
-            new RSType<tree::BallTree>::Tree(move(referenceSet),
-            oldFromNewReferences, leafSize);
-        ballTreeRS = new RSType<tree::BallTree>(ballTree, singleMode);
-
-        // Give the model ownership of the tree and the mappings.
-        ballTreeRS->treeOwner = true;
-        ballTreeRS->oldFromNewReferences = move(oldFromNewReferences);
-      }
-
+      rSearch = new RSType<tree::BallTree>(naive, singleMode);  
       break;
 
     case X_TREE:
-      xTreeRS = new RSType<tree::XTree>(move(referenceSet), naive,
-          singleMode);
+      rSearch = new RSType<tree::XTree>(naive, singleMode);
       break;
 
     case HILBERT_R_TREE:
-      hilbertRTreeRS = new RSType<tree::HilbertRTree>(move(referenceSet), naive,
-          singleMode);
+      rSearch = new RSType<tree::HilbertRTree>(naive, singleMode);
       break;
 
     case R_PLUS_TREE:
-      rPlusTreeRS = new RSType<tree::RPlusTree>(move(referenceSet), naive,
-          singleMode);
+      rSearch = new RSType<tree::RPlusTree>(naive, singleMode);
       break;
 
     case R_PLUS_PLUS_TREE:
-      rPlusPlusTreeRS = new RSType<tree::RPlusPlusTree>(move(referenceSet),
-          naive, singleMode);
+      rSearch = new RSType<tree::RPlusPlusTree>(naive, singleMode);
       break;
 
     case VP_TREE:
-      vpTreeRS = new RSType<tree::VPTree>(move(referenceSet), naive,
-          singleMode);
+      rSearch = new RSType<tree::VPTree>(naive, singleMode);
       break;
 
     case RP_TREE:
-      rpTreeRS = new RSType<tree::RPTree>(move(referenceSet), naive,
-          singleMode);
+      rSearch = new RSType<tree::RPTree>(naive, singleMode);
       break;
 
     case MAX_RP_TREE:
-      maxRPTreeRS = new RSType<tree::MaxRPTree>(move(referenceSet),
-          naive, singleMode);
+      rSearch = new RSType<tree::MaxRPTree>(naive, singleMode);
       break;
 
     case UB_TREE:
-      ubTreeRS = new RSType<tree::UBTree>(move(referenceSet),
-          naive, singleMode);
+      rSearch = new RSType<tree::UBTree>(naive, singleMode);
       break;
 
     case OCTREE:
-      // If necessary, build the octree.
-      if (naive)
-      {
-        octreeRS = new RSType<tree::Octree>(move(referenceSet), naive,
-            singleMode);
-      }
-      else
-      {
-        vector<size_t> oldFromNewReferences;
-        RSType<tree::Octree>::Tree* octree =
-            new RSType<tree::Octree>::Tree(move(referenceSet),
-            oldFromNewReferences, leafSize);
-        octreeRS = new RSType<tree::Octree>(octree, singleMode);
-
-        // Give the model ownership of the tree and the mappings.
-        octreeRS->treeOwner = true;
-        octreeRS->oldFromNewReferences = move(oldFromNewReferences);
-      }
-
+      rSearch = new RSType<tree::Octree>(naive, singleMode);
       break;
   }
 
+  TrainVisitor tn(std::move(referenceSet), leafSize);
+  boost::apply_visitor(tn, rSearch);
+
   if (!naive)
   {
     Timer::Stop("tree_building");
@@ -224,148 +149,10 @@ void RSModel::Search(arma::mat&& querySet,
   else
     Log::Info << "brute-force (naive) search..." << endl;
 
-  switch (treeType)
-  {
-    case KD_TREE:
-      if (!kdTreeRS->Naive() && !kdTreeRS->SingleMode())
-      {
-        // Build a second tree and search.
-        Timer::Start("tree_building");
-        Log::Info << "Building query tree..." << endl;
-        vector<size_t> oldFromNewQueries;
-        RSType<tree::KDTree>::Tree queryTree(move(querySet), oldFromNewQueries,
-            leafSize);
-        Log::Info << "Tree built." << endl;
-        Timer::Stop("tree_building");
-
-        vector<vector<size_t>> neighborsOut;
-        vector<vector<double>> distancesOut;
-        kdTreeRS->Search(&queryTree, range, neighborsOut, distancesOut);
-
-        // Remap the query points.
-        neighbors.resize(queryTree.Dataset().n_cols);
-        distances.resize(queryTree.Dataset().n_cols);
-        for (size_t i = 0; i < queryTree.Dataset().n_cols; ++i)
-        {
-          neighbors[oldFromNewQueries[i]] = neighborsOut[i];
-          distances[oldFromNewQueries[i]] = distancesOut[i];
-        }
-      }
-      else
-      {
-        // Search without building a second tree.
-        kdTreeRS->Search(querySet, range, neighbors, distances);
-      }
-      break;
-
-    case COVER_TREE:
-      coverTreeRS->Search(querySet, range, neighbors, distances);
-      break;
-
-    case R_TREE:
-      rTreeRS->Search(querySet, range, neighbors, distances);
-      break;
-
-    case R_STAR_TREE:
-      rStarTreeRS->Search(querySet, range, neighbors, distances);
-      break;
-
-    case BALL_TREE:
-      if (!ballTreeRS->Naive() && !ballTreeRS->SingleMode())
-      {
-        // Build a second tree and search.
-        Timer::Start("tree_building");
-        Log::Info << "Building query tree..." << endl;
-        vector<size_t> oldFromNewQueries;
-        RSType<tree::BallTree>::Tree queryTree(move(querySet),
-            oldFromNewQueries, leafSize);
-        Log::Info << "Tree built." << endl;
-        Timer::Stop("tree_building");
-
-        vector<vector<size_t>> neighborsOut;
-        vector<vector<double>> distancesOut;
-        ballTreeRS->Search(&queryTree, range, neighborsOut, distancesOut);
-
-        // Remap the query points.
-        neighbors.resize(queryTree.Dataset().n_cols);
-        distances.resize(queryTree.Dataset().n_cols);
-        for (size_t i = 0; i < queryTree.Dataset().n_cols; ++i)
-        {
-          neighbors[oldFromNewQueries[i]] = neighborsOut[i];
-          distances[oldFromNewQueries[i]] = distancesOut[i];
-        }
-      }
-      else
-      {
-        // Search without building a second tree.
-        ballTreeRS->Search(querySet, range, neighbors, distances);
-      }
-      break;
-
-    case X_TREE:
-      xTreeRS->Search(querySet, range, neighbors, distances);
-      break;
-
-    case HILBERT_R_TREE:
-      hilbertRTreeRS->Search(querySet, range, neighbors, distances);
-      break;
-
-    case R_PLUS_TREE:
-      rPlusTreeRS->Search(querySet, range, neighbors, distances);
-      break;
-
-    case R_PLUS_PLUS_TREE:
-      rPlusPlusTreeRS->Search(querySet, range, neighbors, distances);
-      break;
-
-    case VP_TREE:
-      vpTreeRS->Search(querySet, range, neighbors, distances);
-      break;
 
-    case RP_TREE:
-      rpTreeRS->Search(querySet, range, neighbors, distances);
-      break;
-
-    case MAX_RP_TREE:
-      maxRPTreeRS->Search(querySet, range, neighbors, distances);
-      break;
-
-    case UB_TREE:
-      ubTreeRS->Search(querySet, range, neighbors, distances);
-      break;
-
-    case OCTREE:
-      if (!octreeRS->Naive() && !octreeRS->SingleMode())
-      {
-        // Build a query tree and search.
-        Timer::Start("tree_building");
-        Log::Info << "Building query tree..." << endl;
-        vector<size_t> oldFromNewQueries;
-        RSType<tree::Octree>::Tree queryTree(move(querySet), oldFromNewQueries,
-            leafSize);
-        Log::Info << "Tree built." << endl;
-        Timer::Stop("tree_building");
-
-        vector<vector<size_t>> neighborsOut;
-        vector<vector<double>> distancesOut;
-        octreeRS->Search(&queryTree, range, neighborsOut, distancesOut);
-
-        // Remap the query points.
-        neighbors.resize(queryTree.Dataset().n_cols);
-        distances.resize(queryTree.Dataset().n_cols);
-        for (size_t i = 0; i < queryTree.Dataset().n_cols; ++i)
-        {
-          neighbors[oldFromNewQueries[i]] = neighborsOut[i];
-          distances[oldFromNewQueries[i]] = distancesOut[i];
-        }
-      }
-      else
-      {
-        // Search without building a second tree.
-        octreeRS->Search(querySet, range, neighbors, distances);
-      }
-      break;
-  }
+  BiSearchVisitor search(querySet, range, neighbors, distances,
+      leafSize);
+  boost::apply_visitor(search, rSearch);
 }
 
 // Perform range search (monochromatic case).
@@ -382,64 +169,8 @@ void RSModel::Search(const math::Range& range,
   else
     Log::Info << "brute-force (naive) search..." << endl;
 
-  switch (treeType)
-  {
-    case KD_TREE:
-      kdTreeRS->Search(range, neighbors, distances);
-      break;
-
-    case COVER_TREE:
-      coverTreeRS->Search(range, neighbors, distances);
-      break;
-
-    case R_TREE:
-      rTreeRS->Search(range, neighbors, distances);
-      break;
-
-    case R_STAR_TREE:
-      rStarTreeRS->Search(range, neighbors, distances);
-      break;
-
-    case BALL_TREE:
-      ballTreeRS->Search(range, neighbors, distances);
-      break;
-
-    case X_TREE:
-      xTreeRS->Search(range, neighbors, distances);
-      break;
-
-    case HILBERT_R_TREE:
-      hilbertRTreeRS->Search(range, neighbors, distances);
-      break;
-
-    case R_PLUS_TREE:
-      rPlusTreeRS->Search(range, neighbors, distances);
-      break;
-
-    case R_PLUS_PLUS_TREE:
-      rPlusPlusTreeRS->Search(range, neighbors, distances);
-      break;
-
-    case VP_TREE:
-      vpTreeRS->Search(range, neighbors, distances);
-      break;
-
-    case RP_TREE:
-      rpTreeRS->Search(range, neighbors, distances);
-      break;
-
-    case MAX_RP_TREE:
-      maxRPTreeRS->Search(range, neighbors, distances);
-      break;
-
-    case UB_TREE:
-      ubTreeRS->Search(range, neighbors, distances);
-      break;
-
-    case OCTREE:
-      octreeRS->Search(range, neighbors, distances);
-      break;
-  }
+  MonoSearchVisitor search(range, neighbors, distances);
+  boost::apply_visitor(search, rSearch);
 }
 
 // Get the name of the tree type.
@@ -483,33 +214,5 @@ std::string RSModel::TreeName() const
 // Clean memory.
 void RSModel::CleanMemory()
 {
-  delete kdTreeRS;
-  delete coverTreeRS;
-  delete rTreeRS;
-  delete rStarTreeRS;
-  delete ballTreeRS;
-  delete xTreeRS;
-  delete hilbertRTreeRS;
-  delete rPlusTreeRS;
-  delete rPlusPlusTreeRS;
-  delete vpTreeRS;
-  delete rpTreeRS;
-  delete maxRPTreeRS;
-  delete ubTreeRS;
-  delete octreeRS;
-
-  kdTreeRS = NULL;
-  coverTreeRS = NULL;
-  rTreeRS = NULL;
-  rStarTreeRS = NULL;
-  ballTreeRS = NULL;
-  xTreeRS = NULL;
-  hilbertRTreeRS = NULL;
-  rPlusTreeRS = NULL;
-  rPlusPlusTreeRS = NULL;
-  vpTreeRS = NULL;
-  rpTreeRS = NULL;
-  maxRPTreeRS = NULL;
-  ubTreeRS = NULL;
-  octreeRS = NULL;
+  boost::apply_visitor(DeleteVisitor(), rSearch);
 }
diff --git a/src/mlpack/methods/range_search/rs_model.hpp b/src/mlpack/methods/range_search/rs_model.hpp
index 3812521..3ef6424 100644
--- a/src/mlpack/methods/range_search/rs_model.hpp
+++ b/src/mlpack/methods/range_search/rs_model.hpp
@@ -19,12 +19,212 @@
 #include <mlpack/core/tree/cover_tree.hpp>
 #include <mlpack/core/tree/rectangle_tree.hpp>
 #include <mlpack/core/tree/octree.hpp>
-
+#include <boost/variant.hpp>
 #include "range_search.hpp"
 
 namespace mlpack {
 namespace range {
 
+/**
+ * Alias template for Range Search.
+ */
+template<template<typename TreeMetricType,
+                    typename TreeStatType,
+                    typename TreeMatType> class TreeType>
+using RSType = RangeSearch<metric::EuclideanDistance, arma::mat, TreeType>;
+
+
+struct RSModelName
+{
+  static const std::string Name() { return "range_search_model"; }
+};
+
+/**
+ * MonoSearchVisitor executes a monochromatic neighbor search on the given
+ * RSType. Range Search is performed on the reference set itself, no querySet.
+ */
+class MonoSearchVisitor : public  boost::static_visitor<void>
+{
+  private:
+    const math::Range& range;
+    std::vector<std::vector<size_t>>& neighbors;
+    std::vector<std::vector<double>>& distances;
+
+  public:
+    template<typename RSType>
+    void operator()(RSType* rs) const;
+
+    MonoSearchVisitor(const math::Range& range,
+                      std::vector<std::vector<size_t>>& neighbors,
+                      std::vector<std::vector<double>>& distances):
+        range(range),
+        neighbors(neighbors),
+        distances(distances)
+      {};      
+
+};
+
+/**
+ * BiSearchVisitor executes a bichromatic neighbor search on the given RSType.
+ * We use template specialization to differentiate those tree types that
+ * accept leafSize as a parameter. In these cases, before doing neighbor search,
+ * a query tree with proper leafSize is built from the querySet.
+ */
+class BiSearchVisitor : public boost::static_visitor<void>
+{
+ private:
+  //! The query set for the bichromatic search.
+  const arma::mat& querySet;
+  //! Range to search neighbours for.
+  const math::Range& range;
+  //! The result vector for neighbors.
+  std::vector<std::vector<size_t>>& neighbors;
+  //! The result vector for distances.
+  std::vector<std::vector<double>>& distances;
+  //! The number of points in a leaf (for BinarySpaceTrees).
+  const size_t leafSize;
+
+  //! Bichromatic neighbor search on the given RSType considering the leafSize.
+  template<typename RSType>
+  void SearchLeaf(RSType* rs) const;
+
+ public:
+  //! Alias template necessary for visual c++ compiler.
+  template<template<typename TreeMetricType,
+                    typename TreeStatType,
+                    typename TreeMatType> class TreeType>
+ using RSTypeT = RSType<TreeType>;
+
+  //! Default Bichromatic neighbor search on the given RSType instance.
+  template<template<typename TreeMetricType,
+                    typename TreeStatType,
+                    typename TreeMatType> class TreeType>
+  void operator()(RSTypeT<TreeType>* rs) const;
+
+  //! Bichromatic neighbor search on the given RSType specialized for KDTrees.
+  void operator()(RSTypeT<tree::KDTree>* rs) const;
+
+  //! Bichromatic neighbor search on the given RSType specialized for BallTrees.
+  void operator()(RSTypeT<tree::BallTree>* rs) const;
+
+  //! Bichromatic neighbor search specialized for octrees.
+  void operator()(RSTypeT<tree::Octree>* rs) const;
+
+  //! Construct the BiSearchVisitor.
+  BiSearchVisitor(const arma::mat& querySet,
+                  const math::Range& range,
+                  std::vector<std::vector<size_t>>& neighbors,
+                  std::vector<std::vector<double>>& distances,
+                  const size_t leafSize
+                  );
+};
+
+/**
+ * TrainVisitor sets the reference set to a new reference set on the given
+ * RSType. We use template specialization to differentiate those tree types that
+ * accept leafSize as a parameter. In these cases, a reference tree with proper
+ * leafSize is built from the referenceSet.
+ */
+class TrainVisitor : public boost::static_visitor<void>
+{
+ private:
+  //! The reference set to use for training.
+  arma::mat&& referenceSet;
+  //! The leaf size, used only by BinarySpaceTree.
+  size_t leafSize;
+  //! Train on the given RsType considering the leafSize.
+  template<typename RSType>
+  void TrainLeaf(RSType* rs) const;
+
+ public:
+  //! Alias template necessary for visual c++ compiler.
+  template<template<typename TreeMetricType,
+                    typename TreeStatType,
+                    typename TreeMatType> class TreeType>
+  using RSTypeT = RSType<TreeType>;
+
+  //! Default Train on the given RSType instance.
+  template<template<typename TreeMetricType,
+                    typename TreeStatType,
+                    typename TreeMatType> class TreeType>
+  void operator()(RSTypeT<TreeType>* rs) const;
+
+  //! Train on the given RSType specialized for KDTrees.
+  void operator()(RSTypeT<tree::KDTree>* rs) const;
+
+  //! Train on the given RSType specialized for BallTrees.
+  void operator()(RSTypeT<tree::BallTree>* rs) const;
+
+  //! Train specialized for octrees.
+  void operator()(RSTypeT<tree::Octree>* rs) const;
+
+  //! Construct the TrainVisitor object with the given reference set, leafSize
+  //! for BinarySpaceTrees, and tau and rho for spill trees.
+  TrainVisitor(arma::mat&& referenceSet,
+               const size_t leafSize
+               );
+};
+
+
+/**
+ * ReferenceSetVisitor exposes the referenceSet of the given RSType.
+ */
+class ReferenceSetVisitor : public boost::static_visitor<const arma::mat&>
+{
+ public:
+  //! Return the reference set.
+  template<typename RSType>
+  const arma::mat& operator()(RSType *rs) const;
+};
+
+/**
+ * DeleteVisitor deletes the given RSType instance.
+ */
+class DeleteVisitor : public boost::static_visitor<void>
+{
+ public:
+  //! Delete the RSType object.
+  template<typename RSType>
+  void operator()(RSType *rs) const;
+};
+
+/**
+  * Exposes the seralize method of the given RSType
+  */
+template<typename Archive>
+ class SerializeVisitor : public boost::static_visitor<void>
+ {
+  private:
+    Archive& ar;
+    const std::string& name;
+ 
+  public:
+    template<typename RSType>
+    void operator()(RSType *rs) const;
+ 
+    SerializeVisitor(Archive& ar, const std::string& name);
+ };
+
+/**
+ * SearchModeVisitor exposes the SearchMode() method of the given RSType.
+ */
+ class SingleModeVisitor : public boost::static_visitor<bool&>
+ {
+  public:
+    template<typename NSType>
+    bool& operator()(NSType *ns) const;
+ };
+
+/**
+ * NaiveVisitor exposes the Naive() method of the given RSType.
+ */
+ class NaiveVisitor : public boost::static_visitor<bool&>
+ {
+  public:
+    template<typename NSType>
+    bool& operator()(NSType *ns) const;
+ };
+
 class RSModel
 {
  public:
@@ -55,45 +255,26 @@ class RSModel
   //! Random projection matrix.
   arma::mat q;
 
-  //! The mostly-specified type of the range search model.
-  template<template<typename TreeMetricType,
-                    typename TreeStatType,
-                    typename TreeMatType> class TreeType>
-  using RSType = RangeSearch<metric::EuclideanDistance, arma::mat, TreeType>;
-
-  // Only one of these pointers will be non-NULL.
-  //! kd-tree based range search object (NULL if not in use).
-  RSType<tree::KDTree>* kdTreeRS;
-  //! Cover tree based range search object (NULL if not in use).
-  RSType<tree::StandardCoverTree>* coverTreeRS;
-  //! R tree based range search object (NULL if not in use).
-  RSType<tree::RTree>* rTreeRS;
-  //! R* tree based range search object (NULL if not in use).
-  RSType<tree::RStarTree>* rStarTreeRS;
-  //! Ball tree based range search object (NULL if not in use).
-  RSType<tree::BallTree>* ballTreeRS;
-  //! X tree based range search object (NULL if not in use).
-  RSType<tree::XTree>* xTreeRS;
-  //! Hilbert R tree based range search object (NULL if not in use).
-  RSType<tree::HilbertRTree>* hilbertRTreeRS;
-  //! R+ tree based range search object (NULL if not in use).
-  RSType<tree::RPlusTree>* rPlusTreeRS;
-  //! R++ tree based range search object (NULL if not in use).
-  RSType<tree::RPlusPlusTree>* rPlusPlusTreeRS;
-  //! VP tree based range search object (NULL if not in use).
-  RSType<tree::VPTree>* vpTreeRS;
-  //! Random projection tree (mean) based range search object
-  //! (NULL if not in use).
-  RSType<tree::RPTree>* rpTreeRS;
-  //! Random projection tree (max) based range search object
-  //! (NULL if not in use).
-  RSType<tree::MaxRPTree>* maxRPTreeRS;
-  //! Universal B tree based range search object
-  //! (NULL if not in use).
-  RSType<tree::UBTree>* ubTreeRS;
-  //! Octree-based range search object (NULL if not in use).
-  RSType<tree::Octree>* octreeRS;
-
+ /**
+   * rSearch holds an instance of the RangeSearch class for the current
+   * treeType. It is initialized every time BuildModel is executed.
+   * We access to the contained value through the visitor classes defined above.
+   */
+  boost::variant<RSType<tree::KDTree> *,
+                 RSType<tree::StandardCoverTree> *,
+                 RSType<tree::RTree> *,
+                 RSType<tree::RStarTree> *,
+                 RSType<tree::BallTree> *,
+                 RSType<tree::XTree> *,
+                 RSType<tree::HilbertRTree> *,
+                 RSType<tree::RPlusTree> *,
+                 RSType<tree::RPlusPlusTree> *,
+                 RSType<tree::VPTree> *,
+                 RSType<tree::RPTree> *,
+                 RSType<tree::MaxRPTree> *,
+                 RSType<tree::UBTree> *,
+                 RSType<tree::Octree> *> rSearch;
+                 
  public:
   /**
    * Initialize the RSModel with the given type and whether or not a random
diff --git a/src/mlpack/methods/range_search/rs_model_impl.hpp b/src/mlpack/methods/range_search/rs_model_impl.hpp
index 4bc73b4..3fde0f5 100644
--- a/src/mlpack/methods/range_search/rs_model_impl.hpp
+++ b/src/mlpack/methods/range_search/rs_model_impl.hpp
@@ -18,6 +18,216 @@
 namespace mlpack {
 namespace range {
 
+//! Monochromatic range search on the given RSType instance.
+template<typename RSType>
+void MonoSearchVisitor::operator()(RSType *rs) const
+{
+  if (rs)
+    return rs->Search(range, neighbors, distances);
+  throw std::runtime_error("no neighbor search model initialized");
+}
+
+//! Save parameters for bichromatic range search.
+BiSearchVisitor::BiSearchVisitor(const arma::mat& querySet,
+                                                        const math::Range& range,
+                                                        std::vector<std::vector<size_t>>& neighbors,
+                                                        std::vector<std::vector<double>>& distances,
+                                                        const size_t leafSize
+                                                        ) :
+    querySet(querySet),
+    range(range),
+    neighbors(neighbors),
+    distances(distances),
+    leafSize(leafSize)
+{}
+
+//! Default Bichromatic range search on the given RSType instance.
+template<template<typename TreeMetricType,
+                  typename TreeStatType,
+                  typename TreeMatType> class TreeType>
+void BiSearchVisitor::operator()(RSTypeT<TreeType>* rs) const
+{
+  if (rs)
+    return rs->Search(querySet, range, neighbors, distances);
+  throw std::runtime_error("no neighbor search model initialized");
+}
+
+//! Bichromatic range search on the given RSType specialized for KDTrees.
+void BiSearchVisitor::operator()(RSTypeT<tree::KDTree>* rs) const
+{
+  if (rs)
+    return SearchLeaf(rs);
+  throw std::runtime_error("no neighbor search model initialized");
+}
+
+//! Bichromatic range search on the given RSType specialized for BallTrees.
+void BiSearchVisitor::operator()(RSTypeT<tree::BallTree>* rs) const
+{
+  if (rs)
+    return SearchLeaf(rs);
+  throw std::runtime_error("no neighbor search model initialized");
+}
+
+//! Bichromatic range search specialized for Ocrees.
+void BiSearchVisitor::operator()(RSTypeT<tree::Octree>* rs) const
+{
+  if (rs)
+    return SearchLeaf(rs);
+  throw std::runtime_error("no neighbor search model initialized");
+}
+
+//! Bichromatic range search on the given RSType considering the leafSize.
+template<typename RSType>
+void BiSearchVisitor::SearchLeaf(RSType* rs) const
+{
+   if (!rs->Naive() && !rs->SingleMode())
+      {
+        // // Build a second tree and search.
+        // Timer::Start("tree_building");
+        // Log::Info << "Building query tree..." << endl;
+        std::vector<size_t> oldFromNewQueries;
+        typename RSType::Tree queryTree(std::move(querySet), oldFromNewQueries,
+            leafSize);
+        // Log::Info << "Tree built." << endl;
+        // Timer::Stop("tree_building");
+
+        std::vector<std::vector<size_t>> neighborsOut;
+        std::vector<std::vector<double>> distancesOut;
+        rs->Search(&queryTree, range, neighborsOut, distancesOut);
+
+        // Remap the query points.
+        neighbors.resize(queryTree.Dataset().n_cols);
+        distances.resize(queryTree.Dataset().n_cols);
+        for (size_t i = 0; i < queryTree.Dataset().n_cols; ++i)
+        {
+          neighbors[oldFromNewQueries[i]] = neighborsOut[i];
+          distances[oldFromNewQueries[i]] = distancesOut[i];
+        }
+      }
+      else
+      {
+        // Search without building a second tree.
+        rs->Search(querySet, range, neighbors, distances);
+      }
+}
+
+//! Save parameters for Train.
+TrainVisitor::TrainVisitor(arma::mat&& referenceSet,
+                                       const size_t leafSize
+                                       ) :
+      referenceSet(std::move(referenceSet)),
+      leafSize(leafSize)
+{}
+
+//! Default Train on the given RSType instance.
+template<template<typename TreeMetricType,
+                  typename TreeStatType,
+                  typename TreeMatType> class TreeType>
+void TrainVisitor::operator()(RSTypeT<TreeType>* rs) const
+{
+  if (rs)
+    return rs->Train(std::move(referenceSet));
+  throw std::runtime_error("no neighbor search model initialized");
+}
+
+//! Train on the given RSType specialized for KDTrees.
+void TrainVisitor::operator()(RSTypeT<tree::KDTree>* rs) const
+{
+  if (rs)
+    return TrainLeaf(rs);
+  throw std::runtime_error("no neighbor search model initialized");
+}
+
+//! Train on the given RSType specialized for BallTrees.
+void TrainVisitor::operator()(RSTypeT<tree::BallTree>* rs) const
+{
+  if (rs)
+    return TrainLeaf(rs);
+  throw std::runtime_error("no neighbor search model initialized");
+}
+
+//! Train specialized for Octrees.
+void TrainVisitor::operator()(RSTypeT<tree::Octree>* rs) const
+{
+  if (rs)
+    return TrainLeaf(rs);
+  throw std::runtime_error("no neighbor search model initialized");
+}
+
+//! Train on the given RSType considering the leafSize.
+template<typename RSType>
+void TrainVisitor::TrainLeaf(RSType* rs) const
+{
+   if (rs->Naive())
+      {
+        rs->Train(std::move(referenceSet));
+      }
+      else
+      {
+       std::vector<size_t> oldFromNewReferences;
+       typename RSType::Tree* tree =
+                    new typename RSType::Tree(std::move(referenceSet),
+                    oldFromNewReferences, leafSize);
+       rs->Train(tree);
+ 
+     // Give the model ownership of the tree and the mappings.
+     rs->treeOwner = true;
+     rs->oldFromNewReferences = std::move(oldFromNewReferences);
+  }
+      }
+
+
+//! Expose the referenceSet of the given RSType.
+template<typename RSType>
+const arma::mat& ReferenceSetVisitor::operator()(RSType* rs) const
+{
+  if (rs)
+    return rs->ReferenceSet();
+  throw std::runtime_error("no neighbor search model initialized");
+}
+
+//! For cleaning memory
+template<typename RSType>
+void DeleteVisitor::operator()(RSType* rs) const
+{
+  if (rs)
+    delete rs;
+}
+
+//! Save parameters for serializing
+template<typename Archive>
+ SerializeVisitor<Archive>::SerializeVisitor(Archive& ar,
+                                             const std::string& name) :
+                            ar(ar),
+                            name(name)
+ {}
+ 
+ //! Serializes the given RSType instance
+ template<typename Archive>
+ template<typename RSType>
+ void SerializeVisitor<Archive>::operator()(RSType *rs) const
+ {
+   ar & data::CreateNVP(rs, name);
+ }
+
+//! Return whether single mode enabled
+template<typename RSType>
+ bool& SingleModeVisitor::operator()(RSType *rs) const
+ {
+   if (rs)
+     return rs->SingleMode();
+   throw std::runtime_error("no neighbor search model initialized");
+ }
+
+//! Exposes Naive() function of given RSType
+template<typename RSType>
+ bool& NaiveVisitor::operator()(RSType *rs) const
+ {
+   if (rs)
+     return rs->Naive();
+   throw std::runtime_error("no neighbor search model initialized");
+ }
+ 
 // Serialize the model.
 template<typename Archive>
 void RSModel::Serialize(Archive& ar, const unsigned int /* version */)
@@ -30,237 +240,37 @@ void RSModel::Serialize(Archive& ar, const unsigned int /* version */)
 
   // This should never happen, but just in case...
   if (Archive::is_loading::value)
-    CleanMemory();
+    boost::apply_visitor(DeleteVisitor(), rSearch);
 
   // We'll only need to serialize one of the model objects, based on the type.
-  switch (treeType)
-  {
-    case KD_TREE:
-      ar & CreateNVP(kdTreeRS, "range_search_model");
-      break;
-
-    case COVER_TREE:
-      ar & CreateNVP(coverTreeRS, "range_search_model");
-      break;
-
-    case R_TREE:
-      ar & CreateNVP(rTreeRS, "range_search_model");
-      break;
-
-    case R_STAR_TREE:
-      ar & CreateNVP(rStarTreeRS, "range_search_model");
-      break;
-
-    case BALL_TREE:
-      ar & CreateNVP(ballTreeRS, "range_search_model");
-      break;
-
-    case X_TREE:
-      ar & CreateNVP(xTreeRS, "range_search_model");
-      break;
-
-    case HILBERT_R_TREE:
-      ar & CreateNVP(hilbertRTreeRS, "range_search_model");
-      break;
-
-    case R_PLUS_TREE:
-      ar & CreateNVP(rPlusTreeRS, "range_search_model");
-      break;
-
-    case R_PLUS_PLUS_TREE:
-      ar & CreateNVP(rPlusPlusTreeRS, "range_search_model");
-      break;
-
-    case VP_TREE:
-      ar & CreateNVP(vpTreeRS, "range_search_model");
-      break;
-
-    case RP_TREE:
-      ar & CreateNVP(rpTreeRS, "range_search_model");
-      break;
-
-    case MAX_RP_TREE:
-      ar & CreateNVP(maxRPTreeRS, "range_search_model");
-      break;
-
-    case UB_TREE:
-      ar & CreateNVP(ubTreeRS, "range_search_model");
-      break;
-
-    case OCTREE:
-      ar & CreateNVP(octreeRS, "range_search_model");
-      break;
-  }
+  const std::string& name = RSModelName::Name();
+  SerializeVisitor<Archive> s(ar, name);
+  boost::apply_visitor(s, rSearch);
 }
 
 inline const arma::mat& RSModel::Dataset() const
 {
-  if (kdTreeRS)
-    return kdTreeRS->ReferenceSet();
-  else if (coverTreeRS)
-    return coverTreeRS->ReferenceSet();
-  else if (rTreeRS)
-    return rTreeRS->ReferenceSet();
-  else if (rStarTreeRS)
-    return rStarTreeRS->ReferenceSet();
-  else if (ballTreeRS)
-    return ballTreeRS->ReferenceSet();
-  else if (xTreeRS)
-    return xTreeRS->ReferenceSet();
-  else if (hilbertRTreeRS)
-    return hilbertRTreeRS->ReferenceSet();
-  else if (rPlusTreeRS)
-    return rPlusTreeRS->ReferenceSet();
-  else if (rPlusPlusTreeRS)
-    return rPlusPlusTreeRS->ReferenceSet();
-  else if (vpTreeRS)
-    return vpTreeRS->ReferenceSet();
-  else if (rpTreeRS)
-    return rpTreeRS->ReferenceSet();
-  else if (maxRPTreeRS)
-    return maxRPTreeRS->ReferenceSet();
-  else if (ubTreeRS)
-    return ubTreeRS->ReferenceSet();
-  else if (octreeRS)
-    return octreeRS->ReferenceSet();
-
-  throw std::runtime_error("no range search model initialized");
+  return boost::apply_visitor(ReferenceSetVisitor(), rSearch);
 }
 
 inline bool RSModel::SingleMode() const
 {
-  if (kdTreeRS)
-    return kdTreeRS->SingleMode();
-  else if (coverTreeRS)
-    return coverTreeRS->SingleMode();
-  else if (rTreeRS)
-    return rTreeRS->SingleMode();
-  else if (rStarTreeRS)
-    return rStarTreeRS->SingleMode();
-  else if (ballTreeRS)
-    return ballTreeRS->SingleMode();
-  else if (xTreeRS)
-    return xTreeRS->SingleMode();
-  else if (hilbertRTreeRS)
-    return hilbertRTreeRS->SingleMode();
-  else if (rPlusTreeRS)
-    return rPlusTreeRS->SingleMode();
-  else if (rPlusPlusTreeRS)
-    return rPlusPlusTreeRS->SingleMode();
-  else if (vpTreeRS)
-    return vpTreeRS->SingleMode();
-  else if (rpTreeRS)
-    return rpTreeRS->SingleMode();
-  else if (maxRPTreeRS)
-    return maxRPTreeRS->SingleMode();
-  else if (ubTreeRS)
-    return ubTreeRS->SingleMode();
-  else if (octreeRS)
-    return octreeRS->SingleMode();
-
-  throw std::runtime_error("no range search model initialized");
+ return boost::apply_visitor(SingleModeVisitor(), rSearch);
 }
 
 inline bool& RSModel::SingleMode()
 {
-  if (kdTreeRS)
-    return kdTreeRS->SingleMode();
-  else if (coverTreeRS)
-    return coverTreeRS->SingleMode();
-  else if (rTreeRS)
-    return rTreeRS->SingleMode();
-  else if (rStarTreeRS)
-    return rStarTreeRS->SingleMode();
-  else if (ballTreeRS)
-    return ballTreeRS->SingleMode();
-  else if (xTreeRS)
-    return xTreeRS->SingleMode();
-  else if (hilbertRTreeRS)
-    return hilbertRTreeRS->SingleMode();
-  else if (rPlusTreeRS)
-    return rPlusTreeRS->SingleMode();
-  else if (rPlusPlusTreeRS)
-    return rPlusPlusTreeRS->SingleMode();
-  else if (vpTreeRS)
-    return vpTreeRS->SingleMode();
-  else if (rpTreeRS)
-    return rpTreeRS->SingleMode();
-  else if (maxRPTreeRS)
-    return maxRPTreeRS->SingleMode();
-  else if (ubTreeRS)
-    return ubTreeRS->SingleMode();
-  else if (octreeRS)
-    return octreeRS->SingleMode();
-
-  throw std::runtime_error("no range search model initialized");
+  return boost::apply_visitor(SingleModeVisitor(), rSearch);
 }
 
 inline bool RSModel::Naive() const
 {
-  if (kdTreeRS)
-    return kdTreeRS->Naive();
-  else if (coverTreeRS)
-    return coverTreeRS->Naive();
-  else if (rTreeRS)
-    return rTreeRS->Naive();
-  else if (rStarTreeRS)
-    return rStarTreeRS->Naive();
-  else if (ballTreeRS)
-    return ballTreeRS->Naive();
-  else if (xTreeRS)
-    return xTreeRS->Naive();
-  else if (hilbertRTreeRS)
-    return hilbertRTreeRS->Naive();
-  else if (rPlusTreeRS)
-    return rPlusTreeRS->Naive();
-  else if (rPlusPlusTreeRS)
-    return rPlusPlusTreeRS->Naive();
-  else if (vpTreeRS)
-    return vpTreeRS->Naive();
-  else if (rpTreeRS)
-    return rpTreeRS->Naive();
-  else if (maxRPTreeRS)
-    return maxRPTreeRS->Naive();
-  else if (ubTreeRS)
-    return ubTreeRS->Naive();
-  else if (octreeRS)
-    return octreeRS->Naive();
-
-  throw std::runtime_error("no range search model initialized");
+  return boost::apply_visitor(NaiveVisitor(), rSearch);
 }
 
 inline bool& RSModel::Naive()
 {
-  if (kdTreeRS)
-    return kdTreeRS->Naive();
-  else if (coverTreeRS)
-    return coverTreeRS->Naive();
-  else if (rTreeRS)
-    return rTreeRS->Naive();
-  else if (rStarTreeRS)
-    return rStarTreeRS->Naive();
-  else if (ballTreeRS)
-    return ballTreeRS->Naive();
-  else if (xTreeRS)
-    return xTreeRS->Naive();
-  else if (hilbertRTreeRS)
-    return hilbertRTreeRS->Naive();
-  else if (rPlusTreeRS)
-    return rPlusTreeRS->Naive();
-  else if (rPlusPlusTreeRS)
-    return rPlusPlusTreeRS->Naive();
-  else if (vpTreeRS)
-    return vpTreeRS->Naive();
-  else if (rpTreeRS)
-    return rpTreeRS->Naive();
-  else if (maxRPTreeRS)
-    return maxRPTreeRS->Naive();
-  else if (ubTreeRS)
-    return ubTreeRS->Naive();
-  else if (octreeRS)
-    return octreeRS->Naive();
-
-  throw std::runtime_error("no range search model initialized");
+  return boost::apply_visitor(NaiveVisitor(), rSearch);
 }
 
 } // namespace range

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git



More information about the debian-science-commits mailing list