[mlpack] 99/324: rectangle tree traverser
Barak A. Pearlmutter
barak+git at cs.nuim.ie
Sun Aug 17 08:22:00 UTC 2014
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch svn-trunk
in repository mlpack.
commit f8dd10288a0d4bae122b35099d2bd108673bc2f5
Author: andrewmw94 <andrewmw94 at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date: Wed Jul 2 17:28:23 2014 +0000
rectangle tree traverser
git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@16737 9d5b8971-822b-0410-80eb-d18c1038ef23
---
src/mlpack/core/tree/CMakeLists.txt | 6 +-
src/mlpack/core/tree/rectangle_tree.hpp | 3 +-
.../tree/rectangle_tree/dual_tree_traverser.hpp | 90 ++++++++++++++++++++++
.../rectangle_tree/dual_tree_traverser_impl.hpp | 49 ++++++++++++
.../core/tree/rectangle_tree/rectangle_tree.hpp | 9 ++-
src/mlpack/methods/neighbor_search/allknn_main.cpp | 70 +++++++++++++++--
.../neighbor_search/neighbor_search_impl.hpp | 2 +-
.../neighbor_search/neighbor_search_rules_impl.hpp | 2 +-
8 files changed, 216 insertions(+), 15 deletions(-)
diff --git a/src/mlpack/core/tree/CMakeLists.txt b/src/mlpack/core/tree/CMakeLists.txt
index c526a70..7988285 100644
--- a/src/mlpack/core/tree/CMakeLists.txt
+++ b/src/mlpack/core/tree/CMakeLists.txt
@@ -32,8 +32,10 @@ set(SOURCES
rectangle_tree.hpp
rectangle_tree/rectangle_tree.hpp
rectangle_tree/rectangle_tree_impl.hpp
- rectangle_tree/rectangle_tree_traverser.hpp
- rectangle_tree/rectangle_tree_traverser_impl.hpp
+ rectangle_tree/single_tree_traverser.hpp
+ rectangle_tree/single_tree_traverser_impl.hpp
+ rectangle_tree/dual_tree_traverser.hpp
+ rectangle_tree/dual_tree_traverser_impl.hpp
rectangle_tree/r_tree_split.hpp
rectangle_tree/r_tree_split_impl.hpp
rectangle_tree/r_tree_descent_heuristic.hpp
diff --git a/src/mlpack/core/tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree.hpp
index 743f7f5..f0d156e 100644
--- a/src/mlpack/core/tree/rectangle_tree.hpp
+++ b/src/mlpack/core/tree/rectangle_tree.hpp
@@ -13,7 +13,8 @@
*/
#include "bounds.hpp"
#include "rectangle_tree/rectangle_tree.hpp"
-#include "rectangle_tree/rectangle_tree_traverser.hpp"
+#include "rectangle_tree/single_tree_traverser.hpp"
+#include "rectangle_tree/dual_tree_traverser.hpp"
#include "rectangle_tree/r_tree_split.hpp"
#include "rectangle_tree/r_tree_descent_heuristic.hpp"
#include "rectangle_tree/traits.hpp"
diff --git a/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp
new file mode 100644
index 0000000..1091224
--- /dev/null
+++ b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp
@@ -0,0 +1,90 @@
+/**
+ * @file dual_tree_traverser.hpp
+ * @author Andrew Wells
+ *
+ * A nested class of Rectangle Tree for traversing rectangle type trees
+ * with a given set of rules which indicate the branches to prune and the
+ * order in which to recurse. This is just here to make it compile.
+ */
+#ifndef __MLPACK_CORE_TREE_RECTANGLE_TREE_DUAL_TREE_TRAVERSER_HPP
+#define __MLPACK_CORE_TREE_RECTANGLE_TREE_DUAL_TREE_TRAVERSER_HPP
+
+#include <mlpack/core.hpp>
+
+#include "rectangle_tree.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename SplitType,
+ typename DescentType,
+ typename StatisticType,
+ typename MatType>
+template<typename RuleType>
+class RectangleTree<SplitType, DescentType, StatisticType, MatType>::
+ DualTreeTraverser
+{
+ public:
+ /**
+ * Instantiate the dual-tree traverser with the given rule set.
+ */
+ DualTreeTraverser(RuleType& rule);
+
+ /**
+ * Traverse the two trees. This does not reset the number of prunes.
+ *
+ * @param queryNode The query node to be traversed.
+ * @param referenceNode The reference node to be traversed.
+ * @param score The score of the current node combination.
+ */
+ void Traverse(RectangleTree<SplitType, DescentType, StatisticType, MatType>& queryNode,
+ RectangleTree<SplitType, DescentType, StatisticType, MatType>& referenceNode);
+
+ //! Get the number of prunes.
+ size_t NumPrunes() const { return numPrunes; }
+ //! Modify the number of prunes.
+ size_t& NumPrunes() { return numPrunes; }
+
+ //! Get the number of visited combinations.
+ size_t NumVisited() const { return numVisited; }
+ //! Modify the number of visited combinations.
+ size_t& NumVisited() { return numVisited; }
+
+ //! Get the number of times a node combination was scored.
+ size_t NumScores() const { return numScores; }
+ //! Modify the number of times a node combination was scored.
+ size_t& NumScores() { return numScores; }
+
+ //! Get the number of times a base case was calculated.
+ size_t NumBaseCases() const { return numBaseCases; }
+ //! Modify the number of times a base case was calculated.
+ size_t& NumBaseCases() { return numBaseCases; }
+
+ private:
+ //! Reference to the rules with which the trees will be traversed.
+ RuleType& rule;
+
+ //! The number of prunes.
+ size_t numPrunes;
+
+ //! The number of node combinations that have been visited during traversal.
+ size_t numVisited;
+
+ //! The number of times a node combination was scored.
+ size_t numScores;
+
+ //! The number of times a base case was calculated.
+ size_t numBaseCases;
+
+ //! Traversal information, held in the class so that it isn't continually
+ //! being reallocated.
+ typename RuleType::TraversalInfoType traversalInfo;
+};
+
+}; // namespace tree
+}; // namespace mlpack
+
+// Include implementation.
+#include "dual_tree_traverser_impl.hpp"
+
+#endif
diff --git a/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
new file mode 100644
index 0000000..8af988e
--- /dev/null
+++ b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
@@ -0,0 +1,49 @@
+/**
+ * @file dual_tree_traverser_impl.hpp
+ * @author Andrew Wells
+ *
+ * A class for traversing rectangle type trees with a given set of rules
+ * which indicate the branches to prune and the order in which to recurse.
+ * This is a depth-first traverser.
+ */
+#ifndef __MLPAC_CORE_TREE_RECTANGLE_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP
+#define __MLPAC_CORE_TREE_RECTANGLE_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP
+
+#include "dual_tree_traverser.hpp"
+
+#include <algorithm>
+#include <stack>
+
+namespace mlpack {
+namespace tree {
+
+template<typename SplitType,
+ typename DescentType,
+ typename StatisticType,
+ typename MatType>
+template<typename RuleType>
+RectangleTree<SplitType, DescentType, StatisticType, MatType>::
+DualTreeTraverser<RuleType>::DualTreeTraverser(RuleType& rule) :
+ rule(rule),
+ numPrunes(0)
+{ /* Nothing to do */ }
+
+template<typename SplitType,
+ typename DescentType,
+ typename StatisticType,
+ typename MatType>
+template<typename RuleType>
+void RectangleTree<SplitType, DescentType, StatisticType, MatType>::
+DualTreeTraverser<RuleType>::Traverse(RectangleTree<SplitType, DescentType, StatisticType, MatType>& queryNode,
+ RectangleTree<SplitType, DescentType, StatisticType, MatType>& referenceNode)
+{
+ //Do nothing. Just here to prevent warnings.
+ if(queryNode.NumDescendants() > referenceNode.NumDescendants())
+ return;
+ return;
+}
+
+}; // namespace tree
+}; // namespace mlpack
+
+#endif
\ No newline at end of file
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
index e6833b9..e50d372 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
@@ -79,10 +79,13 @@ class RectangleTree
//! So other classes can use TreeType::Mat.
typedef MatType Mat;
- //! A traverser for rectangle type trees. See
- //! rectangle_tree_traverser.hpp for implementation.
+ //! A single traverser for rectangle type trees. See
+ //! single_tree_traverser.hpp for implementation.
template<typename RuleType>
- class RectangleTreeTraverser;
+ class SingleTreeTraverser;
+ //! A dual tree traverser for rectangle type trees.
+ template<typename RuleType>
+ class DualTreeTraverser;
/**
* Construct this as the root node of a rectangle type tree using the given
diff --git a/src/mlpack/methods/neighbor_search/allknn_main.cpp b/src/mlpack/methods/neighbor_search/allknn_main.cpp
index 432c23e..b4dc0cf 100644
--- a/src/mlpack/methods/neighbor_search/allknn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/allknn_main.cpp
@@ -58,7 +58,7 @@ PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
PARAM_FLAG("cover_tree", "If true, use cover trees to perform the search "
"(experimental, may be slow).", "c");
PARAM_FLAG("r_tree", "If true, use an R-Tree to perform the search "
- "(experimental, may be slow. Currently automatically sets single_mode.).", "R");
+ "(experimental, may be slow. Currently automatically sets single_mode.).", "T");
PARAM_FLAG("random_basis", "Before tree-building, project the data onto a "
"random orthogonal basis.", "R");
PARAM_INT("seed", "Random seed (if 0, std::time(NULL) is used).", "s", 0);
@@ -132,6 +132,7 @@ int main(int argc, char *argv[])
} else if (!singleMode && CLI::HasParam("r_tree")) // R_tree requires single mode.
{
Log::Warn << "--single_mode assumed because --r_tree is present." << endl;
+ singleMode = true;
}
if (naive)
@@ -269,16 +270,71 @@ int main(int argc, char *argv[])
// Make sure to notify the user that they are using an r tree.
Log::Info << "Using r tree for nearest-neighbor calculation." << endl;
- // Build the reference tree.
+ // Because we may construct it differently, we need a pointer.
+ NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
+ RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+ tree::RTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat> >* allknn = NULL;
+
+ // Build trees by hand, so we can save memory: if we pass a tree to
+ // NeighborSearch, it does not copy the matrix.
Log::Info << "Building reference tree..." << endl;
Timer::Start("tree_building");
+
+ RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+ tree::RTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat>
+ refTree(referenceData, leafSize, leafSize/3, 5, 2, 0);
+
+ RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+ tree::RTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat>*
+ queryTree = NULL; // Empty for now.
+
+ Timer::Stop("tree_building");
+
+ if (CLI::GetParam<string>("query_file") != "")
+ {
+ Log::Info << "Loaded query data from '" << queryFile << "' ("
+ << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
+
+ allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
- tree::RTreeDescentHeuristic,
- NeighborSearchStat<NearestNeighborSort>,
- arma::mat>
- refTree(referenceData, leafSize, leafSize/3, 5, 2, 0);
+ tree::RTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat> >(&refTree, queryTree,
+ referenceData, queryData, singleMode);
+ } else
+ {
+ allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
+ RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+ tree::RTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat> >(&refTree,
+ referenceData, singleMode);
+ }
+ Log::Info << "Tree built." << endl;
+
+ arma::mat distancesOut;
+ arma::Mat<size_t> neighborsOut;
+
+ Log::Info << "Computing " << k << " nearest neighbors..." << endl;
+ allknn->Search(k, neighborsOut, distancesOut);
+
+ Log::Info << "Neighbors computed." << endl;
+
+
+ delete allknn;
+
+
+ // Build the reference tree.
+ Log::Info << "Building reference tree..." << endl;
+ Timer::Start("tree_building");
+
Timer::Stop("tree_building");
- std::cout << "completed tree building " << refTree.NumDescendants() << std::endl;
}
}
else // Cover trees.
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index dac46c0..bcada1f 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -60,7 +60,7 @@ NeighborSearch(const typename TreeType::Mat& referenceSetIn,
metric(metric)
{
// C++11 will allow us to call out to other constructors so we can avoid this
- // copypasta problem.
+ // copy/paste problem.
// We'll time tree building, but only if we are building trees.
Timer::Start("tree_building");
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
index c96b10a..5b4eca9 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
@@ -56,7 +56,7 @@ BaseCase(const size_t queryIndex, const size_t referenceIndex)
++baseCases;
// If this distance is better than any of the current candidates, the
- // SortDistance() function will give us the poto insert it into.
+ // SortDistance() function will give us the position to insert it into.
arma::vec queryDist = distances.unsafe_col(queryIndex);
arma::Col<size_t> queryIndices = neighbors.unsafe_col(queryIndex);
const size_t insertPosition = SortPolicy::SortDistance(queryDist,
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list