[mlpack] 122/324: Refactor tree construction so that arbitrary tree types can be constructed.

Barak A. Pearlmutter barak+git at cs.nuim.ie
Sun Aug 17 08:22:02 UTC 2014


This is an automated email from the git hooks/post-receive script.

bap pushed a commit to branch svn-trunk
in repository mlpack.

commit a48e7911ba1d783f94ca330edb9425cd29376774
Author: rcurtin <rcurtin at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date:   Thu Jul 3 20:46:27 2014 +0000

    Refactor tree construction so that arbitrary tree types can be constructed.
    
    
    git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@16760 9d5b8971-822b-0410-80eb-d18c1038ef23
---
 src/mlpack/methods/emst/dtb_impl.hpp | 117 ++++++++++++++++-------------------
 src/mlpack/methods/emst/dtb_stat.hpp |  14 ++++-
 2 files changed, 65 insertions(+), 66 deletions(-)

diff --git a/src/mlpack/methods/emst/dtb_impl.hpp b/src/mlpack/methods/emst/dtb_impl.hpp
index 94a4027..2a257bc 100644
--- a/src/mlpack/methods/emst/dtb_impl.hpp
+++ b/src/mlpack/methods/emst/dtb_impl.hpp
@@ -12,36 +12,30 @@
 namespace mlpack {
 namespace emst {
 
-// DTBStat
-
-/**
- * A generic initializer.
- */
-inline DTBStat::DTBStat() :
-    maxNeighborDistance(DBL_MAX),
-    minNeighborDistance(DBL_MAX),
-    bound(DBL_MAX),
-    componentMembership(-1)
+//! Call the tree constructor that does mapping.
+template<typename TreeType>
+TreeType* BuildTree(
+    typename TreeType::Mat& dataset,
+    std::vector<size_t>& oldFromNew,
+    typename boost::enable_if_c<
+        tree::TreeTraits<TreeType>::RearrangesDataset == true, TreeType*
+    >::type = 0)
 {
-  // Nothing to do.
+  return new TreeType(dataset, oldFromNew);
 }
 
-/**
- * An initializer for leaves.
- */
+//! Call the tree constructor that does not do mapping.
 template<typename TreeType>
-DTBStat::DTBStat(const TreeType& node) :
-    maxNeighborDistance(DBL_MAX),
-    minNeighborDistance(DBL_MAX),
-    bound(DBL_MAX),
-    componentMembership(((node.NumPoints() == 1) && (node.NumChildren() == 0)) ?
-        node.Point(0) : -1)
+TreeType* BuildTree(
+    const typename TreeType::Mat& dataset,
+    const std::vector<size_t>& /* oldFromNew */,
+    const typename boost::enable_if_c<
+        tree::TreeTraits<TreeType>::RearrangesDataset == false, TreeType*
+    >::type = 0)
 {
-  // Nothing to do.
+  return new TreeType(dataset);
 }
 
-// DualTreeBoruvka
-
 /**
  * Takes in a reference to the data set.  Copies the data, builds the tree,
  * and initializes all of the member variables.
@@ -51,20 +45,24 @@ DualTreeBoruvka<MetricType, TreeType>::DualTreeBoruvka(
     const typename TreeType::Mat& dataset,
     const bool naive,
     const MetricType metric) :
-    dataCopy(dataset),
-    data(dataCopy), // The reference points to our copy of the data.
+    data((tree::TreeTraits<TreeType>::RearrangesDataset && !naive) ? dataCopy : dataset),
     ownTree(!naive),
     naive(naive),
-    connections(data.n_cols),
+    connections(dataset.n_cols),
     totalDist(0.0),
     metric(metric)
 {
   Timer::Start("emst/tree_building");
 
-  // Default leaf size is 1; this gives the best pruning, empirically.  Use
-  // leaf_size = 1 unless space is a big concern.
   if (!naive)
-    tree = new TreeType(dataCopy, oldFromNew);
+  {
+    // Copy the dataset, if it will be modified during tree construction.
+    if (tree::TreeTraits<TreeType>::RearrangesDataset)
+      dataCopy = dataset;
+
+    tree = BuildTree<TreeType>(const_cast<typename TreeType::Mat&>(data),
+        oldFromNew);
+  }
 
   Timer::Stop("emst/tree_building");
 
@@ -89,7 +87,7 @@ DualTreeBoruvka<MetricType, TreeType>::DualTreeBoruvka(
     totalDist(0.0),
     metric(metric)
 {
-  edges.reserve(data.n_cols - 1); // fill with EdgePairs
+  edges.reserve(data.n_cols - 1); // Fill with EdgePairs.
 
   neighborsInComponent.set_size(data.n_cols);
   neighborsOutComponent.set_size(data.n_cols);
@@ -205,7 +203,7 @@ void DualTreeBoruvka<MetricType, TreeType>::EmitResults(arma::mat& results)
   results.set_size(3, edges.size());
 
   // Need to unpermute the point labels.
-  if (!naive && ownTree)
+  if (!naive && ownTree && tree::TreeTraits<TreeType>::RearrangesDataset)
   {
     for (size_t i = 0; i < (data.n_cols - 1); i++)
     {
@@ -248,39 +246,34 @@ void DualTreeBoruvka<MetricType, TreeType>::EmitResults(arma::mat& results)
 template<typename MetricType, typename TreeType>
 void DualTreeBoruvka<MetricType, TreeType>::CleanupHelper(TreeType* tree)
 {
+  // Reset the statistic information.
   tree->Stat().MaxNeighborDistance() = DBL_MAX;
   tree->Stat().MinNeighborDistance() = DBL_MAX;
   tree->Stat().Bound() = DBL_MAX;
 
-  if (!tree->IsLeaf())
-  {
-    CleanupHelper(tree->Left());
-    CleanupHelper(tree->Right());
+  // Recurse into all children.
+  for (size_t i = 0; i < tree->NumChildren(); ++i)
+    CleanupHelper(&tree->Child(i));
 
-    if ((tree->Left()->Stat().ComponentMembership() >= 0)
-        && (tree->Left()->Stat().ComponentMembership() ==
-            tree->Right()->Stat().ComponentMembership()))
-    {
-      tree->Stat().ComponentMembership() =
-          tree->Left()->Stat().ComponentMembership();
-    }
-  }
-  else
-  {
-    size_t newMembership = connections.Find(tree->Begin());
+  // Get the component of the first child or point.  Then we will check to see
+  // if all other components of children and points are the same.
+  const int component = (tree->NumChildren() != 0) ?
+      tree->Child(0).Stat().ComponentMembership() :
+      connections.Find(tree->Point(0));
 
-    for (size_t i = tree->Begin(); i < tree->End(); ++i)
-    {
-      if (newMembership != connections.Find(i))
-      {
-        newMembership = -1;
-        Log::Assert(tree->Stat().ComponentMembership() < 0);
-        return;
-      }
-    }
-    tree->Stat().ComponentMembership() = newMembership;
-  }
-} // CleanupHelper
+  // Check components of children.
+  for (size_t i = 0; i < tree->NumChildren(); ++i)
+    if (tree->Child(i).Stat().ComponentMembership() != component)
+      return;
+
+  // Check components of points.
+  for (size_t i = 0; i < tree->NumPoints(); ++i)
+    if (connections.Find(tree->Point(i)) != component)
+      return;
+
+  // If we made it this far, all components are the same.
+  tree->Stat().ComponentMembership() = component;
+}
 
 /**
  * The values stored in the tree must be reset on each iteration.
@@ -289,14 +282,10 @@ template<typename MetricType, typename TreeType>
 void DualTreeBoruvka<MetricType, TreeType>::Cleanup()
 {
   for (size_t i = 0; i < data.n_cols; i++)
-  {
     neighborsDistances[i] = DBL_MAX;
-  }
 
   if (!naive)
-  {
     CleanupHelper(tree);
-  }
 }
 
 // convert the object to a string
@@ -304,12 +293,12 @@ template<typename MetricType, typename TreeType>
 std::string DualTreeBoruvka<MetricType, TreeType>::ToString() const
 {
   std::ostringstream convert;
-  convert << "Dual Tree Boruvka [" << this << "]" << std::endl;
+  convert << "DualTreeBoruvka [" << this << "]" << std::endl;
   convert << "  Data: " << data.n_rows << "x" << data.n_cols <<std::endl;
   convert << "  Total Distance: " << totalDist <<std::endl;
   convert << "  Naive: " << naive << std::endl;
   convert << "  Metric: " << std::endl;
-  convert << mlpack::util::Indent(metric.ToString(),2);
+  convert << util::Indent(metric.ToString(), 2);
   convert << std::endl;
   return convert.str();
 }
diff --git a/src/mlpack/methods/emst/dtb_stat.hpp b/src/mlpack/methods/emst/dtb_stat.hpp
index 98351a2..0c6839c 100644
--- a/src/mlpack/methods/emst/dtb_stat.hpp
+++ b/src/mlpack/methods/emst/dtb_stat.hpp
@@ -41,7 +41,11 @@ class DTBStat
    * A generic initializer.  Sets the maximum neighbor distance to its default,
    * and the component membership to -1 (no component).
    */
-  DTBStat();
+  DTBStat() :
+      maxNeighborDistance(DBL_MAX),
+      minNeighborDistance(DBL_MAX),
+      bound(DBL_MAX),
+      componentMembership(-1) { }
 
   /**
    * This is called when a node is finished initializing.  We set the maximum
@@ -51,7 +55,13 @@ class DTBStat
    * @param node Node that has been finished.
    */
   template<typename TreeType>
-  DTBStat(const TreeType& node);
+  DTBStat(const TreeType& node) :
+      maxNeighborDistance(DBL_MAX),
+      minNeighborDistance(DBL_MAX),
+      bound(DBL_MAX),
+      componentMembership(
+          ((node.NumPoints() == 1) && (node.NumChildren() == 0)) ?
+            node.Point(0) : -1) { }
 
   //! Get the maximum neighbor distance.
   double MaxNeighborDistance() const { return maxNeighborDistance; }

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git



More information about the debian-science-commits mailing list