[mlpack] 49/207: Add test for categorical data and stumps.
Barak A. Pearlmutter
barak+git at pearlmutter.net
Thu Mar 23 17:53:39 UTC 2017
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch master
in repository mlpack.
commit d62eaf235b5924fb27f57b9573c74e444e4eed88
Author: Ryan Curtin <ryan at ratml.org>
Date: Mon Jan 23 17:26:06 2017 -0500
Add test for categorical data and stumps.
---
src/mlpack/tests/decision_tree_test.cpp | 140 ++++++++++++++++++++++++++++++--
1 file changed, 135 insertions(+), 5 deletions(-)
diff --git a/src/mlpack/tests/decision_tree_test.cpp b/src/mlpack/tests/decision_tree_test.cpp
index d628d34..02f8733 100644
--- a/src/mlpack/tests/decision_tree_test.cpp
+++ b/src/mlpack/tests/decision_tree_test.cpp
@@ -20,6 +20,7 @@
using namespace mlpack;
using namespace mlpack::tree;
+using namespace mlpack::distribution;
BOOST_AUTO_TEST_SUITE(DecisionTreeTest);
@@ -456,10 +457,139 @@ BOOST_AUTO_TEST_CASE(SimpleGeneralizationTest)
}
/**
-- aux split info is empty
-- basic construction test
-- build on sparse test on dense
-- efficacy test
+ * Test that we can build a decision tree on a simple categorical dataset.
+ */
+BOOST_AUTO_TEST_CASE(CategoricalBuildTest)
+{
+ math::RandomSeed(std::time(NULL));
+
+ // We'll build a spiral dataset plus two noisy categorical features. We need
+ // to build the distributions for the categorical features (they'll be
+ // discrete distributions).
+ DiscreteDistribution c1[5];
+ // The distribution will be automatically normalized.
+ for (size_t i = 0; i < 5; ++i)
+ c1[i] = DiscreteDistribution(arma::vec(4, arma::fill::randu));
+
+ DiscreteDistribution c2[5];
+ for (size_t i = 0; i < 5; ++i)
+ c2[i] = DiscreteDistribution(arma::vec(2, arma::fill::randu));
+
+ arma::mat spiralDataset(4, 10000);
+ arma::Row<size_t> labels(10000);
+ for (size_t i = 0; i < 10000; ++i)
+ {
+ // One circle every 20000 samples. Plus some noise.
+ const double magnitude = 2.0 + (double(i) / 2000.0) +
+ 0.5 * mlpack::math::Random();
+ const double angle = (i % 2000) * (2 * M_PI) + mlpack::math::Random();
+
+ const double x = magnitude * cos(angle);
+ const double y = magnitude * sin(angle);
+
+ spiralDataset(0, i) = x;
+ spiralDataset(1, i) = y;
+
+ // Set categorical features c1 and c2.
+ if (i < 2000)
+ {
+ spiralDataset(2, i) = c1[1].Random()[0];
+ spiralDataset(3, i) = c2[1].Random()[0];
+ labels[i] = 1;
+ }
+ else if (i < 4000)
+ {
+ spiralDataset(2, i) = c1[3].Random()[0];
+ spiralDataset(3, i) = c2[3].Random()[0];
+ labels[i] = 3;
+ }
+ else if (i < 6000)
+ {
+ spiralDataset(2, i) = c1[2].Random()[0];
+ spiralDataset(3, i) = c2[2].Random()[0];
+ labels[i] = 2;
+ }
+ else if (i < 8000)
+ {
+ spiralDataset(2, i) = c1[0].Random()[0];
+ spiralDataset(3, i) = c2[0].Random()[0];
+ labels[i] = 0;
+ }
+ else
+ {
+ spiralDataset(2, i) = c1[4].Random()[0];
+ spiralDataset(3, i) = c2[4].Random()[0];
+ labels[i] = 4;
+ }
+ }
+
+ // Now create the dataset info.
+ data::DatasetInfo di(4);
+ di.Type(2) = data::Datatype::categorical;
+ di.Type(3) = data::Datatype::categorical;
+ // Set mappings.
+ di.MapString("0", 2);
+ di.MapString("1", 2);
+ di.MapString("2", 2);
+ di.MapString("3", 2);
+ di.MapString("0", 3);
+ di.MapString("1", 3);
+
+ // Now shuffle the dataset.
+ arma::uvec indices = arma::shuffle(arma::linspace<arma::uvec>(0, 9999,
+ 10000));
+ arma::mat d(4, 10000);
+ arma::Row<size_t> l(10000);
+ for (size_t i = 0; i < 10000; ++i)
+ {
+ d.col(i) = spiralDataset.col(indices[i]);
+ l[i] = labels[indices[i]];
+ }
+
+ // Split into a training set and a test set.
+ arma::mat trainingData = d.cols(0, 4999);
+ arma::mat testData = d.cols(5000, 9999);
+ arma::Row<size_t> trainingLabels = l.subvec(0, 4999);
+ arma::Row<size_t> testLabels = l.subvec(5000, 9999);
+
+ // Build the tree.
+ DecisionTree<> tree(trainingData, di, trainingLabels, 5, 10);
+
+ // Now evaluate the accuracy of the tree.
+ arma::Row<size_t> predictions;
+ tree.Classify(testData, predictions);
+
+ BOOST_REQUIRE_EQUAL(predictions.n_elem, testData.n_cols);
+ size_t correct = 0;
+ for (size_t i = 0; i < testData.n_cols; ++i)
+ if (testLabels[i] == predictions[i])
+ ++correct;
+
+ // Make sure we got at least 70% accuracy.
+ const double correctPct = double(correct) / double(testData.n_cols);
+ BOOST_REQUIRE_GT(correctPct, 0.70);
+}
+
+/**
+ * Make sure that when we ask for a decision stump, we get one.
+ */
+BOOST_AUTO_TEST_CASE(DecisionStumpTest)
+{
+ // Use a random dataset.
+ arma::mat dataset(10, 1000, arma::fill::randu);
+ arma::Row<size_t> labels(1000);
+ for (size_t i = 0; i < 1000; ++i)
+ labels[i] = i % 3; // 3 classes.
+
+ // Build a decision stump.
+ DecisionTree<GiniGain, BestBinaryNumericSplit, AllCategoricalSplit, double,
+ true> stump(dataset, labels, 3, 1);
+
+ // Check that it has children.
+ BOOST_REQUIRE_EQUAL(stump.NumChildren(), 2);
+ // Check that its children doesn't have children.
+ BOOST_REQUIRE_EQUAL(stump.Child(0).NumChildren(), 0);
+ BOOST_REQUIRE_EQUAL(stump.Child(1).NumChildren(), 0);
+}
-*/
BOOST_AUTO_TEST_SUITE_END();
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list