[mlpack] 53/207: Add first pass at CLI program.

Barak A. Pearlmutter barak+git at pearlmutter.net
Thu Mar 23 17:53:40 UTC 2017


This is an automated email from the git hooks/post-receive script.

bap pushed a commit to branch master
in repository mlpack.

commit c37dcad982e12055f4931823b333cd1900d16d60
Author: Ryan Curtin <ryan at ratml.org>
Date:   Mon Jan 23 17:28:30 2017 -0500

    Add first pass at CLI program.
---
 .../methods/decision_tree/decision_tree_main.cpp   | 151 +++++++++++++++++++++
 1 file changed, 151 insertions(+)

diff --git a/src/mlpack/methods/decision_tree/decision_tree_main.cpp b/src/mlpack/methods/decision_tree/decision_tree_main.cpp
new file mode 100644
index 0000000..c37ffdc
--- /dev/null
+++ b/src/mlpack/methods/decision_tree/decision_tree_main.cpp
@@ -0,0 +1,151 @@
+/**
+ * @file decision_tree_main.cpp
+ * @author Ryan Curtin
+ *
+ * A command-line program to build a decision tree.
+ */
+#include <mlpack/core.hpp>
+#include "decision_tree.hpp"
+
+using namespace std;
+using namespace mlpack;
+using namespace mlpack::tree;
+
+PROGRAM_INFO("Decision tree",
+    "Train and evaluate using a decision tree.  Given a dataset containing "
+    "numeric features and associated labels for each point in the dataset, this"
+    " program can train a decision tree on that data."
+    "\n\n"
+    "The training file and associated labels are specified with the "
+    "--training_file and --labels_file options, respectively.  The labels "
+    "should be in the range [0, num_classes - 1]."
+    "\n\n"
+    "When a model is trained, it may be saved to file with the "
+    "--output_model_file (-M) option.  A model may be loaded from file for "
+    "predictions with the --input_model_file (-m) option.  The "
+    "--input_model_file option may not be specified when the --training_file "
+    "option is specified."
+    "\n\n"
+    "A file containing test data may be specified with the --test_file (-T) "
+    "option, and if performance numbers are desired for that test set, labels "
+    "may be specified with the --test_labels_file (-L) option.  Predictions f"
+    "for each test point may be stored into the file specified by the "
+    "--predictions_file (-p) option.  Class probabilities for each prediction "
+    "will be stored in the file specified by the --probabilities_file (-P) "
+    "option.");
+
+// Datasets.
+PARAM_MATRIX_IN("training", "Matrix of training points.", "t");
+PARAM_UMATRIX_IN("labels", "Training labels.", "l");
+PARAM_MATRIX_IN("test", "Matrix of test points.", "T");
+PARAM_UMATRIX_IN("test_labels", "Test point labels, if accuracy calculation "
+    "is desired.", "L");
+
+// Models.
+PARAM_MODEL_IN(DecisionTree<>, "input_model", "Pre-trained decision tree, to be"
+    " used with test points.", "m");
+PARAM_MODEL_OUT(DecisionTree<>, "output_model", "Output for trained decision "
+    "tree.", "M");
+
+// Training parameters.
+PARAM_INT_IN("minimum_leaf_size", "Minimum number of points in a leaf.", "l",
+    20);
+
+// Output parameters.
+PARAM_MATRIX_OUT("probabilities", "Class probabilities for each test point.",
+    "P");
+PARAM_UMATRIX_OUT("predictions", "Class predictions for each test point.", "p");
+
+int main(int argc, char** argv)
+{
+  CLI::ParseCommandLine(argc, argv);
+
+  // Check parameters.
+  if (CLI::HasParam("training") && CLI::HasParam("input_model"))
+    Log::Fatal << "Cannot specify both --training_file and --input_model_file!"
+        << endl;
+
+  if (CLI::HasParam("training") && !CLI::HasParam("labels"))
+    Log::Fatal << "Must specify --labels_file when --training_file is "
+        << "specified!" << endl;
+
+  if (CLI::HasParam("test_labels") && !CLI::HasParam("test"))
+    Log::Warn << "--test_labels_file ignored because --test_file is not passed."
+        << endl;
+
+  if (!CLI::HasParam("output_model") && !CLI::HasParam("probabilities") &&
+      !CLI::HasParam("predictions") && !CLI::HasParam("test_labels"))
+    Log::Warn << "None of --output_model_file, --probabilities_file, or "
+        << "--predictions_file are given, and accuracy is not being calculated;"
+        << " no output will be saved!" << endl;
+
+  if (!CLI::HasParam("test"))
+  {
+    if (CLI::HasParam("probabilities"))
+      Log::Warn << "--probabilities_file ignored because --test_file is not "
+          << "specified." << endl;
+    if (CLI::HasParam("predictions"))
+      Log::Warn << "--predictions_file ignored because --test_file is not "
+          << "specified." << endl;
+  }
+
+  // Load the model or build the tree.
+  DecisionTree tree;
+
+  if (CLI::HasParam("training"))
+  {
+    arma::mat dataset = std::move(CLI::GetParam<arma::mat>("training"));
+    arma::Mat<size_t> labels =
+        std::move(CLI::GetParam<arma::Mat<size_t>>("labels"));
+
+    // Calculate number of classes.
+    const size_t numClasses = arma::max(labels) + 1;
+
+    // Now build the tree.
+    const size_t minLeafSize = (size_t) CLI::GetParam<int>("minimum_leaf_size");
+
+    tree = DecisionTree(dataset, labels.row(0), numClasses, minLeafSize);
+  }
+  else
+  {
+    tree = std::move(CLI::GetParam<DecisionTree<>>("input_model"));
+  }
+
+  // Do we need to get predictions?
+  if (CLI::HasParam("test"))
+  {
+    arma::mat testPoints = std::move(CLI::GetParam<arma::mat>("test"));
+
+    arma::Row<size_t> predictions;
+    arma::mat probabilities;
+
+    tree.Classify(testPoints, predictions, probabilities);
+
+    // Do we need to calculate accuracy?
+    if (CLI::HasParam("test_labels"))
+    {
+      arma::Mat<size_t> testLabels =
+          std::move(CLI::GetParam<arma::Mat<size_t>>("test_labels"));
+
+      size_t correct = 0;
+      for (size_t i = 0; i < testPoints.n_cols; ++i)
+        if (predictions[i] == testLabels[i])
+          ++correct;
+
+      // Print number of correct points.
+      Log::Info << double(correct) / double(testPoints.n_cols) * 100 << "\% "
+          << "correct (" << correct << " / " << testPoints.n_cols << ")."
+          << endl;
+    }
+
+    // Do we need to save outputs?
+    if (CLI::HasParam("predictions"))
+      CLI::GetParam<arma::Mat<size_t>>("predictions") = std::move(predictions);
+    if (CLI::HasParam("probabilities"))
+      CLI::GetParam<arma::mat>("probabilities") = std::move(probabilities);
+  }
+
+  // Do we need to save the model?
+  if (CLI::HasParam("output_model"))
+    CLI::GetParam<DecisionTree<>>("output_model") = std::move(tree);
+}

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git



More information about the debian-science-commits mailing list