[mlpack] 77/324: For clarity, use separate split and binLabels objects instead of storing the label in the split matrix. No casting is necessary anymore.

Barak A. Pearlmutter barak+git at cs.nuim.ie
Sun Aug 17 08:21:58 UTC 2014


This is an automated email from the git hooks/post-receive script.

bap pushed a commit to branch svn-trunk
in repository mlpack.

commit 5a40b6a4413fd60023d5737bfcc3e4737cf49ce3
Author: rcurtin <rcurtin at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date:   Wed Jun 25 00:22:09 2014 +0000

    For clarity, use separate split and binLabels objects instead of storing the label in the split matrix.  No casting is necessary anymore.
    
    
    git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@16708 9d5b8971-822b-0410-80eb-d18c1038ef23
---
 .../methods/decision_stump/decision_stump.hpp      |  7 ++-
 .../methods/decision_stump/decision_stump_impl.hpp | 52 ++++++++++------------
 2 files changed, 29 insertions(+), 30 deletions(-)

diff --git a/src/mlpack/methods/decision_stump/decision_stump.hpp b/src/mlpack/methods/decision_stump/decision_stump.hpp
index 5db64be..e1bec19 100644
--- a/src/mlpack/methods/decision_stump/decision_stump.hpp
+++ b/src/mlpack/methods/decision_stump/decision_stump.hpp
@@ -64,8 +64,11 @@ class DecisionStump
   //! Stores the class labels for the input data.
   arma::Row<size_t> classLabels;
 
-  //! Stores the splitting criterion after training.
-  arma::mat split;
+  //! Stores the splitting values after training.
+  arma::vec split;
+
+  //! Stores the labels for each splitting bin.
+  arma::Col<size_t> binLabels;
 
   /**
    * Sets up attribute as if it were splitting on it and finds entropy when
diff --git a/src/mlpack/methods/decision_stump/decision_stump_impl.hpp b/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
index 2c757ca..b3d4075 100644
--- a/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
+++ b/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
@@ -39,7 +39,7 @@ DecisionStump<MatType>::DecisionStump(const MatType& data,
   bucketSize = inpBucketSize;
 
   // Check whether the input labels are not all identical.
-  if (!isDistinct<size_t>(classLabels))
+  if (!isDistinct<size_t>(labels))
   {
     // If the classLabels are all identical, the default class is the only
     // class.
@@ -99,31 +99,21 @@ void DecisionStump<MatType>::Classify(const MatType& test,
   {
     for (int i = 0; i < test.n_cols; i++)
     {
-      int j = 0;
-
+      // Determine which bin the test point falls into.
+      // Assume first that it falls into the first bin, then proceed through the
+      // bins until it is known which bin it falls into.
+      int bin = 0;
       const double val = test(splitCol, i);
-      while (j < split.n_rows)
+
+      while (bin < split.n_elem - 1)
       {
-        if (val < split(j, 0) && (!j))
-        {
-          predictedLabels(i) = split(0, 1);
+        if (val < split(bin + 1))
           break;
-        }
-        else if (val >= split(j, 0))
-        {
-          if (j == split.n_rows - 1)
-          {
-            predictedLabels(i) = split(split.n_rows - 1, 1);
-            break;
-          }
-          else if (val < split(j + 1, 0))
-          {
-            predictedLabels(i) = split(j, 1);
-            break;
-          }
-        }
-        j++;
+
+        ++bin;
       }
+
+      predictedLabels(i) = binLabels(bin);
     }
   }
   else
@@ -243,7 +233,8 @@ void DecisionStump<MatType>::TrainOnAtt(const arma::rowvec& attribute)
   arma::uvec sortedSplitIndexAtt = arma::stable_sort_index(attribute.t());
   arma::Row<size_t> sortedLabels(attribute.n_elem);
   sortedLabels.fill(0);
-  arma::mat tempSplit;
+  arma::vec tempSplit;
+  arma::Row<size_t> tempLabel;
 
   for (i = 0; i < attribute.n_elem; i++)
     sortedLabels(i) = classLabels(sortedSplitIndexAtt(i));
@@ -267,8 +258,10 @@ void DecisionStump<MatType>::TrainOnAtt(const arma::rowvec& attribute)
 
       mostFreq = CountMostFreq<double>(subCols);
 
-      tempSplit << sortedSplitAtt(begin)<< mostFreq << arma::endr;
-      split = arma::join_cols(split, tempSplit);
+      split.resize(split.n_elem + 1);
+      split(split.n_elem - 1) = sortedSplitAtt(begin);
+      binLabels.resize(binLabels.n_elem + 1);
+      binLabels(binLabels.n_elem - 1) = mostFreq;
 
       i++;
     }
@@ -297,8 +290,10 @@ void DecisionStump<MatType>::TrainOnAtt(const arma::rowvec& attribute)
       // the bucket of subCols.
       mostFreq = CountMostFreq<double>(subCols);
 
-      tempSplit << sortedSplitAtt(begin) << mostFreq << arma::endr;
-      split = arma::join_cols(split, tempSplit);
+      split.resize(split.n_elem + 1);
+      split(split.n_elem - 1) = sortedSplitAtt(begin);
+      binLabels.resize(binLabels.n_elem + 1);
+      binLabels(binLabels.n_elem - 1) = mostFreq;
 
       i = end + 1;
       count = 0;
@@ -321,9 +316,10 @@ void DecisionStump<MatType>::MergeRanges()
 {
   for (int i = 1; i < split.n_rows; i++)
   {
-    if (split(i, 1) == split(i - 1, 1))
+    if (binLabels(i) == binLabels(i - 1))
     {
       // Remove this row, as it has the same label as the previous bucket.
+      binLabels.shed_row(i);
       split.shed_row(i);
       // Go back to previous row.
       i--;

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git



More information about the debian-science-commits mailing list