[mlpack] 138/324: Minor refactoring of AMF class; mostly renaming for consistency and clarification of comments.

Barak A. Pearlmutter barak+git at cs.nuim.ie
Sun Aug 17 08:22:04 UTC 2014


This is an automated email from the git hooks/post-receive script.

bap pushed a commit to branch svn-trunk
in repository mlpack.

commit 4a3613c0a69aabf9812de1fbf1e01b768043a7ae
Author: rcurtin <rcurtin at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date:   Tue Jul 8 11:27:24 2014 +0000

    Minor refactoring of AMF class; mostly renaming for consistency and
    clarification of comments.
    
    
    git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@16777 9d5b8971-822b-0410-80eb-d18c1038ef23
---
 src/mlpack/methods/amf/amf.hpp      | 100 ++++++++++++++++++++----------------
 src/mlpack/methods/amf/amf_impl.hpp |  61 +++++++++++-----------
 2 files changed, 87 insertions(+), 74 deletions(-)

diff --git a/src/mlpack/methods/amf/amf.hpp b/src/mlpack/methods/amf/amf.hpp
index 5ffad8c..ef20b0b 100644
--- a/src/mlpack/methods/amf/amf.hpp
+++ b/src/mlpack/methods/amf/amf.hpp
@@ -1,6 +1,12 @@
 /**
  * @file amf.hpp
  * @author Sumedh Ghaisas
+ * @author Mohan Rajendran
+ * @author Ryan Curtin
+ *
+ * The AMF (alternating matrix factorization) class, from which more commonly
+ * known techniques such as incremental SVD, NMF, and batch-learning SVD can be
+ * derived.
  */
 #ifndef __MLPACK_METHODS_AMF_AMF_HPP
 #define __MLPACK_METHODS_AMF_AMF_HPP
@@ -14,19 +20,22 @@ namespace mlpack {
 namespace amf {
 
 /**
- * This class implements the AMF on the given matrix V. Alternaing Matrix
- * Factorization decomposes V in the form \f$ V \approx WH \f$ where W is
- * called the basis matrix and H is called the encoding matrix. V is taken
- * to be of size n x m and the obtained W is n x r and H is r x m. The size r is
- * called the rank of the factorization.
+ * This class implements AMF (alternating matrix factorization) on the given
+ * matrix V. Alternating matrix factorization decomposes V in the form
+ * \f$ V \approx WH \f$ where W is called the basis matrix and H is called the
+ * encoding matrix. V is taken to be of size n x m and the obtained W is n x r
+ * and H is r x m. The size r is called the rank of the factorization.
  *
- * The implementation requires two template types; the first contains the
- * initialization rule for the W and H matrix and the other contains the update
- * rule to be used during each iteration.  This templatization allows the
- * user to try various update rules (including ones not supplied with MLPACK)
- * for factorization.
+ * The implementation requires three template types; the first contains the
+ * policy used to determine when the algorithm has converged; the second
+ * contains the initialization rule for the W and H matrix; the last contains
+ * the update rule to be used during each iteration. This templatization allows
+ * the user to try various update rules, initialization rules, and termination
+ * policies (including ones not supplied with MLPACK) for factorization.  By
+ * default, the template parameters to AMF implement non-negative matrix
+ * factorization with the multiplicative distance update.
  *
- * A simple example of how to run AMF is shown below.
+ * A simple example of how to run AMF (or NMF) is shown below.
  *
  * @code
  * extern arma::mat V; // Matrix that we want to perform LMF on.
@@ -34,10 +43,12 @@ namespace amf {
  * arma::mat W; // Basis matrix
  * arma::mat H; // Encoding matrix
  *
- * AMF<> amf; // Default options
+ * AMF<> amf; // Default options: NMF with multiplicative distance update rules.
  * amf.Apply(V, W, H, r);
  * @endcode
  *
+ * @tparam TerminationPolicy The policy to use for determining when the
+ *     factorization has converged.
  * @tparam InitializationRule The initialization rule for initializing W and H
  *     matrix.
  * @tparam UpdateRule The update rule for calculating W and H matrix at each
@@ -45,35 +56,34 @@ namespace amf {
  *
  * @see NMF_MultiplicativeDistanceUpdate
  */
-template<typename TerminationPolicy = SimpleResidueTermination,
-         typename InitializationRule = RandomInitialization,
-         typename UpdateRule = NMFMultiplicativeDistanceUpdate>
+template<typename TerminationPolicyType = SimpleResidueTermination,
+         typename InitializationRuleType = RandomInitialization,
+         typename UpdateRuleType = NMFMultiplicativeDistanceUpdate>
 class AMF
 {
  public:
   /**
-   * Create the LMF object and (optionally) set the parameters which LMF will
+   * Create the AMF object and (optionally) set the parameters which AMF will
    * run with.  The minimum residue refers to the root mean square of the
    * difference between two subsequent iterations of the product W * H.  A low
    * residue indicates that subsequent iterations are not producing much change
    * in W and H.  Once the residue goes below the specified minimum residue, the
    * algorithm terminates.
    *
-   * @param maxIterations Maximum number of iterations allowed before giving up.
-   *     A value of 0 indicates no limit.
-   * @param minResidue The minimum allowed residue before the algorithm
-   *     terminates.
-   * @param Initialize Optional Initialization object for initializing the
-   *     W and H matrices
-   * @param Update Optional UpdateRule object; for when the update rule for
-   *     the W and H vector has states that it needs to store
+   * @param initializationRule Optional instantiated InitializationRule object
+   *      for initializing the W and H matrices.
+   * @param updateRule Optional instantiated UpdateRule object; this parameter
+   *      is useful when the update rule for the W and H vector has state that
+   *      it needs to store (i.e. HUpdate() and WUpdate() are not static
+   *      functions).
+   * @param terminationPolicy Optional instantiated TerminationPolicy object.
    */
-  AMF(const TerminationPolicy& t_policy = TerminationPolicy(),
-      const InitializationRule& initializeRule = InitializationRule(),
-      const UpdateRule& update = UpdateRule());
+  AMF(const TerminationPolicyType& terminationPolicy = TerminationPolicyType(),
+      const InitializationRuleType& initializeRule = InitializationRuleType(),
+      const UpdateRuleType& update = UpdateRuleType());
 
   /**
-   * Apply Latent Matrix Factorization to the provided matrix.
+   * Apply Alternating Matrix Factorization to the provided matrix.
    *
    * @param V Input matrix to be factorized.
    * @param W Basis matrix to be output.
@@ -86,28 +96,30 @@ class AMF
                arma::mat& W,
                arma::mat& H);
 
- private:
-  //! termination policy
-  TerminationPolicy t_policy;
-  //! Instantiated initialization Rule.
-  InitializationRule initializeRule;
-  //! Instantiated update rule.
-  UpdateRule update;
+  //! Access the termination policy.
+  const TerminationPolicyType& TerminationPolicy() const
+  { return terminationPolicy; }
+  //! Modify the termination policy.
+  TerminationPolicyType& TerminationPolicy() { return terminationPolicy; }
 
- public:
   //! Access the initialization rule.
-  const InitializationRule& InitializeRule() const { return initializeRule; }
+  const InitializationRuleType& InitializeRule() const
+  { return initializationRule; }
   //! Modify the initialization rule.
-  InitializationRule& InitializeRule() { return initializeRule; }
+  InitializationRuleType& InitializeRule() { return initializationRule; }
+
   //! Access the update rule.
-  const UpdateRule& Update() const { return update; }
+  const UpdateRuleType& Update() const { return update; }
   //! Modify the update rule.
-  UpdateRule& Update() { return update; }
-  //! Access the termination policy
-  const TerminationPolicy& TPolicy() const { return t_policy; }
-  //! Modify the termination policy
-  TerminationPolicy& TPolicy() { return t_policy; }
+  UpdateRuleType& Update() { return update; }
 
+ private:
+  //! Termination policy.
+  TerminationPolicyType terminationPolicy;
+  //! Instantiated initialization Rule.
+  InitializationRuleType initializationRule;
+  //! Instantiated update rule.
+  UpdateRuleType update;
 }; // class AMF
 
 }; // namespace amf
diff --git a/src/mlpack/methods/amf/amf_impl.hpp b/src/mlpack/methods/amf/amf_impl.hpp
index 4621422..ce9a2aa 100644
--- a/src/mlpack/methods/amf/amf_impl.hpp
+++ b/src/mlpack/methods/amf/amf_impl.hpp
@@ -1,65 +1,66 @@
 /**
  * @file amf_impl.hpp
  * @author Sumedh Ghaisas
+ * @author Mohan Rajendran
+ * @author Ryan Curtin
+ *
+ * Implementation of AMF class.
  */
 namespace mlpack {
 namespace amf {
 
 /**
- * Construct the LMF object.
+ * Construct the AMF object.
  */
-template<typename TerminationPolicy,
-         typename InitializationRule,
-         typename UpdateRule>
-AMF<TerminationPolicy, InitializationRule, UpdateRule>::AMF(
-    const TerminationPolicy& t_policy,
-    const InitializationRule& initializeRule,
-    const UpdateRule& update) :
-    t_policy(t_policy),
-    initializeRule(initializeRule),
+template<typename TerminationPolicyType,
+         typename InitializationRuleType,
+         typename UpdateRuleType>
+AMF<TerminationPolicyType, InitializationRuleType, UpdateRuleType>::AMF(
+    const TerminationPolicyType& terminationPolicy,
+    const InitializationRuleType& initializationRule,
+    const UpdateRuleType& update) :
+    terminationPolicy(terminationPolicy),
+    initializationRule(initializationRule),
     update(update)
 { }
 
 /**
- * Apply Latent Matrix Factorization to the provided matrix.
+ * Apply Alternating Matrix Factorization to the provided matrix.
  *
  * @param V Input matrix to be factorized
  * @param W Basis matrix to be output
  * @param H Encoding matrix to output
  * @param r Rank r of the factorization
  */
-template<typename TerminationPolicy,
-         typename InitializationRule,
-         typename UpdateRule>
+template<typename TerminationPolicyType,
+         typename InitializationRuleType,
+         typename UpdateRuleType>
 template<typename MatType>
-double AMF<TerminationPolicy, InitializationRule, UpdateRule>::Apply(
-    const MatType& V,
-    const size_t r,
-    arma::mat& W,
-    arma::mat& H)
+double AMF<TerminationPolicyType, InitializationRuleType, UpdateRuleType>::
+Apply(const MatType& V,
+      const size_t r,
+      arma::mat& W,
+      arma::mat& H)
 {
   // Initialize W and H.
-  initializeRule.Initialize(V, r, W, H);
+  initializationRule.Initialize(V, r, W, H);
 
   Log::Info << "Initialized W and H." << std::endl;
 
-  arma::mat WH;
-
   update.Initialize(V, r);
-  t_policy.Initialize(V);
+  terminationPolicy.Initialize(V);
 
-  while (!t_policy.IsConverged())
+  while (!terminationPolicy.IsConverged())
   {
-    // Update step.
-    // Update the value of W and H based on the Update Rules provided
+    // Update the values of W and H based on the update rules provided.
     update.WUpdate(V, W, H);
     update.HUpdate(V, W, H);
 
-    t_policy.Step(W, H);
+    terminationPolicy.Step(W, H);
   }
 
-  double residue = sqrt(t_policy.Index());
-  size_t iteration = t_policy.Iteration();
+  const double residue = sqrt(terminationPolicy.Index());
+  const size_t iteration = terminationPolicy.Iteration();
 
   Log::Info << "AMF converged to residue of " << residue << " in "
       << iteration << " iterations." << std::endl;
@@ -67,5 +68,5 @@ double AMF<TerminationPolicy, InitializationRule, UpdateRule>::Apply(
   return residue;
 }
 
-}; // namespace nmf
+}; // namespace amf
 }; // namespace mlpack

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git



More information about the debian-science-commits mailing list