[mlpack] 29/58: dual tree traverser bug fixes.
Barak A. Pearlmutter
barak+git at cs.nuim.ie
Tue Sep 9 13:19:41 UTC 2014
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch svn-trunk
in repository mlpack.
commit 005999bd7ecb05fe918732e1ce10e0c1dbc5026b
Author: andrewmw94 <andrewmw94 at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date: Wed Aug 20 17:23:37 2014 +0000
dual tree traverser bug fixes.
git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@17084 9d5b8971-822b-0410-80eb-d18c1038ef23
---
.../core/tree/rectangle_tree/dual_tree_traverser.hpp | 1 +
.../tree/rectangle_tree/dual_tree_traverser_impl.hpp | 16 +++++++---------
2 files changed, 8 insertions(+), 9 deletions(-)
diff --git a/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp
index 6610da5..50f971d 100644
--- a/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp
@@ -67,6 +67,7 @@ class RectangleTree<SplitType, DescentType, StatisticType, MatType>::
public:
RectangleTree<SplitType, DescentType, StatisticType, MatType>* node;
double score;
+ typename RuleType::TraversalInfoType travInfo;
};
static bool nodeComparator(const NodeAndScore& obj1,
diff --git a/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
index 7ac9b20..8bdad30 100644
--- a/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
@@ -76,13 +76,7 @@ DualTreeTraverser<RuleType>::Traverse(
else if(!queryNode.IsLeaf() && referenceNode.IsLeaf())
{
// We only need to traverse down the query node. Order doesn't matter here.
- ++numScores;
- if(rule.Score(queryNode.Child(0), referenceNode) < DBL_MAX)
- Traverse(queryNode.Child(0), referenceNode);
- else
- numPrunes++;
-
- for(size_t i = 1; i < queryNode.NumChildren(); ++i)
+ for(size_t i = 0; i < queryNode.NumChildren(); ++i)
{
// Before recursing, we have to set the traversal information correctly.
rule.TraversalInfo() = traversalInfo;
@@ -101,16 +95,18 @@ DualTreeTraverser<RuleType>::Traverse(
std::vector<NodeAndScore> nodesAndScores(referenceNode.NumChildren());
for(int i = 0; i < referenceNode.NumChildren(); i++)
{
+ rule.TraversalInfo() = traversalInfo;
nodesAndScores[i].node = referenceNode.Children()[i];
nodesAndScores[i].score = rule.Score(queryNode, *(nodesAndScores[i].node));
+ nodesAndScores[i].travInfo = rule.TraversalInfo();
}
std::sort(nodesAndScores.begin(), nodesAndScores.end(), nodeComparator);
numScores += nodesAndScores.size();
for(int i = 0; i < nodesAndScores.size(); i++)
{
+ rule.TraversalInfo() = nodesAndScores[i].travInfo;
if(rule.Rescore(queryNode, *(nodesAndScores[i].node), nodesAndScores[i].score) < DBL_MAX) {
- rule.TraversalInfo() = traversalInfo;
Traverse(queryNode, *(nodesAndScores[i].node));
} else {
numPrunes += nodesAndScores.size() - i;
@@ -130,16 +126,18 @@ DualTreeTraverser<RuleType>::Traverse(
std::vector<NodeAndScore> nodesAndScores(referenceNode.NumChildren());
for(int i = 0; i < referenceNode.NumChildren(); i++)
{
+ rule.TraversalInfo() = traversalInfo;
nodesAndScores[i].node = referenceNode.Children()[i];
nodesAndScores[i].score = rule.Score(queryNode, *nodesAndScores[i].node);
+ nodesAndScores[i].travInfo = rule.TraversalInfo();
}
std::sort(nodesAndScores.begin(), nodesAndScores.end(), nodeComparator);
numScores += nodesAndScores.size();
for(int i = 0; i < nodesAndScores.size(); i++)
{
+ rule.TraversalInfo() = nodesAndScores[i].travInfo;
if(rule.Rescore(queryNode, *(nodesAndScores[i].node), nodesAndScores[i].score) < DBL_MAX) {
- rule.TraversalInfo() = traversalInfo;
Traverse(queryNode, *(nodesAndScores[i].node));
} else {
numPrunes += nodesAndScores.size() - i;
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list