[mlpack] 26/37: Backport trunk fixes for NeighborSearch.
Barak A. Pearlmutter
barak+git at pearlmutter.net
Mon Feb 15 19:35:48 UTC 2016
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to tag mlpack-1.0.10
in repository mlpack.
commit 33ba453de9e3efadd828787f6f1fb5b219260c9c
Author: Ryan Curtin <ryan at ratml.org>
Date: Mon Aug 25 21:49:46 2014 +0000
Backport trunk fixes for NeighborSearch.
---
.../neighbor_search/neighbor_search_impl.hpp | 23 +--
.../neighbor_search/neighbor_search_rules_impl.hpp | 209 ++++++++-------------
2 files changed, 80 insertions(+), 152 deletions(-)
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index c8ee610..33b2577 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -4,21 +4,6 @@
*
* Implementation of Neighbor-Search class to perform all-nearest-neighbors on
* two specified data sets.
- *
- * This file is part of MLPACK 1.0.9.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_IMPL_HPP
#define __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_IMPL_HPP
@@ -267,19 +252,21 @@ void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
}
else if (singleMode)
{
-
// The search doesn't work if the root node is also a leaf node.
// if this is the case, it is suggested that you use the naive method.
assert(!(referenceTree->IsLeaf()));
-
+
// Create the traverser.
typename TreeType::template SingleTreeTraverser<RuleType> traverser(rules);
// Now have it traverse for each point.
for (size_t i = 0; i < querySet.n_cols; ++i)
traverser.Traverse(i, *referenceTree);
+
+ Log::Info << rules.Scores() << " node combinations were scored.\n";
+ Log::Info << rules.BaseCases() << " base cases were calculated.\n";
}
- else // Dual-tree recursion.referenceTree
+ else // Dual-tree recursion.
{
// Create the traverser.
typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
index 5d1cc99..bdafbb1 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
@@ -3,21 +3,6 @@
* @author Ryan Curtin
*
* Implementation of NearestNeighborRules.
- *
- * This file is part of MLPACK 1.0.9.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_RULES_IMPL_HPP
#define __MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_RULES_IMPL_HPP
@@ -328,136 +313,92 @@ template<typename SortPolicy, typename MetricType, typename TreeType>
inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
CalculateBound(TreeType& queryNode) const
{
- // We have five possible bounds, and we must take the best of them all. We
- // don't use min/max here, but instead "best/worst", because this is general
- // to the nearest-neighbors/furthest-neighbors cases. For nearest neighbors,
- // min = best, max = worst.
- //
- // (1) worst ( worst_{all points p in queryNode} D_p[k],
- // worst_{all children c in queryNode} B(c) );
- // (2) best_{all points p in queryNode} D_p[k] + worst child distance +
- // worst descendant distance;
- // (3) best_{all children c in queryNode} B(c) +
- // 2 ( worst descendant distance of queryNode -
- // worst descendant distance of c );
- // (4) B_1(parent of queryNode)
- // (5) B_2(parent of queryNode);
- //
- // D_p[k] is the current k'th candidate distance for point p.
- // So we will loop over the points in queryNode and the children in queryNode
- // to calculate all five of these quantities.
-
- // Hm, can we populate our distances vector with estimates from the parent?
- // This is written specifically for the cover tree and assumes only one point
- // in a node.
-// if (queryNode.Parent() != NULL && queryNode.NumPoints() > 0)
-// {
-// size_t parentIndexStart = 0;
-// for (size_t i = 0; i < neighbors.n_rows; ++i)
-// {
-// const double pointDistance = distances(i, queryNode.Point(0));
-// if (pointDistance == DBL_MAX)
-// {
-// // Cool, can we take an estimate from the parent?
-// const double parentWorstBound = distances(distances.n_rows - 1,
-// queryNode.Parent()->Point(0));
-// if (parentWorstBound != DBL_MAX)
-// {
-// const double parentAdjustedDistance = parentWorstBound +
-// queryNode.ParentDistance();
-// distances(i, queryNode.Point(0)) = parentAdjustedDistance;
-// }
-// }
-// }
-// }
-
- double worstPointDistance = SortPolicy::BestDistance();
- double bestPointDistance = SortPolicy::WorstDistance();
-
- // Loop over all points in this node to find the best and worst distance
- // candidates (for (1) and (2)).
+ // This is an adapted form of the B(N_q) function in the paper
+ // ``Tree-Independent Dual-Tree Algorithms'' by Curtin et. al.; the goal is to
+ // place a bound on the worst possible distance a point combination could have
+ // to improve any of the current neighbor estimates. If the best possible
+ // distance between two nodes is greater than this bound, then the node
+ // combination can be pruned (see Score()).
+
+ // There are a couple ways we can assemble a bound. For simplicity, this is
+ // described for nearest neighbor search (SortPolicy = NearestNeighborSort),
+ // but the code that is written is adapted for whichever SortPolicy.
+
+ // First, we can consider the current worst neighbor candidate distance of any
+ // descendant point. This is assembled with 'worstDistance' by looping
+ // through the points held by the query node, and then by taking the cached
+ // worst distance from any child nodes (Stat().FirstBound()). This
+ // corresponds roughly to B_1(N_q) in the paper.
+
+ // The other way of bounding is to use the triangle inequality. To do this,
+ // we find the current best kth-neighbor candidate distance of any descendant
+ // query point, and use the triangle inequality to place a bound on the
+ // distance that candidate would have to any other descendant query point.
+ // This corresponds roughly to B_2(N_q) in the paper, and is the bounding
+ // style for cover trees.
+
+ // Then, to assemble the final bound, since both bounds are valid, we simply
+ // take the better of the two.
+
+ double worstDistance = SortPolicy::BestDistance();
+ double bestDistance = SortPolicy::WorstDistance();
+
+ // Loop over points held in the node.
for (size_t i = 0; i < queryNode.NumPoints(); ++i)
{
- const double distance = distances(distances.n_rows - 1,
- queryNode.Point(i));
- if (SortPolicy::IsBetter(distance, bestPointDistance))
- bestPointDistance = distance;
- if (SortPolicy::IsBetter(worstPointDistance, distance))
- worstPointDistance = distance;
+ const double distance = distances(distances.n_rows - 1, queryNode.Point(i));
+ if (SortPolicy::IsBetter(worstDistance, distance))
+ worstDistance = distance;
+ if (SortPolicy::IsBetter(distance, bestDistance))
+ bestDistance = distance;
}
- // Loop over all the children in this node to find the worst bound (for (1))
- // and the best bound with the correcting factor for descendant distances (for
- // (3)).
- double worstChildBound = SortPolicy::BestDistance();
- double bestAdjustedChildBound = SortPolicy::WorstDistance();
- const double queryMaxDescendantDistance =
+ // Add triangle inequality adjustment to best distance. It is possible this
+ // could be tighter for some certain types of trees.
+ bestDistance += queryNode.FurthestPointDistance() +
queryNode.FurthestDescendantDistance();
+ // Loop over children of the node, and use their cached information to
+ // assemble bounds.
for (size_t i = 0; i < queryNode.NumChildren(); ++i)
{
const double firstBound = queryNode.Child(i).Stat().FirstBound();
- const double secondBound = queryNode.Child(i).Stat().SecondBound();
- const double childMaxDescendantDistance =
- queryNode.Child(i).FurthestDescendantDistance();
-
- if (SortPolicy::IsBetter(worstChildBound, firstBound))
- worstChildBound = firstBound;
-
- // Now calculate adjustment for maximum descendant distances.
- const double adjustedBound = SortPolicy::CombineWorst(secondBound,
- 2 * (queryMaxDescendantDistance - childMaxDescendantDistance));
- if (SortPolicy::IsBetter(adjustedBound, bestAdjustedChildBound))
- bestAdjustedChildBound = adjustedBound;
+ const double adjustedSecondBound = queryNode.Child(i).Stat().SecondBound() +
+ 2 * (queryNode.FurthestDescendantDistance() -
+ queryNode.Child(i).FurthestDescendantDistance());
+
+ if (SortPolicy::IsBetter(worstDistance, firstBound))
+ worstDistance = firstBound;
+ if (SortPolicy::IsBetter(adjustedSecondBound, bestDistance))
+ bestDistance = adjustedSecondBound;
}
- // This is bound (1).
- const double firstBound =
- (SortPolicy::IsBetter(worstPointDistance, worstChildBound)) ?
- worstChildBound : worstPointDistance;
-
- // This is bound (2).
- const double secondBound = SortPolicy::CombineWorst(
- SortPolicy::CombineWorst(bestPointDistance, queryMaxDescendantDistance),
- queryNode.FurthestPointDistance());
-
- // Bound (3) is bestAdjustedChildBound.
-
- // Bounds (4) and (5) are the parent bounds.
- const double fourthBound = (queryNode.Parent() != NULL) ?
- queryNode.Parent()->Stat().FirstBound() : SortPolicy::WorstDistance();
-// const double fifthBound = (queryNode.Parent() != NULL) ?
-// queryNode.Parent()->Stat().SecondBound() -
-// queryNode.Parent()->FurthestDescendantDistance() -
-// queryNode.Parent()->FurthestPointDistance() + queryMaxDescendantDistance +
-// queryNode.FurthestPointDistance() + queryNode.ParentDistance() :
-// SortPolicy::WorstDistance();
-
- // Now, we will take the best of these. Unfortunately due to the way
- // IsBetter() is defined, this sort of has to be a little ugly.
- // The variable interA represents the first bound (B_1), which is the worst
- // candidate distance of any descendants of this node.
- // The variable interC represents the second bound (B_2), which is a bound on
- // the worst distance of any descendants of this node assembled using the best
- // descendant candidate distance modified using the furthest descendant
- // distance.
- const double interA = (SortPolicy::IsBetter(firstBound, fourthBound)) ?
- firstBound : fourthBound;
- const double interB =
- (SortPolicy::IsBetter(bestAdjustedChildBound, secondBound)) ?
- bestAdjustedChildBound : secondBound;
-// const double interC = (SortPolicy::IsBetter(interB, fifthBound)) ? interB :
-// fifthBound;
-
- // Update the first and second bounds of the node.
- queryNode.Stat().FirstBound() = interA;
- queryNode.Stat().SecondBound() = interB;
-
- // Update the actual bound of the node.
- queryNode.Stat().Bound() = (SortPolicy::IsBetter(interA, interB)) ? interB :
- interB;
-
- return queryNode.Stat().Bound();
+ // Now consider the parent bounds.
+ if (queryNode.Parent() != NULL)
+ {
+ // The parent's worst distance bound implies that the bound for this node
+ // must be at least as good. Thus, if the parent worst distance bound is
+ // better, then take it.
+ if (SortPolicy::IsBetter(queryNode.Parent()->Stat().FirstBound(),
+ worstDistance))
+ worstDistance = queryNode.Parent()->Stat().FirstBound();
+
+ // The parent's best distance bound implies that the bound for this node
+ // must be at least as good. Thus, if the parent best distance bound is
+ // better, then take it.
+ if (SortPolicy::IsBetter(queryNode.Parent()->Stat().SecondBound(),
+ bestDistance))
+ bestDistance = queryNode.Parent()->Stat().SecondBound();
+ }
+
+ // Cache bounds for later.
+ queryNode.Stat().FirstBound() = worstDistance;
+ queryNode.Stat().SecondBound() = bestDistance;
+
+ if (SortPolicy::IsBetter(worstDistance, bestDistance))
+ return worstDistance;
+ else
+ return bestDistance;
}
/**
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list