[mlpack] 48/207: Add simple generalization test for numeric features.

Barak A. Pearlmutter barak+git at pearlmutter.net
Thu Mar 23 17:53:39 UTC 2017


This is an automated email from the git hooks/post-receive script.

bap pushed a commit to branch master
in repository mlpack.

commit 6f9083da7a919f71af9b96f80e48291828d51745
Author: Ryan Curtin <ryan at ratml.org>
Date:   Mon Jan 23 13:23:17 2017 -0500

    Add simple generalization test for numeric features.
---
 .../methods/decision_stump/decision_stump.hpp      |  4 +-
 .../methods/decision_stump/decision_stump_impl.hpp | 43 +++++++++++++++++-----
 src/mlpack/tests/adaboost_test.cpp                 |  2 +-
 src/mlpack/tests/decision_stump_test.cpp           | 22 +++++++++--
 src/mlpack/tests/decision_tree_test.cpp            | 41 +++++++++++++++++++++
 5 files changed, 96 insertions(+), 16 deletions(-)

diff --git a/src/mlpack/methods/decision_stump/decision_stump.hpp b/src/mlpack/methods/decision_stump/decision_stump.hpp
index d262143..0d36bfa 100644
--- a/src/mlpack/methods/decision_stump/decision_stump.hpp
+++ b/src/mlpack/methods/decision_stump/decision_stump.hpp
@@ -205,8 +205,8 @@ class DecisionStump
    * @tparam dimension dimension is the dimension decided by the constructor
    *      on which we now train the decision stump.
    */
-  template<typename VecType>
-  void TrainOnDim(const VecType& dimension,
+  void TrainOnDim(const MatType& data,
+                  const size_t dimension,
                   const arma::Row<size_t>& labels);
 
   /**
diff --git a/src/mlpack/methods/decision_stump/decision_stump_impl.hpp b/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
index 228b720..c1b37d9 100644
--- a/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
+++ b/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
@@ -233,7 +233,7 @@ void DecisionStump<MatType, NoRecursion>::Train(
   splitDimensionOrLabel = bestDim;
 
   // Once the splitting column/dimension has been decided, train on it.
-  TrainOnDim(data.row(splitDimensionOrLabel), labels);
+  TrainOnDim(data, splitDimensionOrLabel, labels);
 }
 
 /**
@@ -447,25 +447,30 @@ double DecisionStump<MatType, NoRecursion>::SetupSplitDimension(
  *      which we now train the decision stump.
  */
 template<typename MatType, bool NoRecursion>
-template<typename VecType>
 void DecisionStump<MatType, NoRecursion>::TrainOnDim(
-    const VecType& dimension,
+    const MatType& data,
+    const size_t dimension,
     const arma::Row<size_t>& labels)
 {
   size_t i, count, begin, end;
 
-  typename MatType::row_type sortedSplitDim = arma::sort(dimension);
-  arma::uvec sortedSplitIndexDim = arma::stable_sort_index(dimension.t());
-  arma::Row<size_t> sortedLabels(dimension.n_elem);
-  sortedLabels.fill(0);
+  typename MatType::row_type sortedSplitDim = arma::sort(data.row(dimension));
+  arma::uvec sortedSplitIndexDim =
+      arma::stable_sort_index(data.row(dimension).t());
+  arma::Row<size_t> sortedLabels(data.n_cols);
   arma::Col<size_t> binLabels;
 
-  for (i = 0; i < dimension.n_elem; i++)
+  for (i = 0; i < data.n_cols; i++)
     sortedLabels(i) = labels(sortedSplitIndexDim(i));
 
+  /**
+   * Loop through the points, splitting when it is advantageous.  We check to
+   * see if we can split at index i, and then if we can, the split will take the
+   * value that's the midpoint between index i and index i + 1.
+   */
   arma::rowvec subCols;
   double mostFreq;
-  i = 0;
+  i = bucketSize;
   count = 0;
   while (i < sortedLabels.n_elem)
   {
@@ -506,7 +511,7 @@ void DecisionStump<MatType, NoRecursion>::TrainOnDim(
       mostFreq = CountMostFreq(sortedLabels.cols(begin, end));
 
       splitOrClassProbs.resize(splitOrClassProbs.n_elem + 1);
-      splitOrClassProbs(splitOrClassProbs.n_elem - 1) = sortedSplitDim(begin);
+      splitOrClassProbs(splitOrClassProbs.n_elem - 1) = sortedSplitDim(end + 1);
       binLabels.resize(binLabels.n_elem + 1);
       binLabels(binLabels.n_elem - 1) = mostFreq;
 
@@ -576,6 +581,24 @@ void DecisionStump<MatType, NoRecursion>::TrainOnDim(
   else
   {
     // Do recursion.
+    size_t begin = 0;
+    for (size_t i = 0; i < splitOrClassProbs.n_elem; ++i)
+    {
+      // Determine how many points are in this child.
+      size_t lastBegin = begin;
+      while (sortedSplitDim(++begin) < splitOrClassProbs[i]) { }
+      size_t numPoints = (lastBegin - begin);
+
+      // Allocate memory for child data and fill it.
+      MatType childData(data.n_rows, numPoints);
+      for (size_t i = lastBegin; i < begin; ++i)
+        childData.col(i - lastBegin) = data.col(sortedSplitIndexDim[i]);
+      arma::Row<size_t> childLabels = sortedLabels.subvec(lastBegin, begin - 1);
+
+      // Create the child recursively.
+      children.push_back(new DecisionStump(childData, childLabels, classes,
+          bucketSize));
+    }
   }
 }
 
diff --git a/src/mlpack/tests/adaboost_test.cpp b/src/mlpack/tests/adaboost_test.cpp
index ec4b4ca..b67f7de 100644
--- a/src/mlpack/tests/adaboost_test.cpp
+++ b/src/mlpack/tests/adaboost_test.cpp
@@ -125,7 +125,7 @@ BOOST_AUTO_TEST_CASE(HammingLossBoundVertebralColumn)
     BOOST_FAIL("Cannot load test dataset vc2.csv!");
 
   arma::Mat<size_t> labels;
-  if (!data::Load("vc2_labels.txt",labels))
+  if (!data::Load("vc2_labels.txt", labels))
     BOOST_FAIL("Cannot load labels for vc2_labels.txt");
 
   // Define your own weak learner, perceptron in this case.
diff --git a/src/mlpack/tests/decision_stump_test.cpp b/src/mlpack/tests/decision_stump_test.cpp
index a243076..21d1d71 100644
--- a/src/mlpack/tests/decision_stump_test.cpp
+++ b/src/mlpack/tests/decision_stump_test.cpp
@@ -558,11 +558,27 @@ BOOST_AUTO_TEST_CASE(DecisionStumpMoveOperatorTest)
 }
 
 /**
- * Test that the decision tree outperforms the decision stump.
+ * Test that the decision tree can be reasonably built.
  */
-BOOST_AUTO_TEST_CASE(DecisionTreeVsStumpTest)
+BOOST_AUTO_TEST_CASE(DecisionTreeBuildTest)
 {
-  
+  arma::mat inputData;
+  if (!data::Load("vc2.csv", inputData))
+    BOOST_FAIL("Cannot load test dataset vc2.csv!");
+
+  arma::Mat<size_t> labels;
+  if (!data::Load("vc2_labels.txt", labels))
+    BOOST_FAIL("Cannot load labels for vc2_labels.txt");
+
+  // Construct a full decision tree.
+  DecisionStump<arma::mat, false> tree(inputData, labels.row(0), 3);
+
+  // Ensure that it has some children.
+  BOOST_REQUIRE_GT(tree.NumChildren(), 0);
+
+  // Ensure that its children have some children.
+  for (size_t i = 0; i < tree.NumChildren(); ++i)
+    BOOST_REQUIRE_GT(tree.Child(i).NumChildren(), 0);
 }
 
 BOOST_AUTO_TEST_SUITE_END();
diff --git a/src/mlpack/tests/decision_tree_test.cpp b/src/mlpack/tests/decision_tree_test.cpp
index e2b0f31..d628d34 100644
--- a/src/mlpack/tests/decision_tree_test.cpp
+++ b/src/mlpack/tests/decision_tree_test.cpp
@@ -415,6 +415,47 @@ BOOST_AUTO_TEST_CASE(ClassProbabilityTest)
 }
 
 /**
+ * Test that the decision tree generalizes reasonably.
+ */
+BOOST_AUTO_TEST_CASE(SimpleGeneralizationTest)
+{
+  arma::mat inputData;
+  if (!data::Load("vc2.csv", inputData))
+    BOOST_FAIL("Cannot load test dataset vc2.csv!");
+
+  arma::Mat<size_t> labels;
+  if (!data::Load("vc2_labels.txt", labels))
+    BOOST_FAIL("Cannot load labels for vc2_labels.txt");
+
+  // Build decision tree.
+  DecisionTree<> d(inputData, labels, 3, 10); // Leaf size of 10.
+
+  // Load testing data.
+  arma::mat testData;
+  if (!data::Load("vc2_test.csv", testData))
+    BOOST_FAIL("Cannot load test dataset vc2_test.csv!");
+
+  arma::Mat<size_t> trueTestLabels;
+  if (!data::Load("vc2_test_labels.txt", trueTestLabels))
+    BOOST_FAIL("Cannot load labels for vc2_test_labels.txt");
+
+  // Get the predicted test labels.
+  arma::Row<size_t> predictions;
+  d.Classify(testData, predictions);
+
+  BOOST_REQUIRE_EQUAL(predictions.n_elem, testData.n_cols);
+
+  // Figure out the accuracy.
+  double correct = 0.0;
+  for (size_t i = 0; i < predictions.n_elem; ++i)
+    if (predictions[i] == trueTestLabels[i])
+      ++correct;
+  correct /= predictions.n_elem;
+
+  BOOST_REQUIRE_GT(correct, 0.75);
+}
+
+/**
 - aux split info is empty
 - basic construction test
 - build on sparse test on dense

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git



More information about the debian-science-commits mailing list