[mlpack] 62/207: Add public overload for weighted learning.
Barak A. Pearlmutter
barak+git at pearlmutter.net
Thu Mar 23 17:53:41 UTC 2017
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch master
in repository mlpack.
commit b880154df2173286e04cb898530aabfc7d216049
Author: Ryan Curtin <ryan at ratml.org>
Date: Wed Feb 22 09:42:07 2017 -0500
Add public overload for weighted learning.
---
src/mlpack/methods/decision_stump/decision_stump.hpp | 17 +++++++++++++++++
.../methods/decision_stump/decision_stump_impl.hpp | 19 +++++++++++++++++++
2 files changed, 36 insertions(+)
diff --git a/src/mlpack/methods/decision_stump/decision_stump.hpp b/src/mlpack/methods/decision_stump/decision_stump.hpp
index 5918aaa..df4ea58 100644
--- a/src/mlpack/methods/decision_stump/decision_stump.hpp
+++ b/src/mlpack/methods/decision_stump/decision_stump.hpp
@@ -87,6 +87,23 @@ class DecisionStump
const size_t bucketSize);
/**
+ * Train the decision stump on the given data, with the given weights. This
+ * completely overwrites any previous training data, so after training the
+ * stump may be completely different.
+ *
+ * @param data Dataset to train on.
+ * @param labels Labels for each point in the dataset.
+ * @param weights Weights for each point in the dataset.
+ * @param classes Number of classes in the dataset.
+ * @param bucketSize Minimum size of bucket when splitting.
+ */
+ void Train(const MatType& data,
+ const arma::Row<size_t>& labels,
+ const arma::rowvec& weights,
+ const size_t classes,
+ const size_t bucketSize);
+
+ /**
* Classification function. After training, classify test, and put the
* predicted classes in predictedLabels.
*
diff --git a/src/mlpack/methods/decision_stump/decision_stump_impl.hpp b/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
index aa7201a..4cd3a4e 100644
--- a/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
+++ b/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
@@ -72,6 +72,25 @@ void DecisionStump<MatType>::Train(const MatType& data,
}
/**
+ * Train the decision stump on the given data, with the given weights. This
+ * completely overwrites any previous training data, so after training the
+ * stump may be completely different.
+ */
+template<typename MatType>
+void DecisionStump<MatType>::Train(const MatType& data,
+ const arma::Row<size_t>& labels,
+ const arma::rowvec& weights,
+ const size_t classes,
+ const size_t bucketSize)
+{
+ this->classes = classes;
+ this->bucketSize = bucketSize;
+
+ // Pass to weighted training function.
+ Train<true>(data, labels, weights);
+}
+
+/**
* Train the decision stump on the given data and labels.
*
* @param data Dataset to train on.
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list