[mlpack] 183/207: Change Predict() to wrap Classify()

Barak A. Pearlmutter barak+git at pearlmutter.net
Thu Mar 23 17:53:52 UTC 2017


This is an automated email from the git hooks/post-receive script.

bap pushed a commit to branch master
in repository mlpack.

commit 0b49d05e9175d05b545bb759ce0d138db5ab624f
Author: Kirill Mishchenko <ki.mishchenko at gmail.com>
Date:   Thu Mar 16 07:22:23 2017 +0500

    Change Predict() to wrap Classify()
---
 .../softmax_regression/softmax_regression_impl.hpp | 32 +++++++++++-----------
 1 file changed, 16 insertions(+), 16 deletions(-)

diff --git a/src/mlpack/methods/softmax_regression/softmax_regression_impl.hpp b/src/mlpack/methods/softmax_regression/softmax_regression_impl.hpp
index 0a915f2..5538383 100644
--- a/src/mlpack/methods/softmax_regression/softmax_regression_impl.hpp
+++ b/src/mlpack/methods/softmax_regression/softmax_regression_impl.hpp
@@ -66,10 +66,18 @@ void SoftmaxRegression<OptimizerType>::Predict(const arma::mat& testData,
                                                arma::Row<size_t>& predictions)
     const
 {
-  if (testData.n_rows != FeatureSize())
+  Classify(testData, predictions);
+}
+
+template<template<typename> class OptimizerType>
+void SoftmaxRegression<OptimizerType>::Classify(const arma::mat& dataset,
+                                                arma::Row<size_t>& labels)
+    const
+{
+  if (dataset.n_rows != FeatureSize())
   {
     std::ostringstream oss;
-    oss << "SoftmaxRegression::Predict(): test data has " << testData.n_rows
+    oss << "SoftmaxRegression::Classify(): dataset has " << dataset.n_rows
         << " dimensions, but model has " << FeatureSize() << "dimensions";
     throw std::invalid_argument(oss.str());
   }
@@ -85,23 +93,23 @@ void SoftmaxRegression<OptimizerType>::Predict(const arma::mat& testData,
     // Since the cost of join maybe high due to the copy of original data,
     // split the hypothesis computation to two components.
     hypothesis = arma::exp(
-      arma::repmat(parameters.col(0), 1, testData.n_cols) +
-      parameters.cols(1, parameters.n_cols - 1) * testData);
+      arma::repmat(parameters.col(0), 1, dataset.n_cols) +
+      parameters.cols(1, parameters.n_cols - 1) * dataset);
   }
   else
   {
-    hypothesis = arma::exp(parameters * testData);
+    hypothesis = arma::exp(parameters * dataset);
   }
 
   probabilities = hypothesis / arma::repmat(arma::sum(hypothesis, 0),
                                             numClasses, 1);
 
   // Prepare necessary data.
-  predictions.zeros(testData.n_cols);
+  labels.zeros(dataset.n_cols);
   double maxProbability = 0;
 
   // For each test input.
-  for (size_t i = 0; i < testData.n_cols; i++)
+  for (size_t i = 0; i < dataset.n_cols; i++)
   {
     // For each class.
     for (size_t j = 0; j < numClasses; j++)
@@ -110,7 +118,7 @@ void SoftmaxRegression<OptimizerType>::Predict(const arma::mat& testData,
       if (probabilities(j, i) > maxProbability)
       {
         maxProbability = probabilities(j, i);
-        predictions(i) = j;
+        labels(i) = j;
       }
     }
 
@@ -120,14 +128,6 @@ void SoftmaxRegression<OptimizerType>::Predict(const arma::mat& testData,
 }
 
 template<template<typename> class OptimizerType>
-void SoftmaxRegression<OptimizerType>::Classify(const arma::mat& dataset,
-                                                arma::Row<size_t>& labels)
-    const
-{
-  Predict(dataset, labels);
-}
-
-template<template<typename> class OptimizerType>
 double SoftmaxRegression<OptimizerType>::ComputeAccuracy(
     const arma::mat& testData,
     const arma::Row<size_t>& labels) const

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git



More information about the debian-science-commits mailing list