[mlpack] 12/30: Fix too-specific DecisionStump implementation.

Barak A. Pearlmutter barak+git at pearlmutter.net
Mon Dec 26 10:15:26 UTC 2016


This is an automated email from the git hooks/post-receive script.

bap pushed a commit to branch master
in repository mlpack.

commit d26bfcb6714e3aaabbf71098e995b833559fc5b9
Author: Ryan Curtin <ryan at ratml.org>
Date:   Thu Dec 1 10:53:59 2016 -0500

    Fix too-specific DecisionStump implementation.
    
    Now it works (and is tested) with imat.
---
 .../methods/decision_stump/decision_stump.hpp      |  4 +--
 .../methods/decision_stump/decision_stump_impl.hpp |  9 ++----
 src/mlpack/tests/decision_stump_test.cpp           | 34 ++++++++++++++++++++++
 3 files changed, 39 insertions(+), 8 deletions(-)

diff --git a/src/mlpack/methods/decision_stump/decision_stump.hpp b/src/mlpack/methods/decision_stump/decision_stump.hpp
index f3b9c6b..b58ff4c 100644
--- a/src/mlpack/methods/decision_stump/decision_stump.hpp
+++ b/src/mlpack/methods/decision_stump/decision_stump.hpp
@@ -136,8 +136,8 @@ class DecisionStump
    *     candidate for the splitting dimension.
    * @tparam UseWeights Whether we need to run a weighted Decision Stump.
    */
-  template<bool UseWeights>
-  double SetupSplitDimension(const arma::rowvec& dimension,
+  template<bool UseWeights, typename VecType>
+  double SetupSplitDimension(const VecType& dimension,
                              const arma::Row<size_t>& labels,
                              const arma::rowvec& weightD);
 
diff --git a/src/mlpack/methods/decision_stump/decision_stump_impl.hpp b/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
index 9143148..aa7201a 100644
--- a/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
+++ b/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
@@ -202,18 +202,15 @@ void DecisionStump<MatType>::Serialize(Archive& ar,
  * @param UseWeights Whether we need to run a weighted Decision Stump.
  */
 template<typename MatType>
-template<bool UseWeights>
+template<bool UseWeights, typename VecType>
 double DecisionStump<MatType>::SetupSplitDimension(
-    const arma::rowvec& dimension,
+    const VecType& dimension,
     const arma::Row<size_t>& labels,
     const arma::rowvec& weights)
 {
   size_t i, count, begin, end;
   double entropy = 0.0;
 
-  // Sort the dimension in order to calculate splitting ranges.
-  arma::rowvec sortedDim = arma::sort(dimension);
-
   // Store the indices of the sorted dimension to build a vector of sorted
   // labels.  This sort is stable.
   arma::uvec sortedIndexDim = arma::stable_sort_index(dimension.t());
@@ -301,7 +298,7 @@ void DecisionStump<MatType>::TrainOnDim(const VecType& dimension,
 {
   size_t i, count, begin, end;
 
-  arma::rowvec sortedSplitDim = arma::sort(dimension);
+  typename MatType::row_type sortedSplitDim = arma::sort(dimension);
   arma::uvec sortedSplitIndexDim = arma::stable_sort_index(dimension.t());
   arma::Row<size_t> sortedLabels(dimension.n_elem);
   sortedLabels.fill(0);
diff --git a/src/mlpack/tests/decision_stump_test.cpp b/src/mlpack/tests/decision_stump_test.cpp
index 3a1e7dc..af1e5f7 100644
--- a/src/mlpack/tests/decision_stump_test.cpp
+++ b/src/mlpack/tests/decision_stump_test.cpp
@@ -358,4 +358,38 @@ BOOST_AUTO_TEST_CASE(EmptyConstructorTest)
   BOOST_CHECK_EQUAL(predictedLabels(0, 7), 2);
 }
 
+/**
+ * Ensure that a matrix holding ints can be trained.  The bigger issue here is
+ * just compilation.
+ */
+BOOST_AUTO_TEST_CASE(IntTest)
+{
+  // Train on a dataset and make sure something kind of makes sense.
+  imat trainingData;
+  trainingData << -7 << -6 << -5 << -4 << -3 << -2 << -1 << 0 << 1
+               << 2  << 3  << 4  << 5  << 6  << 7  << 8  << 9 << 10;
+
+  // No need to normalize labels here.
+  Mat<size_t> labelsIn;
+  labelsIn << 0 << 0 << 0 << 0 << 1 << 1 << 0 << 0
+           << 1 << 1 << 1 << 2 << 1 << 2 << 2 << 2 << 2 << 2;
+
+  DecisionStump<arma::imat> ds(trainingData, labelsIn.row(0), 4, 3);
+
+  imat testingData;
+  testingData << -6 << -6 << -2 << -1 << 3 << 5 << 7 << 9;
+
+  arma::Row<size_t> predictedLabels;
+  ds.Classify(testingData, predictedLabels);
+
+  BOOST_CHECK_EQUAL(predictedLabels(0, 0), 0);
+  BOOST_CHECK_EQUAL(predictedLabels(0, 1), 0);
+  BOOST_CHECK_EQUAL(predictedLabels(0, 2), 1);
+  BOOST_CHECK_EQUAL(predictedLabels(0, 3), 1);
+  BOOST_CHECK_EQUAL(predictedLabels(0, 4), 1);
+  BOOST_CHECK_EQUAL(predictedLabels(0, 5), 1);
+  BOOST_CHECK_EQUAL(predictedLabels(0, 6), 2);
+  BOOST_CHECK_EQUAL(predictedLabels(0, 7), 2);
+}
+
 BOOST_AUTO_TEST_SUITE_END();

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git



More information about the debian-science-commits mailing list