[mlpack] 43/207: Refactor auxiliary infos to be part of the class.

Barak A. Pearlmutter barak+git at pearlmutter.net
Thu Mar 23 17:53:39 UTC 2017


This is an automated email from the git hooks/post-receive script.

bap pushed a commit to branch master
in repository mlpack.

commit 4a82bbcdc9b5d04acd8cd0f12bb02b00c22fdeab
Author: Ryan Curtin <ryan at ratml.org>
Date:   Thu Jan 19 10:56:36 2017 -0500

    Refactor auxiliary infos to be part of the class.
    
    This is so that when the auxiliary infos are empty, no size is added to the
    decision tree itself.
---
 src/mlpack/methods/decision_tree/decision_tree.hpp |  33 +++--
 .../methods/decision_tree/decision_tree_impl.hpp   | 143 ++++++++++++++-------
 2 files changed, 121 insertions(+), 55 deletions(-)

diff --git a/src/mlpack/methods/decision_tree/decision_tree.hpp b/src/mlpack/methods/decision_tree/decision_tree.hpp
index 4ab3975..39b5369 100644
--- a/src/mlpack/methods/decision_tree/decision_tree.hpp
+++ b/src/mlpack/methods/decision_tree/decision_tree.hpp
@@ -19,13 +19,20 @@ namespace tree {
 /**
  * This class implements a generic decision tree learner.  Its behavior can be
  * controlled via its template arguments.
+ *
+ * The class inherits from the auxiliary split information in order to prevent
+ * an empty auxiliary split information struct from taking any extra size.
  */
 template<typename FitnessFunction = GiniGain,
          template<typename> class NumericSplitType = BestBinaryNumericSplit,
          template<typename> class CategoricalSplitType = AllCategoricalSplit,
          typename ElemType = double,
          bool NoRecursion = false>
-class DecisionTree
+class DecisionTree :
+    public NumericSplitType<FitnessFunction>::template
+        AuxiliarySplitInfo<ElemType>,
+    public CategoricalSplitType<FitnessFunction>::template
+        AuxiliarySplitInfo<ElemType>
 {
  public:
   //! Allow access to the numeric split type.
@@ -190,16 +197,26 @@ class DecisionTree
    */
   arma::vec classProbabilities;
 
-  //! Auxiliary information in case this is a numeric split.  This class may be
-  //! empty.
-  typename NumericSplit::template AuxiliarySplitInfo<ElemType> numericAux;
-  //! Auxiliary information in case this is a categorical split.  This class may
-  //! be empty.
-  typename CategoricalSplit::template AuxiliarySplitInfo<ElemType>
-      categoricalAux;
+  //! Note that this class will also hold the members of the NumericSplit and
+  //! CategoricalSplit AuxiliarySplitInfo classes, since it inherits from them.
+  //! We'll define some convenience typedefs here.
+  typedef typename NumericSplit::template AuxiliarySplitInfo<ElemType>
+      NumericAuxiliarySplitInfo;
+  typedef typename CategoricalSplit::template AuxiliarySplitInfo<ElemType>
+      CategoricalAuxiliarySplitInfo;
+
+  /**
+   * Calculate the class probabilities of the given labels.
+   */
+  template<typename RowType>
+  void CalculateClassProbabilities(const RowType& labels,
+                                   const size_t numClasses);
 };
 
 } // namespace tree
 } // namespace mlpack
 
+// Include implementation.
+#include "decision_tree_impl.hpp"
+
 #endif
diff --git a/src/mlpack/methods/decision_tree/decision_tree_impl.hpp b/src/mlpack/methods/decision_tree/decision_tree_impl.hpp
index fbb07d8..00070d6 100644
--- a/src/mlpack/methods/decision_tree/decision_tree_impl.hpp
+++ b/src/mlpack/methods/decision_tree/decision_tree_impl.hpp
@@ -16,6 +16,7 @@ template<typename FitnessFunction,
          template<typename> class CategoricalSplitType,
          typename ElemType,
          bool NoRecursion>
+template<typename MatType>
 DecisionTree<FitnessFunction,
              NumericSplitType,
              CategoricalSplitType,
@@ -36,6 +37,7 @@ template<typename FitnessFunction,
          template<typename> class CategoricalSplitType,
          typename ElemType,
          bool NoRecursion>
+template<typename MatType>
 DecisionTree<FitnessFunction,
              NumericSplitType,
              CategoricalSplitType,
@@ -78,11 +80,11 @@ DecisionTree<FitnessFunction,
              CategoricalSplitType,
              ElemType,
              NoRecursion>::DecisionTree(const DecisionTree& other) :
+    NumericAuxiliarySplitInfo(other),
+    CategoricalAuxiliarySplitInfo(other),
     splitDimension(other.splitDimension),
     dimensionTypeOrMajorityClass(other.dimensionTypeOrMajorityClass),
-    classProbabilities(other.classProbabilities),
-    numericAux(other.numericAux),
-    categoricalAux(other.categoricalAux)
+    classProbabilities(other.classProbabilities)
 {
   // Copy each child.
   for (size_t i = 0; i < other.children.size(); ++i)
@@ -100,12 +102,12 @@ DecisionTree<FitnessFunction,
              CategoricalSplitType,
              ElemType,
              NoRecursion>::DecisionTree(DecisionTree&& other) :
+    NumericAuxiliarySplitInfo(std::move(other)),
+    CategoricalAuxiliarySplitInfo(std::move(other)),
     children(std::move(other.children)),
     splitDimension(other.splitDimension),
     dimensionTypeOrMajorityClass(other.dimensionTypeOrMajorityClass),
-    classProbabilities(std::move(other.classProbabilities)),
-    numericAux(std::move(other.numericAux)),
-    categoricalAux(std::move(other.categoricalAux))
+    classProbabilities(std::move(other.classProbabilities))
 {
   // Reset the other object.
   other.classProbabilities.ones(1); // One class, P(1) = 1.
@@ -137,13 +139,15 @@ DecisionTree<FitnessFunction,
   splitDimension = other.splitDimension;
   dimensionTypeOrMajorityClass = other.dimensionTypeOrMajorityClass;
   classProbabilities = other.classProbabilities;
-  numericAux = other.numericAux;
-  categoricalAux = other.categoricalAux;
 
   // Copy the children.
   for (size_t i = 0; i < other.children.size(); ++i)
     children.push_back(new DecisionTree(*other.children[i]));
 
+  // Copy the auxiliary info.
+  NumericAuxiliarySplitInfo::operator=(other);
+  CategoricalAuxiliarySplitInfo::operator=(other);
+
   return *this;
 }
 
@@ -174,12 +178,14 @@ DecisionTree<FitnessFunction,
   splitDimension = other.splitDimension;
   dimensionTypeOrMajorityClass = other.dimensionTypeOrMajorityClass;
   classProbabilities = std::move(other.classProbabilities);
-  numericAux = std::move(other.numericAux);
-  categoricalAux = std::move(other.categoricalAux);
 
   // Reset the class probabilities of the other object.
   other.classProbabilities.ones(1); // One class, P(1) = 1.
 
+  // Take ownership of the auxiliary info.
+  NumericAuxiliarySplitInfo::operator=(std::move(other));
+  CategoricalAuxiliarySplitInfo::operator=(std::move(other));
+
   return *this;
 }
 
@@ -231,47 +237,59 @@ void DecisionTree<FitnessFunction,
   for (size_t i = 0; i < datasetInfo.Dimensionality(); ++i)
   {
     double dimGain;
-    if (datasetInfo.Type(i) == Datatype::categorical)
+    if (datasetInfo.Type(i) == data::Datatype::categorical)
       dimGain = CategoricalSplit::SplitIfBetter(bestGain, data.row(i),
-          info.NumMappings(i), labels, numClasses, minimumLeafSize,
-          classProbabilities, categoricalAux);
-    else if (datasetInfo.Type(i) == Datatype::numeric)
+          datasetInfo.NumMappings(i), labels, numClasses, minimumLeafSize,
+          classProbabilities, *this);
+    else if (datasetInfo.Type(i) == data::Datatype::numeric)
       dimGain = NumericSplit::SplitIfBetter(bestGain, data.row(i), labels,
-          numClasses, minimumLeafSize, classProbabilities, numericAux);
+          numClasses, minimumLeafSize, classProbabilities, *this);
 
     // Was there an improvement?  If so mark that it's the new best dimension.
     if (dimGain > bestGain)
+    {
       bestDim = i;
+      bestGain = dimGain;
+    }
+
+    // If the gain is the best possible, no need to keep looking.
+    if (bestGain == 0.0)
+      break;
   }
 
   // Did we split or not?  If so, then split the data and create the children.
   if (bestDim != datasetInfo.Dimensionality())
   {
     dimensionTypeOrMajorityClass = (size_t) datasetInfo.Type(bestDim);
+    splitDimension = bestDim;
 
     // Get the number of children we will have.
     size_t numChildren = 0;
-    if (datasetInfo.Type(bestDim) == Datatype::categorical)
-      numChildren = CategoricalSplit::NumChildren(classProbabilities,
-          categoricalAux);
+    if (datasetInfo.Type(bestDim) == data::Datatype::categorical)
+      numChildren = CategoricalSplit::NumChildren(classProbabilities, *this);
     else
-      numChildren = NumericSplit::NumChildren(classProbabilities, numericAux);
+      numChildren = NumericSplit::NumChildren(classProbabilities, *this);
 
     // Calculate all child assignments.
     arma::Col<size_t> childAssignments(data.n_cols);
-    if (datasetInfo.Type(bestDim) == Datatype::categorical)
+    if (datasetInfo.Type(bestDim) == data::Datatype::categorical)
     {
       for (size_t j = 0; j < data.n_cols; ++j)
         childAssignments[j] = CategoricalSplit::CalculateDirection(
-            classProbabilities, categoricalAux);
+            data(bestDim, j), classProbabilities, *this);
     }
     else
     {
       for (size_t j = 0; j < data.n_cols; ++j)
-        childAssignments[j] = NumericSplit::CalculateDirection(
-            classProbabilities, numericAux);
+        childAssignments[j] = NumericSplit::CalculateDirection(data(bestDim, j),
+            classProbabilities, *this);
     }
 
+    // Figure out counts of children.
+    arma::Row<size_t> childCounts(numClasses, arma::fill::zeros);
+    for (size_t i = 0; i < childAssignments.n_elem; ++i)
+      childCounts[childAssignments[i]]++;
+
     // Split into children.
     for (size_t i = 0; i < numChildren; ++i)
     {
@@ -296,8 +314,8 @@ void DecisionTree<FitnessFunction,
   else
   {
     // Clear auxiliary info objects.
-    numericAux = NumericSplit::AuxiliarySplitInfo<ElemType>();
-    categoricalAux = CategoricalSplit::AuxiliarySplitInfo<ElemType>();
+    NumericAuxiliarySplitInfo::operator=(NumericAuxiliarySplitInfo());
+    CategoricalAuxiliarySplitInfo::operator=(CategoricalAuxiliarySplitInfo());
 
     // Calculate class probabilities because we are a leaf.
     CalculateClassProbabilities(labels);
@@ -325,9 +343,8 @@ void DecisionTree<FitnessFunction,
     delete children[i];
   children.clear();
 
-  // We won't be using this.
-  categoricalAux =
-      CategoricalSplitType<FitnessFunction>::AuxiliarySplitInfo<ElemType>();
+  // We won't be using these members, so reset them.
+  CategoricalAuxiliarySplitInfo::operator=(CategoricalAuxiliarySplitInfo());
 
   // Look through the list of dimensions and obtain the best split.  We'll cache
   // the best numeric split auxiliary information in numericAux (and clear it
@@ -340,24 +357,32 @@ void DecisionTree<FitnessFunction,
   {
     double dimGain = NumericSplitType<FitnessFunction>::SplitIfBetter(bestGain,
         data.row(i), labels, numClasses, minimumLeafSize, classProbabilities,
-        numericAux);
+        *this);
 
     if (dimGain > bestGain)
+    {
       bestDim = i;
+      bestGain = dimGain;
+    }
+
+    // If the gain is the best possible, no need to keep looking.
+    if (bestGain == 0.0)
+      break;
   }
 
   // Did we split or not?  If so, then split the data and create the children.
   if (bestDim != data.n_rows)
   {
     // We know that the split is numeric.
-    size_t numChildren = NumericSplit::NumChildren(classProbabilities,
-        numericAux);
+    size_t numChildren = NumericSplit::NumChildren(classProbabilities, *this);
+    splitDimension = bestDim;
+    dimensionTypeOrMajorityClass = (size_t) data::Datatype::numeric;
 
     // Calculate all child assignments.
     arma::Col<size_t> childAssignments(data.n_cols);
     for (size_t j = 0; j < data.n_cols; ++j)
-      childAssignments[j] = NumericSplit::CalculateDirection(classProbabilities,
-          numericAux);
+      childAssignments[j] = NumericSplit::CalculateDirection(data(bestDim, j),
+          classProbabilities, *this);
 
     // Calculate counts of children in each node.
     arma::Col<size_t> childCounts(numChildren);
@@ -387,11 +412,11 @@ void DecisionTree<FitnessFunction,
   }
   else
   {
-    // We won't be needing this.
-    numericAux = NumericSplit::AuxiliarySplitInfo<ElemType>();
+    // We won't be needing these members, so reset them.
+    NumericAuxiliarySplitInfo::operator=(NumericAuxiliarySplitInfo());
 
     // Calculate class probabilities because we are a leaf.
-    CalculateClassProbabilities(labels);
+    CalculateClassProbabilities(labels, numClasses);
   }
 }
 
@@ -420,7 +445,7 @@ size_t DecisionTree<FitnessFunction,
 //! Return class probabilities for a given point.
 template<typename FitnessFunction,
          template<typename> class NumericSplitType,
-         template<typename> class CategoricalSplittype,
+         template<typename> class CategoricalSplitType,
          typename ElemType,
          bool NoRecursion>
 template<typename VecType>
@@ -439,7 +464,7 @@ void DecisionTree<FitnessFunction,
     return;
   }
 
-  children[CalculateDirection(point)]->Classify(point, probabilities);
+  children[CalculateDirection(point)]->Classify(point, prediction, probabilities);
 }
 
 //! Return the class for a set of points.
@@ -495,7 +520,7 @@ void DecisionTree<FitnessFunction,
   // be.
   DecisionTree* node = children[0];
   while (node->NumChildren() != 0)
-    node = &node->Child(i);
+    node = &node->Child(0);
   probabilities.set_size(node.classProbabilities.n_elem, data.n_cols);
 
   for (size_t i = 0; i < data.n_cols; ++i)
@@ -528,7 +553,7 @@ void DecisionTree<FitnessFunction,
 
   // Serialize the children first.
   size_t numChildren = children.size();
-  ar & CreateNVP(numChildren);
+  ar & CreateNVP(numChildren, "numChildren");
   if (Archive::is_loading::value)
   {
     children.resize(numChildren, NULL);
@@ -547,8 +572,6 @@ void DecisionTree<FitnessFunction,
   ar & CreateNVP(splitDimension, "splitDimension");
   ar & CreateNVP(dimensionTypeOrMajorityClass, "dimensionTypeOrMajorityClass");
   ar & CreateNVP(classProbabilities, "classProbabilities");
-  ar & CreateNVP(numericAux, "numericAux");
-  ar & CreateNVP(categoricalAux, "categoricalAux");
 }
 
 template<typename FitnessFunction,
@@ -563,12 +586,38 @@ size_t DecisionTree<FitnessFunction,
                     ElemType,
                     NoRecursion>::CalculateDirection(const VecType& point) const
 {
-  if ((Datatype) dimensionTypeOrMajorityClass == Datatype::categorical)
-    return CategoricalSplit::CalculateDirection(point, classProbabilities,
-        categoricalAux);
+  if ((data::Datatype) dimensionTypeOrMajorityClass ==
+      data::Datatype::categorical)
+    return CategoricalSplit::CalculateDirection(point[splitDimension],
+        classProbabilities, *this);
   else
-    return NumericSplit::CalculateDirection(point, classProbabilities,
-        numericAux);
+    return NumericSplit::CalculateDirection(point[splitDimension],
+        classProbabilities, *this);
+}
+
+template<typename FitnessFunction,
+         template<typename> class NumericSplitType,
+         template<typename> class CategoricalSplitType,
+         typename ElemType,
+         bool NoRecursion>
+template<typename RowType>
+void DecisionTree<FitnessFunction,
+                  NumericSplitType,
+                  CategoricalSplitType,
+                  ElemType,
+                  NoRecursion>::CalculateClassProbabilities(
+    const RowType& labels,
+    const size_t numClasses)
+{
+  classProbabilities.zeros(numClasses);
+  for (size_t i = 0; i < labels.n_elem; ++i)
+    classProbabilities[labels[i]]++;
+
+  // Now normalize into probabilities.
+  classProbabilities /= labels.n_elem;
+  arma::uword maxIndex;
+  classProbabilities.max(maxIndex);
+  dimensionTypeOrMajorityClass = (size_t) maxIndex;
 }
 
 } // namespace tree

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git



More information about the debian-science-commits mailing list