[mlpack] 106/149: Refactor for cleaner code and avoid storing WH explicitly if possible.

Barak A. Pearlmutter barak+git at pearlmutter.net
Sat May 2 09:11:14 UTC 2015


This is an automated email from the git hooks/post-receive script.

bap pushed a commit to branch svn-trunk
in repository mlpack.

commit efd784a565c0a69fd7be77c214be33cd7bbb7295
Author: rcurtin <rcurtin at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date:   Wed Nov 19 17:08:48 2014 +0000

    Refactor for cleaner code and avoid storing WH explicitly if possible.
    
    
    git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@17387 9d5b8971-822b-0410-80eb-d18c1038ef23
---
 .../simple_residue_termination.hpp                 | 68 +++++++++++-----------
 1 file changed, 33 insertions(+), 35 deletions(-)

diff --git a/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp b/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
index 66d7930..3b7f18e 100644
--- a/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
+++ b/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
@@ -13,21 +13,30 @@ namespace mlpack {
 namespace amf {
 
 /**
- * This class implements simple residue based termination policy. Termination 
- * decision depends on two factors, value of residue and number of iteration. 
- * If the current value of residue drops below the threshold or the number of 
- * iterations goes above the threshold, positive termination signal is passed 
- * to AMF.
+ * This class implements a simple residue-based termination policy. The
+ * termination decision depends on two factors: the value of the residue (the
+ * difference between the norm of WH this iteration and the previous iteration),
+ * and the number of iterations.  If the current value of residue drops below
+ * the threshold or the number of iterations goes above the iteration limit,
+ * IsConverged() will return true.  This class is meant for use with the AMF
+ * (alternating matrix factorization) class.
  *
  * @see AMF
  */
 class SimpleResidueTermination
 {
  public:
-  //! empty constructor
+  /**
+   * Construct the SimpleResidueTermination object with the given minimum
+   * residue (or the default) and the given maximum number of iterations (or the
+   * default).  0 indicates no iteration limit.
+   *
+   * @param minResidue Minimum residue for termination.
+   * @param maxIterations Maximum number of iterations.
+   */
   SimpleResidueTermination(const double minResidue = 1e-10,
                            const size_t maxIterations = 10000)
-        : minResidue(minResidue), maxIterations(maxIterations) { }
+      : minResidue(minResidue), maxIterations(maxIterations) { }
 
   /**
    * Initializes the termination policy before stating the factorization.
@@ -37,57 +46,46 @@ class SimpleResidueTermination
   template<typename MatType>
   void Initialize(const MatType& V)
   {
-    // set resisue to minimum value
-    residue = minResidue;
-    // set iteration to minimum value
+    // Initialize the things we keep track of.
+    residue = DBL_MAX;
     iteration = 1;
-    // remove history
+    nm = V.n_rows * V.n_cols;
+    // Remove history.
     normOld = 0;
-
-    // initialize required variables
-    const size_t n = V.n_rows;
-    const size_t m = V.n_cols;
-    nm = n * m;
   }
 
   /**
-   * Check if termination criterio is met.
+   * Check if termination criterion is met.
    *
    * @param W Basis matrix of output.
    * @param H Encoding matrix of output.
    */
   bool IsConverged(arma::mat& W, arma::mat& H)
   {
-    // Calculate norm of WH after each iteration.
-    arma::mat WH;
-
-    // calculate the norm and compute the residue 
-    WH = W * H;
-    double norm = sqrt(accu(WH % WH) / nm);
-    residue = fabs(normOld - norm);
-    residue /= normOld;
+    // Calculate the norm and compute the residue
+    const double norm = arma::norm(W * H, "fro");
+    residue = fabs(normOld - norm) / normOld;
 
-    // store the residue into history
+    // Store the norm.
     normOld = norm;
-    
-    // increment iteration count
+
+    // Increment iteration count
     iteration++;
-    
-    // check if termination criterion is met
-    if(residue < minResidue || iteration > maxIterations) return true;
-    else return false;
+
+    // Check if termination criterion is met.
+    return (residue < minResidue || iteration > maxIterations);
   }
 
   //! Get current value of residue
   const double& Index() const { return residue; }
 
-  //! Get current iteration count  
+  //! Get current iteration count
   const size_t& Iteration() const { return iteration; }
-  
+
   //! Access max iteration count
   const size_t& MaxIterations() const { return maxIterations; }
   size_t& MaxIterations() { return maxIterations; }
-  
+
   //! Access minimum residue value
   const double& MinResidue() const { return minResidue; }
   double& MinResidue() { return minResidue; }

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git



More information about the debian-science-commits mailing list