[mlpack] 09/22: Change to two-pass algorithm suggested by Vahab in #344.

Barak A. Pearlmutter barak+git at cs.nuim.ie
Thu Apr 17 12:23:02 UTC 2014


This is an automated email from the git hooks/post-receive script.

bap pushed a commit to branch svn-trunk
in repository mlpack.

commit 8fbd23738487cab4fd898a4e43c5d52ee26e73cb
Author: rcurtin <rcurtin at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date:   Wed Apr 16 18:18:13 2014 +0000

    Change to two-pass algorithm suggested by Vahab in #344.
    
    
    git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@16430 9d5b8971-822b-0410-80eb-d18c1038ef23
---
 .../naive_bayes/naive_bayes_classifier_impl.hpp    | 30 +++++++++++++++-------
 1 file changed, 21 insertions(+), 9 deletions(-)

diff --git a/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp b/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
index 2fd92f4..5cd4fb9 100644
--- a/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
+++ b/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
@@ -1,6 +1,7 @@
 /**
  * @file naive_bayes_classifier_impl.hpp
  * @author Parikshit Ram (pram at cc.gatech.edu)
+ * @author Vahab Akbarzadeh (v.akbarzadeh at gmail.com)
  *
  * A Naive Bayes Classifier which parametrically estimates the distribution of
  * the features.  This classifier makes its predictions based on the assumption
@@ -59,25 +60,36 @@ NaiveBayesClassifier<MatType>::NaiveBayesClassifier(
   }
   else
   {
-    // Don't use incremental algorithm.
+    // Don't use incremental algorithm.  This is a two-pass algorithm.  It is
+    // possible to calculate the means and variances using a faster one-pass
+    // algorithm but there are some precision and stability issues.  If this is
+    // too slow, it's an option to use the faster algorithm by default and then
+    // have this (and the incremental algorithm) be other options.
+
+    // Calculate the means.
     for (size_t j = 0; j < data.n_cols; ++j)
     {
       const size_t label = labels[j];
       ++probabilities[label];
-
       means.col(label) += data.col(j);
-      variances.col(label) += square(data.col(j));
     }
 
+    // Normalize means.
     for (size_t i = 0; i < classes; ++i)
-    {
-      if (probabilities[i] != 0)
-      {
-        variances.col(i) -= (square(means.col(i)) / probabilities[i]);
+      if (probabilities[i] != 0.0)
         means.col(i) /= probabilities[i];
-        variances.col(i) /= (probabilities[i] - 1);
-      }
+
+    // Calculate variances.
+    for (size_t j = 0; j < data.n_cols; ++j)
+    {
+      const size_t label = labels[j];
+      variances.col(label) += square(data.col(j) - means.col(label));
     }
+
+    // Normalize variances.
+    for (size_t i = 0; i < classes; ++i)
+      if (probabilities[i] > 1)
+        variances.col(i) /= (probabilities[i] - 1);
   }
 
   // Ensure that the variances are invertible.

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git



More information about the debian-science-commits mailing list