[mlpack] 153/324: Fix some formatting, fix backwards entropy splitting, add getters/setters, and comment a little bit about the internal structure of the class.
Barak A. Pearlmutter
barak+git at cs.nuim.ie
Sun Aug 17 08:22:05 UTC 2014
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch svn-trunk
in repository mlpack.
commit 5ec50c73e042600b64684ecfd0760ded1da3d64d
Author: rcurtin <rcurtin at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date: Wed Jul 9 16:49:43 2014 +0000
Fix some formatting, fix backwards entropy splitting, add getters/setters, and
comment a little bit about the internal structure of the class.
git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@16792 9d5b8971-822b-0410-80eb-d18c1038ef23
---
.../methods/decision_stump/decision_stump.hpp | 26 +++++++-
.../methods/decision_stump/decision_stump_impl.hpp | 73 +++++++++-------------
2 files changed, 53 insertions(+), 46 deletions(-)
diff --git a/src/mlpack/methods/decision_stump/decision_stump.hpp b/src/mlpack/methods/decision_stump/decision_stump.hpp
index 3f90729..2e57d05 100644
--- a/src/mlpack/methods/decision_stump/decision_stump.hpp
+++ b/src/mlpack/methods/decision_stump/decision_stump.hpp
@@ -16,6 +16,14 @@ namespace decision_stump {
* This class implements a decision stump. It constructs a single level
* decision tree, i.e., a decision stump. It uses entropy to decide splitting
* ranges.
+ *
+ * The stump is parameterized by a splitting attribute (the dimension on which
+ * points are split), a vector of bin split values, and a vector of labels for
+ * each bin. Bin i is specified by the range [split[i], split[i + 1]). The
+ * last bin has range up to \infty (split[i + 1] does not exist in that case).
+ * Points that are below the first bin will take the label of the first bin.
+ *
+ * @tparam MatType Type of matrix that is being used (sparse or dense).
*/
template <typename MatType = arma::mat>
class DecisionStump
@@ -45,13 +53,27 @@ class DecisionStump
*/
void Classify(const MatType& test, arma::Row<size_t>& predictedLabels);
- int splitCol;
+ //! Access the splitting attribute.
+ int SplitAttribute() const { return splitAttribute; }
+ //! Modify the splitting attribute (be careful!).
+ int& SplitAttribute() { return splitAttribute; }
+
+ //! Access the splitting values.
+ const arma::vec& Split() const { return split; }
+ //! Modify the splitting values (be careful!).
+ arma::vec& Split() { return split; }
+
+ //! Access the labels for each split bin.
+ const arma::Col<size_t> BinLabels() const { return binLabels; }
+ //! Modify the labels for each split bin (be careful!).
+ arma::Col<size_t>& BinLabels() { return binLabels; }
+
private:
//! Stores the number of classes.
size_t numClass;
//! Stores the value of the attribute on which to split.
- // int splitCol;
+ int splitAttribute;
//! Size of bucket while determining splitting criterion.
size_t bucketSize;
diff --git a/src/mlpack/methods/decision_stump/decision_stump_impl.hpp b/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
index bdf531c..6e02538 100644
--- a/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
+++ b/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
@@ -37,11 +37,7 @@ DecisionStump<MatType>::DecisionStump(const MatType& data,
// If classLabels are not all identical, proceed with training.
int bestAtt = -1;
double entropy;
- double bestEntropy = DBL_MAX;
-
- // Set the default class to handle attribute values which are not present in
- // the training data.
- //defaultClass = CountMostFreq<size_t>(classLabels);
+ double bestEntropy = -DBL_MAX;
for (int i = 0; i < data.n_rows; i++)
{
@@ -52,37 +48,21 @@ DecisionStump<MatType>::DecisionStump(const MatType& data,
// splitting attribute and calculate entropy if split on it.
entropy = SetupSplitAttribute(data.row(i), labels);
- // Find the attribute with the bestEntropy so that the gain is
+ Log::Debug << "Entropy for attribute " << i << " is " << entropy << ".\n";
+
+ // Find the attribute with the best entropy so that the gain is
// maximized.
- if (entropy < bestEntropy)
+ if (entropy > bestEntropy)
{
bestAtt = i;
bestEntropy = entropy;
}
-
- /* This section is commented out because I believe entropy calculation is
- * wrong. Entropy should only be 0 if there is only one class, in which
- * case classification is perfect and we can take the shortcut below.
-
- // If the entropy is 0, then all the labels are the same and we are done.
- Log::Debug << "Entropy is " << entropy << "\n";
- if (entropy == 0)
- {
- // Only one split element... there is no split at all, just one bin.
- split.set_size(1);
- binLabels.set_size(1);
- split[0] = -DBL_MAX;
- binLabels[0] = labels[0];
- splitCol = 0; // It doesn't matter.
- return;
- }
- */
}
}
- splitCol = bestAtt;
+ splitAttribute = bestAtt;
// Once the splitting column/attribute has been decided, train on it.
- TrainOnAtt<double>(data.row(splitCol), labels);
+ TrainOnAtt<double>(data.row(splitAttribute), labels);
}
/**
@@ -103,7 +83,7 @@ void DecisionStump<MatType>::Classify(const MatType& test,
// Assume first that it falls into the first bin, then proceed through the
// bins until it is known which bin it falls into.
int bin = 0;
- const double val = test(splitCol, i);
+ const double val = test(splitAttribute, i);
while (bin < split.n_elem - 1)
{
@@ -147,33 +127,34 @@ double DecisionStump<MatType>::SetupSplitAttribute(
i = 0;
count = 0;
- double ratioEl;
+
// This splits the sorted into buckets of size greater than or equal to
// inpBucketSize.
while (i < sortedLabels.n_elem)
{
count++;
- if (i == sortedLabels.n_elem - 1)
+ if (i == sortedLabels.n_elem - 1)
{
- // if we're at the end, then don't worry about the bucket size
- // just take this as the last bin.
+ // If we're at the end, then don't worry about the bucket size; just take
+ // this as the last bin.
begin = i - count + 1;
end = i;
-
- // using ratioEl to calculate the ratio of elements in this split.
- ratioEl = ((double)(end - begin + 1)/sortedLabels.n_elem);
-
- entropy += ratioEl * CalculateEntropy<size_t>(sortedLabels.subvec(begin,end));
+
+ // Use ratioEl to calculate the ratio of elements in this split.
+ const double ratioEl = ((double) (end - begin + 1) / sortedLabels.n_elem);
+
+ entropy += ratioEl * CalculateEntropy<size_t>(
+ sortedLabels.subvec(begin, end));
i++;
}
else if (sortedLabels(i) != sortedLabels(i + 1))
{
- // if we're not at the last element of sortedLabels, then check whether
+ // If we're not at the last element of sortedLabels, then check whether
// count is less than the current bucket size.
if (count < bucketSize)
{
- // if it is, then take the minimum bucket size anyways
- // this is where the inpBucketSize comes into use
+ // If it is, then take the minimum bucket size anyways.
+ // This is where the inpBucketSize comes into use.
// This makes sure there isn't a bucket for every change in labels.
begin = i - count + 1;
end = begin + bucketSize - 1;
@@ -183,13 +164,14 @@ double DecisionStump<MatType>::SetupSplitAttribute(
}
else
{
- // if it is not, then take the bucket size as the value of count.
+ // If it is not, then take the bucket size as the value of count.
begin = i - count + 1;
end = i;
}
- ratioEl = ((double)(end - begin + 1)/sortedLabels.n_elem);
-
- entropy +=ratioEl * CalculateEntropy<size_t>(sortedLabels.subvec(begin,end));
+ const double ratioEl = ((double) (end - begin + 1) / sortedLabels.n_elem);
+
+ entropy += ratioEl * CalculateEntropy<size_t>(
+ sortedLabels.subvec(begin, end));
i = end + 1;
count = 0;
@@ -321,6 +303,9 @@ rType DecisionStump<MatType>::CountMostFreq(const arma::Row<rType>& subCols)
rType element;
int count = 0, localCount = 0;
+ if (sortCounts.n_elem == 1)
+ return sortCounts[0];
+
// An O(n) loop which counts the most frequent element in sortCounts
for (int i = 0; i < sortCounts.n_elem; ++i)
{
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list