[mlpack] 56/207: Fix bugs in main program.
Barak A. Pearlmutter
barak+git at pearlmutter.net
Thu Mar 23 17:53:40 UTC 2017
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch master
in repository mlpack.
commit c3b82d4d3881789a1a2e246b3f4f0f710660229e
Author: Ryan Curtin <ryan at ratml.org>
Date: Tue Jan 24 11:14:34 2017 -0500
Fix bugs in main program.
---
.../methods/decision_tree/decision_tree_main.cpp | 84 +++++++++++++++++-----
1 file changed, 68 insertions(+), 16 deletions(-)
diff --git a/src/mlpack/methods/decision_tree/decision_tree_main.cpp b/src/mlpack/methods/decision_tree/decision_tree_main.cpp
index c37ffdc..956ff1f 100644
--- a/src/mlpack/methods/decision_tree/decision_tree_main.cpp
+++ b/src/mlpack/methods/decision_tree/decision_tree_main.cpp
@@ -24,7 +24,10 @@ PROGRAM_INFO("Decision tree",
"--output_model_file (-M) option. A model may be loaded from file for "
"predictions with the --input_model_file (-m) option. The "
"--input_model_file option may not be specified when the --training_file "
- "option is specified."
+ "option is specified. The --minimum_leaf_size (-n) parameter specifies "
+ "the minimum number of training points that must fall into each leaf for "
+ "it to be split. If --print_training_error (-e) is specified, the training"
+ " error will be printed."
"\n\n"
"A file containing test data may be specified with the --test_file (-T) "
"option, and if performance numbers are desired for that test set, labels "
@@ -41,21 +44,44 @@ PARAM_MATRIX_IN("test", "Matrix of test points.", "T");
PARAM_UMATRIX_IN("test_labels", "Test point labels, if accuracy calculation "
"is desired.", "L");
-// Models.
-PARAM_MODEL_IN(DecisionTree<>, "input_model", "Pre-trained decision tree, to be"
- " used with test points.", "m");
-PARAM_MODEL_OUT(DecisionTree<>, "output_model", "Output for trained decision "
- "tree.", "M");
-
// Training parameters.
-PARAM_INT_IN("minimum_leaf_size", "Minimum number of points in a leaf.", "l",
+PARAM_INT_IN("minimum_leaf_size", "Minimum number of points in a leaf.", "n",
20);
+PARAM_FLAG("print_training_error", "Print the training error.", "e");
// Output parameters.
PARAM_MATRIX_OUT("probabilities", "Class probabilities for each test point.",
"P");
PARAM_UMATRIX_OUT("predictions", "Class predictions for each test point.", "p");
+/**
+ * This is the class that we will serialize. It is a pretty simple wrapper
+ * around DecisionTree<>. In order to support categoricals, it will need to
+ * also hold and serialize a DatasetInfo.
+ */
+class DecisionTreeModel
+{
+ public:
+ // The tree itself, left public for direct access by this program.
+ DecisionTree<> tree;
+
+ // Create the model.
+ DecisionTreeModel() { /* Nothing to do. */ }
+
+ // Serialize the model.
+ template<typename Archive>
+ void Serialize(Archive& ar, const unsigned int /* version */)
+ {
+ ar & data::CreateNVP(tree, "tree");
+ }
+};
+
+// Models.
+PARAM_MODEL_IN(DecisionTreeModel, "input_model", "Pre-trained decision tree, "
+ "to be used with test points.", "m");
+PARAM_MODEL_OUT(DecisionTreeModel, "output_model", "Output for trained decision"
+ " tree.", "M");
+
int main(int argc, char** argv)
{
CLI::ParseCommandLine(argc, argv);
@@ -79,6 +105,10 @@ int main(int argc, char** argv)
<< "--predictions_file are given, and accuracy is not being calculated;"
<< " no output will be saved!" << endl;
+ if (CLI::HasParam("print_training_error") && !CLI::HasParam("training"))
+ Log::Warn << "--print_training_error ignored because --training_file is not"
+ << " specified." << endl;
+
if (!CLI::HasParam("test"))
{
if (CLI::HasParam("probabilities"))
@@ -90,7 +120,7 @@ int main(int argc, char** argv)
}
// Load the model or build the tree.
- DecisionTree tree;
+ DecisionTreeModel model;
if (CLI::HasParam("training"))
{
@@ -99,16 +129,36 @@ int main(int argc, char** argv)
std::move(CLI::GetParam<arma::Mat<size_t>>("labels"));
// Calculate number of classes.
- const size_t numClasses = arma::max(labels) + 1;
+ const size_t numClasses = arma::max(arma::max(labels)) + 1;
// Now build the tree.
const size_t minLeafSize = (size_t) CLI::GetParam<int>("minimum_leaf_size");
- tree = DecisionTree(dataset, labels.row(0), numClasses, minLeafSize);
+ model.tree = DecisionTree<>(dataset, labels.row(0), numClasses,
+ minLeafSize);
+
+ // Do we need to print training error?
+ if (CLI::HasParam("print_training_error"))
+ {
+ arma::Row<size_t> predictions;
+ arma::mat probabilities;
+
+ model.tree.Classify(dataset, predictions, probabilities);
+
+ size_t correct = 0;
+ for (size_t i = 0; i < dataset.n_cols; ++i)
+ if (predictions[i] == labels[i])
+ ++correct;
+
+ // Print number of correct points.
+ Log::Info << double(correct) / double(dataset.n_cols) * 100 << "\% "
+ << "correct on training set (" << correct << " / " << dataset.n_cols
+ << ")." << endl;
+ }
}
else
{
- tree = std::move(CLI::GetParam<DecisionTree<>>("input_model"));
+ model = std::move(CLI::GetParam<DecisionTreeModel>("input_model"));
}
// Do we need to get predictions?
@@ -119,7 +169,7 @@ int main(int argc, char** argv)
arma::Row<size_t> predictions;
arma::mat probabilities;
- tree.Classify(testPoints, predictions, probabilities);
+ model.tree.Classify(testPoints, predictions, probabilities);
// Do we need to calculate accuracy?
if (CLI::HasParam("test_labels"))
@@ -134,8 +184,8 @@ int main(int argc, char** argv)
// Print number of correct points.
Log::Info << double(correct) / double(testPoints.n_cols) * 100 << "\% "
- << "correct (" << correct << " / " << testPoints.n_cols << ")."
- << endl;
+ << "correct on test set (" << correct << " / " << testPoints.n_cols
+ << ")." << endl;
}
// Do we need to save outputs?
@@ -147,5 +197,7 @@ int main(int argc, char** argv)
// Do we need to save the model?
if (CLI::HasParam("output_model"))
- CLI::GetParam<DecisionTree<>>("output_model") = std::move(tree);
+ CLI::GetParam<DecisionTreeModel>("output_model") = std::move(model);
+
+ CLI::Destroy();
}
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list