[mlpack] 12/30: Fix too-specific DecisionStump implementation.
Barak A. Pearlmutter
barak+git at pearlmutter.net
Mon Dec 26 10:15:26 UTC 2016
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch master
in repository mlpack.
commit d26bfcb6714e3aaabbf71098e995b833559fc5b9
Author: Ryan Curtin <ryan at ratml.org>
Date: Thu Dec 1 10:53:59 2016 -0500
Fix too-specific DecisionStump implementation.
Now it works (and is tested) with imat.
---
.../methods/decision_stump/decision_stump.hpp | 4 +--
.../methods/decision_stump/decision_stump_impl.hpp | 9 ++----
src/mlpack/tests/decision_stump_test.cpp | 34 ++++++++++++++++++++++
3 files changed, 39 insertions(+), 8 deletions(-)
diff --git a/src/mlpack/methods/decision_stump/decision_stump.hpp b/src/mlpack/methods/decision_stump/decision_stump.hpp
index f3b9c6b..b58ff4c 100644
--- a/src/mlpack/methods/decision_stump/decision_stump.hpp
+++ b/src/mlpack/methods/decision_stump/decision_stump.hpp
@@ -136,8 +136,8 @@ class DecisionStump
* candidate for the splitting dimension.
* @tparam UseWeights Whether we need to run a weighted Decision Stump.
*/
- template<bool UseWeights>
- double SetupSplitDimension(const arma::rowvec& dimension,
+ template<bool UseWeights, typename VecType>
+ double SetupSplitDimension(const VecType& dimension,
const arma::Row<size_t>& labels,
const arma::rowvec& weightD);
diff --git a/src/mlpack/methods/decision_stump/decision_stump_impl.hpp b/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
index 9143148..aa7201a 100644
--- a/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
+++ b/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
@@ -202,18 +202,15 @@ void DecisionStump<MatType>::Serialize(Archive& ar,
* @param UseWeights Whether we need to run a weighted Decision Stump.
*/
template<typename MatType>
-template<bool UseWeights>
+template<bool UseWeights, typename VecType>
double DecisionStump<MatType>::SetupSplitDimension(
- const arma::rowvec& dimension,
+ const VecType& dimension,
const arma::Row<size_t>& labels,
const arma::rowvec& weights)
{
size_t i, count, begin, end;
double entropy = 0.0;
- // Sort the dimension in order to calculate splitting ranges.
- arma::rowvec sortedDim = arma::sort(dimension);
-
// Store the indices of the sorted dimension to build a vector of sorted
// labels. This sort is stable.
arma::uvec sortedIndexDim = arma::stable_sort_index(dimension.t());
@@ -301,7 +298,7 @@ void DecisionStump<MatType>::TrainOnDim(const VecType& dimension,
{
size_t i, count, begin, end;
- arma::rowvec sortedSplitDim = arma::sort(dimension);
+ typename MatType::row_type sortedSplitDim = arma::sort(dimension);
arma::uvec sortedSplitIndexDim = arma::stable_sort_index(dimension.t());
arma::Row<size_t> sortedLabels(dimension.n_elem);
sortedLabels.fill(0);
diff --git a/src/mlpack/tests/decision_stump_test.cpp b/src/mlpack/tests/decision_stump_test.cpp
index 3a1e7dc..af1e5f7 100644
--- a/src/mlpack/tests/decision_stump_test.cpp
+++ b/src/mlpack/tests/decision_stump_test.cpp
@@ -358,4 +358,38 @@ BOOST_AUTO_TEST_CASE(EmptyConstructorTest)
BOOST_CHECK_EQUAL(predictedLabels(0, 7), 2);
}
+/**
+ * Ensure that a matrix holding ints can be trained. The bigger issue here is
+ * just compilation.
+ */
+BOOST_AUTO_TEST_CASE(IntTest)
+{
+ // Train on a dataset and make sure something kind of makes sense.
+ imat trainingData;
+ trainingData << -7 << -6 << -5 << -4 << -3 << -2 << -1 << 0 << 1
+ << 2 << 3 << 4 << 5 << 6 << 7 << 8 << 9 << 10;
+
+ // No need to normalize labels here.
+ Mat<size_t> labelsIn;
+ labelsIn << 0 << 0 << 0 << 0 << 1 << 1 << 0 << 0
+ << 1 << 1 << 1 << 2 << 1 << 2 << 2 << 2 << 2 << 2;
+
+ DecisionStump<arma::imat> ds(trainingData, labelsIn.row(0), 4, 3);
+
+ imat testingData;
+ testingData << -6 << -6 << -2 << -1 << 3 << 5 << 7 << 9;
+
+ arma::Row<size_t> predictedLabels;
+ ds.Classify(testingData, predictedLabels);
+
+ BOOST_CHECK_EQUAL(predictedLabels(0, 0), 0);
+ BOOST_CHECK_EQUAL(predictedLabels(0, 1), 0);
+ BOOST_CHECK_EQUAL(predictedLabels(0, 2), 1);
+ BOOST_CHECK_EQUAL(predictedLabels(0, 3), 1);
+ BOOST_CHECK_EQUAL(predictedLabels(0, 4), 1);
+ BOOST_CHECK_EQUAL(predictedLabels(0, 5), 1);
+ BOOST_CHECK_EQUAL(predictedLabels(0, 6), 2);
+ BOOST_CHECK_EQUAL(predictedLabels(0, 7), 2);
+}
+
BOOST_AUTO_TEST_SUITE_END();
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list