[mlpack] 56/324: Remove leafSize parameter from DTB constructor.

Barak A. Pearlmutter barak+git at cs.nuim.ie
Sun Aug 17 08:21:55 UTC 2014


This is an automated email from the git hooks/post-receive script.

bap pushed a commit to branch svn-trunk
in repository mlpack.

commit 0e2b33c09d712a14ea5799211d4ac144a29cca84
Author: rcurtin <rcurtin at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date:   Thu Jun 12 20:15:48 2014 +0000

    Remove leafSize parameter from DTB constructor.
    
    
    git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@16685 9d5b8971-822b-0410-80eb-d18c1038ef23
---
 src/mlpack/methods/emst/dtb.hpp            |  6 ++---
 src/mlpack/methods/emst/dtb_impl.hpp       | 43 +++++++++++++++++-------------
 src/mlpack/methods/emst/dtb_rules.hpp      | 19 +++++++++++--
 src/mlpack/methods/emst/dtb_rules_impl.hpp | 11 +++++---
 src/mlpack/methods/emst/emst_main.cpp      | 36 ++++++++++++++++++++++---
 5 files changed, 85 insertions(+), 30 deletions(-)

diff --git a/src/mlpack/methods/emst/dtb.hpp b/src/mlpack/methods/emst/dtb.hpp
index 3852fba..667da69 100644
--- a/src/mlpack/methods/emst/dtb.hpp
+++ b/src/mlpack/methods/emst/dtb.hpp
@@ -79,7 +79,7 @@ class DualTreeBoruvka
   //! Copy of the data (if necessary).
   typename TreeType::Mat dataCopy;
   //! Reference to the data (this is what should be used for accessing data).
-  typename TreeType::Mat& data;
+  const typename TreeType::Mat& data;
 
   //! Pointer to the root of the tree.
   TreeType* tree;
@@ -130,7 +130,6 @@ class DualTreeBoruvka
    */
   DualTreeBoruvka(const typename TreeType::Mat& dataset,
                   const bool naive = false,
-                  const size_t leafSize = 1,
                   const MetricType metric = MetricType());
 
   /**
@@ -150,7 +149,8 @@ class DualTreeBoruvka
    * @param tree Pre-built tree.
    * @param dataset Dataset corresponding to the pre-built tree.
    */
-  DualTreeBoruvka(TreeType* tree, const typename TreeType::Mat& dataset,
+  DualTreeBoruvka(TreeType* tree,
+                  const typename TreeType::Mat& dataset,
                   const MetricType metric = MetricType());
 
   /**
diff --git a/src/mlpack/methods/emst/dtb_impl.hpp b/src/mlpack/methods/emst/dtb_impl.hpp
index 138ef55..94a4027 100644
--- a/src/mlpack/methods/emst/dtb_impl.hpp
+++ b/src/mlpack/methods/emst/dtb_impl.hpp
@@ -50,11 +50,10 @@ template<typename MetricType, typename TreeType>
 DualTreeBoruvka<MetricType, TreeType>::DualTreeBoruvka(
     const typename TreeType::Mat& dataset,
     const bool naive,
-    const size_t leafSize,
     const MetricType metric) :
     dataCopy(dataset),
     data(dataCopy), // The reference points to our copy of the data.
-    ownTree(true),
+    ownTree(!naive),
     naive(naive),
     connections(data.n_cols),
     totalDist(0.0),
@@ -62,17 +61,10 @@ DualTreeBoruvka<MetricType, TreeType>::DualTreeBoruvka(
 {
   Timer::Start("emst/tree_building");
 
+  // Default leaf size is 1; this gives the best pruning, empirically.  Use
+  // leaf_size = 1 unless space is a big concern.
   if (!naive)
-  {
-    // Default leaf size is 1; this gives the best pruning, empirically.  Use
-    // leaf_size = 1 unless space is a big concern.
-    tree = new TreeType(data, oldFromNew, leafSize);
-  }
-  else
-  {
-    // Naive tree holds all data in one leaf.
-    tree = new TreeType(data, oldFromNew, data.n_cols);
-  }
+    tree = new TreeType(dataCopy, oldFromNew);
 
   Timer::Stop("emst/tree_building");
 
@@ -91,7 +83,7 @@ DualTreeBoruvka<MetricType, TreeType>::DualTreeBoruvka(
     const MetricType metric) :
     data(dataset),
     tree(tree),
-    ownTree(true),
+    ownTree(false),
     naive(false),
     connections(data.n_cols),
     totalDist(0.0),
@@ -126,19 +118,32 @@ void DualTreeBoruvka<MetricType, TreeType>::ComputeMST(arma::mat& results)
   typedef DTBRules<MetricType, TreeType> RuleType;
   RuleType rules(data, connections, neighborsDistances, neighborsInComponent,
                  neighborsOutComponent, metric);
-
   while (edges.size() < (data.n_cols - 1))
   {
-    typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
-
-    traverser.Traverse(*tree, *tree);
+    if (naive)
+    {
+      // Full O(N^2) traversal.
+      for (size_t i = 0; i < data.n_cols; ++i)
+        for (size_t j = 0; j < data.n_cols; ++j)
+          rules.BaseCase(i, j);
+    }
+    else
+    {
+      typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+      traverser.Traverse(*tree, *tree);
+    }
 
     AddAllEdges();
 
     Cleanup();
 
     Log::Info << edges.size() << " edges found so far." << std::endl;
-    Log::Info << traverser.NumPrunes() << " nodes pruned." << std::endl;
+    if (!naive)
+    {
+      Log::Info << rules.BaseCases() << " cumulative base cases." << std::endl;
+      Log::Info << rules.Scores() << " cumulative node combinations scored."
+          << std::endl;
+    }
   }
 
   Timer::Stop("emst/mst_computation");
@@ -146,7 +151,7 @@ void DualTreeBoruvka<MetricType, TreeType>::ComputeMST(arma::mat& results)
   EmitResults(results);
 
   Log::Info << "Total spanning tree length: " << totalDist << std::endl;
-} // ComputeMST
+}
 
 /**
  * Adds a single edge to the edge list
diff --git a/src/mlpack/methods/emst/dtb_rules.hpp b/src/mlpack/methods/emst/dtb_rules.hpp
index 5024ff6..bc0c98e 100644
--- a/src/mlpack/methods/emst/dtb_rules.hpp
+++ b/src/mlpack/methods/emst/dtb_rules.hpp
@@ -74,7 +74,7 @@ class DTBRules
    * @param queryNode Candidate query node to recurse into.
    * @param referenceNode Candidate reference node to recurse into.
    */
-  double Score(TreeType& queryNode, TreeType& referenceNode) const;
+  double Score(TreeType& queryNode, TreeType& referenceNode);
 
   /**
    * Get the score for recursion order, passing the base case result (in the
@@ -88,7 +88,7 @@ class DTBRules
    */
   double Score(TreeType& queryNode,
                TreeType& referenceNode,
-               const double baseCaseResult) const;
+               const double baseCaseResult);
 
   /**
    * Re-evaluate the score for recursion order.  A low score indicates priority
@@ -110,6 +110,16 @@ class DTBRules
   const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
   TraversalInfoType& TraversalInfo() { return traversalInfo; }
 
+  //! Get the number of base cases performed.
+  size_t BaseCases() const { return baseCases; }
+  //! Modify the number of base cases performed.
+  size_t& BaseCases() { return baseCases; }
+
+  //! Get the number of node combinations that have been scored.
+  size_t Scores() const { return scores; }
+  //! Modify the number of node combinations that have been scored.
+  size_t& Scores() { return scores; }
+
  private:
   //! The data points.
   const arma::mat& dataSet;
@@ -138,6 +148,11 @@ class DTBRules
 
   TraversalInfoType traversalInfo;
 
+  //! The number of base cases calculated.
+  size_t baseCases;
+  //! The number of node combinations that have been scored.
+  size_t scores;
+
 }; // class DTBRules
 
 } // emst namespace
diff --git a/src/mlpack/methods/emst/dtb_rules_impl.hpp b/src/mlpack/methods/emst/dtb_rules_impl.hpp
index 8405a17..6d70528 100644
--- a/src/mlpack/methods/emst/dtb_rules_impl.hpp
+++ b/src/mlpack/methods/emst/dtb_rules_impl.hpp
@@ -24,7 +24,9 @@ DTBRules(const arma::mat& dataSet,
   neighborsDistances(neighborsDistances),
   neighborsInComponent(neighborsInComponent),
   neighborsOutComponent(neighborsOutComponent),
-  metric(metric)
+  metric(metric),
+  baseCases(0),
+  scores(0)
 {
   // Nothing else to do.
 }
@@ -46,6 +48,7 @@ double DTBRules<MetricType, TreeType>::BaseCase(const size_t queryIndex,
 
   if (queryComponentIndex != referenceComponentIndex)
   {
+    ++baseCases;
     double distance = metric.Evaluate(dataSet.col(queryIndex),
                                       dataSet.col(referenceIndex));
 
@@ -127,7 +130,7 @@ double DTBRules<MetricType, TreeType>::Rescore(const size_t queryIndex,
 
 template<typename MetricType, typename TreeType>
 double DTBRules<MetricType, TreeType>::Score(TreeType& queryNode,
-                                             TreeType& referenceNode) const
+                                             TreeType& referenceNode)
 {
   // If all the queries belong to the same component as all the references
   // then we prune.
@@ -136,6 +139,7 @@ double DTBRules<MetricType, TreeType>::Score(TreeType& queryNode,
            referenceNode.Stat().ComponentMembership()))
     return DBL_MAX;
 
+  ++scores;
   const double distance = queryNode.MinDistance(&referenceNode);
   const double bound = CalculateBound(queryNode);
 
@@ -147,7 +151,7 @@ double DTBRules<MetricType, TreeType>::Score(TreeType& queryNode,
 template<typename MetricType, typename TreeType>
 double DTBRules<MetricType, TreeType>::Score(TreeType& queryNode,
                                              TreeType& referenceNode,
-                                             const double baseCaseResult) const
+                                             const double baseCaseResult)
 {
   // If all the queries belong to the same component as all the references
   // then we prune.
@@ -156,6 +160,7 @@ double DTBRules<MetricType, TreeType>::Score(TreeType& queryNode,
            referenceNode.Stat().ComponentMembership()))
     return DBL_MAX;
 
+  ++scores;
   const double distance = queryNode.MinDistance(referenceNode, baseCaseResult);
   const double bound = CalculateBound(queryNode);
 
diff --git a/src/mlpack/methods/emst/emst_main.cpp b/src/mlpack/methods/emst/emst_main.cpp
index bba4b92..56a9c4b 100644
--- a/src/mlpack/methods/emst/emst_main.cpp
+++ b/src/mlpack/methods/emst/emst_main.cpp
@@ -80,18 +80,48 @@ int main(int argc, char* argv[])
           << ")!  Must be greater than or equal to 1." << std::endl;
     }
 
-    // Initialize the tree and get ready to compute the MST.
+    // Initialize the tree and get ready to compute the MST.  Compute the tree
+    // by hand.
     const size_t leafSize = (size_t) CLI::GetParam<int>("leaf_size");
-    DualTreeBoruvka<> dtb(dataPoints, false, leafSize);
+
+    Timer::Start("tree_building");
+    std::vector<size_t> oldFromNew;
+    tree::BinarySpaceTree<bound::HRectBound<2>, DTBStat> tree(dataPoints,
+        oldFromNew, leafSize);
+    metric::LMetric<2, true> metric;
+    Timer::Stop("tree_building");
+
+    DualTreeBoruvka<> dtb(&tree, dataPoints, metric);
 
     // Run the DTB algorithm.
     Log::Info << "Calculating minimum spanning tree." << endl;
     arma::mat results;
     dtb.ComputeMST(results);
 
+    // Unmap the results.
+    arma::mat unmappedResults(results.n_rows, results.n_cols);
+    for (size_t i = 0; i < results.n_cols; ++i)
+    {
+      const size_t indexA = oldFromNew[size_t(results(0, i))];
+      const size_t indexB = oldFromNew[size_t(results(1, i))];
+
+      if (indexA < indexB)
+      {
+        unmappedResults(0, i) = indexA;
+        unmappedResults(1, i) = indexB;
+      }
+      else
+      {
+        unmappedResults(0, i) = indexB;
+        unmappedResults(1, i) = indexA;
+      }
+
+      unmappedResults(2, i) = results(2, i);
+    }
+
     // Output the results.
     const string outputFilename = CLI::GetParam<string>("output_file");
 
-    data::Save(outputFilename, results, true);
+    data::Save(outputFilename, unmappedResults, true);
   }
 }

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git



More information about the debian-science-commits mailing list