[mlpack] 61/207: Backport CLI interface to 2.2.x.

Barak A. Pearlmutter barak+git at pearlmutter.net
Thu Mar 23 17:53:41 UTC 2017


This is an automated email from the git hooks/post-receive script.

bap pushed a commit to branch master
in repository mlpack.

commit 682e938cc6c2f856234f93d9a4b573f38b7edf81
Author: Ryan Curtin <ryan at ratml.org>
Date:   Tue Mar 21 10:06:36 2017 -0400

    Backport CLI interface to 2.2.x.
---
 .../methods/decision_tree/decision_tree_main.cpp   | 84 ++++++++++++----------
 1 file changed, 45 insertions(+), 39 deletions(-)

diff --git a/src/mlpack/methods/decision_tree/decision_tree_main.cpp b/src/mlpack/methods/decision_tree/decision_tree_main.cpp
index 956ff1f..19c34c7 100644
--- a/src/mlpack/methods/decision_tree/decision_tree_main.cpp
+++ b/src/mlpack/methods/decision_tree/decision_tree_main.cpp
@@ -38,11 +38,11 @@ PROGRAM_INFO("Decision tree",
     "option.");
 
 // Datasets.
-PARAM_MATRIX_IN("training", "Matrix of training points.", "t");
-PARAM_UMATRIX_IN("labels", "Training labels.", "l");
-PARAM_MATRIX_IN("test", "Matrix of test points.", "T");
-PARAM_UMATRIX_IN("test_labels", "Test point labels, if accuracy calculation "
-    "is desired.", "L");
+PARAM_STRING_IN("training_file", "File containing training points.", "t", "");
+PARAM_STRING_IN("labels_file", "File containing training labels.", "l", "");
+PARAM_STRING_IN("test_file", "File containing test points.", "T", "");
+PARAM_STRING_IN("test_labels_file", "File containing test labels, if accuracy "
+    "calculation is desired.", "L", "");
 
 // Training parameters.
 PARAM_INT_IN("minimum_leaf_size", "Minimum number of points in a leaf.", "n",
@@ -50,9 +50,10 @@ PARAM_INT_IN("minimum_leaf_size", "Minimum number of points in a leaf.", "n",
 PARAM_FLAG("print_training_error", "Print the training error.", "e");
 
 // Output parameters.
-PARAM_MATRIX_OUT("probabilities", "Class probabilities for each test point.",
-    "P");
-PARAM_UMATRIX_OUT("predictions", "Class predictions for each test point.", "p");
+PARAM_STRING_IN("probabilities_file", "File to save class probabilities to for"
+    " each test point.", "P", "");
+PARAM_STRING_IN("predictions_file", "File to save class predictions to for "
+    "each test point.", "p", "");
 
 /**
  * This is the class that we will serialize.  It is a pretty simple wrapper
@@ -77,44 +78,46 @@ class DecisionTreeModel
 };
 
 // Models.
-PARAM_MODEL_IN(DecisionTreeModel, "input_model", "Pre-trained decision tree, "
-    "to be used with test points.", "m");
-PARAM_MODEL_OUT(DecisionTreeModel, "output_model", "Output for trained decision"
-    " tree.", "M");
+PARAM_STRING_IN("input_model_file", "File to load pre-trained decision tree "
+    "from, to be used with test points.", "m", "");
+PARAM_STRING_IN("output_model_file", "File to save trained decision tree to.",
+    "M", "");
 
 int main(int argc, char** argv)
 {
   CLI::ParseCommandLine(argc, argv);
 
   // Check parameters.
-  if (CLI::HasParam("training") && CLI::HasParam("input_model"))
+  if (CLI::HasParam("training_file") && CLI::HasParam("input_model_file"))
     Log::Fatal << "Cannot specify both --training_file and --input_model_file!"
         << endl;
 
-  if (CLI::HasParam("training") && !CLI::HasParam("labels"))
+  if (CLI::HasParam("training_file") && !CLI::HasParam("labels_file"))
     Log::Fatal << "Must specify --labels_file when --training_file is "
         << "specified!" << endl;
 
-  if (CLI::HasParam("test_labels") && !CLI::HasParam("test"))
+  if (CLI::HasParam("test_labels_file") && !CLI::HasParam("test_file"))
     Log::Warn << "--test_labels_file ignored because --test_file is not passed."
         << endl;
 
-  if (!CLI::HasParam("output_model") && !CLI::HasParam("probabilities") &&
-      !CLI::HasParam("predictions") && !CLI::HasParam("test_labels"))
+  if (!CLI::HasParam("output_model_file") &&
+      !CLI::HasParam("probabilities_file") &&
+      !CLI::HasParam("predictions_file") &&
+      !CLI::HasParam("test_labels_file"))
     Log::Warn << "None of --output_model_file, --probabilities_file, or "
         << "--predictions_file are given, and accuracy is not being calculated;"
         << " no output will be saved!" << endl;
 
-  if (CLI::HasParam("print_training_error") && !CLI::HasParam("training"))
+  if (CLI::HasParam("print_training_error") && !CLI::HasParam("training_file"))
     Log::Warn << "--print_training_error ignored because --training_file is not"
         << " specified." << endl;
 
-  if (!CLI::HasParam("test"))
+  if (!CLI::HasParam("test_file"))
   {
-    if (CLI::HasParam("probabilities"))
+    if (CLI::HasParam("probabilities_file"))
       Log::Warn << "--probabilities_file ignored because --test_file is not "
           << "specified." << endl;
-    if (CLI::HasParam("predictions"))
+    if (CLI::HasParam("predictions_file"))
       Log::Warn << "--predictions_file ignored because --test_file is not "
           << "specified." << endl;
   }
@@ -122,11 +125,12 @@ int main(int argc, char** argv)
   // Load the model or build the tree.
   DecisionTreeModel model;
 
-  if (CLI::HasParam("training"))
+  if (CLI::HasParam("training_file"))
   {
-    arma::mat dataset = std::move(CLI::GetParam<arma::mat>("training"));
-    arma::Mat<size_t> labels =
-        std::move(CLI::GetParam<arma::Mat<size_t>>("labels"));
+    arma::mat dataset;
+    data::Load(CLI::GetParam<std::string>("training_file"), dataset, true);
+    arma::Mat<size_t> labels;
+    data::Load(CLI::GetParam<std::string>("labels_file"), labels, true);
 
     // Calculate number of classes.
     const size_t numClasses = arma::max(arma::max(labels)) + 1;
@@ -158,13 +162,15 @@ int main(int argc, char** argv)
   }
   else
   {
-    model = std::move(CLI::GetParam<DecisionTreeModel>("input_model"));
+    data::Load(CLI::GetParam<std::string>("input_model_file"), "model", model,
+        true);
   }
 
   // Do we need to get predictions?
-  if (CLI::HasParam("test"))
+  if (CLI::HasParam("test_file"))
   {
-    arma::mat testPoints = std::move(CLI::GetParam<arma::mat>("test"));
+    arma::mat testPoints;
+    data::Load(CLI::GetParam<std::string>("test_file"), testPoints, true);
 
     arma::Row<size_t> predictions;
     arma::mat probabilities;
@@ -172,10 +178,11 @@ int main(int argc, char** argv)
     model.tree.Classify(testPoints, predictions, probabilities);
 
     // Do we need to calculate accuracy?
-    if (CLI::HasParam("test_labels"))
+    if (CLI::HasParam("test_labels_file"))
     {
-      arma::Mat<size_t> testLabels =
-          std::move(CLI::GetParam<arma::Mat<size_t>>("test_labels"));
+      arma::Mat<size_t> testLabels;
+      data::Load(CLI::GetParam<std::string>("test_labels_file"), testLabels,
+          true);
 
       size_t correct = 0;
       for (size_t i = 0; i < testPoints.n_cols; ++i)
@@ -189,15 +196,14 @@ int main(int argc, char** argv)
     }
 
     // Do we need to save outputs?
-    if (CLI::HasParam("predictions"))
-      CLI::GetParam<arma::Mat<size_t>>("predictions") = std::move(predictions);
-    if (CLI::HasParam("probabilities"))
-      CLI::GetParam<arma::mat>("probabilities") = std::move(probabilities);
+    if (CLI::HasParam("predictions_file"))
+      data::Save(CLI::GetParam<std::string>("prediction_file"), predictions);
+    if (CLI::HasParam("probabilities_file"))
+      data::Save(CLI::GetParam<std::string>("probabilities_file"),
+          probabilities);
   }
 
   // Do we need to save the model?
-  if (CLI::HasParam("output_model"))
-    CLI::GetParam<DecisionTreeModel>("output_model") = std::move(model);
-
-  CLI::Destroy();
+  if (CLI::HasParam("output_model_file"))
+    data::Save(CLI::GetParam<std::string>("output_model_file"), "model", model);
 }

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git



More information about the debian-science-commits mailing list