[mlpack] 44/207: Add some tests for the decision tree.
Barak A. Pearlmutter
barak+git at pearlmutter.net
Thu Mar 23 17:53:39 UTC 2017
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch master
in repository mlpack.
commit 4483792a6d088b8b83575b5706bb733f69ca89d6
Author: Ryan Curtin <ryan at ratml.org>
Date: Thu Jan 19 10:57:06 2017 -0500
Add some tests for the decision tree.
---
src/mlpack/tests/decision_tree_test.cpp | 229 ++++++++++++++++++++++++++++++++
1 file changed, 229 insertions(+)
diff --git a/src/mlpack/tests/decision_tree_test.cpp b/src/mlpack/tests/decision_tree_test.cpp
index df4e961..5d57340 100644
--- a/src/mlpack/tests/decision_tree_test.cpp
+++ b/src/mlpack/tests/decision_tree_test.cpp
@@ -175,6 +175,235 @@ BOOST_AUTO_TEST_CASE(InformationGainManyPoints)
}
/**
+ * Check that the BestBinaryNumericSplit will split on an obviously splittable
+ * dimension.
+ */
+BOOST_AUTO_TEST_CASE(BestBinaryNumericSplitSimpleSplitTest)
+{
+ arma::vec values("0.0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0");
+ arma::Row<size_t> labels("0 0 0 0 0 1 1 1 1 1 1");
+
+ arma::vec classProbabilities;
+ BestBinaryNumericSplit<GiniGain>::template AuxiliarySplitInfo<double> aux;
+
+ // Call the method to do the splitting.
+ const double bestGain = GiniGain::Evaluate(labels, 2);
+ const double gain = BestBinaryNumericSplit<GiniGain>::SplitIfBetter(bestGain,
+ values, labels, 2, 3, classProbabilities, aux);
+
+ // Make sure that a split was made.
+ BOOST_REQUIRE_GT(gain, bestGain);
+
+ // The split is perfect, so we should be able to accomplish a gain of 0.
+ BOOST_REQUIRE_SMALL(gain, 1e-5);
+
+ // The class probabilities, for this split, hold the splitting point, which
+ // should be between 4 and 5.
+ BOOST_REQUIRE_EQUAL(classProbabilities.n_elem, 1);
+ BOOST_REQUIRE_GT(classProbabilities[0], 0.4);
+ BOOST_REQUIRE_LT(classProbabilities[1], 0.5);
+}
+
+/**
+ * Check that the BestBinaryNumericSplit won't split if not enough points are
+ * given.
+ */
+BOOST_AUTO_TEST_CASE(BestBinaryNumericSplitMinSamplesTest)
+{
+ arma::vec values("0.0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0");
+ arma::Row<size_t> labels("0 0 0 0 0 1 1 1 1 1 1");
+
+ arma::vec classProbabilities;
+ BestBinaryNumericSplit<GiniGain>::template AuxiliarySplitInfo<double> aux;
+
+ // Call the method to do the splitting.
+ const double bestGain = GiniGain::Evaluate(labels, 2);
+ const double gain = BestBinaryNumericSplit<GiniGain>::SplitIfBetter(bestGain,
+ values, labels, 2, 8, classProbabilities, aux);
+
+ // Make sure that no split was made.
+ BOOST_REQUIRE_EQUAL(gain, bestGain);
+ BOOST_REQUIRE_EQUAL(classProbabilities.n_elem, 0);
+}
+
+/**
+ * Check that the BestBinaryNumericSplit doesn't split a dimension that gives no
+ * gain.
+ */
+BOOST_AUTO_TEST_CASE(BestBinaryNumericSplitNoGainTest)
+{
+ arma::vec values(100);
+ arma::Row<size_t> labels(100);
+ for (size_t i = 0; i < 100; i += 2)
+ {
+ values[i] = i;
+ labels[i] = 0;
+ values[i + 1] = i;
+ labels[i + 1] = 1;
+ }
+
+ arma::vec classProbabilities;
+ BestBinaryNumericSplit<GiniGain>::template AuxiliarySplitInfo<double> aux;
+
+ // Call the method to do the splitting.
+ const double bestGain = GiniGain::Evaluate(labels, 2);
+ const double gain = BestBinaryNumericSplit<GiniGain>::SplitIfBetter(bestGain,
+ values, labels, 2, 10, classProbabilities, aux);
+
+ // Make sure there was no split.
+ BOOST_REQUIRE_EQUAL(gain, bestGain);
+ BOOST_REQUIRE_EQUAL(classProbabilities.n_elem, 0);
+}
+
+/**
+ * Check that the AllCategoricalSplit will split when the split is obviously
+ * better.
+ */
+BOOST_AUTO_TEST_CASE(AllCategoricalSplitSimpleSplitTest)
+{
+ arma::vec values("0 0 0 1 1 1 2 2 2 3 3 3");
+ arma::Row<size_t> labels("0 0 0 2 2 2 1 1 1 2 2 2");
+
+ arma::vec classProbabilities;
+ AllCategoricalSplit<GiniGain>::template AuxiliarySplitInfo<double> aux;
+
+ // Call the method to do the splitting.
+ const double bestGain = GiniGain::Evaluate(labels, 3);
+ const double gain = AllCategoricalSplit<GiniGain>::SplitIfBetter(bestGain,
+ values, 4, labels, 3, 3, classProbabilities, aux);
+
+ // Make sure that a split was made.
+ BOOST_REQUIRE_GT(gain, bestGain);
+
+ // Since the split is perfect, make sure the new gain is 0.
+ BOOST_REQUIRE_SMALL(gain, 1e-5);
+
+ // Make sure the class probabilities now hold the number of children.
+ BOOST_REQUIRE_EQUAL(classProbabilities.n_elem, 1);
+ BOOST_REQUIRE_EQUAL((size_t) classProbabilities[0], 4);
+}
+
+/**
+ * Make sure that AllCategoricalSplit respects the minimum number of samples
+ * required to split.
+ */
+BOOST_AUTO_TEST_CASE(AllCategoricalSplitMinSamplesTest)
+{
+ arma::vec values("0 0 0 1 1 1 2 2 2 3 3 3");
+ arma::Row<size_t> labels("0 0 0 2 2 2 1 1 1 2 2 2");
+
+ arma::vec classProbabilities;
+ AllCategoricalSplit<GiniGain>::template AuxiliarySplitInfo<double> aux;
+
+ // Call the method to do the splitting.
+ const double bestGain = GiniGain::Evaluate(labels, 3);
+ const double gain = AllCategoricalSplit<GiniGain>::SplitIfBetter(bestGain,
+ values, 4, labels, 3, 4, classProbabilities, aux);
+
+ // Make sure it's not split.
+ BOOST_REQUIRE_EQUAL(gain, bestGain);
+ BOOST_REQUIRE_EQUAL(classProbabilities.n_elem, 0);
+}
+
+/**
+ * Check that no split is made when it doesn't get us anything.
+ */
+BOOST_AUTO_TEST_CASE(AllCategoricalSplitNoGainTest)
+{
+ arma::vec values(300);
+ arma::Row<size_t> labels(300);
+ for (size_t i = 0; i < 300; i += 3)
+ {
+ values[i] = (i / 3) % 10;
+ labels[i] = 0;
+ values[i + 1] = (i / 3) % 10;
+ labels[i + 1] = 1;
+ values[i + 2] = (i / 3) % 10;
+ labels[i + 2] = 2;
+ }
+
+ arma::vec classProbabilities;
+ AllCategoricalSplit<GiniGain>::template AuxiliarySplitInfo<double> aux;
+
+ // Call the method to do the splitting.
+ const double bestGain = GiniGain::Evaluate(labels, 3);
+ const double gain = AllCategoricalSplit<GiniGain>::SplitIfBetter(bestGain,
+ values, 10, labels, 3, 10, classProbabilities, aux);
+
+ // Make sure that there was no split.
+ BOOST_REQUIRE_EQUAL(gain, bestGain);
+ BOOST_REQUIRE_EQUAL(classProbabilities.n_elem, 0);
+}
+
+/**
+ * A basic construction of the decision tree---ensure that we can create the
+ * tree and that it split at least once.
+ */
+BOOST_AUTO_TEST_CASE(BasicConstructionTest)
+{
+ arma::mat dataset(10, 1000, arma::fill::randu);
+ arma::Row<size_t> labels(1000);
+ for (size_t i = 0; i < 1000; ++i)
+ labels[i] = i % 3; // 3 classes.
+
+ // Use default parameters.
+ DecisionTree<> d(dataset, labels, 3, 50);
+
+ // Now require that we have some children.
+ BOOST_REQUIRE_GT(d.NumChildren(), 0);
+}
+
+/**
+ * Construct the decision tree on numeric data only and see that we can fit it
+ * exactly and achieve perfect performance on the training set.
+ */
+BOOST_AUTO_TEST_CASE(PerfectTrainingSet)
+{
+ // Completely random dataset with no structure.
+ arma::mat dataset(10, 1000, arma::fill::randu);
+ arma::Row<size_t> labels(1000);
+ for (size_t i = 0; i < 1000; ++i)
+ labels[i] = i % 3; // 3 classes.
+
+ DecisionTree<> d(dataset, labels, 3, 1); // Minimum leaf size of 1.
+
+ // Make sure that we can get perfect accuracy on the training set.
+ arma::Row<size_t> predictions;
+ d.Classify(dataset, predictions);
+
+ for (size_t i = 0; i < 1000; ++i)
+ BOOST_REQUIRE_EQUAL(predictions[i], labels[i]);
+}
+
+/**
+ * Make sure class probabilities are computed correctly in the root node.
+ */
+BOOST_AUTO_TEST_CASE(ClassProbabilityTest)
+{
+ arma::mat dataset(5, 100, arma::fill::randu);
+ arma::Row<size_t> labels(100);
+ for (size_t i = 0; i < 100; i += 2)
+ {
+ labels[i] = 0;
+ labels[i + 1] = 1;
+ }
+
+ // Create a decision tree that can't split.
+ DecisionTree<> d(dataset, labels, 2, 1000);
+
+ BOOST_REQUIRE_EQUAL(d.NumChildren(), 0);
+
+ // Estimate a point's probabilities.
+ arma::vec probabilities;
+ size_t prediction;
+ d.Classify(dataset.col(0), prediction, probabilities);
+
+ BOOST_REQUIRE_EQUAL(probabilities.n_elem, 2);
+ BOOST_REQUIRE_CLOSE(probabilities[0], 0.5, 1e-5);
+ BOOST_REQUIRE_CLOSE(probabilities[1], 0.5, 1e-5);
+}
+
+/**
- aux split info is empty
- basic construction test
- build on sparse test on dense
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list