[mlpack] 122/324: Refactor tree construction so that arbitrary tree types can be constructed.
Barak A. Pearlmutter
barak+git at cs.nuim.ie
Sun Aug 17 08:22:02 UTC 2014
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch svn-trunk
in repository mlpack.
commit a48e7911ba1d783f94ca330edb9425cd29376774
Author: rcurtin <rcurtin at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date: Thu Jul 3 20:46:27 2014 +0000
Refactor tree construction so that arbitrary tree types can be constructed.
git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@16760 9d5b8971-822b-0410-80eb-d18c1038ef23
---
src/mlpack/methods/emst/dtb_impl.hpp | 117 ++++++++++++++++-------------------
src/mlpack/methods/emst/dtb_stat.hpp | 14 ++++-
2 files changed, 65 insertions(+), 66 deletions(-)
diff --git a/src/mlpack/methods/emst/dtb_impl.hpp b/src/mlpack/methods/emst/dtb_impl.hpp
index 94a4027..2a257bc 100644
--- a/src/mlpack/methods/emst/dtb_impl.hpp
+++ b/src/mlpack/methods/emst/dtb_impl.hpp
@@ -12,36 +12,30 @@
namespace mlpack {
namespace emst {
-// DTBStat
-
-/**
- * A generic initializer.
- */
-inline DTBStat::DTBStat() :
- maxNeighborDistance(DBL_MAX),
- minNeighborDistance(DBL_MAX),
- bound(DBL_MAX),
- componentMembership(-1)
+//! Call the tree constructor that does mapping.
+template<typename TreeType>
+TreeType* BuildTree(
+ typename TreeType::Mat& dataset,
+ std::vector<size_t>& oldFromNew,
+ typename boost::enable_if_c<
+ tree::TreeTraits<TreeType>::RearrangesDataset == true, TreeType*
+ >::type = 0)
{
- // Nothing to do.
+ return new TreeType(dataset, oldFromNew);
}
-/**
- * An initializer for leaves.
- */
+//! Call the tree constructor that does not do mapping.
template<typename TreeType>
-DTBStat::DTBStat(const TreeType& node) :
- maxNeighborDistance(DBL_MAX),
- minNeighborDistance(DBL_MAX),
- bound(DBL_MAX),
- componentMembership(((node.NumPoints() == 1) && (node.NumChildren() == 0)) ?
- node.Point(0) : -1)
+TreeType* BuildTree(
+ const typename TreeType::Mat& dataset,
+ const std::vector<size_t>& /* oldFromNew */,
+ const typename boost::enable_if_c<
+ tree::TreeTraits<TreeType>::RearrangesDataset == false, TreeType*
+ >::type = 0)
{
- // Nothing to do.
+ return new TreeType(dataset);
}
-// DualTreeBoruvka
-
/**
* Takes in a reference to the data set. Copies the data, builds the tree,
* and initializes all of the member variables.
@@ -51,20 +45,24 @@ DualTreeBoruvka<MetricType, TreeType>::DualTreeBoruvka(
const typename TreeType::Mat& dataset,
const bool naive,
const MetricType metric) :
- dataCopy(dataset),
- data(dataCopy), // The reference points to our copy of the data.
+ data((tree::TreeTraits<TreeType>::RearrangesDataset && !naive) ? dataCopy : dataset),
ownTree(!naive),
naive(naive),
- connections(data.n_cols),
+ connections(dataset.n_cols),
totalDist(0.0),
metric(metric)
{
Timer::Start("emst/tree_building");
- // Default leaf size is 1; this gives the best pruning, empirically. Use
- // leaf_size = 1 unless space is a big concern.
if (!naive)
- tree = new TreeType(dataCopy, oldFromNew);
+ {
+ // Copy the dataset, if it will be modified during tree construction.
+ if (tree::TreeTraits<TreeType>::RearrangesDataset)
+ dataCopy = dataset;
+
+ tree = BuildTree<TreeType>(const_cast<typename TreeType::Mat&>(data),
+ oldFromNew);
+ }
Timer::Stop("emst/tree_building");
@@ -89,7 +87,7 @@ DualTreeBoruvka<MetricType, TreeType>::DualTreeBoruvka(
totalDist(0.0),
metric(metric)
{
- edges.reserve(data.n_cols - 1); // fill with EdgePairs
+ edges.reserve(data.n_cols - 1); // Fill with EdgePairs.
neighborsInComponent.set_size(data.n_cols);
neighborsOutComponent.set_size(data.n_cols);
@@ -205,7 +203,7 @@ void DualTreeBoruvka<MetricType, TreeType>::EmitResults(arma::mat& results)
results.set_size(3, edges.size());
// Need to unpermute the point labels.
- if (!naive && ownTree)
+ if (!naive && ownTree && tree::TreeTraits<TreeType>::RearrangesDataset)
{
for (size_t i = 0; i < (data.n_cols - 1); i++)
{
@@ -248,39 +246,34 @@ void DualTreeBoruvka<MetricType, TreeType>::EmitResults(arma::mat& results)
template<typename MetricType, typename TreeType>
void DualTreeBoruvka<MetricType, TreeType>::CleanupHelper(TreeType* tree)
{
+ // Reset the statistic information.
tree->Stat().MaxNeighborDistance() = DBL_MAX;
tree->Stat().MinNeighborDistance() = DBL_MAX;
tree->Stat().Bound() = DBL_MAX;
- if (!tree->IsLeaf())
- {
- CleanupHelper(tree->Left());
- CleanupHelper(tree->Right());
+ // Recurse into all children.
+ for (size_t i = 0; i < tree->NumChildren(); ++i)
+ CleanupHelper(&tree->Child(i));
- if ((tree->Left()->Stat().ComponentMembership() >= 0)
- && (tree->Left()->Stat().ComponentMembership() ==
- tree->Right()->Stat().ComponentMembership()))
- {
- tree->Stat().ComponentMembership() =
- tree->Left()->Stat().ComponentMembership();
- }
- }
- else
- {
- size_t newMembership = connections.Find(tree->Begin());
+ // Get the component of the first child or point. Then we will check to see
+ // if all other components of children and points are the same.
+ const int component = (tree->NumChildren() != 0) ?
+ tree->Child(0).Stat().ComponentMembership() :
+ connections.Find(tree->Point(0));
- for (size_t i = tree->Begin(); i < tree->End(); ++i)
- {
- if (newMembership != connections.Find(i))
- {
- newMembership = -1;
- Log::Assert(tree->Stat().ComponentMembership() < 0);
- return;
- }
- }
- tree->Stat().ComponentMembership() = newMembership;
- }
-} // CleanupHelper
+ // Check components of children.
+ for (size_t i = 0; i < tree->NumChildren(); ++i)
+ if (tree->Child(i).Stat().ComponentMembership() != component)
+ return;
+
+ // Check components of points.
+ for (size_t i = 0; i < tree->NumPoints(); ++i)
+ if (connections.Find(tree->Point(i)) != component)
+ return;
+
+ // If we made it this far, all components are the same.
+ tree->Stat().ComponentMembership() = component;
+}
/**
* The values stored in the tree must be reset on each iteration.
@@ -289,14 +282,10 @@ template<typename MetricType, typename TreeType>
void DualTreeBoruvka<MetricType, TreeType>::Cleanup()
{
for (size_t i = 0; i < data.n_cols; i++)
- {
neighborsDistances[i] = DBL_MAX;
- }
if (!naive)
- {
CleanupHelper(tree);
- }
}
// convert the object to a string
@@ -304,12 +293,12 @@ template<typename MetricType, typename TreeType>
std::string DualTreeBoruvka<MetricType, TreeType>::ToString() const
{
std::ostringstream convert;
- convert << "Dual Tree Boruvka [" << this << "]" << std::endl;
+ convert << "DualTreeBoruvka [" << this << "]" << std::endl;
convert << " Data: " << data.n_rows << "x" << data.n_cols <<std::endl;
convert << " Total Distance: " << totalDist <<std::endl;
convert << " Naive: " << naive << std::endl;
convert << " Metric: " << std::endl;
- convert << mlpack::util::Indent(metric.ToString(),2);
+ convert << util::Indent(metric.ToString(), 2);
convert << std::endl;
return convert.str();
}
diff --git a/src/mlpack/methods/emst/dtb_stat.hpp b/src/mlpack/methods/emst/dtb_stat.hpp
index 98351a2..0c6839c 100644
--- a/src/mlpack/methods/emst/dtb_stat.hpp
+++ b/src/mlpack/methods/emst/dtb_stat.hpp
@@ -41,7 +41,11 @@ class DTBStat
* A generic initializer. Sets the maximum neighbor distance to its default,
* and the component membership to -1 (no component).
*/
- DTBStat();
+ DTBStat() :
+ maxNeighborDistance(DBL_MAX),
+ minNeighborDistance(DBL_MAX),
+ bound(DBL_MAX),
+ componentMembership(-1) { }
/**
* This is called when a node is finished initializing. We set the maximum
@@ -51,7 +55,13 @@ class DTBStat
* @param node Node that has been finished.
*/
template<typename TreeType>
- DTBStat(const TreeType& node);
+ DTBStat(const TreeType& node) :
+ maxNeighborDistance(DBL_MAX),
+ minNeighborDistance(DBL_MAX),
+ bound(DBL_MAX),
+ componentMembership(
+ ((node.NumPoints() == 1) && (node.NumChildren() == 0)) ?
+ node.Point(0) : -1) { }
//! Get the maximum neighbor distance.
double MaxNeighborDistance() const { return maxNeighborDistance; }
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list