[mlpack] 09/22: Change to two-pass algorithm suggested by Vahab in #344.
Barak A. Pearlmutter
barak+git at cs.nuim.ie
Thu Apr 17 12:23:02 UTC 2014
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch svn-trunk
in repository mlpack.
commit 8fbd23738487cab4fd898a4e43c5d52ee26e73cb
Author: rcurtin <rcurtin at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date: Wed Apr 16 18:18:13 2014 +0000
Change to two-pass algorithm suggested by Vahab in #344.
git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@16430 9d5b8971-822b-0410-80eb-d18c1038ef23
---
.../naive_bayes/naive_bayes_classifier_impl.hpp | 30 +++++++++++++++-------
1 file changed, 21 insertions(+), 9 deletions(-)
diff --git a/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp b/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
index 2fd92f4..5cd4fb9 100644
--- a/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
+++ b/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
@@ -1,6 +1,7 @@
/**
* @file naive_bayes_classifier_impl.hpp
* @author Parikshit Ram (pram at cc.gatech.edu)
+ * @author Vahab Akbarzadeh (v.akbarzadeh at gmail.com)
*
* A Naive Bayes Classifier which parametrically estimates the distribution of
* the features. This classifier makes its predictions based on the assumption
@@ -59,25 +60,36 @@ NaiveBayesClassifier<MatType>::NaiveBayesClassifier(
}
else
{
- // Don't use incremental algorithm.
+ // Don't use incremental algorithm. This is a two-pass algorithm. It is
+ // possible to calculate the means and variances using a faster one-pass
+ // algorithm but there are some precision and stability issues. If this is
+ // too slow, it's an option to use the faster algorithm by default and then
+ // have this (and the incremental algorithm) be other options.
+
+ // Calculate the means.
for (size_t j = 0; j < data.n_cols; ++j)
{
const size_t label = labels[j];
++probabilities[label];
-
means.col(label) += data.col(j);
- variances.col(label) += square(data.col(j));
}
+ // Normalize means.
for (size_t i = 0; i < classes; ++i)
- {
- if (probabilities[i] != 0)
- {
- variances.col(i) -= (square(means.col(i)) / probabilities[i]);
+ if (probabilities[i] != 0.0)
means.col(i) /= probabilities[i];
- variances.col(i) /= (probabilities[i] - 1);
- }
+
+ // Calculate variances.
+ for (size_t j = 0; j < data.n_cols; ++j)
+ {
+ const size_t label = labels[j];
+ variances.col(label) += square(data.col(j) - means.col(label));
}
+
+ // Normalize variances.
+ for (size_t i = 0; i < classes; ++i)
+ if (probabilities[i] > 1)
+ variances.col(i) /= (probabilities[i] - 1);
}
// Ensure that the variances are invertible.
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list