[mlpack] 209/324: Bug fix. Node splitting tests.
Barak A. Pearlmutter
barak+git at cs.nuim.ie
Sun Aug 17 08:22:11 UTC 2014
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch svn-trunk
in repository mlpack.
commit cd98bf3d48f2ae038b142c60e7ada1e5c64123e3
Author: andrewmw94 <andrewmw94 at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date: Wed Jul 23 15:43:54 2014 +0000
Bug fix. Node splitting tests.
git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@16851 9d5b8971-822b-0410-80eb-d18c1038ef23
---
.../core/tree/rectangle_tree/r_star_tree_split.hpp | 4 +-
.../tree/rectangle_tree/r_star_tree_split_impl.hpp | 30 +-
.../core/tree/rectangle_tree/r_tree_split.hpp | 6 +-
.../core/tree/rectangle_tree/r_tree_split_impl.hpp | 42 +-
.../core/tree/rectangle_tree/rectangle_tree.hpp | 57 ++-
.../tree/rectangle_tree/rectangle_tree_impl.hpp | 169 ++++++--
src/mlpack/tests/rectangle_tree_test.cpp | 471 ++++++++++++++++++---
7 files changed, 628 insertions(+), 151 deletions(-)
diff --git a/src/mlpack/core/tree/rectangle_tree/r_star_tree_split.hpp b/src/mlpack/core/tree/rectangle_tree/r_star_tree_split.hpp
index c7611e3..70d3300 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_star_tree_split.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_star_tree_split.hpp
@@ -30,13 +30,13 @@ public:
* for Points and Rectangles." If necessary, this split will propagate
* upwards through the tree.
*/
-static void SplitLeafNode(RectangleTree<RStarTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>* tree);
+static void SplitLeafNode(RectangleTree<RStarTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>* tree, std::vector<bool>& relevels);
/**
* Split a non-leaf node using the "default" algorithm. If this is a root node, the
* tree increases in depth.
*/
-static bool SplitNonLeafNode(RectangleTree<RStarTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>* tree);
+static bool SplitNonLeafNode(RectangleTree<RStarTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>* tree, std::vector<bool>& relevels);
private:
/**
diff --git a/src/mlpack/core/tree/rectangle_tree/r_star_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_star_tree_split_impl.hpp
index 22fc327..5e612dc 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_star_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_star_tree_split_impl.hpp
@@ -24,7 +24,8 @@ template<typename DescentType,
typename StatisticType,
typename MatType>
void RStarTreeSplit<DescentType, StatisticType, MatType>::SplitLeafNode(
- RectangleTree<RStarTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>* tree)
+ RectangleTree<RStarTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>* tree,
+ std::vector<bool>& relevels)
{
// If we are splitting the root node, we need will do things differently so that the constructor
// and other methods don't confuse the end user by giving an address of another node.
@@ -36,7 +37,7 @@ void RStarTreeSplit<DescentType, StatisticType, MatType>::SplitLeafNode(
tree->NullifyData();
tree->Child((tree->NumChildren())++) = copy; // Because this was a leaf node, numChildren must be 0.
assert(tree->NumChildren() == 1);
- RStarTreeSplit<DescentType, StatisticType, MatType>::SplitLeafNode(copy);
+ RStarTreeSplit<DescentType, StatisticType, MatType>::SplitLeafNode(copy, relevels);
return;
}
@@ -168,14 +169,14 @@ void RStarTreeSplit<DescentType, StatisticType, MatType>::SplitLeafNode(
// we only add one at a time, so we should only need to test for equality
// just in case, we use an assert.
- assert(par->NumChildren() <= par->MaxNumChildren());
- if (par->NumChildren() == par->MaxNumChildren()) {
- SplitNonLeafNode(par);
+ assert(par->NumChildren() <= par->MaxNumChildren()+1);
+ if (par->NumChildren() == par->MaxNumChildren()+1) {
+ SplitNonLeafNode(par, relevels);
}
- assert(treeOne->Parent()->NumChildren() < treeOne->MaxNumChildren());
+ assert(treeOne->Parent()->NumChildren() <= treeOne->MaxNumChildren());
assert(treeOne->Parent()->NumChildren() >= treeOne->MinNumChildren());
- assert(treeTwo->Parent()->NumChildren() < treeTwo->MaxNumChildren());
+ assert(treeTwo->Parent()->NumChildren() <= treeTwo->MaxNumChildren());
assert(treeTwo->Parent()->NumChildren() >= treeTwo->MinNumChildren());
tree->SoftDelete();
@@ -194,7 +195,8 @@ template<typename DescentType,
typename StatisticType,
typename MatType>
bool RStarTreeSplit<DescentType, StatisticType, MatType>::SplitNonLeafNode(
- RectangleTree<RStarTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>* tree)
+ RectangleTree<RStarTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>* tree,
+ std::vector<bool>& relevels)
{
// If we are splitting the root node, we need will do things differently so that the constructor
// and other methods don't confuse the end user by giving an address of another node.
@@ -205,7 +207,7 @@ bool RStarTreeSplit<DescentType, StatisticType, MatType>::SplitNonLeafNode(
tree->NumChildren() = 0;
tree->NullifyData();
tree->Child((tree->NumChildren())++) = copy;
- RStarTreeSplit<DescentType, StatisticType, MatType>::SplitNonLeafNode(copy);
+ RStarTreeSplit<DescentType, StatisticType, MatType>::SplitNonLeafNode(copy, relevels);
return true;
}
@@ -433,9 +435,9 @@ bool RStarTreeSplit<DescentType, StatisticType, MatType>::SplitNonLeafNode(
// we only add one at a time, so we should only need to test for equality
// just in case, we use an assert.
- assert(par->NumChildren() <= par->MaxNumChildren());
- if (par->NumChildren() == par->MaxNumChildren()) {
- SplitNonLeafNode(par);
+ assert(par->NumChildren() <= par->MaxNumChildren()+1);
+ if (par->NumChildren() == par->MaxNumChildren()+1) {
+ SplitNonLeafNode(par, relevels);
}
// We have to update the children of each of these new nodes so that they record the
@@ -447,9 +449,9 @@ bool RStarTreeSplit<DescentType, StatisticType, MatType>::SplitNonLeafNode(
treeTwo->Child(i)->Parent() = treeTwo;
}
- assert(treeOne->Parent()->NumChildren() < treeOne->MaxNumChildren());
+ assert(treeOne->Parent()->NumChildren() <= treeOne->MaxNumChildren());
assert(treeOne->Parent()->NumChildren() >= treeOne->MinNumChildren());
- assert(treeTwo->Parent()->NumChildren() < treeTwo->MaxNumChildren());
+ assert(treeTwo->Parent()->NumChildren() <= treeTwo->MaxNumChildren());
assert(treeTwo->Parent()->NumChildren() >= treeTwo->MinNumChildren());
tree->SoftDelete();
diff --git a/src/mlpack/core/tree/rectangle_tree/r_tree_split.hpp b/src/mlpack/core/tree/rectangle_tree/r_tree_split.hpp
index e26475c..0397e04 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_tree_split.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_tree_split.hpp
@@ -29,13 +29,15 @@ public:
* Split a leaf node using the "default" algorithm. If necessary, this split will propagate
* upwards through the tree.
*/
-static void SplitLeafNode(RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>* tree);
+static void SplitLeafNode(RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>* tree,
+ std::vector<bool>& relevels);
/**
* Split a non-leaf node using the "default" algorithm. If this is a root node, the
* tree increases in depth.
*/
-static bool SplitNonLeafNode(RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>* tree);
+static bool SplitNonLeafNode(RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>* tree,
+ std::vector<bool>& relevels);
private:
diff --git a/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
index f7f154f..54e4ac9 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
@@ -24,7 +24,8 @@ template<typename DescentType,
typename StatisticType,
typename MatType>
void RTreeSplit<DescentType, StatisticType, MatType>::SplitLeafNode(
- RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>* tree)
+ RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>* tree,
+ std::vector<bool>& relevels)
{
// If we are splitting the root node, we need will do things differently so that the constructor
// and other methods don't confuse the end user by giving an address of another node.
@@ -35,11 +36,10 @@ void RTreeSplit<DescentType, StatisticType, MatType>::SplitLeafNode(
tree->Count() = 0;
tree->NullifyData();
tree->Child((tree->NumChildren())++) = copy; // Because this was a leaf node, numChildren must be 0.
- assert(tree->NumChildren() == 1);
- RTreeSplit<DescentType, StatisticType, MatType>::SplitLeafNode(copy);
+ RTreeSplit<DescentType, StatisticType, MatType>::SplitLeafNode(copy, relevels);
return;
}
- assert(tree->Parent()->NumChildren() < tree->Parent()->MaxNumChildren());
+ assert(tree->Parent()->NumChildren() <= tree->Parent()->MaxNumChildren());
// Use the quadratic split method from: Guttman "R-Trees: A Dynamic Index Structure for
// Spatial Searching" It is simplified since we don't handle rectangles, only points.
@@ -70,14 +70,14 @@ void RTreeSplit<DescentType, StatisticType, MatType>::SplitLeafNode(
// we only add one at a time, so we should only need to test for equality
// just in case, we use an assert.
- assert(par->NumChildren() <= par->MaxNumChildren());
- if (par->NumChildren() == par->MaxNumChildren()) {
- SplitNonLeafNode(par);
+ assert(par->NumChildren() <= par->MaxNumChildren()+1);
+ if (par->NumChildren() == par->MaxNumChildren()+1) {
+ SplitNonLeafNode(par, relevels);
}
- assert(treeOne->Parent()->NumChildren() < treeOne->MaxNumChildren());
+ assert(treeOne->Parent()->NumChildren() <= treeOne->MaxNumChildren());
assert(treeOne->Parent()->NumChildren() >= treeOne->MinNumChildren());
- assert(treeTwo->Parent()->NumChildren() < treeTwo->MaxNumChildren());
+ assert(treeTwo->Parent()->NumChildren() <= treeTwo->MaxNumChildren());
assert(treeTwo->Parent()->NumChildren() >= treeTwo->MinNumChildren());
// We need to delete this carefully since references to points are used.
@@ -97,7 +97,8 @@ template<typename DescentType,
typename StatisticType,
typename MatType>
bool RTreeSplit<DescentType, StatisticType, MatType>::SplitNonLeafNode(
- RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>* tree)
+ RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>* tree,
+ std::vector<bool>& relevels)
{
// If we are splitting the root node, we need will do things differently so that the constructor
// and other methods don't confuse the end user by giving an address of another node.
@@ -108,7 +109,7 @@ bool RTreeSplit<DescentType, StatisticType, MatType>::SplitNonLeafNode(
tree->NumChildren() = 0;
tree->NullifyData();
tree->Child((tree->NumChildren())++) = copy;
- RTreeSplit<DescentType, StatisticType, MatType>::SplitNonLeafNode(copy);
+ RTreeSplit<DescentType, StatisticType, MatType>::SplitNonLeafNode(copy, relevels);
return true;
}
@@ -140,17 +141,15 @@ bool RTreeSplit<DescentType, StatisticType, MatType>::SplitNonLeafNode(
par->Child(par->NumChildren()++) = treeTwo;
for (int i = 0; i < par->NumChildren(); i++) {
- if (par->Child(i) == tree) {
- assert(par->Child(i) != tree);
- }
+ assert(par->Child(i) != tree);
}
// we only add one at a time, so should only need to test for equality
// just in case, we use an assert.
- assert(par->NumChildren() <= par->MaxNumChildren());
+ assert(par->NumChildren() <= par->MaxNumChildren()+1);
- if (par->NumChildren() == par->MaxNumChildren()) {
- SplitNonLeafNode(par);
+ if (par->NumChildren() == par->MaxNumChildren()+1) {
+ SplitNonLeafNode(par, relevels);
}
// We have to update the children of each of these new nodes so that they record the
@@ -162,9 +161,9 @@ bool RTreeSplit<DescentType, StatisticType, MatType>::SplitNonLeafNode(
treeTwo->Child(i)->Parent() = treeTwo;
}
- assert(treeOne->NumChildren() < treeOne->MaxNumChildren());
- assert(treeTwo->NumChildren() < treeTwo->MaxNumChildren());
- assert(treeOne->Parent()->NumChildren() < treeOne->MaxNumChildren());
+ assert(treeOne->NumChildren() <= treeOne->MaxNumChildren());
+ assert(treeTwo->NumChildren() <= treeTwo->MaxNumChildren());
+ assert(treeOne->Parent()->NumChildren() <= treeOne->MaxNumChildren());
// Because we now have pointers to the information stored under this tree,
// we need to delete this node carefully.
@@ -504,9 +503,6 @@ void RTreeSplit<DescentType, StatisticType, MatType>::AssignNodeDestNode(
assert(treeTwo->Child(i) != treeTwo->Child(j));
}
}
- assert(treeOne->NumChildren() == numAssignTreeOne);
- assert(treeTwo->NumChildren() == numAssignTreeTwo);
- assert(numAssignTreeOne + numAssignTreeTwo == 5);
}
/**
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
index f3dea79..f8639de 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
@@ -146,6 +146,27 @@ class RectangleTree
* @param point The point (arma::vec&) to be inserted.
*/
void InsertPoint(const size_t point);
+
+ /**
+ * Inserts a point into the tree, tracking which levels have been inserted into.
+ * The point will be copied to the data matrix of the leaf node where it is
+ * finally inserted, but we pass by reference since it may be passed many times
+ * before it actually reaches a leaf.
+ * @param point The point (arma::vec&) to be inserted.
+ * @param relevels The levels that have been reinserted to on this top level insertion.
+ */
+ void InsertPoint(const size_t point, std::vector<bool>& relevels);
+
+ /**
+ * Inserts a node into the tree, tracking which levels have been inserted into.
+ * The node will be inserted so that the tree remains valid.
+ * @param node The node to be inserted.
+ * @param level The depth that should match the node where this node is finally inserted.
+ * This should be the number returned by calling TreeDepth() from the node that originally
+ * contained "node".
+ * @param relevels The levels that have been reinserted to on this top level insertion.
+ */
+ void InsertNode(const RectangleTree* node, const size_t level, std::vector<bool>& relevels);
/**
* Deletes a point in the tree. The point will be removed from the data matrix
@@ -156,6 +177,21 @@ class RectangleTree
* (ie. the point is not in the tree)
*/
bool DeletePoint(const size_t point);
+
+ /**
+ * Deletes a point in the tree, tracking levels. The point will be removed from the data matrix
+ * of the leaf node where it is store and the bounding rectangles will be updated.
+ * However, the point will be kept in the centeral dataset. (The user may remove it
+ * from there if he wants, but he must not change the indices of the other points.)
+ * Returns true if the point is successfully removed and false if it is not.
+ * (ie. the point is not in the tree)
+ */
+ bool DeletePoint(const size_t point, std::vector<bool>& relevels);
+
+ /**
+ * Deletes a node from the tree (along with all descendants).
+ */
+ bool DeleteNode(const RectangleTree* node, std::vector<bool>& relevels);
/**
* Find a node in this tree by its begin and count (const).
@@ -429,9 +465,9 @@ class RectangleTree
/**
* Splits the current node, recursing up the tree.
*
- * @param tree The RectangleTree object (node) to split.
+ * @param relevels Vector to track which levels have been inserted to.
*/
- void SplitNode();
+ void SplitNode(std::vector<bool>& relevels);
/**
* Splits the current node, recursing up the tree.
@@ -450,8 +486,11 @@ class RectangleTree
*
* @param point The arma::vec& of the point that was removed to require this
* condesation of the tree.
+ * @param usePoint True if we use the optimized version of the algorithm that is
+ * possible when we now what point was deleted. False otherwise (eg. if we
+ * deleted a node instead of a point).
*/
- void CondenseTree(const arma::vec& point);
+ void CondenseTree(const arma::vec& point, std::vector<bool>& relevels, const bool usePoint);
/**
* Shrink the bound object of this node for the removal of a point.
@@ -470,18 +509,6 @@ class RectangleTree
* @return true if the bound needed to be changed, false if it did not.
*/
bool ShrinkBoundForBound(const HRectBound<>& changedBound);
-
- /**
- * Inserts a node into the tree. The node will be inserted so that the tree
- * remains valid.
- *
- * @param node The node to insert into the tree.
- * @param level The depth that should match the node where this node is finally inserted.
- * This should be the number returned by calling TreeDepth() from the node that originally
- * contained "node".
- */
- void InsertNode(const RectangleTree<SplitType, DescentType, StatisticType, MatType>* node,
- const size_t level);
/**
* Returns a string representation of this object.
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
index f58aed9..b642048 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
@@ -149,7 +149,8 @@ InsertPoint(const size_t point)
if (numChildren == 0) {
localDataset->col(count) = dataset.col(point);
points[count++] = point;
- SplitNode();
+ std::vector<bool> lvls(TreeDepth());
+ SplitNode(lvls);
return;
}
@@ -159,6 +160,60 @@ InsertPoint(const size_t point)
}
/**
+ * Inserts a point into the tree, tracking which levels have been inserted into.
+ * The point will be copied to the data matrix
+ * of the leaf node where it is finally inserted, but we pass by reference since
+ * it may be passed many times before it actually reaches a leaf.
+ */
+template<typename SplitType,
+typename DescentType,
+typename StatisticType,
+typename MatType>
+void RectangleTree<SplitType, DescentType, StatisticType, MatType>::InsertPoint(const size_t point, std::vector<bool>& relevels)
+{
+ // Expand the bound regardless of whether it is a leaf node.
+ bound |= dataset.col(point);
+
+ // If this is a leaf node, we stop here and add the point.
+ if (numChildren == 0) {
+ localDataset->col(count) = dataset.col(point);
+ points[count++] = point;
+ std::vector<bool> lvls(TreeDepth());
+ SplitNode(lvls);
+ return;
+ }
+
+ // If it is not a leaf node, we use the DescentHeuristic to choose a child
+ // to which we recurse.
+ children[DescentType::ChooseDescentNode(this, dataset.col(point))]->InsertPoint(point);
+}
+
+/**
+ * Inserts a node into the tree, tracking which levels have been inserted into.
+ * @param node The node to be inserted.
+ * @param level The level on which this node should be inserted.
+ * @param relevels The levels that have been reinserted to on this top level insertion.
+ */
+template<typename SplitType,
+typename DescentType,
+typename StatisticType,
+typename MatType>
+void RectangleTree<SplitType, DescentType, StatisticType, MatType>::InsertNode(const RectangleTree* node, const size_t level, std::vector<bool>& relevels)
+{
+ // Expand the bound regardless of the level.
+ bound |= node->Bound();
+
+ if (level == TreeDepth()) {
+ children[numChildren++] = const_cast<RectangleTree*> (node);
+ assert(numChildren <= maxNumChildren); // We should never have increased without splitting.
+ if (numChildren == maxNumChildren)
+ SplitType::SplitNonLeafNode(this, relevels);
+ } else {
+ children[DescentType::ChooseDescentNode(this, node)]->InsertNode(node, level, relevels);
+ }
+}
+
+/**
* Recurse through the tree to remove the point. Once we find the point, we
* shrink the rectangles if necessary.
*/
@@ -174,7 +229,12 @@ DeletePoint(const size_t point)
if (points[i] == point) {
localDataset->col(i) = localDataset->col(--count); // decrement count
points[i] = points[count];
- CondenseTree(dataset.col(point)); // This function will ensure that minFill is satisfied.
+ //It is possible that this will cause a reinsertion, so we need to handle the lvls properly.
+ RectangleTree* root = this;
+ while(root->Parent() != NULL)
+ root = root->Parent();
+ std::vector<bool> lvls(root->TreeDepth());
+ CondenseTree(dataset.col(point), lvls, true); // This function will ensure that minFill is satisfied.
return true;
}
}
@@ -187,6 +247,59 @@ DeletePoint(const size_t point)
return false;
}
+/**
+ * Recurse through the tree to remove the point. Once we find the point, we
+ * shrink the rectangles if necessary.
+ */
+template<typename SplitType,
+typename DescentType,
+typename StatisticType,
+typename MatType>
+bool RectangleTree<SplitType, DescentType, StatisticType, MatType>::
+DeletePoint(const size_t point, std::vector<bool>& relevels)
+{
+ if (numChildren == 0) {
+ for (size_t i = 0; i < count; i++) {
+ if (points[i] == point) {
+ localDataset->col(i) = localDataset->col(--count); // decrement count
+ points[i] = points[count];
+ CondenseTree(dataset.col(point), relevels, true); // This function will ensure that minFill is satisfied.
+ return true;
+ }
+ }
+ }
+ for (size_t i = 0; i < numChildren; i++) {
+ if (children[i]->Bound().Contains(dataset.col(point)))
+ if (children[i]->DeletePoint(point))
+ return true;
+ }
+ return false;
+}
+
+/**
+ * Recurse through the tree to remove the node. Once we find the node, we
+ * shrink the rectangles if necessary.
+ */
+template<typename SplitType,
+typename DescentType,
+typename StatisticType,
+typename MatType>
+bool RectangleTree<SplitType, DescentType, StatisticType, MatType>::
+DeleteNode(const RectangleTree* node, std::vector<bool>& relevels)
+{
+ for (size_t i = 0; i < numChildren; i++) {
+ if (children[i] == node) {
+ children[i] = children[--numChildren]; // Decrement numChildren
+ CondenseTree(arma::vec(), false);
+ return true;
+ }
+ if (children[i]->Bound().Contains(node->Bound()))
+ if (children[i]->DeleteNode(node))
+ return true;
+ }
+ return false;
+}
+
template<typename SplitType,
typename DescentType,
typename StatisticType,
@@ -354,7 +467,7 @@ template<typename SplitType,
typename DescentType,
typename StatisticType,
typename MatType>
-void RectangleTree<SplitType, DescentType, StatisticType, MatType>::SplitNode()
+void RectangleTree<SplitType, DescentType, StatisticType, MatType>::SplitNode(std::vector<bool>& relevels)
{
// This should always be a leaf node. When we need to split other nodes,
// the split will be called from here but will take place in the SplitType code.
@@ -366,7 +479,7 @@ void RectangleTree<SplitType, DescentType, StatisticType, MatType>::SplitNode()
// If we are full, then we need to split (or at least try). The SplitType takes
// care of this and of moving up the tree if necessary.
- SplitType::SplitLeafNode(this);
+ SplitType::SplitLeafNode(this, relevels);
}
/**
@@ -377,9 +490,8 @@ template<typename SplitType,
typename DescentType,
typename StatisticType,
typename MatType>
-void RectangleTree<SplitType, DescentType, StatisticType, MatType>::CondenseTree(const arma::vec& point)
+void RectangleTree<SplitType, DescentType, StatisticType, MatType>::CondenseTree(const arma::vec& point, std::vector<bool>& relevels, const bool usePoint)
{
-
// First delete the node if we need to. There's no point in shrinking the bound first.
if (IsLeaf() && count < minLeafSize && parent != NULL) { //We can't delete the root node
for (size_t i = 0; i < parent->NumChildren(); i++) {
@@ -391,12 +503,12 @@ void RectangleTree<SplitType, DescentType, StatisticType, MatType>::CondenseTree
RectangleTree<SplitType, DescentType, StatisticType, MatType>* root = parent;
while (root->Parent() != NULL)
root = root->Parent();
-
+
for (size_t j = 0; j < count; j++) {
root->InsertPoint(points[j]);
}
- parent->CondenseTree(point); // This will check the MinFill of the parent.
+ parent->CondenseTree(point, relevels, usePoint); // This will check the MinFill of the parent.
//Now it should be safe to delete this node.
SoftDelete();
return;
@@ -418,9 +530,9 @@ void RectangleTree<SplitType, DescentType, StatisticType, MatType>::CondenseTree
while (root->Parent() != NULL)
root = root->Parent();
for (size_t i = 0; i < numChildren; i++)
- root->InsertNode(children[i], level);
+ root->InsertNode(children[i], level, relevels);
- parent->CondenseTree(point); // This will check the MinFill of the parent.
+ parent->CondenseTree(point, relevels, usePoint); // This will check the MinFill of the parent.
//Now it should be safe to delete this node.
SoftDelete();
return;
@@ -442,10 +554,11 @@ void RectangleTree<SplitType, DescentType, StatisticType, MatType>::CondenseTree
}
// If we didn't delete it, shrink the bound if we need to.
- if (ShrinkBoundForPoint(point) && parent != NULL) {
- parent->CondenseTree(point);
+ if (usePoint && ShrinkBoundForPoint(point) && parent != NULL) {
+ parent->CondenseTree(point, relevels, usePoint);
+ } else if (!usePoint && ShrinkBoundForBound(bound) && parent != NULL) {
+ parent->CondenseTree(point, relevels, usePoint);
}
-
}
/**
@@ -469,7 +582,7 @@ bool RectangleTree<SplitType, DescentType, StatisticType, MatType>::ShrinkBoundF
if (bound[i].Lo() < min) {
shrunk = true;
bound[i].Lo() = min;
- } else if(min < bound[i].Lo()) {
+ } else if (min < bound[i].Lo()) {
assert(true == false); // we have a problem.
}
} else if (bound[i].Hi() == point[i]) {
@@ -481,7 +594,7 @@ bool RectangleTree<SplitType, DescentType, StatisticType, MatType>::ShrinkBoundF
if (bound[i].Hi() > max) {
shrunk = true;
bound[i].Hi() = max;
- } else if(max > bound[i].Hi()) {
+ } else if (max > bound[i].Hi()) {
assert(true == false); // we have a problem.
}
}
@@ -533,32 +646,8 @@ bool RectangleTree<SplitType, DescentType, StatisticType, MatType>::ShrinkBoundF
double sum2 = 0;
for (size_t i = 0; i < bound.Dim(); i++)
sum2 += bound[i].Width();
-
- return sum != sum2;
-}
-
-/**
- * Insert the node into the tree at the appropriate depth.
- */
-template<typename SplitType,
-typename DescentType,
-typename StatisticType,
-typename MatType>
-void RectangleTree<SplitType, DescentType, StatisticType, MatType>::InsertNode(
- const RectangleTree<SplitType, DescentType, StatisticType, MatType>* node,
- const size_t level)
-{
- // Expand the bound regardless of the level.
- bound |= node->Bound();
- if (level == TreeDepth()) {
- children[numChildren++] = const_cast<RectangleTree*> (node);
- assert(numChildren <= maxNumChildren); // We should never have increased without splitting.
- if (numChildren == maxNumChildren)
- SplitType::SplitNonLeafNode(this);
- } else {
- children[DescentType::ChooseDescentNode(this, node)]->InsertNode(node, level);
- }
+ return sum != sum2;
}
/**
diff --git a/src/mlpack/tests/rectangle_tree_test.cpp b/src/mlpack/tests/rectangle_tree_test.cpp
index a9325e4..cc559c9 100644
--- a/src/mlpack/tests/rectangle_tree_test.cpp
+++ b/src/mlpack/tests/rectangle_tree_test.cpp
@@ -1,4 +1,3 @@
-
/**
* @file tree_traits_test.cpp
* @author Andrew Wells
@@ -44,6 +43,9 @@ BOOST_AUTO_TEST_CASE(RectangeTreeTraitsTest) {
BOOST_REQUIRE_EQUAL(b, false);
}
+// Test to make sure the tree can be contains the correct number of points after it is
+// constructed.
+
BOOST_AUTO_TEST_CASE(RectangleTreeConstructionCountTest) {
arma::mat dataset;
dataset.randu(3, 1000); // 1000 points in 3 dimensions.
@@ -55,6 +57,11 @@ BOOST_AUTO_TEST_CASE(RectangleTreeConstructionCountTest) {
BOOST_REQUIRE_EQUAL(tree.NumDescendants(), 1000);
}
+/**
+ * A function to return a std::vector containing pointers to each point in the tree.
+ * @param tree The tree that we want to extract all of the points from.
+ * @return A vector containing pointers to each point in this tree.
+ */
std::vector<arma::vec*> getAllPointsInTree(const RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
tree::RTreeDescentHeuristic,
NeighborSearchStat<NearestNeighborSort>,
@@ -74,6 +81,9 @@ std::vector<arma::vec*> getAllPointsInTree(const RectangleTree<tree::RTreeSplit<
return vec;
}
+// Test to ensure that none of the points in the tree are duplicates. This,
+// combined with the above test to see how many points are in the tree, should
+// ensure that we inserted all points.
BOOST_AUTO_TEST_CASE(RectangleTreeConstructionRepeatTest) {
arma::mat dataset;
dataset.randu(8, 1000); // 1000 points in 8 dimensions.
@@ -92,7 +102,7 @@ BOOST_AUTO_TEST_CASE(RectangleTreeConstructionRepeatTest) {
for (size_t k = 0; k < v1.n_rows; k++) {
same &= (v1[k] == v2[k]);
}
- assert(same != true);
+ BOOST_REQUIRE_NE(same, true);
}
}
for (size_t i = 0; i < allPoints.size(); i++) {
@@ -100,28 +110,33 @@ BOOST_AUTO_TEST_CASE(RectangleTreeConstructionRepeatTest) {
}
}
-bool checkContainment(const RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+/**
+ * A function to check that each non-leaf node fully encloses its child nodes
+ * and that each leaf node encloses its points. It recurses so that it checks
+ * each node under (and including) this one.
+ * @param tree The tree to check.
+ */
+void checkContainment(const RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
tree::RTreeDescentHeuristic,
NeighborSearchStat<NearestNeighborSort>,
arma::mat>& tree) {
- bool passed = true;
if (tree.NumChildren() == 0) {
for (size_t i = 0; i < tree.Count(); i++) {
- passed &= tree.Bound().Contains(tree.Dataset().unsafe_col(tree.Points()[i]));
+ BOOST_REQUIRE_EQUAL(tree.Bound().Contains(tree.Dataset().unsafe_col(tree.Points()[i])), true);
}
} else {
for (size_t i = 0; i < tree.NumChildren(); i++) {
- bool p1 = true;
for (size_t j = 0; j < tree.Bound().Dim(); j++) {
- p1 &= tree.Bound()[j].Contains(tree.Children()[i]->Bound()[j]);
+ BOOST_REQUIRE_EQUAL(tree.Bound()[j].Contains(tree.Children()[i]->Bound()[j]), true);
}
- passed &= p1;
- passed &= checkContainment(*(tree.Child(i)));
+ checkContainment(*(tree.Child(i)));
}
}
- return passed;
+ return;
}
+// Test to see if the bounds of the tree are correct. (Cover all bounds and points
+// beneath this node of the tree).
BOOST_AUTO_TEST_CASE(RectangleTreeContainmentTest) {
arma::mat dataset;
dataset.randu(8, 1000); // 1000 points in 8 dimensions.
@@ -130,30 +145,34 @@ BOOST_AUTO_TEST_CASE(RectangleTreeContainmentTest) {
tree::RTreeDescentHeuristic,
NeighborSearchStat<NearestNeighborSort>,
arma::mat> tree(dataset, 20, 6, 5, 2, 0);
- assert(checkContainment(tree) == true);
+ checkContainment(tree);
}
-bool checkSync(const RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+/**
+ * A function to ensure that the dataset for the tree, and the datasets stored
+ * in each leaf node are in sync.
+ * @param tree The tree to check.
+ */
+void checkSync(const RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
tree::RTreeDescentHeuristic,
NeighborSearchStat<NearestNeighborSort>,
arma::mat>& tree) {
if (tree.IsLeaf()) {
for (size_t i = 0; i < tree.Count(); i++) {
for (size_t j = 0; j < tree.LocalDataset().n_rows; j++) {
- if (tree.LocalDataset().col(i)[j] != tree.Dataset().col(tree.Points()[i])[j]) {
- return false;
- }
+ BOOST_REQUIRE_EQUAL(tree.LocalDataset().col(i)[j], tree.Dataset().col(tree.Points()[i])[j]);
}
}
} else {
for (size_t i = 0; i < tree.NumChildren(); i++) {
- if (!checkSync(*tree.Child(i)))
- return false;
+ checkSync(*tree.Child(i));
}
}
- return true;
+ return;
}
+// Test to ensure that the dataset used by the whole tree (and the traversers)
+// is in sync with the datasets stored in each leaf node.
BOOST_AUTO_TEST_CASE(TreeLocalDatasetInSync) {
arma::mat dataset;
dataset.randu(8, 1000); // 1000 points in 8 dimensions.
@@ -162,12 +181,115 @@ BOOST_AUTO_TEST_CASE(TreeLocalDatasetInSync) {
tree::RTreeDescentHeuristic,
NeighborSearchStat<NearestNeighborSort>,
arma::mat> tree(dataset, 20, 6, 5, 2, 0);
- assert(checkSync(tree) == true);
+ checkSync(tree);
+}
+
+/**
+ * A function to check that each of the fill requirements is met. For a non-leaf node:
+ * MinNumChildren() <= NumChildren() <= MaxNumChildren()
+ * For a leaf node:
+ * MinLeafSize() <= Count() <= MaxLeafSize
+ *
+ * It recurses so that it checks each node under (and including) this one.
+ * @param tree The tree to check.
+ */
+void checkFills(const RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+ tree::RTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat>& tree) {
+ if (tree.IsLeaf()) {
+ BOOST_REQUIRE_EQUAL((tree.Count() >= tree.MinLeafSize() || tree.Parent() == NULL) && tree.Count() <= tree.MaxLeafSize(), true);
+ } else {
+ for (size_t i = 0; i < tree.NumChildren(); i++) {
+ BOOST_REQUIRE_EQUAL((tree.NumChildren() >= tree.MinNumChildren() || tree.Parent() == NULL) && tree.NumChildren() <= tree.MaxNumChildren(), true);
+ checkFills(*tree.Child(i));
+ }
+ }
+ return;
+}
+
+// Test to ensure that the minimum and maximum fills are satisfied.
+BOOST_AUTO_TEST_CASE(CheckMinAndMaxFills) {
+ arma::mat dataset;
+ dataset.randu(8, 1000); // 1000 points in 8 dimensions.
+
+ RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+ tree::RTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat> tree(dataset, 20, 6, 5, 2, 0);
+ checkFills(tree);
+}
+
+/**
+ * A function to get the height of this tree. Though it should equal tree.TreeDepth(), we ensure
+ * that every leaf node is on the same level by doing it this way.
+ * @param tree The tree for which we want the height.
+ * @return The height of this tree.
+ */
+int getMaxLevel(const RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+ tree::RTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat>& tree) {
+ int max = 1;
+ if (!tree.IsLeaf()) {
+ int m = 0;
+ for (size_t i = 0; i < tree.NumChildren(); i++) {
+ int n = getMaxLevel(*tree.Child(i));
+ if (n > m)
+ m = n;
+ }
+ max += m;
+ }
+ return max;
+}
+
+/**
+ * A function to get the "shortest height" of this tree. Though it should equal tree.TreeDepth(), we ensure
+ * that every leaf node is on the same level by doing it this way.
+ * @param tree The tree for which we want the height.
+ * @return The "shortest height" of the tree.
+ */
+int getMinLevel(const RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+ tree::RTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat>& tree) {
+ int min = 1;
+ if (!tree.IsLeaf()) {
+ int m = INT_MAX;
+ for (size_t i = 0; i < tree.NumChildren(); i++) {
+ int n = getMinLevel(*tree.Child(i));
+ if (n < m)
+ m = n;
+ }
+ min += m;
+ }
+ return min;
+}
+
+// A test to ensure that all leaf nodes are stored on the same level of the tree.
+BOOST_AUTO_TEST_CASE(TreeBalance) {
+ arma::mat dataset;
+ dataset.randu(8, 1000); // 1000 points in 8 dimensions.
+
+ RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+ tree::RTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat> tree(dataset, 20, 6, 5, 2, 0);
+
+ BOOST_REQUIRE_EQUAL(getMinLevel(tree), getMaxLevel(tree));
+ BOOST_REQUIRE_EQUAL(tree.TreeDepth(), getMinLevel(tree));
}
+// A test to see if point deletion is working correctly.
+// We build a tree, then delete numIter points and test that the query gives correct
+// results. It is remotely possible that this test will give a false negative if
+// it should happen that two points are the same distance from a third point.
BOOST_AUTO_TEST_CASE(PointDeletion) {
arma::mat dataset;
dataset.randu(8, 1000); // 1000 points in 8 dimensions.
+
+ arma::mat querySet;
+ querySet.randu(8, 500);
const int numIter = 50;
@@ -177,79 +299,185 @@ BOOST_AUTO_TEST_CASE(PointDeletion) {
arma::mat> tree(dataset, 20, 6, 5, 2, 0);
for (int i = 0; i < numIter; i++) {
- tree.DeletePoint(i);
+ tree.DeletePoint(999 - i);
+ }
+
+ // Do a few sanity checks. Ensure each point is unique, the tree has the correct
+ // number of points, the tree has legal containment, and the tree's data is in sync.
+ std::vector<arma::vec*> allPoints = getAllPointsInTree(tree);
+ for (size_t i = 0; i < allPoints.size(); i++) {
+ for (size_t j = i + 1; j < allPoints.size(); j++) {
+ arma::vec v1 = *(allPoints[i]);
+ arma::vec v2 = *(allPoints[j]);
+ bool same = true;
+ for (size_t k = 0; k < v1.n_rows; k++) {
+ same &= (v1[k] == v2[k]);
+ }
+ BOOST_REQUIRE_NE(same, true);
+ }
+ }
+ for (size_t i = 0; i < allPoints.size(); i++) {
+ delete allPoints[i];
}
- assert(tree.NumDescendants() == 1000 - numIter);
- assert(checkContainment(tree) == true);
- assert(checkSync(tree) == true);
+ BOOST_REQUIRE_EQUAL(tree.NumDescendants(), 1000 - numIter);
+ checkContainment(tree);
+ checkSync(tree);
mlpack::neighbor::NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
tree::RTreeDescentHeuristic,
NeighborSearchStat<NearestNeighborSort>,
+ arma::mat> > allknn1(&tree, NULL,
+ dataset, querySet, true);
+
+ arma::Mat<size_t> neighbors1;
+ arma::mat distances1;
+ allknn1.Search(5, neighbors1, distances1);
+
+ arma::mat newDataset;
+ newDataset = dataset;
+ newDataset.resize(8, 1000-numIter);
+
+ arma::Mat<size_t> neighbors2;
+ arma::mat distances2;
+
+ // nearest neighbor search the naive way.
+ mlpack::neighbor::AllkNN allknn2(newDataset, querySet,
+ true, true);
+
+ allknn2.Search(5, neighbors2, distances2);
+
+ for (size_t i = 0; i < neighbors1.size(); i++) {
+ BOOST_REQUIRE_EQUAL(distances1[i], distances2[i]);
+ BOOST_REQUIRE_EQUAL(neighbors1[i], neighbors2[i]);
+ }
+}
+
+// A test to see if dynamic point insertion is working correctly.
+// We build a tree, then add numIter points and test that the query gives correct
+// results. It is remotely possible that this test will give a false negative if
+// it should happen that two points are the same distance from a third point.
+// Note that this is extremely inefficient. You should not use dynamic insertion until
+// a better solution for resizing matrices is available.
+
+BOOST_AUTO_TEST_CASE(PointDynamicAdd) {
+ const int numIter = 50;
+ arma::mat dataset;
+ dataset.randu(8, 1000); // 1000 points in 8 dimensions.
+
+ RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+ tree::RTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat> tree(dataset, 20, 6, 5, 2, 0);
+
+ // Add numIter new points to the dataset.
+ dataset.reshape(8, 1000+numIter);
+ arma::mat tmpData;
+ tmpData.randu(8, numIter);
+ for (int i = 0; i < numIter; i++) {
+ dataset.col(1000 + i) = tmpData.col(i);
+ tree.InsertPoint(1000 + i);
+ }
+
+ // Do a few sanity checks. Ensure each point is unique, the tree has the correct
+ // number of points, the tree has legal containment, and the tree's data is in sync.
+ std::vector<arma::vec*> allPoints = getAllPointsInTree(tree);
+ for (size_t i = 0; i < allPoints.size(); i++) {
+ for (size_t j = i + 1; j < allPoints.size(); j++) {
+ arma::vec v1 = *(allPoints[i]);
+ arma::vec v2 = *(allPoints[j]);
+ bool same = true;
+ for (size_t k = 0; k < v1.n_rows; k++) {
+ same &= (v1[k] == v2[k]);
+ }
+ BOOST_REQUIRE_NE(same, true);
+ }
+ }
+ for (size_t i = 0; i < allPoints.size(); i++) {
+ delete allPoints[i];
+ }
+ BOOST_REQUIRE_EQUAL(tree.NumDescendants(), 1000 + numIter);
+ checkContainment(tree);
+ checkSync(tree);
+
+ // Now we will compare the output of the R Tree vs the output of a naive search.
+ arma::Mat<size_t> neighbors1;
+ arma::mat distances1;
+ arma::Mat<size_t> neighbors2;
+ arma::mat distances2;
+
+ // nearest neighbor search with the R tree.
+ mlpack::neighbor::NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
+ RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+ tree::RTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
arma::mat> > allknn1(&tree,
dataset, true);
- arma::Mat<size_t> neighbors;
- arma::mat distances;
- allknn1.Search(5, neighbors, distances);
+ allknn1.Search(5, neighbors1, distances1);
- for (int i = 0; i < numIter; i++)
- assert(distances.at(0, i) > 0);
+ // nearest neighbor search the naive way.
+ mlpack::neighbor::AllkNN allknn2(dataset,
+ true, true);
- assert(checkContainment(tree) == true);
+ allknn2.Search(5, neighbors2, distances2);
+ for (size_t i = 0; i < neighbors1.size(); i++) {
+ BOOST_REQUIRE_EQUAL(distances1[i], distances2[i]);
+ BOOST_REQUIRE_EQUAL(neighbors1[i], neighbors2[i]);
+ }
}
-bool checkContainment(const RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+/**
+ * A function to check that each non-leaf node fully encloses its child nodes
+ * and that each leaf node encloses its points. It recurses so that it checks
+ * each node under (and including) this one.
+ * @param tree The tree to check.
+ */
+void checkContainment(const RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
tree::RStarTreeDescentHeuristic,
NeighborSearchStat<NearestNeighborSort>,
arma::mat>& tree) {
- bool passed = true;
if (tree.NumChildren() == 0) {
for (size_t i = 0; i < tree.Count(); i++) {
- passed &= tree.Bound().Contains(tree.Dataset().unsafe_col(tree.Points()[i]));
- if(!passed)
- std::cout << ".................PointContainmentFailed" << std::endl;
+ BOOST_REQUIRE_EQUAL(tree.Bound().Contains(tree.Dataset().unsafe_col(tree.Points()[i])), true);
}
} else {
for (size_t i = 0; i < tree.NumChildren(); i++) {
- bool p1 = true;
for (size_t j = 0; j < tree.Bound().Dim(); j++) {
- p1 &= tree.Bound()[j].Contains(tree.Children()[i]->Bound()[j]);
- if(!p1)
- std::cout << ".................BoundContainmentFailed" << std::endl;
+ BOOST_REQUIRE_EQUAL(tree.Bound()[j].Contains(tree.Children()[i]->Bound()[j]), true);
}
- passed &= p1;
- passed &= checkContainment(*(tree.Child(i)));
+ checkContainment(*(tree.Child(i)));
}
}
- return passed;
+ return;
}
-
-bool checkSync(const RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+/**
+ * A function to ensure that the dataset for the tree, and the datasets stored
+ * in each leaf node are in sync.
+ * @param tree The tree to check.
+ */
+void checkSync(const RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
tree::RStarTreeDescentHeuristic,
NeighborSearchStat<NearestNeighborSort>,
arma::mat>& tree) {
if (tree.IsLeaf()) {
for (size_t i = 0; i < tree.Count(); i++) {
for (size_t j = 0; j < tree.LocalDataset().n_rows; j++) {
- if (tree.LocalDataset().col(i)[j] != tree.Dataset().col(tree.Points()[i])[j]) {
- return false;
- }
+ BOOST_REQUIRE_EQUAL(tree.LocalDataset().col(i)[j], tree.Dataset().col(tree.Points()[i])[j]);
}
}
} else {
for (size_t i = 0; i < tree.NumChildren(); i++) {
- if (!checkSync(*tree.Child(i)))
- return false;
+ checkSync(*tree.Child(i));
}
}
- return true;
+ return;
}
-
+// A test to ensure that the SingleTreeTraverser is working correctly by comparing
+// its results to the results of a naive search.
BOOST_AUTO_TEST_CASE(SingleTreeTraverserTest) {
arma::mat dataset;
dataset.randu(8, 1000); // 1000 points in 8 dimensions.
@@ -271,9 +499,9 @@ BOOST_AUTO_TEST_CASE(SingleTreeTraverserTest) {
arma::mat> > allknn1(&RTree,
dataset, true);
- assert(RTree.NumDescendants() == 1000);
- assert(checkSync(RTree) == true);
- assert(checkContainment(RTree) == true);
+ BOOST_REQUIRE_EQUAL(RTree.NumDescendants(), 1000);
+ checkSync(RTree);
+ checkContainment(RTree);
allknn1.Search(5, neighbors1, distances1);
@@ -285,9 +513,142 @@ BOOST_AUTO_TEST_CASE(SingleTreeTraverserTest) {
allknn2.Search(5, neighbors2, distances2);
for (size_t i = 0; i < neighbors1.size(); i++) {
- assert(neighbors1[i] == neighbors2[i]);
- assert(distances1[i] == distances2[i]);
+ BOOST_REQUIRE_EQUAL(neighbors1[i], neighbors2[i]);
+ BOOST_REQUIRE_EQUAL(distances1[i], distances2[i]);
}
}
+// Test the tree splitting. We set MaxLeafSize and MaxNumChildren rather low
+// to allow us to test by hand without adding hundreds of points.
+BOOST_AUTO_TEST_CASE(RTreeSplitTest) {
+ arma::mat data = arma::trans(arma::mat("0.0 0.0;"
+ "0.0 1.0;"
+ "1.0 0.1;"
+ "1.0 0.5;"
+ "0.7 0.3;"
+ "0.9 0.9;"
+ "0.5 0.6;"
+ "0.6 0.3;"
+ "0.1 0.5;"
+ "0.3 0.7;"));
+
+ RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+ tree::RTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat> RTree(data, 5, 2, 2, 1, 0);
+
+ //There's technically no reason they have to be in a certain order, so we
+ //use firstChild etc. to arbitrarily name them.
+ BOOST_REQUIRE_EQUAL(RTree.NumChildren(), 2);
+ BOOST_REQUIRE_EQUAL(RTree.NumDescendants(), 10);
+ BOOST_REQUIRE_EQUAL(RTree.TreeDepth(), 3);
+
+ int firstChild = 0, secondChild = 1;
+ if(RTree.Child(firstChild)->NumChildren() == 2) {
+ firstChild = 1;
+ secondChild = 0;
+ }
+
+ BOOST_REQUIRE_EQUAL(RTree.Child(firstChild)->Bound()[0].Lo(), 0);
+ BOOST_REQUIRE_EQUAL(RTree.Child(firstChild)->Bound()[0].Hi(), 0.1);
+ BOOST_REQUIRE_EQUAL(RTree.Child(firstChild)->Bound()[1].Lo(), 0);
+ BOOST_REQUIRE_EQUAL(RTree.Child(firstChild)->Bound()[1].Hi(), 1.0);
+
+ BOOST_REQUIRE_EQUAL(RTree.Child(secondChild)->Bound()[0].Lo(), 0.3);
+ BOOST_REQUIRE_EQUAL(RTree.Child(secondChild)->Bound()[0].Hi(), 1.0);
+ BOOST_REQUIRE_EQUAL(RTree.Child(secondChild)->Bound()[1].Lo(), 0.1);
+ BOOST_REQUIRE_EQUAL(RTree.Child(secondChild)->Bound()[1].Hi(), 0.9);
+
+ BOOST_REQUIRE_EQUAL(RTree.Child(firstChild)->NumChildren(), 1);
+ BOOST_REQUIRE_EQUAL(RTree.Child(firstChild)->Child(0)->Bound()[0].Lo(), 0);
+ BOOST_REQUIRE_EQUAL(RTree.Child(firstChild)->Child(0)->Bound()[0].Hi(), 0.1);
+ BOOST_REQUIRE_EQUAL(RTree.Child(firstChild)->Child(0)->Bound()[1].Lo(), 0);
+ BOOST_REQUIRE_EQUAL(RTree.Child(firstChild)->Child(0)->Bound()[1].Hi(), 1.0);
+ BOOST_REQUIRE_EQUAL(RTree.Child(firstChild)->Child(0)->Count(), 3);
+
+ int firstPrime = 0, secondPrime = 1;
+ if(RTree.Child(secondChild)->Child(firstPrime)->Count() == 3) {
+ firstPrime = 1;
+ secondPrime = 0;
+ }
+ BOOST_REQUIRE_EQUAL(RTree.Child(secondChild)->NumChildren(), 2);
+ BOOST_REQUIRE_EQUAL(RTree.Child(secondChild)->Child(firstPrime)->Count(), 4);
+ BOOST_REQUIRE_EQUAL(RTree.Child(secondChild)->Child(firstPrime)->Bound()[0].Lo(), 0.3);
+ BOOST_REQUIRE_EQUAL(RTree.Child(secondChild)->Child(firstPrime)->Bound()[0].Hi(), 0.7);
+ BOOST_REQUIRE_EQUAL(RTree.Child(secondChild)->Child(firstPrime)->Bound()[1].Lo(), 0.3);
+ BOOST_REQUIRE_EQUAL(RTree.Child(secondChild)->Child(firstPrime)->Bound()[1].Hi(), 0.7);
+ BOOST_REQUIRE_EQUAL(RTree.Child(secondChild)->Child(secondPrime)->Count(), 3);
+ BOOST_REQUIRE_EQUAL(RTree.Child(secondChild)->Child(secondPrime)->Bound()[0].Lo(), 0.9);
+ BOOST_REQUIRE_EQUAL(RTree.Child(secondChild)->Child(secondPrime)->Bound()[0].Hi(), 1.0);
+ BOOST_REQUIRE_EQUAL(RTree.Child(secondChild)->Child(secondPrime)->Bound()[1].Lo(), 0.1);
+ BOOST_REQUIRE_EQUAL(RTree.Child(secondChild)->Child(secondPrime)->Bound()[1].Hi(), 0.9);
+
+}
+
+// Test the tree splitting. We set MaxLeafSize and MaxNumChildren rather low
+// to allow us to test by hand without adding hundreds of points.
+BOOST_AUTO_TEST_CASE(RStarTreeSplitTest) {
+ arma::mat data = arma::trans(arma::mat("0.0 0.0;"
+ "0.0 1.0;"
+ "1.0 0.1;"
+ "1.0 0.5;"
+ "0.7 0.3;"
+ "0.9 0.9;"
+ "0.5 0.6;"
+ "0.6 0.3;"
+ "0.1 0.5;"
+ "0.3 0.7;"));
+
+ RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+ tree::RStarTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat> RTree(data, 5, 2, 2, 1, 0);
+
+ //There's technically no reason they have to be in a certain order, so we
+ //use firstChild etc. to arbitrarily name them.
+ BOOST_REQUIRE_EQUAL(RTree.NumChildren(), 2);
+ BOOST_REQUIRE_EQUAL(RTree.NumDescendants(), 10);
+ BOOST_REQUIRE_EQUAL(RTree.TreeDepth(), 3);
+
+ int firstChild = 0, secondChild = 1;
+ if(RTree.Child(firstChild)->NumChildren() == 2) {
+ firstChild = 1;
+ secondChild = 0;
+ }
+
+ BOOST_REQUIRE_EQUAL(RTree.Child(firstChild)->Bound()[0].Lo(), 0);
+ BOOST_REQUIRE_EQUAL(RTree.Child(firstChild)->Bound()[0].Hi(), 0.1);
+ BOOST_REQUIRE_EQUAL(RTree.Child(firstChild)->Bound()[1].Lo(), 0);
+ BOOST_REQUIRE_EQUAL(RTree.Child(firstChild)->Bound()[1].Hi(), 1.0);
+
+ BOOST_REQUIRE_EQUAL(RTree.Child(secondChild)->Bound()[0].Lo(), 0.3);
+ BOOST_REQUIRE_EQUAL(RTree.Child(secondChild)->Bound()[0].Hi(), 1.0);
+ BOOST_REQUIRE_EQUAL(RTree.Child(secondChild)->Bound()[1].Lo(), 0.1);
+ BOOST_REQUIRE_EQUAL(RTree.Child(secondChild)->Bound()[1].Hi(), 0.9);
+
+ BOOST_REQUIRE_EQUAL(RTree.Child(firstChild)->NumChildren(), 1);
+ BOOST_REQUIRE_EQUAL(RTree.Child(firstChild)->Child(0)->Bound()[0].Lo(), 0);
+ BOOST_REQUIRE_EQUAL(RTree.Child(firstChild)->Child(0)->Bound()[0].Hi(), 0.1);
+ BOOST_REQUIRE_EQUAL(RTree.Child(firstChild)->Child(0)->Bound()[1].Lo(), 0);
+ BOOST_REQUIRE_EQUAL(RTree.Child(firstChild)->Child(0)->Bound()[1].Hi(), 1.0);
+ BOOST_REQUIRE_EQUAL(RTree.Child(firstChild)->Child(0)->Count(), 3);
+
+ int firstPrime = 0, secondPrime = 1;
+ if(RTree.Child(secondChild)->Child(firstPrime)->Count() == 3) {
+ firstPrime = 1;
+ secondPrime = 0;
+ }
+ BOOST_REQUIRE_EQUAL(RTree.Child(secondChild)->NumChildren(), 2);
+ BOOST_REQUIRE_EQUAL(RTree.Child(secondChild)->Child(firstPrime)->Count(), 4);
+ BOOST_REQUIRE_EQUAL(RTree.Child(secondChild)->Child(firstPrime)->Bound()[0].Lo(), 0.3);
+ BOOST_REQUIRE_EQUAL(RTree.Child(secondChild)->Child(firstPrime)->Bound()[0].Hi(), 1.0);
+ BOOST_REQUIRE_EQUAL(RTree.Child(secondChild)->Child(firstPrime)->Bound()[1].Lo(), 0.5);
+ BOOST_REQUIRE_EQUAL(RTree.Child(secondChild)->Child(firstPrime)->Bound()[1].Hi(), 0.9);
+ BOOST_REQUIRE_EQUAL(RTree.Child(secondChild)->Child(secondPrime)->Count(), 3);
+ BOOST_REQUIRE_EQUAL(RTree.Child(secondChild)->Child(secondPrime)->Bound()[0].Lo(), 0.6);
+ BOOST_REQUIRE_EQUAL(RTree.Child(secondChild)->Child(secondPrime)->Bound()[0].Hi(), 1.0);
+ BOOST_REQUIRE_EQUAL(RTree.Child(secondChild)->Child(secondPrime)->Bound()[1].Lo(), 0.1);
+ BOOST_REQUIRE_EQUAL(RTree.Child(secondChild)->Child(secondPrime)->Bound()[1].Hi(), 0.3);
+}
+
BOOST_AUTO_TEST_SUITE_END();
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list