[mlpack] 90/324: Rectangle tree and tests. Construction seems to work.
Barak A. Pearlmutter
barak+git at cs.nuim.ie
Sun Aug 17 08:21:59 UTC 2014
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch svn-trunk
in repository mlpack.
commit cbb12971a01f95c8a875e71cfd5e3d1d6a925c44
Author: andrewmw94 <andrewmw94 at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date: Fri Jun 27 15:30:54 2014 +0000
Rectangle tree and tests. Construction seems to work.
git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@16728 9d5b8971-822b-0410-80eb-d18c1038ef23
---
.../r_tree_descent_heuristic_impl.hpp | 1 -
.../core/tree/rectangle_tree/r_tree_split_impl.hpp | 161 +++++++++++++--------
.../tree/rectangle_tree/rectangle_tree_impl.hpp | 31 ++--
src/mlpack/methods/neighbor_search/allknn_main.cpp | 2 +-
src/mlpack/tests/CMakeLists.txt | 1 +
src/mlpack/tests/rectangle_tree_test.cpp | 142 ++++++++++++++++++
src/mlpack/tests/tree_test.cpp | 1 +
src/mlpack/tests/tree_traits_test.cpp | 1 +
8 files changed, 255 insertions(+), 85 deletions(-)
diff --git a/src/mlpack/core/tree/rectangle_tree/r_tree_descent_heuristic_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_tree_descent_heuristic_impl.hpp
index bc4e701..7c15a22 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_tree_descent_heuristic_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_tree_descent_heuristic_impl.hpp
@@ -15,7 +15,6 @@ namespace tree {
inline double RTreeDescentHeuristic::EvalNode(const HRectBound<>& bound, const arma::vec& point)
{
- std::cout << "eval node called" << std::endl;
return bound.Contains(point) ? 0 : bound.MinDistance(point);
}
diff --git a/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
index 79f75b4..32dbee0 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
@@ -13,6 +13,9 @@
namespace mlpack {
namespace tree {
+ //-r ../test_data_3_1000.csv -n neighbors_out.csv -d distances_out.csv -k 3 -v --r_tree
+
+
/**
* We call GetPointSeeds to get the two points which will be the initial points in the new nodes
* We then call AssignPointDestNode to assign the remaining points to the two new nodes.
@@ -25,23 +28,19 @@ template<typename DescentType,
void RTreeSplit<DescentType, StatisticType, MatType>::SplitLeafNode(
RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>* tree)
{
-
- std::cout << "splitting a leaf node." << std::endl;
-
// If we are splitting the root node, we need will do things differently so that the constructor
// and other methods don't confuse the end user by giving an address of another node.
if(tree->Parent() == NULL) {
RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>* copy =
new RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>(*tree); // We actually want to copy this way. Pointers and everything.
- std::cout << "copy made ." << std::endl;
-
copy->Parent() = tree;
tree->Count() = 0;
tree->Children()[(tree->NumChildren())++] = copy; // Because this was a leaf node, numChildren must be 0.
+ assert(tree->NumChildren() == 1);
RTreeSplit<DescentType, StatisticType, MatType>::SplitLeafNode(copy);
- std::cout << "finished split" << std::endl;
return;
}
+ assert(tree->Parent()->NumChildren() < tree->Parent()->MaxNumChildren());
// Use the quadratic split method from: Guttman "R-Trees: A Dynamic Index Structure for
// Spatial Searching" It is simplified since we don't handle rectangles, only points.
@@ -49,23 +48,15 @@ void RTreeSplit<DescentType, StatisticType, MatType>::SplitLeafNode(
int i = 0;
int j = 0;
GetPointSeeds(*tree, &i, &j);
-
- std::cout << "point seeds found." << std::endl;
RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType> *treeOne = new
RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>(tree->Parent());
RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType> *treeTwo = new
RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>(tree->Parent());
-
- std::cout << "new trees made." << std::endl;
-
// This will assign the ith and jth point appropriately.
AssignPointDestNode(tree, treeOne, treeTwo, i, j);
- std::cout << "assignments made." << std::endl;
-
-
//Remove this node and insert treeOne and treeTwo
RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>* par = tree->Parent();
int index = 0;
@@ -77,23 +68,21 @@ void RTreeSplit<DescentType, StatisticType, MatType>::SplitLeafNode(
}
par->Children()[index] = treeOne;
par->Children()[par->NumChildren()++] = treeTwo;
-
-
- std::cout << "points copied." << std::endl;
-
-
- //because we copied the points to treeOne and treeTwo, we can just delete this node
- // I THOUGHT?
- //delete tree;
+
+ // We need to delete this carefully since references to points are used.
+ tree->softDelete();
// we only add one at a time, so we should only need to test for equality
// just in case, we use an assert.
assert(par->NumChildren() <= par->MaxNumChildren());
if(par->NumChildren() == par->MaxNumChildren()) {
- std::cout << "leaf split calls non-leaf split" << std::endl;
SplitNonLeafNode(par);
}
- std::cout << "about to end leaf split." << std::endl;
+
+ assert(treeOne->Parent()->NumChildren() < treeOne->MaxNumChildren());
+ assert(treeOne->Parent()->NumChildren() >= treeOne->MinNumChildren());
+ assert(treeTwo->Parent()->NumChildren() < treeTwo->MaxNumChildren());
+ assert(treeTwo->Parent()->NumChildren() >= treeTwo->MinNumChildren());
return;
}
@@ -109,61 +98,56 @@ template<typename DescentType,
typename MatType>
bool RTreeSplit<DescentType, StatisticType, MatType>::SplitNonLeafNode(
RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>* tree)
-{
- std::cout << "splitting non-leaf node." << std::endl;
-
+{
// If we are splitting the root node, we need will do things differently so that the constructor
// and other methods don't confuse the end user by giving an address of another node.
if(tree->Parent() == NULL) {
- std::cout << "root node" << std::endl;
RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>* copy =
new RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>(*tree); // We actually want to copy this way. Pointers and everything.
copy->Parent() = tree;
tree->NumChildren() = 0;
tree->Children()[(tree->NumChildren())++] = copy;
RTreeSplit<DescentType, StatisticType, MatType>::SplitNonLeafNode(copy);
-
- std::cout << tree->ToString() << std::endl;
- std::cout << "root split finished" << std::endl;
-
return true;
}
-
- std::cout << "about to get bound seeds" << std::endl;
+
int i = 0;
int j = 0;
GetBoundSeeds(*tree, &i, &j);
- std::cout << "bound seeds" << std::endl;
+ assert(i != j);
RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>* treeOne = new
RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>(tree->Parent());
RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>* treeTwo = new
RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>(tree->Parent());
-
- std::cout << "new nodes created" << std::endl;
// This will assign the ith and jth rectangles appropriately.
AssignNodeDestNode(tree, treeOne, treeTwo, i, j);
-
- std::cout << "nodes assigned" << std::endl;
-
+
//Remove this node and insert treeOne and treeTwo
RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>* par = tree->Parent();
- int index = 0;
+ int index = -1;
for(int i = 0; i < par->NumChildren(); i++) {
if(par->Children()[i] == tree) {
index = i;
break;
}
}
+ assert(index != -1);
par->Children()[index] = treeOne;
par->Children()[par->NumChildren()++] = treeTwo;
+
+ for(int i = 0; i < par->NumChildren(); i++) {
+ if(par->Children()[i] == tree) {
+ assert(par->Children()[i] != tree);
+ }
+ }
// Because we now have pointers to the information stored under this tree,
// we need to delete this node carefully.
tree->softDelete(); //currently does nothing but leak memory.
-
+
// we only add one at a time, so should only need to test for equality
// just in case, we use an assert.
assert(par->NumChildren() <= par->MaxNumChildren());
@@ -179,9 +163,12 @@ bool RTreeSplit<DescentType, StatisticType, MatType>::SplitNonLeafNode(
}
for(int i = 0; i < treeTwo->NumChildren(); i++) {
treeTwo->Children()[i]->Parent() = treeTwo;
- }
+ }
- std::cout << "about to end split non-leaf" << std::endl;
+ assert(treeOne->NumChildren() < treeOne->MaxNumChildren());
+ assert(treeTwo->NumChildren() < treeTwo->MaxNumChildren());
+ assert(treeOne->Parent()->NumChildren() < treeOne->MaxNumChildren());
+
return false;
}
@@ -269,26 +256,27 @@ void RTreeSplit<DescentType, StatisticType, MatType>::AssignPointDestNode(
{
int end = oldTree->Count();
+
assert(end > 1); // If this isn't true, the tree is really weird.
-
// Restart the point counts since we are going to move them.
oldTree->Count() = 0;
treeOne->Count() = 0;
treeTwo->Count() = 0;
-
- std::cout << " about to assign i and j" << std::endl;
-
- treeOne->InsertPoint(oldTree->Dataset().col(intI));
- std::cout << "assignment of i made." << std::endl;
- oldTree->Dataset().col(intI) = oldTree->Dataset().col(--end); // decrement end
+ treeOne->InsertPoint(oldTree->Dataset().col(intI));
treeTwo->InsertPoint(oldTree->Dataset().col(intJ));
- oldTree->Dataset().col(intJ) = oldTree->Dataset().col(--end); // decrement end
-
-
- std::cout << "i and j assigned" << std::endl;
+ // If intJ is the last point in the tree, we need to switch the order so that we remove the correct points.
+ if(intI > intJ) {
+ oldTree->Dataset().col(intI) = oldTree->Dataset().col(--end); // decrement end
+ oldTree->Dataset().col(intJ) = oldTree->Dataset().col(--end); // decrement end
+ } else {
+ oldTree->Dataset().col(intJ) = oldTree->Dataset().col(--end); // decrement end
+ oldTree->Dataset().col(intI) = oldTree->Dataset().col(--end); // decrement end
+ }
+
+
int numAssignedOne = 1;
int numAssignedTwo = 1;
@@ -300,9 +288,7 @@ void RTreeSplit<DescentType, StatisticType, MatType>::AssignPointDestNode(
// The below is safe because if end decreases and the right hand side of the second part of the conjunction changes
// on the same iteration, we added the point to the node with fewer points anyways.
while(end > 0 && end > oldTree->MinLeafSize() - std::min(numAssignedOne, numAssignedTwo)) {
-
- std::cout << "while loop entered with end = "<< end << std::endl;
-
+
int bestIndex = 0;
double bestScore = DBL_MAX;
int bestRect = 1;
@@ -388,11 +374,44 @@ void RTreeSplit<DescentType, StatisticType, MatType>::AssignNodeDestNode(
int end = oldTree->NumChildren();
assert(end > 1); // If this isn't true, the tree is really weird.
- treeOne->Children()[0] = oldTree->Children()[intI];
- oldTree->Children()[intI] = oldTree->Children()[--end]; // decrement end
- treeTwo->Children()[0] = oldTree->Children()[intJ];
- oldTree->Children()[intJ] = oldTree->Children()[--end]; // decrement end
-
+ assert(intI != intJ);
+
+ for(int i = 0; i < oldTree->NumChildren(); i++) {
+ for(int j = i+1; j < oldTree->NumChildren(); j++) {
+ assert(oldTree->Children()[i] != oldTree->Children()[j]);
+ }
+ }
+
+ insertNodeIntoTree(treeOne, oldTree->Children()[intI]);
+ insertNodeIntoTree(treeTwo, oldTree->Children()[intJ]);
+
+ // If intJ is the last node in the tree, we need to switch the order so that we remove the correct nodes.
+ if(intI > intJ) {
+ oldTree->Children()[intI] = oldTree->Children()[--end]; // decrement end
+ oldTree->Children()[intJ] = oldTree->Children()[--end]; // decrement end
+ } else {
+ oldTree->Children()[intJ] = oldTree->Children()[--end]; // decrement end
+ oldTree->Children()[intI] = oldTree->Children()[--end]; // decrement end
+ }
+
+ assert(treeOne->NumChildren() == 1);
+ assert(treeTwo->NumChildren() == 1);
+
+ for(int i = 0; i < end; i++) {
+ for(int j = i+1; j < end; j++) {
+ assert(oldTree->Children()[i] != oldTree->Children()[j]);
+ }
+ }
+
+ for(int i = 0; i < end; i++) {
+ assert(oldTree->Children()[i] != treeOne->Children()[0]);
+ }
+
+ for(int i = 0; i < end; i++) {
+ assert(oldTree->Children()[i] != treeTwo->Children()[0]);
+ }
+
+
int numAssignTreeOne = 1;
int numAssignTreeTwo = 1;
@@ -461,13 +480,29 @@ void RTreeSplit<DescentType, StatisticType, MatType>::AssignNodeDestNode(
if(numAssignTreeOne < numAssignTreeTwo) {
for(int i = 0; i < end; i++) {
insertNodeIntoTree(treeOne, oldTree->Children()[i]);
+ numAssignTreeOne++;
}
} else {
for(int i = 0; i < end; i++) {
insertNodeIntoTree(treeTwo, oldTree->Children()[i]);
+ numAssignTreeTwo++;
}
}
}
+
+ for(int i = 0; i < treeOne->NumChildren(); i++) {
+ for(int j = i+1; j < treeOne->NumChildren(); j++) {
+ assert(treeOne->Children()[i] != treeOne->Children()[j]);
+ }
+ }
+ for(int i = 0; i < treeTwo->NumChildren(); i++) {
+ for(int j = i+1; j < treeTwo->NumChildren(); j++) {
+ assert(treeTwo->Children()[i] != treeTwo->Children()[j]);
+ }
+ }
+ assert(treeOne->NumChildren() == numAssignTreeOne);
+ assert(treeTwo->NumChildren() == numAssignTreeTwo);
+ assert(numAssignTreeOne+numAssignTreeTwo == 5);
}
/**
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
index c5a0ed7..6cd63ae 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
@@ -43,20 +43,13 @@ RectangleTree<SplitType, DescentType, StatisticType, MatType>::RectangleTree(
{
stat = StatisticType(*this);
- std::cout << ToString() << std::endl;
-
-
// For now, just insert the points in order.
RectangleTree* root = this;
//for(int i = firstDataIndex; i < 57; i++) { // 56,57 are the bound for where it works/breaks
for(int i = firstDataIndex; i < data.n_cols; i++) {
- std::cout << "inserting point number: " << i << std::endl;
root->InsertPoint(data.col(i));
- std::cout << "finished inserting point number: " << i << std::endl;
- std::cout << ToString() << std::endl;
}
-
}
template<typename SplitType,
@@ -93,8 +86,6 @@ template<typename SplitType,
RectangleTree<SplitType, DescentType, StatisticType, MatType>::
~RectangleTree()
{
- //LEAK MEMORY
-
for(int i = 0; i < numChildren; i++) {
delete children[i];
}
@@ -126,14 +117,11 @@ template<typename SplitType,
void RectangleTree<SplitType, DescentType, StatisticType, MatType>::
InsertPoint(const arma::vec& point)
{
-
- std::cout << "insert point called" << std::endl;
// Expand the bound regardless of whether it is a leaf node.
bound |= point;
// If this is a leaf node, we stop here and add the point.
if(numChildren == 0) {
- std::cout << "count = " << count << std::endl;
dataset->col(count++) = point;
SplitNode();
return;
@@ -254,8 +242,7 @@ inline size_t RectangleTree<SplitType, DescentType, StatisticType, MatType>::
}
/**
- * Return the number of descendants contained in this node. MEANINIGLESS AS IT CURRENTLY STANDS.
- * USE NumPoints() INSTEAD.
+ * Return the number of descendants under or in this node.
*/
template<typename SplitType,
typename DescentType,
@@ -264,7 +251,15 @@ template<typename SplitType,
inline size_t RectangleTree<SplitType, DescentType, StatisticType, MatType>::
NumDescendants() const
{
- return count;
+ if(numChildren == 0)
+ return count;
+ else {
+ size_t n = 0;
+ for(int i = 0; i < numChildren; i++) {
+ n += children[i]->NumDescendants();
+ }
+ return n;
+ }
}
/**
@@ -328,13 +323,9 @@ void RectangleTree<SplitType, DescentType, StatisticType, MatType>::SplitNode()
if(count < maxLeafSize)
return; // We don't need to split.
- std::cout << "we are actually splitting the node." << std::endl;
// If we are full, then we need to split (or at least try). The SplitType takes
// care of this and of moving up the tree if necessary.
SplitType::SplitLeafNode(this);
- std::cout << "we finished actually splitting the node." << std::endl;
-
- std::cout << ToString() << std::endl;
}
@@ -362,7 +353,7 @@ std::string RectangleTree<SplitType, DescentType, StatisticType, MatType>::ToStr
convert << " Min num of children: " << minNumChildren << std::endl;
convert << " Parent address: " << parent << std::endl;
- // How many levels should we print? This will print the root and it's children.
+ // How many levels should we print? This will print 3 levels (counting the root).
if(parent == NULL || parent->Parent() == NULL) {
for(int i = 0; i < numChildren; i++) {
convert << children[i]->ToString();
diff --git a/src/mlpack/methods/neighbor_search/allknn_main.cpp b/src/mlpack/methods/neighbor_search/allknn_main.cpp
index 0541470..432c23e 100644
--- a/src/mlpack/methods/neighbor_search/allknn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/allknn_main.cpp
@@ -278,7 +278,7 @@ int main(int argc, char *argv[])
arma::mat>
refTree(referenceData, leafSize, leafSize/3, 5, 2, 0);
Timer::Stop("tree_building");
- std::cout << "completed tree building" << std::endl;
+ std::cout << "completed tree building " << refTree.NumDescendants() << std::endl;
}
}
else // Cover trees.
diff --git a/src/mlpack/tests/CMakeLists.txt b/src/mlpack/tests/CMakeLists.txt
index ec36031..c49b70f 100644
--- a/src/mlpack/tests/CMakeLists.txt
+++ b/src/mlpack/tests/CMakeLists.txt
@@ -38,6 +38,7 @@ add_executable(mlpack_test
perceptron_test.cpp
radical_test.cpp
range_search_test.cpp
+ rectangle_tree_test.cpp
save_restore_utility_test.cpp
sgd_test.cpp
sort_policy_test.cpp
diff --git a/src/mlpack/tests/rectangle_tree_test.cpp b/src/mlpack/tests/rectangle_tree_test.cpp
new file mode 100644
index 0000000..bbfa4ee
--- /dev/null
+++ b/src/mlpack/tests/rectangle_tree_test.cpp
@@ -0,0 +1,142 @@
+
+/**
+ * @file tree_traits_test.cpp
+ * @author Andrew Wells
+ *
+ * Tests for the RectangleTree class. This should ensure that the class works correctly
+ * and that subsequent changes don't break anything. Because it's only used to test the trees,
+ * it is slow.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/core/tree/tree_traits.hpp>
+#include <mlpack/core/tree/rectangle_tree.hpp>
+#include <mlpack/methods/neighbor_search/neighbor_search.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace mlpack;
+using namespace mlpack::neighbor;
+using namespace mlpack::tree;
+using namespace mlpack::metric;
+
+BOOST_AUTO_TEST_SUITE(RectangleTreeTest);
+
+// Be careful! When writing new tests, always get the boolean value and store
+// it in a temporary, because the Boost unit test macros do weird things and
+// will cause bizarre problems.
+
+// Test the traits on RectangleTrees.
+BOOST_AUTO_TEST_CASE(RectangeTreeTraitsTest)
+{
+ // Children may be overlapping.
+ bool b = TreeTraits<RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+ tree::RTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat> >::HasOverlappingChildren;
+ BOOST_REQUIRE_EQUAL(b, true);
+
+ // Points are not contained in multiple levels.
+ b = TreeTraits<RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+ tree::RTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat> >::HasSelfChildren;
+ BOOST_REQUIRE_EQUAL(b, false);
+}
+
+BOOST_AUTO_TEST_CASE(RectangleTreeConstructionCountTest)
+{
+ arma::mat dataset;
+ dataset.randu(3, 1000); // 1000 points in 3 dimensions.
+
+ RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+ tree::RTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat> tree(dataset, 20, 6, 5, 2, 0);
+ BOOST_REQUIRE_EQUAL(tree.NumDescendants(), 1000);
+}
+
+std::vector<arma::vec*> getAllPointsInTree(const RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+ tree::RTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat>& tree)
+{
+ std::vector<arma::vec*> vec;
+ if(tree.NumChildren() > 0) {
+ for(size_t i = 0; i < tree.NumChildren(); i++) {
+ std::vector<arma::vec*> tmp = getAllPointsInTree(*(tree.Children()[i]));
+ vec.insert(vec.begin(), tmp.begin(), tmp.end());
+ }
+ } else {
+ for(size_t i = 0; i < tree.Count(); i++) {
+ arma::vec* c = new arma::vec(tree.Dataset().col(i));
+ vec.push_back(c);
+ }
+ }
+ return vec;
+}
+
+BOOST_AUTO_TEST_CASE(RectangleTreeConstructionRepeatTest)
+{
+ arma::mat dataset;
+ dataset.randu(8, 1000); // 1000 points in 8 dimensions.
+
+ RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+ tree::RTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat> tree(dataset, 20, 6, 5, 2, 0);
+
+ std::vector<arma::vec*> allPoints = getAllPointsInTree(tree);
+ for(size_t i = 0; i < allPoints.size(); i++) {
+ for(size_t j = i+1; j < allPoints.size(); j++) {
+ arma::vec v1 = *(allPoints[i]);
+ arma::vec v2 = *(allPoints[j]);
+ bool same = true;
+ for(size_t k = 0; k < v1.n_rows; k++) {
+ same &= (v1[k] == v2[k]);
+ }
+ assert(same != true);
+ }
+ }
+ for(size_t i = 0; i < allPoints.size(); i++) {
+ delete allPoints[i];
+ }
+}
+
+bool checkContainment(const RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+ tree::RTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat>& tree)
+{
+ bool passed = true;
+ if(tree.NumChildren() == 0) {
+ for(size_t i = 0; i < tree.Count(); i++) {
+ passed &= tree.Bound().Contains(tree.Dataset().unsafe_col(i));
+ }
+ } else {
+ for(size_t i = 0; i < tree.NumChildren(); i++) {
+ bool p1 = true;
+ for(size_t j = 0; j < tree.Bound().Dim(); j++) {
+ p1 &= tree.Bound()[j].Contains(tree.Children()[i]->Bound()[j]);
+ }
+ passed &= p1;
+ passed &= checkContainment(*(tree.Children()[i]));
+ }
+ }
+ return passed;
+}
+
+BOOST_AUTO_TEST_CASE(RectangleTreeContainmentTest)
+{
+ arma::mat dataset;
+ dataset.randu(8, 1000); // 1000 points in 8 dimensions.
+
+ RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+ tree::RTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat> tree(dataset, 20, 6, 5, 2, 0);
+ assert(checkContainment(tree) == true);
+}
+
+
+BOOST_AUTO_TEST_SUITE_END();
diff --git a/src/mlpack/tests/tree_test.cpp b/src/mlpack/tests/tree_test.cpp
index 7c442ba..a5bc474 100644
--- a/src/mlpack/tests/tree_test.cpp
+++ b/src/mlpack/tests/tree_test.cpp
@@ -8,6 +8,7 @@
#include <mlpack/core/tree/binary_space_tree/binary_space_tree.hpp>
#include <mlpack/core/metrics/lmetric.hpp>
#include <mlpack/core/tree/cover_tree/cover_tree.hpp>
+#include <mlpack/core/tree/rectangle_tree.hpp>
#include <queue>
#include <stack>
diff --git a/src/mlpack/tests/tree_traits_test.cpp b/src/mlpack/tests/tree_traits_test.cpp
index df220fc..f4182b0 100644
--- a/src/mlpack/tests/tree_traits_test.cpp
+++ b/src/mlpack/tests/tree_traits_test.cpp
@@ -12,6 +12,7 @@
#include <mlpack/core/tree/tree_traits.hpp>
#include <mlpack/core/tree/binary_space_tree.hpp>
#include <mlpack/core/tree/cover_tree.hpp>
+#include <mlpack/core/tree/rectangle_tree.hpp>
#include <boost/test/unit_test.hpp>
#include "old_boost_test_definitions.hpp"
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list