[mlpack] 95/324: change the tree to store size_t in the nodes and keep the dataset together. Other misc changes.
Barak A. Pearlmutter
barak+git at cs.nuim.ie
Sun Aug 17 08:21:59 UTC 2014
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch svn-trunk
in repository mlpack.
commit 083e69e953faf30b827034a8f202019530d93e75
Author: andrewmw94 <andrewmw94 at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date: Tue Jul 1 16:11:31 2014 +0000
change the tree to store size_t in the nodes and keep the dataset together. Other misc changes.
git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@16733 9d5b8971-822b-0410-80eb-d18c1038ef23
---
.../core/tree/rectangle_tree/r_tree_split_impl.hpp | 26 +++----
.../core/tree/rectangle_tree/rectangle_tree.hpp | 25 +++++--
.../tree/rectangle_tree/rectangle_tree_impl.hpp | 80 ++++++++++++++++++----
.../rectangle_tree_traverser_impl.hpp | 16 ++---
src/mlpack/tests/rectangle_tree_test.cpp | 4 +-
5 files changed, 111 insertions(+), 40 deletions(-)
diff --git a/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
index fbe4519..ab9c55a 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
@@ -194,7 +194,7 @@ void RTreeSplit<DescentType, StatisticType, MatType>::GetPointSeeds(
for(int j = i+1; j < tree.Count(); j++) {
double score = 1.0;
for(int k = 0; k < tree.Bound().Dim(); k++) {
- score *= std::abs(tree.Dataset().at(k, i) - tree.Dataset().at(k, j)); // Points are stored by column, but this function takes (row, col).
+ score *= std::abs(tree.Dataset().at(k, tree.Points()[i]) - tree.Dataset().at(k, tree.Points()[j])); // Points (in the dataset) are stored by column, but this function takes (row, col).
}
if(score > worstPairScore) {
worstPairScore = score;
@@ -264,16 +264,16 @@ void RTreeSplit<DescentType, StatisticType, MatType>::AssignPointDestNode(
treeOne->Count() = 0;
treeTwo->Count() = 0;
- treeOne->InsertPoint(oldTree->Dataset().col(intI));
- treeTwo->InsertPoint(oldTree->Dataset().col(intJ));
+ treeOne->InsertPoint(oldTree->Points()[intI]);
+ treeTwo->InsertPoint(oldTree->Points()[intJ]);
// If intJ is the last point in the tree, we need to switch the order so that we remove the correct points.
if(intI > intJ) {
- oldTree->Dataset().col(intI) = oldTree->Dataset().col(--end); // decrement end
- oldTree->Dataset().col(intJ) = oldTree->Dataset().col(--end); // decrement end
+ oldTree->Points()[intI] = oldTree->Points()[--end]; // decrement end
+ oldTree->Points()[intJ] = oldTree->Points()[--end]; // decrement end
} else {
- oldTree->Dataset().col(intJ) = oldTree->Dataset().col(--end); // decrement end
- oldTree->Dataset().col(intI) = oldTree->Dataset().col(--end); // decrement end
+ oldTree->Points()[intJ] = oldTree->Points()[--end]; // decrement end
+ oldTree->Points()[intI] = oldTree->Points()[--end]; // decrement end
}
@@ -309,7 +309,7 @@ void RTreeSplit<DescentType, StatisticType, MatType>::AssignPointDestNode(
double newVolOne = 1.0;
double newVolTwo = 1.0;
for(int i = 0; i < oldTree->Bound().Dim(); i++) {
- double c = oldTree->Dataset().col(index)[i];
+ double c = oldTree->Dataset().col(oldTree->Points()[index])[i];
newVolOne *= treeOne->Bound()[i].Contains(c) ? treeOne->Bound()[i].Width() :
(c < treeOne->Bound()[i].Lo() ? (treeOne->Bound()[i].Hi() - c) : (c - treeOne->Bound()[i].Lo()));
newVolTwo *= treeTwo->Bound()[i].Contains(c) ? treeTwo->Bound()[i].Width() :
@@ -335,26 +335,26 @@ void RTreeSplit<DescentType, StatisticType, MatType>::AssignPointDestNode(
// Assign the point that causes the least increase in volume
// to the appropriate rectangle.
if(bestRect == 1) {
- treeOne->InsertPoint(oldTree->Dataset().col(bestIndex));
+ treeOne->InsertPoint(oldTree->Points()[bestIndex]);
numAssignedOne++;
}
else {
- treeTwo->InsertPoint(oldTree->Dataset().col(bestIndex));
+ treeTwo->InsertPoint(oldTree->Points()[bestIndex]);
numAssignedTwo++;
}
- oldTree->Dataset().col(bestIndex) = oldTree->Dataset().col(--end); // decrement end.
+ oldTree->Points()[bestIndex] = oldTree->Points()[--end]; // decrement end.
}
// See if we need to satisfy the minimum fill.
if(end > 0) {
if(numAssignedOne < numAssignedTwo) {
for(int i = 0; i < end; i++) {
- treeOne->InsertPoint(oldTree->Dataset().col(i));
+ treeOne->InsertPoint(oldTree->Points()[i]);
}
} else {
for(int i = 0; i < end; i++) {
- treeTwo->InsertPoint(oldTree->Dataset().col(i));
+ treeTwo->InsertPoint(oldTree->Points()[i]);
}
}
}
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
index 3e4bdfd..e6833b9 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
@@ -71,7 +71,9 @@ class RectangleTree
//! The discance to the furthest descendant, cached to speed things up.
double furthestDescendantDistance;
//! The dataset.
- MatType* dataset;
+ MatType& dataset;
+ //! The mapping to the dataset
+ std::vector<size_t> points;
public:
//! So other classes can use TreeType::Mat.
@@ -138,9 +140,17 @@ class RectangleTree
* it may be passed many times before it actually reaches a leaf.
* @param point The point (arma::vec&) to be inserted.
*/
- void InsertPoint(const arma::vec& point);
+ void InsertPoint(const size_t point);
/**
+ * Deletes a point in the tree. The point will be removed from the data matrix
+ * of the leaf node where it is store and the bounding rectangles will be updated.
+ * Returns true if the point is successfully removed and false if it is not.
+ * (ie. the point is not in the tree)
+ */
+ bool DeletePoint(const size_t point);
+
+ /**
* Find a node in this tree by its begin and count (const).
*
* Every node is uniquely identified by these two numbers.
@@ -205,10 +215,15 @@ class RectangleTree
RectangleTree*& Parent() { return parent; }
//! Get the dataset which the tree is built on.
- const arma::mat& Dataset() const { return *dataset; }
+ const arma::mat& Dataset() const { return dataset; }
//! Modify the dataset which the tree is built on. Be careful!
- arma::mat& Dataset() { return *dataset; }
-
+ arma::mat& Dataset() { return dataset; }
+
+ //! Get the points vector for this node.
+ const std::vector<size_t>& Points() const { return points; }
+ //! Modify the points vector for this node. Be careful!
+ std::vector<size_t>& Points() { return points; }
+
//! Get the metric which the tree uses.
typename HRectBound<>::MetricType Metric() const { return bound.Metric(); }
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
index 1589343..0b16246 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
@@ -39,7 +39,8 @@ RectangleTree<SplitType, DescentType, StatisticType, MatType>::RectangleTree(
minLeafSize(minLeafSize),
bound(data.n_rows),
parentDistance(0),
- dataset(new MatType(data.n_rows, static_cast<int>(maxLeafSize)+1)) // Add one to make splitting the node simpler
+ dataset(data),
+ points(maxLeafSize+1) // Add one to make splitting the node simpler.
{
stat = StatisticType(*this);
@@ -47,8 +48,8 @@ RectangleTree<SplitType, DescentType, StatisticType, MatType>::RectangleTree(
RectangleTree* root = this;
//for(int i = firstDataIndex; i < 57; i++) { // 56,57 are the bound for where it works/breaks
- for(int i = firstDataIndex; i < data.n_cols; i++) {
- root->InsertPoint(data.col(i));
+ for(size_t i = firstDataIndex; i < data.n_cols; i++) {
+ root->InsertPoint(i);
}
}
@@ -69,7 +70,8 @@ RectangleTree<SplitType, DescentType, StatisticType, MatType>::RectangleTree(
minLeafSize(parentNode->MinLeafSize()),
bound(parentNode->Bound().Dim()),
parentDistance(0),
- dataset(new MatType(static_cast<int>(parentNode->Bound().Dim()), static_cast<int>(maxLeafSize)+1)) // Add one to make splitting the node simpler
+ dataset(parentNode->Dataset()),
+ points(maxLeafSize+1) // Add one to make splitting the node simpler.
{
stat = StatisticType(*this);
}
@@ -90,7 +92,7 @@ RectangleTree<SplitType, DescentType, StatisticType, MatType>::
delete children[i];
}
//if(numChildren == 0)
- delete dataset;
+ //delete points;
}
@@ -125,7 +127,7 @@ template<typename SplitType,
void RectangleTree<SplitType, DescentType, StatisticType, MatType>::
NullifyData()
{
- dataset = NULL;
+ //points = NULL;
}
@@ -138,25 +140,25 @@ template<typename SplitType,
typename StatisticType,
typename MatType>
void RectangleTree<SplitType, DescentType, StatisticType, MatType>::
- InsertPoint(const arma::vec& point)
+ InsertPoint(const size_t point)
{
// Expand the bound regardless of whether it is a leaf node.
- bound |= point;
+ bound |= dataset.col(point);
// If this is a leaf node, we stop here and add the point.
if(numChildren == 0) {
- dataset->col(count++) = point;
+ points[count++] = point;
SplitNode();
return;
}
// If it is not a leaf node, we use the DescentHeuristic to choose a child
// to which we recurse.
- double minScore = DescentType::EvalNode(children[0]->Bound(), point);
+ double minScore = DescentType::EvalNode(children[0]->Bound(), dataset.col(point));
int bestIndex = 0;
for(int i = 1; i < numChildren; i++) {
- double score = DescentType::EvalNode(children[i]->Bound(), point);
+ double score = DescentType::EvalNode(children[i]->Bound(), dataset.col(point));
if(score < minScore) {
minScore = score;
bestIndex = i;
@@ -165,6 +167,60 @@ void RectangleTree<SplitType, DescentType, StatisticType, MatType>::
children[bestIndex]->InsertPoint(point);
}
+/**
+ * Recurse through the tree to remove the point. Once we find the point, we
+ * shrink the rectangles if necessary.
+ */
+template<typename SplitType,
+ typename DescentType,
+ typename StatisticType,
+ typename MatType>
+bool RectangleTree<SplitType, DescentType, StatisticType, MatType>::
+ DeletePoint(const size_t point)
+{
+ if(numChildren == 0) {
+ for(int i = 0; i < count; i++) {
+ if(points[i] == point) {
+ points[i] = points[--count];
+ for(int j = 0; j < bound.Dim(); j++) {
+ if(bound[j].Lo() == dataset.col(point)[j]) {
+ int loIndx = 0;
+ double lo = dataset(points[0])[j];
+ for(int k = 1; k < count; k++) {
+ if(dataset(points[k])[j] < lo) {
+ lo = dataset(points[k])[j];
+ loIndx = k;
+ }
+ }
+ bound[j].Lo() = lo;
+ } else if(bound[j].Hi() == dataset.col(point)[j]) {
+ int hiIndx = 0;
+ double hi = dataset(points[0])[j];
+ for(int k = 1; k < count; k++) {
+ if(dataset(points[k])[j] > hi) {
+ hi = dataset(points[k])[j];
+ hiIndx = k;
+ }
+ }
+ bound[j].Hi() = hi;
+ }
+ }
+ return true;
+ }
+ }
+ } else {
+ for(int i = 0; i < numChildren; i++) {
+ if(children[i].Bound().Contains(dataset.col(point))) {
+ if(children[i].DeletePoint(dataset.col(point))) {
+
+ return true;
+ }
+ }
+ }
+ }
+ return false;
+}
+
template<typename SplitType,
typename DescentType,
typename StatisticType,
@@ -308,7 +364,7 @@ template<typename SplitType,
inline size_t RectangleTree<SplitType, DescentType, StatisticType, MatType>::
Point(const size_t index) const
{
- return (begin + index);
+ return dataset(points[index]);
}
/**
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_traverser_impl.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_traverser_impl.hpp
index 482221a..93319cf 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_traverser_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_traverser_impl.hpp
@@ -49,8 +49,8 @@ RectangleTreeTraverser<RuleType>::Traverse(
// This is not a leaf node so we:
// Sort the children of this node by their scores.
- std::vector<RectangleTree*> nodes = new std::vector<RectangleTree*>(referenceNode.NumChildren());
- std::vector<double> scores = new std::vector<double>(referenceNode.NumChildren());
+ std::vector<RectangleTree*> nodes(referenceNode.NumChildren());
+ std::vector<double> scores(referenceNode.NumChildren());
for(int i = 0; i < referenceNode.NumChildren(); i++) {
nodes[i] = referenceNode.Children()[i];
scores[i] = rule.Score(nodes[i]);
@@ -60,12 +60,12 @@ RectangleTreeTraverser<RuleType>::Traverse(
// Iterate through them starting with the best and stopping when we reach
// one that isn't good enough.
for(int i = 0; i < referenceNode.NumChildren(); i++) {
- if(rule.Rescore(queryIndex, nodes[i], scores[i]) != DBL_MAX)
- Traverse(queryIndex, nodes[i]);
- else {
- numPrunes += referenceNode.NumChildren - i;
- return;
- }
+ if(rule.Rescore(queryIndex, nodes[i], scores[i]) != DBL_MAX)
+ Traverse(queryIndex, nodes[i]);
+ else {
+ numPrunes += referenceNode.NumChildren - i;
+ return;
+ }
}
// We only get here if we couldn't prune any of them.
return;
diff --git a/src/mlpack/tests/rectangle_tree_test.cpp b/src/mlpack/tests/rectangle_tree_test.cpp
index da04f58..83caf30 100644
--- a/src/mlpack/tests/rectangle_tree_test.cpp
+++ b/src/mlpack/tests/rectangle_tree_test.cpp
@@ -70,7 +70,7 @@ std::vector<arma::vec*> getAllPointsInTree(const RectangleTree<tree::RTreeSplit<
}
} else {
for(size_t i = 0; i < tree.Count(); i++) {
- arma::vec* c = new arma::vec(tree.Dataset().col(i));
+ arma::vec* c = new arma::vec(tree.Dataset().col(tree.Points()[i]));
vec.push_back(c);
}
}
@@ -112,7 +112,7 @@ bool checkContainment(const RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeu
bool passed = true;
if(tree.NumChildren() == 0) {
for(size_t i = 0; i < tree.Count(); i++) {
- passed &= tree.Bound().Contains(tree.Dataset().unsafe_col(i));
+ passed &= tree.Bound().Contains(tree.Dataset().unsafe_col(tree.Points()[i]));
}
} else {
for(size_t i = 0; i < tree.NumChildren(); i++) {
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list