[mlpack] 40/207: Add decision tree first implementation.

Barak A. Pearlmutter barak+git at pearlmutter.net
Thu Mar 23 17:53:39 UTC 2017


This is an automated email from the git hooks/post-receive script.

bap pushed a commit to branch master
in repository mlpack.

commit 4d3f1faebb3f25a4eaeaeecead99df69ed4c5866
Author: Ryan Curtin <ryan at ratml.org>
Date:   Tue Jan 17 14:20:07 2017 -0500

    Add decision tree first implementation.
---
 .../decision_tree/all_categorical_split.hpp        |  63 +++
 .../decision_tree/all_categorical_split_impl.hpp   |  95 ++++
 .../decision_tree/best_binary_numeric_split.hpp    |  62 +++
 .../best_binary_numeric_split_impl.hpp             |  95 ++++
 src/mlpack/methods/decision_tree/decision_tree.hpp | 205 ++++++++
 .../methods/decision_tree/decision_tree_impl.hpp   | 577 +++++++++++++++++++++
 src/mlpack/methods/decision_tree/gini_gain.hpp     |  77 +++
 .../methods/decision_tree/information_gain.hpp     |  75 +++
 8 files changed, 1249 insertions(+)

diff --git a/src/mlpack/methods/decision_tree/all_categorical_split.hpp b/src/mlpack/methods/decision_tree/all_categorical_split.hpp
new file mode 100644
index 0000000..4999627
--- /dev/null
+++ b/src/mlpack/methods/decision_tree/all_categorical_split.hpp
@@ -0,0 +1,63 @@
+/**
+ * @file all_categorical_split.hpp
+ * @author Ryan Curtin
+ *
+ * This file defines a tree splitter that split a categorical feature into all
+ * of the possible categories.
+ */
+#ifndef MLPACK_METHODS_DECISION_TREE_ALL_CATEGORICAL_SPLIT_HPP
+#define MLPACK_METHODS_DECISION_TREE_ALL_CATEGORICAL_SPLIT_HPP
+
+#include <mlpack/prereqs.hpp>
+
+namespace mlpack {
+namespace tree {
+
+template<typename FitnessFunction>
+class AllCategoricalSplit
+{
+ public:
+  // No extra info needed for split.
+  template<typename ElemType>
+  class AuxiliarySplitInfo { };
+
+  /**
+   * Check if we can split a node.  If we can split a node in a way that
+   * improves on 'bestGain', then we return the improved gain.  Otherwise we
+   * return the value 'bestGain'.  If a split is made, then classProbabilities
+   * and aux may be modified.  For this particular split type, aux will be empty
+   * and classProbabilities will hold one element---the number of children.
+   */
+  template<typename VecType>
+  static double SplitIfBetter(
+      const double bestGain,
+      const VecType& data,
+      const size_t numCategories,
+      const arma::Row<size_t>& labels,
+      const size_t numClasses,
+      const size_t minimumLeafSize,
+      arma::Col<typename VecType::elem_type>& classProbabilities,
+      AuxiliarySplitInfo<typename VecType::elem_type>& aux);
+
+  /**
+   * Return the number of children in the split.
+   */
+  template<typename ElemType>
+  static size_t NumChildren(const arma::Col<ElemType>& classProbabilities,
+                            const AuxiliarySplitInfo<ElemType>& /* aux */);
+
+  template<typename ElemType>
+  static size_t CalculateDirection(
+      const ElemType& point,
+      const arma::Col<ElemType>& classProbabilities,
+      const AuxiliarySplitInfo<ElemType>& /* aux */);
+};
+
+} // namespace tree
+} // namespace mlpack
+
+// Include implementation.
+#include "all_categorical_split_impl.hpp"
+
+#endif
+
diff --git a/src/mlpack/methods/decision_tree/all_categorical_split_impl.hpp b/src/mlpack/methods/decision_tree/all_categorical_split_impl.hpp
new file mode 100644
index 0000000..76aea39
--- /dev/null
+++ b/src/mlpack/methods/decision_tree/all_categorical_split_impl.hpp
@@ -0,0 +1,95 @@
+/**
+ * @file all_categorical_split_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the AllCategoricalSplit categorical split class.
+ */
+#ifndef MLPACK_METHODS_DECISION_TREE_ALL_CATEGORICAL_SPLIT_IMPL_HPP
+#define MLPACK_METHODS_DECISION_TREE_ALL_CATEGORICAL_SPLIT_IMPL_HPP
+
+namespace mlpack {
+namespace tree {
+
+template<typename FitnessFunction>
+template<typename VecType>
+double AllCategoricalSplit<FitnessFunction>::SplitIfBetter(
+    const double bestGain,
+    const VecType& data,
+    const size_t numCategories,
+    const arma::Row<size_t>& labels,
+    const size_t numClasses,
+    const size_t minimumLeafSize,
+    arma::Col<typename VecType::elem_type>& classProbabilities,
+    AllCategoricalSplit::AuxiliarySplitInfo<typename VecType::elem_type>& aux)
+{
+  // Count the number of elements in each potential child.
+  arma::Col<size_t> counts(numCategories);
+  counts.zeros();
+  for (size_t i = 0; i < data.n_elem; ++i)
+    counts[(size_t) data[i]]++;
+
+  // If each child will have the minimum number of points in it, we can split.
+  // Otherwise we can't.
+  if (arma::min(counts) < minimumLeafSize)
+    return bestGain;
+
+  // Calculate the gain of the split.  First we have to calculate the labels
+  // that would be assigned to each child.
+  arma::uvec childPositions(numCategories);
+  std::vector<arma::Row<size_t>> childLabels;
+  for (size_t i = 0; i < numCategories; ++i)
+    childLabels[i].zeros(counts[i]);
+
+  // Extract labels for each child.
+  for (size_t i = 0; i < data.n_elem; ++i)
+  {
+    const size_t category = (size_t) data[i];
+    childLabels[category][childPositions[category]++] = labels[i];
+  }
+
+  double overallGain = 0.0;
+  for (size_t i = 0; i < counts.n_elem; ++i)
+  {
+    // Calculate the gain of this child.
+    const double childPct = double(counts[i]) / double(data.n_elem);
+    const double childGain = FitnessFunction::Evaluate(childLabels[i],
+        numClasses);
+
+    overallGain += childPct * childGain;
+  }
+
+  if (overallGain > bestGain)
+  {
+    // This is better, so set up the class probabilities vector and return.
+    classProbabilities.set_size(1);
+    classProbabilities[0] = numCategories;
+    return overallGain;
+  }
+
+  // Otherwise there was no improvement.
+  return bestGain;
+}
+
+template<typename FitnessFunction>
+template<typename ElemType>
+size_t AllCategoricalSplit<FitnessFunction>::NumChildren(
+    const arma::Col<ElemType>& classProbabilities,
+    const AllCategoricalSplit::AuxiliarySplitInfo<ElemType>& /* aux */)
+{
+  return classProbabilities[0];
+}
+
+template<typename FitnessFunction>
+template<typename ElemType>
+size_t AllCategoricalSplit<FitnessFunction>::CalculateDirection(
+    const ElemType& point,
+    const arma::Col<ElemType>& classProbabilities,
+    const AllCategoricalSplit::AuxiliarySplitInfo<ElemType>& /* aux */)
+{
+  return point;
+}
+
+} // namespace tree
+} // namespace mlpack
+
+#endif
diff --git a/src/mlpack/methods/decision_tree/best_binary_numeric_split.hpp b/src/mlpack/methods/decision_tree/best_binary_numeric_split.hpp
new file mode 100644
index 0000000..3d71229
--- /dev/null
+++ b/src/mlpack/methods/decision_tree/best_binary_numeric_split.hpp
@@ -0,0 +1,62 @@
+/**
+ * @file best_binary_numeric_split.hpp
+ * @author Ryan Curtin
+ *
+ * A tree splitter that finds the best binary numeric split.
+ */
+#ifndef MLPACK_METHODS_DECISION_TREE_BEST_BINARY_NUMERIC_SPLIT_HPP
+#define MLPACK_METHODS_DECISION_TREE_BEST_BINARY_NUMERIC_SPLIT_HPP
+
+#include <mlpack/prereqs.hpp>
+
+namespace mlpack {
+namespace tree {
+
+template<typename FitnessFunction>
+class BestBinaryNumericSplit
+{
+ public:
+  // No extra info needed for split.
+  template<typename ElemType>
+  class AuxiliarySplitInfo { };
+
+  /**
+   * Check if we can split a node.  If we can split a node in a way that
+   * improves on 'bestGain', then we return the improved gain.  Otherwise we
+   * return the value 'bestGain'.  If a split is made, then classProbabilities
+   * and aux may be modified.
+   */
+  template<typename VecType>
+  static double SplitIfBetter(
+      const double bestGain,
+      const VecType& data,
+      const arma::Row<size_t>& labels,
+      const size_t numClasses,
+      const size_t minimumLeafSize,
+      arma::Col<typename VecType::elem_type>& classProbabilities,
+      AuxiliarySplitInfo<typename VecType::elem_type>& aux);
+
+  /**
+   * Returns 2, since the binary split always has two children.
+   */
+  template<typename ElemType>
+  static size_t NumChildren(const arma::Col<ElemType>& /* classProbabilities */,
+                            const AuxiliarySplitInfo<ElemType>& /* aux */)
+  {
+    return 2;
+  }
+
+  template<typename ElemType>
+  static size_t CalculateDirection(
+      const ElemType& point,
+      const arma::Col<ElemType>& classProbabilities,
+      const AuxiliarySplitInfo<ElemType>& /* aux */);
+};
+
+} // namespace tree
+} // namespace mlpack
+
+// Include implementation.
+#include "best_binary_numeric_split_impl.hpp"
+
+#endif
diff --git a/src/mlpack/methods/decision_tree/best_binary_numeric_split_impl.hpp b/src/mlpack/methods/decision_tree/best_binary_numeric_split_impl.hpp
new file mode 100644
index 0000000..0a2981a
--- /dev/null
+++ b/src/mlpack/methods/decision_tree/best_binary_numeric_split_impl.hpp
@@ -0,0 +1,95 @@
+/**
+ * @file best_binary_numeric_split_impl.hpl
+ * @author Ryan Curtin
+ *
+ * Implementation of strategy that finds the best binary numeric split.
+ */
+#ifndef MLPACK_METHODS_DECISION_TREE_BEST_BINARY_NUMERIC_SPLIT_IMPL_HPP
+#define MLPACK_METHODS_DECISION_TREE_BEST_BINARY_NUMERIC_SPLIT_IMPL_HPP
+
+namespace mlpack {
+namespace tree {
+
+template<typename FitnessFunction>
+template<typename VecType>
+double BestBinaryNumericSplit<FitnessFunction>::SplitIfBetter(
+    const double bestGain,
+    const VecType& data,
+    const arma::Row<size_t>& labels,
+    const size_t numClasses,
+    const size_t minimumLeafSize,
+    arma::Col<typename VecType::elem_type>& classProbabilities,
+    BestBinaryNumericSplit::AuxiliarySplitInfo<typename VecType::elem_type>&
+        aux)
+{
+  // First sanity check: if we don't have enough points, we can't split.
+  if (data.n_elem < (minimumLeafSize * 2))
+    return bestGain;
+
+  // Next, sort the data.
+  arma::uvec sortedIndices = arma::sort_index(data);
+  arma::Row<size_t> sortedLabels(labels.n_elem);
+  for (size_t i = 0; i < sortedLabels.n_elem; ++i)
+    sortedLabels[i] = labels[sortedIndices[i]];
+
+  // Loop through all possible split points, choosing the best one.
+  double bestFoundGain = bestGain;
+  for (size_t index = minimumLeafSize; index < data.n_elem - minimumLeafSize;
+      ++index)
+  {
+    // Calculate the gain for the left and right child.
+    const double leftGain = FitnessFunction::Evaluate(sortedLabels.subvec(0,
+        index - 1), numClasses);
+    const double rightGain = FitnessFunction::Evaluate(sortedLabels.subvec(
+        index, sortedLabels.n_elem - 1), numClasses);
+
+    // Calculate the fraction of points in the left and right children.
+    const double leftRatio = double(index) / double(sortedLabels.n_elem);
+    const double rightRatio = 1.0 - leftRatio;
+
+    // Calculate the gain at this split point.
+    const double gain = leftRatio * leftGain + rightRatio * rightGain;
+
+    // Corner case: is this the best possible split?
+    if (gain == FitnessFunction::BestGain(numClasses))
+    {
+      // We can take a shortcut: no split will be better than this, so just take
+      // this one.
+      classProbabilities.set_size(1);
+      // The actual split value will be halfway between the value at index - 1
+      // and index.
+      classProbabilities[0] = (data[sortedIndices[index - 1]] +
+          data[sortedIndices[index]]) / 2.0;
+      return gain;
+    }
+    else if (gain > bestFoundGain)
+    {
+      // We still have a better split.
+      bestFoundGain = gain;
+      classProbabilities.set_size(1);
+      classProbabilities[0] = (data[sortedIndices[index - 1]] +
+          data[sortedIndices[index]]) / 2.0;
+    }
+  }
+
+  return bestFoundGain;
+}
+
+template<typename FitnessFunction>
+template<typename ElemType>
+size_t BestBinaryNumericSplit<FitnessFunction>::CalculateDirection(
+    const ElemType& point,
+    const arma::Col<ElemType>& classProbabilities,
+    const BestBinaryNumericSplit<FitnessFunction>::AuxiliarySplitInfo<ElemType>&
+        /* aux */)
+{
+  if (point <= classProbabilities[0])
+    return 0; // Go left.
+  else
+    return 1; // Go right.
+}
+
+} // namespace tree
+} // namespace mlpack
+
+#endif
diff --git a/src/mlpack/methods/decision_tree/decision_tree.hpp b/src/mlpack/methods/decision_tree/decision_tree.hpp
new file mode 100644
index 0000000..4ab3975
--- /dev/null
+++ b/src/mlpack/methods/decision_tree/decision_tree.hpp
@@ -0,0 +1,205 @@
+/**
+ * @file decision_tree.hpp
+ * @author Ryan Curtin
+ *
+ * A generic decision tree learner.  Its behavior can be controlled via template
+ * arguments.
+ */
+#ifndef MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP
+#define MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP
+
+#include <mlpack/prereqs.hpp>
+#include "gini_gain.hpp"
+#include "best_binary_numeric_split.hpp"
+#include "all_categorical_split.hpp"
+
+namespace mlpack {
+namespace tree {
+
+/**
+ * This class implements a generic decision tree learner.  Its behavior can be
+ * controlled via its template arguments.
+ */
+template<typename FitnessFunction = GiniGain,
+         template<typename> class NumericSplitType = BestBinaryNumericSplit,
+         template<typename> class CategoricalSplitType = AllCategoricalSplit,
+         typename ElemType = double,
+         bool NoRecursion = false>
+class DecisionTree
+{
+ public:
+  //! Allow access to the numeric split type.
+  typedef NumericSplitType<FitnessFunction> NumericSplit;
+  //! Allow access to the categorical split type.
+  typedef CategoricalSplitType<FitnessFunction> CategoricalSplit;
+
+  /**
+   * Construct the decision tree on the given data and labels.
+   */
+  template<typename MatType>
+  DecisionTree(const MatType& data,
+               const data::DatasetInfo& datasetInfo,
+               const arma::Row<size_t>& labels,
+               const size_t numClasses,
+               const size_t minimumLeafSize = 10);
+
+  /**
+   * Construct the decision tree on the given data and labels, assuming that the
+   * data is all of the numeric type.
+   */
+  template<typename MatType>
+  DecisionTree(const MatType& data,
+               const arma::Row<size_t>& labels,
+               const size_t numClasses,
+               const size_t minimumLeafSize = 10);
+
+  /**
+   * Construct a decision tree without training it.  It will be a leaf node with
+   * equal probabilities for each class.
+   */
+  DecisionTree(const size_t numClasses = 1);
+
+  /**
+   * Copy another tree.  This may use a lot of memory---be sure that it's what
+   * you want to do.
+   *
+   * @param other Tree to copy.
+   */
+  DecisionTree(const DecisionTree& other);
+
+  /**
+   * Take ownership of another tree.
+   *
+   * @param other Tree to take ownership of.
+   */
+  DecisionTree(DecisionTree&& other);
+
+  /**
+   * Copy another tree.  This may use a lot of memory---be sure that it's what
+   * you want to do.
+   *
+   * @param other Tree to copy.
+   */
+  DecisionTree& operator=(const DecisionTree& other);
+
+  /**
+   * Take ownership of another tree.
+   *
+   * @param other Tree to take ownership of.
+   */
+  DecisionTree& operator=(DecisionTree&& other);
+
+  /**
+   * Clean up memory.
+   */
+  ~DecisionTree();
+
+  /**
+   * Train the decision tree on the given data.  This will overwrite the
+   * existing model.
+   */
+  template<typename MatType>
+  void Train(const MatType& data,
+             const data::DatasetInfo& datasetInfo,
+             const arma::Row<size_t>& labels,
+             const size_t numClasses,
+             const size_t minimumLeafSize = 10);
+
+  /**
+   * Train the decision tree on the given data, assuming that all dimensions are
+   * numeric.  This will overwrite the given model.
+   */
+  template<typename MatType>
+  void Train(const MatType& data,
+             const arma::Row<size_t>& labels,
+             const size_t numClasses,
+             const size_t minimumLeafSize = 10);
+
+  /**
+   * Classify the given point, using the entire tree.  The predicted label is
+   * returned.
+   */
+  template<typename VecType>
+  size_t Classify(const VecType& point) const;
+
+  /**
+   * Classify the given point and also return estimates of the probability for
+   * each class in the given vector.
+   */
+  template<typename VecType>
+  void Classify(const VecType& point,
+                size_t& prediction,
+                arma::vec& probabilities) const;
+
+  /**
+   * Classify the given points, using the entire tree.  The predicted labels for
+   * each point are stored in the given vector.
+   */
+  template<typename MatType>
+  void Classify(const MatType& data,
+                arma::Row<size_t>& predictions) const;
+
+  /**
+   * Classify the given points and also return estimates of the probabilities
+   * for each class in the given matrix.  The predicted labels for each point
+   * are stored in the given vector.
+   */
+  template<typename MatType>
+  void Classify(const MatType& data,
+                arma::Row<size_t>& predictions,
+                arma::mat& probabilities) const;
+
+  /**
+   * Serialize the tree.
+   */
+  template<typename Archive>
+  void Serialize(Archive& ar, const unsigned int /* version */);
+
+  //! Get the number of children.
+  size_t NumChildren() const { return children.size(); }
+
+  //! Get the child of the given index.
+  const DecisionTree& Child(const size_t i) const { return *children[i]; }
+  //! Modify the child of the given index (be careful!).
+  DecisionTree& Child(const size_t i) { return *children[i]; }
+
+  /**
+   * Given a point and that this node is not a leaf, calculate the index of the
+   * child node this point would go towards.  This method is primarily used by
+   * the Classify() function, but it can be used in a standalone sense too.
+   *
+   * @param point Point to classify.
+   */
+  template<typename VecType>
+  size_t CalculateDirection(const VecType& point) const;
+
+ private:
+  //! The vector of children.
+  std::vector<DecisionTree*> children;
+  //! The dimension this node splits on.
+  size_t splitDimension;
+  //! The type of the dimension that we have split on (if we are not a leaf).
+  //! If we are a leaf, then this is the index of the majority class.
+  size_t dimensionTypeOrMajorityClass;
+  /**
+   * This vector may hold different things.  If the node has no children, then
+   * it is guaranteed to hold the probabilities of each class.  If the node has
+   * children, then it may be used arbitrarily by the split type's
+   * CalculateDirection() function and may not necessarily hold class
+   * probabilities.
+   */
+  arma::vec classProbabilities;
+
+  //! Auxiliary information in case this is a numeric split.  This class may be
+  //! empty.
+  typename NumericSplit::template AuxiliarySplitInfo<ElemType> numericAux;
+  //! Auxiliary information in case this is a categorical split.  This class may
+  //! be empty.
+  typename CategoricalSplit::template AuxiliarySplitInfo<ElemType>
+      categoricalAux;
+};
+
+} // namespace tree
+} // namespace mlpack
+
+#endif
diff --git a/src/mlpack/methods/decision_tree/decision_tree_impl.hpp b/src/mlpack/methods/decision_tree/decision_tree_impl.hpp
new file mode 100644
index 0000000..fbb07d8
--- /dev/null
+++ b/src/mlpack/methods/decision_tree/decision_tree_impl.hpp
@@ -0,0 +1,577 @@
+/**
+ * @file decision_tree_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of generic decision tree class.
+ */
+#ifndef MLPACK_METHODS_DECISION_TREE_DECISION_TREE_IMPL_HPP
+#define MLPACK_METHODS_DECISION_TREE_DECISION_TREE_IMPL_HPP
+
+namespace mlpack {
+namespace tree {
+
+//! Construct and train.
+template<typename FitnessFunction,
+         template<typename> class NumericSplitType,
+         template<typename> class CategoricalSplitType,
+         typename ElemType,
+         bool NoRecursion>
+DecisionTree<FitnessFunction,
+             NumericSplitType,
+             CategoricalSplitType,
+             ElemType,
+             NoRecursion>::DecisionTree(const MatType& data,
+                                        const data::DatasetInfo& datasetInfo,
+                                        const arma::Row<size_t>& labels,
+                                        const size_t numClasses,
+                                        const size_t minimumLeafSize)
+{
+  // Pass off work to the Train() method.
+  Train(data, datasetInfo, labels, numClasses, minimumLeafSize);
+}
+
+//! Construct and train.
+template<typename FitnessFunction,
+         template<typename> class NumericSplitType,
+         template<typename> class CategoricalSplitType,
+         typename ElemType,
+         bool NoRecursion>
+DecisionTree<FitnessFunction,
+             NumericSplitType,
+             CategoricalSplitType,
+             ElemType,
+             NoRecursion>::DecisionTree(const MatType& data,
+                                        const arma::Row<size_t>& labels,
+                                        const size_t numClasses,
+                                        const size_t minimumLeafSize)
+{
+  // Pass off work to the Train() method.
+  Train(data, labels, numClasses, minimumLeafSize);
+}
+
+//! Construct, don't train.
+template<typename FitnessFunction,
+         template<typename> class NumericSplitType,
+         template<typename> class CategoricalSplitType,
+         typename ElemType,
+         bool NoRecursion>
+DecisionTree<FitnessFunction,
+             NumericSplitType,
+             CategoricalSplitType,
+             ElemType,
+             NoRecursion>::DecisionTree(const size_t numClasses) :
+    classProbabilities(numClasses),
+    dimensionTypeOrMajorityClass(0)
+{
+  // Initialize utility vector.
+  classProbabilities.fill(1.0 / (double) numClasses);
+}
+
+//! Copy another tree.
+template<typename FitnessFunction,
+         template<typename> class NumericSplitType,
+         template<typename> class CategoricalSplitType,
+         typename ElemType,
+         bool NoRecursion>
+DecisionTree<FitnessFunction,
+             NumericSplitType,
+             CategoricalSplitType,
+             ElemType,
+             NoRecursion>::DecisionTree(const DecisionTree& other) :
+    splitDimension(other.splitDimension),
+    dimensionTypeOrMajorityClass(other.dimensionTypeOrMajorityClass),
+    classProbabilities(other.classProbabilities),
+    numericAux(other.numericAux),
+    categoricalAux(other.categoricalAux)
+{
+  // Copy each child.
+  for (size_t i = 0; i < other.children.size(); ++i)
+    children.push_back(new DecisionTree(*other.children[i]));
+}
+
+//! Take ownership of another tree.
+template<typename FitnessFunction,
+         template<typename> class NumericSplitType,
+         template<typename> class CategoricalSplitType,
+         typename ElemType,
+         bool NoRecursion>
+DecisionTree<FitnessFunction,
+             NumericSplitType,
+             CategoricalSplitType,
+             ElemType,
+             NoRecursion>::DecisionTree(DecisionTree&& other) :
+    children(std::move(other.children)),
+    splitDimension(other.splitDimension),
+    dimensionTypeOrMajorityClass(other.dimensionTypeOrMajorityClass),
+    classProbabilities(std::move(other.classProbabilities)),
+    numericAux(std::move(other.numericAux)),
+    categoricalAux(std::move(other.categoricalAux))
+{
+  // Reset the other object.
+  other.classProbabilities.ones(1); // One class, P(1) = 1.
+}
+
+//! Copy another tree.
+template<typename FitnessFunction,
+         template<typename> class NumericSplitType,
+         template<typename> class CategoricalSplitType,
+         typename ElemType,
+         bool NoRecursion>
+DecisionTree<FitnessFunction,
+             NumericSplitType,
+             CategoricalSplitType,
+             ElemType,
+             NoRecursion>&
+DecisionTree<FitnessFunction,
+             NumericSplitType,
+             CategoricalSplitType,
+             ElemType,
+             NoRecursion>::operator=(const DecisionTree& other)
+{
+  // Clean memory if needed.
+  for (size_t i = 0; i < children.size(); ++i)
+    delete children[i];
+  children.clear();
+
+  // Copy everything from the other tree.
+  splitDimension = other.splitDimension;
+  dimensionTypeOrMajorityClass = other.dimensionTypeOrMajorityClass;
+  classProbabilities = other.classProbabilities;
+  numericAux = other.numericAux;
+  categoricalAux = other.categoricalAux;
+
+  // Copy the children.
+  for (size_t i = 0; i < other.children.size(); ++i)
+    children.push_back(new DecisionTree(*other.children[i]));
+
+  return *this;
+}
+
+//! Take ownership of another tree.
+template<typename FitnessFunction,
+         template<typename> class NumericSplitType,
+         template<typename> class CategoricalSplitType,
+         typename ElemType,
+         bool NoRecursion>
+DecisionTree<FitnessFunction,
+             NumericSplitType,
+             CategoricalSplitType,
+             ElemType,
+             NoRecursion>&
+DecisionTree<FitnessFunction,
+             NumericSplitType,
+             CategoricalSplitType,
+             ElemType,
+             NoRecursion>::operator=(DecisionTree&& other)
+{
+  // Clean memory if needed.
+  for (size_t i = 0; i < children.size(); ++i)
+    delete children[i];
+  children.clear();
+
+  // Take ownership of the other tree's components.
+  children = std::move(other.children);
+  splitDimension = other.splitDimension;
+  dimensionTypeOrMajorityClass = other.dimensionTypeOrMajorityClass;
+  classProbabilities = std::move(other.classProbabilities);
+  numericAux = std::move(other.numericAux);
+  categoricalAux = std::move(other.categoricalAux);
+
+  // Reset the class probabilities of the other object.
+  other.classProbabilities.ones(1); // One class, P(1) = 1.
+
+  return *this;
+}
+
+//! Clean up memory.
+template<typename FitnessFunction,
+         template<typename> class NumericSplitType,
+         template<typename> class CategoricalSplitType,
+         typename ElemType,
+         bool NoRecursion>
+DecisionTree<FitnessFunction,
+             NumericSplitType,
+             CategoricalSplitType,
+             ElemType,
+             NoRecursion>::~DecisionTree()
+{
+  for (size_t i = 0; i < children.size(); ++i)
+    delete children[i];
+}
+
+//! Train on the given data.
+template<typename FitnessFunction,
+         template<typename> class NumericSplitType,
+         template<typename> class CategoricalSplitType,
+         typename ElemType,
+         bool NoRecursion>
+template<typename MatType>
+void DecisionTree<FitnessFunction,
+                  NumericSplitType,
+                  CategoricalSplitType,
+                  ElemType,
+                  NoRecursion>::Train(const MatType& data,
+                                      const data::DatasetInfo& datasetInfo,
+                                      const arma::Row<size_t>& labels,
+                                      const size_t numClasses,
+                                      const size_t minimumLeafSize)
+{
+  // Clear children if needed.
+  for (size_t i = 0; i < children.size(); ++i)
+    delete children[i];
+  children.clear();
+
+  // Look through the list of dimensions and obtain the gain of the best split.
+  // We'll cache the best numeric and categorical split auxiliary information in
+  // numericAux and categoricalAux (and clear them later if we make not split),
+  // and use classProbabilities as auxiliary information.  Later we'll overwrite
+  // classProbabilities to the empirical class probabilities if we do not split.
+  double bestGain = FitnessFunction::Evaluate(labels, numClasses);
+  size_t bestDim = datasetInfo.Dimensionality(); // This means "no split".
+  for (size_t i = 0; i < datasetInfo.Dimensionality(); ++i)
+  {
+    double dimGain;
+    if (datasetInfo.Type(i) == Datatype::categorical)
+      dimGain = CategoricalSplit::SplitIfBetter(bestGain, data.row(i),
+          info.NumMappings(i), labels, numClasses, minimumLeafSize,
+          classProbabilities, categoricalAux);
+    else if (datasetInfo.Type(i) == Datatype::numeric)
+      dimGain = NumericSplit::SplitIfBetter(bestGain, data.row(i), labels,
+          numClasses, minimumLeafSize, classProbabilities, numericAux);
+
+    // Was there an improvement?  If so mark that it's the new best dimension.
+    if (dimGain > bestGain)
+      bestDim = i;
+  }
+
+  // Did we split or not?  If so, then split the data and create the children.
+  if (bestDim != datasetInfo.Dimensionality())
+  {
+    dimensionTypeOrMajorityClass = (size_t) datasetInfo.Type(bestDim);
+
+    // Get the number of children we will have.
+    size_t numChildren = 0;
+    if (datasetInfo.Type(bestDim) == Datatype::categorical)
+      numChildren = CategoricalSplit::NumChildren(classProbabilities,
+          categoricalAux);
+    else
+      numChildren = NumericSplit::NumChildren(classProbabilities, numericAux);
+
+    // Calculate all child assignments.
+    arma::Col<size_t> childAssignments(data.n_cols);
+    if (datasetInfo.Type(bestDim) == Datatype::categorical)
+    {
+      for (size_t j = 0; j < data.n_cols; ++j)
+        childAssignments[j] = CategoricalSplit::CalculateDirection(
+            classProbabilities, categoricalAux);
+    }
+    else
+    {
+      for (size_t j = 0; j < data.n_cols; ++j)
+        childAssignments[j] = NumericSplit::CalculateDirection(
+            classProbabilities, numericAux);
+    }
+
+    // Split into children.
+    for (size_t i = 0; i < numChildren; ++i)
+    {
+      // Now that we have the size of the matrix we need to extract, extract it.
+      MatType childPoints(data.n_rows, childCounts[i]);
+      arma::Row<size_t> childLabels(childCounts[i]);
+      size_t currentCol = 0;
+      for (size_t j = 0; j < data.n_cols; ++j)
+      {
+        if (childAssignments[j] == i)
+        {
+          childPoints.col(currentCol) = data.col(j);
+          childLabels[currentCol++] = labels[j];
+        }
+      }
+
+      // Now build the child recursively.
+      children.push_back(new DecisionTree(childPoints, childLabels, numClasses,
+          minimumLeafSize));
+    }
+  }
+  else
+  {
+    // Clear auxiliary info objects.
+    numericAux = NumericSplit::AuxiliarySplitInfo<ElemType>();
+    categoricalAux = CategoricalSplit::AuxiliarySplitInfo<ElemType>();
+
+    // Calculate class probabilities because we are a leaf.
+    CalculateClassProbabilities(labels);
+  }
+}
+
+//! Train on the given data, assuming all dimensions are numeric.
+template<typename FitnessFunction,
+         template<typename> class NumericSplitType,
+         template<typename> class CategoricalSplitType,
+         typename ElemType,
+         bool NoRecursion>
+template<typename MatType>
+void DecisionTree<FitnessFunction,
+                  NumericSplitType,
+                  CategoricalSplitType,
+                  ElemType,
+                  NoRecursion>::Train(const MatType& data,
+                                      const arma::Row<size_t>& labels,
+                                      const size_t numClasses,
+                                      const size_t minimumLeafSize)
+{
+  // Clear children if needed.
+  for (size_t i = 0; i < children.size(); ++i)
+    delete children[i];
+  children.clear();
+
+  // We won't be using this.
+  categoricalAux =
+      CategoricalSplitType<FitnessFunction>::AuxiliarySplitInfo<ElemType>();
+
+  // Look through the list of dimensions and obtain the best split.  We'll cache
+  // the best numeric split auxiliary information in numericAux (and clear it
+  // later if we don't make a split), and use classProbabilities as auxiliary
+  // information.  Later we'll overwrite classProbabilities to the empirical
+  // class probabilities if we do not split.
+  double bestGain = FitnessFunction::Evaluate(labels, numClasses);
+  size_t bestDim = data.n_rows; // This means "no split".
+  for (size_t i = 0; i < data.n_rows; ++i)
+  {
+    double dimGain = NumericSplitType<FitnessFunction>::SplitIfBetter(bestGain,
+        data.row(i), labels, numClasses, minimumLeafSize, classProbabilities,
+        numericAux);
+
+    if (dimGain > bestGain)
+      bestDim = i;
+  }
+
+  // Did we split or not?  If so, then split the data and create the children.
+  if (bestDim != data.n_rows)
+  {
+    // We know that the split is numeric.
+    size_t numChildren = NumericSplit::NumChildren(classProbabilities,
+        numericAux);
+
+    // Calculate all child assignments.
+    arma::Col<size_t> childAssignments(data.n_cols);
+    for (size_t j = 0; j < data.n_cols; ++j)
+      childAssignments[j] = NumericSplit::CalculateDirection(classProbabilities,
+          numericAux);
+
+    // Calculate counts of children in each node.
+    arma::Col<size_t> childCounts(numChildren);
+    childCounts.zeros();
+    for (size_t j = 0; j < childAssignments.n_elem; ++j)
+      childCounts[childAssignments[j]]++;
+
+    for (size_t i = 0; i < numChildren; ++i)
+    {
+      // Now that we have the size of the matrix we need to extract, extract it.
+      MatType childPoints(data.n_rows, childCounts[i]);
+      arma::Row<size_t> childLabels(childCounts[i]);
+      size_t currentCol = 0;
+      for (size_t j = 0; j < data.n_cols; ++j)
+      {
+        if (childAssignments[j] == i)
+        {
+          childPoints.col(currentCol) = data.col(j);
+          childLabels[currentCol++] = labels[j];
+        }
+      }
+
+      // Now build the child recursively.
+      children.push_back(new DecisionTree(childPoints, childLabels, numClasses,
+          minimumLeafSize));
+    }
+  }
+  else
+  {
+    // We won't be needing this.
+    numericAux = NumericSplit::AuxiliarySplitInfo<ElemType>();
+
+    // Calculate class probabilities because we are a leaf.
+    CalculateClassProbabilities(labels);
+  }
+}
+
+//! Return the class.
+template<typename FitnessFunction,
+         template<typename> class NumericSplitType,
+         template<typename> class CategoricalSplitType,
+         typename ElemType,
+         bool NoRecursion>
+template<typename VecType>
+size_t DecisionTree<FitnessFunction,
+                    NumericSplitType,
+                    CategoricalSplitType,
+                    ElemType,
+                    NoRecursion>::Classify(const VecType& point) const
+{
+  if (children.size() == 0)
+  {
+    // Return cached max of probabilities.
+    return dimensionTypeOrMajorityClass;
+  }
+
+  return children[CalculateDirection(point)]->Classify(point);
+}
+
+//! Return class probabilities for a given point.
+template<typename FitnessFunction,
+         template<typename> class NumericSplitType,
+         template<typename> class CategoricalSplittype,
+         typename ElemType,
+         bool NoRecursion>
+template<typename VecType>
+void DecisionTree<FitnessFunction,
+                  NumericSplitType,
+                  CategoricalSplitType,
+                  ElemType,
+                  NoRecursion>::Classify(const VecType& point,
+                                         size_t& prediction,
+                                         arma::vec& probabilities) const
+{
+  if (children.size() == 0)
+  {
+    prediction = dimensionTypeOrMajorityClass;
+    probabilities = classProbabilities;
+    return;
+  }
+
+  children[CalculateDirection(point)]->Classify(point, probabilities);
+}
+
+//! Return the class for a set of points.
+template<typename FitnessFunction,
+         template<typename> class NumericSplitType,
+         template<typename> class CategoricalSplitType,
+         typename ElemType,
+         bool NoRecursion>
+template<typename MatType>
+void DecisionTree<FitnessFunction,
+                  NumericSplitType,
+                  CategoricalSplitType,
+                  ElemType,
+                  NoRecursion>::Classify(const MatType& data,
+                                         arma::Row<size_t>& predictions) const
+{
+  predictions.set_size(data.n_cols);
+  if (children.size() == 0)
+  {
+    predictions.fill(dimensionTypeOrMajorityClass);
+    return;
+  }
+
+  // Loop over each point.
+  for (size_t i = 0; i < data.n_cols; ++i)
+    predictions[i] = Classify(data.col(i));
+}
+
+//! Return the class probabilities for a set of points.
+template<typename FitnessFunction,
+         template<typename> class NumericSplitType,
+         template<typename> class CategoricalSplitType,
+         typename ElemType,
+         bool NoRecursion>
+template<typename MatType>
+void DecisionTree<FitnessFunction,
+                  NumericSplitType,
+                  CategoricalSplitType,
+                  ElemType,
+                  NoRecursion>::Classify(const MatType& data,
+                                         arma::Row<size_t>& predictions,
+                                         arma::mat& probabilities) const
+{
+  predictions.set_size(data.n_cols);
+  if (children.size() == 0)
+  {
+    predictions.fill(dimensionTypeOrMajorityClass);
+    probabilities = arma::repmat(classProbabilities, 1, data.n_cols);
+    return;
+  }
+
+  // Otherwise we have to find the right size to set the predictions matrix to
+  // be.
+  DecisionTree* node = children[0];
+  while (node->NumChildren() != 0)
+    node = &node->Child(i);
+  probabilities.set_size(node.classProbabilities.n_elem, data.n_cols);
+
+  for (size_t i = 0; i < data.n_cols; ++i)
+    Classify(data.col(i), predictions[i], probabilities.unsafe_col(i));
+}
+
+//! Serialize the tree.
+template<typename FitnessFunction,
+         template<typename> class NumericSplitType,
+         template<typename> class CategoricalSplitType,
+         typename ElemType,
+         bool NoRecursion>
+template<typename Archive>
+void DecisionTree<FitnessFunction,
+                  NumericSplitType,
+                  CategoricalSplitType,
+                  ElemType,
+                  NoRecursion>::Serialize(Archive& ar,
+                                          const unsigned int /* version */)
+{
+  using data::CreateNVP;
+
+  // Clean memory if needed.
+  if (Archive::is_loading::value)
+  {
+    for (size_t i = 0; i < children.size(); ++i)
+      delete children[i];
+    children.clear();
+  }
+
+  // Serialize the children first.
+  size_t numChildren = children.size();
+  ar & CreateNVP(numChildren);
+  if (Archive::is_loading::value)
+  {
+    children.resize(numChildren, NULL);
+    for (size_t i = 0; i < numChildren; ++i)
+      children[i] = new DecisionTree();
+  }
+
+  for (size_t i = 0; i < numChildren; ++i)
+  {
+    std::ostringstream name;
+    name << "child" << i;
+    ar & CreateNVP(*children[i], name.str());
+  }
+
+  // Now serialize the rest of the object.
+  ar & CreateNVP(splitDimension, "splitDimension");
+  ar & CreateNVP(dimensionTypeOrMajorityClass, "dimensionTypeOrMajorityClass");
+  ar & CreateNVP(classProbabilities, "classProbabilities");
+  ar & CreateNVP(numericAux, "numericAux");
+  ar & CreateNVP(categoricalAux, "categoricalAux");
+}
+
+template<typename FitnessFunction,
+         template<typename> class NumericSplitType,
+         template<typename> class CategoricalSplitType,
+         typename ElemType,
+         bool NoRecursion>
+template<typename VecType>
+size_t DecisionTree<FitnessFunction,
+                    NumericSplitType,
+                    CategoricalSplitType,
+                    ElemType,
+                    NoRecursion>::CalculateDirection(const VecType& point) const
+{
+  if ((Datatype) dimensionTypeOrMajorityClass == Datatype::categorical)
+    return CategoricalSplit::CalculateDirection(point, classProbabilities,
+        categoricalAux);
+  else
+    return NumericSplit::CalculateDirection(point, classProbabilities,
+        numericAux);
+}
+
+} // namespace tree
+} // namespace mlpack
+
+#endif
diff --git a/src/mlpack/methods/decision_tree/gini_gain.hpp b/src/mlpack/methods/decision_tree/gini_gain.hpp
new file mode 100644
index 0000000..e075c87
--- /dev/null
+++ b/src/mlpack/methods/decision_tree/gini_gain.hpp
@@ -0,0 +1,77 @@
+/**
+ * @file gini_gain.hpp
+ * @author Ryan Curtin
+ *
+ * The GiniImpurity class, which is a fitness function (FitnessFunction) for
+ * decision trees.
+ *
+ * mlpack is free software; you may redistribute it and/or modify it under the
+ * terms of the 3-clause BSD license.  You should have received a copy of the
+ * 3-clause BSD license along with mlpack.  If not, see
+ * http://www.opensource.org/licenses/BSD-3-Clause for more information.
+ */
+#ifndef MLPACK_METHODS_DECISION_TREE_GINI_GAIN_HPP
+#define MLPACK_METHODS_DECISION_TREE_GINI_GAIN_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace tree {
+
+/**
+ * The Gini gain, a measure of set purity usable as a fitness function
+ * (FitnessFunction) for decision trees.  This is the exact same thing as the
+ * well-known Gini impurity, but negated---since the decision tree will be
+ * trying to maximize gain (and the Gini impurity would need to be minimized).
+ */
+class GiniGain
+{
+ public:
+  /**
+   * Evaluate the Gini impurity on the given set of labels.  RowType should be
+   * an Armadillo vector that holds size_t objects.
+   *
+   * @param labels Set of labels to evaluate Gini impurity on.
+   */
+  template<typename RowType>
+  static double Evaluate(const RowType& labels,
+                         const size_t numClasses)
+  {
+    // Corner case: if there are no elements, the impurity is zero.
+    if (labels.n_elem == 0)
+      return 0.0;
+
+    arma::Col<size_t> counts(numClasses);
+    counts.zeros();
+    for (size_t i = 0; i < labels.n_elem; ++i)
+      counts[labels[i]]++;
+
+    // Calculate the Gini impurity of the un-split node.
+    double impurity = 0.0;
+    for (size_t i = 0; i < numClasses; ++i)
+    {
+      const double f = ((double) counts[i] / (double) labels.n_elem);
+      impurity += f * (1.0 - f);
+    }
+
+    return -impurity;
+  }
+
+  /**
+   * Return the range of the Gini impurity for the given number of classes.
+   * (That is, the difference between the maximum possible value and the minimum
+   * possible value.)
+   */
+  static double Range(const size_t numClasses)
+  {
+    // The best possible case is that only one class exists, which gives a Gini
+    // impurity of 0.  The worst possible case is that the classes are evenly
+    // distributed, which gives n * (1/n * (1 - 1/n)) = 1 - 1/n.
+    return 1.0 - (1.0 / double(numClasses));
+  }
+};
+
+} // namespace tree
+} // namespace mlpack
+
+#endif
diff --git a/src/mlpack/methods/decision_tree/information_gain.hpp b/src/mlpack/methods/decision_tree/information_gain.hpp
new file mode 100644
index 0000000..85eef7e
--- /dev/null
+++ b/src/mlpack/methods/decision_tree/information_gain.hpp
@@ -0,0 +1,75 @@
+/**
+ * @file information_gain.hpp
+ * @author Ryan Curtin
+ *
+ * An implementation of information gain, which can be used in place of Gini
+ * impurity.
+ *
+ * mlpack is free software; you may redistribute it and/or modify it under the
+ * terms of the 3-clause BSD license.  You should have received a copy of the
+ * 3-clause BSD license along with mlpack.  If not, see
+ * http://www.opensource.org/licenses/BSD-3-Clause for more information.
+ */
+#ifndef MLPACK_METHODS_DECISION_TREE_INFORMATION_GAIN_HPP
+#define MLPACK_METHODS_DECISION_TREE_INFORMATION_GAIN_HPP
+
+#include <mlpack/prereqs.hpp>
+
+namespace mlpack {
+namespace tree {
+
+class InformationGain
+{
+ public:
+  /**
+   * Given the sufficient statistics of a proposed split, calculate the
+   * information gain if that split was to be used.  The 'counts' matrix should
+   * contain the number of points in each class in each column, so the size of
+   * 'counts' is children x classes, where 'children' is the number of child
+   * nodes in the proposed split.
+   *
+   * @param counts Matrix of sufficient statistics.
+   */
+  static double Evaluate(const arma::Row<size_t>& labels,
+                         const size_t numClasses)
+  {
+    // Edge case: if there are no elements, the gain is zero.
+    if (labels.n_elem == 0)
+      return 0.0;
+
+    // Count the number of elements in each class.
+    arma::Col<size_t> counts(numClasses);
+    counts.zeros();
+    for (size_t i = 0; i < labels.n_elem; ++i)
+      counts[labels[i]]++;
+
+    // Calculate the information gain.
+    double gain = 0.0;
+    for (size_t i = 0; i < numClasses; ++i)
+    {
+      const double f = ((double) counts[i] / (double) labels.n_elem);
+      if (f > 0.0)
+        gain += f * std::log2(f);
+    }
+
+    return gain;
+  }
+
+  /**
+   * Return the range of the information gain for the given number of classes.
+   * (That is, the difference between the maximum possible value and the minimum
+   * possible value.)
+   */
+  static double Range(const size_t numClasses)
+  {
+    // The best possible case gives an information gain of 0.  The worst
+    // possible case is even distribution, which gives n * (1/n * log2(1/n)) =
+    // log2(1/n) = -log2(n).  So, the range is log2(n).
+    return std::log2(numClasses);
+  }
+};
+
+} // namespace tree
+} // namespace mlpack
+
+#endif

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git



More information about the debian-science-commits mailing list