[mlpack] 61/207: Backport CLI interface to 2.2.x.
Barak A. Pearlmutter
barak+git at pearlmutter.net
Thu Mar 23 17:53:41 UTC 2017
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch master
in repository mlpack.
commit 682e938cc6c2f856234f93d9a4b573f38b7edf81
Author: Ryan Curtin <ryan at ratml.org>
Date: Tue Mar 21 10:06:36 2017 -0400
Backport CLI interface to 2.2.x.
---
.../methods/decision_tree/decision_tree_main.cpp | 84 ++++++++++++----------
1 file changed, 45 insertions(+), 39 deletions(-)
diff --git a/src/mlpack/methods/decision_tree/decision_tree_main.cpp b/src/mlpack/methods/decision_tree/decision_tree_main.cpp
index 956ff1f..19c34c7 100644
--- a/src/mlpack/methods/decision_tree/decision_tree_main.cpp
+++ b/src/mlpack/methods/decision_tree/decision_tree_main.cpp
@@ -38,11 +38,11 @@ PROGRAM_INFO("Decision tree",
"option.");
// Datasets.
-PARAM_MATRIX_IN("training", "Matrix of training points.", "t");
-PARAM_UMATRIX_IN("labels", "Training labels.", "l");
-PARAM_MATRIX_IN("test", "Matrix of test points.", "T");
-PARAM_UMATRIX_IN("test_labels", "Test point labels, if accuracy calculation "
- "is desired.", "L");
+PARAM_STRING_IN("training_file", "File containing training points.", "t", "");
+PARAM_STRING_IN("labels_file", "File containing training labels.", "l", "");
+PARAM_STRING_IN("test_file", "File containing test points.", "T", "");
+PARAM_STRING_IN("test_labels_file", "File containing test labels, if accuracy "
+ "calculation is desired.", "L", "");
// Training parameters.
PARAM_INT_IN("minimum_leaf_size", "Minimum number of points in a leaf.", "n",
@@ -50,9 +50,10 @@ PARAM_INT_IN("minimum_leaf_size", "Minimum number of points in a leaf.", "n",
PARAM_FLAG("print_training_error", "Print the training error.", "e");
// Output parameters.
-PARAM_MATRIX_OUT("probabilities", "Class probabilities for each test point.",
- "P");
-PARAM_UMATRIX_OUT("predictions", "Class predictions for each test point.", "p");
+PARAM_STRING_IN("probabilities_file", "File to save class probabilities to for"
+ " each test point.", "P", "");
+PARAM_STRING_IN("predictions_file", "File to save class predictions to for "
+ "each test point.", "p", "");
/**
* This is the class that we will serialize. It is a pretty simple wrapper
@@ -77,44 +78,46 @@ class DecisionTreeModel
};
// Models.
-PARAM_MODEL_IN(DecisionTreeModel, "input_model", "Pre-trained decision tree, "
- "to be used with test points.", "m");
-PARAM_MODEL_OUT(DecisionTreeModel, "output_model", "Output for trained decision"
- " tree.", "M");
+PARAM_STRING_IN("input_model_file", "File to load pre-trained decision tree "
+ "from, to be used with test points.", "m", "");
+PARAM_STRING_IN("output_model_file", "File to save trained decision tree to.",
+ "M", "");
int main(int argc, char** argv)
{
CLI::ParseCommandLine(argc, argv);
// Check parameters.
- if (CLI::HasParam("training") && CLI::HasParam("input_model"))
+ if (CLI::HasParam("training_file") && CLI::HasParam("input_model_file"))
Log::Fatal << "Cannot specify both --training_file and --input_model_file!"
<< endl;
- if (CLI::HasParam("training") && !CLI::HasParam("labels"))
+ if (CLI::HasParam("training_file") && !CLI::HasParam("labels_file"))
Log::Fatal << "Must specify --labels_file when --training_file is "
<< "specified!" << endl;
- if (CLI::HasParam("test_labels") && !CLI::HasParam("test"))
+ if (CLI::HasParam("test_labels_file") && !CLI::HasParam("test_file"))
Log::Warn << "--test_labels_file ignored because --test_file is not passed."
<< endl;
- if (!CLI::HasParam("output_model") && !CLI::HasParam("probabilities") &&
- !CLI::HasParam("predictions") && !CLI::HasParam("test_labels"))
+ if (!CLI::HasParam("output_model_file") &&
+ !CLI::HasParam("probabilities_file") &&
+ !CLI::HasParam("predictions_file") &&
+ !CLI::HasParam("test_labels_file"))
Log::Warn << "None of --output_model_file, --probabilities_file, or "
<< "--predictions_file are given, and accuracy is not being calculated;"
<< " no output will be saved!" << endl;
- if (CLI::HasParam("print_training_error") && !CLI::HasParam("training"))
+ if (CLI::HasParam("print_training_error") && !CLI::HasParam("training_file"))
Log::Warn << "--print_training_error ignored because --training_file is not"
<< " specified." << endl;
- if (!CLI::HasParam("test"))
+ if (!CLI::HasParam("test_file"))
{
- if (CLI::HasParam("probabilities"))
+ if (CLI::HasParam("probabilities_file"))
Log::Warn << "--probabilities_file ignored because --test_file is not "
<< "specified." << endl;
- if (CLI::HasParam("predictions"))
+ if (CLI::HasParam("predictions_file"))
Log::Warn << "--predictions_file ignored because --test_file is not "
<< "specified." << endl;
}
@@ -122,11 +125,12 @@ int main(int argc, char** argv)
// Load the model or build the tree.
DecisionTreeModel model;
- if (CLI::HasParam("training"))
+ if (CLI::HasParam("training_file"))
{
- arma::mat dataset = std::move(CLI::GetParam<arma::mat>("training"));
- arma::Mat<size_t> labels =
- std::move(CLI::GetParam<arma::Mat<size_t>>("labels"));
+ arma::mat dataset;
+ data::Load(CLI::GetParam<std::string>("training_file"), dataset, true);
+ arma::Mat<size_t> labels;
+ data::Load(CLI::GetParam<std::string>("labels_file"), labels, true);
// Calculate number of classes.
const size_t numClasses = arma::max(arma::max(labels)) + 1;
@@ -158,13 +162,15 @@ int main(int argc, char** argv)
}
else
{
- model = std::move(CLI::GetParam<DecisionTreeModel>("input_model"));
+ data::Load(CLI::GetParam<std::string>("input_model_file"), "model", model,
+ true);
}
// Do we need to get predictions?
- if (CLI::HasParam("test"))
+ if (CLI::HasParam("test_file"))
{
- arma::mat testPoints = std::move(CLI::GetParam<arma::mat>("test"));
+ arma::mat testPoints;
+ data::Load(CLI::GetParam<std::string>("test_file"), testPoints, true);
arma::Row<size_t> predictions;
arma::mat probabilities;
@@ -172,10 +178,11 @@ int main(int argc, char** argv)
model.tree.Classify(testPoints, predictions, probabilities);
// Do we need to calculate accuracy?
- if (CLI::HasParam("test_labels"))
+ if (CLI::HasParam("test_labels_file"))
{
- arma::Mat<size_t> testLabels =
- std::move(CLI::GetParam<arma::Mat<size_t>>("test_labels"));
+ arma::Mat<size_t> testLabels;
+ data::Load(CLI::GetParam<std::string>("test_labels_file"), testLabels,
+ true);
size_t correct = 0;
for (size_t i = 0; i < testPoints.n_cols; ++i)
@@ -189,15 +196,14 @@ int main(int argc, char** argv)
}
// Do we need to save outputs?
- if (CLI::HasParam("predictions"))
- CLI::GetParam<arma::Mat<size_t>>("predictions") = std::move(predictions);
- if (CLI::HasParam("probabilities"))
- CLI::GetParam<arma::mat>("probabilities") = std::move(probabilities);
+ if (CLI::HasParam("predictions_file"))
+ data::Save(CLI::GetParam<std::string>("prediction_file"), predictions);
+ if (CLI::HasParam("probabilities_file"))
+ data::Save(CLI::GetParam<std::string>("probabilities_file"),
+ probabilities);
}
// Do we need to save the model?
- if (CLI::HasParam("output_model"))
- CLI::GetParam<DecisionTreeModel>("output_model") = std::move(model);
-
- CLI::Destroy();
+ if (CLI::HasParam("output_model_file"))
+ data::Save(CLI::GetParam<std::string>("output_model_file"), "model", model);
}
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list