[mlpack] 57/324: Decision Stump added

Barak A. Pearlmutter barak+git at cs.nuim.ie
Sun Aug 17 08:21:55 UTC 2014


This is an automated email from the git hooks/post-receive script.

bap pushed a commit to branch svn-trunk
in repository mlpack.

commit 88de233ff6343bbfdc8c9ff452339535b6922357
Author: saxena.udit <saxena.udit at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date:   Thu Jun 12 20:46:35 2014 +0000

    Decision Stump added
    
    git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@16686 9d5b8971-822b-0410-80eb-d18c1038ef23
---
 src/mlpack/methods/CMakeLists.txt                  |   1 +
 .../methods/{CMakeLists.txt => CMakeLists.txt~}    |   0
 src/mlpack/methods/decision_stump/CMakeLists.txt   |  26 ++
 src/mlpack/methods/decision_stump/CMakeLists.txt~  |  31 ++
 .../methods/decision_stump/decision_stump.hpp      | 127 ++++++
 .../methods/decision_stump/decision_stump_impl.cpp | 441 +++++++++++++++++++++
 .../methods/decision_stump/decision_stump_main.cpp |  90 +++++
 src/mlpack/tests/CMakeLists.txt                    |   1 +
 .../tests/{CMakeLists.txt => CMakeLists.txt~}      |   0
 src/mlpack/tests/decision_stump_test.cpp           | 174 ++++++++
 10 files changed, 891 insertions(+)

diff --git a/src/mlpack/methods/CMakeLists.txt b/src/mlpack/methods/CMakeLists.txt
index fb6a0d4..b930aaa 100644
--- a/src/mlpack/methods/CMakeLists.txt
+++ b/src/mlpack/methods/CMakeLists.txt
@@ -2,6 +2,7 @@
 set(DIRS
   amf
   cf
+  decision_stump
   det
   emst
   fastmks
diff --git a/src/mlpack/methods/CMakeLists.txt b/src/mlpack/methods/CMakeLists.txt~
similarity index 100%
copy from src/mlpack/methods/CMakeLists.txt
copy to src/mlpack/methods/CMakeLists.txt~
diff --git a/src/mlpack/methods/decision_stump/CMakeLists.txt b/src/mlpack/methods/decision_stump/CMakeLists.txt
new file mode 100644
index 0000000..2e6dfe6
--- /dev/null
+++ b/src/mlpack/methods/decision_stump/CMakeLists.txt
@@ -0,0 +1,26 @@
+cmake_minimum_required(VERSION 2.8)
+
+# Define the files we need to compile.
+# Anything not in this list will not be compiled into MLPACK.
+set(SOURCES
+  decision_stump.hpp
+  decision_stump_impl.cpp
+)
+
+# Add directory name to sources.
+set(DIR_SRCS)
+foreach(file ${SOURCES})
+  set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file})
+endforeach()
+# Append sources (with directory name) to list of all MLPACK sources (used at
+# the parent scope).
+set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)
+
+add_executable(dec_stu
+  decision_stump_main.cpp
+)
+target_link_libraries(dec_stu
+  mlpack
+)
+
+install(TARGETS dec_stu RUNTIME DESTINATION bin)
diff --git a/src/mlpack/methods/decision_stump/CMakeLists.txt~ b/src/mlpack/methods/decision_stump/CMakeLists.txt~
new file mode 100644
index 0000000..402bf47
--- /dev/null
+++ b/src/mlpack/methods/decision_stump/CMakeLists.txt~
@@ -0,0 +1,31 @@
+cmake_minimum_required(VERSION 2.8)
+
+# Define the files we need to compile.
+# Anything not in this list will not be compiled into MLPACK.
+set(SOURCES
+  decision_stump.hpp
+  decision_stump_impl.cpp
+)
+
+# Add directory name to sources.
+set(DIR_SRCS)
+foreach(file ${SOURCES})
+  set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file})
+endforeach()
+# Append sources (with directory name) to list of all MLPACK sources (used at
+# the parent scope).
+set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)
+
+add_executable(dec_stu
+  decision_stump_main.cpp
+)
+target_link_libraries(dec_stu
+  mlpack
+)
+
+target_link_libraries(dec_stu_test
+  mlpack
+  boost_unit_test_framework
+)
+
+install(TARGETS dec_stu RUNTIME DESTINATION bin)
diff --git a/src/mlpack/methods/decision_stump/decision_stump.hpp b/src/mlpack/methods/decision_stump/decision_stump.hpp
new file mode 100644
index 0000000..8faa977
--- /dev/null
+++ b/src/mlpack/methods/decision_stump/decision_stump.hpp
@@ -0,0 +1,127 @@
+/**
+ * @file decision_stump.hpp
+ * @author Udit Saxena
+ * 
+ * Defintion of decision stumps.
+ */
+
+#ifndef _MLPACK_METHODS_DECISION_STUMP_HPP
+#define _MLPACK_METHODS_DECISION_STUMP_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace decision_stump {
+/*
+ * This class implements a decision stump. It constructs a single level
+ * decision tree, i.e. a decision stump. It uses entropy to decided splitting
+ * ranges.
+ *
+ */
+template <typename MatType = arma::mat>
+class DecisionStump
+{
+ public:
+  /*
+  Constructor. Train on the provided data. Generate a decision stump
+  from data. 
+
+  @param: data - Input, training data.
+  @param: labels - Labels of data.
+  @param: classes - number of distinct classes in labels.
+  @param: inpBucketSize - minimum size of bucket when splitting.
+   */
+  DecisionStump(const MatType& data,
+                const arma::Row<size_t>& labels,
+                const size_t classes, 
+                size_t inpBucketSize);
+
+  /*
+  Classification function. After training, classify test, and put the 
+  predicted classes in predictedLabels.
+
+  @param: test - testing data or data to classify. 
+  @param: predictedLabels - vector to store the predicted classes after
+                            classifying test
+   */
+  void Classify(const MatType& test, arma::Row<size_t>& predictedLabels);
+
+ private:
+  /* Stores the number of classes.*/
+  size_t numClass; 
+  
+  /* Stores the default class. Provided for handling missing attribute values.*/
+  size_t defaultClass;
+  
+  /* Stores the value of the attribute on which to split.*/
+  int splitCol;
+  
+  /* Flag value for distinct input class labels.*/
+  int oneClass; 
+  
+  /* Size of bucket while determining splitting criterion.*/
+  size_t bucketSize;
+  
+  /* Stores the class labels for the input data*/
+  arma::Row<size_t> classLabels;
+  
+  /* Stores the splitting criterion after training.*/
+  arma::mat split;
+  
+  /* 
+  Sets up attribute as if it were splitting on it and 
+  finds entropy when splitting on attribute.
+
+  @param: attribute - a row from the training data, which might be a
+                      candidate for the splitting attribute.
+  */
+  double SetupSplitAttribute(const arma::rowvec& attribute);
+
+  /* 
+  After having decided the attribute on which to split, 
+  train on that attribute.
+
+  @param: attribute - attribute is the attribute decided by the constructor
+                      on which we now train the decision stump.
+   */
+  template <typename rType> void TrainOnAtt(const arma::rowvec& attribute);
+
+  /* After the "split" matrix has been set up, 
+     merging ranges with identical class labels.
+   */
+  void MergeRanges();
+
+  /* 
+  Used to count the most frequently occurring element in subCols.
+
+  @param: subCols - the vector in which to find the most frequently 
+                    occurring element.  
+   */
+  template <typename rType> rType CountMostFreq(const arma::Row<rType>& subCols);
+ 
+  /* 
+  Returns 1 if all the values of featureRow are not same.
+
+  @param: featureRow - the attribute which is checked so that it 
+                       does not have identical values. 
+  */
+  template <typename rType> int isDistinct(const arma::Row<rType>& featureRow);
+
+  /* 
+  Calculating Entropy of attribute.
+
+  @param: attribute - the attribute of which we calculate the entropy.
+  @param: labels - corresponding labels of the attribute.
+  */
+  double CalculateEntropy(const arma::rowvec& attribute, 
+                          const arma::rowvec& labels);
+
+  
+};
+
+}; //namespace decision_stump
+}; //namespace mlpack
+
+#include "decision_stump_impl.cpp"
+
+#endif
\ No newline at end of file
diff --git a/src/mlpack/methods/decision_stump/decision_stump_impl.cpp b/src/mlpack/methods/decision_stump/decision_stump_impl.cpp
new file mode 100644
index 0000000..007a255
--- /dev/null
+++ b/src/mlpack/methods/decision_stump/decision_stump_impl.cpp
@@ -0,0 +1,441 @@
+/**
+ * @file decision_stump_impl.hpp
+ * @author Udit Saxena
+**/
+
+#ifndef _MLPACK_METHODS_DECISION_STUMP_IMPL_HPP
+#define _MLPACK_METHODS_DECISION_STUMP_IMPL_HPP
+
+#include "decision_stump.hpp"
+
+#include <set>
+#include <algorithm>
+
+namespace mlpack {
+namespace decision_stump {
+/*
+  Constructor. Train on the provided data. Generate a decision stump
+  from data. 
+
+  @param: data - Input, training data.
+  @param: labels - Labels of data.
+  @param: classes - number of distinct classes in labels.
+  @param: inpBucketSize - minimum size of bucket when splitting.
+ */
+template<typename MatType>
+DecisionStump<MatType>::DecisionStump(const MatType& data,
+                                      const arma::Row<size_t>& labels,
+                                      const size_t classes,
+                                      size_t inpBucketSize)
+{
+  classLabels = labels + arma::zeros<arma::Row<size_t> >(labels.n_elem);
+  
+  numClass = classes;
+  bucketSize = inpBucketSize;
+
+  /* Check whether the input labels are not all identical. */
+  if ( !isDistinct<size_t>(classLabels) )
+  {
+    // If the classLabels are all identical, 
+    // the default class is the only class set. 
+    oneClass = 1;
+    defaultClass = classLabels(0); 
+  }
+
+  else
+  {
+    // If classLabels are not all identical
+    // proceed for training
+
+    oneClass = 0;
+    int bestAtt=-1,i,j;
+    double entropy,bestEntropy=DBL_MAX; 
+
+    // Set the default class to handle attribute values which are 
+    // not present in the training data. 
+    defaultClass = CountMostFreq<size_t>(classLabels);
+
+    for (i = 0;i < data.n_rows; i++)
+    {
+      // going through each attribute of data.
+      if (isDistinct<double>(data.row(i)))
+      {
+        // for each attribute with non-identical values, 
+        // treat it as a potential splitting attribute
+        // and calculate entropy if split on it.
+        entropy=SetupSplitAttribute(data.row(i));
+    
+        // finding the attribute with the bestEntropy
+        // so that the gain is max.
+        if (entropy < bestEntropy)
+        {
+          bestAtt = i;
+          bestEntropy = entropy;
+        }
+
+      }
+    }
+    splitCol = bestAtt;
+
+    // once the splitting column/attribute has been decided, 
+    // train on it.
+    TrainOnAtt<double>(data.row(splitCol));
+  }
+}
+
+/*
+  Classification function. After training, classify test, and put the 
+  predicted classes in predictedLabels.
+
+  @param: test - testing data or data to classify. 
+  @param: predictedLabels - vector to store the predicted classes after
+                            classifying test
+ */
+template<typename MatType>
+void DecisionStump<MatType>::Classify(const MatType& test,
+                                      arma::Row<size_t>& predictedLabels)
+{
+  int i,j,flag;
+  double val,testval;
+  if ( !oneClass )
+  {
+    for (i = 0; i < test.n_cols; i++)
+    {
+      j = 0;
+      flag = 0;
+
+      while ((j < split.n_rows) && (!flag))
+      {
+        if(val < split(j,0) && (!j))
+        {
+          predictedLabels(i) = split(0,1);
+          flag = 1;
+        }
+        else if (val >= split(j,0))
+        {
+          if(j == split.n_rows - 1)
+          {
+            predictedLabels(i) = split(split.n_rows - 1, 1);
+            flag = 1;
+          }
+          else if (val < split(j+1,0))
+          {
+            predictedLabels(i) = split(j,1);
+            flag = 1;
+          }
+        }
+        j++;
+      }
+    }
+  }
+  else
+  {
+    for (i = 0;i < test.n_cols;i++)
+      predictedLabels(i)=defaultClass;
+  }
+
+}
+
+/* 
+  Sets up attribute as if it were splitting on it and 
+  finds entropy when splitting on attribute.
+
+  @param: attribute - a row from the training data, which might be a
+                      candidate for the splitting attribute.
+ */
+template <typename MatType>
+double DecisionStump<MatType>::SetupSplitAttribute(const arma::rowvec& attribute)
+{
+  int i, count, begin, end;
+  double entropy = 0.0;
+
+  // sorting the attribute, for calculating splitting ranges
+  arma::rowvec sortedAtt = arma::sort(attribute);
+
+  // storing the indexes of the sorted attribute to build 
+  // a vector of sorted labels.
+  // this sort is stable.
+  arma::uvec sortedIndexAtt = arma::stable_sort_index(attribute.t());
+
+  // vector of sorted labels
+  arma::Row<size_t> sortedLabels(attribute.n_elem,arma::fill::zeros);
+  
+  for (i = 0; i < attribute.n_elem; i++)
+    sortedLabels(i) = classLabels(sortedIndexAtt(i));
+
+  arma::rowvec subColLabels;
+  arma::rowvec subColAtts;
+
+  i = 0;
+  count = 0;
+
+  // this splits the sorted into buckets of size >= inpBucketSize
+  while (i < sortedLabels.n_elem)
+  {
+    count++;
+    if (i == sortedLabels.n_elem - 1)
+    {
+      begin = i - count + 1;
+      end = i;
+
+      subColLabels = sortedLabels.cols(begin, end) + 
+              arma::zeros<arma::rowvec>((sortedLabels.cols(begin, end)).n_elem);
+
+      subColAtts = sortedAtt.cols(begin, end) + 
+              arma::zeros<arma::rowvec>((sortedAtt.cols(begin, end)).n_elem);
+
+      entropy += CalculateEntropy(subColAtts, subColLabels);
+      i++;
+    }
+    else if( sortedLabels(i) != sortedLabels(i + 1) )
+    {
+      if (count < bucketSize) 
+      {
+        begin = i - count + 1;
+        end = begin + bucketSize - 1;
+        
+        if ( end > sortedLabels.n_elem - 1)
+          end = sortedLabels.n_elem - 1;
+      }
+      else
+      {
+        begin = i - count + 1;
+        end = i;
+      }
+
+      subColLabels = sortedLabels.cols(begin, end) + 
+              arma::zeros<arma::rowvec>((sortedLabels.cols(begin, end)).n_elem);
+
+      subColAtts = sortedAtt.cols(begin, end) + 
+              arma::zeros<arma::rowvec>((sortedAtt.cols(begin, end)).n_elem);
+
+      // now using subColLabels and subColAtts to calculate entropuy
+      entropy += CalculateEntropy(subColAtts, subColLabels);
+
+      i = end + 1;
+      count = 0;
+
+    }
+    else
+      i++;
+  }
+  return entropy;
+}
+
+/* 
+  After having decided the attribute on which to split, 
+  train on that attribute.
+
+  @param: attribute - attribute is the attribute decided by the constructor
+                      on which we now train the decision stump.
+ */
+template <typename MatType>
+template <typename rType>
+void DecisionStump<MatType>::TrainOnAtt(const arma::rowvec& attribute)
+{
+  int i, count, begin, end;
+
+  arma::rowvec sortedSplitAtt = arma::sort(attribute);
+  arma::uvec sortedSplitIndexAtt = arma::stable_sort_index(attribute.t());
+  arma::Row<size_t> sortedLabels(attribute.n_elem,arma::fill::zeros);
+  arma::mat tempSplit;
+
+  for (i = 0; i < attribute.n_elem; i++)
+    sortedLabels(i) = classLabels(sortedSplitIndexAtt(i));
+  
+  arma::rowvec subCols;
+  rType mostFreq;
+  i = 0;
+  count = 0;
+  while (i < sortedLabels.n_elem)
+  {
+    count++;
+    if (i == sortedLabels.n_elem - 1)
+    {
+      begin = i - count + 1;
+      end = i;
+
+      subCols = sortedLabels.cols(begin, end) + 
+              arma::zeros<arma::rowvec>((sortedLabels.cols(begin, end)).n_elem);
+
+      mostFreq = CountMostFreq<double>(subCols);
+
+      tempSplit << sortedSplitAtt(begin)<< mostFreq << arma::endr;
+      split = arma::join_cols(split, tempSplit);
+
+      i++;
+    }
+    else if( sortedLabels(i) != sortedLabels(i + 1) )
+    {
+      if (count < bucketSize) // test for differevalues of bucketSize, especially extreme cases. 
+      {
+        begin = i - count + 1;
+        end = begin + bucketSize - 1;
+        
+        if ( end > sortedLabels.n_elem - 1)
+          end = sortedLabels.n_elem - 1;
+      }
+      else
+      {
+        begin = i - count + 1;
+        end = i;
+      }
+      subCols = sortedLabels.cols(begin, end) + 
+              arma::zeros<arma::rowvec>((sortedLabels.cols(begin, end)).n_elem);
+
+      // finding the most freq element in subCols so as to assign a label to the
+      // bucket of subCols
+
+      mostFreq = CountMostFreq<double>(subCols);
+
+      tempSplit << sortedSplitAtt(begin)<< mostFreq << arma::endr;
+      split = arma::join_cols(split, tempSplit);
+
+      i = end + 1;
+      count = 0;
+    }
+    else
+      i++;
+  }
+
+  // now trimming the split matrix so that buckets one after the after 
+  // which point to the same classLabel are merged as one big bucket.
+  MergeRanges();
+}
+
+/* After the "split" matrix has been set up, 
+     merging ranges with identical class labels.
+ */
+template <typename MatType>
+void DecisionStump<MatType>::MergeRanges()
+{
+  int i;
+  for (i = 1;i < split.n_rows; i++)
+  {
+    if (split(i,1) == split(i-1,1))
+    {
+      // remove this row, as it has the same label as
+      // the previous bucket.
+      split.shed_row(i);
+      // go back to previous row.
+      i--;
+    }
+  }
+}
+
+template <typename MatType>
+template <typename rType>
+rType DecisionStump<MatType>::CountMostFreq(const arma::Row<rType>& subCols)
+{
+  // sort subCols for easier processing.
+  arma::Row<rType> sortCounts = arma::sort(subCols);
+  rType element;
+  int count = 0, localCount = 0,i;
+
+  // an O(n) loop which counts the most frequent element in sortCounts
+  for (i = 0; i < sortCounts.n_elem ; ++i)
+  {
+    if (i == sortCounts.n_elem - 1)
+    {
+      if (sortCounts(i-1) == sortCounts(i))
+      {
+        // element = sortCounts(i-1);
+        localCount++;
+      }
+      else
+      if (localCount > count)
+        count = localCount;
+    }
+    else if (sortCounts(i) != sortCounts(i+1))
+    {
+      localCount = 0;
+      count++;
+    }
+    else
+    {
+      localCount++;
+      if (localCount > count)
+      {
+        count = localCount;
+        if(localCount == 1)
+          element = sortCounts(i);
+      }
+    }
+  }
+  return element;
+}
+
+/* 
+  Returns 1 if all the values of featureRow are not same.
+
+  @param: featureRow - the attribute which is checked so that it 
+                       does not have identical values. 
+ */
+template <typename MatType>
+template <typename rType>
+int DecisionStump<MatType>::isDistinct(const arma::Row<rType>& featureRow)
+{
+  if (featureRow.max()-featureRow.min() > 0)
+    return 1;
+  else
+    return 0;
+}
+
+/* 
+  Calculating Entropy of attribute.
+
+  @param: attribute - the attribute of which we calculate the entropy.
+  @param: labels - corresponding labels of the attribute.
+ */
+template<typename MatType>
+double DecisionStump<MatType>::CalculateEntropy(const arma::rowvec& attribute,
+                                                const arma::rowvec& labels)
+{
+  int i,j,count;
+  double entropy=0.0;
+
+  arma::rowvec uniqueAtt = arma::unique(attribute);
+  arma::rowvec uniqueLabel = arma::unique(labels);
+  arma::Row<size_t> numElem(uniqueAtt.n_elem,arma::fill::zeros); 
+  arma::Mat<size_t> entropyArray(uniqueAtt.n_elem,numClass,arma::fill::zeros); 
+  
+  // populating entropyArray and numElem, they are to be used as 
+  // helpers to calculate entropy
+  for (j = 0;j < uniqueAtt.n_elem; j++)
+  {
+    for (i = 0; i < attribute.n_elem; i++)
+    {
+      if (uniqueAtt[j] == attribute[i])
+      {
+        entropyArray(j,labels(i))++;
+        numElem(j)++;
+      }
+    }
+  }
+
+  double p1, p2, p3;
+  for ( j = 0; j < uniqueAtt.size(); j++ )
+  {
+    p1 = ((double)numElem(j) / attribute.n_elem);
+
+    for ( i = 0; i < numClass; i++)
+    {
+      p2 = ((double)entropyArray(j,i) / numElem(j));
+      
+      if(p2 == 0)
+        p3 = 0;
+      else
+        p3 = (  p2 * log2(p2) );
+
+      entropy+=( p1 * p3 );
+    }
+  }
+
+  return entropy;
+}
+
+
+}; // namespace decision_stump
+}; // namespace mlpack
+
+#endif
\ No newline at end of file
diff --git a/src/mlpack/methods/decision_stump/decision_stump_main.cpp b/src/mlpack/methods/decision_stump/decision_stump_main.cpp
new file mode 100644
index 0000000..186a7cb
--- /dev/null
+++ b/src/mlpack/methods/decision_stump/decision_stump_main.cpp
@@ -0,0 +1,90 @@
+/*
+ * @author: Udit Saxena
+ * @file: decision_stump_main.cpp
+ *
+ *
+ */
+
+#include <mlpack/core.hpp>
+#include "decision_stump.hpp"
+
+using namespace mlpack;
+using namespace mlpack::decision_stump;
+using namespace std;
+using namespace arma;
+
+PROGRAM_INFO("Decision Stump","This program implements a decision stump, "
+    "a single level decision tree, on the given training data set. "
+    "Default size of buckets is 6");
+
+// necessary parameters
+PARAM_STRING_REQ("train_file", "A file containing the training set.", "tr");
+PARAM_STRING_REQ("labels_file", "A file containing labels for the training set.",
+  "l");
+PARAM_STRING_REQ("test_file", "A file containing the test set.", "te");
+PARAM_STRING_REQ("num_classes","The number of classes","c");
+
+// output parameters (optional)
+PARAM_STRING("output", "The file in which the predicted labels for the test set"
+    " will be written.", "o", "output.csv");
+
+PARAM_INT("bucket_size","The size of ranges/buckets to be used while splitting the decision stump.","b", 6);
+
+int main(int argc, char *argv[])
+{
+  CLI::ParseCommandLine(argc, argv);
+
+  const string trainingDataFilename = CLI::GetParam<string>("train_file");
+  mat trainingData;
+  data::Load(trainingDataFilename, trainingData, true);
+
+  const string labelsFilename = CLI::GetParam<string>("labels_file");
+  // Load labels.
+  mat labelsIn;
+  data::Load(labelsFilename, labelsIn, true);
+
+  // helpers for normalizing the labels
+  Col<size_t> labels;
+  vec mappings;
+
+  // Do the labels need to be transposed?
+  if (labelsIn.n_rows == 1)
+    labelsIn = labelsIn.t();
+
+  size_t inpBucketSize = CLI::GetParam<int>("bucket_size");
+  
+  // normalize the labels
+  data::NormalizeLabels(labelsIn.unsafe_col(0), labels, mappings);
+
+  const size_t num_classes = CLI::GetParam<size_t>("num_classes");
+  /*
+  Should number of classes be input or should it be 
+  derived from the labels row ?
+  */
+  const string testingDataFilename = CLI::GetParam<std::string>("test_file");
+  mat testingData;
+  data::Load(testingDataFilename, testingData, true);
+
+  if (testingData.n_rows != trainingData.n_rows)
+    Log::Fatal << "Test data dimensionality (" << testingData.n_rows << ") "
+        << "must be the same as training data (" << trainingData.n_rows - 1
+        << ")!" << std::endl;
+
+  Timer::Start("training");
+  DecisionStump<> ds(trainingData, labels, num_classes, inpBucketSize);
+  Timer::Stop("training");
+
+  Row<size_t> predictedLabels(testingData.n_cols);
+  Timer::Start("testing");
+  ds.Classify(testingData, predictedLabels);
+  Timer::Stop("testing");
+
+  vec results;
+  data::RevertLabels(predictedLabels, mappings, results);
+
+  const string outputFilename = CLI::GetParam<string>("output");
+  data::Save(outputFilename, results, true, true);
+  // saving the predictedLabels in the transposed manner in output
+
+  return 0;
+}
\ No newline at end of file
diff --git a/src/mlpack/tests/CMakeLists.txt b/src/mlpack/tests/CMakeLists.txt
index 2e7f3ce..2aebb62 100644
--- a/src/mlpack/tests/CMakeLists.txt
+++ b/src/mlpack/tests/CMakeLists.txt
@@ -8,6 +8,7 @@ add_executable(mlpack_test
   aug_lagrangian_test.cpp
   cf_test.cpp
   cli_test.cpp
+  decision_stump_test.cpp
   det_test.cpp
   distribution_test.cpp
   emst_test.cpp
diff --git a/src/mlpack/tests/CMakeLists.txt b/src/mlpack/tests/CMakeLists.txt~
similarity index 100%
copy from src/mlpack/tests/CMakeLists.txt
copy to src/mlpack/tests/CMakeLists.txt~
diff --git a/src/mlpack/tests/decision_stump_test.cpp b/src/mlpack/tests/decision_stump_test.cpp
new file mode 100644
index 0000000..310c93a
--- /dev/null
+++ b/src/mlpack/tests/decision_stump_test.cpp
@@ -0,0 +1,174 @@
+/*
+ *  @file decision_stump_test.cpp
+ *  @author Udit Saxena
+ *  
+ *  Test for Decision Stump  
+ */
+
+#include <mlpack/core.hpp>
+#include <mlpack/methods/decision_stump/decision_stump.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace mlpack;
+using namespace mlpack::decision_stump;
+using namespace arma;
+
+BOOST_AUTO_TEST_SUITE(DSTEST);
+
+/*
+This tests handles the case wherein only one class exists in the input labels.
+It checks whether the only class supplied was the only class predicted.
+ */
+BOOST_AUTO_TEST_CASE(OneClass)
+{
+  size_t numClasses = 2;
+  size_t inpBucketSize = 6;
+
+  mat trainingData;
+  trainingData << 2.4 << 3.8 << 3.8 << endr
+               << 1 << 1 << 2 << endr
+               << 1.3 << 1.9 << 1.3 << endr;
+  
+  Mat<size_t> labelsIn;
+  labelsIn << 1 << 1 << 1;
+  
+  // no need to normalize labels here.
+
+  mat testingData;
+  testingData << 2.4 << 2.5 << 2.6;
+  
+  DecisionStump<> ds(trainingData, labelsIn.row(0), numClasses, inpBucketSize);
+
+  Row<size_t> predictedLabels(testingData.n_cols);
+  ds.Classify(testingData, predictedLabels);
+
+  for(int i = 0; i < predictedLabels.size(); i++ )
+    BOOST_CHECK_EQUAL(predictedLabels(i),1);  
+
+} 
+
+/*
+This tests for the classification: 
+ if testinput < 0 - class 0
+ if testinput > 0 - class 1
+An almost perfect split on zero.
+*/
+BOOST_AUTO_TEST_CASE(PerfectSplitOnZero)
+{
+  size_t numClasses = 2;
+  const char* output = "outputPerfectSplitOnZero.csv";
+  size_t inpBucketSize = 2;
+
+  mat trainingData;
+  trainingData << -1 << 1 << -2 << 2 << -3 << 3;
+  
+  Mat<size_t> labelsIn;
+  labelsIn << 0 << 1 << 0 << 1 << 0 << 1;
+  // no need to normalize labels here.
+
+  mat testingData;
+  testingData << -4 << 7 << -7 << -5 << 6;
+  
+  DecisionStump<> ds(trainingData, labelsIn.row(0), numClasses, inpBucketSize);
+
+  Row<size_t> predictedLabels(testingData.n_cols);
+  ds.Classify(testingData, predictedLabels);
+
+  data::Save(output, predictedLabels, true, true);
+}
+
+/*
+This tests the binning function for the case when a dataset with 
+cardinality of input < inpBucketSize is provided.
+*/
+BOOST_AUTO_TEST_CASE(BinningTesting)
+{
+  size_t numClasses = 2;
+  const char* output = "outputBinningTesting.csv";
+  size_t inpBucketSize = 10;
+
+  mat trainingData;
+  trainingData << -1 << 1 << -2 << 2 << -3 << 3 << -4;
+ 
+  Mat<size_t> labelsIn;
+  labelsIn << 0 << 1 << 0 << 1 << 0 << 1 << 0;
+  
+  // no need to normalize labels here.
+
+  mat testingData;
+  testingData << 5;
+  
+  DecisionStump<> ds(trainingData, labelsIn.row(0), numClasses, inpBucketSize);
+
+  Row<size_t> predictedLabels(testingData.n_cols);
+  ds.Classify(testingData, predictedLabels);
+
+  data::Save(output, predictedLabels, true, true);
+}
+
+/*
+This is a test for the case when non-overlapping, multiple
+classes are provided. It tests for a perfect split due to the
+non-overlapping nature of the input classes.
+*/
+BOOST_AUTO_TEST_CASE(PerfectMultiClassSplit)
+{
+  size_t numClasses = 4;
+  const char* output = "outputPerfectMultiClassSplit.csv";
+  size_t inpBucketSize = 3;
+
+  mat trainingData;
+  trainingData << -8 << -7 << -6 << -5 << -4 << -3 << -2 << -1
+               << 0 << 1 << 2 << 3 << 4 << 5 << 6 << 7;
+  
+  Mat<size_t> labelsIn;
+  labelsIn << 0 << 0 << 0 << 0 << 1 << 1 << 1 << 1 
+           << 2 << 2 << 2 << 2 << 3 << 3 << 3 << 3;
+  // no need to normalize labels here.
+
+  mat testingData;
+  testingData << -6.1 << -2.1 << 1.1 << 5.1;
+  
+  DecisionStump<> ds(trainingData, labelsIn.row(0), numClasses, inpBucketSize);
+
+  Row<size_t> predictedLabels(testingData.n_cols);
+  ds.Classify(testingData, predictedLabels);
+
+  data::Save(output, predictedLabels, true, true);
+}
+
+/*
+This test is for the case when reasonably overlapping, multiple classes 
+are provided in the input label set. It tests whether classification 
+takes place with a reasonable amount of error due to the overlapping 
+nature of input classes.
+*/
+BOOST_AUTO_TEST_CASE(MultiClassSplit)
+{
+  size_t numClasses = 3;
+  const char* output = "outputMultiClassSplit.csv";
+  size_t inpBucketSize = 3;
+
+  mat trainingData;
+  trainingData << -7 << -6 << -5 << -4 << -3 << -2 << -1 << 0 << 1 
+               << 2 << 3 << 4 << 5 << 6 << 7 << 8 << 9 << 10;
+  
+  Mat<size_t> labelsIn;
+  labelsIn << 0 << 0 << 0 << 0 << 1 << 1 << 0 << 0 
+           << 1 << 1 << 1 << 2 << 1 << 2 << 2 << 2 << 2 << 2;
+  // no need to normalize labels here.
+
+  mat testingData;
+  testingData << -6.1 << -5.9 << -2.1 << -0.7 << 2.5 << 4.7 << 7.2 << 9.1;
+  
+  DecisionStump<> ds(trainingData, labelsIn.row(0), numClasses, inpBucketSize);
+
+  Row<size_t> predictedLabels(testingData.n_cols);
+  ds.Classify(testingData, predictedLabels);
+
+  data::Save(output, predictedLabels, true, true);
+}
+
+BOOST_AUTO_TEST_SUITE_END();
\ No newline at end of file

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git



More information about the debian-science-commits mailing list