[mlpack] 75/149: Add a semi-hackish breadth-first traverser. The tree abstractions will need to change to support arbitrary traverser types (probably by adding a template parameter) but for now this works to make DualTreeKMeans work.
Barak A. Pearlmutter
barak+git at pearlmutter.net
Sat May 2 09:11:11 UTC 2015
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch svn-trunk
in repository mlpack.
commit a4750ff8cd469b48905b9a4e8b9f9d1987725575
Author: rcurtin <rcurtin at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date: Wed Nov 5 21:40:37 2014 +0000
Add a semi-hackish breadth-first traverser. The tree abstractions will need to
change to support arbitrary traverser types (probably by adding a template
parameter) but for now this works to make DualTreeKMeans work.
git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@17303 9d5b8971-822b-0410-80eb-d18c1038ef23
---
.../tree/binary_space_tree/binary_space_tree.hpp | 3 +
.../breadth_first_dual_tree_traverser.hpp | 92 +++++
.../breadth_first_dual_tree_traverser_impl.hpp | 442 +++++++++++++++++++++
3 files changed, 537 insertions(+)
diff --git a/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp b/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
index eadd300..db1ece7 100644
--- a/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
@@ -87,6 +87,9 @@ class BinarySpaceTree
template<typename RuleType>
class DualTreeTraverser;
+ template<typename RuleType>
+ class BreadthFirstDualTreeTraverser;
+
/**
* Construct this as the root node of a binary space tree using the given
* dataset. This will modify the ordering of the points in the dataset!
diff --git a/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser.hpp b/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser.hpp
new file mode 100644
index 0000000..8a22a70
--- /dev/null
+++ b/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser.hpp
@@ -0,0 +1,92 @@
+/**
+ * @file breadth_first_dual_tree_traverser.hpp
+ * @author Ryan Curtin
+ *
+ * Defines the BreadthFirstDualTreeTraverser for the BinarySpaceTree tree type.
+ * This is a nested class of BinarySpaceTree which traverses two trees in a
+ * breadth-first manner with a given set of rules which indicate the branches
+ * which can be pruned and the order in which to recurse.
+ */
+#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_BREADTH_FIRST_DUAL_TREE_TRAVERSER_HPP
+#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_BREADTH_FIRST_DUAL_TREE_TRAVERSER_HPP
+
+#include <mlpack/core.hpp>
+
+#include "binary_space_tree.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
+template<typename RuleType>
+class BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+ BreadthFirstDualTreeTraverser
+{
+ public:
+ /**
+ * Instantiate the dual-tree traverser with the given rule set.
+ */
+ BreadthFirstDualTreeTraverser(RuleType& rule);
+
+ /**
+ * Traverse the two trees. This does not reset the number of prunes.
+ *
+ * @param queryNode The query node to be traversed.
+ * @param referenceNode The reference node to be traversed.
+ * @param score The score of the current node combination.
+ */
+ void Traverse(BinarySpaceTree& queryNode,
+ BinarySpaceTree& referenceNode);
+
+ //! Get the number of prunes.
+ size_t NumPrunes() const { return numPrunes; }
+ //! Modify the number of prunes.
+ size_t& NumPrunes() { return numPrunes; }
+
+ //! Get the number of visited combinations.
+ size_t NumVisited() const { return numVisited; }
+ //! Modify the number of visited combinations.
+ size_t& NumVisited() { return numVisited; }
+
+ //! Get the number of times a node combination was scored.
+ size_t NumScores() const { return numScores; }
+ //! Modify the number of times a node combination was scored.
+ size_t& NumScores() { return numScores; }
+
+ //! Get the number of times a base case was calculated.
+ size_t NumBaseCases() const { return numBaseCases; }
+ //! Modify the number of times a base case was calculated.
+ size_t& NumBaseCases() { return numBaseCases; }
+
+ private:
+ //! Reference to the rules with which the trees will be traversed.
+ RuleType& rule;
+
+ //! The number of prunes.
+ size_t numPrunes;
+
+ //! The number of node combinations that have been visited during traversal.
+ size_t numVisited;
+
+ //! The number of times a node combination was scored.
+ size_t numScores;
+
+ //! The number of times a base case was calculated.
+ size_t numBaseCases;
+
+ //! Traversal information, held in the class so that it isn't continually
+ //! being reallocated.
+ typename RuleType::TraversalInfoType traversalInfo;
+};
+
+}; // namespace tree
+}; // namespace mlpack
+
+// Include implementation.
+#include "breadth_first_dual_tree_traverser_impl.hpp"
+
+#endif // __MLPACK_CORE_TREE_BINARY_SPACE_TREE_BREADTH_FIRST_DUAL_TREE_TRAVERSER_HPP
+
diff --git a/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser_impl.hpp b/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser_impl.hpp
new file mode 100644
index 0000000..bd81df2
--- /dev/null
+++ b/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser_impl.hpp
@@ -0,0 +1,442 @@
+/**
+ * @file breadth_first_dual_tree_traverser_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the BreadthFirstDualTreeTraverser for BinarySpaceTree.
+ * This is a way to perform a dual-tree traversal of two trees. The trees must
+ * be the same type.
+ */
+#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_BREADTH_FIRST_DUAL_TREE_TRAVERSER_IMPL_HPP
+#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_BREADTH_FIRST_DUAL_TREE_TRAVERSER_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "breadth_first_dual_tree_traverser.hpp"
+
+#include <queue>
+
+namespace mlpack {
+namespace tree {
+
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
+template<typename RuleType>
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+BreadthFirstDualTreeTraverser<RuleType>::BreadthFirstDualTreeTraverser(
+ RuleType& rule) :
+ rule(rule),
+ numPrunes(0),
+ numVisited(0),
+ numScores(0),
+ numBaseCases(0)
+{ /* Nothing to do. */ }
+
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
+template<typename RuleType>
+void BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+BreadthFirstDualTreeTraverser<RuleType>::Traverse(
+ BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>& queryRoot,
+ BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>&
+ referenceRoot)
+{
+ // Increment the visit counter.
+ ++numVisited;
+
+ // Store the current traversal info.
+ traversalInfo = rule.TraversalInfo();
+
+ typedef BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>
+ TreeType;
+
+ std::queue<TreeType*> queryList;
+ std::queue<TreeType*> referenceList;
+ std::queue<typename RuleType::TraversalInfoType> traversalInfos;
+ queryList.push(&queryRoot);
+ referenceList.push(&referenceRoot);
+ traversalInfos.push(rule.TraversalInfo());
+
+ while (!queryList.empty())
+ {
+ TreeType& queryNode = *queryList.front();
+ TreeType& referenceNode = *referenceList.front();
+ typename RuleType::TraversalInfoType ti = traversalInfos.front();
+
+ queryList.pop();
+ referenceList.pop();
+ traversalInfos.pop();
+
+ rule.TraversalInfo() = ti;
+
+ // If both are leaves, we must evaluate the base case.
+ if (queryNode.IsLeaf() && referenceNode.IsLeaf())
+ {
+ // Loop through each of the points in each node.
+ for (size_t query = queryNode.Begin(); query < queryNode.End(); ++query)
+ {
+ // See if we need to investigate this point (this function should be
+ // implemented for the single-tree recursion too). Restore the traversal
+ // information first.
+// const double childScore = rule.Score(query, referenceNode);
+
+// if (childScore == DBL_MAX)
+// continue; // We can't improve this particular point.
+
+ for (size_t ref = referenceNode.Begin(); ref < referenceNode.End(); ++ref)
+ rule.BaseCase(query, ref);
+
+ numBaseCases += referenceNode.Count();
+ }
+ }
+ else if ((!queryNode.IsLeaf()) && referenceNode.IsLeaf())
+ {
+ // We have to recurse down the query node. In this case the recursion order
+ // does not matter.
+ const double leftScore = rule.Score(*queryNode.Left(), referenceNode);
+ ++numScores;
+
+ if (leftScore != DBL_MAX)
+ {
+ queryList.push(queryNode.Left());
+ referenceList.push(&referenceNode);
+ traversalInfos.push(rule.TraversalInfo());
+// Log::Debug << "Push1 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+// << referenceList.back()->Count() << "\n";
+ }
+ else
+ {
+ ++numPrunes;
+ }
+
+ // Before recursing, we have to set the traversal information correctly.
+ rule.TraversalInfo() = ti;
+ const double rightScore = rule.Score(*queryNode.Right(), referenceNode);
+ ++numScores;
+
+ if (rightScore != DBL_MAX)
+ {
+ queryList.push(queryNode.Right());
+ referenceList.push(&referenceNode);
+ traversalInfos.push(rule.TraversalInfo());
+// Log::Debug << "Push2 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+// << referenceList.back()->Count() << "\n";
+ }
+ else
+ ++numPrunes;
+ }
+ else if (queryNode.IsLeaf() && (!referenceNode.IsLeaf()))
+ {
+ // We have to recurse down the reference node. In this case the recursion
+ // order does matter. Before recursing, though, we have to set the
+ // traversal information correctly.
+ double leftScore = rule.Score(queryNode, *referenceNode.Left());
+ typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo();
+ rule.TraversalInfo() = ti;
+ double rightScore = rule.Score(queryNode, *referenceNode.Right());
+ numScores += 2;
+
+ if (leftScore < rightScore)
+ {
+ // Recurse to the left. Restore the left traversal info. Store the right
+ // traversal info.
+ queryList.push(&queryNode);
+ referenceList.push(referenceNode.Left());
+ traversalInfos.push(leftInfo);
+// Log::Debug << "Push3 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+// << referenceList.back()->Count() << "\n";
+
+ // Is it still valid to recurse to the right?
+ rightScore = rule.Rescore(queryNode, *referenceNode.Right(), rightScore);
+
+ if (rightScore != DBL_MAX)
+ {
+ // Restore the right traversal info.
+ queryList.push(&queryNode);
+ referenceList.push(referenceNode.Right());
+ traversalInfos.push(rule.TraversalInfo());
+// Log::Debug << "Push4 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+// << referenceList.back()->Count() << "\n";
+ }
+ else
+ ++numPrunes;
+ }
+ else if (rightScore < leftScore)
+ {
+ // Recurse to the right.
+ queryList.push(&queryNode);
+ referenceList.push(referenceNode.Right());
+ traversalInfos.push(rule.TraversalInfo());
+// Log::Debug << "Push5 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+// << referenceList.back()->Count() << "\n";
+
+ // Is it still valid to recurse to the left?
+ leftScore = rule.Rescore(queryNode, *referenceNode.Left(), leftScore);
+
+ if (leftScore != DBL_MAX)
+ {
+ // Restore the left traversal info.
+ queryList.push(&queryNode);
+ referenceList.push(referenceNode.Left());
+ traversalInfos.push(leftInfo);
+// Log::Debug << "Push6 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+// << referenceList.back()->Count() << "\n";
+ }
+ else
+ ++numPrunes;
+ }
+ else // leftScore is equal to rightScore.
+ {
+ if (leftScore == DBL_MAX)
+ {
+ numPrunes += 2;
+ }
+ else
+ {
+ // Choose the left first. Restore the left traversal info. Store the
+ // right traversal info.
+ queryList.push(&queryNode);
+ referenceList.push(referenceNode.Left());
+ traversalInfos.push(leftInfo);
+// Log::Debug << "Push7 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+// << referenceList.back()->Count() << "\n";
+
+ rightScore = rule.Rescore(queryNode, *referenceNode.Right(),
+ rightScore);
+
+ if (rightScore != DBL_MAX)
+ {
+ // Restore the right traversal info.
+ queryList.push(&queryNode);
+ referenceList.push(referenceNode.Right());
+ traversalInfos.push(rule.TraversalInfo());
+// Log::Debug << "Push8 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+// << referenceList.back()->Count() << "\n";
+ }
+ else
+ ++numPrunes;
+ }
+ }
+ }
+ else
+ {
+ // We have to recurse down both query and reference nodes. Because the
+ // query descent order does not matter, we will go to the left query child
+ // first. Before recursing, we have to set the traversal information
+ // correctly.
+ double leftScore = rule.Score(*queryNode.Left(), *referenceNode.Left());
+ typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo();
+ rule.TraversalInfo() = ti;
+ double rightScore = rule.Score(*queryNode.Left(), *referenceNode.Right());
+ typename RuleType::TraversalInfoType rightInfo;
+ numScores += 2;
+
+ if (leftScore < rightScore)
+ {
+ // Recurse to the left. Restore the left traversal info. Store the right
+ // traversal info.
+ queryList.push(queryNode.Left());
+ referenceList.push(referenceNode.Left());
+ traversalInfos.push(leftInfo);
+// Log::Debug << "Push9 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+// << referenceList.back()->Count() << "\n";
+
+ // Is it still valid to recurse to the right?
+ rightScore = rule.Rescore(*queryNode.Left(), *referenceNode.Right(),
+ rightScore);
+
+ if (rightScore != DBL_MAX)
+ {
+ // Restore the right traversal info.
+ queryList.push(queryNode.Left());
+ referenceList.push(referenceNode.Right());
+ traversalInfos.push(rule.TraversalInfo());
+// Log::Debug << "Push10 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+// << referenceList.back()->Count() << "\n";
+ }
+ else
+ ++numPrunes;
+ }
+ else if (rightScore < leftScore)
+ {
+ // Recurse to the right.
+ queryList.push(queryNode.Left());
+ referenceList.push(referenceNode.Right());
+ traversalInfos.push(rule.TraversalInfo());
+// Log::Debug << "Push11 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+// << referenceList.back()->Count() << "\n";
+
+ // Is it still valid to recurse to the left?
+ leftScore = rule.Rescore(*queryNode.Left(), *referenceNode.Left(),
+ leftScore);
+
+ if (leftScore != DBL_MAX)
+ {
+ // Restore the left traversal info.
+ queryList.push(queryNode.Left());
+ referenceList.push(referenceNode.Left());
+ traversalInfos.push(leftInfo);
+// Log::Debug << "Push12 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+// << referenceList.back()->Count() << "\n";
+ }
+ else
+ ++numPrunes;
+ }
+ else
+ {
+ if (leftScore == DBL_MAX)
+ {
+ numPrunes += 2;
+ }
+ else
+ {
+ // Choose the left first. Restore the left traversal info and store the
+ // right traversal info.
+ queryList.push(queryNode.Left());
+ referenceList.push(referenceNode.Left());
+ traversalInfos.push(leftInfo);
+// Log::Debug << "Push13 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+// << referenceList.back()->Count() << "\n";
+
+ // Is it still valid to recurse to the right?
+ rightScore = rule.Rescore(*queryNode.Left(), *referenceNode.Right(),
+ rightScore);
+
+ if (rightScore != DBL_MAX)
+ {
+ // Restore the right traversal information.
+ queryList.push(queryNode.Left());
+ referenceList.push(referenceNode.Right());
+ traversalInfos.push(rule.TraversalInfo());
+// Log::Debug << "Push14 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+// << referenceList.back()->Count() << "\n";
+ }
+ else
+ ++numPrunes;
+ }
+ }
+
+ // Restore the main traversal information.
+ rule.TraversalInfo() = ti;
+
+ // Now recurse down the right query node.
+ leftScore = rule.Score(*queryNode.Right(), *referenceNode.Left());
+ leftInfo = rule.TraversalInfo();
+ rule.TraversalInfo() = ti;
+ rightScore = rule.Score(*queryNode.Right(), *referenceNode.Right());
+ numScores += 2;
+
+ if (leftScore < rightScore)
+ {
+ // Recurse to the left. Restore the left traversal info. Store the right
+ // traversal info.
+ queryList.push(queryNode.Right());
+ referenceList.push(referenceNode.Left());
+ traversalInfos.push(leftInfo);
+// Log::Debug << "Push15 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+// << referenceList.back()->Count() << "\n";
+
+ // Is it still valid to recurse to the right?
+ rightScore = rule.Rescore(*queryNode.Right(), *referenceNode.Right(),
+ rightScore);
+
+ if (rightScore != DBL_MAX)
+ {
+ // Restore the right traversal info.
+ queryList.push(queryNode.Right());
+ referenceList.push(referenceNode.Right());
+ traversalInfos.push(rule.TraversalInfo());
+// Log::Debug << "Push16 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+// << referenceList.back()->Count() << "\n";
+ }
+ else
+ ++numPrunes;
+ }
+ else if (rightScore < leftScore)
+ {
+ // Recurse to the right.
+ queryList.push(queryNode.Right());
+ referenceList.push(referenceNode.Right());
+ traversalInfos.push(rule.TraversalInfo());
+// Log::Debug << "Push17 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+// << referenceList.back()->Count() << "\n";
+
+ // Is it still valid to recurse to the left?
+ leftScore = rule.Rescore(*queryNode.Right(), *referenceNode.Left(),
+ leftScore);
+
+ if (leftScore != DBL_MAX)
+ {
+ // Restore the left traversal info.
+ queryList.push(queryNode.Right());
+ referenceList.push(referenceNode.Left());
+ traversalInfos.push(leftInfo);
+// Log::Debug << "Push18 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+// << referenceList.back()->Count() << "\n";
+ }
+ else
+ ++numPrunes;
+ }
+ else
+ {
+ if (leftScore == DBL_MAX)
+ {
+ numPrunes += 2;
+ }
+ else
+ {
+ // Choose the left first. Restore the left traversal info. Store the
+ // right traversal info.
+ queryList.push(queryNode.Right());
+ referenceList.push(referenceNode.Left());
+ traversalInfos.push(leftInfo);
+// Log::Debug << "Push19 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+// << referenceList.back()->Count() << "\n";
+
+ // Is it still valid to recurse to the right?
+ rightScore = rule.Rescore(*queryNode.Right(), *referenceNode.Right(),
+ rightScore);
+
+ if (rightScore != DBL_MAX)
+ {
+ // Restore the right traversal info.
+ queryList.push(queryNode.Right());
+ referenceList.push(referenceNode.Right());
+ traversalInfos.push(rule.TraversalInfo());
+// Log::Debug << "Push20 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+// << referenceList.back()->Count() << "\n";
+ }
+ else
+ ++numPrunes;
+ }
+ }
+ }
+ }
+}
+
+}; // namespace tree
+}; // namespace mlpack
+
+#endif // __MLPACK_CORE_TREE_BINARY_SPACE_TREE_BREADTH_FIRST_DUAL_TREE_TRAVERSER_IMPL_HPP
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list