[mlpack] 70/149: This is an experimental method that I am working on. Right now it is not very useful as I have not implemented all of the pruning strategies that I intend to.
Barak A. Pearlmutter
barak+git at pearlmutter.net
Sat May 2 09:11:10 UTC 2015
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch svn-trunk
in repository mlpack.
commit 6370c12e5209b3234311ff919d2f67aff30354dd
Author: rcurtin <rcurtin at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date: Wed Nov 5 19:36:49 2014 +0000
This is an experimental method that I am working on. Right now it is not very
useful as I have not implemented all of the pruning strategies that I intend to.
git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@17298 9d5b8971-822b-0410-80eb-d18c1038ef23
---
src/mlpack/methods/kmeans/dual_tree_kmeans.hpp | 71 +++++
.../methods/kmeans/dual_tree_kmeans_impl.hpp | 116 ++++++++
.../methods/kmeans/dual_tree_kmeans_rules.hpp | 80 ++++++
.../methods/kmeans/dual_tree_kmeans_rules_impl.hpp | 317 +++++++++++++++++++++
.../methods/kmeans/dual_tree_kmeans_statistic.hpp | 96 +++++++
5 files changed, 680 insertions(+)
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
new file mode 100644
index 0000000..f2b6376
--- /dev/null
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
@@ -0,0 +1,71 @@
+/**
+ * @file dual_tree_kmeans.hpp
+ * @author Ryan Curtin
+ *
+ * A dual-tree algorithm for a single k-means iteration.
+ */
+#ifndef __MLPACK_METHODS_KMEANS_DUAL_TREE_KMEANS_HPP
+#define __MLPACK_METHODS_KMEANS_DUAL_TREE_KMEANS_HPP
+
+#include "dual_tree_kmeans_statistic.hpp"
+
+namespace mlpack {
+namespace kmeans {
+
+template<
+ typename MetricType,
+ typename MatType,
+ typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2>,
+ DualTreeKMeansStatistic>
+>
+class DualTreeKMeans
+{
+ public:
+ DualTreeKMeans(const MatType& dataset, MetricType& metric);
+
+ ~DualTreeKMeans();
+
+ double Iterate(const arma::mat& centroids,
+ arma::mat& newCentroids,
+ arma::Col<size_t>& counts);
+
+ //! Return the number of distance calculations.
+ size_t DistanceCalculations() const { return distanceCalculations; }
+ //! Modify the number of distance calculations.
+ size_t& DistanceCalculations() { return distanceCalculations; }
+
+ private:
+ //! The original dataset reference.
+ const MatType& datasetOrig;
+ //! The dataset we are using.
+ const MatType& dataset;
+ //! A copy of the dataset, if necessary.
+ MatType datasetCopy;
+ //! The metric.
+ MetricType metric;
+
+ //! The tree built on the points.
+ TreeType* tree;
+
+ arma::vec clusterDistances;
+ arma::Col<size_t> assignments;
+ arma::vec distances;
+ arma::Col<size_t> distanceIteration;
+
+ //! The current iteration.
+ size_t iteration;
+
+ //! Track distance calculations.
+ size_t distanceCalculations;
+};
+
+template<typename MetricType, typename MatType>
+using DefaultDualTreeKMeans = DualTreeKMeans<MetricType, MatType>;
+
+} // namespace kmeans
+} // namespace mlpack
+
+// Include implementation.
+#include "dual_tree_kmeans_impl.hpp"
+
+#endif
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
new file mode 100644
index 0000000..35a3a9d
--- /dev/null
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
@@ -0,0 +1,116 @@
+/**
+ * @file dual_tree_kmeans_impl.hpp
+ * @author Ryan Curtin
+ *
+ * A dual-tree algorithm for a single k-means iteration.
+ */
+#ifndef __MLPACK_METHODS_KMEANS_DUAL_TREE_KMEANS_IMPL_HPP
+#define __MLPACK_METHODS_KMEANS_DUAL_TREE_KMEANS_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "dual_tree_kmeans.hpp"
+#include "dual_tree_kmeans_rules.hpp"
+
+namespace mlpack {
+namespace kmeans {
+
+template<typename MetricType, typename MatType, typename TreeType>
+DualTreeKMeans<MetricType, MatType, TreeType>::DualTreeKMeans(
+ const MatType& dataset,
+ MetricType& metric) :
+ datasetOrig(dataset),
+ dataset(tree::TreeTraits<TreeType>::RearrangesDataset ? datasetCopy :
+ datasetOrig),
+ metric(metric),
+ iteration(0),
+ distanceCalculations(0)
+{
+ distances.set_size(dataset.n_cols);
+ distances.fill(DBL_MAX);
+ assignments.zeros(dataset.n_cols);
+ distanceIteration.zeros(dataset.n_cols);
+
+ Timer::Start("tree_building");
+
+ // Copy the dataset, if necessary.
+ if (tree::TreeTraits<TreeType>::RearrangesDataset)
+ datasetCopy = datasetOrig;
+
+ // Now build the tree. We don't need any mappings.
+ tree = new TreeType(const_cast<typename TreeType::Mat&>(this->dataset));
+
+ Timer::Stop("tree_building");
+}
+
+template<typename MetricType, typename MatType, typename TreeType>
+DualTreeKMeans<MetricType, MatType, TreeType>::~DualTreeKMeans()
+{
+ if (tree)
+ delete tree;
+}
+
+template<typename MetricType, typename MatType, typename TreeType>
+double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
+ const arma::mat& centroids,
+ arma::mat& newCentroids,
+ arma::Col<size_t>& counts)
+{
+ newCentroids.zeros(centroids.n_rows, centroids.n_cols);
+ counts.zeros(centroids.n_cols);
+ if (clusterDistances.n_elem != centroids.n_cols + 1)
+ {
+ clusterDistances.set_size(centroids.n_cols + 1);
+ clusterDistances.fill(DBL_MAX / 2.0); // To prevent overflow.
+ }
+
+ // Build a tree on the centroids.
+ std::vector<size_t> oldFromNewCentroids;
+ TreeType* centroidTree = BuildTree<TreeType>(
+ const_cast<typename TreeType::Mat&>(centroids), oldFromNewCentroids);
+
+ // Now run the dual-tree algorithm.
+ typedef DualTreeKMeansRules<MetricType, TreeType> RulesType;
+ RulesType rules(dataset, centroids, newCentroids, counts, oldFromNewCentroids,
+ iteration, clusterDistances, distances, assignments, distanceIteration,
+ metric);
+
+ // Use the dual-tree traverser.
+//typename TreeType::template DualTreeTraverser<RulesType> traverser(rules);
+ typename TreeType::template BreadthFirstDualTreeTraverser<RulesType>
+ traverser(rules);
+
+ traverser.Traverse(*centroidTree, *tree);
+
+ distanceCalculations += rules.DistanceCalculations();
+
+ // Now, calculate how far the clusters moved, after normalizing them.
+ double residual = 0.0;
+ clusterDistances.zeros();
+ for (size_t c = 0; c < centroids.n_cols; ++c)
+ {
+ if (counts[c] == 0)
+ {
+ newCentroids.col(c).fill(DBL_MAX); // Should have happened anyway I think.
+ }
+ else
+ {
+ const size_t oldCluster = oldFromNewCentroids[c];
+ newCentroids.col(oldCluster) /= counts(oldCluster);
+ const double dist = metric.Evaluate(centroids.col(c),
+ newCentroids.col(oldCluster));
+ if (dist > clusterDistances[centroids.n_cols])
+ clusterDistances[centroids.n_cols] = dist;
+ clusterDistances[oldCluster] = dist;
+ residual += std::pow(dist, 2.0);
+ }
+ }
+ Log::Info << clusterDistances.t();
+
+ ++iteration;
+ return std::sqrt(residual);
+}
+
+} // namespace kmeans
+} // namespace mlpack
+
+#endif
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp
new file mode 100644
index 0000000..e9320d1
--- /dev/null
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp
@@ -0,0 +1,80 @@
+/**
+ * @file dual_tree_kmeans_rules.hpp
+ * @author Ryan Curtin
+ *
+ * A set of tree traversal rules for dual-tree k-means clustering.
+ */
+#ifndef __MLPACK_METHODS_KMEANS_DUAL_TREE_KMEANS_RULES_HPP
+#define __MLPACK_METHODS_KMEANS_DUAL_TREE_KMEANS_RULES_HPP
+
+namespace mlpack {
+namespace kmeans {
+
+template<typename MetricType, typename TreeType>
+class DualTreeKMeansRules
+{
+ public:
+ DualTreeKMeansRules(const typename TreeType::Mat& dataset,
+ const arma::mat& centroids,
+ arma::mat& newCentroids,
+ arma::Col<size_t>& counts,
+ const std::vector<size_t>& mappings,
+ const size_t iteration,
+ const arma::vec& clusterDistances,
+ arma::vec& distances,
+ arma::Col<size_t>& assignments,
+ arma::Col<size_t>& distanceIteration,
+ MetricType& metric);
+
+ double BaseCase(const size_t queryIndex, const size_t referenceIndex);
+
+ double Score(const size_t queryIndex, TreeType& referenceNode);
+
+ double Score(TreeType& queryNode, TreeType& referenceNode);
+
+ double Rescore(const size_t queryIndex,
+ TreeType& referenceNode,
+ const double oldScore) const;
+
+ double Rescore(TreeType& queryNode,
+ TreeType& referenceNode,
+ const double oldScore) const;
+
+ size_t DistanceCalculations() const { return distanceCalculations; }
+ size_t& DistanceCalculations() { return distanceCalculations; }
+
+ typedef neighbor::NeighborSearchTraversalInfo<TreeType> TraversalInfoType;
+
+ const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
+ TraversalInfoType& TraversalInfo() { return traversalInfo; }
+
+ private:
+ const typename TreeType::Mat& dataset;
+ const arma::mat& centroids;
+ arma::mat& newCentroids;
+ arma::Col<size_t>& counts;
+ const std::vector<size_t>& mappings;
+ const size_t iteration;
+ const arma::vec& clusterDistances;
+ arma::vec& distances;
+ arma::Col<size_t>& assignments;
+ arma::Col<size_t> visited;
+ arma::Col<size_t>& distanceIteration;
+ MetricType& metric;
+
+ size_t distanceCalculations;
+
+ TraversalInfoType traversalInfo;
+
+ size_t IterationUpdate(TreeType& referenceNode) const;
+
+ bool IsDescendantOf(const TreeType& potentialParent, const TreeType&
+ potentialChild) const;
+};
+
+} // namespace kmeans
+} // namespace mlpack
+
+#include "dual_tree_kmeans_rules_impl.hpp"
+
+#endif
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
new file mode 100644
index 0000000..adcedad
--- /dev/null
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
@@ -0,0 +1,317 @@
+/**
+ * @file dual_tree_kmeans_rules_impl.hpp
+ * @author Ryan Curtin
+ *
+ * A set of tree traversal rules for dual-tree k-means clustering.
+ */
+#ifndef __MLPACK_METHODS_KMEANS_DUAL_TREE_KMEANS_RULES_IMPL_HPP
+#define __MLPACK_METHODS_KMEANS_DUAL_TREE_KMEANS_RULES_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "dual_tree_kmeans_rules.hpp"
+
+namespace mlpack {
+namespace kmeans {
+
+template<typename MetricType, typename TreeType>
+DualTreeKMeansRules<MetricType, TreeType>::DualTreeKMeansRules(
+ const typename TreeType::Mat& dataset,
+ const arma::mat& centroids,
+ arma::mat& newCentroids,
+ arma::Col<size_t>& counts,
+ const std::vector<size_t>& mappings,
+ const size_t iteration,
+ const arma::vec& clusterDistances,
+ arma::vec& distances,
+ arma::Col<size_t>& assignments,
+ arma::Col<size_t>& distanceIteration,
+ MetricType& metric) :
+ dataset(dataset),
+ centroids(centroids),
+ newCentroids(newCentroids),
+ counts(counts),
+ mappings(mappings),
+ iteration(iteration),
+ clusterDistances(clusterDistances),
+ distances(distances),
+ assignments(assignments),
+ distanceIteration(distanceIteration),
+ metric(metric),
+ distanceCalculations(0)
+{
+ // Nothing has been visited yet.
+ visited.zeros(dataset.n_cols);
+}
+
+template<typename MetricType, typename TreeType>
+inline force_inline double DualTreeKMeansRules<MetricType, TreeType>::BaseCase(
+ const size_t queryIndex,
+ const size_t referenceIndex)
+{
+// Log::Info << "Base case, query " << queryIndex << " (" << mappings[queryIndex]
+// << "), reference " << referenceIndex << ".\n";
+
+ // Collect the number of clusters that have been pruned during the traversal.
+ // The ternary operator may not be necessary.
+ const size_t traversalPruned = (traversalInfo.LastReferenceNode() != NULL &&
+ traversalInfo.LastReferenceNode()->Stat().Iteration() == iteration) ?
+ traversalInfo.LastReferenceNode()->Stat().ClustersPruned() : 0;
+
+ // It's possible that the reference node has been pruned before we got to the
+ // base case. In that case, don't do the base case, and just return.
+ if (traversalInfo.LastReferenceNode()->Stat().ClustersPruned() +
+ visited[referenceIndex] == centroids.n_cols)
+ return 0.0;
+
+ ++distanceCalculations;
+
+ const double distance = metric.Evaluate(centroids.col(queryIndex),
+ dataset.col(referenceIndex));
+
+ // Iteration change check.
+ if (distanceIteration[referenceIndex] < iteration)
+ {
+ distanceIteration[referenceIndex] = iteration;
+ distances[referenceIndex] = distance;
+ assignments[referenceIndex] = mappings[queryIndex];
+ }
+ else if (distance < distances[referenceIndex])
+ {
+ distances[referenceIndex] = distance;
+ assignments[referenceIndex] = mappings[queryIndex];
+ }
+
+ ++visited[referenceIndex];
+
+ if (visited[referenceIndex] + traversalPruned == centroids.n_cols)
+ {
+// Log::Warn << "Commit reference index " << referenceIndex << " to cluster "
+// << assignments[referenceIndex] << ".\n";
+ newCentroids.col(assignments[referenceIndex]) +=
+ dataset.col(referenceIndex);
+ ++counts(assignments[referenceIndex]);
+ }
+
+ return distance;
+}
+
+template<typename MetricType, typename TreeType>
+double DualTreeKMeansRules<MetricType, TreeType>::Score(
+ const size_t queryIndex,
+ TreeType& referenceNode)
+{
+ // Update from previous iteration, if necessary.
+ IterationUpdate(referenceNode);
+
+ // No pruning here, for now.
+ return 0.0;
+}
+
+template<typename MetricType, typename TreeType>
+double DualTreeKMeansRules<MetricType, TreeType>::Score(
+ TreeType& queryNode,
+ TreeType& referenceNode)
+{
+ IterationUpdate(referenceNode);
+
+ traversalInfo.LastReferenceNode() = &referenceNode;
+
+ // Can we update the minimum query node distance for this reference node?
+ const double minDistance = referenceNode.MinDistance(&queryNode);
+ ++distanceCalculations;
+ if (minDistance < referenceNode.Stat().MinQueryNodeDistance())
+ {
+ referenceNode.Stat().ClosestQueryNode() = (void*) &queryNode;
+ referenceNode.Stat().MinQueryNodeDistance() = minDistance;
+ referenceNode.Stat().MaxQueryNodeDistance() =
+ referenceNode.MaxDistance(&queryNode);
+ ++distanceCalculations;
+ return 0.0; // Pruning is not possible.
+ }
+ else if (IsDescendantOf(
+ *((TreeType*) referenceNode.Stat().ClosestQueryNode()), queryNode))
+ {
+ // Just update.
+ referenceNode.Stat().ClosestQueryNode() = (void*) &queryNode;
+ referenceNode.Stat().MinQueryNodeDistance() = minDistance;
+ referenceNode.Stat().MaxQueryNodeDistance() =
+ referenceNode.MaxDistance(&queryNode);
+ ++distanceCalculations;
+ return 0.0; // Pruning is not possible.
+ }
+
+ // See if we can do an Elkan-type prune on between-centroid distances.
+ const double maxDistance = referenceNode.Stat().MaxQueryNodeDistance();
+ const double minQueryDistance = queryNode.MinDistance((TreeType*)
+ referenceNode.Stat().ClosestQueryNode());
+ ++distanceCalculations;
+
+ if (minQueryDistance > 2.0 * maxDistance)
+ {
+ // Then we can conclude d_max(best(N_r), N_r) <= d_min(N_q, N_r) which
+ // means that N_q cannot possibly hold any clusters that own any points in
+ // N_r.
+ referenceNode.Stat().ClustersPruned() += queryNode.NumDescendants();
+
+ // Have we pruned everything?
+ if (referenceNode.Stat().ClustersPruned() == centroids.n_cols - 1)
+ {
+ // Then the best query node must contain just one point.
+ const TreeType* bestQueryNode = (TreeType*)
+ referenceNode.Stat().ClosestQueryNode();
+ const size_t cluster = mappings[bestQueryNode->Descendant(0)];
+
+ referenceNode.Stat().Owner() = cluster;
+ newCentroids.col(cluster) += referenceNode.NumDescendants() *
+ referenceNode.Stat().Centroid();
+ counts(cluster) += referenceNode.NumDescendants();
+ referenceNode.Stat().ClustersPruned()++;
+ }
+ else if (referenceNode.Stat().ClustersPruned() +
+ visited[referenceNode.Descendant(0)] == centroids.n_cols)
+ {
+ for (size_t i = 0; i < referenceNode.NumPoints(); ++i)
+ {
+ const size_t cluster = assignments[referenceNode.Point(i)];
+ newCentroids.col(cluster) += dataset.col(referenceNode.Point(i));
+ counts(cluster)++;
+ }
+ }
+
+ return DBL_MAX;
+ }
+
+ return minQueryDistance;
+}
+
+template<typename MetricType, typename TreeType>
+double DualTreeKMeansRules<MetricType, TreeType>::Rescore(
+ const size_t /* queryIndex */,
+ TreeType& /* referenceNode */,
+ const double oldScore) const
+{
+ return oldScore;
+}
+
+template<typename MetricType, typename TreeType>
+double DualTreeKMeansRules<MetricType, TreeType>::Rescore(
+ TreeType& queryNode,
+ TreeType& referenceNode,
+ const double oldScore) const
+{
+ if (oldScore == DBL_MAX)
+ return oldScore; // We can't unprune something. This shouldn't happen.
+
+ // Can we update the minimum query node distance for this reference node?
+ const double minQueryDistance = oldScore;
+
+ // See if we can do an Elkan-type prune on between-centroid distances.
+ const double maxDistance = referenceNode.Stat().MaxQueryNodeDistance();
+
+ if (minQueryDistance > 2.0 * maxDistance)
+ {
+ // Then we can conclude d_max(best(N_r), N_r) <= d_min(N_q, N_r) which
+ // means that N_q cannot possibly hold any clusters that own any points in
+ // N_r.
+ referenceNode.Stat().ClustersPruned() += queryNode.NumDescendants();
+
+ // Have we pruned everything?
+ if (referenceNode.Stat().ClustersPruned() == centroids.n_cols - 1)
+ {
+ // Then the best query node must contain just one point.
+ const TreeType* bestQueryNode = (TreeType*)
+ referenceNode.Stat().ClosestQueryNode();
+ const size_t cluster = mappings[bestQueryNode->Descendant(0)];
+
+ referenceNode.Stat().Owner() = cluster;
+ newCentroids.col(cluster) += referenceNode.NumDescendants() *
+ referenceNode.Stat().Centroid();
+ counts(cluster) += referenceNode.NumDescendants();
+ referenceNode.Stat().ClustersPruned()++;
+ }
+ else if (referenceNode.Stat().ClustersPruned() +
+ visited[referenceNode.Descendant(0)] == centroids.n_cols)
+ {
+ for (size_t i = 0; i < referenceNode.NumPoints(); ++i)
+ {
+ const size_t cluster = assignments[referenceNode.Point(i)];
+ newCentroids.col(cluster) += dataset.col(referenceNode.Point(i));
+ counts(cluster)++;
+ }
+ }
+
+ return DBL_MAX;
+ }
+
+ return oldScore;
+}
+
+template<typename MetricType, typename TreeType>
+inline size_t DualTreeKMeansRules<MetricType, TreeType>::IterationUpdate(
+ TreeType& referenceNode) const
+{
+ if (referenceNode.Stat().Iteration() == iteration)
+ return 0;
+
+ referenceNode.Stat().Iteration() = iteration;
+ referenceNode.Stat().ClustersPruned() = (referenceNode.Parent() == NULL) ?
+ 0 : referenceNode.Parent()->Stat().ClustersPruned();
+ referenceNode.Stat().ClosestQueryNode() = (referenceNode.Parent() == NULL) ?
+ NULL : referenceNode.Parent()->Stat().ClosestQueryNode();
+
+ if (referenceNode.Stat().ClosestQueryNode() != NULL)
+ referenceNode.Stat().MinQueryNodeDistance() =
+ referenceNode.MinDistance((TreeType*)
+ referenceNode.Stat().ClosestQueryNode());
+
+ const size_t itDiff = iteration - referenceNode.Stat().Iteration();
+ if (itDiff > 1)
+ {
+ // Maybe this can be tighter?
+ referenceNode.Stat().MinQueryNodeDistance() = DBL_MAX;
+ }
+ else
+ {
+ if (referenceNode.Stat().MinQueryNodeDistance() != DBL_MAX)
+ {
+ // Update the distance to the closest query node. If this node has an
+ // owner, we know how far to increase the bound. Otherwise, increase it
+ // by the furthest amount that any centroid moved.
+// if (referenceNode.Stat().Owner() < centroids.n_cols)
+// referenceNode.Stat().MinQueryNodeDistance() +=
+// clusterDistances(referenceNode.Stat().Owner());
+// else
+// referenceNode.Stat().MinQueryNodeDistance() = DBL_MAX;
+// clusterDistances(centroids.n_cols);
+ if (referenceNode.Stat().MaxQueryNodeDistance() == DBL_MAX)
+ referenceNode.Stat().MinQueryNodeDistance() = DBL_MAX;
+ else
+ {
+ referenceNode.Stat().MinQueryNodeDistance() +=
+ clusterDistances(centroids.n_cols);
+//referenceNode.Stat().MaxQueryNodeDistance() +
+//clusterDistances(centroids.n_cols);
+ }
+ }
+ }
+
+ return 1;
+}
+
+template<typename MetricType, typename TreeType>
+bool DualTreeKMeansRules<MetricType, TreeType>::IsDescendantOf(
+ const TreeType& potentialParent,
+ const TreeType& potentialChild) const
+{
+ if (potentialChild.Parent() == &potentialParent)
+ return true;
+ else if (potentialChild.Parent() == NULL)
+ return false;
+ else
+ return IsDescendantOf(potentialParent, *potentialChild.Parent());
+}
+
+} // namespace kmeans
+} // namespace mlpack
+
+#endif
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp
new file mode 100644
index 0000000..21481da
--- /dev/null
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp
@@ -0,0 +1,96 @@
+/**
+ * @file dual_tree_kmeans_statistic.hpp
+ * @author Ryan Curtin
+ *
+ * Statistic for dual-tree k-means traversal.
+ */
+#ifndef __MLPACK_METHODS_KMEANS_DUAL_TREE_KMEANS_STATISTIC_HPP
+#define __MLPACK_METHODS_KMEANS_DUAL_TREE_KMEANS_STATISTIC_HPP
+
+namespace mlpack {
+namespace kmeans {
+
+class DualTreeKMeansStatistic
+{
+ public:
+ DualTreeKMeansStatistic() { /* Nothing to do. */ }
+
+ template<typename TreeType>
+ DualTreeKMeansStatistic(TreeType& node) :
+ closestQueryNode(NULL),
+ minQueryNodeDistance(DBL_MAX),
+ maxQueryNodeDistance(DBL_MAX),
+ clustersPruned(0),
+ iteration(size_t() - 1)
+ {
+ // Empirically calculate the centroid.
+ centroid.zeros(node.Dataset().n_rows);
+ for (size_t i = 0; i < node.NumPoints(); ++i)
+ centroid += node.Dataset().col(node.Point(i));
+
+ for (size_t i = 0; i < node.NumChildren(); ++i)
+ centroid += node.Child(i).NumDescendants() *
+ node.Child(i).Stat().Centroid();
+
+ centroid /= node.NumDescendants();
+ }
+
+ //! Return the centroid.
+ const arma::vec& Centroid() const { return centroid; }
+ //! Modify the centroid.
+ arma::vec& Centroid() { return centroid; }
+
+ //! Get the current closest query node.
+ void* ClosestQueryNode() const { return closestQueryNode; }
+ //! Modify the current closest query node.
+ void*& ClosestQueryNode() { return closestQueryNode; }
+
+ //! Get the minimum distance to the closest query node.
+ double MinQueryNodeDistance() const { return minQueryNodeDistance; }
+ //! Modify the minimum distance to the closest query node.
+ double& MinQueryNodeDistance() { return minQueryNodeDistance; }
+
+ //! Get the maximum distance to the closest query node.
+ double MaxQueryNodeDistance() const { return maxQueryNodeDistance; }
+ //! Modify the maximum distance to the closest query node.
+ double& MaxQueryNodeDistance() { return maxQueryNodeDistance; }
+
+ //! Get the number of clusters that have been pruned during this iteration.
+ size_t ClustersPruned() const { return clustersPruned; }
+ //! Modify the number of clusters that have been pruned during this iteration.
+ size_t& ClustersPruned() { return clustersPruned; }
+
+ //! Get the current iteration.
+ size_t Iteration() const { return iteration; }
+ //! Modify the current iteration.
+ size_t& Iteration() { return iteration; }
+
+ //! Get the current owner (if any) of these reference points.
+ size_t Owner() const { return owner; }
+ //! Modify the current owner (if any) of these reference points.
+ size_t& Owner() { return owner; }
+
+ private:
+ //! The empirically calculated centroid of the node.
+ arma::vec centroid;
+
+ //! The current closest query node to this reference node.
+ void* closestQueryNode;
+ //! The minimum distance to the closest query node.
+ double minQueryNodeDistance;
+ //! The maximum distance to the closest query node.
+ double maxQueryNodeDistance;
+
+ //! The number of clusters that have been pruned.
+ size_t clustersPruned;
+ //! The current iteration.
+ size_t iteration;
+ //! The owner of these reference nodes (centroids.n_cols if there is no
+ //! owner).
+ size_t owner;
+};
+
+} // namespace kmeans
+} // namespace mlpack
+
+#endif
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list