[mlpack] 42/207: Fix minor bugs in splitting strategies.
Barak A. Pearlmutter
barak+git at pearlmutter.net
Thu Mar 23 17:53:39 UTC 2017
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch master
in repository mlpack.
commit f08b674048e2c5eeb87ac11d995643129dbbe9f2
Author: Ryan Curtin <ryan at ratml.org>
Date: Thu Jan 19 10:56:11 2017 -0500
Fix minor bugs in splitting strategies.
---
.../decision_tree/all_categorical_split_impl.hpp | 11 ++++++-----
.../best_binary_numeric_split_impl.hpp | 21 +++++++++++++--------
2 files changed, 19 insertions(+), 13 deletions(-)
diff --git a/src/mlpack/methods/decision_tree/all_categorical_split_impl.hpp b/src/mlpack/methods/decision_tree/all_categorical_split_impl.hpp
index 76aea39..166f2de 100644
--- a/src/mlpack/methods/decision_tree/all_categorical_split_impl.hpp
+++ b/src/mlpack/methods/decision_tree/all_categorical_split_impl.hpp
@@ -20,7 +20,8 @@ double AllCategoricalSplit<FitnessFunction>::SplitIfBetter(
const size_t numClasses,
const size_t minimumLeafSize,
arma::Col<typename VecType::elem_type>& classProbabilities,
- AllCategoricalSplit::AuxiliarySplitInfo<typename VecType::elem_type>& aux)
+ AllCategoricalSplit::AuxiliarySplitInfo<typename VecType::elem_type>&
+ /* aux */)
{
// Count the number of elements in each potential child.
arma::Col<size_t> counts(numCategories);
@@ -35,8 +36,8 @@ double AllCategoricalSplit<FitnessFunction>::SplitIfBetter(
// Calculate the gain of the split. First we have to calculate the labels
// that would be assigned to each child.
- arma::uvec childPositions(numCategories);
- std::vector<arma::Row<size_t>> childLabels;
+ arma::uvec childPositions(numCategories, arma::fill::zeros);
+ std::vector<arma::Row<size_t>> childLabels(numCategories);
for (size_t i = 0; i < numCategories; ++i)
childLabels[i].zeros(counts[i]);
@@ -83,10 +84,10 @@ template<typename FitnessFunction>
template<typename ElemType>
size_t AllCategoricalSplit<FitnessFunction>::CalculateDirection(
const ElemType& point,
- const arma::Col<ElemType>& classProbabilities,
+ const arma::Col<ElemType>& /* classProbabilities */,
const AllCategoricalSplit::AuxiliarySplitInfo<ElemType>& /* aux */)
{
- return point;
+ return (size_t) point;
}
} // namespace tree
diff --git a/src/mlpack/methods/decision_tree/best_binary_numeric_split_impl.hpp b/src/mlpack/methods/decision_tree/best_binary_numeric_split_impl.hpp
index 0a2981a..7ec453f 100644
--- a/src/mlpack/methods/decision_tree/best_binary_numeric_split_impl.hpp
+++ b/src/mlpack/methods/decision_tree/best_binary_numeric_split_impl.hpp
@@ -20,7 +20,7 @@ double BestBinaryNumericSplit<FitnessFunction>::SplitIfBetter(
const size_t minimumLeafSize,
arma::Col<typename VecType::elem_type>& classProbabilities,
BestBinaryNumericSplit::AuxiliarySplitInfo<typename VecType::elem_type>&
- aux)
+ /* aux */)
{
// First sanity check: if we don't have enough points, we can't split.
if (data.n_elem < (minimumLeafSize * 2))
@@ -30,13 +30,18 @@ double BestBinaryNumericSplit<FitnessFunction>::SplitIfBetter(
arma::uvec sortedIndices = arma::sort_index(data);
arma::Row<size_t> sortedLabels(labels.n_elem);
for (size_t i = 0; i < sortedLabels.n_elem; ++i)
- sortedLabels[i] = labels[sortedIndices[i]];
+ sortedLabels[sortedIndices[i]] = labels[i];
- // Loop through all possible split points, choosing the best one.
+ // Loop through all possible split points, choosing the best one. Also, force
+ // a minimum leaf size of 1 (empty children don't make sense).
double bestFoundGain = bestGain;
- for (size_t index = minimumLeafSize; index < data.n_elem - minimumLeafSize;
- ++index)
+ const size_t minimum = std::max(minimumLeafSize, (size_t) 1);
+ for (size_t index = minimum; index < data.n_elem - minimum; ++index)
{
+ // Make sure that the value has changed.
+ if (data[sortedIndices[index]] == data[sortedIndices[index - 1]])
+ continue;
+
// Calculate the gain for the left and right child.
const double leftGain = FitnessFunction::Evaluate(sortedLabels.subvec(0,
index - 1), numClasses);
@@ -44,14 +49,14 @@ double BestBinaryNumericSplit<FitnessFunction>::SplitIfBetter(
index, sortedLabels.n_elem - 1), numClasses);
// Calculate the fraction of points in the left and right children.
- const double leftRatio = double(index) / double(sortedLabels.n_elem);
- const double rightRatio = 1.0 - leftRatio;
+ const double rightRatio = double(index) / double(sortedLabels.n_elem);
+ const double leftRatio = 1.0 - rightRatio;
// Calculate the gain at this split point.
const double gain = leftRatio * leftGain + rightRatio * rightGain;
// Corner case: is this the best possible split?
- if (gain == FitnessFunction::BestGain(numClasses))
+ if (gain == 0.0)
{
// We can take a shortcut: no split will be better than this, so just take
// this one.
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list