[mlpack] 178/324: Point deletion. Massive refactoring of Children()[i] to Child(i). Really minor point deletion test.
Barak A. Pearlmutter
barak+git at cs.nuim.ie
Sun Aug 17 08:22:08 UTC 2014
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch svn-trunk
in repository mlpack.
commit f554c6d7ce5f039d9ac2e39a71a4dfde06603605
Author: andrewmw94 <andrewmw94 at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date: Sat Jul 12 18:25:03 2014 +0000
Point deletion. Massive refactoring of Children()[i] to Child(i). Really minor point deletion test.
git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@16817 9d5b8971-822b-0410-80eb-d18c1038ef23
---
.../rectangle_tree/r_tree_descent_heuristic.hpp | 5 +-
.../r_tree_descent_heuristic_impl.hpp | 52 ++-
.../core/tree/rectangle_tree/r_tree_split.hpp | 8 +-
.../core/tree/rectangle_tree/r_tree_split_impl.hpp | 70 +--
.../core/tree/rectangle_tree/rectangle_tree.hpp | 45 +-
.../tree/rectangle_tree/rectangle_tree_impl.hpp | 512 +++++++++++++--------
src/mlpack/tests/rectangle_tree_test.cpp | 13 +
7 files changed, 467 insertions(+), 238 deletions(-)
diff --git a/src/mlpack/core/tree/rectangle_tree/r_tree_descent_heuristic.hpp b/src/mlpack/core/tree/rectangle_tree/r_tree_descent_heuristic.hpp
index d757777..da5b7a8 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_tree_descent_heuristic.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_tree_descent_heuristic.hpp
@@ -28,8 +28,11 @@ class RTreeDescentHeuristic
* @param bound The bound used for the node that is being evaluated.
* @param point The point that is being inserted.
*/
- static double EvalNode(const HRectBound<>& bound, const arma::vec& point);
+ template<typename TreeType>
+ static size_t ChooseDescentNode(const TreeType* node, const arma::vec& point);
+ template<typename TreeType>
+ static size_t ChooseDescentNode(const TreeType* node, const TreeType* insertedNode);
};
}; // namespace tree
diff --git a/src/mlpack/core/tree/rectangle_tree/r_tree_descent_heuristic_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_tree_descent_heuristic_impl.hpp
index 5542119..3d61c75 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_tree_descent_heuristic_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_tree_descent_heuristic_impl.hpp
@@ -14,17 +14,51 @@ namespace mlpack {
namespace tree {
// Return the increase in volume required when inserting point into bound.
-inline double RTreeDescentHeuristic::EvalNode(const HRectBound<>& bound, const arma::vec& point)
+
+template<typename TreeType>
+inline size_t RTreeDescentHeuristic::ChooseDescentNode(const TreeType* node, const arma::vec& point)
{
- double v1 = 1.0;
- double v2 = 1.0;
- for(size_t i = 0; i < bound.Dim(); i++) {
- v1 *= bound[i].Width();
- v2 *= bound[i].Contains(point[i]) ? bound[i].Width() : (bound[i].Hi() < point[i] ? (point[i] - bound[i].Lo()) :
- (bound[i].Hi() - point[i]));
+ double minScore = DBL_MAX;
+ int bestIndex = 0;
+
+ for (size_t j = 0; j < node->NumChildren(); j++) {
+ double v1 = 1.0;
+ double v2 = 1.0;
+ for (size_t i = 0; i < node->Child(j)->Bound().Dim(); i++) {
+ v1 *= node->Child(j)->Bound()[i].Width();
+ v2 *= node->Child(j)->Bound()[i].Contains(point[i]) ? node->Child(j)->Bound()[i].Width() : (node->Child(j)->Bound()[i].Hi() < point[i] ? (point[i] - node->Child(j)->Bound()[i].Lo()) :
+ (node->Child(j)->Bound()[i].Hi() - point[i]));
+ }
+ assert(v2 - v1 >= 0);
+ if (v2 - v1 < minScore) {
+ minScore = v2 - v1;
+ bestIndex = j;
+ }
+ }
+ return bestIndex;
+}
+
+template<typename TreeType>
+inline size_t RTreeDescentHeuristic::ChooseDescentNode(const TreeType* node, const TreeType* insertedNode)
+{
+ double minScore = DBL_MAX;
+ int bestIndex = 0;
+
+ for (size_t j = 0; j < node->NumChildren(); j++) {
+ double v1 = 1.0;
+ double v2 = 1.0;
+ for (size_t i = 0; i < node->Child(j)->Bound().Dim(); i++) {
+ v1 *= node->Child(j)->Bound()[i].Width();
+ v2 *= node->Child(j)->Bound()[i].Contains(insertedNode->Bound()[i]) ? node->Child(j)->Bound()[i].Width() :
+ (insertedNode->Bound()[i].Contains(node->Child(j)->Bound()[i]) ? insertedNode->Bound()[i].Width() : (insertedNode->Bound()[i].Lo() < node->Child(j)->Bound()[i].Lo() ? (node->Child(j)->Bound()[i].Hi() - insertedNode->Bound()[i].Lo()) : (insertedNode->Bound()[i].Hi() - node->Child(j)->Bound()[i].Lo())));
+ }
+ assert(v2 - v1 >= 0);
+ if (v2 - v1 < minScore) {
+ minScore = v2 - v1;
+ bestIndex = j;
+ }
}
- assert(v2 - v1 >= 0);
- return v2 - v1;
+ return bestIndex;
}
}; // namespace tree
diff --git a/src/mlpack/core/tree/rectangle_tree/r_tree_split.hpp b/src/mlpack/core/tree/rectangle_tree/r_tree_split.hpp
index ccac982..e57ec63 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_tree_split.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_tree_split.hpp
@@ -32,14 +32,14 @@ public:
*/
static void SplitLeafNode(RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>* tree);
-private:
-
/**
- * Split a non-leaf node using the "default" algorithm. If this is the root node and
- * we need to move up the tree, a new root node is created.
+ * Split a non-leaf node using the "default" algorithm. If this is a root node, the
+ * tree increases in depth.
*/
static bool SplitNonLeafNode(RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>* tree);
+private:
+
/**
* Get the seeds for splitting a leaf node.
*/
diff --git a/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
index ad4c51b..fc04839 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
@@ -33,7 +33,7 @@ void RTreeSplit<DescentType, StatisticType, MatType>::SplitLeafNode(
copy->Parent() = tree;
tree->Count() = 0;
tree->NullifyData();
- tree->Children()[(tree->NumChildren())++] = copy; // Because this was a leaf node, numChildren must be 0.
+ tree->Child((tree->NumChildren())++) = copy; // Because this was a leaf node, numChildren must be 0.
assert(tree->NumChildren() == 1);
RTreeSplit<DescentType, StatisticType, MatType>::SplitLeafNode(copy);
return;
@@ -59,13 +59,13 @@ void RTreeSplit<DescentType, StatisticType, MatType>::SplitLeafNode(
RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>* par = tree->Parent();
int index = 0;
for(int i = 0; i < par->NumChildren(); i++) {
- if(par->Children()[i] == tree) {
+ if(par->Child(i) == tree) {
index = i;
break;
}
}
- par->Children()[index] = treeOne;
- par->Children()[par->NumChildren()++] = treeTwo;
+ par->Child(index) = treeOne;
+ par->Child(par->NumChildren()++) = treeTwo;
// we only add one at a time, so we should only need to test for equality
// just in case, we use an assert.
@@ -80,7 +80,7 @@ void RTreeSplit<DescentType, StatisticType, MatType>::SplitLeafNode(
assert(treeTwo->Parent()->NumChildren() >= treeTwo->MinNumChildren());
// We need to delete this carefully since references to points are used.
- tree->softDelete();
+ tree->SoftDelete();
return;
}
@@ -106,7 +106,7 @@ bool RTreeSplit<DescentType, StatisticType, MatType>::SplitNonLeafNode(
copy->Parent() = tree;
tree->NumChildren() = 0;
tree->NullifyData();
- tree->Children()[(tree->NumChildren())++] = copy;
+ tree->Child((tree->NumChildren())++) = copy;
RTreeSplit<DescentType, StatisticType, MatType>::SplitNonLeafNode(copy);
return true;
}
@@ -132,18 +132,18 @@ bool RTreeSplit<DescentType, StatisticType, MatType>::SplitNonLeafNode(
RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>* par = tree->Parent();
int index = -1;
for(int i = 0; i < par->NumChildren(); i++) {
- if(par->Children()[i] == tree) {
+ if(par->Child(i) == tree) {
index = i;
break;
}
}
assert(index != -1);
- par->Children()[index] = treeOne;
- par->Children()[par->NumChildren()++] = treeTwo;
+ par->Child(index) = treeOne;
+ par->Child(par->NumChildren()++) = treeTwo;
for(int i = 0; i < par->NumChildren(); i++) {
- if(par->Children()[i] == tree) {
- assert(par->Children()[i] != tree);
+ if(par->Child(i) == tree) {
+ assert(par->Child(i) != tree);
}
}
@@ -158,10 +158,10 @@ bool RTreeSplit<DescentType, StatisticType, MatType>::SplitNonLeafNode(
// We have to update the children of each of these new nodes so that they record the
// correct parent.
for(int i = 0; i < treeOne->NumChildren(); i++) {
- treeOne->Children()[i]->Parent() = treeOne;
+ treeOne->Child(i)->Parent() = treeOne;
}
for(int i = 0; i < treeTwo->NumChildren(); i++) {
- treeTwo->Children()[i]->Parent() = treeTwo;
+ treeTwo->Child(i)->Parent() = treeTwo;
}
assert(treeOne->NumChildren() < treeOne->MaxNumChildren());
@@ -170,7 +170,7 @@ bool RTreeSplit<DescentType, StatisticType, MatType>::SplitNonLeafNode(
// Because we now have pointers to the information stored under this tree,
// we need to delete this node carefully.
- tree->softDelete(); //currently does nothing but leak memory.
+ tree->SoftDelete(); //currently does nothing but leak memory.
return false;
}
@@ -231,8 +231,8 @@ void RTreeSplit<DescentType, StatisticType, MatType>::GetBoundSeeds(
for(int j = i+1; j < tree.NumChildren(); j++) {
double score = 1.0;
for(int k = 0; k < tree.Bound().Dim(); k++) {
- score *= std::max(tree.Children()[i]->Bound()[k].Hi(), tree.Children()[j]->Bound()[k].Hi()) -
- std::min(tree.Children()[i]->Bound()[k].Lo(), tree.Children()[j]->Bound()[k].Lo());
+ score *= std::max(tree.Child(i)->Bound()[k].Hi(), tree.Child(j)->Bound()[k].Hi()) -
+ std::min(tree.Child(i)->Bound()[k].Lo(), tree.Child(j)->Bound()[k].Lo());
}
if(score > worstPairScore) {
worstPairScore = score;
@@ -382,20 +382,20 @@ void RTreeSplit<DescentType, StatisticType, MatType>::AssignNodeDestNode(
for(int i = 0; i < oldTree->NumChildren(); i++) {
for(int j = i+1; j < oldTree->NumChildren(); j++) {
- assert(oldTree->Children()[i] != oldTree->Children()[j]);
+ assert(oldTree->Child(i) != oldTree->Child(j));
}
}
- insertNodeIntoTree(treeOne, oldTree->Children()[intI]);
- insertNodeIntoTree(treeTwo, oldTree->Children()[intJ]);
+ insertNodeIntoTree(treeOne, oldTree->Child(intI));
+ insertNodeIntoTree(treeTwo, oldTree->Child(intJ));
// If intJ is the last node in the tree, we need to switch the order so that we remove the correct nodes.
if(intI > intJ) {
- oldTree->Children()[intI] = oldTree->Children()[--end]; // decrement end
- oldTree->Children()[intJ] = oldTree->Children()[--end]; // decrement end
+ oldTree->Child(intI) = oldTree->Child(--end); // decrement end
+ oldTree->Child(intJ) = oldTree->Child(--end); // decrement end
} else {
- oldTree->Children()[intJ] = oldTree->Children()[--end]; // decrement end
- oldTree->Children()[intI] = oldTree->Children()[--end]; // decrement end
+ oldTree->Child(intJ) = oldTree->Child(--end); // decrement end
+ oldTree->Child(intI) = oldTree->Child(--end); // decrement end
}
assert(treeOne->NumChildren() == 1);
@@ -403,16 +403,16 @@ void RTreeSplit<DescentType, StatisticType, MatType>::AssignNodeDestNode(
for(int i = 0; i < end; i++) {
for(int j = i+1; j < end; j++) {
- assert(oldTree->Children()[i] != oldTree->Children()[j]);
+ assert(oldTree->Child(i) != oldTree->Child(j));
}
}
for(int i = 0; i < end; i++) {
- assert(oldTree->Children()[i] != treeOne->Children()[0]);
+ assert(oldTree->Child(i) != treeOne->Child(0));
}
for(int i = 0; i < end; i++) {
- assert(oldTree->Children()[i] != treeTwo->Children()[0]);
+ assert(oldTree->Child(i) != treeTwo->Child(0));
}
@@ -441,7 +441,7 @@ void RTreeSplit<DescentType, StatisticType, MatType>::AssignNodeDestNode(
for(int i = 0; i < oldTree->Bound().Dim(); i++) {
// For each of the new rectangles, find the width in this dimension if we add the rectangle at index to
// the new rectangle.
- math::Range range = oldTree->Children()[index]->Bound()[i];
+ math::Range range = oldTree->Child(index)->Bound()[i];
newVolOne *= treeOne->Bound()[i].Contains(range) ? treeOne->Bound()[i].Width() :
(range.Contains(treeOne->Bound()[i]) ? range.Width() : (range.Lo() < treeOne->Bound()[i].Lo() ? (treeOne->Bound()[i].Hi() - range.Lo()) :
(range.Hi() - treeOne->Bound()[i].Lo())));
@@ -469,26 +469,26 @@ void RTreeSplit<DescentType, StatisticType, MatType>::AssignNodeDestNode(
// Assign the rectangle that causes the least increase in volume
// to the appropriate rectangle.
if(bestRect == 1) {
- insertNodeIntoTree(treeOne, oldTree->Children()[bestIndex]);
+ insertNodeIntoTree(treeOne, oldTree->Child(bestIndex));
numAssignTreeOne++;
}
else {
- insertNodeIntoTree(treeTwo, oldTree->Children()[bestIndex]);
+ insertNodeIntoTree(treeTwo, oldTree->Child(bestIndex));
numAssignTreeTwo++;
}
- oldTree->Children()[bestIndex] = oldTree->Children()[--end]; // Decrement end.
+ oldTree->Child(bestIndex) = oldTree->Child(--end); // Decrement end.
}
// See if we need to satisfy the minimum fill.
if(end > 0) {
if(numAssignTreeOne < numAssignTreeTwo) {
for(int i = 0; i < end; i++) {
- insertNodeIntoTree(treeOne, oldTree->Children()[i]);
+ insertNodeIntoTree(treeOne, oldTree->Child(i));
numAssignTreeOne++;
}
} else {
for(int i = 0; i < end; i++) {
- insertNodeIntoTree(treeTwo, oldTree->Children()[i]);
+ insertNodeIntoTree(treeTwo, oldTree->Child(i));
numAssignTreeTwo++;
}
}
@@ -496,12 +496,12 @@ void RTreeSplit<DescentType, StatisticType, MatType>::AssignNodeDestNode(
for(int i = 0; i < treeOne->NumChildren(); i++) {
for(int j = i+1; j < treeOne->NumChildren(); j++) {
- assert(treeOne->Children()[i] != treeOne->Children()[j]);
+ assert(treeOne->Child(i) != treeOne->Child(j));
}
}
for(int i = 0; i < treeTwo->NumChildren(); i++) {
for(int j = i+1; j < treeTwo->NumChildren(); j++) {
- assert(treeTwo->Children()[i] != treeTwo->Children()[j]);
+ assert(treeTwo->Child(i) != treeTwo->Child(j));
}
}
assert(treeOne->NumChildren() == numAssignTreeOne);
@@ -520,7 +520,7 @@ void RTreeSplit<DescentType, StatisticType, MatType>::insertNodeIntoTree(
RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>* srcNode)
{
destTree->Bound() |= srcNode->Bound();
- destTree->Children()[destTree->NumChildren()++] = srcNode;
+ destTree->Child(destTree->NumChildren()++) = srcNode;
}
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
index 89a80fe..56f6a5e 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
@@ -132,7 +132,7 @@ class RectangleTree
* This is used when splitting a node, where the data in this tree is moved to two
* other trees.
*/
- void softDelete();
+ void SoftDelete();
/**
* Set dataset to null. Used for memory management. Be cafeful.
@@ -150,6 +150,8 @@ class RectangleTree
/**
* Deletes a point in the tree. The point will be removed from the data matrix
* of the leaf node where it is store and the bounding rectangles will be updated.
+ * However, the point will be kept in the centeral dataset. (The user may remove it
+ * from there if he wants, but he must not change the indices of the other points.)
* Returns true if the point is successfully removed and false if it is not.
* (ie. the point is not in the tree)
*/
@@ -282,7 +284,6 @@ class RectangleTree
*
* @param child Index of child to return.
*/
-
inline RectangleTree<SplitType, DescentType, StatisticType, MatType>*
Child(const size_t child) const
{
@@ -443,6 +444,46 @@ class RectangleTree
public:
/**
+ * Condense the bounding rectangles for this node based on the removal of
+ * the point specified by the arma::vec&. This recurses up the tree. If a node
+ * goes below the minimum fill, this function will fix the tree.
+ *
+ * @param point The arma::vec& of the point that was removed to require this
+ * condesation of the tree.
+ */
+ void CondenseTree(const arma::vec& point);
+
+ /**
+ * Shrink the bound object of this node for the removal of a point.
+ *
+ * @param point The arma::vec& of the point that was removed to require this
+ * shrinking.
+ * @return true if the bound needed to be changed, false if it did not.
+ */
+ bool ShrinkBoundForPoint(const arma::vec& point);
+
+ /**
+ * Shrink the bound object of this node for the removal of a child node.
+ *
+ * @param bound The HRectBound<>& of the bound that was removed to reqire this
+ * shrinking.
+ * @return true if the bound needed to be changed, false if it did not.
+ */
+ bool ShrinkBoundForBound(const HRectBound<>& changedBound);
+
+ /**
+ * Inserts a node into the tree. The node will be inserted so that the tree
+ * remains valid.
+ *
+ * @param node The node to insert into the tree.
+ * @param level The depth that should match the node where this node is finally inserted.
+ * This should be the number returned by calling TreeDepth() from the node that originally
+ * contained "node".
+ */
+ void InsertNode(const RectangleTree<SplitType, DescentType, StatisticType, MatType>* node,
+ const size_t level);
+
+ /**
* Returns a string representation of this object.
*/
std::string ToString() const;
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
index 8df89ec..17bada5 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
@@ -18,30 +18,30 @@ namespace mlpack {
namespace tree {
template<typename SplitType,
- typename DescentType,
- typename StatisticType,
- typename MatType>
+typename DescentType,
+typename StatisticType,
+typename MatType>
RectangleTree<SplitType, DescentType, StatisticType, MatType>::RectangleTree(
- MatType& data,
- const size_t maxLeafSize,
- const size_t minLeafSize,
- const size_t maxNumChildren,
- const size_t minNumChildren,
- const size_t firstDataIndex):
- maxNumChildren(maxNumChildren),
- minNumChildren(minNumChildren),
- numChildren(0),
- children(maxNumChildren+1), // Add one to make splitting the node simpler
- parent(NULL),
- begin(0),
- count(0),
- maxLeafSize(maxLeafSize),
- minLeafSize(minLeafSize),
- bound(data.n_rows),
- parentDistance(0),
- dataset(data),
- points(maxLeafSize+1), // Add one to make splitting the node simpler.
- localDataset(new MatType(data.n_rows, static_cast<int>(maxLeafSize)+1)) // Add one to make splitting the node simpler
+ MatType& data,
+ const size_t maxLeafSize,
+ const size_t minLeafSize,
+ const size_t maxNumChildren,
+ const size_t minNumChildren,
+ const size_t firstDataIndex) :
+maxNumChildren(maxNumChildren),
+minNumChildren(minNumChildren),
+numChildren(0),
+children(maxNumChildren + 1), // Add one to make splitting the node simpler
+parent(NULL),
+begin(0),
+count(0),
+maxLeafSize(maxLeafSize),
+minLeafSize(minLeafSize),
+bound(data.n_rows),
+parentDistance(0),
+dataset(data),
+points(maxLeafSize + 1), // Add one to make splitting the node simpler.
+localDataset(new MatType(data.n_rows, static_cast<int> (maxLeafSize) + 1)) // Add one to make splitting the node simpler
{
stat = StatisticType(*this);
@@ -49,31 +49,31 @@ RectangleTree<SplitType, DescentType, StatisticType, MatType>::RectangleTree(
RectangleTree* root = this;
//for(int i = firstDataIndex; i < 57; i++) { // 56,57 are the bound for where it works/breaks
- for(size_t i = firstDataIndex; i < data.n_cols; i++) {
+ for (size_t i = firstDataIndex; i < data.n_cols; i++) {
root->InsertPoint(i);
}
}
template<typename SplitType,
- typename DescentType,
- typename StatisticType,
- typename MatType>
+typename DescentType,
+typename StatisticType,
+typename MatType>
RectangleTree<SplitType, DescentType, StatisticType, MatType>::RectangleTree(
- RectangleTree<SplitType, DescentType, StatisticType, MatType>* parentNode):
- maxNumChildren(parentNode->MaxNumChildren()),
- minNumChildren(parentNode->MinNumChildren()),
- numChildren(0),
- children(maxNumChildren+1),
- parent(parentNode),
- begin(0),
- count(0),
- maxLeafSize(parentNode->MaxLeafSize()),
- minLeafSize(parentNode->MinLeafSize()),
- bound(parentNode->Bound().Dim()),
- parentDistance(0),
- dataset(parentNode->Dataset()),
- points(maxLeafSize+1), // Add one to make splitting the node simpler.
- localDataset(new MatType(static_cast<int>(parentNode->Bound().Dim()), static_cast<int>(maxLeafSize)+1)) // Add one to make splitting the node simpler
+ RectangleTree<SplitType, DescentType, StatisticType, MatType>* parentNode) :
+maxNumChildren(parentNode->MaxNumChildren()),
+minNumChildren(parentNode->MinNumChildren()),
+numChildren(0),
+children(maxNumChildren + 1),
+parent(parentNode),
+begin(0),
+count(0),
+maxLeafSize(parentNode->MaxLeafSize()),
+minLeafSize(parentNode->MinLeafSize()),
+bound(parentNode->Bound().Dim()),
+parentDistance(0),
+dataset(parentNode->Dataset()),
+points(maxLeafSize + 1), // Add one to make splitting the node simpler.
+localDataset(new MatType(static_cast<int> (parentNode->Bound().Dim()), static_cast<int> (maxLeafSize) + 1)) // Add one to make splitting the node simpler
{
stat = StatisticType(*this);
}
@@ -84,35 +84,34 @@ RectangleTree<SplitType, DescentType, StatisticType, MatType>::RectangleTree(
* to any nodes which are children of this one.
*/
template<typename SplitType,
- typename DescentType,
- typename StatisticType,
- typename MatType>
+typename DescentType,
+typename StatisticType,
+typename MatType>
RectangleTree<SplitType, DescentType, StatisticType, MatType>::
- ~RectangleTree()
+~RectangleTree()
{
- for(int i = 0; i < numChildren; i++) {
+ for (int i = 0; i < numChildren; i++) {
delete children[i];
}
//if(numChildren == 0)
delete localDataset;
}
-
/**
- * Deletes this node but leaves the children untouched. Needed for when we
- * split nodes and remove nodes (inserting and deleting points).
- */
+ * Deletes this node but leaves the children untouched. Needed for when we
+ * split nodes and remove nodes (inserting and deleting points).
+ */
template<typename SplitType,
- typename DescentType,
- typename StatisticType,
- typename MatType>
+typename DescentType,
+typename StatisticType,
+typename MatType>
void RectangleTree<SplitType, DescentType, StatisticType, MatType>::
- softDelete()
+SoftDelete()
{
//if(numChildren != 0)
- //dataset = NULL;
+ //dataset = NULL;
parent = NULL;
- for(int i = 0; i < children.size(); i++) {
+ for (int i = 0; i < children.size(); i++) {
children[i] = NULL;
}
numChildren = 0;
@@ -120,35 +119,34 @@ void RectangleTree<SplitType, DescentType, StatisticType, MatType>::
}
/**
- * Set the dataset to null.
- */
+ * Set the dataset to null.
+ */
template<typename SplitType,
- typename DescentType,
- typename StatisticType,
- typename MatType>
+typename DescentType,
+typename StatisticType,
+typename MatType>
void RectangleTree<SplitType, DescentType, StatisticType, MatType>::
- NullifyData()
+NullifyData()
{
localDataset = NULL;
}
-
/**
* Recurse through the tree and insert the point at the leaf node chosen
* by the heuristic.
*/
template<typename SplitType,
- typename DescentType,
- typename StatisticType,
- typename MatType>
+typename DescentType,
+typename StatisticType,
+typename MatType>
void RectangleTree<SplitType, DescentType, StatisticType, MatType>::
- InsertPoint(const size_t point)
+InsertPoint(const size_t point)
{
// Expand the bound regardless of whether it is a leaf node.
bound |= dataset.col(point);
// If this is a leaf node, we stop here and add the point.
- if(numChildren == 0) {
+ if (numChildren == 0) {
points[count++] = point;
localDataset->col(count) = dataset.col(point);
SplitNode();
@@ -157,17 +155,7 @@ void RectangleTree<SplitType, DescentType, StatisticType, MatType>::
// If it is not a leaf node, we use the DescentHeuristic to choose a child
// to which we recurse.
- double minScore = DescentType::EvalNode(children[0]->Bound(), dataset.col(point));
- int bestIndex = 0;
-
- for(int i = 1; i < numChildren; i++) {
- double score = DescentType::EvalNode(children[i]->Bound(), dataset.col(point));
- if(score < minScore) {
- minScore = score;
- bestIndex = i;
- }
- }
- children[bestIndex]->InsertPoint(point);
+ children[DescentType::ChooseDescentNode(this, dataset.col(point))]->InsertPoint(point);
}
/**
@@ -175,98 +163,70 @@ void RectangleTree<SplitType, DescentType, StatisticType, MatType>::
* shrink the rectangles if necessary.
*/
template<typename SplitType,
- typename DescentType,
- typename StatisticType,
- typename MatType>
+typename DescentType,
+typename StatisticType,
+typename MatType>
bool RectangleTree<SplitType, DescentType, StatisticType, MatType>::
- DeletePoint(const size_t point)
+DeletePoint(const size_t point)
{
- if(numChildren == 0) {
- for(int i = 0; i < count; i++) {
- if(points[i] == point) {
- points[i] = points[--count];
- for(int j = 0; j < bound.Dim(); j++) {
- if(bound[j].Lo() == dataset.col(point)[j]) {
- int loIndx = 0;
- double lo = dataset(points[0])[j];
- for(int k = 1; k < count; k++) {
- if(dataset(points[k])[j] < lo) {
- lo = dataset(points[k])[j];
- loIndx = k;
- }
- }
- bound[j].Lo() = lo;
- } else if(bound[j].Hi() == dataset.col(point)[j]) {
- int hiIndx = 0;
- double hi = dataset(points[0])[j];
- for(int k = 1; k < count; k++) {
- if(dataset(points[k])[j] > hi) {
- hi = dataset(points[k])[j];
- hiIndx = k;
- }
- }
- bound[j].Hi() = hi;
- }
- }
- return true;
- }
- }
- } else {
- for(int i = 0; i < numChildren; i++) {
- if(children[i].Bound().Contains(dataset.col(point))) {
- if(children[i].DeletePoint(dataset.col(point))) {
- return true;
- }
- }
+ if (numChildren == 0) {
+ for (size_t i = 0; i < count; i++) {
+ if (points[i] == point) {
+ localDataset->col(i) = localDataset->col(--count); // decrement count
+ points[i] = points[count];
+ CondenseTree(dataset.col(point)); // This function will ensure that minFill is satisfied.
+ return true;
}
+ }
+ }
+ for (size_t i = 0; i < numChildren; i++) {
+ if (children[i]->Bound().Contains(dataset.col(point)))
+ if (children[i]->DeletePoint(point))
+ return true;
}
return false;
}
template<typename SplitType,
- typename DescentType,
- typename StatisticType,
- typename MatType>
+typename DescentType,
+typename StatisticType,
+typename MatType>
size_t RectangleTree<SplitType, DescentType, StatisticType, MatType>::
- TreeSize() const
+TreeSize() const
{
int n = 0;
- for(int i = 0; i < numChildren; i++) {
+ for (int i = 0; i < numChildren; i++) {
n += children[i]->TreeSize();
}
return n + 1; // we add one for this node
}
-
-
template<typename SplitType,
- typename DescentType,
- typename StatisticType,
- typename MatType>
+typename DescentType,
+typename StatisticType,
+typename MatType>
size_t RectangleTree<SplitType, DescentType, StatisticType, MatType>::
- TreeDepth() const
+TreeDepth() const
{
/* Because R trees are balanced, we could simplify this. However, X trees are not
guaranteed to be balanced so I keep it as is: */
-
- // Recursively count the depth of each subtree. The plus one is
- // because we have to count this node, too.
- int maxSubDepth = 0;
- for(int i = 0; i < numChildren; i++) {
- int d = children[i]->TreeDepth();
- if(d > maxSubDepth)
- maxSubDepth = d;
+ int n = 1;
+ RectangleTree<SplitType, DescentType, StatisticType, MatType>* currentNode =
+ const_cast<RectangleTree*> (this);
+ while (!currentNode->IsLeaf()) {
+ currentNode = currentNode->Children()[0];
+ n++;
}
- return maxSubDepth + 1;
+ return n;
}
template<typename SplitType,
- typename DescentType,
- typename StatisticType,
- typename MatType>
+typename DescentType,
+typename StatisticType,
+typename MatType>
inline bool RectangleTree<SplitType, DescentType, StatisticType, MatType>::
- IsLeaf() const
+IsLeaf() const
{
return numChildren == 0;
}
@@ -276,13 +236,13 @@ inline bool RectangleTree<SplitType, DescentType, StatisticType, MatType>::
* This returns 0 unless the node is a leaf.
*/
template<typename SplitType,
- typename DescentType,
- typename StatisticType,
- typename MatType>
+typename DescentType,
+typename StatisticType,
+typename MatType>
inline double RectangleTree<SplitType, DescentType, StatisticType, MatType>::
FurthestPointDistance() const
{
- if(!IsLeaf())
+ if (!IsLeaf())
return 0.0;
// Otherwise return the distance from the centroid to a corner of the bound.
@@ -297,11 +257,11 @@ FurthestPointDistance() const
* it will never be greater than this).
*/
template<typename SplitType,
- typename DescentType,
- typename StatisticType,
- typename MatType>
+typename DescentType,
+typename StatisticType,
+typename MatType>
inline double RectangleTree<SplitType, DescentType, StatisticType, MatType>::
- FurthestDescendantDistance() const
+FurthestDescendantDistance() const
{
//return the distance from the centroid to a corner of the bound.
return 0.5 * bound.Diameter();
@@ -311,13 +271,13 @@ inline double RectangleTree<SplitType, DescentType, StatisticType, MatType>::
* Return the number of points contained in this node. Zero if it is a non-leaf node.
*/
template<typename SplitType,
- typename DescentType,
- typename StatisticType,
- typename MatType>
+typename DescentType,
+typename StatisticType,
+typename MatType>
inline size_t RectangleTree<SplitType, DescentType, StatisticType, MatType>::
- NumPoints() const
+NumPoints() const
{
- if(numChildren != 0) // This is not a leaf node.
+ if (numChildren != 0) // This is not a leaf node.
return 0;
return count;
@@ -327,17 +287,17 @@ inline size_t RectangleTree<SplitType, DescentType, StatisticType, MatType>::
* Return the number of descendants under or in this node.
*/
template<typename SplitType,
- typename DescentType,
- typename StatisticType,
- typename MatType>
+typename DescentType,
+typename StatisticType,
+typename MatType>
inline size_t RectangleTree<SplitType, DescentType, StatisticType, MatType>::
- NumDescendants() const
+NumDescendants() const
{
- if(numChildren == 0)
+ if (numChildren == 0)
return count;
else {
size_t n = 0;
- for(int i = 0; i < numChildren; i++) {
+ for (int i = 0; i < numChildren; i++) {
n += children[i]->NumDescendants();
}
return n;
@@ -348,11 +308,11 @@ inline size_t RectangleTree<SplitType, DescentType, StatisticType, MatType>::
* Return the index of a particular descendant contained in this node.
*/
template<typename SplitType,
- typename DescentType,
- typename StatisticType,
- typename MatType>
+typename DescentType,
+typename StatisticType,
+typename MatType>
inline size_t RectangleTree<SplitType, DescentType, StatisticType, MatType>::
- Descendant(const size_t index) const
+Descendant(const size_t index) const
{
return (points[index]);
}
@@ -361,11 +321,11 @@ inline size_t RectangleTree<SplitType, DescentType, StatisticType, MatType>::
* Return the index of a particular point contained in this node.
*/
template<typename SplitType,
- typename DescentType,
- typename StatisticType,
- typename MatType>
+typename DescentType,
+typename StatisticType,
+typename MatType>
inline size_t RectangleTree<SplitType, DescentType, StatisticType, MatType>::
- Point(const size_t index) const
+Point(const size_t index) const
{
return dataset(points[index]);
}
@@ -375,26 +335,26 @@ inline size_t RectangleTree<SplitType, DescentType, StatisticType, MatType>::
* SO THIS IS RATHER POINTLESS.
*/
template<typename SplitType,
- typename DescentType,
- typename StatisticType,
- typename MatType>
+typename DescentType,
+typename StatisticType,
+typename MatType>
inline size_t RectangleTree<SplitType, DescentType, StatisticType, MatType>::End() const
{
- if(numChildren)
+ if (numChildren)
return begin + count;
- return children[numChildren-1]->End();
+ return children[numChildren - 1]->End();
}
- //have functions for returning the list of modified indices if we end up doing it that way.
+//have functions for returning the list of modified indices if we end up doing it that way.
/**
* Split the tree. This calls the SplitType code to split a node. This method should only
* be called on a leaf node.
*/
template<typename SplitType,
- typename DescentType,
- typename StatisticType,
- typename MatType>
+typename DescentType,
+typename StatisticType,
+typename MatType>
void RectangleTree<SplitType, DescentType, StatisticType, MatType>::SplitNode()
{
// This should always be a leaf node. When we need to split other nodes,
@@ -402,7 +362,7 @@ void RectangleTree<SplitType, DescentType, StatisticType, MatType>::SplitNode()
assert(numChildren == 0);
// Check to see if we are full.
- if(count < maxLeafSize)
+ if (count < maxLeafSize)
return; // We don't need to split.
// If we are full, then we need to split (or at least try). The SplitType takes
@@ -410,14 +370,192 @@ void RectangleTree<SplitType, DescentType, StatisticType, MatType>::SplitNode()
SplitType::SplitLeafNode(this);
}
+/**
+ * Condense the tree. This shrinks the bounds and moves up the tree if applicable.
+ * If a node goes below minimum fill, this code will deal with it.
+ */
+template<typename SplitType,
+typename DescentType,
+typename StatisticType,
+typename MatType>
+void RectangleTree<SplitType, DescentType, StatisticType, MatType>::CondenseTree(const arma::vec& point)
+{
+ // First delete the node if we need to. There's no point in shrinking the bound first.
+ if (IsLeaf() && count < minLeafSize && parent != NULL) { //We can't delete the root node
+ for (size_t i = 0; i < parent->NumChildren(); i++) {
+ if (parent->Children()[i] == this) {
+ parent->Children()[i] = parent->Children()[--parent->NumChildren()]; // decrement numChildren
+ parent->ShrinkBoundForBound(bound); // We want to do this before reinserting points.
+
+ // Reinsert the points at the root node.
+ RectangleTree<SplitType, DescentType, StatisticType, MatType>* root = this;
+ while (root->Parent() != NULL)
+ root = root->Parent();
+ for (size_t j = 0; j < numChildren; j++)
+ root->InsertPoint(points[j]);
+
+ parent->CondenseTree(point); // This will check the MinFill of the parent.
+ //Now it should be safe to delete this node.
+ SoftDelete();
+ return;
+ }
+ }
+ // Control should never reach here.
+ assert(true == false);
+ } else if (!IsLeaf() && numChildren < minNumChildren) {
+ if (parent != NULL) { // The normal case. We need to be careful with the root.
+ for (size_t j = 0; j < parent->NumChildren(); j++) {
+ if (parent->Children()[j] == this) {
+ parent->Children()[j] = parent->Children()[--parent->NumChildren()]; // decrement numChildren
+ parent->ShrinkBoundForBound(bound); // We want to do this before reinserting nodes.
+
+ size_t level = TreeDepth();
+ // Reinsert the nodes at the root node.
+ RectangleTree<SplitType, DescentType, StatisticType, MatType>* root = this;
+ while (root->Parent() != NULL)
+ root = root->Parent();
+ for (size_t i = 0; i < numChildren; i++)
+ root->InsertNode(children[i], level);
+
+ parent->CondenseTree(point); // This will check the MinFill of the parent.
+ //Now it should be safe to delete this node.
+ SoftDelete();
+ return;
+ }
+ }
+ } else if (numChildren == 1) { // If there are multiple children, we can't do anything to the root.
+ RectangleTree<SplitType, DescentType, StatisticType, MatType>* child = children[0];
+ for (size_t i = 0; i < child->NumChildren(); i++) {
+ children[i] = child->Children()[i];
+ }
+ for (size_t i = 0; i < child->Count(); i++) { // In case the tree has a height of two.
+ points[i] = child->Points()[i];
+ localDataset->col(i) = child->LocalDataset().col(i);
+ }
+ child->SoftDelete();
+ }
+ }
+
+ // If we didn't delete it, shrink the bound if we need to.
+ if (ShrinkBoundForPoint(point))
+ parent->CondenseTree(point);
+
+}
+
+/**
+ * Shrink the bound so it fits tightly after the removal of this point.
+ */
+template<typename SplitType,
+typename DescentType,
+typename StatisticType,
+typename MatType>
+bool RectangleTree<SplitType, DescentType, StatisticType, MatType>::ShrinkBoundForPoint(const arma::vec& point)
+{
+ bool shrunk = false;
+ if (IsLeaf()) {
+ for (size_t i = 0; i < point.n_elem; i++) {
+ if (bound[i].Lo() == point[i]) {
+ double min = DBL_MAX;
+ for (size_t j = 0; j < count; j++) {
+ if (localDataset->col(j)[i] < min)
+ min = localDataset->col(j)[i];
+ }
+ if (bound[i].Lo() < min)
+ shrunk = true;
+ bound[i].Lo() = min;
+ } else if (bound[i].Hi() == point[i]) {
+ double max = -1 * DBL_MAX;
+ for (size_t j = 0; j < count; j++) {
+ if (localDataset->col(j)[i] > max)
+ max = localDataset->col(j)[i];
+ }
+ if (bound[i].Hi() > max)
+ shrunk = true;
+ bound[i].Hi() = max;
+ }
+ }
+ } else {
+ for (size_t i = 0; i < point.n_elem; i++) {
+ if (bound[i].Lo() == point[i]) {
+ double min = DBL_MAX;
+ for (size_t j = 0; j < numChildren; j++) {
+ if (children[j]->Bound()[i].Lo() < min)
+ min = children[j]->Bound()[i].Lo();
+ }
+ if (bound[i].Lo() < min)
+ shrunk = true;
+ bound[i].Lo() = min;
+ } else if (bound[i].Hi() == point[i]) {
+ double max = -1 * DBL_MAX;
+ for (size_t j = 0; j < numChildren; j++) {
+ if (children[j]->Bound()[i].Hi() > max)
+ max = children[j]->Bound()[i].Hi();
+ }
+ if (bound[i].Hi() > max)
+ shrunk = true;
+ bound[i].Hi() = max;
+ }
+ }
+ }
+ return shrunk;
+}
+
+/**
+ * Shrink the bound so it fits tightly after the removal of this bound.
+ */
+template<typename SplitType,
+typename DescentType,
+typename StatisticType,
+typename MatType>
+bool RectangleTree<SplitType, DescentType, StatisticType, MatType>::ShrinkBoundForBound(const HRectBound<>& /* unused */)
+{
+ double sum = 0; // Using the sum is safe since none of the dimensions can increase.
+ //I think it may be faster to just recalculate the whole thing.
+ for (size_t i = 0; i < bound.Dim(); i++) {
+ sum += bound[i].Width();
+ bound[i].Lo() = DBL_MAX;
+ bound[i].Hi() = -1 * DBL_MAX;
+ }
+ for (size_t i = 0; i < numChildren; i++) {
+ bound |= children[i]->Bound();
+ }
+ double sum2 = 0;
+ for (size_t i = 0; i < bound.Dim(); i++)
+ sum2 += bound[i].Width();
+ return sum == sum2;
+}
+
+/**
+ * Insert the node into the tree at the appropriate depth.
+ */
+template<typename SplitType,
+typename DescentType,
+typename StatisticType,
+typename MatType>
+void RectangleTree<SplitType, DescentType, StatisticType, MatType>::InsertNode(
+ const RectangleTree<SplitType, DescentType, StatisticType, MatType>* node,
+ const size_t level)
+{
+ // Expand the bound regardless of the level.
+ bound |= node->Bound();
+
+ if (level == TreeDepth()) {
+ children[numChildren++] = const_cast<RectangleTree*> (node);
+ assert(numChildren <= maxNumChildren); // We should never have increased without splitting.
+ if (numChildren == maxNumChildren)
+ SplitType::SplitNonLeafNode(this);
+ } else {
+ children[DescentType::ChooseDescentNode(this, node)]->InsertNode(node, level);
+ }
+}
/**
* Returns a string representation of this object.
*/
template<typename SplitType,
- typename DescentType,
- typename StatisticType,
- typename MatType>
+typename DescentType,
+typename StatisticType,
+typename MatType>
std::string RectangleTree<SplitType, DescentType, StatisticType, MatType>::ToString() const
{
std::ostringstream convert;
@@ -436,8 +574,8 @@ std::string RectangleTree<SplitType, DescentType, StatisticType, MatType>::ToStr
convert << " Parent address: " << parent << std::endl;
// How many levels should we print? This will print 3 levels (counting the root).
- if(parent == NULL || parent->Parent() == NULL) {
- for(int i = 0; i < numChildren; i++) {
+ if (parent == NULL || parent->Parent() == NULL) {
+ for (int i = 0; i < numChildren; i++) {
convert << children[i]->ToString();
}
}
diff --git a/src/mlpack/tests/rectangle_tree_test.cpp b/src/mlpack/tests/rectangle_tree_test.cpp
index 28a14fd..6331a68 100644
--- a/src/mlpack/tests/rectangle_tree_test.cpp
+++ b/src/mlpack/tests/rectangle_tree_test.cpp
@@ -169,6 +169,19 @@ BOOST_AUTO_TEST_CASE(TreeLocalDatasetInSync) {
assert(checkSync(tree) == true);
}
+BOOST_AUTO_TEST_CASE(PointDeletion) {
+ arma::mat dataset;
+ dataset.randu(8, 1000); // 1000 points in 8 dimensions.
+
+ RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+ tree::RTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat> tree(dataset, 20, 6, 5, 2, 0);
+ tree.DeletePoint(999);
+ assert(tree.NumDescendants() == 999);
+
+}
+
BOOST_AUTO_TEST_CASE(SingleTreeTraverserTest)
{
arma::mat dataset;
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list