[mlpack] 180/324: Refactor main executable. Some further refactorization will need to be done when everything transitions to arma::Row<size_t> for labels.
Barak A. Pearlmutter
barak+git at cs.nuim.ie
Sun Aug 17 08:22:08 UTC 2014
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch svn-trunk
in repository mlpack.
commit 92f718178b96dcdb6844566d4b8ffd83c365d441
Author: rcurtin <rcurtin at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date: Mon Jul 14 13:17:01 2014 +0000
Refactor main executable. Some further refactorization will need to be done
when everything transitions to arma::Row<size_t> for labels.
git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@16819 9d5b8971-822b-0410-80eb-d18c1038ef23
---
.../methods/decision_stump/decision_stump_main.cpp | 112 ++++++++++++---------
1 file changed, 67 insertions(+), 45 deletions(-)
diff --git a/src/mlpack/methods/decision_stump/decision_stump_main.cpp b/src/mlpack/methods/decision_stump/decision_stump_main.cpp
index 4c998d5..fe92e33 100644
--- a/src/mlpack/methods/decision_stump/decision_stump_main.cpp
+++ b/src/mlpack/methods/decision_stump/decision_stump_main.cpp
@@ -1,10 +1,9 @@
-/*
- * @author: Udit Saxena
- * @file: decision_stump_main.cpp
+/**
+ * @file decision_stump_main.cpp
+ * @author Udit Saxena
*
* Main executable for the decision stump.
*/
-
#include <mlpack/core.hpp>
#include "decision_stump.hpp"
@@ -13,22 +12,40 @@ using namespace mlpack::decision_stump;
using namespace std;
using namespace arma;
-PROGRAM_INFO("Decision Stump","This program implements a decision stump, "
- "a single level decision tree, on the given training data set. "
- "Default size of buckets is 6");
-
-// necessary parameters
-PARAM_STRING_REQ("train_file", "A file containing the training set.", "tr");
-PARAM_STRING_REQ("labels_file", "A file containing labels for the training set.",
- "l");
-PARAM_STRING_REQ("test_file", "A file containing the test set.", "te");
-PARAM_STRING_REQ("num_classes","The number of classes","c");
-
-// output parameters (optional)
-PARAM_STRING("output", "The file in which the predicted labels for the test set"
- " will be written.", "o", "output.csv");
-
-PARAM_INT("bucket_size","The size of ranges/buckets to be used while splitting the decision stump.","b", 6);
+PROGRAM_INFO("Decision Stump",
+ "This program implements a decision stump, which is a single-level decision"
+ " tree. The decision stump will split on one dimension of the input data, "
+ "and will split into multiple buckets. The dimension and bins are selected"
+ " by minimizing the entropy of the split. Optionally, the minimum number "
+ "of training points in each bin can be specified with the --bin_size (-b) "
+ "parameter.\n"
+ "\n"
+ "The decision stump is parameterized by a splitting dimension and a vector "
+ "of values that denote the splitting values of each bin.\n"
+ "\n"
+ "This program allows training of a decision stump, and then application of "
+ "the learned decision stump to a test dataset. To train a decision stump, "
+ "a training dataset must be passed to --train_file (-t). Labels can either"
+ " be present as the last dimension of the training dataset, or given "
+ "explicitly with the --labels_file (-l) parameter.\n"
+ "\n"
+ "A test file is given through the --test_file (-T) parameter. The "
+ "predicted labels for the test set will be stored in the file specified by "
+ "the --output_file (-o) parameter.");
+
+// Necessary parameters.
+PARAM_STRING_REQ("train_file", "A file containing the training set.", "t");
+PARAM_STRING_REQ("test_file", "A file containing the test set.", "T");
+
+// Output parameters (optional).
+PARAM_STRING("labels_file", "A file containing labels for the training set. If "
+ "not specified, the labels are assumed to be the last row of the training "
+ "data.", "l", "");
+PARAM_STRING("output_file", "The file in which the predicted labels for the "
+ "test set will be written.", "o", "output.csv");
+
+PARAM_INT("bin_size", "The minimum number of training points in each "
+ "decision stump bin.", "b", 6);
int main(int argc, char *argv[])
{
@@ -38,29 +55,36 @@ int main(int argc, char *argv[])
mat trainingData;
data::Load(trainingDataFilename, trainingData, true);
- const string labelsFilename = CLI::GetParam<string>("labels_file");
- // Load labels.
+ // Load labels, if necessary.
mat labelsIn;
- data::Load(labelsFilename, labelsIn, true);
-
- // helpers for normalizing the labels
+ if (CLI::HasParam("labels_file"))
+ {
+ const string labelsFilename = CLI::GetParam<string>("labels_file");
+ // Load labels.
+ data::Load(labelsFilename, labelsIn, true);
+
+ // Do the labels need to be transposed?
+ if (labelsIn.n_rows == 1)
+ labelsIn = labelsIn.t();
+ }
+ else
+ {
+ // Extract the labels as the last
+ Log::Info << "Using the last dimension of training set as labels." << endl;
+
+ labelsIn = trainingData.row(trainingData.n_rows - 1).t();
+ trainingData.shed_row(trainingData.n_rows - 1);
+ }
+
+ // Normalize the labels.
Col<size_t> labels;
vec mappings;
-
- // Do the labels need to be transposed?
- if (labelsIn.n_rows == 1)
- labelsIn = labelsIn.t();
-
- size_t inpBucketSize = CLI::GetParam<int>("bucket_size");
-
- // normalize the labels
data::NormalizeLabels(labelsIn.unsafe_col(0), labels, mappings);
- const size_t num_classes = CLI::GetParam<size_t>("num_classes");
- /*
- Should number of classes be input or should it be
- derived from the labels row ?
- */
+ const size_t inpBucketSize = CLI::GetParam<int>("bucket_size");
+ const size_t numClasses = labels.max() + 1;
+
+ // Load the test file.
const string testingDataFilename = CLI::GetParam<std::string>("test_file");
mat testingData;
data::Load(testingDataFilename, testingData, true);
@@ -71,7 +95,7 @@ int main(int argc, char *argv[])
<< ")!" << std::endl;
Timer::Start("training");
- DecisionStump<> ds(trainingData, labels, num_classes, inpBucketSize);
+ DecisionStump<> ds(trainingData, labels.t(), numClasses, inpBucketSize);
Timer::Stop("training");
Row<size_t> predictedLabels(testingData.n_cols);
@@ -80,11 +104,9 @@ int main(int argc, char *argv[])
Timer::Stop("testing");
vec results;
- data::RevertLabels(predictedLabels, mappings, results);
-
- const string outputFilename = CLI::GetParam<string>("output");
- data::Save(outputFilename, results, true, true);
- // saving the predictedLabels in the transposed manner in output
+ data::RevertLabels(predictedLabels.t(), mappings, results);
- return 0;
+ // Save the predicted labels in a transposed form as output.
+ const string outputFilename = CLI::GetParam<string>("output_file");
+ data::Save(outputFilename, results, true, false);
}
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list