[mlpack] 56/324: Remove leafSize parameter from DTB constructor.
Barak A. Pearlmutter
barak+git at cs.nuim.ie
Sun Aug 17 08:21:55 UTC 2014
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch svn-trunk
in repository mlpack.
commit 0e2b33c09d712a14ea5799211d4ac144a29cca84
Author: rcurtin <rcurtin at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date: Thu Jun 12 20:15:48 2014 +0000
Remove leafSize parameter from DTB constructor.
git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@16685 9d5b8971-822b-0410-80eb-d18c1038ef23
---
src/mlpack/methods/emst/dtb.hpp | 6 ++---
src/mlpack/methods/emst/dtb_impl.hpp | 43 +++++++++++++++++-------------
src/mlpack/methods/emst/dtb_rules.hpp | 19 +++++++++++--
src/mlpack/methods/emst/dtb_rules_impl.hpp | 11 +++++---
src/mlpack/methods/emst/emst_main.cpp | 36 ++++++++++++++++++++++---
5 files changed, 85 insertions(+), 30 deletions(-)
diff --git a/src/mlpack/methods/emst/dtb.hpp b/src/mlpack/methods/emst/dtb.hpp
index 3852fba..667da69 100644
--- a/src/mlpack/methods/emst/dtb.hpp
+++ b/src/mlpack/methods/emst/dtb.hpp
@@ -79,7 +79,7 @@ class DualTreeBoruvka
//! Copy of the data (if necessary).
typename TreeType::Mat dataCopy;
//! Reference to the data (this is what should be used for accessing data).
- typename TreeType::Mat& data;
+ const typename TreeType::Mat& data;
//! Pointer to the root of the tree.
TreeType* tree;
@@ -130,7 +130,6 @@ class DualTreeBoruvka
*/
DualTreeBoruvka(const typename TreeType::Mat& dataset,
const bool naive = false,
- const size_t leafSize = 1,
const MetricType metric = MetricType());
/**
@@ -150,7 +149,8 @@ class DualTreeBoruvka
* @param tree Pre-built tree.
* @param dataset Dataset corresponding to the pre-built tree.
*/
- DualTreeBoruvka(TreeType* tree, const typename TreeType::Mat& dataset,
+ DualTreeBoruvka(TreeType* tree,
+ const typename TreeType::Mat& dataset,
const MetricType metric = MetricType());
/**
diff --git a/src/mlpack/methods/emst/dtb_impl.hpp b/src/mlpack/methods/emst/dtb_impl.hpp
index 138ef55..94a4027 100644
--- a/src/mlpack/methods/emst/dtb_impl.hpp
+++ b/src/mlpack/methods/emst/dtb_impl.hpp
@@ -50,11 +50,10 @@ template<typename MetricType, typename TreeType>
DualTreeBoruvka<MetricType, TreeType>::DualTreeBoruvka(
const typename TreeType::Mat& dataset,
const bool naive,
- const size_t leafSize,
const MetricType metric) :
dataCopy(dataset),
data(dataCopy), // The reference points to our copy of the data.
- ownTree(true),
+ ownTree(!naive),
naive(naive),
connections(data.n_cols),
totalDist(0.0),
@@ -62,17 +61,10 @@ DualTreeBoruvka<MetricType, TreeType>::DualTreeBoruvka(
{
Timer::Start("emst/tree_building");
+ // Default leaf size is 1; this gives the best pruning, empirically. Use
+ // leaf_size = 1 unless space is a big concern.
if (!naive)
- {
- // Default leaf size is 1; this gives the best pruning, empirically. Use
- // leaf_size = 1 unless space is a big concern.
- tree = new TreeType(data, oldFromNew, leafSize);
- }
- else
- {
- // Naive tree holds all data in one leaf.
- tree = new TreeType(data, oldFromNew, data.n_cols);
- }
+ tree = new TreeType(dataCopy, oldFromNew);
Timer::Stop("emst/tree_building");
@@ -91,7 +83,7 @@ DualTreeBoruvka<MetricType, TreeType>::DualTreeBoruvka(
const MetricType metric) :
data(dataset),
tree(tree),
- ownTree(true),
+ ownTree(false),
naive(false),
connections(data.n_cols),
totalDist(0.0),
@@ -126,19 +118,32 @@ void DualTreeBoruvka<MetricType, TreeType>::ComputeMST(arma::mat& results)
typedef DTBRules<MetricType, TreeType> RuleType;
RuleType rules(data, connections, neighborsDistances, neighborsInComponent,
neighborsOutComponent, metric);
-
while (edges.size() < (data.n_cols - 1))
{
- typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
-
- traverser.Traverse(*tree, *tree);
+ if (naive)
+ {
+ // Full O(N^2) traversal.
+ for (size_t i = 0; i < data.n_cols; ++i)
+ for (size_t j = 0; j < data.n_cols; ++j)
+ rules.BaseCase(i, j);
+ }
+ else
+ {
+ typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+ traverser.Traverse(*tree, *tree);
+ }
AddAllEdges();
Cleanup();
Log::Info << edges.size() << " edges found so far." << std::endl;
- Log::Info << traverser.NumPrunes() << " nodes pruned." << std::endl;
+ if (!naive)
+ {
+ Log::Info << rules.BaseCases() << " cumulative base cases." << std::endl;
+ Log::Info << rules.Scores() << " cumulative node combinations scored."
+ << std::endl;
+ }
}
Timer::Stop("emst/mst_computation");
@@ -146,7 +151,7 @@ void DualTreeBoruvka<MetricType, TreeType>::ComputeMST(arma::mat& results)
EmitResults(results);
Log::Info << "Total spanning tree length: " << totalDist << std::endl;
-} // ComputeMST
+}
/**
* Adds a single edge to the edge list
diff --git a/src/mlpack/methods/emst/dtb_rules.hpp b/src/mlpack/methods/emst/dtb_rules.hpp
index 5024ff6..bc0c98e 100644
--- a/src/mlpack/methods/emst/dtb_rules.hpp
+++ b/src/mlpack/methods/emst/dtb_rules.hpp
@@ -74,7 +74,7 @@ class DTBRules
* @param queryNode Candidate query node to recurse into.
* @param referenceNode Candidate reference node to recurse into.
*/
- double Score(TreeType& queryNode, TreeType& referenceNode) const;
+ double Score(TreeType& queryNode, TreeType& referenceNode);
/**
* Get the score for recursion order, passing the base case result (in the
@@ -88,7 +88,7 @@ class DTBRules
*/
double Score(TreeType& queryNode,
TreeType& referenceNode,
- const double baseCaseResult) const;
+ const double baseCaseResult);
/**
* Re-evaluate the score for recursion order. A low score indicates priority
@@ -110,6 +110,16 @@ class DTBRules
const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
TraversalInfoType& TraversalInfo() { return traversalInfo; }
+ //! Get the number of base cases performed.
+ size_t BaseCases() const { return baseCases; }
+ //! Modify the number of base cases performed.
+ size_t& BaseCases() { return baseCases; }
+
+ //! Get the number of node combinations that have been scored.
+ size_t Scores() const { return scores; }
+ //! Modify the number of node combinations that have been scored.
+ size_t& Scores() { return scores; }
+
private:
//! The data points.
const arma::mat& dataSet;
@@ -138,6 +148,11 @@ class DTBRules
TraversalInfoType traversalInfo;
+ //! The number of base cases calculated.
+ size_t baseCases;
+ //! The number of node combinations that have been scored.
+ size_t scores;
+
}; // class DTBRules
} // emst namespace
diff --git a/src/mlpack/methods/emst/dtb_rules_impl.hpp b/src/mlpack/methods/emst/dtb_rules_impl.hpp
index 8405a17..6d70528 100644
--- a/src/mlpack/methods/emst/dtb_rules_impl.hpp
+++ b/src/mlpack/methods/emst/dtb_rules_impl.hpp
@@ -24,7 +24,9 @@ DTBRules(const arma::mat& dataSet,
neighborsDistances(neighborsDistances),
neighborsInComponent(neighborsInComponent),
neighborsOutComponent(neighborsOutComponent),
- metric(metric)
+ metric(metric),
+ baseCases(0),
+ scores(0)
{
// Nothing else to do.
}
@@ -46,6 +48,7 @@ double DTBRules<MetricType, TreeType>::BaseCase(const size_t queryIndex,
if (queryComponentIndex != referenceComponentIndex)
{
+ ++baseCases;
double distance = metric.Evaluate(dataSet.col(queryIndex),
dataSet.col(referenceIndex));
@@ -127,7 +130,7 @@ double DTBRules<MetricType, TreeType>::Rescore(const size_t queryIndex,
template<typename MetricType, typename TreeType>
double DTBRules<MetricType, TreeType>::Score(TreeType& queryNode,
- TreeType& referenceNode) const
+ TreeType& referenceNode)
{
// If all the queries belong to the same component as all the references
// then we prune.
@@ -136,6 +139,7 @@ double DTBRules<MetricType, TreeType>::Score(TreeType& queryNode,
referenceNode.Stat().ComponentMembership()))
return DBL_MAX;
+ ++scores;
const double distance = queryNode.MinDistance(&referenceNode);
const double bound = CalculateBound(queryNode);
@@ -147,7 +151,7 @@ double DTBRules<MetricType, TreeType>::Score(TreeType& queryNode,
template<typename MetricType, typename TreeType>
double DTBRules<MetricType, TreeType>::Score(TreeType& queryNode,
TreeType& referenceNode,
- const double baseCaseResult) const
+ const double baseCaseResult)
{
// If all the queries belong to the same component as all the references
// then we prune.
@@ -156,6 +160,7 @@ double DTBRules<MetricType, TreeType>::Score(TreeType& queryNode,
referenceNode.Stat().ComponentMembership()))
return DBL_MAX;
+ ++scores;
const double distance = queryNode.MinDistance(referenceNode, baseCaseResult);
const double bound = CalculateBound(queryNode);
diff --git a/src/mlpack/methods/emst/emst_main.cpp b/src/mlpack/methods/emst/emst_main.cpp
index bba4b92..56a9c4b 100644
--- a/src/mlpack/methods/emst/emst_main.cpp
+++ b/src/mlpack/methods/emst/emst_main.cpp
@@ -80,18 +80,48 @@ int main(int argc, char* argv[])
<< ")! Must be greater than or equal to 1." << std::endl;
}
- // Initialize the tree and get ready to compute the MST.
+ // Initialize the tree and get ready to compute the MST. Compute the tree
+ // by hand.
const size_t leafSize = (size_t) CLI::GetParam<int>("leaf_size");
- DualTreeBoruvka<> dtb(dataPoints, false, leafSize);
+
+ Timer::Start("tree_building");
+ std::vector<size_t> oldFromNew;
+ tree::BinarySpaceTree<bound::HRectBound<2>, DTBStat> tree(dataPoints,
+ oldFromNew, leafSize);
+ metric::LMetric<2, true> metric;
+ Timer::Stop("tree_building");
+
+ DualTreeBoruvka<> dtb(&tree, dataPoints, metric);
// Run the DTB algorithm.
Log::Info << "Calculating minimum spanning tree." << endl;
arma::mat results;
dtb.ComputeMST(results);
+ // Unmap the results.
+ arma::mat unmappedResults(results.n_rows, results.n_cols);
+ for (size_t i = 0; i < results.n_cols; ++i)
+ {
+ const size_t indexA = oldFromNew[size_t(results(0, i))];
+ const size_t indexB = oldFromNew[size_t(results(1, i))];
+
+ if (indexA < indexB)
+ {
+ unmappedResults(0, i) = indexA;
+ unmappedResults(1, i) = indexB;
+ }
+ else
+ {
+ unmappedResults(0, i) = indexB;
+ unmappedResults(1, i) = indexA;
+ }
+
+ unmappedResults(2, i) = results(2, i);
+ }
+
// Output the results.
const string outputFilename = CLI::GetParam<string>("output_file");
- data::Save(outputFilename, results, true);
+ data::Save(outputFilename, unmappedResults, true);
}
}
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list