[mlpack] 56/207: Fix bugs in main program.

Barak A. Pearlmutter barak+git at pearlmutter.net
Thu Mar 23 17:53:40 UTC 2017


This is an automated email from the git hooks/post-receive script.

bap pushed a commit to branch master
in repository mlpack.

commit c3b82d4d3881789a1a2e246b3f4f0f710660229e
Author: Ryan Curtin <ryan at ratml.org>
Date:   Tue Jan 24 11:14:34 2017 -0500

    Fix bugs in main program.
---
 .../methods/decision_tree/decision_tree_main.cpp   | 84 +++++++++++++++++-----
 1 file changed, 68 insertions(+), 16 deletions(-)

diff --git a/src/mlpack/methods/decision_tree/decision_tree_main.cpp b/src/mlpack/methods/decision_tree/decision_tree_main.cpp
index c37ffdc..956ff1f 100644
--- a/src/mlpack/methods/decision_tree/decision_tree_main.cpp
+++ b/src/mlpack/methods/decision_tree/decision_tree_main.cpp
@@ -24,7 +24,10 @@ PROGRAM_INFO("Decision tree",
     "--output_model_file (-M) option.  A model may be loaded from file for "
     "predictions with the --input_model_file (-m) option.  The "
     "--input_model_file option may not be specified when the --training_file "
-    "option is specified."
+    "option is specified.  The --minimum_leaf_size (-n) parameter specifies "
+    "the minimum number of training points that must fall into each leaf for "
+    "it to be split.  If --print_training_error (-e) is specified, the training"
+    " error will be printed."
     "\n\n"
     "A file containing test data may be specified with the --test_file (-T) "
     "option, and if performance numbers are desired for that test set, labels "
@@ -41,21 +44,44 @@ PARAM_MATRIX_IN("test", "Matrix of test points.", "T");
 PARAM_UMATRIX_IN("test_labels", "Test point labels, if accuracy calculation "
     "is desired.", "L");
 
-// Models.
-PARAM_MODEL_IN(DecisionTree<>, "input_model", "Pre-trained decision tree, to be"
-    " used with test points.", "m");
-PARAM_MODEL_OUT(DecisionTree<>, "output_model", "Output for trained decision "
-    "tree.", "M");
-
 // Training parameters.
-PARAM_INT_IN("minimum_leaf_size", "Minimum number of points in a leaf.", "l",
+PARAM_INT_IN("minimum_leaf_size", "Minimum number of points in a leaf.", "n",
     20);
+PARAM_FLAG("print_training_error", "Print the training error.", "e");
 
 // Output parameters.
 PARAM_MATRIX_OUT("probabilities", "Class probabilities for each test point.",
     "P");
 PARAM_UMATRIX_OUT("predictions", "Class predictions for each test point.", "p");
 
+/**
+ * This is the class that we will serialize.  It is a pretty simple wrapper
+ * around DecisionTree<>.  In order to support categoricals, it will need to
+ * also hold and serialize a DatasetInfo.
+ */
+class DecisionTreeModel
+{
+ public:
+  // The tree itself, left public for direct access by this program.
+  DecisionTree<> tree;
+
+  // Create the model.
+  DecisionTreeModel() { /* Nothing to do. */ }
+
+  // Serialize the model.
+  template<typename Archive>
+  void Serialize(Archive& ar, const unsigned int /* version */)
+  {
+    ar & data::CreateNVP(tree, "tree");
+  }
+};
+
+// Models.
+PARAM_MODEL_IN(DecisionTreeModel, "input_model", "Pre-trained decision tree, "
+    "to be used with test points.", "m");
+PARAM_MODEL_OUT(DecisionTreeModel, "output_model", "Output for trained decision"
+    " tree.", "M");
+
 int main(int argc, char** argv)
 {
   CLI::ParseCommandLine(argc, argv);
@@ -79,6 +105,10 @@ int main(int argc, char** argv)
         << "--predictions_file are given, and accuracy is not being calculated;"
         << " no output will be saved!" << endl;
 
+  if (CLI::HasParam("print_training_error") && !CLI::HasParam("training"))
+    Log::Warn << "--print_training_error ignored because --training_file is not"
+        << " specified." << endl;
+
   if (!CLI::HasParam("test"))
   {
     if (CLI::HasParam("probabilities"))
@@ -90,7 +120,7 @@ int main(int argc, char** argv)
   }
 
   // Load the model or build the tree.
-  DecisionTree tree;
+  DecisionTreeModel model;
 
   if (CLI::HasParam("training"))
   {
@@ -99,16 +129,36 @@ int main(int argc, char** argv)
         std::move(CLI::GetParam<arma::Mat<size_t>>("labels"));
 
     // Calculate number of classes.
-    const size_t numClasses = arma::max(labels) + 1;
+    const size_t numClasses = arma::max(arma::max(labels)) + 1;
 
     // Now build the tree.
     const size_t minLeafSize = (size_t) CLI::GetParam<int>("minimum_leaf_size");
 
-    tree = DecisionTree(dataset, labels.row(0), numClasses, minLeafSize);
+    model.tree = DecisionTree<>(dataset, labels.row(0), numClasses,
+        minLeafSize);
+
+    // Do we need to print training error?
+    if (CLI::HasParam("print_training_error"))
+    {
+      arma::Row<size_t> predictions;
+      arma::mat probabilities;
+
+      model.tree.Classify(dataset, predictions, probabilities);
+
+      size_t correct = 0;
+      for (size_t i = 0; i < dataset.n_cols; ++i)
+        if (predictions[i] == labels[i])
+          ++correct;
+
+      // Print number of correct points.
+      Log::Info << double(correct) / double(dataset.n_cols) * 100 << "\% "
+          << "correct on training set (" << correct << " / " << dataset.n_cols
+          << ")." << endl;
+    }
   }
   else
   {
-    tree = std::move(CLI::GetParam<DecisionTree<>>("input_model"));
+    model = std::move(CLI::GetParam<DecisionTreeModel>("input_model"));
   }
 
   // Do we need to get predictions?
@@ -119,7 +169,7 @@ int main(int argc, char** argv)
     arma::Row<size_t> predictions;
     arma::mat probabilities;
 
-    tree.Classify(testPoints, predictions, probabilities);
+    model.tree.Classify(testPoints, predictions, probabilities);
 
     // Do we need to calculate accuracy?
     if (CLI::HasParam("test_labels"))
@@ -134,8 +184,8 @@ int main(int argc, char** argv)
 
       // Print number of correct points.
       Log::Info << double(correct) / double(testPoints.n_cols) * 100 << "\% "
-          << "correct (" << correct << " / " << testPoints.n_cols << ")."
-          << endl;
+          << "correct on test set (" << correct << " / " << testPoints.n_cols
+          << ")." << endl;
     }
 
     // Do we need to save outputs?
@@ -147,5 +197,7 @@ int main(int argc, char** argv)
 
   // Do we need to save the model?
   if (CLI::HasParam("output_model"))
-    CLI::GetParam<DecisionTree<>>("output_model") = std::move(tree);
+    CLI::GetParam<DecisionTreeModel>("output_model") = std::move(model);
+
+  CLI::Destroy();
 }

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git



More information about the debian-science-commits mailing list