[mlpack] 29/58: dual tree traverser bug fixes.

Barak A. Pearlmutter barak+git at cs.nuim.ie
Tue Sep 9 13:19:41 UTC 2014


This is an automated email from the git hooks/post-receive script.

bap pushed a commit to branch svn-trunk
in repository mlpack.

commit 005999bd7ecb05fe918732e1ce10e0c1dbc5026b
Author: andrewmw94 <andrewmw94 at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date:   Wed Aug 20 17:23:37 2014 +0000

    dual tree traverser bug fixes.
    
    git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@17084 9d5b8971-822b-0410-80eb-d18c1038ef23
---
 .../core/tree/rectangle_tree/dual_tree_traverser.hpp     |  1 +
 .../tree/rectangle_tree/dual_tree_traverser_impl.hpp     | 16 +++++++---------
 2 files changed, 8 insertions(+), 9 deletions(-)

diff --git a/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp
index 6610da5..50f971d 100644
--- a/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp
@@ -67,6 +67,7 @@ class RectangleTree<SplitType, DescentType, StatisticType, MatType>::
   public:
     RectangleTree<SplitType, DescentType, StatisticType, MatType>* node;
     double score;
+    typename RuleType::TraversalInfoType travInfo;
   };
 
   static bool nodeComparator(const NodeAndScore& obj1,
diff --git a/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
index 7ac9b20..8bdad30 100644
--- a/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
@@ -76,13 +76,7 @@ DualTreeTraverser<RuleType>::Traverse(
   else if(!queryNode.IsLeaf() && referenceNode.IsLeaf())
   {
     // We only need to traverse down the query node.  Order doesn't matter here.
-    ++numScores;
-    if(rule.Score(queryNode.Child(0), referenceNode) < DBL_MAX)
-      Traverse(queryNode.Child(0), referenceNode);
-    else
-      numPrunes++;
-    
-    for(size_t i = 1; i < queryNode.NumChildren(); ++i)
+    for(size_t i = 0; i < queryNode.NumChildren(); ++i)
     {
       // Before recursing, we have to set the traversal information correctly.
       rule.TraversalInfo() = traversalInfo;
@@ -101,16 +95,18 @@ DualTreeTraverser<RuleType>::Traverse(
     std::vector<NodeAndScore> nodesAndScores(referenceNode.NumChildren());
     for(int i = 0; i < referenceNode.NumChildren(); i++)
     {
+      rule.TraversalInfo() = traversalInfo;
       nodesAndScores[i].node = referenceNode.Children()[i];
       nodesAndScores[i].score = rule.Score(queryNode, *(nodesAndScores[i].node));
+      nodesAndScores[i].travInfo = rule.TraversalInfo();
     }
     std::sort(nodesAndScores.begin(), nodesAndScores.end(), nodeComparator);
     numScores += nodesAndScores.size();
     
     for(int i = 0; i < nodesAndScores.size(); i++)
     {
+      rule.TraversalInfo() = nodesAndScores[i].travInfo;
       if(rule.Rescore(queryNode, *(nodesAndScores[i].node), nodesAndScores[i].score) < DBL_MAX) {
-        rule.TraversalInfo() = traversalInfo;
         Traverse(queryNode, *(nodesAndScores[i].node));
       } else {
         numPrunes += nodesAndScores.size() - i;
@@ -130,16 +126,18 @@ DualTreeTraverser<RuleType>::Traverse(
       std::vector<NodeAndScore> nodesAndScores(referenceNode.NumChildren());
       for(int i = 0; i < referenceNode.NumChildren(); i++)
       {
+        rule.TraversalInfo() = traversalInfo;
         nodesAndScores[i].node = referenceNode.Children()[i];
         nodesAndScores[i].score = rule.Score(queryNode, *nodesAndScores[i].node);
+        nodesAndScores[i].travInfo = rule.TraversalInfo();
       }
       std::sort(nodesAndScores.begin(), nodesAndScores.end(), nodeComparator);
       numScores += nodesAndScores.size();
     
       for(int i = 0; i < nodesAndScores.size(); i++)
       {
+        rule.TraversalInfo() = nodesAndScores[i].travInfo;
         if(rule.Rescore(queryNode, *(nodesAndScores[i].node), nodesAndScores[i].score) < DBL_MAX) {
-          rule.TraversalInfo() = traversalInfo;
           Traverse(queryNode, *(nodesAndScores[i].node));
         } else {
           numPrunes += nodesAndScores.size() - i;

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git



More information about the debian-science-commits mailing list