[mlpack] 06/44: Merge observation weight support for LinearRegression (r17104-17106) and style fixes (r17194-17195).
Barak A. Pearlmutter
barak+git at pearlmutter.net
Mon Feb 15 19:35:52 UTC 2016
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to tag mlpack-1.0.11
in repository mlpack.
commit 42de78b4dade0a878019eff105d54fd18c8623e8
Author: Ryan Curtin <ryan at ratml.org>
Date: Sun Dec 7 19:17:59 2014 +0000
Merge observation weight support for LinearRegression (r17104-17106) and style fixes (r17194-17195).
---
HISTORY.txt | 2 +
.../linear_regression/linear_regression.cpp | 107 +++++++++++++--------
.../linear_regression/linear_regression.hpp | 11 ++-
3 files changed, 79 insertions(+), 41 deletions(-)
diff --git a/HISTORY.txt b/HISTORY.txt
index 4e58de7..81ed204 100644
--- a/HISTORY.txt
+++ b/HISTORY.txt
@@ -6,6 +6,8 @@
* Linker fixes for AugLagrangian specializations under Visual Studio.
+ * Add support for observation weights to LinearRegression.
+
2014-08-29 mlpack 1.0.10
* Bugfix for NeighborSearch regression which caused very slow allknn/allkfn.
diff --git a/src/mlpack/methods/linear_regression/linear_regression.cpp b/src/mlpack/methods/linear_regression/linear_regression.cpp
index 0bdd718..c08ac23 100644
--- a/src/mlpack/methods/linear_regression/linear_regression.cpp
+++ b/src/mlpack/methods/linear_regression/linear_regression.cpp
@@ -1,6 +1,7 @@
/**
* @file linear_regression.cpp
* @author James Cline
+ * @author Michael Fox
*
* Implementation of simple linear regression.
*
@@ -25,9 +26,13 @@ using namespace mlpack;
using namespace mlpack::regression;
LinearRegression::LinearRegression(const arma::mat& predictors,
- const arma::colvec& responses,
- const double lambda) :
- lambda(lambda)
+ const arma::vec& responses,
+ const double lambda,
+ const bool intercept,
+ const arma::vec& weights
+ ) :
+ lambda(lambda),
+ intercept(intercept)
{
/*
* We want to calculate the a_i coefficients of:
@@ -40,25 +45,31 @@ LinearRegression::LinearRegression(const arma::mat& predictors,
// that is, columns are actually rows (see: column major order).
const size_t nCols = predictors.n_cols;
+ arma::mat p = predictors;
+ arma::vec r = responses;
// Here we add the row of ones to the predictors.
- arma::mat p;
- if (lambda == 0.0)
+ // The intercept is not penalized. Add an "all ones" row to design and set
+ // intercept = false to get a penalized intercept
+ if(intercept)
{
- p.set_size(predictors.n_rows + 1, nCols);
- p.submat(1, 0, p.n_rows - 1, nCols - 1) = predictors;
- p.row(0).fill(1);
+ p.insert_rows(0, arma::ones<arma::mat>(1,nCols));
}
- else
+
+ if(weights.n_elem > 0)
+ {
+ p = p * diagmat(sqrt(weights));
+ r = sqrt(weights) % responses;
+ }
+
+ if (lambda != 0.0)
{
// Add the identity matrix to the predictors (this is equivalent to ridge
// regression). See http://math.stackexchange.com/questions/299481/ for
// more information.
- p.set_size(predictors.n_rows + 1, nCols + predictors.n_rows + 1);
- p.submat(1, 0, p.n_rows - 1, nCols - 1) = predictors;
- p.row(0).subvec(0, nCols - 1).fill(1);
- p.submat(0, nCols, p.n_rows - 1, nCols + predictors.n_rows) =
- lambda * arma::eye<arma::mat>(predictors.n_rows + 1,
- predictors.n_rows + 1);
+ p.insert_cols(nCols, predictors.n_rows);
+ p.submat(p.n_rows - predictors.n_rows, nCols, p.n_rows - 1, nCols +
+ predictors.n_rows - 1) = sqrt(lambda) * arma::eye<arma::mat>(predictors.n_rows,
+ predictors.n_rows);
}
// We compute the QR decomposition of the predictors.
@@ -72,15 +83,12 @@ LinearRegression::LinearRegression(const arma::mat& predictors,
// If lambda > 0, then we must add a bunch of empty responses.
if (lambda == 0.0)
{
- arma::solve(parameters, R, arma::trans(Q) * responses);
+ arma::solve(parameters, R, arma::trans(Q) * r);
}
else
{
// Copy responses into larger vector.
- arma::vec r(nCols + predictors.n_rows + 1);
- r.subvec(0, nCols - 1) = responses;
- r.subvec(nCols, nCols + predictors.n_rows).fill(0);
-
+ r.insert_rows(nCols,p.n_cols - nCols);
arma::solve(parameters, R, arma::trans(Q) * r);
}
}
@@ -101,15 +109,25 @@ LinearRegression::LinearRegression(const LinearRegression& linearRegression) :
void LinearRegression::Predict(const arma::mat& points, arma::vec& predictions)
const
{
- // We want to be sure we have the correct number of dimensions in the dataset.
- Log::Assert(points.n_rows == parameters.n_rows - 1);
-
- // Get the predictions, but this ignores the intercept value (parameters[0]).
- predictions = arma::trans(arma::trans(
- parameters.subvec(1, parameters.n_elem - 1)) * points);
+ if (intercept)
+ {
+ // We want to be sure we have the correct number of dimensions in the
+ // dataset.
+ Log::Assert(points.n_rows == parameters.n_rows - 1);
+ // Get the predictions, but this ignores the intercept value
+ // (parameters[0]).
+ predictions = arma::trans(arma::trans(parameters.subvec(1,
+ parameters.n_elem - 1)) * points);
+ // Now add the intercept.
+ predictions += parameters(0);
+ }
+ else
+ {
+ // We want to be sure we have the correct number of dimensions in the dataset.
+ Log::Assert(points.n_rows == parameters.n_rows);
+ predictions = arma::trans(arma::trans(parameters) * points);
+ }
- // Now add the intercept.
- predictions += parameters(0);
}
//! Compute the L2 squared error on the given predictors and responses.
@@ -120,19 +138,30 @@ double LinearRegression::ComputeError(const arma::mat& predictors,
const size_t nCols = predictors.n_cols;
const size_t nRows = predictors.n_rows;
- // Ensure that we have the correct number of dimensions in the dataset.
- if (nRows != parameters.n_rows - 1)
- {
- Log::Fatal << "The test data must have the same number of columns as the "
- "training file." << std::endl;
- }
-
// Calculate the differences between actual responses and predicted responses.
// We must also add the intercept (parameters(0)) to the predictions.
- arma::vec temp = responses - arma::trans(
- (arma::trans(parameters.subvec(1, parameters.n_elem - 1)) * predictors) +
- parameters(0));
-
+ arma::vec temp;
+ if (intercept)
+ {
+ // Ensure that we have the correct number of dimensions in the dataset.
+ if (nRows != parameters.n_rows - 1)
+ {
+ Log::Fatal << "The test data must have the same number of columns as the "
+ "training file." << std::endl;
+ }
+ temp = responses - arma::trans( (arma::trans(parameters.subvec(1,
+ parameters.n_elem - 1)) * predictors) + parameters(0));
+ }
+ else
+ {
+ // Ensure that we have the correct number of dimensions in the dataset.
+ if (nRows != parameters.n_rows)
+ {
+ Log::Fatal << "The test data must have the same number of columns as the "
+ "training file." << std::endl;
+ }
+ temp = responses - arma::trans((arma::trans(parameters) * predictors));
+ }
const double cost = arma::dot(temp, temp) / nCols;
return cost;
diff --git a/src/mlpack/methods/linear_regression/linear_regression.hpp b/src/mlpack/methods/linear_regression/linear_regression.hpp
index ef98f9b..b529a40 100644
--- a/src/mlpack/methods/linear_regression/linear_regression.hpp
+++ b/src/mlpack/methods/linear_regression/linear_regression.hpp
@@ -1,6 +1,7 @@
/**
* @file linear_regression.hpp
* @author James Cline
+ * @author Michael Fox
*
* Simple least-squares linear regression.
*
@@ -40,10 +41,15 @@ class LinearRegression
*
* @param predictors X, matrix of data points to create B with.
* @param responses y, the measured data for each point in X
+ * @param intercept include intercept?
+ * @param weights observation weights
*/
LinearRegression(const arma::mat& predictors,
const arma::vec& responses,
- const double lambda = 0);
+ const double lambda = 0,
+ const bool intercept = true,
+ const arma::vec& weights = arma::vec()
+ );
/**
* Initialize the model from a file.
@@ -111,12 +117,13 @@ class LinearRegression
* Initialized and filled by constructor to hold the least squares solution.
*/
arma::vec parameters;
-
/**
* The Tikhonov regularization parameter for ridge regression (0 for linear
* regression).
*/
double lambda;
+ //! Indicates whether first parameter is intercept.
+ bool intercept;
};
}; // namespace linear_regression
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list