[mlpack] 35/207: Refactor DecisionStump to hold children when it is a stump.

Barak A. Pearlmutter barak+git at pearlmutter.net
Thu Mar 23 17:53:38 UTC 2017


This is an automated email from the git hooks/post-receive script.

bap pushed a commit to branch master
in repository mlpack.

commit d8958268e6fd70208013be73ebfd700cc207c54e
Author: Ryan Curtin <ryan at ratml.org>
Date:   Thu Dec 15 16:34:00 2016 -0500

    Refactor DecisionStump to hold children when it is a stump.
    
    Plus first steps towards generalizing to decision trees.
---
 .../methods/decision_stump/decision_stump.hpp      |  88 ++++--
 .../methods/decision_stump/decision_stump_impl.hpp | 301 ++++++++++++++++-----
 2 files changed, 303 insertions(+), 86 deletions(-)

diff --git a/src/mlpack/methods/decision_stump/decision_stump.hpp b/src/mlpack/methods/decision_stump/decision_stump.hpp
index 5918aaa..6ffc418 100644
--- a/src/mlpack/methods/decision_stump/decision_stump.hpp
+++ b/src/mlpack/methods/decision_stump/decision_stump.hpp
@@ -29,8 +29,11 @@ namespace decision_stump {
  * Points that are below the first bin will take the label of the first bin.
  *
  * @tparam MatType Type of matrix that is being used (sparse or dense).
+ * @tparam NoRecursion If true, this will create a stump (a one-level decision
+ *       tree).
  */
-template<typename MatType = arma::mat>
+template<typename MatType = arma::mat,
+         bool NoRecursion = true>
 class DecisionStump
 {
  public:
@@ -72,6 +75,39 @@ class DecisionStump
   DecisionStump();
 
   /**
+   * Copy the given decision stump.
+   *
+   * @param other Decision stump to copy.
+   */
+  DecisionStump(const DecisionStump& other);
+
+  /**
+   * Take ownership of the given decision stump.
+   *
+   * @param other Decision stump to take ownership of.
+   */
+  DecisionStump(DecisionStump&& other);
+
+  /**
+   * Copy the given decision stump.
+   *
+   * @param other Decision stump to copy.
+   */
+  DecisionStump& operator=(const DecisionStump& other);
+
+  /**
+   * Take ownership of the given decision stump.
+   *
+   * @param other Decision stump to take ownership of.
+   */
+  DecisionStump& operator=(DecisionStump&& other);
+
+  /**
+   * Destroy the decision stump.
+   */
+  ~DecisionStump();
+
+  /**
    * Train the decision stump on the given data.  This completely overwrites any
    * previous training data, so after training the stump may be completely
    * different.
@@ -97,36 +133,54 @@ class DecisionStump
   void Classify(const MatType& test, arma::Row<size_t>& predictedLabels);
 
   //! Access the splitting dimension.
-  size_t SplitDimension() const { return splitDimension; }
+  size_t SplitDimension() const { return splitDimensionOrLabel; }
   //! Modify the splitting dimension (be careful!).
-  size_t& SplitDimension() { return splitDimension; }
+  size_t& SplitDimension() { return splitDimensionOrLabel; }
 
   //! Access the splitting values.
-  const arma::vec& Split() const { return split; }
+  const arma::vec& Split() const { return splitOrClassProbs; }
   //! Modify the splitting values (be careful!).
-  arma::vec& Split() { return split; }
+  arma::vec& Split() { return splitOrClassProbs; }
 
-  //! Access the labels for each split bin.
-  const arma::Col<size_t> BinLabels() const { return binLabels; }
-  //! Modify the labels for each split bin (be careful!).
-  arma::Col<size_t>& BinLabels() { return binLabels; }
+  //! Access the label.
+  size_t Label() const { return splitDimensionOrLabel; }
+  //! Modify the label.
+  size_t& Label() { return splitDimensionOrLabel; }
+
+  //! Access the given child.
+  const DecisionStump& Child(const size_t i) const { return *children[i]; }
+  //! Modify the given child.
+  DecisionStump& Child(const size_t i) { return *children[i]; }
 
   //! Serialize the decision stump.
   template<typename Archive>
   void Serialize(Archive& ar, const unsigned int /* version */);
 
  private:
-  //! The number of classes (we must store this for boosting).
+  /**
+   * Construct a leaf with the given probabilities and class label.
+   *
+   * @param bucketSize Bucket size for training.
+   * @param label Majority label of leaf.
+   * @param probabilities Class probabilities of leaf.
+   */
+  DecisionStump(const size_t bucketSize,
+                const size_t label,
+                arma::vec&& probabilities);
+
+  //! The number of classes in the model.
   size_t classes;
-  //! The minimum number of points in a bucket.
+  //! The minimum number of points in a bucket (training parameter).
   size_t bucketSize;
 
-  //! Stores the value of the dimension on which to split.
-  size_t splitDimension;
-  //! Stores the splitting values after training.
-  arma::vec split;
-  //! Stores the labels for each splitting bin.
-  arma::Col<size_t> binLabels;
+  //! Stores the value of the dimension on which to split, or the label.
+  size_t splitDimensionOrLabel;
+  //! Stores either the splitting values after training, or the class
+  //! probabilities.
+  arma::vec splitOrClassProbs;
+
+  //! Stores the children (if any).
+  std::vector<DecisionStump*> children;
 
   /**
    * Sets up dimension as if it were splitting on it and finds entropy when
diff --git a/src/mlpack/methods/decision_stump/decision_stump_impl.hpp b/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
index aa7201a..cea3d96 100644
--- a/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
+++ b/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
@@ -9,7 +9,6 @@
  * 3-clause BSD license along with mlpack.  If not, see
  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
  */
-
 #ifndef MLPACK_METHODS_DECISION_STUMP_DECISION_STUMP_IMPL_HPP
 #define MLPACK_METHODS_DECISION_STUMP_DECISION_STUMP_IMPL_HPP
 
@@ -27,11 +26,12 @@ namespace decision_stump {
  * @param classes Number of distinct classes in labels.
  * @param bucketSize Minimum size of bucket when splitting.
  */
-template<typename MatType>
-DecisionStump<MatType>::DecisionStump(const MatType& data,
-                                      const arma::Row<size_t>& labels,
-                                      const size_t classes,
-                                      const size_t bucketSize) :
+template<typename MatType, bool NoRecursion>
+DecisionStump<MatType, NoRecursion>::DecisionStump(
+    const MatType& data,
+    const arma::Row<size_t>& labels,
+    const size_t classes,
+    const size_t bucketSize) :
     classes(classes),
     bucketSize(bucketSize)
 {
@@ -42,26 +42,111 @@ DecisionStump<MatType>::DecisionStump(const MatType& data,
 /**
  * Empty constructor.
  */
-template<typename MatType>
-DecisionStump<MatType>::DecisionStump() :
+template<typename MatType, bool NoRecursion>
+DecisionStump<MatType, NoRecursion>::DecisionStump() :
     classes(1),
     bucketSize(0),
-    splitDimension(0),
-    split(1),
-    binLabels(1)
+    splitDimensionOrLabel(0),
+    splitOrClassProbs(1)
 {
-  split[0] = DBL_MAX;
-  binLabels[0] = 0;
+  splitOrClassProbs[0] = 1.0;
+  if (NoRecursion)
+  {
+    // Make a fake stump by creating two children.  We create two and not one so
+    // that we can be guaranteed that splitOrClassProbs has at least one
+    // element.  The children are identical in functionality though.
+    children.push_back(new DecisionStump(0, 0, std::move(arma::vec("1.0"))));
+    children.push_back(new DecisionStump(0, 0, std::move(arma::vec("1.0"))));
+  }
+}
+
+// Copy constructor.
+template<typename MatType, bool NoRecursion>
+DecisionStump<MatType, NoRecursion>::DecisionStump(const DecisionStump& other) :
+    classes(other.classes),
+    bucketSize(other.bucketSize),
+    splitDimensionOrLabel(other.splitDimensionOrLabel),
+    splitOrClassProbs(other.splitOrClassProbs)
+{
+  for (size_t i = 0; i < other.children.size(); ++i)
+    children.push_back(new DecisionStump(*other.children[i]));
+}
+
+// Move constructor.
+template<typename MatType, bool NoRecursion>
+DecisionStump<MatType, NoRecursion>::DecisionStump(DecisionStump&& other) :
+    classes(other.classes),
+    bucketSize(other.bucketSize),
+    splitDimensionOrLabel(other.splitDimensionOrLabel),
+    splitOrClassProbs(std::move(other.splitOrClassProbs)),
+    children(std::move(other.children))
+{
+  // Reset the other one.
+  other = DecisionStump();
+}
+
+// Copy assignment operator.
+template<typename MatType, bool NoRecursion>
+DecisionStump<MatType, NoRecursion>&
+DecisionStump<MatType, NoRecursion>::operator=(const DecisionStump& other)
+{
+  // Clear existing memory.
+  for (size_t i = 0; i < children.size(); ++i)
+    delete children[i];
+  children.clear();
+
+  classes = other.classes;
+  bucketSize = other.bucketSize;
+  splitDimensionOrLabel = other.splitDimensionOrLabel;
+  splitOrClassProbs = other.splitOrClassProbs;
+
+  // Create copies of the children.
+  for (size_t i = 0; i < other.children.size(); ++i)
+    children.push_back(new DecisionStump(*other.children[i]));
+
+  return *this;
+}
+
+// Move assignment operator.
+template<typename MatType, bool NoRecursion>
+DecisionStump<MatType, NoRecursion>&
+DecisionStump<MatType, NoRecursion>::operator=(DecisionStump&& other)
+{
+  // Clear existing memory.
+  for (size_t i = 0; i < children.size(); ++i)
+    delete children[i];
+  children.clear();
+
+  classes = other.classes;
+  bucketSize = other.bucketSize;
+  splitDimensionOrLabel = other.splitDimensionOrLabel;
+  splitOrClassProbs = std::move(other.splitOrClassProbs);
+  children = std::move(other.children);
+
+  // Clear and reinitialize other object.
+  other = DecisionStump();
+
+  return *this;
+}
+
+// Destructor.
+template<typename MatType, bool NoRecursion>
+DecisionStump<MatType, NoRecursion>::~DecisionStump()
+{
+  for (size_t i = 0; i < children.size(); ++i)
+    delete children[i];
+  children.clear();
 }
 
 /**
  * Train on the given data and labels.
  */
-template<typename MatType>
-void DecisionStump<MatType>::Train(const MatType& data,
-                                   const arma::Row<size_t>& labels,
-                                   const size_t classes,
-                                   const size_t bucketSize)
+template<typename MatType, bool NoRecursion>
+void DecisionStump<MatType, NoRecursion>::Train(
+    const MatType& data,
+    const arma::Row<size_t>& labels,
+    const size_t classes,
+    const size_t bucketSize)
 {
   this->classes = classes;
   this->bucketSize = bucketSize;
@@ -78,11 +163,12 @@ void DecisionStump<MatType>::Train(const MatType& data,
  * @param labels Labels for dataset.
  * @param UseWeights Whether we need to run a weighted Decision Stump.
  */
-template<typename MatType>
+template<typename MatType, bool NoRecursion>
 template<bool UseWeights>
-void DecisionStump<MatType>::Train(const MatType& data,
-                                   const arma::Row<size_t>& labels,
-                                   const arma::rowvec& weights)
+void DecisionStump<MatType, NoRecursion>::Train(
+    const MatType& data,
+    const arma::Row<size_t>& labels,
+    const arma::rowvec& weights)
 {
   // If classLabels are not all identical, proceed with training.
   size_t bestDim = 0;
@@ -112,10 +198,10 @@ void DecisionStump<MatType>::Train(const MatType& data,
       }
     }
   }
-  splitDimension = bestDim;
+  splitDimensionOrLabel = bestDim;
 
   // Once the splitting column/dimension has been decided, train on it.
-  TrainOnDim(data.row(splitDimension), labels);
+  TrainOnDim(data.row(splitDimensionOrLabel), labels);
 }
 
 /**
@@ -126,9 +212,10 @@ void DecisionStump<MatType>::Train(const MatType& data,
  * @param predictedLabels Vector to store the predicted classes after
  *      classifying test
  */
-template<typename MatType>
-void DecisionStump<MatType>::Classify(const MatType& test,
-                                      arma::Row<size_t>& predictedLabels)
+template<typename MatType, bool NoRecursion>
+void DecisionStump<MatType, NoRecursion>::Classify(
+    const MatType& test,
+    arma::Row<size_t>& predictedLabels)
 {
   predictedLabels.set_size(test.n_cols);
   for (size_t i = 0; i < test.n_cols; i++)
@@ -137,17 +224,20 @@ void DecisionStump<MatType>::Classify(const MatType& test,
     // Assume first that it falls into the first bin, then proceed through the
     // bins until it is known which bin it falls into.
     size_t bin = 0;
-    const double val = test(splitDimension, i);
+    const double val = test(splitDimensionOrLabel, i);
 
-    while (bin < split.n_elem - 1)
+    while (bin < splitOrClassProbs.n_elem - 1)
     {
-      if (val < split(bin + 1))
+      if (val < splitOrClassProbs(bin + 1))
         break;
 
       ++bin;
     }
 
-    predictedLabels(i) = binLabels(bin);
+    if (NoRecursion)
+      predictedLabels(i) = children[bin]->Label();
+    else
+      children[bin]->Classify(test, predictedLabels);
   }
 }
 
@@ -163,11 +253,12 @@ void DecisionStump<MatType>::Classify(const MatType& test,
  * @param labels The labels of data.
  * @param UseWeights Whether we need to run a weighted Decision Stump.
  */
-template<typename MatType>
-DecisionStump<MatType>::DecisionStump(const DecisionStump<>& other,
-                                      const MatType& data,
-                                      const arma::Row<size_t>& labels,
-                                      const arma::rowvec& weights) :
+template<typename MatType, bool NoRecursion>
+DecisionStump<MatType, NoRecursion>::DecisionStump(
+    const DecisionStump<>& other,
+    const MatType& data,
+    const arma::Row<size_t>& labels,
+    const arma::rowvec& weights) :
     classes(other.classes),
     bucketSize(other.bucketSize)
 {
@@ -177,10 +268,11 @@ DecisionStump<MatType>::DecisionStump(const DecisionStump<>& other,
 /**
  * Serialize the decision stump.
  */
-template<typename MatType>
+template<typename MatType, bool NoRecursion>
 template<typename Archive>
-void DecisionStump<MatType>::Serialize(Archive& ar,
-                                       const unsigned int /* version */)
+void DecisionStump<MatType, NoRecursion>::Serialize(
+    Archive& ar,
+    const unsigned int /* version */)
 {
   using data::CreateNVP;
 
@@ -188,9 +280,40 @@ void DecisionStump<MatType>::Serialize(Archive& ar,
   // None need special handling.
   ar & CreateNVP(classes, "classes");
   ar & CreateNVP(bucketSize, "bucketSize");
-  ar & CreateNVP(splitDimension, "splitDimension");
-  ar & CreateNVP(split, "split");
-  ar & CreateNVP(binLabels, "binLabels");
+  ar & CreateNVP(splitDimensionOrLabel, "splitDimensionOrLabel");
+  ar & CreateNVP(splitOrClassProbs, "splitOrClassProbs");
+
+  size_t numChildren = children.size();
+  ar & CreateNVP(numChildren, "numChildren");
+  if (Archive::is_loading::value)
+  {
+    // Clear memory and prepare for loading children.
+    for (size_t i = 0; i < children.size(); ++i)
+      delete children[i];
+    children.clear();
+    children.resize(numChildren);
+  }
+
+  for (size_t i = 0; i < numChildren; ++i)
+  {
+    std::ostringstream oss;
+    oss << "child" << i;
+    ar & CreateNVP(children[i], oss.str());
+  }
+}
+
+/**
+ * Create a leaf manually.
+ */
+template<typename MatType, bool NoRecursion>
+DecisionStump<MatType, NoRecursion>::DecisionStump(const size_t bucketSize,
+                                                   const size_t label,
+                                                   arma::vec&& probabilities) :
+    bucketSize(bucketSize),
+    splitDimensionOrLabel(label),
+    splitOrClassProbs(std::move(probabilities))
+{
+  // Nothing else to do.
 }
 
 /**
@@ -201,9 +324,9 @@ void DecisionStump<MatType>::Serialize(Archive& ar,
  *      the splitting dimension.
  * @param UseWeights Whether we need to run a weighted Decision Stump.
  */
-template<typename MatType>
+template<typename MatType, bool NoRecursion>
 template<bool UseWeights, typename VecType>
-double DecisionStump<MatType>::SetupSplitDimension(
+double DecisionStump<MatType, NoRecursion>::SetupSplitDimension(
     const VecType& dimension,
     const arma::Row<size_t>& labels,
     const arma::rowvec& weights)
@@ -291,10 +414,11 @@ double DecisionStump<MatType>::SetupSplitDimension(
  * @param dimension Dimension is the dimension decided by the constructor on
  *      which we now train the decision stump.
  */
-template<typename MatType>
+template<typename MatType, bool NoRecursion>
 template<typename VecType>
-void DecisionStump<MatType>::TrainOnDim(const VecType& dimension,
-                                        const arma::Row<size_t>& labels)
+void DecisionStump<MatType, NoRecursion>::TrainOnDim(
+    const VecType& dimension,
+    const arma::Row<size_t>& labels)
 {
   size_t i, count, begin, end;
 
@@ -302,6 +426,7 @@ void DecisionStump<MatType>::TrainOnDim(const VecType& dimension,
   arma::uvec sortedSplitIndexDim = arma::stable_sort_index(dimension.t());
   arma::Row<size_t> sortedLabels(dimension.n_elem);
   sortedLabels.fill(0);
+  arma::Col<size_t> binLabels;
 
   for (i = 0; i < dimension.n_elem; i++)
     sortedLabels(i) = labels(sortedSplitIndexDim(i));
@@ -320,8 +445,8 @@ void DecisionStump<MatType>::TrainOnDim(const VecType& dimension,
 
       mostFreq = CountMostFreq(sortedLabels.cols(begin, end));
 
-      split.resize(split.n_elem + 1);
-      split(split.n_elem - 1) = sortedSplitDim(begin);
+      splitOrClassProbs.resize(splitOrClassProbs.n_elem + 1);
+      splitOrClassProbs(splitOrClassProbs.n_elem - 1) = sortedSplitDim(begin);
       binLabels.resize(binLabels.n_elem + 1);
       binLabels(binLabels.n_elem - 1) = mostFreq;
 
@@ -348,8 +473,8 @@ void DecisionStump<MatType>::TrainOnDim(const VecType& dimension,
       // the bucket of subCols.
       mostFreq = CountMostFreq(sortedLabels.cols(begin, end));
 
-      split.resize(split.n_elem + 1);
-      split(split.n_elem - 1) = sortedSplitDim(begin);
+      splitOrClassProbs.resize(splitOrClassProbs.n_elem + 1);
+      splitOrClassProbs(splitOrClassProbs.n_elem - 1) = sortedSplitDim(begin);
       binLabels.resize(binLabels.n_elem + 1);
       binLabels(binLabels.n_elem - 1) = mostFreq;
 
@@ -362,32 +487,70 @@ void DecisionStump<MatType>::TrainOnDim(const VecType& dimension,
 
   // Now trim the split matrix so that buckets one after the after which point
   // to the same classLabel are merged as one big bucket.
-  MergeRanges();
-}
-
-/**
- * After the "split" matrix has been set up, merge ranges with identical class
- * labels.
- */
-template<typename MatType>
-void DecisionStump<MatType>::MergeRanges()
-{
-  for (size_t i = 1; i < split.n_rows; i++)
+  for (size_t i = 1; i < splitOrClassProbs.n_rows; i++)
   {
     if (binLabels(i) == binLabels(i - 1))
     {
       // Remove this row, as it has the same label as the previous bucket.
       binLabels.shed_row(i);
-      split.shed_row(i);
+      splitOrClassProbs.shed_row(i);
       // Go back to previous row.
       i--;
     }
   }
+
+  // Now create the children, either recursively (if we are not a tree) or not
+  // (if we are a stump).
+  if (NoRecursion)
+  {
+    size_t begin = 0;
+    for (size_t i = 0; i < splitOrClassProbs.n_elem; ++i)
+    {
+      // Calculate class probabilities for children.
+      arma::vec childClassProbs(classes);
+      childClassProbs.zeros();
+
+      size_t lastBegin = begin;
+      do
+      {
+        childClassProbs[sortedLabels[begin]]++;
+      } while (sortedSplitDim(++begin) < splitOrClassProbs[i]);
+
+      // Normalize probabilities.
+      childClassProbs /= (begin - lastBegin);
+
+      // Create child.
+      children.push_back(new DecisionStump(bucketSize, binLabels[i],
+          std::move(childClassProbs)));
+    }
+
+    // Create the last child.
+    arma::vec childClassProbs(classes);
+    childClassProbs.zeros();
+
+    size_t lastBegin = begin;
+    do
+    {
+      childClassProbs[sortedLabels[begin]]++;
+    } while (++begin < sortedSplitDim.n_elem);
+
+    // Normalize probabilities.
+    childClassProbs /= (begin - lastBegin);
+
+    // Create child.
+    children.push_back(new DecisionStump(bucketSize,
+        binLabels[binLabels.n_elem - 1], std::move(childClassProbs)));
+  }
+  else
+  {
+    // Do recursion.
+  }
 }
 
-template<typename MatType>
+template<typename MatType, bool NoRecursion>
 template<typename VecType>
-double DecisionStump<MatType>::CountMostFreq(const VecType& subCols)
+double DecisionStump<MatType, NoRecursion>::CountMostFreq(
+    const VecType& subCols)
 {
   // We'll create a map of elements and the number of times that each element is
   // seen.
@@ -424,9 +587,9 @@ double DecisionStump<MatType>::CountMostFreq(const VecType& subCols)
  *
  * @param featureRow The dimension which is checked for identical values.
  */
-template<typename MatType>
+template<typename MatType, bool NoRecursion>
 template<typename VecType>
-int DecisionStump<MatType>::IsDistinct(const VecType& featureRow)
+int DecisionStump<MatType, NoRecursion>::IsDistinct(const VecType& featureRow)
 {
   typename VecType::elem_type val = featureRow(0);
   for (size_t i = 1; i < featureRow.n_elem; ++i)
@@ -441,9 +604,9 @@ int DecisionStump<MatType>::IsDistinct(const VecType& featureRow)
  * @param labels Corresponding labels of the dimension.
  * @param UseWeights Whether we need to run a weighted Decision Stump.
  */
-template<typename MatType>
+template<typename MatType, bool NoRecursion>
 template<bool UseWeights, typename VecType, typename WeightVecType>
-double DecisionStump<MatType>::CalculateEntropy(
+double DecisionStump<MatType, NoRecursion>::CalculateEntropy(
     const VecType& labels,
     const WeightVecType& weights)
 {

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git



More information about the debian-science-commits mailing list