[mlpack] 06/22: Overhaul implementation; do not use gmm::phi(). This gives serious speedup, as high-dimensional matrix inverses are not being calculated. The previous calls to gmm::phi() would invert a diagonal matrix without being able to assume that the matrix was diagonal. This explains the very very poor benchmarking results for nbc in mlpack.
Barak A. Pearlmutter
barak+git at cs.nuim.ie
Thu Apr 17 12:23:02 UTC 2014
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch svn-trunk
in repository mlpack.
commit 727b070da193ebc808efadc80b8d1b76db5cd4be
Author: rcurtin <rcurtin at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date: Mon Apr 14 20:11:16 2014 +0000
Overhaul implementation; do not use gmm::phi(). This gives serious speedup, as
high-dimensional matrix inverses are not being calculated. The previous calls
to gmm::phi() would invert a diagonal matrix without being able to assume that
the matrix was diagonal. This explains the very very poor benchmarking results
for nbc in mlpack.
git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@16426 9d5b8971-822b-0410-80eb-d18c1038ef23
---
.../naive_bayes/naive_bayes_classifier_impl.hpp | 57 +++++++++++++---------
1 file changed, 34 insertions(+), 23 deletions(-)
diff --git a/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp b/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
index c1b0290..08791ee 100644
--- a/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
+++ b/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
@@ -1,10 +1,11 @@
/**
- * @file simple_nbc_impl.hpp
+ * @file naive_bayes_classifier_impl.hpp
* @author Parikshit Ram (pram at cc.gatech.edu)
*
* A Naive Bayes Classifier which parametrically estimates the distribution of
- * the features. It is assumed that the features have been sampled from a
- * Gaussian PDF.
+ * the features. This classifier makes its predictions based on the assumption
+ * that the features have been sampled from a set of Gaussians with diagonal
+ * covariance.
*/
#ifndef __MLPACK_METHODS_NAIVE_BAYES_NAIVE_BAYES_CLASSIFIER_IMPL_HPP
#define __MLPACK_METHODS_NAIVE_BAYES_NAIVE_BAYES_CLASSIFIER_IMPL_HPP
@@ -53,6 +54,11 @@ NaiveBayesClassifier<MatType>::NaiveBayesClassifier(
means.col(i) /= probabilities[i];
variances.col(i) /= (probabilities[i] - 1);
}
+
+ // Make sure variance is invertible.
+ for (size_t j = 0; j < dimensionality; ++j)
+ if (variances(j, i) == 0.0)
+ variances(j, i) = 1e-50;
}
probabilities /= data.n_cols;
@@ -66,9 +72,12 @@ void NaiveBayesClassifier<MatType>::Classify(const MatType& data,
// training data.
Log::Assert(data.n_rows == means.n_rows);
- arma::vec probs(means.n_cols);
+ arma::vec probs = arma::log(probabilities);
+ arma::mat invVar = 1.0 / variances;
+
+ arma::mat testProbs = arma::repmat(probs.t(), data.n_cols, 1);
- results.zeros(data.n_cols);
+ results.set_size(data.n_cols); // No need to fill with anything yet.
Log::Info << "Running Naive Bayes classifier on " << data.n_cols
<< " data points with " << data.n_rows << " features each." << std::endl;
@@ -76,28 +85,30 @@ void NaiveBayesClassifier<MatType>::Classify(const MatType& data,
// Calculate the joint probability for each of the data points for each of the
// means.n_cols.
- // Loop over every test case.
- for (size_t n = 0; n < data.n_cols; n++)
+ // Loop over every class.
+ for (size_t i = 0; i < means.n_cols; i++)
{
- // Loop over every class.
- for (size_t i = 0; i < means.n_cols; i++)
- {
- // Use the log values to prevent floating point underflow.
- probs(i) = log(probabilities(i));
-
- // Loop over every feature, but avoid inverting empty matrices.
- if (probabilities[i] != 0)
- {
- probs(i) += log(gmm::phi(data.unsafe_col(n), means.unsafe_col(i),
- diagmat(variances.unsafe_col(i))));
- }
- }
+ // This is an adaptation of gmm::phi() for the case where the covariance is
+ // a diagonal matrix.
+ arma::mat diffs = data - arma::repmat(means.col(i), 1, data.n_cols);
+ arma::mat rhs = -0.5 * arma::diagmat(invVar.col(i)) * diffs;
+ arma::vec exponents(diffs.n_cols);
+ for (size_t j = 0; j < diffs.n_cols; ++j)
+ exponents(j) = std::exp(arma::accu(diffs.col(j) % rhs.unsafe_col(j)));
+
+ testProbs.col(i) += log(pow(2 * M_PI, (double) data.n_rows / -2.0) *
+ pow(det(arma::diagmat(invVar.col(i))), -0.5) * exponents);
+ }
- // Find the index of the maximum value in tmp_vals.
+ // Now calculate the label.
+ for (size_t i = 0; i < data.n_cols; ++i)
+ {
+ // Find the index of the class with maximum probability for this point.
arma::uword maxIndex = 0;
- probs.max(maxIndex);
+ arma::vec pointProbs = testProbs.row(i).t();
+ pointProbs.max(maxIndex);
- results[n] = maxIndex;
+ results[i] = maxIndex;
}
return;
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list