[mlpack] 157/324: Entropy calculation improved.
Barak A. Pearlmutter
barak+git at cs.nuim.ie
Sun Aug 17 08:22:06 UTC 2014
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch svn-trunk
in repository mlpack.
commit b942dc823dd1223fb084de4797419673a0935a9b
Author: saxena.udit <saxena.udit at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date: Wed Jul 9 19:19:52 2014 +0000
Entropy calculation improved.
git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@16796 9d5b8971-822b-0410-80eb-d18c1038ef23
---
.../methods/decision_stump/decision_stump_impl.hpp | 20 ++++++++++++-----
src/mlpack/tests/decision_stump_test.cpp | 26 ++++++++++++++++++----
2 files changed, 36 insertions(+), 10 deletions(-)
diff --git a/src/mlpack/methods/decision_stump/decision_stump_impl.hpp b/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
index 6e02538..e30168f 100644
--- a/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
+++ b/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
@@ -35,10 +35,12 @@ DecisionStump<MatType>::DecisionStump(const MatType& data,
bucketSize = inpBucketSize;
// If classLabels are not all identical, proceed with training.
- int bestAtt = -1;
+ int bestAtt = 0;
double entropy;
- double bestEntropy = -DBL_MAX;
-
+ double rootEntropy = CalculateEntropy<size_t>(labels.subvec(0,labels.n_elem-1));
+ // std::cout<<"rootEntropy is: "<<rootEntropy<<"\n";
+ // double bestEntropy = DBL_MAX;
+ double gain, bestGain = 0.0;
for (int i = 0; i < data.n_rows; i++)
{
// Go through each attribute of the data.
@@ -49,13 +51,18 @@ DecisionStump<MatType>::DecisionStump(const MatType& data,
entropy = SetupSplitAttribute(data.row(i), labels);
Log::Debug << "Entropy for attribute " << i << " is " << entropy << ".\n";
-
+ gain = rootEntropy - entropy;
// Find the attribute with the best entropy so that the gain is
// maximized.
- if (entropy > bestEntropy)
+
+ // if (entropy < bestEntropy)
+ // Instead of the above rule, we are maximizing gain, which was
+ // what is returned from SetupSplitAttribute.
+ if (gain < bestGain)
{
bestAtt = i;
- bestEntropy = entropy;
+ // bestEntropy = entropy;
+ bestGain = gain;
}
}
}
@@ -380,6 +387,7 @@ double DecisionStump<MatType>::CalculateEntropy(arma::subview_row<LabelType> lab
entropy += (p1 == 0) ? 0 : p1 * log2(p1);
}
+
return entropy;
}
diff --git a/src/mlpack/tests/decision_stump_test.cpp b/src/mlpack/tests/decision_stump_test.cpp
index 1fec0bc..a56dff8 100644
--- a/src/mlpack/tests/decision_stump_test.cpp
+++ b/src/mlpack/tests/decision_stump_test.cpp
@@ -221,7 +221,7 @@ BOOST_AUTO_TEST_CASE(MultiClassSplit)
BOOST_AUTO_TEST_CASE(DimensionSelectionTest)
{
const size_t numClasses = 2;
- const size_t inpBucketSize = 25;
+ const size_t inpBucketSize = 2500;
arma::mat dataset(4, 5000);
@@ -294,17 +294,35 @@ BOOST_AUTO_TEST_CASE(DimensionSelectionTest)
DecisionStump<> ds(dataset, labels, numClasses, inpBucketSize);
// Make sure it split on the dimension that is most separable.
- BOOST_REQUIRE_EQUAL(ds.SplitAttribute(), 1);
+ BOOST_CHECK_EQUAL(ds.SplitAttribute(), 1);
// Make sure every bin below -1 classifies as label 0, and every bin above 1
// classifies as label 1 (What happens in [-1, 1] isn't that big a deal.).
for (size_t i = 0; i < ds.Split().n_elem; ++i)
{
if (ds.Split()[i] <= -3.0)
- BOOST_REQUIRE_EQUAL(ds.BinLabels()[i], 0);
+ BOOST_CHECK_EQUAL(ds.BinLabels()[i], 0);
else if (ds.Split()[i] >= 3.0)
- BOOST_REQUIRE_EQUAL(ds.BinLabels()[i], 1);
+ BOOST_CHECK_EQUAL(ds.BinLabels()[i], 1);
}
}
+BOOST_AUTO_TEST_CASE(TempAttributeSplit)
+{
+ const size_t numClasses = 2;
+ const size_t inpBucketSize = 3;
+
+ mat trainingData;
+ trainingData << 1 << 1 << 1 << 2 << 2 << 2 << endr
+ << 0.5 << 0.6 << 0.7 << 0.4 << 0.3 << 0.5 << endr;
+
+ Mat<size_t> labelsIn;
+ labelsIn << 0 << 0 << 0 << 0 << 1 << 1 << 1;
+
+ DecisionStump<> ds(trainingData, labelsIn.row(0), numClasses, inpBucketSize);
+
+ // Row<size_t> predictedLabels(testingData.n_cols);
+ // ds.Classify(testingData, predictedLabels);
+ BOOST_CHECK_EQUAL(ds.SplitAttribute(), 0);
+}
BOOST_AUTO_TEST_SUITE_END();
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list