[mlpack] 99/324: rectangle tree traverser

Barak A. Pearlmutter barak+git at cs.nuim.ie
Sun Aug 17 08:22:00 UTC 2014


This is an automated email from the git hooks/post-receive script.

bap pushed a commit to branch svn-trunk
in repository mlpack.

commit f8dd10288a0d4bae122b35099d2bd108673bc2f5
Author: andrewmw94 <andrewmw94 at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date:   Wed Jul 2 17:28:23 2014 +0000

    rectangle tree traverser
    
    git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@16737 9d5b8971-822b-0410-80eb-d18c1038ef23
---
 src/mlpack/core/tree/CMakeLists.txt                |  6 +-
 src/mlpack/core/tree/rectangle_tree.hpp            |  3 +-
 .../tree/rectangle_tree/dual_tree_traverser.hpp    | 90 ++++++++++++++++++++++
 .../rectangle_tree/dual_tree_traverser_impl.hpp    | 49 ++++++++++++
 .../core/tree/rectangle_tree/rectangle_tree.hpp    |  9 ++-
 src/mlpack/methods/neighbor_search/allknn_main.cpp | 70 +++++++++++++++--
 .../neighbor_search/neighbor_search_impl.hpp       |  2 +-
 .../neighbor_search/neighbor_search_rules_impl.hpp |  2 +-
 8 files changed, 216 insertions(+), 15 deletions(-)

diff --git a/src/mlpack/core/tree/CMakeLists.txt b/src/mlpack/core/tree/CMakeLists.txt
index c526a70..7988285 100644
--- a/src/mlpack/core/tree/CMakeLists.txt
+++ b/src/mlpack/core/tree/CMakeLists.txt
@@ -32,8 +32,10 @@ set(SOURCES
   rectangle_tree.hpp
   rectangle_tree/rectangle_tree.hpp
   rectangle_tree/rectangle_tree_impl.hpp
-  rectangle_tree/rectangle_tree_traverser.hpp
-  rectangle_tree/rectangle_tree_traverser_impl.hpp
+  rectangle_tree/single_tree_traverser.hpp
+  rectangle_tree/single_tree_traverser_impl.hpp
+  rectangle_tree/dual_tree_traverser.hpp
+  rectangle_tree/dual_tree_traverser_impl.hpp
   rectangle_tree/r_tree_split.hpp
   rectangle_tree/r_tree_split_impl.hpp
   rectangle_tree/r_tree_descent_heuristic.hpp
diff --git a/src/mlpack/core/tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree.hpp
index 743f7f5..f0d156e 100644
--- a/src/mlpack/core/tree/rectangle_tree.hpp
+++ b/src/mlpack/core/tree/rectangle_tree.hpp
@@ -13,7 +13,8 @@
  */ 
 #include "bounds.hpp"
 #include "rectangle_tree/rectangle_tree.hpp"
-#include "rectangle_tree/rectangle_tree_traverser.hpp"
+#include "rectangle_tree/single_tree_traverser.hpp"
+#include "rectangle_tree/dual_tree_traverser.hpp"
 #include "rectangle_tree/r_tree_split.hpp"
 #include "rectangle_tree/r_tree_descent_heuristic.hpp"
 #include "rectangle_tree/traits.hpp"
diff --git a/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp
new file mode 100644
index 0000000..1091224
--- /dev/null
+++ b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp
@@ -0,0 +1,90 @@
+/**
+  * @file dual_tree_traverser.hpp
+  * @author Andrew Wells
+  *
+  * A nested class of Rectangle Tree for traversing rectangle type trees
+  * with a given set of rules which indicate the branches to prune and the
+  * order in which to recurse.  This is just here to make it compile.
+  */
+#ifndef __MLPACK_CORE_TREE_RECTANGLE_TREE_DUAL_TREE_TRAVERSER_HPP
+#define __MLPACK_CORE_TREE_RECTANGLE_TREE_DUAL_TREE_TRAVERSER_HPP
+
+#include <mlpack/core.hpp>
+
+#include "rectangle_tree.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename SplitType,
+         typename DescentType,
+	 typename StatisticType,
+         typename MatType>
+template<typename RuleType>
+class RectangleTree<SplitType, DescentType, StatisticType, MatType>::
+    DualTreeTraverser
+{
+ public:
+  /**
+   * Instantiate the dual-tree traverser with the given rule set.
+   */
+  DualTreeTraverser(RuleType& rule);
+
+  /**
+   * Traverse the two trees.  This does not reset the number of prunes.
+   *
+   * @param queryNode The query node to be traversed.
+   * @param referenceNode The reference node to be traversed.
+   * @param score The score of the current node combination.
+   */
+  void Traverse(RectangleTree<SplitType, DescentType, StatisticType, MatType>& queryNode,
+		RectangleTree<SplitType, DescentType, StatisticType, MatType>& referenceNode);
+
+  //! Get the number of prunes.
+  size_t NumPrunes() const { return numPrunes; }
+  //! Modify the number of prunes.
+  size_t& NumPrunes() { return numPrunes; }
+
+  //! Get the number of visited combinations.
+  size_t NumVisited() const { return numVisited; }
+  //! Modify the number of visited combinations.
+  size_t& NumVisited() { return numVisited; }
+
+  //! Get the number of times a node combination was scored.
+  size_t NumScores() const { return numScores; }
+  //! Modify the number of times a node combination was scored.
+  size_t& NumScores() { return numScores; }
+
+  //! Get the number of times a base case was calculated.
+  size_t NumBaseCases() const { return numBaseCases; }
+  //! Modify the number of times a base case was calculated.
+  size_t& NumBaseCases() { return numBaseCases; }
+
+ private:
+  //! Reference to the rules with which the trees will be traversed.
+  RuleType& rule;
+
+  //! The number of prunes.
+  size_t numPrunes;
+
+  //! The number of node combinations that have been visited during traversal.
+  size_t numVisited;
+
+  //! The number of times a node combination was scored.
+  size_t numScores;
+
+  //! The number of times a base case was calculated.
+  size_t numBaseCases;
+
+  //! Traversal information, held in the class so that it isn't continually
+  //! being reallocated.
+  typename RuleType::TraversalInfoType traversalInfo;
+};
+
+}; // namespace tree
+}; // namespace mlpack
+
+// Include implementation.
+#include "dual_tree_traverser_impl.hpp"
+
+#endif
diff --git a/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
new file mode 100644
index 0000000..8af988e
--- /dev/null
+++ b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
@@ -0,0 +1,49 @@
+/**
+  * @file dual_tree_traverser_impl.hpp
+  * @author Andrew Wells
+  *
+  * A class for traversing rectangle type trees with a given set of rules
+  * which indicate the branches to prune and the order in which to recurse.
+  * This is a depth-first traverser.
+  */
+#ifndef __MLPAC_CORE_TREE_RECTANGLE_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP
+#define __MLPAC_CORE_TREE_RECTANGLE_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP
+
+#include "dual_tree_traverser.hpp"
+
+#include <algorithm>
+#include <stack>
+
+namespace mlpack {
+namespace tree {
+
+template<typename SplitType,
+         typename DescentType,
+	 typename StatisticType,
+         typename MatType>
+template<typename RuleType>
+RectangleTree<SplitType, DescentType, StatisticType, MatType>::
+DualTreeTraverser<RuleType>::DualTreeTraverser(RuleType& rule) :
+    rule(rule),
+    numPrunes(0)
+{ /* Nothing to do */ }
+
+template<typename SplitType,
+         typename DescentType,
+	 typename StatisticType,
+         typename MatType>
+template<typename RuleType>
+void RectangleTree<SplitType, DescentType, StatisticType, MatType>::
+DualTreeTraverser<RuleType>::Traverse(RectangleTree<SplitType, DescentType, StatisticType, MatType>& queryNode,
+		RectangleTree<SplitType, DescentType, StatisticType, MatType>& referenceNode)
+{
+  //Do nothing.  Just here to prevent warnings.
+  if(queryNode.NumDescendants() > referenceNode.NumDescendants())
+    return;
+  return;
+}
+
+}; // namespace tree
+}; // namespace mlpack
+
+#endif
\ No newline at end of file
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
index e6833b9..e50d372 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
@@ -79,10 +79,13 @@ class RectangleTree
   //! So other classes can use TreeType::Mat.
   typedef MatType Mat;
 
-  //! A traverser for rectangle type trees.  See
-  //! rectangle_tree_traverser.hpp for implementation.
+  //! A single traverser for rectangle type trees.  See
+  //! single_tree_traverser.hpp for implementation.
   template<typename RuleType>
-  class RectangleTreeTraverser;
+  class SingleTreeTraverser;
+  //! A dual tree traverser for rectangle type trees.
+  template<typename RuleType>
+  class DualTreeTraverser;
 
   /**
    * Construct this as the root node of a rectangle type tree using the given
diff --git a/src/mlpack/methods/neighbor_search/allknn_main.cpp b/src/mlpack/methods/neighbor_search/allknn_main.cpp
index 432c23e..b4dc0cf 100644
--- a/src/mlpack/methods/neighbor_search/allknn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/allknn_main.cpp
@@ -58,7 +58,7 @@ PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
 PARAM_FLAG("cover_tree", "If true, use cover trees to perform the search "
     "(experimental, may be slow).", "c");
 PARAM_FLAG("r_tree", "If true, use an R-Tree to perform the search "
-    "(experimental, may be slow.  Currently automatically sets single_mode.).", "R");
+    "(experimental, may be slow.  Currently automatically sets single_mode.).", "T");
 PARAM_FLAG("random_basis", "Before tree-building, project the data onto a "
     "random orthogonal basis.", "R");
 PARAM_INT("seed", "Random seed (if 0, std::time(NULL) is used).", "s", 0);
@@ -132,6 +132,7 @@ int main(int argc, char *argv[])
   } else if (!singleMode && CLI::HasParam("r_tree"))  // R_tree requires single mode.
   {
     Log::Warn << "--single_mode assumed because --r_tree is present." << endl;
+    singleMode = true;
   }
   
   if (naive)
@@ -269,16 +270,71 @@ int main(int argc, char *argv[])
       // Make sure to notify the user that they are using an r tree.
       Log::Info << "Using r tree for nearest-neighbor calculation." << endl;
       
-      // Build the reference tree.
+      // Because we may construct it differently, we need a pointer.
+      NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
+      RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+		    tree::RTreeDescentHeuristic,
+		    NeighborSearchStat<NearestNeighborSort>,
+		    arma::mat> >* allknn = NULL;
+
+      // Build trees by hand, so we can save memory: if we pass a tree to
+      // NeighborSearch, it does not copy the matrix.
       Log::Info << "Building reference tree..." << endl;
       Timer::Start("tree_building");
+
+      RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+		    tree::RTreeDescentHeuristic,
+		    NeighborSearchStat<NearestNeighborSort>,
+		    arma::mat>
+      refTree(referenceData, leafSize, leafSize/3, 5, 2, 0);
+
+      RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+		    tree::RTreeDescentHeuristic,
+		    NeighborSearchStat<NearestNeighborSort>,
+		    arma::mat>*
+      queryTree = NULL; // Empty for now.
+
+      Timer::Stop("tree_building");
+
+      if (CLI::GetParam<string>("query_file") != "")
+      {
+	Log::Info << "Loaded query data from '" << queryFile << "' ("
+	    << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
+
+        allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
         RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
-                      tree::RTreeDescentHeuristic,
-                      NeighborSearchStat<NearestNeighborSort>,
-                      arma::mat>
-        refTree(referenceData, leafSize, leafSize/3, 5, 2, 0);
+	  	      tree::RTreeDescentHeuristic,
+  		      NeighborSearchStat<NearestNeighborSort>,
+  		      arma::mat> >(&refTree, queryTree,
+          referenceData, queryData, singleMode);
+      } else
+      {
+	      allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
+      RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+		    tree::RTreeDescentHeuristic,
+		    NeighborSearchStat<NearestNeighborSort>,
+		    arma::mat> >(&refTree,
+        referenceData, singleMode);
+      }
+      Log::Info << "Tree built." << endl;
+      
+      arma::mat distancesOut;
+      arma::Mat<size_t> neighborsOut;
+
+      Log::Info << "Computing " << k << " nearest neighbors..." << endl;
+      allknn->Search(k, neighborsOut, distancesOut);
+
+      Log::Info << "Neighbors computed." << endl;
+
+
+      delete allknn;
+            
+      
+      // Build the reference tree.
+      Log::Info << "Building reference tree..." << endl;
+      Timer::Start("tree_building");
+
       Timer::Stop("tree_building");
-      std::cout << "completed tree building " << refTree.NumDescendants() << std::endl;
     }
   }
   else // Cover trees.
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index dac46c0..bcada1f 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -60,7 +60,7 @@ NeighborSearch(const typename TreeType::Mat& referenceSetIn,
     metric(metric)
 {
   // C++11 will allow us to call out to other constructors so we can avoid this
-  // copypasta problem.
+  // copy/paste problem.
 
   // We'll time tree building, but only if we are building trees.
   Timer::Start("tree_building");
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
index c96b10a..5b4eca9 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
@@ -56,7 +56,7 @@ BaseCase(const size_t queryIndex, const size_t referenceIndex)
   ++baseCases;
 
   // If this distance is better than any of the current candidates, the
-  // SortDistance() function will give us the poto insert it into.
+  // SortDistance() function will give us the position to insert it into.
   arma::vec queryDist = distances.unsafe_col(queryIndex);
   arma::Col<size_t> queryIndices = neighbors.unsafe_col(queryIndex);
   const size_t insertPosition = SortPolicy::SortDistance(queryDist,

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git



More information about the debian-science-commits mailing list