[mlpack] 156/324: R tree now has dataset and indices
Barak A. Pearlmutter
barak+git at cs.nuim.ie
Sun Aug 17 08:22:06 UTC 2014
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch svn-trunk
in repository mlpack.
commit 016b20a87245c1bcf1fbff5f6fbf29a5cbc74863
Author: andrewmw94 <andrewmw94 at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date: Wed Jul 9 18:43:09 2014 +0000
R tree now has dataset and indices
git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@16795 9d5b8971-822b-0410-80eb-d18c1038ef23
---
.../core/tree/rectangle_tree/r_tree_split_impl.hpp | 5 ++--
.../core/tree/rectangle_tree/rectangle_tree.hpp | 7 +++++
.../tree/rectangle_tree/rectangle_tree_impl.hpp | 11 ++++---
src/mlpack/tests/rectangle_tree_test.cpp | 34 ++++++++++++++++++++--
4 files changed, 49 insertions(+), 8 deletions(-)
diff --git a/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
index 6e587d0..ad4c51b 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
@@ -197,7 +197,7 @@ void RTreeSplit<DescentType, StatisticType, MatType>::GetPointSeeds(
for(int j = i+1; j < tree.Count(); j++) {
double score = 1.0;
for(int k = 0; k < tree.Bound().Dim(); k++) {
- score *= std::abs(tree.Dataset().at(k, tree.Points()[i]) - tree.Dataset().at(k, tree.Points()[j])); // Points (in the dataset) are stored by column, but this function takes (row, col).
+ score *= std::abs(tree.LocalDataset().at(k, i) - tree.LocalDataset().at(k, j)); // Points (in the dataset) are stored by column, but this function takes (row, col).
}
if(score > worstPairScore) {
worstPairScore = score;
@@ -312,7 +312,7 @@ void RTreeSplit<DescentType, StatisticType, MatType>::AssignPointDestNode(
double newVolOne = 1.0;
double newVolTwo = 1.0;
for(int i = 0; i < oldTree->Bound().Dim(); i++) {
- double c = oldTree->Dataset().col(oldTree->Points()[index])[i];
+ double c = oldTree->LocalDataset().col(index)[i];
newVolOne *= treeOne->Bound()[i].Contains(c) ? treeOne->Bound()[i].Width() :
(c < treeOne->Bound()[i].Lo() ? (treeOne->Bound()[i].Hi() - c) : (c - treeOne->Bound()[i].Lo()));
newVolTwo *= treeTwo->Bound()[i].Contains(c) ? treeTwo->Bound()[i].Width() :
@@ -347,6 +347,7 @@ void RTreeSplit<DescentType, StatisticType, MatType>::AssignPointDestNode(
}
oldTree->Points()[bestIndex] = oldTree->Points()[--end]; // decrement end.
+ oldTree->LocalDataset().col(bestIndex) = oldTree->LocalDataset().col(end);
}
// See if we need to satisfy the minimum fill.
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
index 550cf74..76f879e 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
@@ -74,6 +74,8 @@ class RectangleTree
MatType& dataset;
//! The mapping to the dataset
std::vector<size_t> points;
+ //! The local dataset
+ MatType* localDataset;
public:
//! So other classes can use TreeType::Mat.
@@ -226,6 +228,11 @@ class RectangleTree
const std::vector<size_t>& Points() const { return points; }
//! Modify the points vector for this node. Be careful!
std::vector<size_t>& Points() { return points; }
+
+ //! Get the local dataset of this node.
+ const arma::mat& LocalDataset() const { return *localDataset; }
+ //! Modify the local dataset of this node.
+ arma::mat& LocalDataset() { return *localDataset; }
//! Get the metric which the tree uses.
typename HRectBound<>::MetricType Metric() const { return bound.Metric(); }
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
index ad6f038..8df89ec 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
@@ -40,7 +40,8 @@ RectangleTree<SplitType, DescentType, StatisticType, MatType>::RectangleTree(
bound(data.n_rows),
parentDistance(0),
dataset(data),
- points(maxLeafSize+1) // Add one to make splitting the node simpler.
+ points(maxLeafSize+1), // Add one to make splitting the node simpler.
+ localDataset(new MatType(data.n_rows, static_cast<int>(maxLeafSize)+1)) // Add one to make splitting the node simpler
{
stat = StatisticType(*this);
@@ -71,7 +72,8 @@ RectangleTree<SplitType, DescentType, StatisticType, MatType>::RectangleTree(
bound(parentNode->Bound().Dim()),
parentDistance(0),
dataset(parentNode->Dataset()),
- points(maxLeafSize+1) // Add one to make splitting the node simpler.
+ points(maxLeafSize+1), // Add one to make splitting the node simpler.
+ localDataset(new MatType(static_cast<int>(parentNode->Bound().Dim()), static_cast<int>(maxLeafSize)+1)) // Add one to make splitting the node simpler
{
stat = StatisticType(*this);
}
@@ -92,7 +94,7 @@ RectangleTree<SplitType, DescentType, StatisticType, MatType>::
delete children[i];
}
//if(numChildren == 0)
- //delete points;
+ delete localDataset;
}
@@ -127,7 +129,7 @@ template<typename SplitType,
void RectangleTree<SplitType, DescentType, StatisticType, MatType>::
NullifyData()
{
- //points = NULL;
+ localDataset = NULL;
}
@@ -148,6 +150,7 @@ void RectangleTree<SplitType, DescentType, StatisticType, MatType>::
// If this is a leaf node, we stop here and add the point.
if(numChildren == 0) {
points[count++] = point;
+ localDataset->col(count) = dataset.col(point);
SplitNode();
return;
}
diff --git a/src/mlpack/tests/rectangle_tree_test.cpp b/src/mlpack/tests/rectangle_tree_test.cpp
index c6bb651..28a14fd 100644
--- a/src/mlpack/tests/rectangle_tree_test.cpp
+++ b/src/mlpack/tests/rectangle_tree_test.cpp
@@ -128,7 +128,7 @@ bool checkContainment(const RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeu
BOOST_AUTO_TEST_CASE(RectangleTreeContainmentTest)
{
- arma::mat dataset;
+ arma::mat dataset;
dataset.randu(8, 1000); // 1000 points in 8 dimensions.
RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
@@ -138,6 +138,37 @@ BOOST_AUTO_TEST_CASE(RectangleTreeContainmentTest)
assert(checkContainment(tree) == true);
}
+bool checkSync(const RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+ tree::RTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat>& tree) {
+ if(tree.IsLeaf()) {
+ for(size_t i = 0; i < tree.Count(); i++) {
+ for(size_t j = 0; j < tree.LocalDataset().n_rows; j++) {
+ if(tree.LocalDataset().col(i)[j] != tree.Dataset().col(tree.Points()[i])[j])
+ return false;
+ }
+ }
+ } else {
+ for(size_t i = 0; i < tree.NumChildren(); i++) {
+ if(!checkSync(tree.Children()[i]))
+ return false;
+ }
+ }
+ return true;
+}
+
+BOOST_AUTO_TEST_CASE(TreeLocalDatasetInSync) {
+ arma::mat dataset;
+ dataset.randu(8, 1000); // 1000 points in 8 dimensions.
+
+ RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+ tree::RTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat> tree(dataset, 20, 6, 5, 2, 0);
+ assert(checkSync(tree) == true);
+}
+
BOOST_AUTO_TEST_CASE(SingleTreeTraverserTest)
{
arma::mat dataset;
@@ -174,5 +205,4 @@ BOOST_AUTO_TEST_CASE(SingleTreeTraverserTest)
}
}
-
BOOST_AUTO_TEST_SUITE_END();
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list