[mlpack] 52/149: Properly handle the case where the tree doesn't rearrange points -- like the cover tree. Then create a CoverTreeDTNNKMeans template typedef so that a user can easily use cover tree DTNNKMeans with KMeans<>.

Barak A. Pearlmutter barak+git at pearlmutter.net
Sat May 2 09:11:08 UTC 2015


This is an automated email from the git hooks/post-receive script.

bap pushed a commit to branch svn-trunk
in repository mlpack.

commit 5d9b8ad704c10a315910ef2b04d8233d738ddcc6
Author: rcurtin <rcurtin at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date:   Tue Oct 14 06:14:51 2014 +0000

    Properly handle the case where the tree doesn't rearrange points -- like the
    cover tree.  Then create a CoverTreeDTNNKMeans template typedef so that a user
    can easily use cover tree DTNNKMeans with KMeans<>.
    
    
    git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@17260 9d5b8971-822b-0410-80eb-d18c1038ef23
---
 src/mlpack/methods/kmeans/dtnn_kmeans.hpp      |  9 ++++
 src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp | 57 +++++++++++++++++---------
 2 files changed, 46 insertions(+), 20 deletions(-)

diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
index 4dacd6e..fddca15 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
@@ -12,6 +12,7 @@
 
 #include <mlpack/core/tree/binary_space_tree.hpp>
 #include <mlpack/methods/neighbor_search/neighbor_search.hpp>
+#include <mlpack/core/tree/cover_tree.hpp>
 
 namespace mlpack {
 namespace kmeans {
@@ -78,9 +79,17 @@ class DTNNKMeans
   void UpdateTree(TreeType& node, const double tolerance);
 };
 
+//! A template typedef for the DTNNKMeans algorithm with the default tree type
+//! (a kd-tree).
 template<typename MetricType, typename MatType>
 using DefaultDTNNKMeans = DTNNKMeans<MetricType, MatType>;
 
+//! A template typedef for the DTNNKMeans algorithm with the cover tree type.
+template<typename MetricType, typename MatType>
+using CoverTreeDTNNKMeans = DTNNKMeans<MetricType, MatType,
+    tree::CoverTree<metric::EuclideanDistance, tree::FirstPointIsRoot,
+    neighbor::NeighborSearchStat<neighbor::NearestNeighborSort> > >;
+
 } // namespace kmeans
 } // namespace mlpack
 
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index 44b62c1..fde2b44 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -16,6 +16,32 @@
 namespace mlpack {
 namespace kmeans {
 
+//! Call the tree constructor that does mapping.
+template<typename TreeType>
+TreeType* BuildTree(
+    typename TreeType::Mat& dataset,
+    std::vector<size_t>& oldFromNew,
+    typename boost::enable_if_c<
+        tree::TreeTraits<TreeType>::RearrangesDataset == true, TreeType*
+    >::type = 0)
+{
+  // This is a hack.  I know this will be BinarySpaceTree, so force a leaf size
+  // of two.
+  return new TreeType(dataset, oldFromNew, 1);
+}
+
+//! Call the tree constructor that does not do mapping.
+template<typename TreeType>
+TreeType* BuildTree(
+    const typename TreeType::Mat& dataset,
+    const std::vector<size_t>& /* oldFromNew */,
+    const typename boost::enable_if_c<
+        tree::TreeTraits<TreeType>::RearrangesDataset == false, TreeType*
+    >::type = 0)
+{
+  return new TreeType(dataset);
+}
+
 template<typename MetricType, typename MatType, typename TreeType>
 DTNNKMeans<MetricType, MatType, TreeType>::DTNNKMeans(const MatType& dataset,
                                                       MetricType& metric) :
@@ -53,27 +79,15 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
 {
   newCentroids.zeros(centroids.n_rows, centroids.n_cols);
   counts.zeros(centroids.n_cols);
-  arma::mat centroidsCopy;
 
   // Build a tree on the centroids.
   std::vector<size_t> oldFromNewCentroids;
-  TreeType* centroidTree;
-  if (tree::TreeTraits<TreeType>::RearrangesDataset)
-  {
-    // Manually set leaf size of 2.  This may not always be appropriate.
-    centroidsCopy = centroids;
-    centroidTree = new TreeType(centroidsCopy, oldFromNewCentroids, 2);
-  }
-  else
-  {
-    centroidTree = new TreeType(centroidsCopy);
-  }
+  TreeType* centroidTree = BuildTree<TreeType>(
+      const_cast<typename TreeType::Mat&>(centroids), oldFromNewCentroids);
 
   typedef neighbor::NeighborSearch<neighbor::NearestNeighborSort, MetricType,
       TreeType> AllkNNType;
-  AllkNNType allknn(centroidTree, tree,
-      (tree::TreeTraits<TreeType>::RearrangesDataset) ? centroidsCopy :
-      centroids, dataset, false, metric);
+  AllkNNType allknn(centroidTree, tree, centroids, dataset, false, metric);
 
   // This is a lot of overhead.  We don't need the distances.
   arma::mat distances;
@@ -92,7 +106,7 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
     else
     {
       newCentroids.col(assignments[i]) += dataset.col(i);
-      ++counts(i);
+      ++counts(assignments[i]);
     }
   }
 
@@ -101,15 +115,18 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
   double maxMovement = 0.0;
   for (size_t c = 0; c < centroids.n_cols; ++c)
   {
-    if (counts[c] == 0)
+    // Get the mapping to the old cluster, if necessary.
+    const size_t old = (tree::TreeTraits<TreeType>::RearrangesDataset) ?
+        oldFromNewCentroids[c] : c;
+    if (counts[old] == 0)
     {
-      newCentroids.col(c).fill(DBL_MAX); // Should have happened anyway I think.
+      newCentroids.col(old).fill(DBL_MAX);
     }
     else
     {
-      newCentroids.col(c) /= counts(c);
+      newCentroids.col(old) /= counts(old);
       const double movement = metric.Evaluate(centroids.col(c),
-          newCentroids.col(c));
+          newCentroids.col(old));
       residual += std::pow(movement, 2.0);
 
       if (movement > maxMovement)

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git



More information about the debian-science-commits mailing list