[mlpack] 52/149: Properly handle the case where the tree doesn't rearrange points -- like the cover tree. Then create a CoverTreeDTNNKMeans template typedef so that a user can easily use cover tree DTNNKMeans with KMeans<>.
Barak A. Pearlmutter
barak+git at pearlmutter.net
Sat May 2 09:11:08 UTC 2015
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch svn-trunk
in repository mlpack.
commit 5d9b8ad704c10a315910ef2b04d8233d738ddcc6
Author: rcurtin <rcurtin at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date: Tue Oct 14 06:14:51 2014 +0000
Properly handle the case where the tree doesn't rearrange points -- like the
cover tree. Then create a CoverTreeDTNNKMeans template typedef so that a user
can easily use cover tree DTNNKMeans with KMeans<>.
git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@17260 9d5b8971-822b-0410-80eb-d18c1038ef23
---
src/mlpack/methods/kmeans/dtnn_kmeans.hpp | 9 ++++
src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp | 57 +++++++++++++++++---------
2 files changed, 46 insertions(+), 20 deletions(-)
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
index 4dacd6e..fddca15 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
@@ -12,6 +12,7 @@
#include <mlpack/core/tree/binary_space_tree.hpp>
#include <mlpack/methods/neighbor_search/neighbor_search.hpp>
+#include <mlpack/core/tree/cover_tree.hpp>
namespace mlpack {
namespace kmeans {
@@ -78,9 +79,17 @@ class DTNNKMeans
void UpdateTree(TreeType& node, const double tolerance);
};
+//! A template typedef for the DTNNKMeans algorithm with the default tree type
+//! (a kd-tree).
template<typename MetricType, typename MatType>
using DefaultDTNNKMeans = DTNNKMeans<MetricType, MatType>;
+//! A template typedef for the DTNNKMeans algorithm with the cover tree type.
+template<typename MetricType, typename MatType>
+using CoverTreeDTNNKMeans = DTNNKMeans<MetricType, MatType,
+ tree::CoverTree<metric::EuclideanDistance, tree::FirstPointIsRoot,
+ neighbor::NeighborSearchStat<neighbor::NearestNeighborSort> > >;
+
} // namespace kmeans
} // namespace mlpack
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index 44b62c1..fde2b44 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -16,6 +16,32 @@
namespace mlpack {
namespace kmeans {
+//! Call the tree constructor that does mapping.
+template<typename TreeType>
+TreeType* BuildTree(
+ typename TreeType::Mat& dataset,
+ std::vector<size_t>& oldFromNew,
+ typename boost::enable_if_c<
+ tree::TreeTraits<TreeType>::RearrangesDataset == true, TreeType*
+ >::type = 0)
+{
+ // This is a hack. I know this will be BinarySpaceTree, so force a leaf size
+ // of two.
+ return new TreeType(dataset, oldFromNew, 1);
+}
+
+//! Call the tree constructor that does not do mapping.
+template<typename TreeType>
+TreeType* BuildTree(
+ const typename TreeType::Mat& dataset,
+ const std::vector<size_t>& /* oldFromNew */,
+ const typename boost::enable_if_c<
+ tree::TreeTraits<TreeType>::RearrangesDataset == false, TreeType*
+ >::type = 0)
+{
+ return new TreeType(dataset);
+}
+
template<typename MetricType, typename MatType, typename TreeType>
DTNNKMeans<MetricType, MatType, TreeType>::DTNNKMeans(const MatType& dataset,
MetricType& metric) :
@@ -53,27 +79,15 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
{
newCentroids.zeros(centroids.n_rows, centroids.n_cols);
counts.zeros(centroids.n_cols);
- arma::mat centroidsCopy;
// Build a tree on the centroids.
std::vector<size_t> oldFromNewCentroids;
- TreeType* centroidTree;
- if (tree::TreeTraits<TreeType>::RearrangesDataset)
- {
- // Manually set leaf size of 2. This may not always be appropriate.
- centroidsCopy = centroids;
- centroidTree = new TreeType(centroidsCopy, oldFromNewCentroids, 2);
- }
- else
- {
- centroidTree = new TreeType(centroidsCopy);
- }
+ TreeType* centroidTree = BuildTree<TreeType>(
+ const_cast<typename TreeType::Mat&>(centroids), oldFromNewCentroids);
typedef neighbor::NeighborSearch<neighbor::NearestNeighborSort, MetricType,
TreeType> AllkNNType;
- AllkNNType allknn(centroidTree, tree,
- (tree::TreeTraits<TreeType>::RearrangesDataset) ? centroidsCopy :
- centroids, dataset, false, metric);
+ AllkNNType allknn(centroidTree, tree, centroids, dataset, false, metric);
// This is a lot of overhead. We don't need the distances.
arma::mat distances;
@@ -92,7 +106,7 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
else
{
newCentroids.col(assignments[i]) += dataset.col(i);
- ++counts(i);
+ ++counts(assignments[i]);
}
}
@@ -101,15 +115,18 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
double maxMovement = 0.0;
for (size_t c = 0; c < centroids.n_cols; ++c)
{
- if (counts[c] == 0)
+ // Get the mapping to the old cluster, if necessary.
+ const size_t old = (tree::TreeTraits<TreeType>::RearrangesDataset) ?
+ oldFromNewCentroids[c] : c;
+ if (counts[old] == 0)
{
- newCentroids.col(c).fill(DBL_MAX); // Should have happened anyway I think.
+ newCentroids.col(old).fill(DBL_MAX);
}
else
{
- newCentroids.col(c) /= counts(c);
+ newCentroids.col(old) /= counts(old);
const double movement = metric.Evaluate(centroids.col(c),
- newCentroids.col(c));
+ newCentroids.col(old));
residual += std::pow(movement, 2.0);
if (movement > maxMovement)
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list