[mlpack] 322/324: Decision Stumps modified, along with adding Classify() function to AdaBoost. Other minor changes (renaming).
Barak A. Pearlmutter
barak+git at cs.nuim.ie
Sun Aug 17 08:22:22 UTC 2014
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch svn-trunk
in repository mlpack.
commit b669269cd00e80486aa496c2c5c260ce7f0cdfb7
Author: saxena.udit <saxena.udit at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date: Sat Aug 16 16:00:06 2014 +0000
Decision Stumps modified, along with adding Classify() function to AdaBoost. Other minor changes (renaming).
git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@17042 9d5b8971-822b-0410-80eb-d18c1038ef23
---
src/mlpack/methods/adaboost/adaboost.hpp | 16 +++--
src/mlpack/methods/adaboost/adaboost_impl.hpp | 73 +++++++++++++++----
src/mlpack/methods/adaboost/adaboost_main.cpp | 21 ++++--
.../methods/decision_stump/decision_stump.hpp | 19 +++--
.../methods/decision_stump/decision_stump_impl.hpp | 82 ++++++++++++++--------
.../methods/decision_stump/decision_stump_main.cpp | 3 +-
src/mlpack/tests/adaboost_test.cpp | 61 ++++++++--------
src/mlpack/tests/decision_stump_test.cpp | 2 +-
8 files changed, 190 insertions(+), 87 deletions(-)
diff --git a/src/mlpack/methods/adaboost/adaboost.hpp b/src/mlpack/methods/adaboost/adaboost.hpp
index b013355..58ee336 100644
--- a/src/mlpack/methods/adaboost/adaboost.hpp
+++ b/src/mlpack/methods/adaboost/adaboost.hpp
@@ -32,11 +32,11 @@ namespace adaboost {
template<typename MatType = arma::mat,
typename WeakLearner = mlpack::perceptron::Perceptron<> >
-class Adaboost
+class AdaBoost
{
public:
/**
- * Constructor. Currently runs the Adaboost.mh algorithm.
+ * Constructor. Currently runs the AdaBoost.mh algorithm.
*
* @param data Input data.
* @param labels Corresponding labels.
@@ -44,7 +44,7 @@ class Adaboost
* @param tol The tolerance for change in values of rt.
* @param other Weak Learner, which has been initialized already.
*/
- Adaboost(const MatType& data,
+ AdaBoost(const MatType& data,
const arma::Row<size_t>& labels,
const int iterations,
const double tol,
@@ -59,6 +59,8 @@ class Adaboost
// The tolerance for change in rt and when to stop.
double tolerance;
+ void Classify(const MatType& test, arma::Row<size_t>& predictedLabels);
+
private:
/**
* This function helps in building the Weight Distribution matrix
@@ -72,8 +74,14 @@ private:
*/
void BuildWeightMatrix(const arma::mat& D, arma::rowvec& weights);
+ size_t numClasses;
+
+ std::vector<WeakLearner> wl;
+ std::vector<double> alpha;
+ std::vector<double> z;
+
-}; // class Adaboost
+}; // class AdaBoost
} // namespace adaboost
} // namespace mlpack
diff --git a/src/mlpack/methods/adaboost/adaboost_impl.hpp b/src/mlpack/methods/adaboost/adaboost_impl.hpp
index 70cc9a7..187ce2f 100644
--- a/src/mlpack/methods/adaboost/adaboost_impl.hpp
+++ b/src/mlpack/methods/adaboost/adaboost_impl.hpp
@@ -2,7 +2,7 @@
* @file adaboost_impl.hpp
* @author Udit Saxena
*
- * Implementation of the Adaboost class.
+ * Implementation of the AdaBoost class.
*
* @code
* @article{schapire1999improved,
@@ -27,7 +27,7 @@ namespace mlpack {
namespace adaboost {
/**
- * Constructor. Currently runs the Adaboost.mh algorithm
+ * Constructor. Currently runs the AdaBoost.mh algorithm
*
* @param data Input data
* @param labels Corresponding labels
@@ -35,7 +35,7 @@ namespace adaboost {
* @param other Weak Learner, which has been initialized already
*/
template<typename MatType, typename WeakLearner>
-Adaboost<MatType, WeakLearner>::Adaboost(
+AdaBoost<MatType, WeakLearner>::AdaBoost(
const MatType& data,
const arma::Row<size_t>& labels,
const int iterations,
@@ -43,7 +43,7 @@ Adaboost<MatType, WeakLearner>::Adaboost(
const WeakLearner& other)
{
// Count the number of classes.
- const size_t numClasses = (arma::max(labels) - arma::min(labels)) + 1;
+ numClasses = (arma::max(labels) - arma::min(labels)) + 1;
tolerance = tol;
double rt, crt, alphat = 0.0, zt;
@@ -97,6 +97,7 @@ Adaboost<MatType, WeakLearner>::Adaboost(
// Build the weight vectors
BuildWeightMatrix(D, weights);
+ // std::cout<<"Just about to call the weak leaerner. \n";
// call the other weak learner and train the labels.
WeakLearner w(other, tempData, weights, labels);
w.Classify(tempData, predictedLabels);
@@ -110,14 +111,16 @@ Adaboost<MatType, WeakLearner>::Adaboost(
{
if (predictedLabels(j) == labels(j))
{
- for (int k = 0;k < numClasses; k++)
- rt += D(j,k);
+ // for (int k = 0;k < numClasses; k++)
+ // rt += D(j,k);
+ rt += arma::accu(D.row(j));
}
else
{
- for (int k = 0;k < numClasses; k++)
- rt -= D(j,k);
+ // for (int k = 0;k < numClasses; k++)
+ // rt -= D(j,k);
+ rt -= arma::accu(D.row(j));
}
}
// end calculation of rt
@@ -136,6 +139,9 @@ Adaboost<MatType, WeakLearner>::Adaboost(
alphat = 0.5 * log((1 + rt) / (1 - rt));
// end calculation of alphat
+ alpha.push_back(alphat);
+ wl.push_back(w);
+
// now start modifying weights
for (int j = 0;j < D.n_rows; j++)
{
@@ -178,17 +184,20 @@ Adaboost<MatType, WeakLearner>::Adaboost(
// Accumulating the value of zt for the Hamming Loss bound.
ztAccumulator *= zt;
+ z.push_back(zt);
}
// Iterations are over, now build a strong hypothesis
// from a weighted combination of these weak hypotheses.
- arma::rowvec tempSumFinalH;
+ // std::cout<<"Just about to look at the final hypo.\n";
+ arma::colvec tempSumFinalH;
arma::uword max_index;
-
- for (int i = 0;i < sumFinalH.n_rows; i++)
+ arma::mat sfh = sumFinalH.t();
+
+ for (int i = 0;i < sfh.n_cols; i++)
{
- tempSumFinalH = sumFinalH.row(i);
+ tempSumFinalH = sfh.col(i);
tempSumFinalH.max(max_index);
finalH(i) = max_index;
}
@@ -196,6 +205,44 @@ Adaboost<MatType, WeakLearner>::Adaboost(
}
/**
+ *
+ */
+ template <typename MatType, typename WeakLearner>
+ void AdaBoost<MatType, WeakLearner>::Classify(
+ const MatType& test,
+ arma::Row<size_t>& predictedLabels
+ )
+ {
+ arma::Row<size_t> tempPredictedLabels(predictedLabels.n_cols);
+ arma::mat cMatrix(test.n_cols, numClasses);
+
+ cMatrix.fill(0.0);
+ predictedLabels.fill(0);
+
+ for(int i = 0;i < wl.size();i++)
+ {
+ wl[i].Classify(test,tempPredictedLabels);
+
+ for(int j = 0;j < tempPredictedLabels.n_cols; j++)
+ {
+ cMatrix(j,tempPredictedLabels(j)) += (alpha[i] * tempPredictedLabels(j));
+ }
+
+ }
+
+ arma::rowvec cMRow;
+ arma::uword max_index;
+
+ for(int i = 0;i < predictedLabels.n_cols;i++)
+ {
+ cMRow = cMatrix.row(i);
+ cMRow.max(max_index);
+ predictedLabels(i) = max_index;
+ }
+
+ }
+
+/**
* This function helps in building the Weight Distribution matrix
* which is updated during every iteration. It calculates the
* "difficulty" in classifying a point by adding the weights for all
@@ -206,7 +253,7 @@ Adaboost<MatType, WeakLearner>::Adaboost(
* @param weights The output weight vector.
*/
template <typename MatType, typename WeakLearner>
-void Adaboost<MatType, WeakLearner>::BuildWeightMatrix(
+void AdaBoost<MatType, WeakLearner>::BuildWeightMatrix(
const arma::mat& D,
arma::rowvec& weights)
{
diff --git a/src/mlpack/methods/adaboost/adaboost_main.cpp b/src/mlpack/methods/adaboost/adaboost_main.cpp
index d6a1c12..ecc1cbc 100644
--- a/src/mlpack/methods/adaboost/adaboost_main.cpp
+++ b/src/mlpack/methods/adaboost/adaboost_main.cpp
@@ -2,7 +2,7 @@
* @file: adaboost_main.cpp
* @author: Udit Saxena
*
- * Implementation of the Adaboost main file
+ * Implementation of the AdaBoost main file
*
* @code
* @article{Schapire:1999:IBA:337859.337870,
@@ -37,8 +37,8 @@ using namespace std;
using namespace arma;
using namespace mlpack::adaboost;
-PROGRAM_INFO("Adaboost","This program implements the Adaboost (or Adaptive Boost)"
- " algorithm. The variant of Adaboost implemented here is Adaboost.mh. It uses a"
+PROGRAM_INFO("AdaBoost","This program implements the AdaBoost (or Adaptive Boost)"
+ " algorithm. The variant of AdaBoost implemented here is AdaBoost.mh. It uses a"
" weak learner, either of Decision Stumps or a Perceptron, and over many"
" iterations, creates a strong learner. It runs these iterations till a tolerance"
" value is crossed for change in the value of rt."
@@ -64,7 +64,7 @@ PARAM_STRING("output", "The file in which the predicted labels for the test set"
" will be written.", "o", "output.csv");
PARAM_INT("iterations","The maximum number of boosting iterations "
"to be run", "i", 1000);
-PARAM_INT_REQ("classes","The number of classes in the input label set.","c");
+// PARAM_INT("classes","The number of classes in the input label set.","c");
PARAM_DOUBLE("tolerance","The tolerance for change in values of rt","e",1e-10);
int main(int argc, char *argv[])
@@ -129,8 +129,19 @@ int main(int argc, char *argv[])
perceptron::Perceptron<> p(trainingData, labels.t(), iter);
Timer::Start("Training");
- Adaboost<> a(trainingData, labels.t(), iterations, tolerance, p);
+ AdaBoost<> a(trainingData, labels.t(), iterations, tolerance, p);
Timer::Stop("Training");
+ Row<size_t> predictedLabels(testingData.n_cols);
+ Timer::Start("testing");
+ a.Classify(testingData, predictedLabels);
+ Timer::Stop("testing");
+
+ vec results;
+ data::RevertLabels(predictedLabels.t(), mappings, results);
+
+ // Save the predicted labels in a transposed form as output.
+ const string outputFilename = CLI::GetParam<string>("output_file");
+ data::Save(outputFilename, results, true, false);
return 0;
}
\ No newline at end of file
diff --git a/src/mlpack/methods/decision_stump/decision_stump.hpp b/src/mlpack/methods/decision_stump/decision_stump.hpp
index 895fd4a..de1418a 100644
--- a/src/mlpack/methods/decision_stump/decision_stump.hpp
+++ b/src/mlpack/methods/decision_stump/decision_stump.hpp
@@ -63,11 +63,13 @@ class DecisionStump
* @param data The data on which to train this object on.
* @param D Weight vector to use while training. For boosting purposes.
* @param labels The labels of data.
+ * @param isWeight Whether we need to run a weighted Decision Stump.
*/
DecisionStump(const DecisionStump<>& other,
const MatType& data,
const arma::rowvec& weights,
- const arma::Row<size_t>& labels);
+ const arma::Row<size_t>& labels
+ );
//! Access the splitting attribute.
int SplitAttribute() const { return splitAttribute; }
@@ -106,9 +108,12 @@ class DecisionStump
*
* @param attribute A row from the training data, which might be a
* candidate for the splitting attribute.
+ * @param isWeight Whether we need to run a weighted Decision Stump.
*/
+ template <typename W>
double SetupSplitAttribute(const arma::rowvec& attribute,
- const arma::Row<size_t>& labels);
+ const arma::Row<size_t>& labels,
+ W isWeight);
/**
* After having decided the attribute on which to split, train on that
@@ -147,17 +152,21 @@ class DecisionStump
*
* @param attribute The attribute of which we calculate the entropy.
* @param labels Corresponding labels of the attribute.
+ * @param isWeight Whether we need to run a weighted Decision Stump.
*/
- template <typename LabelType>
- double CalculateEntropy(arma::subview_row<LabelType> labels, int begin);
+ template <typename LabelType, typename W>
+ double CalculateEntropy(arma::subview_row<LabelType> labels, int begin,
+ W isWeight);
/**
* Train the decision stump on the given data and labels.
*
* @param data Dataset to train on.
* @param labels Labels for dataset.
+ * @param isWeight Whether we need to run a weighted Decision Stump.
*/
- void Train(const MatType& data, const arma::Row<size_t>& labels);
+ template <typename W>
+ void Train(const MatType& data, const arma::Row<size_t>& labels, W isWeight);
//! To store the weight vectors for boosting purposes.
arma::rowvec weightD;
diff --git a/src/mlpack/methods/decision_stump/decision_stump_impl.hpp b/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
index 348ab9a..e3b5824 100644
--- a/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
+++ b/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
@@ -30,22 +30,27 @@ DecisionStump<MatType>::DecisionStump(const MatType& data,
{
numClass = classes;
bucketSize = inpBucketSize;
+ const bool isWeight = false;
- weightD = arma::rowvec(data.n_cols);
- weightD.fill(1.0);
- tempD = weightD;
-
- Train(data, labels);
+ Train<bool>(data, labels, isWeight);
}
+/**
+ * Train the decision stump on the given data and labels.
+ *
+ * @param data Dataset to train on.
+ * @param labels Labels for dataset.
+ * @param isWeight Whether we need to run a weighted Decision Stump.
+ */
template<typename MatType>
-void DecisionStump<MatType>::Train(const MatType& data, const arma::Row<size_t>& labels)
+template <typename W>
+void DecisionStump<MatType>::Train(const MatType& data, const arma::Row<size_t>& labels, W isWeight)
{
// If classLabels are not all identical, proceed with training.
int bestAtt = 0;
double entropy;
- const double rootEntropy = CalculateEntropy<size_t>(
- labels.subvec(0, labels.n_elem - 1), 0);
+ const double rootEntropy = CalculateEntropy<size_t, W>(
+ labels.subvec(0, labels.n_elem - 1), 0, isWeight);
double gain, bestGain = 0.0;
for (int i = 0; i < data.n_rows; i++)
@@ -55,7 +60,7 @@ void DecisionStump<MatType>::Train(const MatType& data, const arma::Row<size_t>&
{
// For each attribute with non-identical values, treat it as a potential
// splitting attribute and calculate entropy if split on it.
- entropy = SetupSplitAttribute(data.row(i), labels);
+ entropy = SetupSplitAttribute<W>(data.row(i), labels, isWeight);
gain = rootEntropy - entropy;
// Find the attribute with the best entropy so that the gain is
@@ -119,6 +124,7 @@ void DecisionStump<MatType>::Classify(const MatType& test,
* @param data The data on which to train this object on.
* @param D Weight vector to use while training. For boosting purposes.
* @param labels The labels of data.
+ * @param isWeight Whether we need to run a weighted Decision Stump.
*/
template <typename MatType>
DecisionStump<MatType>::DecisionStump(
@@ -133,8 +139,8 @@ DecisionStump<MatType>::DecisionStump(
weightD = weights;
tempD = weightD;
-
- Train(data, labels);
+ const bool isWeight = true;
+ Train<bool>(data, labels, isWeight);
}
/**
@@ -143,11 +149,14 @@ DecisionStump<MatType>::DecisionStump(
*
* @param attribute A row from the training data, which might be a candidate for
* the splitting attribute.
+ * @param isWeight Whether we need to run a weighted Decision Stump.
*/
template <typename MatType>
+template <typename W>
double DecisionStump<MatType>::SetupSplitAttribute(
const arma::rowvec& attribute,
- const arma::Row<size_t>& labels)
+ const arma::Row<size_t>& labels,
+ W isWeight)
{
int i, count, begin, end;
double entropy = 0.0;
@@ -167,7 +176,9 @@ double DecisionStump<MatType>::SetupSplitAttribute(
for (i = 0; i < attribute.n_elem; i++)
{
sortedLabels(i) = labels(sortedIndexAtt(i));
- tempD(i) = weightD(sortedIndexAtt(i));
+
+ if(isWeight)
+ tempD(i) = weightD(sortedIndexAtt(i));
}
i = 0;
@@ -188,8 +199,8 @@ double DecisionStump<MatType>::SetupSplitAttribute(
// Use ratioEl to calculate the ratio of elements in this split.
const double ratioEl = ((double) (end - begin + 1) / sortedLabels.n_elem);
- entropy += ratioEl * CalculateEntropy<size_t>(
- sortedLabels.subvec(begin, end), begin);
+ entropy += ratioEl * CalculateEntropy<size_t, W>(
+ sortedLabels.subvec(begin, end), begin, isWeight);
i++;
}
else if (sortedLabels(i) != sortedLabels(i + 1))
@@ -215,8 +226,8 @@ double DecisionStump<MatType>::SetupSplitAttribute(
}
const double ratioEl = ((double) (end - begin + 1) / sortedLabels.n_elem);
- entropy += ratioEl * CalculateEntropy<size_t>(
- sortedLabels.subvec(begin, end), begin);
+ entropy += ratioEl * CalculateEntropy<size_t, W>(
+ sortedLabels.subvec(begin, end), begin, isWeight);
i = end + 1;
count = 0;
@@ -404,12 +415,13 @@ int DecisionStump<MatType>::IsDistinct(const arma::Row<rType>& featureRow)
*
* @param attribute The attribute for which we calculate the entropy.
* @param labels Corresponding labels of the attribute.
+ * @param isWeight Whether we need to run a weighted Decision Stump.
*/
template<typename MatType>
-template<typename LabelType>
+template<typename LabelType, typename W>
double DecisionStump<MatType>::CalculateEntropy(
arma::subview_row<LabelType> labels,
- int begin)
+ int begin, W isWeight)
{
double entropy = 0.0;
size_t j;
@@ -421,20 +433,34 @@ double DecisionStump<MatType>::CalculateEntropy(
double accWeight = 0.0;
// Populate numElem; they are used as helpers to calculate entropy.
- for (j = 0; j < labels.n_elem; j++)
+ if(isWeight)
{
- numElem(labels(j)) += tempD(j + begin);
- accWeight += tempD(j + begin);
- }
- // numElem(labels(j))++;
+ for (j = 0; j < labels.n_elem; j++)
+ {
+ numElem(labels(j)) += tempD(j + begin);
+ accWeight += tempD(j + begin);
+ }
+ // numElem(labels(j))++;
- for (j = 0; j < numClass; j++)
- {
- const double p1 = ((double) numElem(j) / accWeight);
+ for (j = 0; j < numClass; j++)
+ {
+ const double p1 = ((double) numElem(j) / accWeight);
- entropy += (p1 == 0) ? 0 : p1 * log2(p1);
+ entropy += (p1 == 0) ? 0 : p1 * log2(p1);
+ }
}
+ else
+ {
+ for (j = 0; j < labels.n_elem; j++)
+ numElem(labels(j))++;
+ for (j = 0; j < numClass; j++)
+ {
+ const double p1 = ((double) numElem(j) / labels.n_elem);
+
+ entropy += (p1 == 0) ? 0 : p1 * log2(p1);
+ }
+ }
return entropy;
}
diff --git a/src/mlpack/methods/decision_stump/decision_stump_main.cpp b/src/mlpack/methods/decision_stump/decision_stump_main.cpp
index f6d6053..48ad4e3 100644
--- a/src/mlpack/methods/decision_stump/decision_stump_main.cpp
+++ b/src/mlpack/methods/decision_stump/decision_stump_main.cpp
@@ -95,7 +95,8 @@ int main(int argc, char *argv[])
<< ")!" << std::endl;
Timer::Start("training");
- DecisionStump<> ds(trainingData, labels.t(), numClasses, inpBucketSize);
+ DecisionStump<> ds(trainingData, labels.t(), numClasses,
+ inpBucketSize);
Timer::Stop("training");
Row<size_t> predictedLabels(testingData.n_cols);
diff --git a/src/mlpack/tests/adaboost_test.cpp b/src/mlpack/tests/adaboost_test.cpp
index 704f3d0..3abba11 100644
--- a/src/mlpack/tests/adaboost_test.cpp
+++ b/src/mlpack/tests/adaboost_test.cpp
@@ -1,8 +1,8 @@
/**
- * @file Adaboost_test.cpp
+ * @file AdaBoost_test.cpp
* @author Udit Saxena
*
- * Tests for Adaboost class.
+ * Tests for AdaBoost class.
*/
#include <mlpack/core.hpp>
@@ -15,10 +15,10 @@ using namespace mlpack;
using namespace arma;
using namespace mlpack::adaboost;
-BOOST_AUTO_TEST_SUITE(AdaboostTest);
+BOOST_AUTO_TEST_SUITE(AdaBoostTest);
/**
- * This test case runs the Adaboost.mh algorithm on the UCI Iris dataset.
+ * This test case runs the AdaBoost.mh algorithm on the UCI Iris dataset.
* It checks whether the hamming loss breaches the upperbound, which
* is provided by ztAccumulator.
*/
@@ -45,7 +45,8 @@ BOOST_AUTO_TEST_CASE(HammingLossBoundIris)
// Define parameters for the adaboost
int iterations = 100;
double tolerance = 1e-10;
- Adaboost<> a(inputData, labels.row(0), iterations, tolerance, p);
+ AdaBoost<> a(inputData, labels.row(0), iterations, tolerance, p);
+
int countError = 0;
for (size_t i = 0; i < labels.n_cols; i++)
if(labels(i) != a.finalHypothesis(i))
@@ -56,7 +57,7 @@ BOOST_AUTO_TEST_CASE(HammingLossBoundIris)
}
/**
- * This test case runs the Adaboost.mh algorithm on the UCI Iris dataset.
+ * This test case runs the AdaBoost.mh algorithm on the UCI Iris dataset.
* It checks if the error returned by running a single instance of the
* weak learner is worse than running the boosted weak learner using
* adaboost.
@@ -92,7 +93,7 @@ BOOST_AUTO_TEST_CASE(WeakLearnerErrorIris)
// Define parameters for the adaboost
int iterations = 100;
double tolerance = 1e-10;
- Adaboost<> a(inputData, labels.row(0), iterations, tolerance, p);
+ AdaBoost<> a(inputData, labels.row(0), iterations, tolerance, p);
int countError = 0;
for (size_t i = 0; i < labels.n_cols; i++)
if(labels(i) != a.finalHypothesis(i))
@@ -103,7 +104,7 @@ BOOST_AUTO_TEST_CASE(WeakLearnerErrorIris)
}
/**
- * This test case runs the Adaboost.mh algorithm on the UCI Vertebral
+ * This test case runs the AdaBoost.mh algorithm on the UCI Vertebral
* Column dataset.
* It checks whether the hamming loss breaches the upperbound, which
* is provided by ztAccumulator.
@@ -131,7 +132,7 @@ BOOST_AUTO_TEST_CASE(HammingLossBoundVertebralColumn)
// Define parameters for the adaboost
int iterations = 50;
double tolerance = 1e-10;
- Adaboost<> a(inputData, labels.row(0), iterations, tolerance, p);
+ AdaBoost<> a(inputData, labels.row(0), iterations, tolerance, p);
int countError = 0;
for (size_t i = 0; i < labels.n_cols; i++)
if(labels(i) != a.finalHypothesis(i))
@@ -142,7 +143,7 @@ BOOST_AUTO_TEST_CASE(HammingLossBoundVertebralColumn)
}
/**
- * This test case runs the Adaboost.mh algorithm on the UCI Vertebral
+ * This test case runs the AdaBoost.mh algorithm on the UCI Vertebral
* Column dataset.
* It checks if the error returned by running a single instance of the
* weak learner is worse than running the boosted weak learner using
@@ -179,7 +180,7 @@ BOOST_AUTO_TEST_CASE(WeakLearnerErrorVertebralColumn)
// Define parameters for the adaboost
int iterations = 50;
double tolerance = 1e-10;
- Adaboost<> a(inputData, labels.row(0), iterations, tolerance, p);
+ AdaBoost<> a(inputData, labels.row(0), iterations, tolerance, p);
int countError = 0;
for (size_t i = 0; i < labels.n_cols; i++)
if(labels(i) != a.finalHypothesis(i))
@@ -190,7 +191,7 @@ BOOST_AUTO_TEST_CASE(WeakLearnerErrorVertebralColumn)
}
/**
- * This test case runs the Adaboost.mh algorithm on non-linearly
+ * This test case runs the AdaBoost.mh algorithm on non-linearly
* separable dataset.
* It checks whether the hamming loss breaches the upperbound, which
* is provided by ztAccumulator.
@@ -218,7 +219,7 @@ BOOST_AUTO_TEST_CASE(HammingLossBoundNonLinearSepData)
// Define parameters for the adaboost
int iterations = 50;
double tolerance = 1e-10;
- Adaboost<> a(inputData, labels.row(0), iterations, tolerance, p);
+ AdaBoost<> a(inputData, labels.row(0), iterations, tolerance, p);
int countError = 0;
for (size_t i = 0; i < labels.n_cols; i++)
if(labels(i) != a.finalHypothesis(i))
@@ -229,7 +230,7 @@ BOOST_AUTO_TEST_CASE(HammingLossBoundNonLinearSepData)
}
/**
- * This test case runs the Adaboost.mh algorithm on a non-linearly
+ * This test case runs the AdaBoost.mh algorithm on a non-linearly
* separable dataset.
* It checks if the error returned by running a single instance of the
* weak learner is worse than running the boosted weak learner using
@@ -266,7 +267,7 @@ BOOST_AUTO_TEST_CASE(WeakLearnerErrorNonLinearSepData)
// Define parameters for the adaboost
int iterations = 50;
double tolerance = 1e-10;
- Adaboost<> a(inputData, labels.row(0), iterations, tolerance, p);
+ AdaBoost<> a(inputData, labels.row(0), iterations, tolerance, p);
int countError = 0;
for (size_t i = 0; i < labels.n_cols; i++)
if(labels(i) != a.finalHypothesis(i))
@@ -277,7 +278,7 @@ BOOST_AUTO_TEST_CASE(WeakLearnerErrorNonLinearSepData)
}
/**
- * This test case runs the Adaboost.mh algorithm on the UCI Iris dataset.
+ * This test case runs the AdaBoost.mh algorithm on the UCI Iris dataset.
* It checks whether the hamming loss breaches the upperbound, which
* is provided by ztAccumulator.
* This is for the weak learner: Decision Stumps.
@@ -307,7 +308,7 @@ BOOST_AUTO_TEST_CASE(HammingLossIris_DS)
int iterations = 50;
double tolerance = 1e-10;
- Adaboost<arma::mat, mlpack::decision_stump::DecisionStump<> > a(inputData,
+ AdaBoost<arma::mat, mlpack::decision_stump::DecisionStump<> > a(inputData,
labels.row(0), iterations, tolerance, ds);
int countError = 0;
for (size_t i = 0; i < labels.n_cols; i++)
@@ -319,7 +320,7 @@ BOOST_AUTO_TEST_CASE(HammingLossIris_DS)
}
/**
- * This test case runs the Adaboost.mh algorithm on a non-linearly
+ * This test case runs the AdaBoost.mh algorithm on a non-linearly
* separable dataset.
* It checks if the error returned by running a single instance of the
* weak learner is worse than running the boosted weak learner using
@@ -360,7 +361,7 @@ BOOST_AUTO_TEST_CASE(WeakLearnerErrorIris_DS)
int iterations = 50;
double tolerance = 1e-10;
- Adaboost<arma::mat, mlpack::decision_stump::DecisionStump<> > a(inputData,
+ AdaBoost<arma::mat, mlpack::decision_stump::DecisionStump<> > a(inputData,
labels.row(0), iterations, tolerance, ds);
int countError = 0;
for (size_t i = 0; i < labels.n_cols; i++)
@@ -371,7 +372,7 @@ BOOST_AUTO_TEST_CASE(WeakLearnerErrorIris_DS)
BOOST_REQUIRE(error <= weakLearnerErrorRate);
}
/**
- * This test case runs the Adaboost.mh algorithm on the UCI Vertebral
+ * This test case runs the AdaBoost.mh algorithm on the UCI Vertebral
* Column dataset.
* It checks if the error returned by running a single instance of the
* weak learner is worse than running the boosted weak learner using
@@ -404,7 +405,7 @@ BOOST_AUTO_TEST_CASE(HammingLossBoundVertebralColumn_DS)
int iterations = 50;
double tolerance = 1e-10;
- Adaboost<arma::mat, mlpack::decision_stump::DecisionStump<> > a(inputData,
+ AdaBoost<arma::mat, mlpack::decision_stump::DecisionStump<> > a(inputData,
labels.row(0), iterations, tolerance, ds);
int countError = 0;
for (size_t i = 0; i < labels.n_cols; i++)
@@ -416,7 +417,7 @@ BOOST_AUTO_TEST_CASE(HammingLossBoundVertebralColumn_DS)
}
/**
- * This test case runs the Adaboost.mh algorithm on the UCI Vertebral
+ * This test case runs the AdaBoost.mh algorithm on the UCI Vertebral
* Column dataset.
* It checks if the error returned by running a single instance of the
* weak learner is worse than running the boosted weak learner using
@@ -456,7 +457,7 @@ BOOST_AUTO_TEST_CASE(WeakLearnerErrorVertebralColumn_DS)
// Define parameters for the adaboost
int iterations = 50;
double tolerance = 1e-10;
- Adaboost<arma::mat, mlpack::decision_stump::DecisionStump<> > a(inputData,
+ AdaBoost<arma::mat, mlpack::decision_stump::DecisionStump<> > a(inputData,
labels.row(0), iterations, tolerance, ds);
int countError = 0;
for (size_t i = 0; i < labels.n_cols; i++)
@@ -467,7 +468,7 @@ BOOST_AUTO_TEST_CASE(WeakLearnerErrorVertebralColumn_DS)
BOOST_REQUIRE(error <= weakLearnerErrorRate);
}
/**
- * This test case runs the Adaboost.mh algorithm on non-linearly
+ * This test case runs the AdaBoost.mh algorithm on non-linearly
* separable dataset.
* It checks whether the hamming loss breaches the upperbound, which
* is provided by ztAccumulator.
@@ -499,7 +500,7 @@ BOOST_AUTO_TEST_CASE(HammingLossBoundNonLinearSepData_DS)
int iterations = 50;
double tolerance = 1e-10;
- Adaboost<arma::mat, mlpack::decision_stump::DecisionStump<> > a(inputData,
+ AdaBoost<arma::mat, mlpack::decision_stump::DecisionStump<> > a(inputData,
labels.row(0), iterations, tolerance, ds);
int countError = 0;
for (size_t i = 0; i < labels.n_cols; i++)
@@ -511,7 +512,7 @@ BOOST_AUTO_TEST_CASE(HammingLossBoundNonLinearSepData_DS)
}
/**
- * This test case runs the Adaboost.mh algorithm on a non-linearly
+ * This test case runs the AdaBoost.mh algorithm on a non-linearly
* separable dataset.
* It checks if the error returned by running a single instance of the
* weak learner is worse than running the boosted weak learner using
@@ -535,7 +536,7 @@ BOOST_AUTO_TEST_CASE(WeakLearnerErrorNonLinearSepData_DS)
// Define your own weak learner, Decision Stump in this case.
const size_t numClasses = 2;
- const size_t inpBucketSize = 6;
+ const size_t inpBucketSize = 3;
arma::Row<size_t> dsPrediction(labels.n_cols);
@@ -549,10 +550,10 @@ BOOST_AUTO_TEST_CASE(WeakLearnerErrorNonLinearSepData_DS)
double weakLearnerErrorRate = (double) countWeakLearnerError / labels.n_cols;
// Define parameters for the adaboost
- int iterations = 50;
- double tolerance = 1e-10;
+ int iterations = 500;
+ double tolerance = 1e-23;
- Adaboost<arma::mat, mlpack::decision_stump::DecisionStump<> > a(inputData,
+ AdaBoost<arma::mat, mlpack::decision_stump::DecisionStump<> > a(inputData,
labels.row(0), iterations, tolerance, ds);
int countError = 0;
for (size_t i = 0; i < labels.n_cols; i++)
diff --git a/src/mlpack/tests/decision_stump_test.cpp b/src/mlpack/tests/decision_stump_test.cpp
index dec4f2c..325bbc6 100644
--- a/src/mlpack/tests/decision_stump_test.cpp
+++ b/src/mlpack/tests/decision_stump_test.cpp
@@ -26,7 +26,7 @@ BOOST_AUTO_TEST_CASE(OneClass)
{
const size_t numClasses = 2;
const size_t inpBucketSize = 6;
-
+
mat trainingData;
trainingData << 2.4 << 3.8 << 3.8 << endr
<< 1 << 1 << 2 << endr
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list