[mlpack] 62/207: Add public overload for weighted learning.

Barak A. Pearlmutter barak+git at pearlmutter.net
Thu Mar 23 17:53:41 UTC 2017


This is an automated email from the git hooks/post-receive script.

bap pushed a commit to branch master
in repository mlpack.

commit b880154df2173286e04cb898530aabfc7d216049
Author: Ryan Curtin <ryan at ratml.org>
Date:   Wed Feb 22 09:42:07 2017 -0500

    Add public overload for weighted learning.
---
 src/mlpack/methods/decision_stump/decision_stump.hpp  | 17 +++++++++++++++++
 .../methods/decision_stump/decision_stump_impl.hpp    | 19 +++++++++++++++++++
 2 files changed, 36 insertions(+)

diff --git a/src/mlpack/methods/decision_stump/decision_stump.hpp b/src/mlpack/methods/decision_stump/decision_stump.hpp
index 5918aaa..df4ea58 100644
--- a/src/mlpack/methods/decision_stump/decision_stump.hpp
+++ b/src/mlpack/methods/decision_stump/decision_stump.hpp
@@ -87,6 +87,23 @@ class DecisionStump
              const size_t bucketSize);
 
   /**
+   * Train the decision stump on the given data, with the given weights.  This
+   * completely overwrites any previous training data, so after training the
+   * stump may be completely different.
+   *
+   * @param data Dataset to train on.
+   * @param labels Labels for each point in the dataset.
+   * @param weights Weights for each point in the dataset.
+   * @param classes Number of classes in the dataset.
+   * @param bucketSize Minimum size of bucket when splitting.
+   */
+  void Train(const MatType& data,
+             const arma::Row<size_t>& labels,
+             const arma::rowvec& weights,
+             const size_t classes,
+             const size_t bucketSize);
+
+  /**
    * Classification function. After training, classify test, and put the
    * predicted classes in predictedLabels.
    *
diff --git a/src/mlpack/methods/decision_stump/decision_stump_impl.hpp b/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
index aa7201a..4cd3a4e 100644
--- a/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
+++ b/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
@@ -72,6 +72,25 @@ void DecisionStump<MatType>::Train(const MatType& data,
 }
 
 /**
+ * Train the decision stump on the given data, with the given weights.  This
+ * completely overwrites any previous training data, so after training the
+ * stump may be completely different.
+ */
+template<typename MatType>
+void DecisionStump<MatType>::Train(const MatType& data,
+                                   const arma::Row<size_t>& labels,
+                                   const arma::rowvec& weights,
+                                   const size_t classes,
+                                   const size_t bucketSize)
+{
+  this->classes = classes;
+  this->bucketSize = bucketSize;
+
+  // Pass to weighted training function.
+  Train<true>(data, labels, weights);
+}
+
+/**
  * Train the decision stump on the given data and labels.
  *
  * @param data Dataset to train on.

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git



More information about the debian-science-commits mailing list