[mlpack] 29/207: Refactored range search model to include boost variant and visitor paradigm
Barak A. Pearlmutter
barak+git at pearlmutter.net
Thu Mar 23 17:53:37 UTC 2017
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch master
in repository mlpack.
commit 13e8c1bdbfcd4fb8c1c3031b63ea1b128afb6a81
Author: dinesh Raj <dinu.iota at gmail.com>
Date: Mon Feb 6 23:11:20 2017 +0530
Refactored range search model to include boost variant and visitor paradigm
---
src/mlpack/methods/range_search/range_search.hpp | 4 +-
src/mlpack/methods/range_search/rs_model.cpp | 349 ++----------------
src/mlpack/methods/range_search/rs_model.hpp | 261 +++++++++++--
src/mlpack/methods/range_search/rs_model_impl.hpp | 428 +++++++++++-----------
4 files changed, 468 insertions(+), 574 deletions(-)
diff --git a/src/mlpack/methods/range_search/range_search.hpp b/src/mlpack/methods/range_search/range_search.hpp
index 2182795..9d4c3e1 100644
--- a/src/mlpack/methods/range_search/range_search.hpp
+++ b/src/mlpack/methods/range_search/range_search.hpp
@@ -22,7 +22,7 @@ namespace mlpack {
namespace range /** Range-search routines. */ {
//! Forward declaration.
-class RSModel;
+class TrainVisitor;
/**
* The RangeSearch class is a template class for performing range searches. It
@@ -323,7 +323,7 @@ class RangeSearch
size_t scores;
//! For access to mappings when building models.
- friend RSModel;
+ friend class TrainVisitor;
};
} // namespace range
diff --git a/src/mlpack/methods/range_search/rs_model.cpp b/src/mlpack/methods/range_search/rs_model.cpp
index 6857f34..8a280a3 100644
--- a/src/mlpack/methods/range_search/rs_model.cpp
+++ b/src/mlpack/methods/range_search/rs_model.cpp
@@ -23,21 +23,7 @@ using namespace mlpack::range;
RSModel::RSModel(TreeTypes treeType, bool randomBasis) :
treeType(treeType),
leafSize(0),
- randomBasis(randomBasis),
- kdTreeRS(NULL),
- coverTreeRS(NULL),
- rTreeRS(NULL),
- rStarTreeRS(NULL),
- ballTreeRS(NULL),
- xTreeRS(NULL),
- hilbertRTreeRS(NULL),
- rPlusTreeRS(NULL),
- rPlusPlusTreeRS(NULL),
- vpTreeRS(NULL),
- rpTreeRS(NULL),
- maxRPTreeRS(NULL),
- ubTreeRS(NULL),
- octreeRS(NULL)
+ randomBasis(randomBasis)
{
// Nothing to do.
}
@@ -45,7 +31,7 @@ RSModel::RSModel(TreeTypes treeType, bool randomBasis) :
// Clean memory, if necessary.
RSModel::~RSModel()
{
- CleanMemory();
+ boost::apply_visitor(DeleteVisitor(), rSearch);
}
void RSModel::BuildModel(arma::mat&& referenceSet,
@@ -63,7 +49,7 @@ void RSModel::BuildModel(arma::mat&& referenceSet,
this->leafSize = leafSize;
// Clean memory, if necessary.
- CleanMemory();
+ boost::apply_visitor(DeleteVisitor(), rSearch);
// Do we need to modify the reference set?
if (randomBasis)
@@ -78,126 +64,65 @@ void RSModel::BuildModel(arma::mat&& referenceSet,
switch (treeType)
{
case KD_TREE:
- // If necessary, build the tree.
- if (naive)
- {
- kdTreeRS = new RSType<tree::KDTree>(move(referenceSet), naive,
- singleMode);
- }
- else
- {
- vector<size_t> oldFromNewReferences;
- RSType<tree::KDTree>::Tree* kdTree = new RSType<tree::KDTree>::Tree(
- move(referenceSet), oldFromNewReferences, leafSize);
- kdTreeRS = new RSType<tree::KDTree>(kdTree, singleMode);
-
- // Give the model ownership of the tree and the mappings.
- kdTreeRS->treeOwner = true;
- kdTreeRS->oldFromNewReferences = move(oldFromNewReferences);
- }
-
+ rSearch = new RSType<tree::KDTree> (naive, singleMode);
break;
case COVER_TREE:
- coverTreeRS = new RSType<tree::StandardCoverTree>(move(referenceSet),
- naive, singleMode);
+ rSearch = new RSType<tree::StandardCoverTree>(naive, singleMode);
break;
case R_TREE:
- rTreeRS = new RSType<tree::RTree>(move(referenceSet), naive,
- singleMode);
+ rSearch = new RSType<tree::RTree>(naive,singleMode);
break;
case R_STAR_TREE:
- rStarTreeRS = new RSType<tree::RStarTree>(move(referenceSet), naive,
- singleMode);
+ rSearch = new RSType<tree::RStarTree>(naive, singleMode);
break;
case BALL_TREE:
- // If necessary, build the ball tree.
- if (naive)
- {
- ballTreeRS = new RSType<tree::BallTree>(move(referenceSet), naive,
- singleMode);
- }
- else
- {
- vector<size_t> oldFromNewReferences;
- RSType<tree::BallTree>::Tree* ballTree =
- new RSType<tree::BallTree>::Tree(move(referenceSet),
- oldFromNewReferences, leafSize);
- ballTreeRS = new RSType<tree::BallTree>(ballTree, singleMode);
-
- // Give the model ownership of the tree and the mappings.
- ballTreeRS->treeOwner = true;
- ballTreeRS->oldFromNewReferences = move(oldFromNewReferences);
- }
-
+ rSearch = new RSType<tree::BallTree>(naive, singleMode);
break;
case X_TREE:
- xTreeRS = new RSType<tree::XTree>(move(referenceSet), naive,
- singleMode);
+ rSearch = new RSType<tree::XTree>(naive, singleMode);
break;
case HILBERT_R_TREE:
- hilbertRTreeRS = new RSType<tree::HilbertRTree>(move(referenceSet), naive,
- singleMode);
+ rSearch = new RSType<tree::HilbertRTree>(naive, singleMode);
break;
case R_PLUS_TREE:
- rPlusTreeRS = new RSType<tree::RPlusTree>(move(referenceSet), naive,
- singleMode);
+ rSearch = new RSType<tree::RPlusTree>(naive, singleMode);
break;
case R_PLUS_PLUS_TREE:
- rPlusPlusTreeRS = new RSType<tree::RPlusPlusTree>(move(referenceSet),
- naive, singleMode);
+ rSearch = new RSType<tree::RPlusPlusTree>(naive, singleMode);
break;
case VP_TREE:
- vpTreeRS = new RSType<tree::VPTree>(move(referenceSet), naive,
- singleMode);
+ rSearch = new RSType<tree::VPTree>(naive, singleMode);
break;
case RP_TREE:
- rpTreeRS = new RSType<tree::RPTree>(move(referenceSet), naive,
- singleMode);
+ rSearch = new RSType<tree::RPTree>(naive, singleMode);
break;
case MAX_RP_TREE:
- maxRPTreeRS = new RSType<tree::MaxRPTree>(move(referenceSet),
- naive, singleMode);
+ rSearch = new RSType<tree::MaxRPTree>(naive, singleMode);
break;
case UB_TREE:
- ubTreeRS = new RSType<tree::UBTree>(move(referenceSet),
- naive, singleMode);
+ rSearch = new RSType<tree::UBTree>(naive, singleMode);
break;
case OCTREE:
- // If necessary, build the octree.
- if (naive)
- {
- octreeRS = new RSType<tree::Octree>(move(referenceSet), naive,
- singleMode);
- }
- else
- {
- vector<size_t> oldFromNewReferences;
- RSType<tree::Octree>::Tree* octree =
- new RSType<tree::Octree>::Tree(move(referenceSet),
- oldFromNewReferences, leafSize);
- octreeRS = new RSType<tree::Octree>(octree, singleMode);
-
- // Give the model ownership of the tree and the mappings.
- octreeRS->treeOwner = true;
- octreeRS->oldFromNewReferences = move(oldFromNewReferences);
- }
-
+ rSearch = new RSType<tree::Octree>(naive, singleMode);
break;
}
+ TrainVisitor tn(std::move(referenceSet), leafSize);
+ boost::apply_visitor(tn, rSearch);
+
if (!naive)
{
Timer::Stop("tree_building");
@@ -224,148 +149,10 @@ void RSModel::Search(arma::mat&& querySet,
else
Log::Info << "brute-force (naive) search..." << endl;
- switch (treeType)
- {
- case KD_TREE:
- if (!kdTreeRS->Naive() && !kdTreeRS->SingleMode())
- {
- // Build a second tree and search.
- Timer::Start("tree_building");
- Log::Info << "Building query tree..." << endl;
- vector<size_t> oldFromNewQueries;
- RSType<tree::KDTree>::Tree queryTree(move(querySet), oldFromNewQueries,
- leafSize);
- Log::Info << "Tree built." << endl;
- Timer::Stop("tree_building");
-
- vector<vector<size_t>> neighborsOut;
- vector<vector<double>> distancesOut;
- kdTreeRS->Search(&queryTree, range, neighborsOut, distancesOut);
-
- // Remap the query points.
- neighbors.resize(queryTree.Dataset().n_cols);
- distances.resize(queryTree.Dataset().n_cols);
- for (size_t i = 0; i < queryTree.Dataset().n_cols; ++i)
- {
- neighbors[oldFromNewQueries[i]] = neighborsOut[i];
- distances[oldFromNewQueries[i]] = distancesOut[i];
- }
- }
- else
- {
- // Search without building a second tree.
- kdTreeRS->Search(querySet, range, neighbors, distances);
- }
- break;
-
- case COVER_TREE:
- coverTreeRS->Search(querySet, range, neighbors, distances);
- break;
-
- case R_TREE:
- rTreeRS->Search(querySet, range, neighbors, distances);
- break;
-
- case R_STAR_TREE:
- rStarTreeRS->Search(querySet, range, neighbors, distances);
- break;
-
- case BALL_TREE:
- if (!ballTreeRS->Naive() && !ballTreeRS->SingleMode())
- {
- // Build a second tree and search.
- Timer::Start("tree_building");
- Log::Info << "Building query tree..." << endl;
- vector<size_t> oldFromNewQueries;
- RSType<tree::BallTree>::Tree queryTree(move(querySet),
- oldFromNewQueries, leafSize);
- Log::Info << "Tree built." << endl;
- Timer::Stop("tree_building");
-
- vector<vector<size_t>> neighborsOut;
- vector<vector<double>> distancesOut;
- ballTreeRS->Search(&queryTree, range, neighborsOut, distancesOut);
-
- // Remap the query points.
- neighbors.resize(queryTree.Dataset().n_cols);
- distances.resize(queryTree.Dataset().n_cols);
- for (size_t i = 0; i < queryTree.Dataset().n_cols; ++i)
- {
- neighbors[oldFromNewQueries[i]] = neighborsOut[i];
- distances[oldFromNewQueries[i]] = distancesOut[i];
- }
- }
- else
- {
- // Search without building a second tree.
- ballTreeRS->Search(querySet, range, neighbors, distances);
- }
- break;
-
- case X_TREE:
- xTreeRS->Search(querySet, range, neighbors, distances);
- break;
-
- case HILBERT_R_TREE:
- hilbertRTreeRS->Search(querySet, range, neighbors, distances);
- break;
-
- case R_PLUS_TREE:
- rPlusTreeRS->Search(querySet, range, neighbors, distances);
- break;
-
- case R_PLUS_PLUS_TREE:
- rPlusPlusTreeRS->Search(querySet, range, neighbors, distances);
- break;
-
- case VP_TREE:
- vpTreeRS->Search(querySet, range, neighbors, distances);
- break;
- case RP_TREE:
- rpTreeRS->Search(querySet, range, neighbors, distances);
- break;
-
- case MAX_RP_TREE:
- maxRPTreeRS->Search(querySet, range, neighbors, distances);
- break;
-
- case UB_TREE:
- ubTreeRS->Search(querySet, range, neighbors, distances);
- break;
-
- case OCTREE:
- if (!octreeRS->Naive() && !octreeRS->SingleMode())
- {
- // Build a query tree and search.
- Timer::Start("tree_building");
- Log::Info << "Building query tree..." << endl;
- vector<size_t> oldFromNewQueries;
- RSType<tree::Octree>::Tree queryTree(move(querySet), oldFromNewQueries,
- leafSize);
- Log::Info << "Tree built." << endl;
- Timer::Stop("tree_building");
-
- vector<vector<size_t>> neighborsOut;
- vector<vector<double>> distancesOut;
- octreeRS->Search(&queryTree, range, neighborsOut, distancesOut);
-
- // Remap the query points.
- neighbors.resize(queryTree.Dataset().n_cols);
- distances.resize(queryTree.Dataset().n_cols);
- for (size_t i = 0; i < queryTree.Dataset().n_cols; ++i)
- {
- neighbors[oldFromNewQueries[i]] = neighborsOut[i];
- distances[oldFromNewQueries[i]] = distancesOut[i];
- }
- }
- else
- {
- // Search without building a second tree.
- octreeRS->Search(querySet, range, neighbors, distances);
- }
- break;
- }
+ BiSearchVisitor search(querySet, range, neighbors, distances,
+ leafSize);
+ boost::apply_visitor(search, rSearch);
}
// Perform range search (monochromatic case).
@@ -382,64 +169,8 @@ void RSModel::Search(const math::Range& range,
else
Log::Info << "brute-force (naive) search..." << endl;
- switch (treeType)
- {
- case KD_TREE:
- kdTreeRS->Search(range, neighbors, distances);
- break;
-
- case COVER_TREE:
- coverTreeRS->Search(range, neighbors, distances);
- break;
-
- case R_TREE:
- rTreeRS->Search(range, neighbors, distances);
- break;
-
- case R_STAR_TREE:
- rStarTreeRS->Search(range, neighbors, distances);
- break;
-
- case BALL_TREE:
- ballTreeRS->Search(range, neighbors, distances);
- break;
-
- case X_TREE:
- xTreeRS->Search(range, neighbors, distances);
- break;
-
- case HILBERT_R_TREE:
- hilbertRTreeRS->Search(range, neighbors, distances);
- break;
-
- case R_PLUS_TREE:
- rPlusTreeRS->Search(range, neighbors, distances);
- break;
-
- case R_PLUS_PLUS_TREE:
- rPlusPlusTreeRS->Search(range, neighbors, distances);
- break;
-
- case VP_TREE:
- vpTreeRS->Search(range, neighbors, distances);
- break;
-
- case RP_TREE:
- rpTreeRS->Search(range, neighbors, distances);
- break;
-
- case MAX_RP_TREE:
- maxRPTreeRS->Search(range, neighbors, distances);
- break;
-
- case UB_TREE:
- ubTreeRS->Search(range, neighbors, distances);
- break;
-
- case OCTREE:
- octreeRS->Search(range, neighbors, distances);
- break;
- }
+ MonoSearchVisitor search(range, neighbors, distances);
+ boost::apply_visitor(search, rSearch);
}
// Get the name of the tree type.
@@ -483,33 +214,5 @@ std::string RSModel::TreeName() const
// Clean memory.
void RSModel::CleanMemory()
{
- delete kdTreeRS;
- delete coverTreeRS;
- delete rTreeRS;
- delete rStarTreeRS;
- delete ballTreeRS;
- delete xTreeRS;
- delete hilbertRTreeRS;
- delete rPlusTreeRS;
- delete rPlusPlusTreeRS;
- delete vpTreeRS;
- delete rpTreeRS;
- delete maxRPTreeRS;
- delete ubTreeRS;
- delete octreeRS;
-
- kdTreeRS = NULL;
- coverTreeRS = NULL;
- rTreeRS = NULL;
- rStarTreeRS = NULL;
- ballTreeRS = NULL;
- xTreeRS = NULL;
- hilbertRTreeRS = NULL;
- rPlusTreeRS = NULL;
- rPlusPlusTreeRS = NULL;
- vpTreeRS = NULL;
- rpTreeRS = NULL;
- maxRPTreeRS = NULL;
- ubTreeRS = NULL;
- octreeRS = NULL;
+ boost::apply_visitor(DeleteVisitor(), rSearch);
}
diff --git a/src/mlpack/methods/range_search/rs_model.hpp b/src/mlpack/methods/range_search/rs_model.hpp
index 3812521..3ef6424 100644
--- a/src/mlpack/methods/range_search/rs_model.hpp
+++ b/src/mlpack/methods/range_search/rs_model.hpp
@@ -19,12 +19,212 @@
#include <mlpack/core/tree/cover_tree.hpp>
#include <mlpack/core/tree/rectangle_tree.hpp>
#include <mlpack/core/tree/octree.hpp>
-
+#include <boost/variant.hpp>
#include "range_search.hpp"
namespace mlpack {
namespace range {
+/**
+ * Alias template for Range Search.
+ */
+template<template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType>
+using RSType = RangeSearch<metric::EuclideanDistance, arma::mat, TreeType>;
+
+
+struct RSModelName
+{
+ static const std::string Name() { return "range_search_model"; }
+};
+
+/**
+ * MonoSearchVisitor executes a monochromatic neighbor search on the given
+ * RSType. Range Search is performed on the reference set itself, no querySet.
+ */
+class MonoSearchVisitor : public boost::static_visitor<void>
+{
+ private:
+ const math::Range& range;
+ std::vector<std::vector<size_t>>& neighbors;
+ std::vector<std::vector<double>>& distances;
+
+ public:
+ template<typename RSType>
+ void operator()(RSType* rs) const;
+
+ MonoSearchVisitor(const math::Range& range,
+ std::vector<std::vector<size_t>>& neighbors,
+ std::vector<std::vector<double>>& distances):
+ range(range),
+ neighbors(neighbors),
+ distances(distances)
+ {};
+
+};
+
+/**
+ * BiSearchVisitor executes a bichromatic neighbor search on the given RSType.
+ * We use template specialization to differentiate those tree types that
+ * accept leafSize as a parameter. In these cases, before doing neighbor search,
+ * a query tree with proper leafSize is built from the querySet.
+ */
+class BiSearchVisitor : public boost::static_visitor<void>
+{
+ private:
+ //! The query set for the bichromatic search.
+ const arma::mat& querySet;
+ //! Range to search neighbours for.
+ const math::Range& range;
+ //! The result vector for neighbors.
+ std::vector<std::vector<size_t>>& neighbors;
+ //! The result vector for distances.
+ std::vector<std::vector<double>>& distances;
+ //! The number of points in a leaf (for BinarySpaceTrees).
+ const size_t leafSize;
+
+ //! Bichromatic neighbor search on the given RSType considering the leafSize.
+ template<typename RSType>
+ void SearchLeaf(RSType* rs) const;
+
+ public:
+ //! Alias template necessary for visual c++ compiler.
+ template<template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType>
+ using RSTypeT = RSType<TreeType>;
+
+ //! Default Bichromatic neighbor search on the given RSType instance.
+ template<template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType>
+ void operator()(RSTypeT<TreeType>* rs) const;
+
+ //! Bichromatic neighbor search on the given RSType specialized for KDTrees.
+ void operator()(RSTypeT<tree::KDTree>* rs) const;
+
+ //! Bichromatic neighbor search on the given RSType specialized for BallTrees.
+ void operator()(RSTypeT<tree::BallTree>* rs) const;
+
+ //! Bichromatic neighbor search specialized for octrees.
+ void operator()(RSTypeT<tree::Octree>* rs) const;
+
+ //! Construct the BiSearchVisitor.
+ BiSearchVisitor(const arma::mat& querySet,
+ const math::Range& range,
+ std::vector<std::vector<size_t>>& neighbors,
+ std::vector<std::vector<double>>& distances,
+ const size_t leafSize
+ );
+};
+
+/**
+ * TrainVisitor sets the reference set to a new reference set on the given
+ * RSType. We use template specialization to differentiate those tree types that
+ * accept leafSize as a parameter. In these cases, a reference tree with proper
+ * leafSize is built from the referenceSet.
+ */
+class TrainVisitor : public boost::static_visitor<void>
+{
+ private:
+ //! The reference set to use for training.
+ arma::mat&& referenceSet;
+ //! The leaf size, used only by BinarySpaceTree.
+ size_t leafSize;
+ //! Train on the given RsType considering the leafSize.
+ template<typename RSType>
+ void TrainLeaf(RSType* rs) const;
+
+ public:
+ //! Alias template necessary for visual c++ compiler.
+ template<template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType>
+ using RSTypeT = RSType<TreeType>;
+
+ //! Default Train on the given RSType instance.
+ template<template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType>
+ void operator()(RSTypeT<TreeType>* rs) const;
+
+ //! Train on the given RSType specialized for KDTrees.
+ void operator()(RSTypeT<tree::KDTree>* rs) const;
+
+ //! Train on the given RSType specialized for BallTrees.
+ void operator()(RSTypeT<tree::BallTree>* rs) const;
+
+ //! Train specialized for octrees.
+ void operator()(RSTypeT<tree::Octree>* rs) const;
+
+ //! Construct the TrainVisitor object with the given reference set, leafSize
+ //! for BinarySpaceTrees, and tau and rho for spill trees.
+ TrainVisitor(arma::mat&& referenceSet,
+ const size_t leafSize
+ );
+};
+
+
+/**
+ * ReferenceSetVisitor exposes the referenceSet of the given RSType.
+ */
+class ReferenceSetVisitor : public boost::static_visitor<const arma::mat&>
+{
+ public:
+ //! Return the reference set.
+ template<typename RSType>
+ const arma::mat& operator()(RSType *rs) const;
+};
+
+/**
+ * DeleteVisitor deletes the given RSType instance.
+ */
+class DeleteVisitor : public boost::static_visitor<void>
+{
+ public:
+ //! Delete the RSType object.
+ template<typename RSType>
+ void operator()(RSType *rs) const;
+};
+
+/**
+ * Exposes the seralize method of the given RSType
+ */
+template<typename Archive>
+ class SerializeVisitor : public boost::static_visitor<void>
+ {
+ private:
+ Archive& ar;
+ const std::string& name;
+
+ public:
+ template<typename RSType>
+ void operator()(RSType *rs) const;
+
+ SerializeVisitor(Archive& ar, const std::string& name);
+ };
+
+/**
+ * SearchModeVisitor exposes the SearchMode() method of the given RSType.
+ */
+ class SingleModeVisitor : public boost::static_visitor<bool&>
+ {
+ public:
+ template<typename NSType>
+ bool& operator()(NSType *ns) const;
+ };
+
+/**
+ * NaiveVisitor exposes the Naive() method of the given RSType.
+ */
+ class NaiveVisitor : public boost::static_visitor<bool&>
+ {
+ public:
+ template<typename NSType>
+ bool& operator()(NSType *ns) const;
+ };
+
class RSModel
{
public:
@@ -55,45 +255,26 @@ class RSModel
//! Random projection matrix.
arma::mat q;
- //! The mostly-specified type of the range search model.
- template<template<typename TreeMetricType,
- typename TreeStatType,
- typename TreeMatType> class TreeType>
- using RSType = RangeSearch<metric::EuclideanDistance, arma::mat, TreeType>;
-
- // Only one of these pointers will be non-NULL.
- //! kd-tree based range search object (NULL if not in use).
- RSType<tree::KDTree>* kdTreeRS;
- //! Cover tree based range search object (NULL if not in use).
- RSType<tree::StandardCoverTree>* coverTreeRS;
- //! R tree based range search object (NULL if not in use).
- RSType<tree::RTree>* rTreeRS;
- //! R* tree based range search object (NULL if not in use).
- RSType<tree::RStarTree>* rStarTreeRS;
- //! Ball tree based range search object (NULL if not in use).
- RSType<tree::BallTree>* ballTreeRS;
- //! X tree based range search object (NULL if not in use).
- RSType<tree::XTree>* xTreeRS;
- //! Hilbert R tree based range search object (NULL if not in use).
- RSType<tree::HilbertRTree>* hilbertRTreeRS;
- //! R+ tree based range search object (NULL if not in use).
- RSType<tree::RPlusTree>* rPlusTreeRS;
- //! R++ tree based range search object (NULL if not in use).
- RSType<tree::RPlusPlusTree>* rPlusPlusTreeRS;
- //! VP tree based range search object (NULL if not in use).
- RSType<tree::VPTree>* vpTreeRS;
- //! Random projection tree (mean) based range search object
- //! (NULL if not in use).
- RSType<tree::RPTree>* rpTreeRS;
- //! Random projection tree (max) based range search object
- //! (NULL if not in use).
- RSType<tree::MaxRPTree>* maxRPTreeRS;
- //! Universal B tree based range search object
- //! (NULL if not in use).
- RSType<tree::UBTree>* ubTreeRS;
- //! Octree-based range search object (NULL if not in use).
- RSType<tree::Octree>* octreeRS;
-
+ /**
+ * rSearch holds an instance of the RangeSearch class for the current
+ * treeType. It is initialized every time BuildModel is executed.
+ * We access to the contained value through the visitor classes defined above.
+ */
+ boost::variant<RSType<tree::KDTree> *,
+ RSType<tree::StandardCoverTree> *,
+ RSType<tree::RTree> *,
+ RSType<tree::RStarTree> *,
+ RSType<tree::BallTree> *,
+ RSType<tree::XTree> *,
+ RSType<tree::HilbertRTree> *,
+ RSType<tree::RPlusTree> *,
+ RSType<tree::RPlusPlusTree> *,
+ RSType<tree::VPTree> *,
+ RSType<tree::RPTree> *,
+ RSType<tree::MaxRPTree> *,
+ RSType<tree::UBTree> *,
+ RSType<tree::Octree> *> rSearch;
+
public:
/**
* Initialize the RSModel with the given type and whether or not a random
diff --git a/src/mlpack/methods/range_search/rs_model_impl.hpp b/src/mlpack/methods/range_search/rs_model_impl.hpp
index 4bc73b4..3fde0f5 100644
--- a/src/mlpack/methods/range_search/rs_model_impl.hpp
+++ b/src/mlpack/methods/range_search/rs_model_impl.hpp
@@ -18,6 +18,216 @@
namespace mlpack {
namespace range {
+//! Monochromatic range search on the given RSType instance.
+template<typename RSType>
+void MonoSearchVisitor::operator()(RSType *rs) const
+{
+ if (rs)
+ return rs->Search(range, neighbors, distances);
+ throw std::runtime_error("no neighbor search model initialized");
+}
+
+//! Save parameters for bichromatic range search.
+BiSearchVisitor::BiSearchVisitor(const arma::mat& querySet,
+ const math::Range& range,
+ std::vector<std::vector<size_t>>& neighbors,
+ std::vector<std::vector<double>>& distances,
+ const size_t leafSize
+ ) :
+ querySet(querySet),
+ range(range),
+ neighbors(neighbors),
+ distances(distances),
+ leafSize(leafSize)
+{}
+
+//! Default Bichromatic range search on the given RSType instance.
+template<template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType>
+void BiSearchVisitor::operator()(RSTypeT<TreeType>* rs) const
+{
+ if (rs)
+ return rs->Search(querySet, range, neighbors, distances);
+ throw std::runtime_error("no neighbor search model initialized");
+}
+
+//! Bichromatic range search on the given RSType specialized for KDTrees.
+void BiSearchVisitor::operator()(RSTypeT<tree::KDTree>* rs) const
+{
+ if (rs)
+ return SearchLeaf(rs);
+ throw std::runtime_error("no neighbor search model initialized");
+}
+
+//! Bichromatic range search on the given RSType specialized for BallTrees.
+void BiSearchVisitor::operator()(RSTypeT<tree::BallTree>* rs) const
+{
+ if (rs)
+ return SearchLeaf(rs);
+ throw std::runtime_error("no neighbor search model initialized");
+}
+
+//! Bichromatic range search specialized for Ocrees.
+void BiSearchVisitor::operator()(RSTypeT<tree::Octree>* rs) const
+{
+ if (rs)
+ return SearchLeaf(rs);
+ throw std::runtime_error("no neighbor search model initialized");
+}
+
+//! Bichromatic range search on the given RSType considering the leafSize.
+template<typename RSType>
+void BiSearchVisitor::SearchLeaf(RSType* rs) const
+{
+ if (!rs->Naive() && !rs->SingleMode())
+ {
+ // // Build a second tree and search.
+ // Timer::Start("tree_building");
+ // Log::Info << "Building query tree..." << endl;
+ std::vector<size_t> oldFromNewQueries;
+ typename RSType::Tree queryTree(std::move(querySet), oldFromNewQueries,
+ leafSize);
+ // Log::Info << "Tree built." << endl;
+ // Timer::Stop("tree_building");
+
+ std::vector<std::vector<size_t>> neighborsOut;
+ std::vector<std::vector<double>> distancesOut;
+ rs->Search(&queryTree, range, neighborsOut, distancesOut);
+
+ // Remap the query points.
+ neighbors.resize(queryTree.Dataset().n_cols);
+ distances.resize(queryTree.Dataset().n_cols);
+ for (size_t i = 0; i < queryTree.Dataset().n_cols; ++i)
+ {
+ neighbors[oldFromNewQueries[i]] = neighborsOut[i];
+ distances[oldFromNewQueries[i]] = distancesOut[i];
+ }
+ }
+ else
+ {
+ // Search without building a second tree.
+ rs->Search(querySet, range, neighbors, distances);
+ }
+}
+
+//! Save parameters for Train.
+TrainVisitor::TrainVisitor(arma::mat&& referenceSet,
+ const size_t leafSize
+ ) :
+ referenceSet(std::move(referenceSet)),
+ leafSize(leafSize)
+{}
+
+//! Default Train on the given RSType instance.
+template<template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType>
+void TrainVisitor::operator()(RSTypeT<TreeType>* rs) const
+{
+ if (rs)
+ return rs->Train(std::move(referenceSet));
+ throw std::runtime_error("no neighbor search model initialized");
+}
+
+//! Train on the given RSType specialized for KDTrees.
+void TrainVisitor::operator()(RSTypeT<tree::KDTree>* rs) const
+{
+ if (rs)
+ return TrainLeaf(rs);
+ throw std::runtime_error("no neighbor search model initialized");
+}
+
+//! Train on the given RSType specialized for BallTrees.
+void TrainVisitor::operator()(RSTypeT<tree::BallTree>* rs) const
+{
+ if (rs)
+ return TrainLeaf(rs);
+ throw std::runtime_error("no neighbor search model initialized");
+}
+
+//! Train specialized for Octrees.
+void TrainVisitor::operator()(RSTypeT<tree::Octree>* rs) const
+{
+ if (rs)
+ return TrainLeaf(rs);
+ throw std::runtime_error("no neighbor search model initialized");
+}
+
+//! Train on the given RSType considering the leafSize.
+template<typename RSType>
+void TrainVisitor::TrainLeaf(RSType* rs) const
+{
+ if (rs->Naive())
+ {
+ rs->Train(std::move(referenceSet));
+ }
+ else
+ {
+ std::vector<size_t> oldFromNewReferences;
+ typename RSType::Tree* tree =
+ new typename RSType::Tree(std::move(referenceSet),
+ oldFromNewReferences, leafSize);
+ rs->Train(tree);
+
+ // Give the model ownership of the tree and the mappings.
+ rs->treeOwner = true;
+ rs->oldFromNewReferences = std::move(oldFromNewReferences);
+ }
+ }
+
+
+//! Expose the referenceSet of the given RSType.
+template<typename RSType>
+const arma::mat& ReferenceSetVisitor::operator()(RSType* rs) const
+{
+ if (rs)
+ return rs->ReferenceSet();
+ throw std::runtime_error("no neighbor search model initialized");
+}
+
+//! For cleaning memory
+template<typename RSType>
+void DeleteVisitor::operator()(RSType* rs) const
+{
+ if (rs)
+ delete rs;
+}
+
+//! Save parameters for serializing
+template<typename Archive>
+ SerializeVisitor<Archive>::SerializeVisitor(Archive& ar,
+ const std::string& name) :
+ ar(ar),
+ name(name)
+ {}
+
+ //! Serializes the given RSType instance
+ template<typename Archive>
+ template<typename RSType>
+ void SerializeVisitor<Archive>::operator()(RSType *rs) const
+ {
+ ar & data::CreateNVP(rs, name);
+ }
+
+//! Return whether single mode enabled
+template<typename RSType>
+ bool& SingleModeVisitor::operator()(RSType *rs) const
+ {
+ if (rs)
+ return rs->SingleMode();
+ throw std::runtime_error("no neighbor search model initialized");
+ }
+
+//! Exposes Naive() function of given RSType
+template<typename RSType>
+ bool& NaiveVisitor::operator()(RSType *rs) const
+ {
+ if (rs)
+ return rs->Naive();
+ throw std::runtime_error("no neighbor search model initialized");
+ }
+
// Serialize the model.
template<typename Archive>
void RSModel::Serialize(Archive& ar, const unsigned int /* version */)
@@ -30,237 +240,37 @@ void RSModel::Serialize(Archive& ar, const unsigned int /* version */)
// This should never happen, but just in case...
if (Archive::is_loading::value)
- CleanMemory();
+ boost::apply_visitor(DeleteVisitor(), rSearch);
// We'll only need to serialize one of the model objects, based on the type.
- switch (treeType)
- {
- case KD_TREE:
- ar & CreateNVP(kdTreeRS, "range_search_model");
- break;
-
- case COVER_TREE:
- ar & CreateNVP(coverTreeRS, "range_search_model");
- break;
-
- case R_TREE:
- ar & CreateNVP(rTreeRS, "range_search_model");
- break;
-
- case R_STAR_TREE:
- ar & CreateNVP(rStarTreeRS, "range_search_model");
- break;
-
- case BALL_TREE:
- ar & CreateNVP(ballTreeRS, "range_search_model");
- break;
-
- case X_TREE:
- ar & CreateNVP(xTreeRS, "range_search_model");
- break;
-
- case HILBERT_R_TREE:
- ar & CreateNVP(hilbertRTreeRS, "range_search_model");
- break;
-
- case R_PLUS_TREE:
- ar & CreateNVP(rPlusTreeRS, "range_search_model");
- break;
-
- case R_PLUS_PLUS_TREE:
- ar & CreateNVP(rPlusPlusTreeRS, "range_search_model");
- break;
-
- case VP_TREE:
- ar & CreateNVP(vpTreeRS, "range_search_model");
- break;
-
- case RP_TREE:
- ar & CreateNVP(rpTreeRS, "range_search_model");
- break;
-
- case MAX_RP_TREE:
- ar & CreateNVP(maxRPTreeRS, "range_search_model");
- break;
-
- case UB_TREE:
- ar & CreateNVP(ubTreeRS, "range_search_model");
- break;
-
- case OCTREE:
- ar & CreateNVP(octreeRS, "range_search_model");
- break;
- }
+ const std::string& name = RSModelName::Name();
+ SerializeVisitor<Archive> s(ar, name);
+ boost::apply_visitor(s, rSearch);
}
inline const arma::mat& RSModel::Dataset() const
{
- if (kdTreeRS)
- return kdTreeRS->ReferenceSet();
- else if (coverTreeRS)
- return coverTreeRS->ReferenceSet();
- else if (rTreeRS)
- return rTreeRS->ReferenceSet();
- else if (rStarTreeRS)
- return rStarTreeRS->ReferenceSet();
- else if (ballTreeRS)
- return ballTreeRS->ReferenceSet();
- else if (xTreeRS)
- return xTreeRS->ReferenceSet();
- else if (hilbertRTreeRS)
- return hilbertRTreeRS->ReferenceSet();
- else if (rPlusTreeRS)
- return rPlusTreeRS->ReferenceSet();
- else if (rPlusPlusTreeRS)
- return rPlusPlusTreeRS->ReferenceSet();
- else if (vpTreeRS)
- return vpTreeRS->ReferenceSet();
- else if (rpTreeRS)
- return rpTreeRS->ReferenceSet();
- else if (maxRPTreeRS)
- return maxRPTreeRS->ReferenceSet();
- else if (ubTreeRS)
- return ubTreeRS->ReferenceSet();
- else if (octreeRS)
- return octreeRS->ReferenceSet();
-
- throw std::runtime_error("no range search model initialized");
+ return boost::apply_visitor(ReferenceSetVisitor(), rSearch);
}
inline bool RSModel::SingleMode() const
{
- if (kdTreeRS)
- return kdTreeRS->SingleMode();
- else if (coverTreeRS)
- return coverTreeRS->SingleMode();
- else if (rTreeRS)
- return rTreeRS->SingleMode();
- else if (rStarTreeRS)
- return rStarTreeRS->SingleMode();
- else if (ballTreeRS)
- return ballTreeRS->SingleMode();
- else if (xTreeRS)
- return xTreeRS->SingleMode();
- else if (hilbertRTreeRS)
- return hilbertRTreeRS->SingleMode();
- else if (rPlusTreeRS)
- return rPlusTreeRS->SingleMode();
- else if (rPlusPlusTreeRS)
- return rPlusPlusTreeRS->SingleMode();
- else if (vpTreeRS)
- return vpTreeRS->SingleMode();
- else if (rpTreeRS)
- return rpTreeRS->SingleMode();
- else if (maxRPTreeRS)
- return maxRPTreeRS->SingleMode();
- else if (ubTreeRS)
- return ubTreeRS->SingleMode();
- else if (octreeRS)
- return octreeRS->SingleMode();
-
- throw std::runtime_error("no range search model initialized");
+ return boost::apply_visitor(SingleModeVisitor(), rSearch);
}
inline bool& RSModel::SingleMode()
{
- if (kdTreeRS)
- return kdTreeRS->SingleMode();
- else if (coverTreeRS)
- return coverTreeRS->SingleMode();
- else if (rTreeRS)
- return rTreeRS->SingleMode();
- else if (rStarTreeRS)
- return rStarTreeRS->SingleMode();
- else if (ballTreeRS)
- return ballTreeRS->SingleMode();
- else if (xTreeRS)
- return xTreeRS->SingleMode();
- else if (hilbertRTreeRS)
- return hilbertRTreeRS->SingleMode();
- else if (rPlusTreeRS)
- return rPlusTreeRS->SingleMode();
- else if (rPlusPlusTreeRS)
- return rPlusPlusTreeRS->SingleMode();
- else if (vpTreeRS)
- return vpTreeRS->SingleMode();
- else if (rpTreeRS)
- return rpTreeRS->SingleMode();
- else if (maxRPTreeRS)
- return maxRPTreeRS->SingleMode();
- else if (ubTreeRS)
- return ubTreeRS->SingleMode();
- else if (octreeRS)
- return octreeRS->SingleMode();
-
- throw std::runtime_error("no range search model initialized");
+ return boost::apply_visitor(SingleModeVisitor(), rSearch);
}
inline bool RSModel::Naive() const
{
- if (kdTreeRS)
- return kdTreeRS->Naive();
- else if (coverTreeRS)
- return coverTreeRS->Naive();
- else if (rTreeRS)
- return rTreeRS->Naive();
- else if (rStarTreeRS)
- return rStarTreeRS->Naive();
- else if (ballTreeRS)
- return ballTreeRS->Naive();
- else if (xTreeRS)
- return xTreeRS->Naive();
- else if (hilbertRTreeRS)
- return hilbertRTreeRS->Naive();
- else if (rPlusTreeRS)
- return rPlusTreeRS->Naive();
- else if (rPlusPlusTreeRS)
- return rPlusPlusTreeRS->Naive();
- else if (vpTreeRS)
- return vpTreeRS->Naive();
- else if (rpTreeRS)
- return rpTreeRS->Naive();
- else if (maxRPTreeRS)
- return maxRPTreeRS->Naive();
- else if (ubTreeRS)
- return ubTreeRS->Naive();
- else if (octreeRS)
- return octreeRS->Naive();
-
- throw std::runtime_error("no range search model initialized");
+ return boost::apply_visitor(NaiveVisitor(), rSearch);
}
inline bool& RSModel::Naive()
{
- if (kdTreeRS)
- return kdTreeRS->Naive();
- else if (coverTreeRS)
- return coverTreeRS->Naive();
- else if (rTreeRS)
- return rTreeRS->Naive();
- else if (rStarTreeRS)
- return rStarTreeRS->Naive();
- else if (ballTreeRS)
- return ballTreeRS->Naive();
- else if (xTreeRS)
- return xTreeRS->Naive();
- else if (hilbertRTreeRS)
- return hilbertRTreeRS->Naive();
- else if (rPlusTreeRS)
- return rPlusTreeRS->Naive();
- else if (rPlusPlusTreeRS)
- return rPlusPlusTreeRS->Naive();
- else if (vpTreeRS)
- return vpTreeRS->Naive();
- else if (rpTreeRS)
- return rpTreeRS->Naive();
- else if (maxRPTreeRS)
- return maxRPTreeRS->Naive();
- else if (ubTreeRS)
- return ubTreeRS->Naive();
- else if (octreeRS)
- return octreeRS->Naive();
-
- throw std::runtime_error("no range search model initialized");
+ return boost::apply_visitor(NaiveVisitor(), rSearch);
}
} // namespace range
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list