[mlpack] 77/324: For clarity, use separate split and binLabels objects instead of storing the label in the split matrix. No casting is necessary anymore.
Barak A. Pearlmutter
barak+git at cs.nuim.ie
Sun Aug 17 08:21:58 UTC 2014
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch svn-trunk
in repository mlpack.
commit 5a40b6a4413fd60023d5737bfcc3e4737cf49ce3
Author: rcurtin <rcurtin at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date: Wed Jun 25 00:22:09 2014 +0000
For clarity, use separate split and binLabels objects instead of storing the label in the split matrix. No casting is necessary anymore.
git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@16708 9d5b8971-822b-0410-80eb-d18c1038ef23
---
.../methods/decision_stump/decision_stump.hpp | 7 ++-
.../methods/decision_stump/decision_stump_impl.hpp | 52 ++++++++++------------
2 files changed, 29 insertions(+), 30 deletions(-)
diff --git a/src/mlpack/methods/decision_stump/decision_stump.hpp b/src/mlpack/methods/decision_stump/decision_stump.hpp
index 5db64be..e1bec19 100644
--- a/src/mlpack/methods/decision_stump/decision_stump.hpp
+++ b/src/mlpack/methods/decision_stump/decision_stump.hpp
@@ -64,8 +64,11 @@ class DecisionStump
//! Stores the class labels for the input data.
arma::Row<size_t> classLabels;
- //! Stores the splitting criterion after training.
- arma::mat split;
+ //! Stores the splitting values after training.
+ arma::vec split;
+
+ //! Stores the labels for each splitting bin.
+ arma::Col<size_t> binLabels;
/**
* Sets up attribute as if it were splitting on it and finds entropy when
diff --git a/src/mlpack/methods/decision_stump/decision_stump_impl.hpp b/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
index 2c757ca..b3d4075 100644
--- a/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
+++ b/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
@@ -39,7 +39,7 @@ DecisionStump<MatType>::DecisionStump(const MatType& data,
bucketSize = inpBucketSize;
// Check whether the input labels are not all identical.
- if (!isDistinct<size_t>(classLabels))
+ if (!isDistinct<size_t>(labels))
{
// If the classLabels are all identical, the default class is the only
// class.
@@ -99,31 +99,21 @@ void DecisionStump<MatType>::Classify(const MatType& test,
{
for (int i = 0; i < test.n_cols; i++)
{
- int j = 0;
-
+ // Determine which bin the test point falls into.
+ // Assume first that it falls into the first bin, then proceed through the
+ // bins until it is known which bin it falls into.
+ int bin = 0;
const double val = test(splitCol, i);
- while (j < split.n_rows)
+
+ while (bin < split.n_elem - 1)
{
- if (val < split(j, 0) && (!j))
- {
- predictedLabels(i) = split(0, 1);
+ if (val < split(bin + 1))
break;
- }
- else if (val >= split(j, 0))
- {
- if (j == split.n_rows - 1)
- {
- predictedLabels(i) = split(split.n_rows - 1, 1);
- break;
- }
- else if (val < split(j + 1, 0))
- {
- predictedLabels(i) = split(j, 1);
- break;
- }
- }
- j++;
+
+ ++bin;
}
+
+ predictedLabels(i) = binLabels(bin);
}
}
else
@@ -243,7 +233,8 @@ void DecisionStump<MatType>::TrainOnAtt(const arma::rowvec& attribute)
arma::uvec sortedSplitIndexAtt = arma::stable_sort_index(attribute.t());
arma::Row<size_t> sortedLabels(attribute.n_elem);
sortedLabels.fill(0);
- arma::mat tempSplit;
+ arma::vec tempSplit;
+ arma::Row<size_t> tempLabel;
for (i = 0; i < attribute.n_elem; i++)
sortedLabels(i) = classLabels(sortedSplitIndexAtt(i));
@@ -267,8 +258,10 @@ void DecisionStump<MatType>::TrainOnAtt(const arma::rowvec& attribute)
mostFreq = CountMostFreq<double>(subCols);
- tempSplit << sortedSplitAtt(begin)<< mostFreq << arma::endr;
- split = arma::join_cols(split, tempSplit);
+ split.resize(split.n_elem + 1);
+ split(split.n_elem - 1) = sortedSplitAtt(begin);
+ binLabels.resize(binLabels.n_elem + 1);
+ binLabels(binLabels.n_elem - 1) = mostFreq;
i++;
}
@@ -297,8 +290,10 @@ void DecisionStump<MatType>::TrainOnAtt(const arma::rowvec& attribute)
// the bucket of subCols.
mostFreq = CountMostFreq<double>(subCols);
- tempSplit << sortedSplitAtt(begin) << mostFreq << arma::endr;
- split = arma::join_cols(split, tempSplit);
+ split.resize(split.n_elem + 1);
+ split(split.n_elem - 1) = sortedSplitAtt(begin);
+ binLabels.resize(binLabels.n_elem + 1);
+ binLabels(binLabels.n_elem - 1) = mostFreq;
i = end + 1;
count = 0;
@@ -321,9 +316,10 @@ void DecisionStump<MatType>::MergeRanges()
{
for (int i = 1; i < split.n_rows; i++)
{
- if (split(i, 1) == split(i - 1, 1))
+ if (binLabels(i) == binLabels(i - 1))
{
// Remove this row, as it has the same label as the previous bucket.
+ binLabels.shed_row(i);
split.shed_row(i);
// Go back to previous row.
i--;
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list