[mlpack] 136/324: R tree traversal test code.

Barak A. Pearlmutter barak+git at cs.nuim.ie
Sun Aug 17 08:22:03 UTC 2014


This is an automated email from the git hooks/post-receive script.

bap pushed a commit to branch svn-trunk
in repository mlpack.

commit 8e5d648ffd2122f2a0457e4502413797e9a0eab0
Author: andrewmw94 <andrewmw94 at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date:   Mon Jul 7 20:45:00 2014 +0000

    R tree traversal test code.
    
    git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@16775 9d5b8971-822b-0410-80eb-d18c1038ef23
---
 .../core/tree/rectangle_tree/r_tree_split_impl.hpp |  7 ++--
 src/mlpack/tests/rectangle_tree_test.cpp           | 37 +++++++++++++++++++++-
 2 files changed, 41 insertions(+), 3 deletions(-)

diff --git a/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
index ab9c55a..6e587d0 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
@@ -115,6 +115,9 @@ bool RTreeSplit<DescentType, StatisticType, MatType>::SplitNonLeafNode(
   int j = 0;
   GetBoundSeeds(*tree, &i, &j);
   
+
+  if(i == j)
+    std::cout << i << ", " << j << "; " << tree->NumChildren() << std::endl;
   assert(i != j);
   
   RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType,  StatisticType, MatType>* treeOne = new 
@@ -187,7 +190,7 @@ void RTreeSplit<DescentType, StatisticType, MatType>::GetPointSeeds(
   // Here we want to find the pair of points that it is worst to place in the same
   // node.  Because we are just using points, we will simply choose the two that would
   // create the most voluminous hyperrectangle.
-  double worstPairScore = 0.0;
+  double worstPairScore = -1.0;
   int worstI = 0;
   int worstJ = 0;
   for(int i = 0; i < tree.Count(); i++) {
@@ -221,7 +224,7 @@ void RTreeSplit<DescentType, StatisticType, MatType>::GetBoundSeeds(
   int* iRet,
   int* jRet)
 {
-  double worstPairScore = 0.0;
+  double worstPairScore = -1.0;
   int worstI = 0;
   int worstJ = 0;
   for(int i = 0; i < tree.NumChildren(); i++) {
diff --git a/src/mlpack/tests/rectangle_tree_test.cpp b/src/mlpack/tests/rectangle_tree_test.cpp
index 5e00d75..c6bb651 100644
--- a/src/mlpack/tests/rectangle_tree_test.cpp
+++ b/src/mlpack/tests/rectangle_tree_test.cpp
@@ -54,7 +54,6 @@ BOOST_AUTO_TEST_CASE(RectangleTreeConstructionCountTest)
                       NeighborSearchStat<NearestNeighborSort>,
                       arma::mat> tree(dataset, 20, 6, 5, 2, 0);
   BOOST_REQUIRE_EQUAL(tree.NumDescendants(), 1000);
-  std::cout << tree.ToString() << std::endl;
 }
 
 std::vector<arma::vec*> getAllPointsInTree(const RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
@@ -139,5 +138,41 @@ BOOST_AUTO_TEST_CASE(RectangleTreeContainmentTest)
   assert(checkContainment(tree) == true);
 }
 
+BOOST_AUTO_TEST_CASE(SingleTreeTraverserTest)
+{
+  arma::mat dataset;
+  dataset.randu(8, 1000); // 1000 points in 8 dimensions.
+  arma::Mat<size_t> neighbors1;
+  arma::mat distances1;
+  arma::Mat<size_t> neighbors2;
+  arma::mat distances2;
+  
+  RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+                      tree::RTreeDescentHeuristic,
+                      NeighborSearchStat<NearestNeighborSort>,
+                      arma::mat> RTree(dataset, 20, 6, 5, 2, 0);
+
+  // nearest neighbor search with the R tree.
+  mlpack::neighbor::NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
+        RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+	  	      tree::RTreeDescentHeuristic,
+  		      NeighborSearchStat<NearestNeighborSort>,
+  		      arma::mat> > allknn1(&RTree,
+        dataset, true);
+        
+  allknn1.Search(5, neighbors1, distances1);
+
+  // nearest neighbor search the naive way.
+  mlpack::neighbor::AllkNN allknn2(dataset,
+        true, true);
+
+  allknn2.Search(5, neighbors2, distances2);
+  
+  for(size_t i = 0; i < neighbors1.size(); i++) {
+    assert(neighbors1[i] == neighbors2[i]);
+    assert(distances1[i] == distances2[i]);
+  }
+}
+
 
 BOOST_AUTO_TEST_SUITE_END();

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git



More information about the debian-science-commits mailing list