[mlpack] 57/324: Decision Stump added
Barak A. Pearlmutter
barak+git at cs.nuim.ie
Sun Aug 17 08:21:55 UTC 2014
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch svn-trunk
in repository mlpack.
commit 88de233ff6343bbfdc8c9ff452339535b6922357
Author: saxena.udit <saxena.udit at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date: Thu Jun 12 20:46:35 2014 +0000
Decision Stump added
git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@16686 9d5b8971-822b-0410-80eb-d18c1038ef23
---
src/mlpack/methods/CMakeLists.txt | 1 +
.../methods/{CMakeLists.txt => CMakeLists.txt~} | 0
src/mlpack/methods/decision_stump/CMakeLists.txt | 26 ++
src/mlpack/methods/decision_stump/CMakeLists.txt~ | 31 ++
.../methods/decision_stump/decision_stump.hpp | 127 ++++++
.../methods/decision_stump/decision_stump_impl.cpp | 441 +++++++++++++++++++++
.../methods/decision_stump/decision_stump_main.cpp | 90 +++++
src/mlpack/tests/CMakeLists.txt | 1 +
.../tests/{CMakeLists.txt => CMakeLists.txt~} | 0
src/mlpack/tests/decision_stump_test.cpp | 174 ++++++++
10 files changed, 891 insertions(+)
diff --git a/src/mlpack/methods/CMakeLists.txt b/src/mlpack/methods/CMakeLists.txt
index fb6a0d4..b930aaa 100644
--- a/src/mlpack/methods/CMakeLists.txt
+++ b/src/mlpack/methods/CMakeLists.txt
@@ -2,6 +2,7 @@
set(DIRS
amf
cf
+ decision_stump
det
emst
fastmks
diff --git a/src/mlpack/methods/CMakeLists.txt b/src/mlpack/methods/CMakeLists.txt~
similarity index 100%
copy from src/mlpack/methods/CMakeLists.txt
copy to src/mlpack/methods/CMakeLists.txt~
diff --git a/src/mlpack/methods/decision_stump/CMakeLists.txt b/src/mlpack/methods/decision_stump/CMakeLists.txt
new file mode 100644
index 0000000..2e6dfe6
--- /dev/null
+++ b/src/mlpack/methods/decision_stump/CMakeLists.txt
@@ -0,0 +1,26 @@
+cmake_minimum_required(VERSION 2.8)
+
+# Define the files we need to compile.
+# Anything not in this list will not be compiled into MLPACK.
+set(SOURCES
+ decision_stump.hpp
+ decision_stump_impl.cpp
+)
+
+# Add directory name to sources.
+set(DIR_SRCS)
+foreach(file ${SOURCES})
+ set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file})
+endforeach()
+# Append sources (with directory name) to list of all MLPACK sources (used at
+# the parent scope).
+set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)
+
+add_executable(dec_stu
+ decision_stump_main.cpp
+)
+target_link_libraries(dec_stu
+ mlpack
+)
+
+install(TARGETS dec_stu RUNTIME DESTINATION bin)
diff --git a/src/mlpack/methods/decision_stump/CMakeLists.txt~ b/src/mlpack/methods/decision_stump/CMakeLists.txt~
new file mode 100644
index 0000000..402bf47
--- /dev/null
+++ b/src/mlpack/methods/decision_stump/CMakeLists.txt~
@@ -0,0 +1,31 @@
+cmake_minimum_required(VERSION 2.8)
+
+# Define the files we need to compile.
+# Anything not in this list will not be compiled into MLPACK.
+set(SOURCES
+ decision_stump.hpp
+ decision_stump_impl.cpp
+)
+
+# Add directory name to sources.
+set(DIR_SRCS)
+foreach(file ${SOURCES})
+ set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file})
+endforeach()
+# Append sources (with directory name) to list of all MLPACK sources (used at
+# the parent scope).
+set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)
+
+add_executable(dec_stu
+ decision_stump_main.cpp
+)
+target_link_libraries(dec_stu
+ mlpack
+)
+
+target_link_libraries(dec_stu_test
+ mlpack
+ boost_unit_test_framework
+)
+
+install(TARGETS dec_stu RUNTIME DESTINATION bin)
diff --git a/src/mlpack/methods/decision_stump/decision_stump.hpp b/src/mlpack/methods/decision_stump/decision_stump.hpp
new file mode 100644
index 0000000..8faa977
--- /dev/null
+++ b/src/mlpack/methods/decision_stump/decision_stump.hpp
@@ -0,0 +1,127 @@
+/**
+ * @file decision_stump.hpp
+ * @author Udit Saxena
+ *
+ * Defintion of decision stumps.
+ */
+
+#ifndef _MLPACK_METHODS_DECISION_STUMP_HPP
+#define _MLPACK_METHODS_DECISION_STUMP_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace decision_stump {
+/*
+ * This class implements a decision stump. It constructs a single level
+ * decision tree, i.e. a decision stump. It uses entropy to decided splitting
+ * ranges.
+ *
+ */
+template <typename MatType = arma::mat>
+class DecisionStump
+{
+ public:
+ /*
+ Constructor. Train on the provided data. Generate a decision stump
+ from data.
+
+ @param: data - Input, training data.
+ @param: labels - Labels of data.
+ @param: classes - number of distinct classes in labels.
+ @param: inpBucketSize - minimum size of bucket when splitting.
+ */
+ DecisionStump(const MatType& data,
+ const arma::Row<size_t>& labels,
+ const size_t classes,
+ size_t inpBucketSize);
+
+ /*
+ Classification function. After training, classify test, and put the
+ predicted classes in predictedLabels.
+
+ @param: test - testing data or data to classify.
+ @param: predictedLabels - vector to store the predicted classes after
+ classifying test
+ */
+ void Classify(const MatType& test, arma::Row<size_t>& predictedLabels);
+
+ private:
+ /* Stores the number of classes.*/
+ size_t numClass;
+
+ /* Stores the default class. Provided for handling missing attribute values.*/
+ size_t defaultClass;
+
+ /* Stores the value of the attribute on which to split.*/
+ int splitCol;
+
+ /* Flag value for distinct input class labels.*/
+ int oneClass;
+
+ /* Size of bucket while determining splitting criterion.*/
+ size_t bucketSize;
+
+ /* Stores the class labels for the input data*/
+ arma::Row<size_t> classLabels;
+
+ /* Stores the splitting criterion after training.*/
+ arma::mat split;
+
+ /*
+ Sets up attribute as if it were splitting on it and
+ finds entropy when splitting on attribute.
+
+ @param: attribute - a row from the training data, which might be a
+ candidate for the splitting attribute.
+ */
+ double SetupSplitAttribute(const arma::rowvec& attribute);
+
+ /*
+ After having decided the attribute on which to split,
+ train on that attribute.
+
+ @param: attribute - attribute is the attribute decided by the constructor
+ on which we now train the decision stump.
+ */
+ template <typename rType> void TrainOnAtt(const arma::rowvec& attribute);
+
+ /* After the "split" matrix has been set up,
+ merging ranges with identical class labels.
+ */
+ void MergeRanges();
+
+ /*
+ Used to count the most frequently occurring element in subCols.
+
+ @param: subCols - the vector in which to find the most frequently
+ occurring element.
+ */
+ template <typename rType> rType CountMostFreq(const arma::Row<rType>& subCols);
+
+ /*
+ Returns 1 if all the values of featureRow are not same.
+
+ @param: featureRow - the attribute which is checked so that it
+ does not have identical values.
+ */
+ template <typename rType> int isDistinct(const arma::Row<rType>& featureRow);
+
+ /*
+ Calculating Entropy of attribute.
+
+ @param: attribute - the attribute of which we calculate the entropy.
+ @param: labels - corresponding labels of the attribute.
+ */
+ double CalculateEntropy(const arma::rowvec& attribute,
+ const arma::rowvec& labels);
+
+
+};
+
+}; //namespace decision_stump
+}; //namespace mlpack
+
+#include "decision_stump_impl.cpp"
+
+#endif
\ No newline at end of file
diff --git a/src/mlpack/methods/decision_stump/decision_stump_impl.cpp b/src/mlpack/methods/decision_stump/decision_stump_impl.cpp
new file mode 100644
index 0000000..007a255
--- /dev/null
+++ b/src/mlpack/methods/decision_stump/decision_stump_impl.cpp
@@ -0,0 +1,441 @@
+/**
+ * @file decision_stump_impl.hpp
+ * @author Udit Saxena
+**/
+
+#ifndef _MLPACK_METHODS_DECISION_STUMP_IMPL_HPP
+#define _MLPACK_METHODS_DECISION_STUMP_IMPL_HPP
+
+#include "decision_stump.hpp"
+
+#include <set>
+#include <algorithm>
+
+namespace mlpack {
+namespace decision_stump {
+/*
+ Constructor. Train on the provided data. Generate a decision stump
+ from data.
+
+ @param: data - Input, training data.
+ @param: labels - Labels of data.
+ @param: classes - number of distinct classes in labels.
+ @param: inpBucketSize - minimum size of bucket when splitting.
+ */
+template<typename MatType>
+DecisionStump<MatType>::DecisionStump(const MatType& data,
+ const arma::Row<size_t>& labels,
+ const size_t classes,
+ size_t inpBucketSize)
+{
+ classLabels = labels + arma::zeros<arma::Row<size_t> >(labels.n_elem);
+
+ numClass = classes;
+ bucketSize = inpBucketSize;
+
+ /* Check whether the input labels are not all identical. */
+ if ( !isDistinct<size_t>(classLabels) )
+ {
+ // If the classLabels are all identical,
+ // the default class is the only class set.
+ oneClass = 1;
+ defaultClass = classLabels(0);
+ }
+
+ else
+ {
+ // If classLabels are not all identical
+ // proceed for training
+
+ oneClass = 0;
+ int bestAtt=-1,i,j;
+ double entropy,bestEntropy=DBL_MAX;
+
+ // Set the default class to handle attribute values which are
+ // not present in the training data.
+ defaultClass = CountMostFreq<size_t>(classLabels);
+
+ for (i = 0;i < data.n_rows; i++)
+ {
+ // going through each attribute of data.
+ if (isDistinct<double>(data.row(i)))
+ {
+ // for each attribute with non-identical values,
+ // treat it as a potential splitting attribute
+ // and calculate entropy if split on it.
+ entropy=SetupSplitAttribute(data.row(i));
+
+ // finding the attribute with the bestEntropy
+ // so that the gain is max.
+ if (entropy < bestEntropy)
+ {
+ bestAtt = i;
+ bestEntropy = entropy;
+ }
+
+ }
+ }
+ splitCol = bestAtt;
+
+ // once the splitting column/attribute has been decided,
+ // train on it.
+ TrainOnAtt<double>(data.row(splitCol));
+ }
+}
+
+/*
+ Classification function. After training, classify test, and put the
+ predicted classes in predictedLabels.
+
+ @param: test - testing data or data to classify.
+ @param: predictedLabels - vector to store the predicted classes after
+ classifying test
+ */
+template<typename MatType>
+void DecisionStump<MatType>::Classify(const MatType& test,
+ arma::Row<size_t>& predictedLabels)
+{
+ int i,j,flag;
+ double val,testval;
+ if ( !oneClass )
+ {
+ for (i = 0; i < test.n_cols; i++)
+ {
+ j = 0;
+ flag = 0;
+
+ while ((j < split.n_rows) && (!flag))
+ {
+ if(val < split(j,0) && (!j))
+ {
+ predictedLabels(i) = split(0,1);
+ flag = 1;
+ }
+ else if (val >= split(j,0))
+ {
+ if(j == split.n_rows - 1)
+ {
+ predictedLabels(i) = split(split.n_rows - 1, 1);
+ flag = 1;
+ }
+ else if (val < split(j+1,0))
+ {
+ predictedLabels(i) = split(j,1);
+ flag = 1;
+ }
+ }
+ j++;
+ }
+ }
+ }
+ else
+ {
+ for (i = 0;i < test.n_cols;i++)
+ predictedLabels(i)=defaultClass;
+ }
+
+}
+
+/*
+ Sets up attribute as if it were splitting on it and
+ finds entropy when splitting on attribute.
+
+ @param: attribute - a row from the training data, which might be a
+ candidate for the splitting attribute.
+ */
+template <typename MatType>
+double DecisionStump<MatType>::SetupSplitAttribute(const arma::rowvec& attribute)
+{
+ int i, count, begin, end;
+ double entropy = 0.0;
+
+ // sorting the attribute, for calculating splitting ranges
+ arma::rowvec sortedAtt = arma::sort(attribute);
+
+ // storing the indexes of the sorted attribute to build
+ // a vector of sorted labels.
+ // this sort is stable.
+ arma::uvec sortedIndexAtt = arma::stable_sort_index(attribute.t());
+
+ // vector of sorted labels
+ arma::Row<size_t> sortedLabels(attribute.n_elem,arma::fill::zeros);
+
+ for (i = 0; i < attribute.n_elem; i++)
+ sortedLabels(i) = classLabels(sortedIndexAtt(i));
+
+ arma::rowvec subColLabels;
+ arma::rowvec subColAtts;
+
+ i = 0;
+ count = 0;
+
+ // this splits the sorted into buckets of size >= inpBucketSize
+ while (i < sortedLabels.n_elem)
+ {
+ count++;
+ if (i == sortedLabels.n_elem - 1)
+ {
+ begin = i - count + 1;
+ end = i;
+
+ subColLabels = sortedLabels.cols(begin, end) +
+ arma::zeros<arma::rowvec>((sortedLabels.cols(begin, end)).n_elem);
+
+ subColAtts = sortedAtt.cols(begin, end) +
+ arma::zeros<arma::rowvec>((sortedAtt.cols(begin, end)).n_elem);
+
+ entropy += CalculateEntropy(subColAtts, subColLabels);
+ i++;
+ }
+ else if( sortedLabels(i) != sortedLabels(i + 1) )
+ {
+ if (count < bucketSize)
+ {
+ begin = i - count + 1;
+ end = begin + bucketSize - 1;
+
+ if ( end > sortedLabels.n_elem - 1)
+ end = sortedLabels.n_elem - 1;
+ }
+ else
+ {
+ begin = i - count + 1;
+ end = i;
+ }
+
+ subColLabels = sortedLabels.cols(begin, end) +
+ arma::zeros<arma::rowvec>((sortedLabels.cols(begin, end)).n_elem);
+
+ subColAtts = sortedAtt.cols(begin, end) +
+ arma::zeros<arma::rowvec>((sortedAtt.cols(begin, end)).n_elem);
+
+ // now using subColLabels and subColAtts to calculate entropuy
+ entropy += CalculateEntropy(subColAtts, subColLabels);
+
+ i = end + 1;
+ count = 0;
+
+ }
+ else
+ i++;
+ }
+ return entropy;
+}
+
+/*
+ After having decided the attribute on which to split,
+ train on that attribute.
+
+ @param: attribute - attribute is the attribute decided by the constructor
+ on which we now train the decision stump.
+ */
+template <typename MatType>
+template <typename rType>
+void DecisionStump<MatType>::TrainOnAtt(const arma::rowvec& attribute)
+{
+ int i, count, begin, end;
+
+ arma::rowvec sortedSplitAtt = arma::sort(attribute);
+ arma::uvec sortedSplitIndexAtt = arma::stable_sort_index(attribute.t());
+ arma::Row<size_t> sortedLabels(attribute.n_elem,arma::fill::zeros);
+ arma::mat tempSplit;
+
+ for (i = 0; i < attribute.n_elem; i++)
+ sortedLabels(i) = classLabels(sortedSplitIndexAtt(i));
+
+ arma::rowvec subCols;
+ rType mostFreq;
+ i = 0;
+ count = 0;
+ while (i < sortedLabels.n_elem)
+ {
+ count++;
+ if (i == sortedLabels.n_elem - 1)
+ {
+ begin = i - count + 1;
+ end = i;
+
+ subCols = sortedLabels.cols(begin, end) +
+ arma::zeros<arma::rowvec>((sortedLabels.cols(begin, end)).n_elem);
+
+ mostFreq = CountMostFreq<double>(subCols);
+
+ tempSplit << sortedSplitAtt(begin)<< mostFreq << arma::endr;
+ split = arma::join_cols(split, tempSplit);
+
+ i++;
+ }
+ else if( sortedLabels(i) != sortedLabels(i + 1) )
+ {
+ if (count < bucketSize) // test for differevalues of bucketSize, especially extreme cases.
+ {
+ begin = i - count + 1;
+ end = begin + bucketSize - 1;
+
+ if ( end > sortedLabels.n_elem - 1)
+ end = sortedLabels.n_elem - 1;
+ }
+ else
+ {
+ begin = i - count + 1;
+ end = i;
+ }
+ subCols = sortedLabels.cols(begin, end) +
+ arma::zeros<arma::rowvec>((sortedLabels.cols(begin, end)).n_elem);
+
+ // finding the most freq element in subCols so as to assign a label to the
+ // bucket of subCols
+
+ mostFreq = CountMostFreq<double>(subCols);
+
+ tempSplit << sortedSplitAtt(begin)<< mostFreq << arma::endr;
+ split = arma::join_cols(split, tempSplit);
+
+ i = end + 1;
+ count = 0;
+ }
+ else
+ i++;
+ }
+
+ // now trimming the split matrix so that buckets one after the after
+ // which point to the same classLabel are merged as one big bucket.
+ MergeRanges();
+}
+
+/* After the "split" matrix has been set up,
+ merging ranges with identical class labels.
+ */
+template <typename MatType>
+void DecisionStump<MatType>::MergeRanges()
+{
+ int i;
+ for (i = 1;i < split.n_rows; i++)
+ {
+ if (split(i,1) == split(i-1,1))
+ {
+ // remove this row, as it has the same label as
+ // the previous bucket.
+ split.shed_row(i);
+ // go back to previous row.
+ i--;
+ }
+ }
+}
+
+template <typename MatType>
+template <typename rType>
+rType DecisionStump<MatType>::CountMostFreq(const arma::Row<rType>& subCols)
+{
+ // sort subCols for easier processing.
+ arma::Row<rType> sortCounts = arma::sort(subCols);
+ rType element;
+ int count = 0, localCount = 0,i;
+
+ // an O(n) loop which counts the most frequent element in sortCounts
+ for (i = 0; i < sortCounts.n_elem ; ++i)
+ {
+ if (i == sortCounts.n_elem - 1)
+ {
+ if (sortCounts(i-1) == sortCounts(i))
+ {
+ // element = sortCounts(i-1);
+ localCount++;
+ }
+ else
+ if (localCount > count)
+ count = localCount;
+ }
+ else if (sortCounts(i) != sortCounts(i+1))
+ {
+ localCount = 0;
+ count++;
+ }
+ else
+ {
+ localCount++;
+ if (localCount > count)
+ {
+ count = localCount;
+ if(localCount == 1)
+ element = sortCounts(i);
+ }
+ }
+ }
+ return element;
+}
+
+/*
+ Returns 1 if all the values of featureRow are not same.
+
+ @param: featureRow - the attribute which is checked so that it
+ does not have identical values.
+ */
+template <typename MatType>
+template <typename rType>
+int DecisionStump<MatType>::isDistinct(const arma::Row<rType>& featureRow)
+{
+ if (featureRow.max()-featureRow.min() > 0)
+ return 1;
+ else
+ return 0;
+}
+
+/*
+ Calculating Entropy of attribute.
+
+ @param: attribute - the attribute of which we calculate the entropy.
+ @param: labels - corresponding labels of the attribute.
+ */
+template<typename MatType>
+double DecisionStump<MatType>::CalculateEntropy(const arma::rowvec& attribute,
+ const arma::rowvec& labels)
+{
+ int i,j,count;
+ double entropy=0.0;
+
+ arma::rowvec uniqueAtt = arma::unique(attribute);
+ arma::rowvec uniqueLabel = arma::unique(labels);
+ arma::Row<size_t> numElem(uniqueAtt.n_elem,arma::fill::zeros);
+ arma::Mat<size_t> entropyArray(uniqueAtt.n_elem,numClass,arma::fill::zeros);
+
+ // populating entropyArray and numElem, they are to be used as
+ // helpers to calculate entropy
+ for (j = 0;j < uniqueAtt.n_elem; j++)
+ {
+ for (i = 0; i < attribute.n_elem; i++)
+ {
+ if (uniqueAtt[j] == attribute[i])
+ {
+ entropyArray(j,labels(i))++;
+ numElem(j)++;
+ }
+ }
+ }
+
+ double p1, p2, p3;
+ for ( j = 0; j < uniqueAtt.size(); j++ )
+ {
+ p1 = ((double)numElem(j) / attribute.n_elem);
+
+ for ( i = 0; i < numClass; i++)
+ {
+ p2 = ((double)entropyArray(j,i) / numElem(j));
+
+ if(p2 == 0)
+ p3 = 0;
+ else
+ p3 = ( p2 * log2(p2) );
+
+ entropy+=( p1 * p3 );
+ }
+ }
+
+ return entropy;
+}
+
+
+}; // namespace decision_stump
+}; // namespace mlpack
+
+#endif
\ No newline at end of file
diff --git a/src/mlpack/methods/decision_stump/decision_stump_main.cpp b/src/mlpack/methods/decision_stump/decision_stump_main.cpp
new file mode 100644
index 0000000..186a7cb
--- /dev/null
+++ b/src/mlpack/methods/decision_stump/decision_stump_main.cpp
@@ -0,0 +1,90 @@
+/*
+ * @author: Udit Saxena
+ * @file: decision_stump_main.cpp
+ *
+ *
+ */
+
+#include <mlpack/core.hpp>
+#include "decision_stump.hpp"
+
+using namespace mlpack;
+using namespace mlpack::decision_stump;
+using namespace std;
+using namespace arma;
+
+PROGRAM_INFO("Decision Stump","This program implements a decision stump, "
+ "a single level decision tree, on the given training data set. "
+ "Default size of buckets is 6");
+
+// necessary parameters
+PARAM_STRING_REQ("train_file", "A file containing the training set.", "tr");
+PARAM_STRING_REQ("labels_file", "A file containing labels for the training set.",
+ "l");
+PARAM_STRING_REQ("test_file", "A file containing the test set.", "te");
+PARAM_STRING_REQ("num_classes","The number of classes","c");
+
+// output parameters (optional)
+PARAM_STRING("output", "The file in which the predicted labels for the test set"
+ " will be written.", "o", "output.csv");
+
+PARAM_INT("bucket_size","The size of ranges/buckets to be used while splitting the decision stump.","b", 6);
+
+int main(int argc, char *argv[])
+{
+ CLI::ParseCommandLine(argc, argv);
+
+ const string trainingDataFilename = CLI::GetParam<string>("train_file");
+ mat trainingData;
+ data::Load(trainingDataFilename, trainingData, true);
+
+ const string labelsFilename = CLI::GetParam<string>("labels_file");
+ // Load labels.
+ mat labelsIn;
+ data::Load(labelsFilename, labelsIn, true);
+
+ // helpers for normalizing the labels
+ Col<size_t> labels;
+ vec mappings;
+
+ // Do the labels need to be transposed?
+ if (labelsIn.n_rows == 1)
+ labelsIn = labelsIn.t();
+
+ size_t inpBucketSize = CLI::GetParam<int>("bucket_size");
+
+ // normalize the labels
+ data::NormalizeLabels(labelsIn.unsafe_col(0), labels, mappings);
+
+ const size_t num_classes = CLI::GetParam<size_t>("num_classes");
+ /*
+ Should number of classes be input or should it be
+ derived from the labels row ?
+ */
+ const string testingDataFilename = CLI::GetParam<std::string>("test_file");
+ mat testingData;
+ data::Load(testingDataFilename, testingData, true);
+
+ if (testingData.n_rows != trainingData.n_rows)
+ Log::Fatal << "Test data dimensionality (" << testingData.n_rows << ") "
+ << "must be the same as training data (" << trainingData.n_rows - 1
+ << ")!" << std::endl;
+
+ Timer::Start("training");
+ DecisionStump<> ds(trainingData, labels, num_classes, inpBucketSize);
+ Timer::Stop("training");
+
+ Row<size_t> predictedLabels(testingData.n_cols);
+ Timer::Start("testing");
+ ds.Classify(testingData, predictedLabels);
+ Timer::Stop("testing");
+
+ vec results;
+ data::RevertLabels(predictedLabels, mappings, results);
+
+ const string outputFilename = CLI::GetParam<string>("output");
+ data::Save(outputFilename, results, true, true);
+ // saving the predictedLabels in the transposed manner in output
+
+ return 0;
+}
\ No newline at end of file
diff --git a/src/mlpack/tests/CMakeLists.txt b/src/mlpack/tests/CMakeLists.txt
index 2e7f3ce..2aebb62 100644
--- a/src/mlpack/tests/CMakeLists.txt
+++ b/src/mlpack/tests/CMakeLists.txt
@@ -8,6 +8,7 @@ add_executable(mlpack_test
aug_lagrangian_test.cpp
cf_test.cpp
cli_test.cpp
+ decision_stump_test.cpp
det_test.cpp
distribution_test.cpp
emst_test.cpp
diff --git a/src/mlpack/tests/CMakeLists.txt b/src/mlpack/tests/CMakeLists.txt~
similarity index 100%
copy from src/mlpack/tests/CMakeLists.txt
copy to src/mlpack/tests/CMakeLists.txt~
diff --git a/src/mlpack/tests/decision_stump_test.cpp b/src/mlpack/tests/decision_stump_test.cpp
new file mode 100644
index 0000000..310c93a
--- /dev/null
+++ b/src/mlpack/tests/decision_stump_test.cpp
@@ -0,0 +1,174 @@
+/*
+ * @file decision_stump_test.cpp
+ * @author Udit Saxena
+ *
+ * Test for Decision Stump
+ */
+
+#include <mlpack/core.hpp>
+#include <mlpack/methods/decision_stump/decision_stump.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace mlpack;
+using namespace mlpack::decision_stump;
+using namespace arma;
+
+BOOST_AUTO_TEST_SUITE(DSTEST);
+
+/*
+This tests handles the case wherein only one class exists in the input labels.
+It checks whether the only class supplied was the only class predicted.
+ */
+BOOST_AUTO_TEST_CASE(OneClass)
+{
+ size_t numClasses = 2;
+ size_t inpBucketSize = 6;
+
+ mat trainingData;
+ trainingData << 2.4 << 3.8 << 3.8 << endr
+ << 1 << 1 << 2 << endr
+ << 1.3 << 1.9 << 1.3 << endr;
+
+ Mat<size_t> labelsIn;
+ labelsIn << 1 << 1 << 1;
+
+ // no need to normalize labels here.
+
+ mat testingData;
+ testingData << 2.4 << 2.5 << 2.6;
+
+ DecisionStump<> ds(trainingData, labelsIn.row(0), numClasses, inpBucketSize);
+
+ Row<size_t> predictedLabels(testingData.n_cols);
+ ds.Classify(testingData, predictedLabels);
+
+ for(int i = 0; i < predictedLabels.size(); i++ )
+ BOOST_CHECK_EQUAL(predictedLabels(i),1);
+
+}
+
+/*
+This tests for the classification:
+ if testinput < 0 - class 0
+ if testinput > 0 - class 1
+An almost perfect split on zero.
+*/
+BOOST_AUTO_TEST_CASE(PerfectSplitOnZero)
+{
+ size_t numClasses = 2;
+ const char* output = "outputPerfectSplitOnZero.csv";
+ size_t inpBucketSize = 2;
+
+ mat trainingData;
+ trainingData << -1 << 1 << -2 << 2 << -3 << 3;
+
+ Mat<size_t> labelsIn;
+ labelsIn << 0 << 1 << 0 << 1 << 0 << 1;
+ // no need to normalize labels here.
+
+ mat testingData;
+ testingData << -4 << 7 << -7 << -5 << 6;
+
+ DecisionStump<> ds(trainingData, labelsIn.row(0), numClasses, inpBucketSize);
+
+ Row<size_t> predictedLabels(testingData.n_cols);
+ ds.Classify(testingData, predictedLabels);
+
+ data::Save(output, predictedLabels, true, true);
+}
+
+/*
+This tests the binning function for the case when a dataset with
+cardinality of input < inpBucketSize is provided.
+*/
+BOOST_AUTO_TEST_CASE(BinningTesting)
+{
+ size_t numClasses = 2;
+ const char* output = "outputBinningTesting.csv";
+ size_t inpBucketSize = 10;
+
+ mat trainingData;
+ trainingData << -1 << 1 << -2 << 2 << -3 << 3 << -4;
+
+ Mat<size_t> labelsIn;
+ labelsIn << 0 << 1 << 0 << 1 << 0 << 1 << 0;
+
+ // no need to normalize labels here.
+
+ mat testingData;
+ testingData << 5;
+
+ DecisionStump<> ds(trainingData, labelsIn.row(0), numClasses, inpBucketSize);
+
+ Row<size_t> predictedLabels(testingData.n_cols);
+ ds.Classify(testingData, predictedLabels);
+
+ data::Save(output, predictedLabels, true, true);
+}
+
+/*
+This is a test for the case when non-overlapping, multiple
+classes are provided. It tests for a perfect split due to the
+non-overlapping nature of the input classes.
+*/
+BOOST_AUTO_TEST_CASE(PerfectMultiClassSplit)
+{
+ size_t numClasses = 4;
+ const char* output = "outputPerfectMultiClassSplit.csv";
+ size_t inpBucketSize = 3;
+
+ mat trainingData;
+ trainingData << -8 << -7 << -6 << -5 << -4 << -3 << -2 << -1
+ << 0 << 1 << 2 << 3 << 4 << 5 << 6 << 7;
+
+ Mat<size_t> labelsIn;
+ labelsIn << 0 << 0 << 0 << 0 << 1 << 1 << 1 << 1
+ << 2 << 2 << 2 << 2 << 3 << 3 << 3 << 3;
+ // no need to normalize labels here.
+
+ mat testingData;
+ testingData << -6.1 << -2.1 << 1.1 << 5.1;
+
+ DecisionStump<> ds(trainingData, labelsIn.row(0), numClasses, inpBucketSize);
+
+ Row<size_t> predictedLabels(testingData.n_cols);
+ ds.Classify(testingData, predictedLabels);
+
+ data::Save(output, predictedLabels, true, true);
+}
+
+/*
+This test is for the case when reasonably overlapping, multiple classes
+are provided in the input label set. It tests whether classification
+takes place with a reasonable amount of error due to the overlapping
+nature of input classes.
+*/
+BOOST_AUTO_TEST_CASE(MultiClassSplit)
+{
+ size_t numClasses = 3;
+ const char* output = "outputMultiClassSplit.csv";
+ size_t inpBucketSize = 3;
+
+ mat trainingData;
+ trainingData << -7 << -6 << -5 << -4 << -3 << -2 << -1 << 0 << 1
+ << 2 << 3 << 4 << 5 << 6 << 7 << 8 << 9 << 10;
+
+ Mat<size_t> labelsIn;
+ labelsIn << 0 << 0 << 0 << 0 << 1 << 1 << 0 << 0
+ << 1 << 1 << 1 << 2 << 1 << 2 << 2 << 2 << 2 << 2;
+ // no need to normalize labels here.
+
+ mat testingData;
+ testingData << -6.1 << -5.9 << -2.1 << -0.7 << 2.5 << 4.7 << 7.2 << 9.1;
+
+ DecisionStump<> ds(trainingData, labelsIn.row(0), numClasses, inpBucketSize);
+
+ Row<size_t> predictedLabels(testingData.n_cols);
+ ds.Classify(testingData, predictedLabels);
+
+ data::Save(output, predictedLabels, true, true);
+}
+
+BOOST_AUTO_TEST_SUITE_END();
\ No newline at end of file
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list