[mlpack] 49/207: Add test for categorical data and stumps.

Barak A. Pearlmutter barak+git at pearlmutter.net
Thu Mar 23 17:53:39 UTC 2017


This is an automated email from the git hooks/post-receive script.

bap pushed a commit to branch master
in repository mlpack.

commit d62eaf235b5924fb27f57b9573c74e444e4eed88
Author: Ryan Curtin <ryan at ratml.org>
Date:   Mon Jan 23 17:26:06 2017 -0500

    Add test for categorical data and stumps.
---
 src/mlpack/tests/decision_tree_test.cpp | 140 ++++++++++++++++++++++++++++++--
 1 file changed, 135 insertions(+), 5 deletions(-)

diff --git a/src/mlpack/tests/decision_tree_test.cpp b/src/mlpack/tests/decision_tree_test.cpp
index d628d34..02f8733 100644
--- a/src/mlpack/tests/decision_tree_test.cpp
+++ b/src/mlpack/tests/decision_tree_test.cpp
@@ -20,6 +20,7 @@
 
 using namespace mlpack;
 using namespace mlpack::tree;
+using namespace mlpack::distribution;
 
 BOOST_AUTO_TEST_SUITE(DecisionTreeTest);
 
@@ -456,10 +457,139 @@ BOOST_AUTO_TEST_CASE(SimpleGeneralizationTest)
 }
 
 /**
-- aux split info is empty
-- basic construction test
-- build on sparse test on dense
-- efficacy test
+ * Test that we can build a decision tree on a simple categorical dataset.
+ */
+BOOST_AUTO_TEST_CASE(CategoricalBuildTest)
+{
+  math::RandomSeed(std::time(NULL));
+
+  // We'll build a spiral dataset plus two noisy categorical features.  We need
+  // to build the distributions for the categorical features (they'll be
+  // discrete distributions).
+  DiscreteDistribution c1[5];
+  // The distribution will be automatically normalized.
+  for (size_t i = 0; i < 5; ++i)
+    c1[i] = DiscreteDistribution(arma::vec(4, arma::fill::randu));
+
+  DiscreteDistribution c2[5];
+  for (size_t i = 0; i < 5; ++i)
+    c2[i] = DiscreteDistribution(arma::vec(2, arma::fill::randu));
+
+  arma::mat spiralDataset(4, 10000);
+  arma::Row<size_t> labels(10000);
+  for (size_t i = 0; i < 10000; ++i)
+  {
+    // One circle every 20000 samples.  Plus some noise. 
+    const double magnitude = 2.0 + (double(i) / 2000.0) +
+        0.5 * mlpack::math::Random();
+    const double angle = (i % 2000) * (2 * M_PI) + mlpack::math::Random();
+
+    const double x = magnitude * cos(angle);
+    const double y = magnitude * sin(angle);
+
+    spiralDataset(0, i) = x;
+    spiralDataset(1, i) = y;
+
+    // Set categorical features c1 and c2.
+    if (i < 2000)
+    {
+      spiralDataset(2, i) = c1[1].Random()[0];
+      spiralDataset(3, i) = c2[1].Random()[0];
+      labels[i] = 1;
+    }
+    else if (i < 4000)
+    {
+      spiralDataset(2, i) = c1[3].Random()[0];
+      spiralDataset(3, i) = c2[3].Random()[0];
+      labels[i] = 3;
+    }
+    else if (i < 6000)
+    {
+      spiralDataset(2, i) = c1[2].Random()[0];
+      spiralDataset(3, i) = c2[2].Random()[0];
+      labels[i] = 2;
+    }
+    else if (i < 8000)
+    {
+      spiralDataset(2, i) = c1[0].Random()[0];
+      spiralDataset(3, i) = c2[0].Random()[0];
+      labels[i] = 0;
+    }
+    else
+    {
+      spiralDataset(2, i) = c1[4].Random()[0];
+      spiralDataset(3, i) = c2[4].Random()[0];
+      labels[i] = 4;
+    }
+  }
+
+  // Now create the dataset info.
+  data::DatasetInfo di(4);
+  di.Type(2) = data::Datatype::categorical;
+  di.Type(3) = data::Datatype::categorical;
+  // Set mappings.
+  di.MapString("0", 2);
+  di.MapString("1", 2);
+  di.MapString("2", 2);
+  di.MapString("3", 2);
+  di.MapString("0", 3);
+  di.MapString("1", 3);
+
+  // Now shuffle the dataset.
+  arma::uvec indices = arma::shuffle(arma::linspace<arma::uvec>(0, 9999,
+      10000));
+  arma::mat d(4, 10000);
+  arma::Row<size_t> l(10000);
+  for (size_t i = 0; i < 10000; ++i)
+  {
+    d.col(i) = spiralDataset.col(indices[i]);
+    l[i] = labels[indices[i]];
+  }
+
+  // Split into a training set and a test set.
+  arma::mat trainingData = d.cols(0, 4999);
+  arma::mat testData = d.cols(5000, 9999);
+  arma::Row<size_t> trainingLabels = l.subvec(0, 4999);
+  arma::Row<size_t> testLabels = l.subvec(5000, 9999);
+
+  // Build the tree.
+  DecisionTree<> tree(trainingData, di, trainingLabels, 5, 10);
+
+  // Now evaluate the accuracy of the tree.
+  arma::Row<size_t> predictions;
+  tree.Classify(testData, predictions);
+
+  BOOST_REQUIRE_EQUAL(predictions.n_elem, testData.n_cols);
+  size_t correct = 0;
+  for (size_t i = 0; i < testData.n_cols; ++i)
+    if (testLabels[i] == predictions[i])
+      ++correct;
+
+  // Make sure we got at least 70% accuracy.
+  const double correctPct = double(correct) / double(testData.n_cols);
+  BOOST_REQUIRE_GT(correctPct, 0.70);
+}
+
+/**
+ * Make sure that when we ask for a decision stump, we get one.
+ */
+BOOST_AUTO_TEST_CASE(DecisionStumpTest)
+{
+  // Use a random dataset.
+  arma::mat dataset(10, 1000, arma::fill::randu);
+  arma::Row<size_t> labels(1000);
+  for (size_t i = 0; i < 1000; ++i)
+    labels[i] = i % 3; // 3 classes.
+
+  // Build a decision stump.
+  DecisionTree<GiniGain, BestBinaryNumericSplit, AllCategoricalSplit, double,
+      true> stump(dataset, labels, 3, 1);
+
+  // Check that it has children.
+  BOOST_REQUIRE_EQUAL(stump.NumChildren(), 2);
+  // Check that its children doesn't have children.
+  BOOST_REQUIRE_EQUAL(stump.Child(0).NumChildren(), 0);
+  BOOST_REQUIRE_EQUAL(stump.Child(1).NumChildren(), 0);
+}
 
-*/
 BOOST_AUTO_TEST_SUITE_END();

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git



More information about the debian-science-commits mailing list