[mlpack] 46/58: Adding Softmax Regression module.
Barak A. Pearlmutter
barak+git at cs.nuim.ie
Tue Sep 9 13:19:42 UTC 2014
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch svn-trunk
in repository mlpack.
commit a57bd44b59bff4d5f8436252be64927eb3cd605a
Author: siddharth.950 <siddharth.950 at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date: Sat Aug 23 13:22:42 2014 +0000
Adding Softmax Regression module.
git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@17103 9d5b8971-822b-0410-80eb-d18c1038ef23
---
src/mlpack/methods/CMakeLists.txt | 1 +
.../methods/softmax_regression/CMakeLists.txt | 17 ++
.../softmax_regression/softmax_regression.hpp | 167 ++++++++++++
.../softmax_regression_function.cpp | 142 ++++++++++
.../softmax_regression_function.hpp | 127 +++++++++
.../softmax_regression/softmax_regression_impl.hpp | 115 ++++++++
.../sparse_autoencoder_function.hpp | 2 +-
src/mlpack/tests/CMakeLists.txt | 1 +
src/mlpack/tests/softmax_regression_test.cpp | 301 +++++++++++++++++++++
9 files changed, 872 insertions(+), 1 deletion(-)
diff --git a/src/mlpack/methods/CMakeLists.txt b/src/mlpack/methods/CMakeLists.txt
index 8ea92d2..8aefd69 100644
--- a/src/mlpack/methods/CMakeLists.txt
+++ b/src/mlpack/methods/CMakeLists.txt
@@ -29,6 +29,7 @@ set(DIRS
range_search
rann
regularized_svd
+ softmax_regression
sparse_autoencoder
sparse_coding
nystroem_method
diff --git a/src/mlpack/methods/softmax_regression/CMakeLists.txt b/src/mlpack/methods/softmax_regression/CMakeLists.txt
new file mode 100644
index 0000000..df2a33f
--- /dev/null
+++ b/src/mlpack/methods/softmax_regression/CMakeLists.txt
@@ -0,0 +1,17 @@
+# Define the files we need to compile.
+# Anything not in this list will not be compiled into MLPACK.
+set(SOURCES
+ softmax_regression.hpp
+ softmax_regression_impl.hpp
+ softmax_regression_function.hpp
+ softmax_regression_function.cpp
+)
+
+# Add directory name to sources.
+set(DIR_SRCS)
+foreach(file ${SOURCES})
+ set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file})
+endforeach()
+# Append sources (with directory name) to list of all MLPACK sources (used at
+# the parent scope).
+set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)
diff --git a/src/mlpack/methods/softmax_regression/softmax_regression.hpp b/src/mlpack/methods/softmax_regression/softmax_regression.hpp
new file mode 100644
index 0000000..88db1c3
--- /dev/null
+++ b/src/mlpack/methods/softmax_regression/softmax_regression.hpp
@@ -0,0 +1,167 @@
+/**
+ * @file softmax_regression.hpp
+ * @author Siddharth Agrawal
+ *
+ * An implementation of softmax regression.
+ */
+#ifndef __MLPACK_METHODS_SOFTMAX_REGRESSION_SOFTMAX_REGRESSION_HPP
+#define __MLPACK_METHODS_SOFTMAX_REGRESSION_SOFTMAX_REGRESSION_HPP
+
+#include <mlpack/core.hpp>
+#include <mlpack/core/optimizers/lbfgs/lbfgs.hpp>
+
+#include "softmax_regression_function.hpp"
+
+namespace mlpack {
+namespace regression {
+
+/**
+ * Softmax Regression is a classifier which can be used for classification when
+ * the data available can take two or more class values. It is a generalization
+ * of Logistic Regression (which is used only for binary classification). The
+ * model has a different set of parameters for each class, but can be easily
+ * converted into a vectorized implementation as has been done in this module.
+ * The model can be used for direct classification of feature data or in
+ * conjunction with unsupervised learning methods. More technical details about
+ * the model can be found on the following webpage:
+ *
+ * http://ufldl.stanford.edu/wiki/index.php/Softmax_Regression
+ *
+ * An example on how to use the interface is shown below:
+ *
+ * @code
+ * arma::mat train_data; // Training data matrix.
+ * arma::vec labels; // Labels associated with the data.
+ * const size_t inputSize = 784; // Size of input feature vector.
+ * const size_t numClasses = 10; // Number of classes.
+ *
+ * // Train the model using default options.
+ * SoftmaxRegression<> regressor1(train_data, labels, inputSize, numClasses);
+ *
+ * const size_t numBasis = 5; // Parameter required for L-BFGS algorithm.
+ * const size_t numIterations = 100; // Maximum number of iterations.
+ *
+ * // Use an instantiated optimizer for the training.
+ * SoftmaxRegressionFunction srf(train_data, labels, inputSize, numClasses);
+ * L_BFGS<SoftmaxRegressionFunction> optimizer(srf, numBasis, numIterations);
+ * SoftmaxRegression<L_BFGS> regressor2(optimizer);
+ *
+ * arma::mat test_data; // Test data matrix.
+ * arma::vec predictions1, predictions2; // Vectors to store predictions in.
+ *
+ * // Obtain predictions from both the learned models.
+ * regressor1.Predict(test_data, predictions1);
+ * regressor2.Predict(test_data, predictions2);
+ * @endcode
+ */
+
+template<
+ template<typename> class OptimizerType = mlpack::optimization::L_BFGS
+>
+class SoftmaxRegression
+{
+ public:
+
+ /**
+ * Construct the SoftmaxRegression class with the provided data and labels.
+ * This will train the model. Optionally, the parameter 'lambda' can be
+ * passed, which controls the amount of L2-regularization in the objective
+ * function. By default, the model takes a small value.
+ *
+ * @param data Input training features.
+ * @param labels Labels associated with the feature data.
+ * @param inputSize Size of the input feature vector.
+ * @param numClasses Number of classes for classification.
+ * @param lambda L2-regularization constant.
+ */
+ SoftmaxRegression(const arma::mat& data,
+ const arma::vec& labels,
+ const size_t inputSize,
+ const size_t numClasses,
+ const double lambda = 0.0001);
+
+ /**
+ * Construct the softmax regression model with the given training data. This
+ * will train the model. This overload takes an already instantiated optimizer
+ * and uses it to train the model. The optimizer should hold an instantiated
+ * SoftmaxRegressionFunction object for the function to operate upon. This
+ * option should be preferred when the optimizer options are to be changed.
+ *
+ * @param optimizer Instantiated optimizer with instantiated error function.
+ */
+ SoftmaxRegression(OptimizerType<SoftmaxRegressionFunction>& optimizer);
+
+ /**
+ * Predict the class labels for the provided feature points. The function
+ * calculates the probabilities for every class, given a data point. It then
+ * chooses the class which has the highest probability among all.
+ *
+ * @param testData Matrix of data points for which predictions are to be made.
+ * @param predictions Vector to store the predictions in.
+ */
+ void Predict(const arma::mat& testData, arma::vec& predictions);
+
+ /**
+ * Computes accuracy of the learned model given the feature data and the
+ * labels associated with each data point. Predictions are made using the
+ * provided data and are compared with the actual labels.
+ *
+ * @param testData Matrix of data points using which predictions are made.
+ * @param labels Vector of labels associated with the data.
+ */
+ double ComputeAccuracy(const arma::mat& testData, const arma::vec& labels);
+
+ //! Sets the size of the input vector.
+ void InputSize(const size_t input)
+ {
+ this->inputSize = input;
+ }
+
+ //! Gets the size of the input vector.
+ size_t InputSize() const
+ {
+ return inputSize;
+ }
+
+ //! Sets the number of classes.
+ void NumClasses(const size_t classes)
+ {
+ this->numClasses = classes;
+ }
+
+ //! Gets the number of classes.
+ size_t NumClasses() const
+ {
+ return numClasses;
+ }
+
+ //! Sets the regularization parameter.
+ void Lambda(const double l)
+ {
+ this->lambda = l;
+ }
+
+ //! Gets the regularization parameter.
+ double Lambda() const
+ {
+ return lambda;
+ }
+
+ private:
+ //! Parameters after optimization.
+ arma::mat parameters;
+ //! Size of input feature vector.
+ size_t inputSize;
+ //! Number of classes.
+ size_t numClasses;
+ //! L2-regularization constant.
+ double lambda;
+};
+
+}; // namespace regression
+}; // namespace mlpack
+
+// Include implementation.
+#include "softmax_regression_impl.hpp"
+
+#endif
diff --git a/src/mlpack/methods/softmax_regression/softmax_regression_function.cpp b/src/mlpack/methods/softmax_regression/softmax_regression_function.cpp
new file mode 100644
index 0000000..f0522d3
--- /dev/null
+++ b/src/mlpack/methods/softmax_regression/softmax_regression_function.cpp
@@ -0,0 +1,142 @@
+/**
+ * @file softmax_regression_function.cpp
+ * @author Siddharth Agrawal
+ *
+ * Implementation of function to be optimized for softmax regression.
+ */
+ #include "softmax_regression_function.hpp"
+
+using namespace mlpack;
+using namespace mlpack::regression;
+
+SoftmaxRegressionFunction::SoftmaxRegressionFunction(const arma::mat& data,
+ const arma::vec& labels,
+ const size_t inputSize,
+ const size_t numClasses,
+ const double lambda) :
+ data(data),
+ labels(labels),
+ inputSize(inputSize),
+ numClasses(numClasses),
+ lambda(lambda)
+{
+ // Intialize the parameters to suitable values.
+ initialPoint = InitializeWeights();
+
+ // Calculate the label matrix.
+ GetGroundTruthMatrix(labels, groundTruth);
+}
+
+/**
+ * Initializes parameter weights to random values taken from a scaled standard
+ * normal distribution. The weights cannot be initialized to zero, as that will
+ * lead to each class output being the same.
+ */
+const arma::mat SoftmaxRegressionFunction::InitializeWeights()
+{
+ // Initialize values to 0.005 * r. 'r' is a matrix of random values taken from
+ // a Gaussian distribution with mean zero and variance one.
+ arma::mat parameters;
+ parameters.randn(numClasses, inputSize);
+ parameters = 0.005 * parameters;
+
+ return parameters;
+}
+
+/**
+ * This is equivalent to applying the indicator function to the training
+ * labels. The output is in the form of a matrix, which leads to simpler
+ * calculations in the Evaluate() and Gradient() methods.
+ */
+void SoftmaxRegressionFunction::GetGroundTruthMatrix(const arma::vec& labels,
+ arma::sp_mat& groundTruth)
+{
+ // Calculate the ground truth matrix according to the labels passed. The
+ // ground truth matrix is a matrix of dimensions 'numClasses * numExamples',
+ // where each column contains a single entry of '1', marking the label
+ // corresponding to that example.
+
+ // Row pointers and column pointers corresponding to the entries.
+ arma::uvec rowPointers(labels.n_elem);
+ arma::uvec colPointers(labels.n_elem + 1);
+
+ // Row pointers are the labels of the examples, and column pointers are the
+ // number of cumulative entries made uptil that column.
+ for(size_t i = 0; i < labels.n_elem; i++)
+ {
+ rowPointers(i) = labels(i, 0);
+ colPointers(i+1) = i + 1;
+ }
+
+ // All entries are '1'.
+ arma::vec values;
+ values.ones(labels.n_elem);
+
+ // Calculate the matrix.
+ groundTruth = arma::sp_mat(rowPointers, colPointers, values, numClasses,
+ labels.n_elem);
+}
+
+/**
+ * Evaluates the objective function given the parameters.
+ */
+double SoftmaxRegressionFunction::Evaluate(const arma::mat& parameters) const
+{
+ // The objective function is the negative log likelihood of the model
+ // calculated over all the training examples. Mathematically it is as follows:
+ // log likelihood = sum(1{y_i = j} * log(probability(j))) / m
+ // The sum is over all 'i's and 'j's, where 'i' points to a training example
+ // and 'j' points to a particular class. 1{x} is an indicator function whose
+ // value is 1 only when 'x' is satisfied, otherwise it is 0.
+ // 'm' is the number of training examples.
+ // The cost also takes into account the regularization to control the
+ // parameter weights.
+
+ // Calculate the class probabilities for each training example. The
+ // probabilities for each of the classes are given by:
+ // p_j = exp(theta_j' * x_i) / sum(exp(theta_k' * x_i))
+ // The sum is calculated over all the classes.
+ // x_i is the input vector for a particular training example.
+ // theta_j is the parameter vector associated with a particular class.
+ arma::mat hypothesis, probabilities;
+
+ hypothesis = arma::exp(parameters * data);
+ probabilities = hypothesis / arma::repmat(arma::sum(hypothesis, 0),
+ numClasses, 1);
+
+ // Calculate the log likelihood and regularization terms.
+ double logLikelihood, weightDecay, cost;
+
+ logLikelihood = arma::accu(groundTruth % arma::log(probabilities)) /
+ data.n_cols;
+ weightDecay = 0.5 * lambda * arma::accu(parameters % parameters);
+
+ // The cost is the sum of the negative log likelihood and the regularization
+ // terms.
+ cost = -logLikelihood + weightDecay;
+
+ return cost;
+}
+
+/**
+ * Calculates and stores the gradient values given a set of parameters.
+ */
+void SoftmaxRegressionFunction::Gradient(const arma::mat& parameters,
+ arma::mat& gradient) const
+{
+ // Calculate the class probabilities for each training example. The
+ // probabilities for each of the classes are given by:
+ // p_j = exp(theta_j' * x_i) / sum(exp(theta_k' * x_i))
+ // The sum is calculated over all the classes.
+ // x_i is the input vector for a particular training example.
+ // theta_j is the parameter vector associated with a particular class.
+ arma::mat hypothesis, probabilities;
+
+ hypothesis = arma::exp(parameters * data);
+ probabilities = hypothesis / arma::repmat(arma::sum(hypothesis, 0),
+ numClasses, 1);
+
+ // Calculate the parameter gradients.
+ gradient = (probabilities - groundTruth) * data.t() / data.n_cols +
+ lambda * parameters;
+}
diff --git a/src/mlpack/methods/softmax_regression/softmax_regression_function.hpp b/src/mlpack/methods/softmax_regression/softmax_regression_function.hpp
new file mode 100644
index 0000000..3706436
--- /dev/null
+++ b/src/mlpack/methods/softmax_regression/softmax_regression_function.hpp
@@ -0,0 +1,127 @@
+/**
+ * @file softmax_regression_function.hpp
+ * @author Siddharth Agrawal
+ *
+ * The function to be optimized for softmax regression. Any mlpack optimizer
+ * can be used.
+ */
+#ifndef __MLPACK_METHODS_SOFTMAX_REGRESSION_SOFTMAX_REGRESSION_FUNCTION_HPP
+#define __MLPACK_METHODS_SOFTMAX_REGRESSION_SOFTMAX_REGRESSION_FUNCTION_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace regression {
+
+class SoftmaxRegressionFunction
+{
+ public:
+ /**
+ * Construct the Softmax Regression objective function with the given
+ * parameters.
+ *
+ * @param data Input training features.
+ * @param labels Labels associated with the feature data.
+ * @param inputSize Size of the input feature vector.
+ * @param numClasses Number of classes for classification.
+ * @param lambda L2-regularization constant.
+ */
+ SoftmaxRegressionFunction(const arma::mat& data,
+ const arma::vec& labels,
+ const size_t inputSize,
+ const size_t numClasses,
+ const double lambda = 0.0001);
+
+ //! Initializes the parameters of the model to suitable values.
+ const arma::mat InitializeWeights();
+
+ /**
+ * Constructs the ground truth label matrix with the passed labels.
+ *
+ * @param labels Labels associated with the training data.
+ * @param groundTruth Pointer to arma::mat which stores the computed matrix.
+ */
+ void GetGroundTruthMatrix(const arma::vec& labels, arma::sp_mat& groundTruth);
+
+ /**
+ * Evaluates the objective function of the softmax regression model using the
+ * given parameters. The cost function has terms for the log likelihood error
+ * and the regularization cost. The objective function takes a low value when
+ * the model generalizes well for the given training data, while having small
+ * parameter values.
+ *
+ * @param parameters Current values of the model parameters.
+ */
+ double Evaluate(const arma::mat& parameters) const;
+
+ /**
+ * Evaluates the gradient values of the objective function given the current
+ * set of parameters. The function calculates the probabilities for each class
+ * given the parameters, and computes the gradients based on the difference
+ * from the ground truth.
+ *
+ * @param parameters Current values of the model parameters.
+ * @param gradient Matrix where gradient values will be stored.
+ */
+ void Gradient(const arma::mat& parameters, arma::mat& gradient) const;
+
+ //! Return the initial point for the optimization.
+ const arma::mat& GetInitialPoint() const { return initialPoint; }
+
+ //! Sets the size of the input vector.
+ void InputSize(const size_t input)
+ {
+ this->inputSize = input;
+ }
+
+ //! Gets the size of the input vector.
+ size_t InputSize() const
+ {
+ return inputSize;
+ }
+
+ //! Sets the number of classes.
+ void NumClasses(const size_t classes)
+ {
+ this->numClasses = classes;
+ }
+
+ //! Gets the number of classes.
+ size_t NumClasses() const
+ {
+ return numClasses;
+ }
+
+ //! Sets the regularization parameter.
+ void Lambda(const double l)
+ {
+ this->lambda = l;
+ }
+
+ //! Gets the regularization parameter.
+ double Lambda() const
+ {
+ return lambda;
+ }
+
+ private:
+ //! Training data matrix.
+ const arma::mat& data;
+ //! Labels associated with the training data.
+ const arma::vec& labels;
+ //! Label matrix for the provided data.
+ arma::sp_mat groundTruth;
+ //! Initial parameter point.
+ arma::mat initialPoint;
+ //! Size of input feature vector.
+ size_t inputSize;
+ //! Number of classes.
+ size_t numClasses;
+ //! L2-regularization constant.
+ double lambda;
+};
+
+}; // namespace regression
+}; // namespace mlpack
+
+#endif
diff --git a/src/mlpack/methods/softmax_regression/softmax_regression_impl.hpp b/src/mlpack/methods/softmax_regression/softmax_regression_impl.hpp
new file mode 100644
index 0000000..b16ea0f
--- /dev/null
+++ b/src/mlpack/methods/softmax_regression/softmax_regression_impl.hpp
@@ -0,0 +1,115 @@
+/**
+ * @file softmax_regression_impl.hpp
+ * @author Siddharth Agrawal
+ *
+ * Implementation of softmax regression.
+ */
+#ifndef __MLPACK_METHODS_SOFTMAX_REGRESSION_SOFTMAX_REGRESSION_IMPL_HPP
+#define __MLPACK_METHODS_SOFTMAX_REGRESSION_SOFTMAX_REGRESSION_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "softmax_regression.hpp"
+
+namespace mlpack {
+namespace regression {
+
+template<template<typename> class OptimizerType>
+SoftmaxRegression<OptimizerType>::SoftmaxRegression(const arma::mat& data,
+ const arma::vec& labels,
+ const size_t inputSize,
+ const size_t numClasses,
+ const double lambda) :
+ inputSize(inputSize),
+ numClasses(numClasses),
+ lambda(lambda)
+{
+ SoftmaxRegressionFunction regressor(data, labels, inputSize, numClasses,
+ lambda);
+ OptimizerType<SoftmaxRegressionFunction> optimizer(regressor);
+
+ parameters = regressor.GetInitialPoint();
+
+ // Train the model.
+ Timer::Start("softmax_regression_optimization");
+ const double out = optimizer.Optimize(parameters);
+ Timer::Stop("softmax_regression_optimization");
+
+ Log::Info << "SoftmaxRegression::SoftmaxRegression(): final objective of "
+ << "trained model is " << out << "." << std::endl;
+}
+
+template<template<typename> class OptimizerType>
+SoftmaxRegression<OptimizerType>::SoftmaxRegression(
+ OptimizerType<SoftmaxRegressionFunction>& optimizer) :
+ parameters(optimizer.Function().GetInitialPoint()),
+ inputSize(optimizer.Function().InputSize()),
+ numClasses(optimizer.Function().NumClasses()),
+ lambda(optimizer.Function().Lambda())
+{
+ // Train the model.
+ Timer::Start("softmax_regression_optimization");
+ const double out = optimizer.Optimize(parameters);
+ Timer::Stop("softmax_regression_optimization");
+
+ Log::Info << "SoftmaxRegression::SoftmaxRegression(): final objective of "
+ << "trained model is " << out << "." << std::endl;
+}
+
+template<template<typename> class OptimizerType>
+void SoftmaxRegression<OptimizerType>::Predict(const arma::mat& testData,
+ arma::vec& predictions)
+{
+ // Calculate the probabilities for each test input.
+ arma::mat hypothesis, probabilities;
+
+ hypothesis = arma::exp(parameters * testData);
+ probabilities = hypothesis / arma::repmat(arma::sum(hypothesis, 0),
+ numClasses, 1);
+
+ // Prepare necessary data.
+ predictions.zeros(testData.n_cols);
+ double maxProbability = 0;
+
+ // For each test input.
+ for(size_t i = 0; i < testData.n_cols; i++)
+ {
+ // For each class.
+ for(size_t j = 0; j < numClasses; j++)
+ {
+ // If a higher class probability is encountered, change prediction.
+ if(probabilities(j, i) > maxProbability)
+ {
+ maxProbability = probabilities(j, i);
+ predictions(i) = j;
+ }
+ }
+
+ // Set maximum probability to zero for the next input.
+ maxProbability = 0;
+ }
+}
+
+template<template<typename> class OptimizerType>
+double SoftmaxRegression<OptimizerType>::ComputeAccuracy(
+ const arma::mat& testData,
+ const arma::vec& labels)
+{
+ arma::vec predictions;
+
+ // Get predictions for the provided data.
+ Predict(testData, predictions);
+
+ // Increment count for every correctly predicted label.
+ size_t count = 0;
+ for(size_t i = 0; i < predictions.n_elem; i++)
+ if(predictions(i) == labels(i))
+ count++;
+
+ // Return percentage accuracy.
+ return (count * 100.0) / predictions.n_elem;
+}
+
+}; // namespace regression
+}; // namespace mlpack
+
+#endif
diff --git a/src/mlpack/methods/sparse_autoencoder/sparse_autoencoder_function.hpp b/src/mlpack/methods/sparse_autoencoder/sparse_autoencoder_function.hpp
index 331e346..e69ec1b 100644
--- a/src/mlpack/methods/sparse_autoencoder/sparse_autoencoder_function.hpp
+++ b/src/mlpack/methods/sparse_autoencoder/sparse_autoencoder_function.hpp
@@ -2,7 +2,7 @@
* @file sparse_autoencoder_function.hpp
* @author Siddharth Agrawal
*
- * The function to be optimized for sparse autoencoders. Any mlpack optimizer
+ * The function to be optimized for sparse autoencoders. Any mlpack optimizer
* can be used.
*/
#ifndef __MLPACK_METHODS_SPARSE_AUTOENCODER_SPARSE_AUTOENCODER_FUNCTION_HPP
diff --git a/src/mlpack/tests/CMakeLists.txt b/src/mlpack/tests/CMakeLists.txt
index 5f4c0e9..b985d26 100644
--- a/src/mlpack/tests/CMakeLists.txt
+++ b/src/mlpack/tests/CMakeLists.txt
@@ -45,6 +45,7 @@ add_executable(mlpack_test
sa_test.cpp
save_restore_utility_test.cpp
sgd_test.cpp
+ softmax_regression_test.cpp
sort_policy_test.cpp
sparse_autoencoder_test.cpp
sparse_coding_test.cpp
diff --git a/src/mlpack/tests/softmax_regression_test.cpp b/src/mlpack/tests/softmax_regression_test.cpp
new file mode 100644
index 0000000..07ad79d
--- /dev/null
+++ b/src/mlpack/tests/softmax_regression_test.cpp
@@ -0,0 +1,301 @@
+/**
+ * @file softmax_regression_test.cpp
+ * @author Siddharth Agrawal
+ *
+ * Test the SoftmaxRegression class.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/methods/softmax_regression/softmax_regression.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace mlpack;
+using namespace mlpack::regression;
+using namespace mlpack::distribution;
+
+BOOST_AUTO_TEST_SUITE(SoftmaxRegressionTest);
+
+BOOST_AUTO_TEST_CASE(SoftmaxRegressionFunctionEvaluate)
+{
+ const size_t points = 1000;
+ const size_t trials = 50;
+ const size_t inputSize = 10;
+ const size_t numClasses = 5;
+
+ // Initialize a random dataset.
+ arma::mat data;
+ data.randu(inputSize, points);
+
+ // Create random class labels.
+ arma::vec labels(points);
+ for(size_t i = 0; i < points; i++)
+ labels(i) = math::RandInt(0, numClasses);
+
+ // Create a SoftmaxRegressionFunction. Regularization term ignored.
+ SoftmaxRegressionFunction srf(data, labels, inputSize, numClasses, 0);
+
+ // Run a number of trials.
+ for(size_t i = 0; i < trials; i++)
+ {
+ // Create a random set of parameters.
+ arma::mat parameters;
+ parameters.randu(numClasses, inputSize);
+
+ double logLikelihood = 0;
+
+ // Compute error for each training example.
+ for(size_t j = 0; j < points; j++)
+ {
+ arma::mat hypothesis, probabilities;
+
+ hypothesis = arma::exp(parameters * data.col(j));
+ probabilities = hypothesis / arma::accu(hypothesis);
+
+ logLikelihood += log(probabilities(labels(j), 0));
+ }
+ logLikelihood /= points;
+
+ // Compare with the value returned by the function.
+ BOOST_REQUIRE_CLOSE(srf.Evaluate(parameters), -logLikelihood, 1e-5);
+ }
+}
+
+BOOST_AUTO_TEST_CASE(SoftmaxRegressionFunctionRegularizationEvaluate)
+{
+ const size_t points = 1000;
+ const size_t trials = 50;
+ const size_t inputSize = 10;
+ const size_t numClasses = 5;
+
+ // Initialize a random dataset.
+ arma::mat data;
+ data.randu(inputSize, points);
+
+ // Create random class labels.
+ arma::vec labels(points);
+ for(size_t i = 0; i < points; i++)
+ labels(i) = math::RandInt(0, numClasses);
+
+ // 3 objects for comparing regularization costs.
+ SoftmaxRegressionFunction srfNoReg(data, labels, inputSize, numClasses, 0);
+ SoftmaxRegressionFunction srfSmallReg(data, labels, inputSize, numClasses, 1);
+ SoftmaxRegressionFunction srfBigReg(data, labels, inputSize, numClasses, 20);
+
+ // Run a number of trials.
+ for (size_t i = 0; i < trials; i++)
+ {
+ // Create a random set of parameters.
+ arma::mat parameters;
+ parameters.randu(numClasses, inputSize);
+
+ double wL2SquaredNorm;
+ wL2SquaredNorm = arma::accu(parameters % parameters);
+
+ // Calculate regularization terms.
+ const double smallRegTerm = 0.5 * wL2SquaredNorm;
+ const double bigRegTerm = 10 * wL2SquaredNorm;
+
+ BOOST_REQUIRE_CLOSE(srfNoReg.Evaluate(parameters) + smallRegTerm,
+ srfSmallReg.Evaluate(parameters), 1e-5);
+ BOOST_REQUIRE_CLOSE(srfNoReg.Evaluate(parameters) + bigRegTerm,
+ srfBigReg.Evaluate(parameters), 1e-5);
+ }
+}
+
+BOOST_AUTO_TEST_CASE(SoftmaxRegressionFunctionGradient)
+{
+ const size_t points = 1000;
+ const size_t inputSize = 10;
+ const size_t numClasses = 5;
+
+ // Initialize a random dataset.
+ arma::mat data;
+ data.randu(inputSize, points);
+
+ // Create random class labels.
+ arma::vec labels(points);
+ for(size_t i = 0; i < points; i++)
+ labels(i) = math::RandInt(0, numClasses);
+
+ // 2 objects for 2 terms in the cost function. Each term contributes towards
+ // the gradient and thus need to be checked independently.
+ SoftmaxRegressionFunction srf1(data, labels, inputSize, numClasses, 0);
+ SoftmaxRegressionFunction srf2(data, labels, inputSize, numClasses, 20);
+
+ // Create a random set of parameters.
+ arma::mat parameters;
+ parameters.randu(numClasses, inputSize);
+
+ // Get gradients for the current parameters.
+ arma::mat gradient1, gradient2;
+ srf1.Gradient(parameters, gradient1);
+ srf2.Gradient(parameters, gradient2);
+
+ // Perturbation constant.
+ const double epsilon = 0.0001;
+ double costPlus1, costMinus1, numGradient1;
+ double costPlus2, costMinus2, numGradient2;
+
+ // For each parameter.
+ for (size_t i = 0; i < numClasses; i++)
+ {
+ for (size_t j = 0; j < inputSize; j++)
+ {
+ // Perturb parameter with a positive constant and get costs.
+ parameters(i, j) += epsilon;
+ costPlus1 = srf1.Evaluate(parameters);
+ costPlus2 = srf2.Evaluate(parameters);
+
+ // Perturb parameter with a negative constant and get costs.
+ parameters(i, j) -= 2 * epsilon;
+ costMinus1 = srf1.Evaluate(parameters);
+ costMinus2 = srf2.Evaluate(parameters);
+
+ // Compute numerical gradients using the costs calculated above.
+ numGradient1 = (costPlus1 - costMinus1) / (2 * epsilon);
+ numGradient2 = (costPlus2 - costMinus2) / (2 * epsilon);
+
+ // Restore the parameter value.
+ parameters(i, j) += epsilon;
+
+ // Compare numerical and backpropagation gradient values.
+ BOOST_REQUIRE_CLOSE(numGradient1, gradient1(i, j), 1e-2);
+ BOOST_REQUIRE_CLOSE(numGradient2, gradient2(i, j), 1e-2);
+ }
+ }
+}
+
+BOOST_AUTO_TEST_CASE(SoftmaxRegressionTwoClasses)
+{
+ const size_t points = 1000;
+ const size_t inputSize = 3;
+ const size_t numClasses = 2;
+ const double lambda = 0.5;
+
+ // Generate two-Gaussian dataset.
+ GaussianDistribution g1(arma::vec("1.0 9.0 1.0"), arma::eye<arma::mat>(3, 3));
+ GaussianDistribution g2(arma::vec("4.0 3.0 4.0"), arma::eye<arma::mat>(3, 3));
+
+ arma::mat data(inputSize, points);
+ arma::vec labels(points);
+
+ for (size_t i = 0; i < points/2; i++)
+ {
+ data.col(i) = g1.Random();
+ labels(i) = 0;
+ }
+ for (size_t i = points/2; i < points; i++)
+ {
+ data.col(i) = g2.Random();
+ labels(i) = 1;
+ }
+
+ // Train softmax regression object.
+ SoftmaxRegression<> sr(data, labels, inputSize, numClasses, lambda);
+
+ // Compare training accuracy to 100.
+ const double acc = sr.ComputeAccuracy(data, labels);
+ BOOST_REQUIRE_CLOSE(acc, 100.0, 0.3);
+
+ // Create test dataset.
+ for (size_t i = 0; i < points/2; i++)
+ {
+ data.col(i) = g1.Random();
+ labels(i) = 0;
+ }
+ for (size_t i = points/2; i < points; i++)
+ {
+ data.col(i) = g2.Random();
+ labels(i) = 1;
+ }
+
+ // Compare test accuracy to 100.
+ const double testAcc = sr.ComputeAccuracy(data, labels);
+ BOOST_REQUIRE_CLOSE(testAcc, 100.0, 0.6);
+}
+
+BOOST_AUTO_TEST_CASE(SoftmaxRegressionMultipleClasses)
+{
+ const size_t points = 5000;
+ const size_t inputSize = 5;
+ const size_t numClasses = 5;
+ const double lambda = 0.5;
+
+ // Generate five-Gaussian dataset.
+ arma::mat identity = arma::eye<arma::mat>(5, 5);
+ GaussianDistribution g1(arma::vec("1.0 9.0 1.0 2.0 2.0"), identity);
+ GaussianDistribution g2(arma::vec("4.0 3.0 4.0 2.0 2.0"), identity);
+ GaussianDistribution g3(arma::vec("3.0 2.0 7.0 0.0 5.0"), identity);
+ GaussianDistribution g4(arma::vec("4.0 1.0 1.0 2.0 7.0"), identity);
+ GaussianDistribution g5(arma::vec("1.0 0.0 1.0 8.0 3.0"), identity);
+
+ arma::mat data(inputSize, points);
+ arma::vec labels(points);
+
+ for (size_t i = 0; i < points/5; i++)
+ {
+ data.col(i) = g1.Random();
+ labels(i) = 0;
+ }
+ for (size_t i = points/5; i < (2*points)/5; i++)
+ {
+ data.col(i) = g2.Random();
+ labels(i) = 1;
+ }
+ for (size_t i = (2*points)/5; i < (3*points)/5; i++)
+ {
+ data.col(i) = g3.Random();
+ labels(i) = 2;
+ }
+ for (size_t i = (3*points)/5; i < (4*points)/5; i++)
+ {
+ data.col(i) = g4.Random();
+ labels(i) = 3;
+ }
+ for (size_t i = (4*points)/5; i < points; i++)
+ {
+ data.col(i) = g5.Random();
+ labels(i) = 4;
+ }
+
+ // Train softmax regression object.
+ SoftmaxRegression<> sr(data, labels, inputSize, numClasses, lambda);
+
+ // Compare training accuracy to 100.
+ const double acc = sr.ComputeAccuracy(data, labels);
+ BOOST_REQUIRE_CLOSE(acc, 100.0, 2.0);
+
+ // Create test dataset.
+ for (size_t i = 0; i < points/5; i++)
+ {
+ data.col(i) = g1.Random();
+ labels(i) = 0;
+ }
+ for (size_t i = points/5; i < (2*points)/5; i++)
+ {
+ data.col(i) = g2.Random();
+ labels(i) = 1;
+ }
+ for (size_t i = (2*points)/5; i < (3*points)/5; i++)
+ {
+ data.col(i) = g3.Random();
+ labels(i) = 2;
+ }
+ for (size_t i = (3*points)/5; i < (4*points)/5; i++)
+ {
+ data.col(i) = g4.Random();
+ labels(i) = 3;
+ }
+ for (size_t i = (4*points)/5; i < points; i++)
+ {
+ data.col(i) = g5.Random();
+ labels(i) = 4;
+ }
+
+ // Compare test accuracy to 100.
+ const double testAcc = sr.ComputeAccuracy(data, labels);
+ BOOST_REQUIRE_CLOSE(testAcc, 100.0, 2.0);
+}
+
+BOOST_AUTO_TEST_SUITE_END();
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list