[mlpack] 95/324: change the tree to store size_t in the nodes and keep the dataset together. Other misc changes.

Barak A. Pearlmutter barak+git at cs.nuim.ie
Sun Aug 17 08:21:59 UTC 2014


This is an automated email from the git hooks/post-receive script.

bap pushed a commit to branch svn-trunk
in repository mlpack.

commit 083e69e953faf30b827034a8f202019530d93e75
Author: andrewmw94 <andrewmw94 at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date:   Tue Jul 1 16:11:31 2014 +0000

    change the tree to store size_t in the nodes and keep the dataset together.  Other misc changes.
    
    git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@16733 9d5b8971-822b-0410-80eb-d18c1038ef23
---
 .../core/tree/rectangle_tree/r_tree_split_impl.hpp | 26 +++----
 .../core/tree/rectangle_tree/rectangle_tree.hpp    | 25 +++++--
 .../tree/rectangle_tree/rectangle_tree_impl.hpp    | 80 ++++++++++++++++++----
 .../rectangle_tree_traverser_impl.hpp              | 16 ++---
 src/mlpack/tests/rectangle_tree_test.cpp           |  4 +-
 5 files changed, 111 insertions(+), 40 deletions(-)

diff --git a/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
index fbe4519..ab9c55a 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
@@ -194,7 +194,7 @@ void RTreeSplit<DescentType, StatisticType, MatType>::GetPointSeeds(
     for(int j = i+1; j < tree.Count(); j++) {
       double score = 1.0;
       for(int k = 0; k < tree.Bound().Dim(); k++) {
-	score *= std::abs(tree.Dataset().at(k, i) - tree.Dataset().at(k, j)); // Points are stored by column, but this function takes (row, col).
+	score *= std::abs(tree.Dataset().at(k, tree.Points()[i]) - tree.Dataset().at(k, tree.Points()[j])); // Points (in the dataset) are stored by column, but this function takes (row, col).
       }
       if(score > worstPairScore) {
 	worstPairScore = score;
@@ -264,16 +264,16 @@ void RTreeSplit<DescentType, StatisticType, MatType>::AssignPointDestNode(
   treeOne->Count() = 0;
   treeTwo->Count() = 0;
 
-  treeOne->InsertPoint(oldTree->Dataset().col(intI));
-  treeTwo->InsertPoint(oldTree->Dataset().col(intJ));
+  treeOne->InsertPoint(oldTree->Points()[intI]);
+  treeTwo->InsertPoint(oldTree->Points()[intJ]);
   
   // If intJ is the last point in the tree, we need to switch the order so that we remove the correct points.
   if(intI > intJ) {
-    oldTree->Dataset().col(intI) = oldTree->Dataset().col(--end); // decrement end
-    oldTree->Dataset().col(intJ) = oldTree->Dataset().col(--end); // decrement end
+    oldTree->Points()[intI] = oldTree->Points()[--end]; // decrement end
+    oldTree->Points()[intJ] = oldTree->Points()[--end]; // decrement end
   } else {
-    oldTree->Dataset().col(intJ) = oldTree->Dataset().col(--end); // decrement end
-    oldTree->Dataset().col(intI) = oldTree->Dataset().col(--end); // decrement end
+    oldTree->Points()[intJ] = oldTree->Points()[--end]; // decrement end
+    oldTree->Points()[intI] = oldTree->Points()[--end]; // decrement end
   }
     
     
@@ -309,7 +309,7 @@ void RTreeSplit<DescentType, StatisticType, MatType>::AssignPointDestNode(
       double newVolOne = 1.0;
       double newVolTwo = 1.0;
       for(int i = 0; i < oldTree->Bound().Dim(); i++) {
-	double c = oldTree->Dataset().col(index)[i];      
+	double c = oldTree->Dataset().col(oldTree->Points()[index])[i];      
 	newVolOne *= treeOne->Bound()[i].Contains(c) ? treeOne->Bound()[i].Width() :
 	  (c < treeOne->Bound()[i].Lo() ? (treeOne->Bound()[i].Hi() - c) : (c - treeOne->Bound()[i].Lo()));
 	newVolTwo *= treeTwo->Bound()[i].Contains(c) ? treeTwo->Bound()[i].Width() :
@@ -335,26 +335,26 @@ void RTreeSplit<DescentType, StatisticType, MatType>::AssignPointDestNode(
     // Assign the point that causes the least increase in volume 
     // to the appropriate rectangle.
     if(bestRect == 1) {
-      treeOne->InsertPoint(oldTree->Dataset().col(bestIndex));
+      treeOne->InsertPoint(oldTree->Points()[bestIndex]);
       numAssignedOne++;
     }
     else {
-      treeTwo->InsertPoint(oldTree->Dataset().col(bestIndex));
+      treeTwo->InsertPoint(oldTree->Points()[bestIndex]);
       numAssignedTwo++;      
     }
 
-    oldTree->Dataset().col(bestIndex) = oldTree->Dataset().col(--end); // decrement end.
+    oldTree->Points()[bestIndex] = oldTree->Points()[--end]; // decrement end.
   }
   
   // See if we need to satisfy the minimum fill.
   if(end > 0) {
     if(numAssignedOne < numAssignedTwo) {
       for(int i = 0; i < end; i++) {
-        treeOne->InsertPoint(oldTree->Dataset().col(i));
+        treeOne->InsertPoint(oldTree->Points()[i]);
       }
     } else {
       for(int i = 0; i < end; i++) {
-        treeTwo->InsertPoint(oldTree->Dataset().col(i));
+        treeTwo->InsertPoint(oldTree->Points()[i]);
       }
     }
   }
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
index 3e4bdfd..e6833b9 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
@@ -71,7 +71,9 @@ class RectangleTree
   //! The discance to the furthest descendant, cached to speed things up.
   double furthestDescendantDistance;
   //! The dataset.
-  MatType* dataset;
+  MatType& dataset;
+  //! The mapping to the dataset
+  std::vector<size_t> points;
   
  public:
   //! So other classes can use TreeType::Mat.
@@ -138,9 +140,17 @@ class RectangleTree
    * it may be passed many times before it actually reaches a leaf.
    * @param point The point (arma::vec&) to be inserted.
    */
-  void InsertPoint(const arma::vec& point);
+  void InsertPoint(const size_t point);
 
   /**
+   * Deletes a point in the tree.  The point will be removed from the data matrix
+   * of the leaf node where it is store and the bounding rectangles will be updated.
+   * Returns true if the point is successfully removed and false if it is not.
+   * (ie. the point is not in the tree)
+   */
+  bool DeletePoint(const size_t point);
+  
+  /**
    * Find a node in this tree by its begin and count (const).
    *
    * Every node is uniquely identified by these two numbers.
@@ -205,10 +215,15 @@ class RectangleTree
   RectangleTree*& Parent() { return parent; }
 
   //! Get the dataset which the tree is built on.
-  const arma::mat& Dataset() const { return *dataset; }
+  const arma::mat& Dataset() const { return dataset; }
   //! Modify the dataset which the tree is built on.  Be careful!
-  arma::mat& Dataset() { return *dataset; }
-
+  arma::mat& Dataset() { return dataset; }
+  
+  //! Get the points vector for this node.
+  const std::vector<size_t>& Points() const { return points; }
+  //! Modify the points vector for this node.  Be careful!
+  std::vector<size_t>& Points() { return points; }
+  
   //! Get the metric which the tree uses.
   typename HRectBound<>::MetricType Metric() const { return bound.Metric(); }
 
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
index 1589343..0b16246 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
@@ -39,7 +39,8 @@ RectangleTree<SplitType, DescentType, StatisticType, MatType>::RectangleTree(
     minLeafSize(minLeafSize),
     bound(data.n_rows),
     parentDistance(0),
-    dataset(new MatType(data.n_rows, static_cast<int>(maxLeafSize)+1)) // Add one to make splitting the node simpler
+    dataset(data),
+    points(maxLeafSize+1) // Add one to make splitting the node simpler.
 {
   stat = StatisticType(*this);
   
@@ -47,8 +48,8 @@ RectangleTree<SplitType, DescentType, StatisticType, MatType>::RectangleTree(
   RectangleTree* root = this;
   
   //for(int i = firstDataIndex; i < 57; i++) { // 56,57 are the bound for where it works/breaks
-  for(int i = firstDataIndex; i < data.n_cols; i++) {
-    root->InsertPoint(data.col(i));
+  for(size_t i = firstDataIndex; i < data.n_cols; i++) {
+    root->InsertPoint(i);
   }
 }
 
@@ -69,7 +70,8 @@ RectangleTree<SplitType, DescentType, StatisticType, MatType>::RectangleTree(
   minLeafSize(parentNode->MinLeafSize()),
   bound(parentNode->Bound().Dim()),
   parentDistance(0),
-  dataset(new MatType(static_cast<int>(parentNode->Bound().Dim()), static_cast<int>(maxLeafSize)+1)) // Add one to make splitting the node simpler
+  dataset(parentNode->Dataset()),
+  points(maxLeafSize+1) // Add one to make splitting the node simpler.
 {
   stat = StatisticType(*this);
 }
@@ -90,7 +92,7 @@ RectangleTree<SplitType, DescentType, StatisticType, MatType>::
     delete children[i];
   }
   //if(numChildren == 0)
-    delete dataset;
+  //delete points;
 }
 
 
@@ -125,7 +127,7 @@ template<typename SplitType,
 void RectangleTree<SplitType, DescentType, StatisticType, MatType>::
     NullifyData()
 {
-  dataset = NULL;
+  //points = NULL;
 }
 
 
@@ -138,25 +140,25 @@ template<typename SplitType,
 	 typename StatisticType,
          typename MatType>
 void RectangleTree<SplitType, DescentType, StatisticType, MatType>::
-    InsertPoint(const arma::vec& point)
+    InsertPoint(const size_t point)
 {
   // Expand the bound regardless of whether it is a leaf node.
-  bound |= point;
+  bound |= dataset.col(point);
 
   // If this is a leaf node, we stop here and add the point.
   if(numChildren == 0) {
-    dataset->col(count++) = point;
+    points[count++] = point;
     SplitNode();
     return;
   }
 
   // If it is not a leaf node, we use the DescentHeuristic to choose a child
   // to which we recurse.
-  double minScore = DescentType::EvalNode(children[0]->Bound(), point);
+  double minScore = DescentType::EvalNode(children[0]->Bound(), dataset.col(point));
   int bestIndex = 0;
   
   for(int i = 1; i < numChildren; i++) {
-    double score = DescentType::EvalNode(children[i]->Bound(), point);
+    double score = DescentType::EvalNode(children[i]->Bound(), dataset.col(point));
     if(score < minScore) {
       minScore = score;
       bestIndex = i;
@@ -165,6 +167,60 @@ void RectangleTree<SplitType, DescentType, StatisticType, MatType>::
   children[bestIndex]->InsertPoint(point);
 }
 
+/**
+ * Recurse through the tree to remove the point.  Once we find the point, we
+ * shrink the rectangles if necessary.
+ */
+template<typename SplitType,
+         typename DescentType,
+	 typename StatisticType,
+         typename MatType>
+bool RectangleTree<SplitType, DescentType, StatisticType, MatType>::
+    DeletePoint(const size_t point)
+{
+  if(numChildren == 0) {
+    for(int i = 0; i < count; i++) {
+      if(points[i] == point) {
+	points[i] = points[--count];
+	for(int j = 0; j < bound.Dim(); j++) {
+	  if(bound[j].Lo() == dataset.col(point)[j]) {
+	    int loIndx = 0;
+	    double lo = dataset(points[0])[j];
+	    for(int k = 1; k < count; k++) {
+	      if(dataset(points[k])[j] < lo) {
+		lo = dataset(points[k])[j];
+		loIndx = k;
+	      }
+	    }
+	    bound[j].Lo() = lo;
+	  } else if(bound[j].Hi() == dataset.col(point)[j]) {
+	    int hiIndx = 0;
+	    double hi = dataset(points[0])[j];
+	    for(int k = 1; k < count; k++) {
+	      if(dataset(points[k])[j] > hi) {
+		hi = dataset(points[k])[j];
+		hiIndx = k;
+	      }
+	    }
+	    bound[j].Hi() = hi;
+	  }
+	}
+	return true;
+      }
+    }
+  } else {
+      for(int i = 0; i < numChildren; i++) {
+	if(children[i].Bound().Contains(dataset.col(point))) {
+	  if(children[i].DeletePoint(dataset.col(point))) {
+	    
+	    return true;
+	  }
+	}
+      }
+  } 
+  return false;
+}
+
 template<typename SplitType,
          typename DescentType,
 	 typename StatisticType,
@@ -308,7 +364,7 @@ template<typename SplitType,
 inline size_t RectangleTree<SplitType, DescentType, StatisticType, MatType>::
     Point(const size_t index) const
 {
-  return (begin + index);
+  return dataset(points[index]);
 }
 
 /**
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_traverser_impl.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_traverser_impl.hpp
index 482221a..93319cf 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_traverser_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_traverser_impl.hpp
@@ -49,8 +49,8 @@ RectangleTreeTraverser<RuleType>::Traverse(
   
   // This is not a leaf node so we:
   // Sort the children of this node by their scores.
-  std::vector<RectangleTree*> nodes = new std::vector<RectangleTree*>(referenceNode.NumChildren());
-  std::vector<double> scores = new std::vector<double>(referenceNode.NumChildren());
+  std::vector<RectangleTree*> nodes(referenceNode.NumChildren());
+  std::vector<double> scores(referenceNode.NumChildren());
   for(int i = 0; i < referenceNode.NumChildren(); i++) {
     nodes[i] = referenceNode.Children()[i];
     scores[i] = rule.Score(nodes[i]);
@@ -60,12 +60,12 @@ RectangleTreeTraverser<RuleType>::Traverse(
   // Iterate through them starting with the best and stopping when we reach
   // one that isn't good enough.
   for(int i = 0; i < referenceNode.NumChildren(); i++) {
-   if(rule.Rescore(queryIndex, nodes[i], scores[i]) != DBL_MAX)
-     Traverse(queryIndex, nodes[i]);
-   else {
-    numPrunes += referenceNode.NumChildren - i;
-    return;
-   }
+    if(rule.Rescore(queryIndex, nodes[i], scores[i]) != DBL_MAX)
+      Traverse(queryIndex, nodes[i]);
+    else {
+      numPrunes += referenceNode.NumChildren - i;
+      return;
+    }
   }
   // We only get here if we couldn't prune any of them.
   return;
diff --git a/src/mlpack/tests/rectangle_tree_test.cpp b/src/mlpack/tests/rectangle_tree_test.cpp
index da04f58..83caf30 100644
--- a/src/mlpack/tests/rectangle_tree_test.cpp
+++ b/src/mlpack/tests/rectangle_tree_test.cpp
@@ -70,7 +70,7 @@ std::vector<arma::vec*> getAllPointsInTree(const RectangleTree<tree::RTreeSplit<
     }
   } else {
     for(size_t i = 0; i < tree.Count(); i++) {
-      arma::vec* c = new arma::vec(tree.Dataset().col(i)); 
+      arma::vec* c = new arma::vec(tree.Dataset().col(tree.Points()[i])); 
       vec.push_back(c);
     }
   }
@@ -112,7 +112,7 @@ bool checkContainment(const RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeu
   bool passed = true;
   if(tree.NumChildren() == 0) {
     for(size_t i = 0; i < tree.Count(); i++) {
-      passed &= tree.Bound().Contains(tree.Dataset().unsafe_col(i));
+      passed &= tree.Bound().Contains(tree.Dataset().unsafe_col(tree.Points()[i]));
     }
   } else {
     for(size_t i = 0; i < tree.NumChildren(); i++) {

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git



More information about the debian-science-commits mailing list