[mlpack] 48/207: Add simple generalization test for numeric features.
Barak A. Pearlmutter
barak+git at pearlmutter.net
Thu Mar 23 17:53:39 UTC 2017
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch master
in repository mlpack.
commit 6f9083da7a919f71af9b96f80e48291828d51745
Author: Ryan Curtin <ryan at ratml.org>
Date: Mon Jan 23 13:23:17 2017 -0500
Add simple generalization test for numeric features.
---
.../methods/decision_stump/decision_stump.hpp | 4 +-
.../methods/decision_stump/decision_stump_impl.hpp | 43 +++++++++++++++++-----
src/mlpack/tests/adaboost_test.cpp | 2 +-
src/mlpack/tests/decision_stump_test.cpp | 22 +++++++++--
src/mlpack/tests/decision_tree_test.cpp | 41 +++++++++++++++++++++
5 files changed, 96 insertions(+), 16 deletions(-)
diff --git a/src/mlpack/methods/decision_stump/decision_stump.hpp b/src/mlpack/methods/decision_stump/decision_stump.hpp
index d262143..0d36bfa 100644
--- a/src/mlpack/methods/decision_stump/decision_stump.hpp
+++ b/src/mlpack/methods/decision_stump/decision_stump.hpp
@@ -205,8 +205,8 @@ class DecisionStump
* @tparam dimension dimension is the dimension decided by the constructor
* on which we now train the decision stump.
*/
- template<typename VecType>
- void TrainOnDim(const VecType& dimension,
+ void TrainOnDim(const MatType& data,
+ const size_t dimension,
const arma::Row<size_t>& labels);
/**
diff --git a/src/mlpack/methods/decision_stump/decision_stump_impl.hpp b/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
index 228b720..c1b37d9 100644
--- a/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
+++ b/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
@@ -233,7 +233,7 @@ void DecisionStump<MatType, NoRecursion>::Train(
splitDimensionOrLabel = bestDim;
// Once the splitting column/dimension has been decided, train on it.
- TrainOnDim(data.row(splitDimensionOrLabel), labels);
+ TrainOnDim(data, splitDimensionOrLabel, labels);
}
/**
@@ -447,25 +447,30 @@ double DecisionStump<MatType, NoRecursion>::SetupSplitDimension(
* which we now train the decision stump.
*/
template<typename MatType, bool NoRecursion>
-template<typename VecType>
void DecisionStump<MatType, NoRecursion>::TrainOnDim(
- const VecType& dimension,
+ const MatType& data,
+ const size_t dimension,
const arma::Row<size_t>& labels)
{
size_t i, count, begin, end;
- typename MatType::row_type sortedSplitDim = arma::sort(dimension);
- arma::uvec sortedSplitIndexDim = arma::stable_sort_index(dimension.t());
- arma::Row<size_t> sortedLabels(dimension.n_elem);
- sortedLabels.fill(0);
+ typename MatType::row_type sortedSplitDim = arma::sort(data.row(dimension));
+ arma::uvec sortedSplitIndexDim =
+ arma::stable_sort_index(data.row(dimension).t());
+ arma::Row<size_t> sortedLabels(data.n_cols);
arma::Col<size_t> binLabels;
- for (i = 0; i < dimension.n_elem; i++)
+ for (i = 0; i < data.n_cols; i++)
sortedLabels(i) = labels(sortedSplitIndexDim(i));
+ /**
+ * Loop through the points, splitting when it is advantageous. We check to
+ * see if we can split at index i, and then if we can, the split will take the
+ * value that's the midpoint between index i and index i + 1.
+ */
arma::rowvec subCols;
double mostFreq;
- i = 0;
+ i = bucketSize;
count = 0;
while (i < sortedLabels.n_elem)
{
@@ -506,7 +511,7 @@ void DecisionStump<MatType, NoRecursion>::TrainOnDim(
mostFreq = CountMostFreq(sortedLabels.cols(begin, end));
splitOrClassProbs.resize(splitOrClassProbs.n_elem + 1);
- splitOrClassProbs(splitOrClassProbs.n_elem - 1) = sortedSplitDim(begin);
+ splitOrClassProbs(splitOrClassProbs.n_elem - 1) = sortedSplitDim(end + 1);
binLabels.resize(binLabels.n_elem + 1);
binLabels(binLabels.n_elem - 1) = mostFreq;
@@ -576,6 +581,24 @@ void DecisionStump<MatType, NoRecursion>::TrainOnDim(
else
{
// Do recursion.
+ size_t begin = 0;
+ for (size_t i = 0; i < splitOrClassProbs.n_elem; ++i)
+ {
+ // Determine how many points are in this child.
+ size_t lastBegin = begin;
+ while (sortedSplitDim(++begin) < splitOrClassProbs[i]) { }
+ size_t numPoints = (lastBegin - begin);
+
+ // Allocate memory for child data and fill it.
+ MatType childData(data.n_rows, numPoints);
+ for (size_t i = lastBegin; i < begin; ++i)
+ childData.col(i - lastBegin) = data.col(sortedSplitIndexDim[i]);
+ arma::Row<size_t> childLabels = sortedLabels.subvec(lastBegin, begin - 1);
+
+ // Create the child recursively.
+ children.push_back(new DecisionStump(childData, childLabels, classes,
+ bucketSize));
+ }
}
}
diff --git a/src/mlpack/tests/adaboost_test.cpp b/src/mlpack/tests/adaboost_test.cpp
index ec4b4ca..b67f7de 100644
--- a/src/mlpack/tests/adaboost_test.cpp
+++ b/src/mlpack/tests/adaboost_test.cpp
@@ -125,7 +125,7 @@ BOOST_AUTO_TEST_CASE(HammingLossBoundVertebralColumn)
BOOST_FAIL("Cannot load test dataset vc2.csv!");
arma::Mat<size_t> labels;
- if (!data::Load("vc2_labels.txt",labels))
+ if (!data::Load("vc2_labels.txt", labels))
BOOST_FAIL("Cannot load labels for vc2_labels.txt");
// Define your own weak learner, perceptron in this case.
diff --git a/src/mlpack/tests/decision_stump_test.cpp b/src/mlpack/tests/decision_stump_test.cpp
index a243076..21d1d71 100644
--- a/src/mlpack/tests/decision_stump_test.cpp
+++ b/src/mlpack/tests/decision_stump_test.cpp
@@ -558,11 +558,27 @@ BOOST_AUTO_TEST_CASE(DecisionStumpMoveOperatorTest)
}
/**
- * Test that the decision tree outperforms the decision stump.
+ * Test that the decision tree can be reasonably built.
*/
-BOOST_AUTO_TEST_CASE(DecisionTreeVsStumpTest)
+BOOST_AUTO_TEST_CASE(DecisionTreeBuildTest)
{
-
+ arma::mat inputData;
+ if (!data::Load("vc2.csv", inputData))
+ BOOST_FAIL("Cannot load test dataset vc2.csv!");
+
+ arma::Mat<size_t> labels;
+ if (!data::Load("vc2_labels.txt", labels))
+ BOOST_FAIL("Cannot load labels for vc2_labels.txt");
+
+ // Construct a full decision tree.
+ DecisionStump<arma::mat, false> tree(inputData, labels.row(0), 3);
+
+ // Ensure that it has some children.
+ BOOST_REQUIRE_GT(tree.NumChildren(), 0);
+
+ // Ensure that its children have some children.
+ for (size_t i = 0; i < tree.NumChildren(); ++i)
+ BOOST_REQUIRE_GT(tree.Child(i).NumChildren(), 0);
}
BOOST_AUTO_TEST_SUITE_END();
diff --git a/src/mlpack/tests/decision_tree_test.cpp b/src/mlpack/tests/decision_tree_test.cpp
index e2b0f31..d628d34 100644
--- a/src/mlpack/tests/decision_tree_test.cpp
+++ b/src/mlpack/tests/decision_tree_test.cpp
@@ -415,6 +415,47 @@ BOOST_AUTO_TEST_CASE(ClassProbabilityTest)
}
/**
+ * Test that the decision tree generalizes reasonably.
+ */
+BOOST_AUTO_TEST_CASE(SimpleGeneralizationTest)
+{
+ arma::mat inputData;
+ if (!data::Load("vc2.csv", inputData))
+ BOOST_FAIL("Cannot load test dataset vc2.csv!");
+
+ arma::Mat<size_t> labels;
+ if (!data::Load("vc2_labels.txt", labels))
+ BOOST_FAIL("Cannot load labels for vc2_labels.txt");
+
+ // Build decision tree.
+ DecisionTree<> d(inputData, labels, 3, 10); // Leaf size of 10.
+
+ // Load testing data.
+ arma::mat testData;
+ if (!data::Load("vc2_test.csv", testData))
+ BOOST_FAIL("Cannot load test dataset vc2_test.csv!");
+
+ arma::Mat<size_t> trueTestLabels;
+ if (!data::Load("vc2_test_labels.txt", trueTestLabels))
+ BOOST_FAIL("Cannot load labels for vc2_test_labels.txt");
+
+ // Get the predicted test labels.
+ arma::Row<size_t> predictions;
+ d.Classify(testData, predictions);
+
+ BOOST_REQUIRE_EQUAL(predictions.n_elem, testData.n_cols);
+
+ // Figure out the accuracy.
+ double correct = 0.0;
+ for (size_t i = 0; i < predictions.n_elem; ++i)
+ if (predictions[i] == trueTestLabels[i])
+ ++correct;
+ correct /= predictions.n_elem;
+
+ BOOST_REQUIRE_GT(correct, 0.75);
+}
+
+/**
- aux split info is empty
- basic construction test
- build on sparse test on dense
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list