[mlpack] 180/324: Refactor main executable. Some further refactorization will need to be done when everything transitions to arma::Row<size_t> for labels.

Barak A. Pearlmutter barak+git at cs.nuim.ie
Sun Aug 17 08:22:08 UTC 2014


This is an automated email from the git hooks/post-receive script.

bap pushed a commit to branch svn-trunk
in repository mlpack.

commit 92f718178b96dcdb6844566d4b8ffd83c365d441
Author: rcurtin <rcurtin at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date:   Mon Jul 14 13:17:01 2014 +0000

    Refactor main executable.  Some further refactorization will need to be done
    when everything transitions to arma::Row<size_t> for labels.
    
    
    git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@16819 9d5b8971-822b-0410-80eb-d18c1038ef23
---
 .../methods/decision_stump/decision_stump_main.cpp | 112 ++++++++++++---------
 1 file changed, 67 insertions(+), 45 deletions(-)

diff --git a/src/mlpack/methods/decision_stump/decision_stump_main.cpp b/src/mlpack/methods/decision_stump/decision_stump_main.cpp
index 4c998d5..fe92e33 100644
--- a/src/mlpack/methods/decision_stump/decision_stump_main.cpp
+++ b/src/mlpack/methods/decision_stump/decision_stump_main.cpp
@@ -1,10 +1,9 @@
-/*
- * @author: Udit Saxena
- * @file: decision_stump_main.cpp
+/**
+ * @file decision_stump_main.cpp
+ * @author Udit Saxena
  *
  * Main executable for the decision stump.
  */
-
 #include <mlpack/core.hpp>
 #include "decision_stump.hpp"
 
@@ -13,22 +12,40 @@ using namespace mlpack::decision_stump;
 using namespace std;
 using namespace arma;
 
-PROGRAM_INFO("Decision Stump","This program implements a decision stump, "
-    "a single level decision tree, on the given training data set. "
-    "Default size of buckets is 6");
-
-// necessary parameters
-PARAM_STRING_REQ("train_file", "A file containing the training set.", "tr");
-PARAM_STRING_REQ("labels_file", "A file containing labels for the training set.",
-  "l");
-PARAM_STRING_REQ("test_file", "A file containing the test set.", "te");
-PARAM_STRING_REQ("num_classes","The number of classes","c");
-
-// output parameters (optional)
-PARAM_STRING("output", "The file in which the predicted labels for the test set"
-    " will be written.", "o", "output.csv");
-
-PARAM_INT("bucket_size","The size of ranges/buckets to be used while splitting the decision stump.","b", 6);
+PROGRAM_INFO("Decision Stump",
+    "This program implements a decision stump, which is a single-level decision"
+    " tree.  The decision stump will split on one dimension of the input data, "
+    "and will split into multiple buckets.  The dimension and bins are selected"
+    " by minimizing the entropy of the split.  Optionally, the minimum number "
+    "of training points in each bin can be specified with the --bin_size (-b) "
+    "parameter.\n"
+    "\n"
+    "The decision stump is parameterized by a splitting dimension and a vector "
+    "of values that denote the splitting values of each bin.\n"
+    "\n"
+    "This program allows training of a decision stump, and then application of "
+    "the learned decision stump to a test dataset.  To train a decision stump, "
+    "a training dataset must be passed to --train_file (-t).  Labels can either"
+    " be present as the last dimension of the training dataset, or given "
+    "explicitly with the --labels_file (-l) parameter.\n"
+    "\n"
+    "A test file is given through the --test_file (-T) parameter.  The "
+    "predicted labels for the test set will be stored in the file specified by "
+    "the --output_file (-o) parameter.");
+
+// Necessary parameters.
+PARAM_STRING_REQ("train_file", "A file containing the training set.", "t");
+PARAM_STRING_REQ("test_file", "A file containing the test set.", "T");
+
+// Output parameters (optional).
+PARAM_STRING("labels_file", "A file containing labels for the training set. If "
+    "not specified, the labels are assumed to be the last row of the training "
+    "data.", "l", "");
+PARAM_STRING("output_file", "The file in which the predicted labels for the "
+    "test set will be written.", "o", "output.csv");
+
+PARAM_INT("bin_size", "The minimum number of training points in each "
+    "decision stump bin.", "b", 6);
 
 int main(int argc, char *argv[])
 {
@@ -38,29 +55,36 @@ int main(int argc, char *argv[])
   mat trainingData;
   data::Load(trainingDataFilename, trainingData, true);
 
-  const string labelsFilename = CLI::GetParam<string>("labels_file");
-  // Load labels.
+  // Load labels, if necessary.
   mat labelsIn;
-  data::Load(labelsFilename, labelsIn, true);
-
-  // helpers for normalizing the labels
+  if (CLI::HasParam("labels_file"))
+  {
+    const string labelsFilename = CLI::GetParam<string>("labels_file");
+    // Load labels.
+    data::Load(labelsFilename, labelsIn, true);
+
+    // Do the labels need to be transposed?
+    if (labelsIn.n_rows == 1)
+      labelsIn = labelsIn.t();
+  }
+  else
+  {
+    // Extract the labels as the last
+    Log::Info << "Using the last dimension of training set as labels." << endl;
+
+    labelsIn = trainingData.row(trainingData.n_rows - 1).t();
+    trainingData.shed_row(trainingData.n_rows - 1);
+  }
+
+  // Normalize the labels.
   Col<size_t> labels;
   vec mappings;
-
-  // Do the labels need to be transposed?
-  if (labelsIn.n_rows == 1)
-    labelsIn = labelsIn.t();
-
-  size_t inpBucketSize = CLI::GetParam<int>("bucket_size");
-
-  // normalize the labels
   data::NormalizeLabels(labelsIn.unsafe_col(0), labels, mappings);
 
-  const size_t num_classes = CLI::GetParam<size_t>("num_classes");
-  /*
-  Should number of classes be input or should it be
-  derived from the labels row ?
-  */
+  const size_t inpBucketSize = CLI::GetParam<int>("bucket_size");
+  const size_t numClasses = labels.max() + 1;
+
+  // Load the test file.
   const string testingDataFilename = CLI::GetParam<std::string>("test_file");
   mat testingData;
   data::Load(testingDataFilename, testingData, true);
@@ -71,7 +95,7 @@ int main(int argc, char *argv[])
         << ")!" << std::endl;
 
   Timer::Start("training");
-  DecisionStump<> ds(trainingData, labels, num_classes, inpBucketSize);
+  DecisionStump<> ds(trainingData, labels.t(), numClasses, inpBucketSize);
   Timer::Stop("training");
 
   Row<size_t> predictedLabels(testingData.n_cols);
@@ -80,11 +104,9 @@ int main(int argc, char *argv[])
   Timer::Stop("testing");
 
   vec results;
-  data::RevertLabels(predictedLabels, mappings, results);
-
-  const string outputFilename = CLI::GetParam<string>("output");
-  data::Save(outputFilename, results, true, true);
-  // saving the predictedLabels in the transposed manner in output
+  data::RevertLabels(predictedLabels.t(), mappings, results);
 
-  return 0;
+  // Save the predicted labels in a transposed form as output.
+  const string outputFilename = CLI::GetParam<string>("output_file");
+  data::Save(outputFilename, results, true, false);
 }

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git



More information about the debian-science-commits mailing list