[mlpack] 81/149: Refactor Elkan-type prune into its own method, for simplicity.
Barak A. Pearlmutter
barak+git at pearlmutter.net
Sat May 2 09:11:11 UTC 2015
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch svn-trunk
in repository mlpack.
commit 1687a65ff6dfb10665ee02a0fcd407d5ad8d5f77
Author: rcurtin <rcurtin at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date: Fri Nov 7 20:54:39 2014 +0000
Refactor Elkan-type prune into its own method, for simplicity.
git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@17309 9d5b8971-822b-0410-80eb-d18c1038ef23
---
.../methods/kmeans/dual_tree_kmeans_rules.hpp | 33 +++++
.../methods/kmeans/dual_tree_kmeans_rules_impl.hpp | 143 +++++++++------------
2 files changed, 94 insertions(+), 82 deletions(-)
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp
index e9320d1..4a54192 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp
@@ -70,6 +70,39 @@ class DualTreeKMeansRules
bool IsDescendantOf(const TreeType& potentialParent, const TreeType&
potentialChild) const;
+
+ /**
+ * See if an Elkan-type prune can be performed. If so, return DBL_MAX;
+ * otherwise, return a score. The Elkan-type prune can occur when the minimum
+ * distance between the query node and the current best query node for the
+ * reference node (referenceNode.Stat().ClosestQueryNode()) is greater than
+ * two times the maximum distance between the reference node and the current
+ * best query node (again, referenceNode.Stat().ClosestQueryNode()).
+ *
+ * @param queryNode Query node.
+ * @param referenceNode Reference node.
+ */
+ double ElkanTypeScore(TreeType& queryNode, TreeType& referenceNode) const;
+
+ /**
+ * See if an Elkan-type prune can be performed. If so, return DBL_MAX;
+ * otherwise, return a score. The Elkan-type prune can occur when the minimum
+ * distance between the query node and the current best query node for the
+ * reference node (referenceNode.Stat().ClosestQueryNode()) is greater than
+ * two times the maximum distance between the reference node and the current
+ * best query node (again, referenceNode.Stat().ClosestQueryNode()).
+ *
+ * This particular overload is for when the minimum distance between the query
+ * noed and the current best query node has already been calculated.
+ *
+ * @param queryNode Query node.
+ * @param referenceNode Reference node.
+ * @param minQueryDistance Minimum distance between query node and current
+ * best query node for the reference node.
+ */
+ double ElkanTypeScore(TreeType& queryNode,
+ TreeType& referenceNode,
+ const double minQueryDistance) const;
};
} // namespace kmeans
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
index adcedad..1e352de 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
@@ -140,48 +140,9 @@ double DualTreeKMeansRules<MetricType, TreeType>::Score(
return 0.0; // Pruning is not possible.
}
- // See if we can do an Elkan-type prune on between-centroid distances.
- const double maxDistance = referenceNode.Stat().MaxQueryNodeDistance();
- const double minQueryDistance = queryNode.MinDistance((TreeType*)
- referenceNode.Stat().ClosestQueryNode());
++distanceCalculations;
- if (minQueryDistance > 2.0 * maxDistance)
- {
- // Then we can conclude d_max(best(N_r), N_r) <= d_min(N_q, N_r) which
- // means that N_q cannot possibly hold any clusters that own any points in
- // N_r.
- referenceNode.Stat().ClustersPruned() += queryNode.NumDescendants();
-
- // Have we pruned everything?
- if (referenceNode.Stat().ClustersPruned() == centroids.n_cols - 1)
- {
- // Then the best query node must contain just one point.
- const TreeType* bestQueryNode = (TreeType*)
- referenceNode.Stat().ClosestQueryNode();
- const size_t cluster = mappings[bestQueryNode->Descendant(0)];
-
- referenceNode.Stat().Owner() = cluster;
- newCentroids.col(cluster) += referenceNode.NumDescendants() *
- referenceNode.Stat().Centroid();
- counts(cluster) += referenceNode.NumDescendants();
- referenceNode.Stat().ClustersPruned()++;
- }
- else if (referenceNode.Stat().ClustersPruned() +
- visited[referenceNode.Descendant(0)] == centroids.n_cols)
- {
- for (size_t i = 0; i < referenceNode.NumPoints(); ++i)
- {
- const size_t cluster = assignments[referenceNode.Point(i)];
- newCentroids.col(cluster) += dataset.col(referenceNode.Point(i));
- counts(cluster)++;
- }
- }
-
- return DBL_MAX;
- }
-
- return minQueryDistance;
+ return ElkanTypeScore(queryNode, referenceNode);
}
template<typename MetricType, typename TreeType>
@@ -202,48 +163,7 @@ double DualTreeKMeansRules<MetricType, TreeType>::Rescore(
if (oldScore == DBL_MAX)
return oldScore; // We can't unprune something. This shouldn't happen.
- // Can we update the minimum query node distance for this reference node?
- const double minQueryDistance = oldScore;
-
- // See if we can do an Elkan-type prune on between-centroid distances.
- const double maxDistance = referenceNode.Stat().MaxQueryNodeDistance();
-
- if (minQueryDistance > 2.0 * maxDistance)
- {
- // Then we can conclude d_max(best(N_r), N_r) <= d_min(N_q, N_r) which
- // means that N_q cannot possibly hold any clusters that own any points in
- // N_r.
- referenceNode.Stat().ClustersPruned() += queryNode.NumDescendants();
-
- // Have we pruned everything?
- if (referenceNode.Stat().ClustersPruned() == centroids.n_cols - 1)
- {
- // Then the best query node must contain just one point.
- const TreeType* bestQueryNode = (TreeType*)
- referenceNode.Stat().ClosestQueryNode();
- const size_t cluster = mappings[bestQueryNode->Descendant(0)];
-
- referenceNode.Stat().Owner() = cluster;
- newCentroids.col(cluster) += referenceNode.NumDescendants() *
- referenceNode.Stat().Centroid();
- counts(cluster) += referenceNode.NumDescendants();
- referenceNode.Stat().ClustersPruned()++;
- }
- else if (referenceNode.Stat().ClustersPruned() +
- visited[referenceNode.Descendant(0)] == centroids.n_cols)
- {
- for (size_t i = 0; i < referenceNode.NumPoints(); ++i)
- {
- const size_t cluster = assignments[referenceNode.Point(i)];
- newCentroids.col(cluster) += dataset.col(referenceNode.Point(i));
- counts(cluster)++;
- }
- }
-
- return DBL_MAX;
- }
-
- return oldScore;
+ return ElkanTypeScore(queryNode, referenceNode, oldScore);
}
template<typename MetricType, typename TreeType>
@@ -311,6 +231,65 @@ bool DualTreeKMeansRules<MetricType, TreeType>::IsDescendantOf(
return IsDescendantOf(potentialParent, *potentialChild.Parent());
}
+template<typename MetricType, typename TreeType>
+double DualTreeKMeansRules<MetricType, TreeType>::ElkanTypeScore(
+ TreeType& queryNode,
+ TreeType& referenceNode) const
+{
+ // We have to calculate the minimum distance between the query node and the
+ // reference node's best query node.
+ const double minQueryDistance = queryNode.MinDistance((TreeType*)
+ referenceNode.Stat().ClosestQueryNode());
+ return ElkanTypeScore(queryNode, referenceNode, minQueryDistance);
+}
+
+template<typename MetricType, typename TreeType>
+double DualTreeKMeansRules<MetricType, TreeType>::ElkanTypeScore(
+ TreeType& queryNode,
+ TreeType& referenceNode,
+ const double minQueryDistance) const
+{
+ // See if we can do an Elkan-type prune on between-centroid distances.
+ const double maxDistance = referenceNode.Stat().MaxQueryNodeDistance();
+
+ if (minQueryDistance > 2.0 * maxDistance)
+ {
+ // Then we can conclude d_max(best(N_r), N_r) <= d_min(N_q, N_r) which
+ // means that N_q cannot possibly hold any clusters that own any points in
+ // N_r.
+ referenceNode.Stat().ClustersPruned() += queryNode.NumDescendants();
+
+ // Have we pruned everything?
+ if (referenceNode.Stat().ClustersPruned() == centroids.n_cols - 1)
+ {
+ // Then the best query node must contain just one point.
+ const TreeType* bestQueryNode = (TreeType*)
+ referenceNode.Stat().ClosestQueryNode();
+ const size_t cluster = mappings[bestQueryNode->Descendant(0)];
+
+ referenceNode.Stat().Owner() = cluster;
+ newCentroids.col(cluster) += referenceNode.NumDescendants() *
+ referenceNode.Stat().Centroid();
+ counts(cluster) += referenceNode.NumDescendants();
+ referenceNode.Stat().ClustersPruned()++;
+ }
+ else if (referenceNode.Stat().ClustersPruned() +
+ visited[referenceNode.Descendant(0)] == centroids.n_cols)
+ {
+ for (size_t i = 0; i < referenceNode.NumPoints(); ++i)
+ {
+ const size_t cluster = assignments[referenceNode.Point(i)];
+ newCentroids.col(cluster) += dataset.col(referenceNode.Point(i));
+ counts(cluster)++;
+ }
+ }
+
+ return DBL_MAX;
+ }
+
+ return minQueryDistance;
+}
+
} // namespace kmeans
} // namespace mlpack
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list