[mlpack] 321/324: X tree
Barak A. Pearlmutter
barak+git at cs.nuim.ie
Sun Aug 17 08:22:22 UTC 2014
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch svn-trunk
in repository mlpack.
commit d02aa8ac19bf7396dddab9ce26871550bb8f3278
Author: andrewmw94 <andrewmw94 at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date: Sat Aug 16 15:12:06 2014 +0000
X tree
git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@17041 9d5b8971-822b-0410-80eb-d18c1038ef23
---
.../tree/rectangle_tree/rectangle_tree_impl.hpp | 6 +-
.../core/tree/rectangle_tree/x_tree_split_impl.hpp | 63 ++++++---
src/mlpack/tests/rectangle_tree_test.cpp | 147 ++++++++++++++++++++-
3 files changed, 189 insertions(+), 27 deletions(-)
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
index 9000f48..afbb5dc 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
@@ -665,8 +665,11 @@ void RectangleTree<SplitType, DescentType, StatisticType, MatType>::
// If there are multiple children, we can't do anything to the root.
RectangleTree<SplitType, DescentType, StatisticType, MatType>* child =
children[0];
- for (size_t i = 0; i < child->NumChildren(); i++)
+ for (size_t i = 0; i < child->NumChildren(); i++) {
children[i] = child->Children()[i];
+ children[i]->Parent() = this;
+ }
+
numChildren = child->NumChildren();
for (size_t i = 0; i < child->Count(); i++)
@@ -677,6 +680,7 @@ void RectangleTree<SplitType, DescentType, StatisticType, MatType>::
}
count = child->Count();
+ maxNumChildren = child->MaxNumChildren(); // Required for the X tree.
child->SoftDelete();
return;
}
diff --git a/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp
index 1ea4018..f1530a2 100644
--- a/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp
@@ -232,11 +232,11 @@ void XTreeSplit<DescentType, StatisticType, MatType>::SplitLeafNode(
if (par->NumChildren() == par->MaxNumChildren()+1) {
SplitNonLeafNode(par, relevels);
}
-
- assert(treeOne->Parent()->NumChildren() <= treeOne->MaxNumChildren());
- assert(treeOne->Parent()->NumChildren() >= treeOne->MinNumChildren());
- assert(treeTwo->Parent()->NumChildren() <= treeTwo->MaxNumChildren());
- assert(treeTwo->Parent()->NumChildren() >= treeTwo->MinNumChildren());
+
+ assert(treeOne->Parent()->NumChildren() <= treeOne->Parent()->MaxNumChildren());
+ assert(treeOne->Parent()->NumChildren() >= treeOne->Parent()->MinNumChildren());
+ assert(treeTwo->Parent()->NumChildren() <= treeTwo->Parent()->MaxNumChildren());
+ assert(treeTwo->Parent()->NumChildren() >= treeTwo->Parent()->MinNumChildren());
tree->SoftDelete();
@@ -600,17 +600,41 @@ bool XTreeSplit<DescentType, StatisticType, MatType>::SplitNonLeafNode(
+ // We don't create a supernode that would be the only child of the root.
+ // (Note that if you did try to do so you would need to update the parent field on
+ // each child of this new node as creating a supernode causes the function to return
+ // before that is done.
+
+ // I thought commenting out the bellow would make the tree less efficient but would still work.
+ // It doesn't. I should look into that to see if there is another bug.
+
+
+ if(tree->Parent()->Parent() == NULL && tree->Parent()->NumChildren() == 1) {
+ // We make the root a supernode instead.
+ tree->Parent()->MaxNumChildren() *= 2;
+ tree->Parent()->Children().resize(tree->Parent()->MaxNumChildren()+1);
+ tree->Parent()->NumChildren() = tree->NumChildren();
+ for(int i = 0; i < tree->NumChildren(); i++) {
+ tree->Parent()->Children()[i] = tree->Children()[i];
+ }
+ delete treeOne;
+ delete treeTwo;
+ tree->NullifyData();
+ tree->SoftDelete();
+ return false;
+ }
+
- // The min overlap split failed so we create a supernode instead.
+ // If we don't have to worry about the root, we just enlarge this node.
tree->MaxNumChildren() *= 2;
- tree->MaxLeafSize() *= 2;
- tree->LocalDataset().resize(tree->LocalDataset().n_rows, 2*tree->LocalDataset().n_cols);
tree->Children().resize(tree->MaxNumChildren()+1);
- tree->Points().resize(tree->MaxLeafSize()+1);
-
- return false;
+ for(int i = 0; i < tree->NumChildren(); i++)
+ tree->Child(i).Parent() = tree;
+ delete treeOne;
+ delete treeTwo;
+ return false;
}
}
@@ -637,11 +661,15 @@ bool XTreeSplit<DescentType, StatisticType, MatType>::SplitNonLeafNode(
break;
}
}
+
par->Children()[index] = treeOne;
par->Children()[par->NumChildren()++] = treeTwo;
// we only add one at a time, so we should only need to test for equality
// just in case, we use an assert.
+
+ if(!(par->NumChildren() <= par->MaxNumChildren()+1))
+ std::cout<<"error " << par->NumChildren() << ", "<<par->MaxNumChildren()+1<<std::endl;
assert(par->NumChildren() <= par->MaxNumChildren()+1);
if (par->NumChildren() == par->MaxNumChildren()+1) {
SplitNonLeafNode(par, relevels);
@@ -655,15 +683,12 @@ bool XTreeSplit<DescentType, StatisticType, MatType>::SplitNonLeafNode(
for (int i = 0; i < treeTwo->NumChildren(); i++) {
treeTwo->Children()[i]->Parent() = treeTwo;
}
-
-
- assert(treeOne->Parent()->NumChildren() <= treeOne->MaxNumChildren());
- assert(treeOne->Parent()->NumChildren() >= treeOne->MinNumChildren());
- assert(treeTwo->Parent()->NumChildren() <= treeTwo->MaxNumChildren());
- assert(treeTwo->Parent()->NumChildren() >= treeTwo->MinNumChildren());
- assert(treeOne->MaxNumChildren() < 7);
- assert(treeTwo->MaxNumChildren() < 7);
+
+ assert(treeOne->Parent()->NumChildren() <= treeOne->Parent()->MaxNumChildren());
+ assert(treeOne->Parent()->NumChildren() >= treeOne->Parent()->MinNumChildren());
+ assert(treeTwo->Parent()->NumChildren() <= treeTwo->Parent()->MaxNumChildren());
+ assert(treeTwo->Parent()->NumChildren() >= treeTwo->Parent()->MinNumChildren());
tree->SoftDelete();
diff --git a/src/mlpack/tests/rectangle_tree_test.cpp b/src/mlpack/tests/rectangle_tree_test.cpp
index d1210db..fd2175f 100644
--- a/src/mlpack/tests/rectangle_tree_test.cpp
+++ b/src/mlpack/tests/rectangle_tree_test.cpp
@@ -224,6 +224,86 @@ void checkExactContainment(const RectangleTree<tree::RStarTreeSplit<tree::RStarT
}
}
+/**
+ * A function to check that containment is as tight as possible.
+ */
+void checkExactContainment(const RectangleTree<tree::XTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+ tree::RStarTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat>& tree) {
+ if(tree.NumChildren() == 0) {
+ for(size_t i = 0; i < tree.Bound().Dim(); i++) {
+ double min = DBL_MAX;
+ double max = -1.0 * DBL_MAX;
+ for(size_t j = 0; j < tree.Count(); j++) {
+ if(tree.LocalDataset().col(j)[i] < min)
+ min = tree.LocalDataset().col(j)[i];
+ if(tree.LocalDataset().col(j)[i] > max)
+ max = tree.LocalDataset().col(j)[i];
+ }
+ BOOST_REQUIRE_EQUAL(max, tree.Bound()[i].Hi());
+ BOOST_REQUIRE_EQUAL(min, tree.Bound()[i].Lo());
+ }
+ } else {
+ for(size_t i = 0; i < tree.Bound().Dim(); i++) {
+ double min = DBL_MAX;
+ double max = -1.0 * DBL_MAX;
+ for(size_t j = 0; j < tree.NumChildren(); j++) {
+ if(tree.Child(j).Bound()[i].Lo() < min)
+ min = tree.Child(j).Bound()[i].Lo();
+ if(tree.Child(j).Bound()[i].Hi() > max)
+ max = tree.Child(j).Bound()[i].Hi();
+ }
+ BOOST_REQUIRE_EQUAL(max, tree.Bound()[i].Hi());
+ BOOST_REQUIRE_EQUAL(min, tree.Bound()[i].Lo());
+ }
+ for(size_t i = 0; i < tree.NumChildren(); i++)
+ checkExactContainment(tree.Child(i));
+ }
+}
+
+/**
+ * A function to check that parents and children are set correctly.
+ */
+void checkHierarchy(const RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+ tree::RTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat>& tree) {
+ for(size_t i = 0; i < tree.NumChildren(); i++) {
+ BOOST_REQUIRE_EQUAL(&tree, tree.Child(i).Parent());
+ checkHierarchy(tree.Child(i));
+ }
+}
+
+/**
+ * A function to check that parents and children are set correctly.
+ */
+void checkHierarchy(const RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+ tree::RStarTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat>& tree) {
+ for(size_t i = 0; i < tree.NumChildren(); i++) {
+ BOOST_REQUIRE_EQUAL(&tree, tree.Child(i).Parent());
+ checkHierarchy(tree.Child(i));
+ }
+}
+
+/**
+ * A function to check that parents and children are set correctly.
+ */
+void checkHierarchy(const RectangleTree<tree::XTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+ tree::RStarTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat>& tree) {
+ for(size_t i = 0; i < tree.NumChildren(); i++) {
+ BOOST_REQUIRE_EQUAL(&tree, tree.Child(i).Parent());
+ checkHierarchy(tree.Child(i));
+ }
+}
+
+
+
+
// Test to see if the bounds of the tree are correct. (Cover all bounds and points
// beneath this node of the tree).
BOOST_AUTO_TEST_CASE(RectangleTreeContainmentTest) {
@@ -568,6 +648,29 @@ void checkSync(const RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHe
return;
}
+/**
+ * A function to ensure that the dataset for the tree, and the datasets stored
+ * in each leaf node are in sync.
+ * @param tree The tree to check.
+ */
+void checkSync(const RectangleTree<tree::XTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+ tree::RStarTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat>& tree) {
+ if (tree.IsLeaf()) {
+ for (size_t i = 0; i < tree.Count(); i++) {
+ for (size_t j = 0; j < tree.LocalDataset().n_rows; j++) {
+ BOOST_REQUIRE_EQUAL(tree.LocalDataset().col(i)[j], tree.Dataset().col(tree.Points()[i])[j]);
+ }
+ }
+ } else {
+ for (size_t i = 0; i < tree.NumChildren(); i++) {
+ checkSync(*tree.Children()[i]);
+ }
+ }
+ return;
+}
+
// A test to ensure that the SingleTreeTraverser is working correctly by comparing
// its results to the results of a naive search.
BOOST_AUTO_TEST_CASE(SingleTreeTraverserTest) {
@@ -595,6 +698,7 @@ BOOST_AUTO_TEST_CASE(SingleTreeTraverserTest) {
checkSync(RTree);
checkContainment(RTree);
checkExactContainment(RTree);
+ checkHierarchy(RTree);
allknn1.Search(5, neighbors1, distances1);
@@ -631,14 +735,25 @@ BOOST_AUTO_TEST_CASE(SingleTreeTraverserTest) {
-/*
+
+
+
+
+
+
+
+
+
// A test to ensure that the SingleTreeTraverser is working correctly by comparing
// its results to the results of a naive search.
BOOST_AUTO_TEST_CASE(XTreeTraverserTest) {
arma::mat dataset;
- dataset.randu(8, 1000); // 1000 points in 8 dimensions.
+
+ const int numP = 1000;
+
+ dataset.randu(8, numP); // 1000 points in 8 dimensions.
arma::Mat<size_t> neighbors1;
arma::mat distances1;
arma::Mat<size_t> neighbors2;
@@ -657,10 +772,11 @@ BOOST_AUTO_TEST_CASE(XTreeTraverserTest) {
arma::mat> > allknn1(&RTree,
dataset, true);
- BOOST_REQUIRE_EQUAL(RTree.NumDescendants(), 1000);
-// checkSync(RTree);
-// checkContainment(RTree);
-// checkExactContainment(RTree);
+ BOOST_REQUIRE_EQUAL(RTree.NumDescendants(), numP);
+ checkSync(RTree);
+ //checkContainment(RTree);
+ checkExactContainment(RTree);
+ checkHierarchy(RTree);
allknn1.Search(5, neighbors1, distances1);
@@ -674,13 +790,30 @@ BOOST_AUTO_TEST_CASE(XTreeTraverserTest) {
BOOST_REQUIRE_EQUAL(neighbors1[i], neighbors2[i]);
BOOST_REQUIRE_EQUAL(distances1[i], distances2[i]);
}
+
+ //std::cout<<""<<RTree.ToString()<<std::endl;
}
-*/
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list