[mlpack] 183/207: Change Predict() to wrap Classify()
Barak A. Pearlmutter
barak+git at pearlmutter.net
Thu Mar 23 17:53:52 UTC 2017
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch master
in repository mlpack.
commit 0b49d05e9175d05b545bb759ce0d138db5ab624f
Author: Kirill Mishchenko <ki.mishchenko at gmail.com>
Date: Thu Mar 16 07:22:23 2017 +0500
Change Predict() to wrap Classify()
---
.../softmax_regression/softmax_regression_impl.hpp | 32 +++++++++++-----------
1 file changed, 16 insertions(+), 16 deletions(-)
diff --git a/src/mlpack/methods/softmax_regression/softmax_regression_impl.hpp b/src/mlpack/methods/softmax_regression/softmax_regression_impl.hpp
index 0a915f2..5538383 100644
--- a/src/mlpack/methods/softmax_regression/softmax_regression_impl.hpp
+++ b/src/mlpack/methods/softmax_regression/softmax_regression_impl.hpp
@@ -66,10 +66,18 @@ void SoftmaxRegression<OptimizerType>::Predict(const arma::mat& testData,
arma::Row<size_t>& predictions)
const
{
- if (testData.n_rows != FeatureSize())
+ Classify(testData, predictions);
+}
+
+template<template<typename> class OptimizerType>
+void SoftmaxRegression<OptimizerType>::Classify(const arma::mat& dataset,
+ arma::Row<size_t>& labels)
+ const
+{
+ if (dataset.n_rows != FeatureSize())
{
std::ostringstream oss;
- oss << "SoftmaxRegression::Predict(): test data has " << testData.n_rows
+ oss << "SoftmaxRegression::Classify(): dataset has " << dataset.n_rows
<< " dimensions, but model has " << FeatureSize() << "dimensions";
throw std::invalid_argument(oss.str());
}
@@ -85,23 +93,23 @@ void SoftmaxRegression<OptimizerType>::Predict(const arma::mat& testData,
// Since the cost of join maybe high due to the copy of original data,
// split the hypothesis computation to two components.
hypothesis = arma::exp(
- arma::repmat(parameters.col(0), 1, testData.n_cols) +
- parameters.cols(1, parameters.n_cols - 1) * testData);
+ arma::repmat(parameters.col(0), 1, dataset.n_cols) +
+ parameters.cols(1, parameters.n_cols - 1) * dataset);
}
else
{
- hypothesis = arma::exp(parameters * testData);
+ hypothesis = arma::exp(parameters * dataset);
}
probabilities = hypothesis / arma::repmat(arma::sum(hypothesis, 0),
numClasses, 1);
// Prepare necessary data.
- predictions.zeros(testData.n_cols);
+ labels.zeros(dataset.n_cols);
double maxProbability = 0;
// For each test input.
- for (size_t i = 0; i < testData.n_cols; i++)
+ for (size_t i = 0; i < dataset.n_cols; i++)
{
// For each class.
for (size_t j = 0; j < numClasses; j++)
@@ -110,7 +118,7 @@ void SoftmaxRegression<OptimizerType>::Predict(const arma::mat& testData,
if (probabilities(j, i) > maxProbability)
{
maxProbability = probabilities(j, i);
- predictions(i) = j;
+ labels(i) = j;
}
}
@@ -120,14 +128,6 @@ void SoftmaxRegression<OptimizerType>::Predict(const arma::mat& testData,
}
template<template<typename> class OptimizerType>
-void SoftmaxRegression<OptimizerType>::Classify(const arma::mat& dataset,
- arma::Row<size_t>& labels)
- const
-{
- Predict(dataset, labels);
-}
-
-template<template<typename> class OptimizerType>
double SoftmaxRegression<OptimizerType>::ComputeAccuracy(
const arma::mat& testData,
const arma::Row<size_t>& labels) const
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list