[mlpack] 106/149: Refactor for cleaner code and avoid storing WH explicitly if possible.
Barak A. Pearlmutter
barak+git at pearlmutter.net
Sat May 2 09:11:14 UTC 2015
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch svn-trunk
in repository mlpack.
commit efd784a565c0a69fd7be77c214be33cd7bbb7295
Author: rcurtin <rcurtin at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date: Wed Nov 19 17:08:48 2014 +0000
Refactor for cleaner code and avoid storing WH explicitly if possible.
git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@17387 9d5b8971-822b-0410-80eb-d18c1038ef23
---
.../simple_residue_termination.hpp | 68 +++++++++++-----------
1 file changed, 33 insertions(+), 35 deletions(-)
diff --git a/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp b/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
index 66d7930..3b7f18e 100644
--- a/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
+++ b/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
@@ -13,21 +13,30 @@ namespace mlpack {
namespace amf {
/**
- * This class implements simple residue based termination policy. Termination
- * decision depends on two factors, value of residue and number of iteration.
- * If the current value of residue drops below the threshold or the number of
- * iterations goes above the threshold, positive termination signal is passed
- * to AMF.
+ * This class implements a simple residue-based termination policy. The
+ * termination decision depends on two factors: the value of the residue (the
+ * difference between the norm of WH this iteration and the previous iteration),
+ * and the number of iterations. If the current value of residue drops below
+ * the threshold or the number of iterations goes above the iteration limit,
+ * IsConverged() will return true. This class is meant for use with the AMF
+ * (alternating matrix factorization) class.
*
* @see AMF
*/
class SimpleResidueTermination
{
public:
- //! empty constructor
+ /**
+ * Construct the SimpleResidueTermination object with the given minimum
+ * residue (or the default) and the given maximum number of iterations (or the
+ * default). 0 indicates no iteration limit.
+ *
+ * @param minResidue Minimum residue for termination.
+ * @param maxIterations Maximum number of iterations.
+ */
SimpleResidueTermination(const double minResidue = 1e-10,
const size_t maxIterations = 10000)
- : minResidue(minResidue), maxIterations(maxIterations) { }
+ : minResidue(minResidue), maxIterations(maxIterations) { }
/**
* Initializes the termination policy before stating the factorization.
@@ -37,57 +46,46 @@ class SimpleResidueTermination
template<typename MatType>
void Initialize(const MatType& V)
{
- // set resisue to minimum value
- residue = minResidue;
- // set iteration to minimum value
+ // Initialize the things we keep track of.
+ residue = DBL_MAX;
iteration = 1;
- // remove history
+ nm = V.n_rows * V.n_cols;
+ // Remove history.
normOld = 0;
-
- // initialize required variables
- const size_t n = V.n_rows;
- const size_t m = V.n_cols;
- nm = n * m;
}
/**
- * Check if termination criterio is met.
+ * Check if termination criterion is met.
*
* @param W Basis matrix of output.
* @param H Encoding matrix of output.
*/
bool IsConverged(arma::mat& W, arma::mat& H)
{
- // Calculate norm of WH after each iteration.
- arma::mat WH;
-
- // calculate the norm and compute the residue
- WH = W * H;
- double norm = sqrt(accu(WH % WH) / nm);
- residue = fabs(normOld - norm);
- residue /= normOld;
+ // Calculate the norm and compute the residue
+ const double norm = arma::norm(W * H, "fro");
+ residue = fabs(normOld - norm) / normOld;
- // store the residue into history
+ // Store the norm.
normOld = norm;
-
- // increment iteration count
+
+ // Increment iteration count
iteration++;
-
- // check if termination criterion is met
- if(residue < minResidue || iteration > maxIterations) return true;
- else return false;
+
+ // Check if termination criterion is met.
+ return (residue < minResidue || iteration > maxIterations);
}
//! Get current value of residue
const double& Index() const { return residue; }
- //! Get current iteration count
+ //! Get current iteration count
const size_t& Iteration() const { return iteration; }
-
+
//! Access max iteration count
const size_t& MaxIterations() const { return maxIterations; }
size_t& MaxIterations() { return maxIterations; }
-
+
//! Access minimum residue value
const double& MinResidue() const { return minResidue; }
double& MinResidue() { return minResidue; }
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list