[mlpack] 43/207: Refactor auxiliary infos to be part of the class.
Barak A. Pearlmutter
barak+git at pearlmutter.net
Thu Mar 23 17:53:39 UTC 2017
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch master
in repository mlpack.
commit 4a82bbcdc9b5d04acd8cd0f12bb02b00c22fdeab
Author: Ryan Curtin <ryan at ratml.org>
Date: Thu Jan 19 10:56:36 2017 -0500
Refactor auxiliary infos to be part of the class.
This is so that when the auxiliary infos are empty, no size is added to the
decision tree itself.
---
src/mlpack/methods/decision_tree/decision_tree.hpp | 33 +++--
.../methods/decision_tree/decision_tree_impl.hpp | 143 ++++++++++++++-------
2 files changed, 121 insertions(+), 55 deletions(-)
diff --git a/src/mlpack/methods/decision_tree/decision_tree.hpp b/src/mlpack/methods/decision_tree/decision_tree.hpp
index 4ab3975..39b5369 100644
--- a/src/mlpack/methods/decision_tree/decision_tree.hpp
+++ b/src/mlpack/methods/decision_tree/decision_tree.hpp
@@ -19,13 +19,20 @@ namespace tree {
/**
* This class implements a generic decision tree learner. Its behavior can be
* controlled via its template arguments.
+ *
+ * The class inherits from the auxiliary split information in order to prevent
+ * an empty auxiliary split information struct from taking any extra size.
*/
template<typename FitnessFunction = GiniGain,
template<typename> class NumericSplitType = BestBinaryNumericSplit,
template<typename> class CategoricalSplitType = AllCategoricalSplit,
typename ElemType = double,
bool NoRecursion = false>
-class DecisionTree
+class DecisionTree :
+ public NumericSplitType<FitnessFunction>::template
+ AuxiliarySplitInfo<ElemType>,
+ public CategoricalSplitType<FitnessFunction>::template
+ AuxiliarySplitInfo<ElemType>
{
public:
//! Allow access to the numeric split type.
@@ -190,16 +197,26 @@ class DecisionTree
*/
arma::vec classProbabilities;
- //! Auxiliary information in case this is a numeric split. This class may be
- //! empty.
- typename NumericSplit::template AuxiliarySplitInfo<ElemType> numericAux;
- //! Auxiliary information in case this is a categorical split. This class may
- //! be empty.
- typename CategoricalSplit::template AuxiliarySplitInfo<ElemType>
- categoricalAux;
+ //! Note that this class will also hold the members of the NumericSplit and
+ //! CategoricalSplit AuxiliarySplitInfo classes, since it inherits from them.
+ //! We'll define some convenience typedefs here.
+ typedef typename NumericSplit::template AuxiliarySplitInfo<ElemType>
+ NumericAuxiliarySplitInfo;
+ typedef typename CategoricalSplit::template AuxiliarySplitInfo<ElemType>
+ CategoricalAuxiliarySplitInfo;
+
+ /**
+ * Calculate the class probabilities of the given labels.
+ */
+ template<typename RowType>
+ void CalculateClassProbabilities(const RowType& labels,
+ const size_t numClasses);
};
} // namespace tree
} // namespace mlpack
+// Include implementation.
+#include "decision_tree_impl.hpp"
+
#endif
diff --git a/src/mlpack/methods/decision_tree/decision_tree_impl.hpp b/src/mlpack/methods/decision_tree/decision_tree_impl.hpp
index fbb07d8..00070d6 100644
--- a/src/mlpack/methods/decision_tree/decision_tree_impl.hpp
+++ b/src/mlpack/methods/decision_tree/decision_tree_impl.hpp
@@ -16,6 +16,7 @@ template<typename FitnessFunction,
template<typename> class CategoricalSplitType,
typename ElemType,
bool NoRecursion>
+template<typename MatType>
DecisionTree<FitnessFunction,
NumericSplitType,
CategoricalSplitType,
@@ -36,6 +37,7 @@ template<typename FitnessFunction,
template<typename> class CategoricalSplitType,
typename ElemType,
bool NoRecursion>
+template<typename MatType>
DecisionTree<FitnessFunction,
NumericSplitType,
CategoricalSplitType,
@@ -78,11 +80,11 @@ DecisionTree<FitnessFunction,
CategoricalSplitType,
ElemType,
NoRecursion>::DecisionTree(const DecisionTree& other) :
+ NumericAuxiliarySplitInfo(other),
+ CategoricalAuxiliarySplitInfo(other),
splitDimension(other.splitDimension),
dimensionTypeOrMajorityClass(other.dimensionTypeOrMajorityClass),
- classProbabilities(other.classProbabilities),
- numericAux(other.numericAux),
- categoricalAux(other.categoricalAux)
+ classProbabilities(other.classProbabilities)
{
// Copy each child.
for (size_t i = 0; i < other.children.size(); ++i)
@@ -100,12 +102,12 @@ DecisionTree<FitnessFunction,
CategoricalSplitType,
ElemType,
NoRecursion>::DecisionTree(DecisionTree&& other) :
+ NumericAuxiliarySplitInfo(std::move(other)),
+ CategoricalAuxiliarySplitInfo(std::move(other)),
children(std::move(other.children)),
splitDimension(other.splitDimension),
dimensionTypeOrMajorityClass(other.dimensionTypeOrMajorityClass),
- classProbabilities(std::move(other.classProbabilities)),
- numericAux(std::move(other.numericAux)),
- categoricalAux(std::move(other.categoricalAux))
+ classProbabilities(std::move(other.classProbabilities))
{
// Reset the other object.
other.classProbabilities.ones(1); // One class, P(1) = 1.
@@ -137,13 +139,15 @@ DecisionTree<FitnessFunction,
splitDimension = other.splitDimension;
dimensionTypeOrMajorityClass = other.dimensionTypeOrMajorityClass;
classProbabilities = other.classProbabilities;
- numericAux = other.numericAux;
- categoricalAux = other.categoricalAux;
// Copy the children.
for (size_t i = 0; i < other.children.size(); ++i)
children.push_back(new DecisionTree(*other.children[i]));
+ // Copy the auxiliary info.
+ NumericAuxiliarySplitInfo::operator=(other);
+ CategoricalAuxiliarySplitInfo::operator=(other);
+
return *this;
}
@@ -174,12 +178,14 @@ DecisionTree<FitnessFunction,
splitDimension = other.splitDimension;
dimensionTypeOrMajorityClass = other.dimensionTypeOrMajorityClass;
classProbabilities = std::move(other.classProbabilities);
- numericAux = std::move(other.numericAux);
- categoricalAux = std::move(other.categoricalAux);
// Reset the class probabilities of the other object.
other.classProbabilities.ones(1); // One class, P(1) = 1.
+ // Take ownership of the auxiliary info.
+ NumericAuxiliarySplitInfo::operator=(std::move(other));
+ CategoricalAuxiliarySplitInfo::operator=(std::move(other));
+
return *this;
}
@@ -231,47 +237,59 @@ void DecisionTree<FitnessFunction,
for (size_t i = 0; i < datasetInfo.Dimensionality(); ++i)
{
double dimGain;
- if (datasetInfo.Type(i) == Datatype::categorical)
+ if (datasetInfo.Type(i) == data::Datatype::categorical)
dimGain = CategoricalSplit::SplitIfBetter(bestGain, data.row(i),
- info.NumMappings(i), labels, numClasses, minimumLeafSize,
- classProbabilities, categoricalAux);
- else if (datasetInfo.Type(i) == Datatype::numeric)
+ datasetInfo.NumMappings(i), labels, numClasses, minimumLeafSize,
+ classProbabilities, *this);
+ else if (datasetInfo.Type(i) == data::Datatype::numeric)
dimGain = NumericSplit::SplitIfBetter(bestGain, data.row(i), labels,
- numClasses, minimumLeafSize, classProbabilities, numericAux);
+ numClasses, minimumLeafSize, classProbabilities, *this);
// Was there an improvement? If so mark that it's the new best dimension.
if (dimGain > bestGain)
+ {
bestDim = i;
+ bestGain = dimGain;
+ }
+
+ // If the gain is the best possible, no need to keep looking.
+ if (bestGain == 0.0)
+ break;
}
// Did we split or not? If so, then split the data and create the children.
if (bestDim != datasetInfo.Dimensionality())
{
dimensionTypeOrMajorityClass = (size_t) datasetInfo.Type(bestDim);
+ splitDimension = bestDim;
// Get the number of children we will have.
size_t numChildren = 0;
- if (datasetInfo.Type(bestDim) == Datatype::categorical)
- numChildren = CategoricalSplit::NumChildren(classProbabilities,
- categoricalAux);
+ if (datasetInfo.Type(bestDim) == data::Datatype::categorical)
+ numChildren = CategoricalSplit::NumChildren(classProbabilities, *this);
else
- numChildren = NumericSplit::NumChildren(classProbabilities, numericAux);
+ numChildren = NumericSplit::NumChildren(classProbabilities, *this);
// Calculate all child assignments.
arma::Col<size_t> childAssignments(data.n_cols);
- if (datasetInfo.Type(bestDim) == Datatype::categorical)
+ if (datasetInfo.Type(bestDim) == data::Datatype::categorical)
{
for (size_t j = 0; j < data.n_cols; ++j)
childAssignments[j] = CategoricalSplit::CalculateDirection(
- classProbabilities, categoricalAux);
+ data(bestDim, j), classProbabilities, *this);
}
else
{
for (size_t j = 0; j < data.n_cols; ++j)
- childAssignments[j] = NumericSplit::CalculateDirection(
- classProbabilities, numericAux);
+ childAssignments[j] = NumericSplit::CalculateDirection(data(bestDim, j),
+ classProbabilities, *this);
}
+ // Figure out counts of children.
+ arma::Row<size_t> childCounts(numClasses, arma::fill::zeros);
+ for (size_t i = 0; i < childAssignments.n_elem; ++i)
+ childCounts[childAssignments[i]]++;
+
// Split into children.
for (size_t i = 0; i < numChildren; ++i)
{
@@ -296,8 +314,8 @@ void DecisionTree<FitnessFunction,
else
{
// Clear auxiliary info objects.
- numericAux = NumericSplit::AuxiliarySplitInfo<ElemType>();
- categoricalAux = CategoricalSplit::AuxiliarySplitInfo<ElemType>();
+ NumericAuxiliarySplitInfo::operator=(NumericAuxiliarySplitInfo());
+ CategoricalAuxiliarySplitInfo::operator=(CategoricalAuxiliarySplitInfo());
// Calculate class probabilities because we are a leaf.
CalculateClassProbabilities(labels);
@@ -325,9 +343,8 @@ void DecisionTree<FitnessFunction,
delete children[i];
children.clear();
- // We won't be using this.
- categoricalAux =
- CategoricalSplitType<FitnessFunction>::AuxiliarySplitInfo<ElemType>();
+ // We won't be using these members, so reset them.
+ CategoricalAuxiliarySplitInfo::operator=(CategoricalAuxiliarySplitInfo());
// Look through the list of dimensions and obtain the best split. We'll cache
// the best numeric split auxiliary information in numericAux (and clear it
@@ -340,24 +357,32 @@ void DecisionTree<FitnessFunction,
{
double dimGain = NumericSplitType<FitnessFunction>::SplitIfBetter(bestGain,
data.row(i), labels, numClasses, minimumLeafSize, classProbabilities,
- numericAux);
+ *this);
if (dimGain > bestGain)
+ {
bestDim = i;
+ bestGain = dimGain;
+ }
+
+ // If the gain is the best possible, no need to keep looking.
+ if (bestGain == 0.0)
+ break;
}
// Did we split or not? If so, then split the data and create the children.
if (bestDim != data.n_rows)
{
// We know that the split is numeric.
- size_t numChildren = NumericSplit::NumChildren(classProbabilities,
- numericAux);
+ size_t numChildren = NumericSplit::NumChildren(classProbabilities, *this);
+ splitDimension = bestDim;
+ dimensionTypeOrMajorityClass = (size_t) data::Datatype::numeric;
// Calculate all child assignments.
arma::Col<size_t> childAssignments(data.n_cols);
for (size_t j = 0; j < data.n_cols; ++j)
- childAssignments[j] = NumericSplit::CalculateDirection(classProbabilities,
- numericAux);
+ childAssignments[j] = NumericSplit::CalculateDirection(data(bestDim, j),
+ classProbabilities, *this);
// Calculate counts of children in each node.
arma::Col<size_t> childCounts(numChildren);
@@ -387,11 +412,11 @@ void DecisionTree<FitnessFunction,
}
else
{
- // We won't be needing this.
- numericAux = NumericSplit::AuxiliarySplitInfo<ElemType>();
+ // We won't be needing these members, so reset them.
+ NumericAuxiliarySplitInfo::operator=(NumericAuxiliarySplitInfo());
// Calculate class probabilities because we are a leaf.
- CalculateClassProbabilities(labels);
+ CalculateClassProbabilities(labels, numClasses);
}
}
@@ -420,7 +445,7 @@ size_t DecisionTree<FitnessFunction,
//! Return class probabilities for a given point.
template<typename FitnessFunction,
template<typename> class NumericSplitType,
- template<typename> class CategoricalSplittype,
+ template<typename> class CategoricalSplitType,
typename ElemType,
bool NoRecursion>
template<typename VecType>
@@ -439,7 +464,7 @@ void DecisionTree<FitnessFunction,
return;
}
- children[CalculateDirection(point)]->Classify(point, probabilities);
+ children[CalculateDirection(point)]->Classify(point, prediction, probabilities);
}
//! Return the class for a set of points.
@@ -495,7 +520,7 @@ void DecisionTree<FitnessFunction,
// be.
DecisionTree* node = children[0];
while (node->NumChildren() != 0)
- node = &node->Child(i);
+ node = &node->Child(0);
probabilities.set_size(node.classProbabilities.n_elem, data.n_cols);
for (size_t i = 0; i < data.n_cols; ++i)
@@ -528,7 +553,7 @@ void DecisionTree<FitnessFunction,
// Serialize the children first.
size_t numChildren = children.size();
- ar & CreateNVP(numChildren);
+ ar & CreateNVP(numChildren, "numChildren");
if (Archive::is_loading::value)
{
children.resize(numChildren, NULL);
@@ -547,8 +572,6 @@ void DecisionTree<FitnessFunction,
ar & CreateNVP(splitDimension, "splitDimension");
ar & CreateNVP(dimensionTypeOrMajorityClass, "dimensionTypeOrMajorityClass");
ar & CreateNVP(classProbabilities, "classProbabilities");
- ar & CreateNVP(numericAux, "numericAux");
- ar & CreateNVP(categoricalAux, "categoricalAux");
}
template<typename FitnessFunction,
@@ -563,12 +586,38 @@ size_t DecisionTree<FitnessFunction,
ElemType,
NoRecursion>::CalculateDirection(const VecType& point) const
{
- if ((Datatype) dimensionTypeOrMajorityClass == Datatype::categorical)
- return CategoricalSplit::CalculateDirection(point, classProbabilities,
- categoricalAux);
+ if ((data::Datatype) dimensionTypeOrMajorityClass ==
+ data::Datatype::categorical)
+ return CategoricalSplit::CalculateDirection(point[splitDimension],
+ classProbabilities, *this);
else
- return NumericSplit::CalculateDirection(point, classProbabilities,
- numericAux);
+ return NumericSplit::CalculateDirection(point[splitDimension],
+ classProbabilities, *this);
+}
+
+template<typename FitnessFunction,
+ template<typename> class NumericSplitType,
+ template<typename> class CategoricalSplitType,
+ typename ElemType,
+ bool NoRecursion>
+template<typename RowType>
+void DecisionTree<FitnessFunction,
+ NumericSplitType,
+ CategoricalSplitType,
+ ElemType,
+ NoRecursion>::CalculateClassProbabilities(
+ const RowType& labels,
+ const size_t numClasses)
+{
+ classProbabilities.zeros(numClasses);
+ for (size_t i = 0; i < labels.n_elem; ++i)
+ classProbabilities[labels[i]]++;
+
+ // Now normalize into probabilities.
+ classProbabilities /= labels.n_elem;
+ arma::uword maxIndex;
+ classProbabilities.max(maxIndex);
+ dimensionTypeOrMajorityClass = (size_t) maxIndex;
}
} // namespace tree
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list