[mlpack] 130/324: Refactor RASearch so that it does not accept a leafSize parameter and can build arbitrary tree types.
Barak A. Pearlmutter
barak+git at cs.nuim.ie
Sun Aug 17 08:22:03 UTC 2014
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch svn-trunk
in repository mlpack.
commit 265c8fdb60ec6a85b38ca0d2433fc9305f9a5bec
Author: rcurtin <rcurtin at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date: Mon Jul 7 13:10:04 2014 +0000
Refactor RASearch so that it does not accept a leafSize parameter and can build
arbitrary tree types.
git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@16769 9d5b8971-822b-0410-80eb-d18c1038ef23
---
src/mlpack/methods/rann/ra_search.hpp | 10 +-
src/mlpack/methods/rann/ra_search_impl.hpp | 194 ++++++++++++++++-------------
2 files changed, 110 insertions(+), 94 deletions(-)
diff --git a/src/mlpack/methods/rann/ra_search.hpp b/src/mlpack/methods/rann/ra_search.hpp
index c25e468..a5952bc 100644
--- a/src/mlpack/methods/rann/ra_search.hpp
+++ b/src/mlpack/methods/rann/ra_search.hpp
@@ -79,7 +79,6 @@ class RASearch
const typename TreeType::Mat& querySet,
const bool naive = false,
const bool singleMode = false,
- const size_t leafSize = 20,
const MetricType metric = MetricType());
/**
@@ -107,7 +106,6 @@ class RASearch
RASearch(const typename TreeType::Mat& referenceSet,
const bool naive = false,
const bool singleMode = false,
- const size_t leafSize = 20,
const MetricType metric = MetricType());
/**
@@ -258,10 +256,10 @@ class RASearch
//! Pointer to the root of the query tree (might not exist).
TreeType* queryTree;
- //! Indicates if we should free the reference tree at deletion time.
- bool ownReferenceTree;
- //! Indicates if we should free the query tree at deletion time.
- bool ownQueryTree;
+ //! If true, this object created the trees and is responsible for them.
+ bool treeOwner;
+ //! Indicates if a separate query set was passed.
+ bool hasQuerySet;
//! Indicates if naive random sampling on the set is being used.
bool naive;
diff --git a/src/mlpack/methods/rann/ra_search_impl.hpp b/src/mlpack/methods/rann/ra_search_impl.hpp
index dd46d18..50188b3 100644
--- a/src/mlpack/methods/rann/ra_search_impl.hpp
+++ b/src/mlpack/methods/rann/ra_search_impl.hpp
@@ -15,23 +15,50 @@
namespace mlpack {
namespace neighbor {
+namespace aux {
+
+//! Call the tree constructor that does mapping.
+template<typename TreeType>
+TreeType* BuildTree(
+ typename TreeType::Mat& dataset,
+ std::vector<size_t>& oldFromNew,
+ typename boost::enable_if_c<
+ tree::TreeTraits<TreeType>::RearrangesDataset == true, TreeType*
+ >::type = 0)
+{
+ return new TreeType(dataset, oldFromNew);
+}
+
+//! Call the tree constructor that does not do mapping.
+template<typename TreeType>
+TreeType* BuildTree(
+ const typename TreeType::Mat& dataset,
+ const std::vector<size_t>& /* oldFromNew */,
+ const typename boost::enable_if_c<
+ tree::TreeTraits<TreeType>::RearrangesDataset == false, TreeType*
+ >::type = 0)
+{
+ return new TreeType(dataset);
+}
+
+}; // namespace aux
+
// Construct the object.
template<typename SortPolicy, typename MetricType, typename TreeType>
RASearch<SortPolicy, MetricType, TreeType>::
-RASearch(const typename TreeType::Mat& referenceSet,
- const typename TreeType::Mat& querySet,
+RASearch(const typename TreeType::Mat& referenceSetIn,
+ const typename TreeType::Mat& querySetIn,
const bool naive,
const bool singleMode,
- const size_t leafSize,
const MetricType metric) :
- referenceCopy(referenceSet),
- queryCopy(querySet),
- referenceSet(referenceCopy),
- querySet(queryCopy),
+ referenceSet(tree::TreeTraits<TreeType>::RearrangesDataset ? referenceCopy :
+ referenceSetIn),
+ querySet((tree::TreeTraits<TreeType>::RearrangesDataset && !singleMode) ?
+ queryCopy : querySetIn),
referenceTree(NULL),
queryTree(NULL),
- ownReferenceTree(true), // False if a tree was passed.
- ownQueryTree(true), // False if a tree was passed.
+ treeOwner(!naive),
+ hasQuerySet(true),
naive(naive),
singleMode(!naive && singleMode), // No single mode if naive.
metric(metric),
@@ -40,12 +67,22 @@ RASearch(const typename TreeType::Mat& referenceSet,
// We'll time tree building.
Timer::Start("tree_building");
+ if (tree::TreeTraits<TreeType>::RearrangesDataset)
+ {
+ referenceCopy = referenceSetIn;
+ if (!singleMode)
+ queryCopy = querySetIn;
+ }
+
// Construct as a naive object if we need to.
if (!naive)
{
- referenceTree = new TreeType(referenceCopy, oldFromNewReferences, leafSize);
+ referenceTree = aux::BuildTree<TreeType>(const_cast<typename
+ TreeType::Mat&>(referenceSet), oldFromNewReferences);
- queryTree = new TreeType(queryCopy, oldFromNewQueries, leafSize);
+ if (!singleMode)
+ queryTree = aux::BuildTree<TreeType>(const_cast<typename
+ TreeType::Mat&>(querySet), oldFromNewQueries);
}
// Stop the timer we started above.
@@ -55,18 +92,18 @@ RASearch(const typename TreeType::Mat& referenceSet,
// Construct the object.
template<typename SortPolicy, typename MetricType, typename TreeType>
RASearch<SortPolicy, MetricType, TreeType>::
-RASearch(const typename TreeType::Mat& referenceSet,
+RASearch(const typename TreeType::Mat& referenceSetIn,
const bool naive,
const bool singleMode,
- const size_t leafSize,
const MetricType metric) :
- referenceCopy(referenceSet),
- referenceSet(referenceCopy),
- querySet(referenceCopy),
+ referenceSet(tree::TreeTraits<TreeType>::RearrangesDataset ? referenceCopy :
+ referenceSetIn),
+ querySet(tree::TreeTraits<TreeType>::RearrangesDataset && !singleMode ?
+ referenceCopy : referenceSetIn),
referenceTree(NULL),
queryTree(NULL),
- ownReferenceTree(true),
- ownQueryTree(false), // Since it will be the same as referenceTree.
+ treeOwner(!naive),
+ hasQuerySet(false),
naive(naive),
singleMode(!naive && singleMode), // No single mode if naive.
metric(metric),
@@ -75,11 +112,13 @@ RASearch(const typename TreeType::Mat& referenceSet,
// We'll time tree building.
Timer::Start("tree_building");
+ if (tree::TreeTraits<TreeType>::RearrangesDataset)
+ referenceCopy = referenceSetIn;
+
// Construct as a naive object if we need to.
if (!naive)
- {
- referenceTree = new TreeType(referenceCopy, oldFromNewReferences, leafSize);
- }
+ referenceTree = aux::BuildTree<TreeType>(const_cast<typename
+ TreeType::Mat&>(referenceSet), oldFromNewReferences);
// Stop the timer we started above.
Timer::Stop("tree_building");
@@ -98,8 +137,8 @@ RASearch(TreeType* referenceTree,
querySet(querySet),
referenceTree(referenceTree),
queryTree(queryTree),
- ownReferenceTree(false),
- ownQueryTree(false),
+ treeOwner(false),
+ hasQuerySet(true),
naive(false),
singleMode(singleMode),
metric(metric),
@@ -114,16 +153,16 @@ RASearch(TreeType* referenceTree,
const typename TreeType::Mat& referenceSet,
const bool singleMode,
const MetricType metric) :
- referenceSet(referenceSet),
- querySet(referenceSet),
- referenceTree(referenceTree),
- queryTree(NULL),
- ownReferenceTree(false),
- ownQueryTree(false),
- naive(false),
- singleMode(singleMode),
- metric(metric),
- numberOfPrunes(0)
+ referenceSet(referenceSet),
+ querySet(referenceSet),
+ referenceTree(referenceTree),
+ queryTree(NULL),
+ treeOwner(false),
+ hasQuerySet(false),
+ naive(false),
+ singleMode(singleMode),
+ metric(metric),
+ numberOfPrunes(0)
// Nothing else to initialize.
{ }
@@ -135,10 +174,13 @@ template<typename SortPolicy, typename MetricType, typename TreeType>
RASearch<SortPolicy, MetricType, TreeType>::
~RASearch()
{
- if (ownReferenceTree)
- delete referenceTree;
- if (ownQueryTree)
- delete queryTree;
+ if (treeOwner)
+ {
+ if (referenceTree)
+ delete referenceTree;
+ if (queryTree)
+ delete queryTree;
+ }
}
/**
@@ -165,11 +207,13 @@ Search(const size_t k,
arma::Mat<size_t>* neighborPtr = &resultingNeighbors;
arma::mat* distancePtr = &distances;
- if (!naive) // If naive, no re-mapping required since points are not mapped.
+ // Mapping is only required if this tree type rearranges points and we are not
+ // in naive mode.
+ if (tree::TreeTraits<TreeType>::RearrangesDataset)
{
- if (ownQueryTree || (ownReferenceTree && !queryTree))
+ if (treeOwner && !(singleMode && hasQuerySet))
distancePtr = new arma::mat; // Query indices need to be mapped.
- if (ownReferenceTree || ownQueryTree)
+ if (treeOwner)
neighborPtr = new arma::Mat<size_t>; // All indices need mapping.
}
@@ -201,12 +245,8 @@ Search(const size_t k,
// Run the base case on each combination of query point and sampled
// reference point.
for (size_t i = 0; i < querySet.n_cols; ++i)
- {
for (size_t j = 0; j < distinctSamples.n_elem; ++j)
- {
rules.BaseCase(i, (size_t) distinctSamples[j]);
- }
- }
}
else if (singleMode)
{
@@ -274,13 +314,12 @@ Search(const size_t k,
Log::Info << "Pruned " << numPrunes << " nodes." << std::endl;
// Now, do we need to do mapping of indices?
- if ((!ownReferenceTree && !ownQueryTree) || naive)
+ if (!treeOwner || !tree::TreeTraits<TreeType>::RearrangesDataset)
{
- // No mapping needed if we do not own the trees or if we are doing naive
- // sampling. We are done.
+ // No mapping needed. We are done.
return;
}
- else if (ownReferenceTree && ownQueryTree) // Map references and queries.
+ else if (treeOwner && hasQuerySet && !singleMode) // Map both sets.
{
// Set size of output matrices correctly.
resultingNeighbors.set_size(k, querySet.n_cols);
@@ -303,62 +342,41 @@ Search(const size_t k,
delete neighborPtr;
delete distancePtr;
}
- else if (ownReferenceTree)
+ else if (treeOwner && !hasQuerySet)
{
- if (!queryTree) // No query tree -- map both references and queries.
- {
- resultingNeighbors.set_size(k, querySet.n_cols);
- distances.set_size(k, querySet.n_cols);
+ // No query tree -- map both references and queries.
+ resultingNeighbors.set_size(k, querySet.n_cols);
+ distances.set_size(k, querySet.n_cols);
- for (size_t i = 0; i < distances.n_cols; i++)
- {
- // Map distances (copy a column).
- distances.col(oldFromNewReferences[i]) = distancePtr->col(i);
-
- // Map indices of neighbors.
- for (size_t j = 0; j < distances.n_rows; j++)
- {
- resultingNeighbors(j, oldFromNewReferences[i]) =
- oldFromNewReferences[(*neighborPtr)(j, i)];
- }
- }
- }
- else // Map only references.
+ for (size_t i = 0; i < distances.n_cols; i++)
{
- // Set size of neighbor indices matrix correctly.
- resultingNeighbors.set_size(k, querySet.n_cols);
+ // Map distances (copy a column).
+ distances.col(oldFromNewReferences[i]) = distancePtr->col(i);
// Map indices of neighbors.
- for (size_t i = 0; i < resultingNeighbors.n_cols; i++)
+ for (size_t j = 0; j < distances.n_rows; j++)
{
- for (size_t j = 0; j < resultingNeighbors.n_rows; j++)
- {
- resultingNeighbors(j, i) = oldFromNewReferences[(*neighborPtr)(j, i)];
- }
+ resultingNeighbors(j, oldFromNewReferences[i]) =
+ oldFromNewReferences[(*neighborPtr)(j, i)];
}
}
-
- // Finished with temporary matrix.
- delete neighborPtr;
}
- else if (ownQueryTree)
+ else if (treeOwner && hasQuerySet && singleMode) // Map only references.
{
- // Set size of matrices correctly.
+ // Set size of neighbor indices matrix correctly.
resultingNeighbors.set_size(k, querySet.n_cols);
- distances.set_size(k, querySet.n_cols);
- for (size_t i = 0; i < distances.n_cols; i++)
+ // Map indices of neighbors.
+ for (size_t i = 0; i < resultingNeighbors.n_cols; i++)
{
- // Map distances (copy a column).
- distances.col(oldFromNewQueries[i]) = distancePtr->col(i);
-
- // Map indices of neighbors.
- resultingNeighbors.col(oldFromNewQueries[i]) = neighborPtr->col(i);
+ for (size_t j = 0; j < resultingNeighbors.n_rows; j++)
+ {
+ resultingNeighbors(j, i) = oldFromNewReferences[(*neighborPtr)(j, i)];
+ }
}
- // Finished with temporary matrices.
+ // Finished with temporary matrix.
delete neighborPtr;
- delete distancePtr;
}
} // Search
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list