[mlpack] 109/324: * faster implementation of SVDBatchWithMomentum * tolerance termination policy is modified according to new policy * test point selection in validation RMSE termination is shifted to constructor
Barak A. Pearlmutter
barak+git at cs.nuim.ie
Sun Aug 17 08:22:01 UTC 2014
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch svn-trunk
in repository mlpack.
commit 053419e6102c1da6aa84eeba11c94bdf6473866d
Author: sumedhghaisas <sumedhghaisas at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date: Wed Jul 2 21:07:42 2014 +0000
* faster implementation of SVDBatchWithMomentum
* tolerance termination policy is modified according to new policy
* test point selection in validation RMSE termination is shifted to constructor
git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@16747 9d5b8971-822b-0410-80eb-d18c1038ef23
---
src/mlpack/methods/amf/amf.hpp | 6 +-
.../amf/termination_policies/CMakeLists.txt | 1 +
.../simple_residue_termination.hpp | 5 +-
.../simple_tolerance_termination.hpp | 46 +++--
.../validation_RMSE_termination.hpp | 119 ++++++++++++
.../methods/amf/update_rules/svd_batchlearning.hpp | 199 +++++++++++----------
6 files changed, 259 insertions(+), 117 deletions(-)
diff --git a/src/mlpack/methods/amf/amf.hpp b/src/mlpack/methods/amf/amf.hpp
index 55c5398..5ffad8c 100644
--- a/src/mlpack/methods/amf/amf.hpp
+++ b/src/mlpack/methods/amf/amf.hpp
@@ -82,9 +82,9 @@ class AMF
*/
template<typename MatType>
double Apply(const MatType& V,
- const size_t r,
- arma::mat& W,
- arma::mat& H);
+ const size_t r,
+ arma::mat& W,
+ arma::mat& H);
private:
//! termination policy
diff --git a/src/mlpack/methods/amf/termination_policies/CMakeLists.txt b/src/mlpack/methods/amf/termination_policies/CMakeLists.txt
index 1a08f32..d2478ca 100644
--- a/src/mlpack/methods/amf/termination_policies/CMakeLists.txt
+++ b/src/mlpack/methods/amf/termination_policies/CMakeLists.txt
@@ -3,6 +3,7 @@
set(SOURCES
simple_residue_termination.hpp
simple_tolerance_termination.hpp
+ validation_rmse_termination.hpp
)
# Add directory name to sources.
diff --git a/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp b/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
index b41aebd..a47ce23 100644
--- a/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
+++ b/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
@@ -18,7 +18,7 @@ class SimpleResidueTermination
: minResidue(minResidue), maxIterations(maxIterations) { }
template<typename MatType>
- void Initialize(MatType& V)
+ void Initialize(const MatType& V)
{
residue = minResidue;
iteration = 1;
@@ -36,8 +36,7 @@ class SimpleResidueTermination
else return false;
}
- template<typename MatType>
- void Step(const MatType& W, const MatType& H)
+ void Step(const arma::mat& W, const arma::mat& H)
{
// Calculate norm of WH after each iteration.
arma::mat WH;
diff --git a/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp b/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp
index 92f9525..17bf24e 100644
--- a/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp
+++ b/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp
@@ -10,6 +10,7 @@
namespace mlpack {
namespace amf {
+template <class MatType>
class SimpleToleranceTermination
{
public:
@@ -17,44 +18,50 @@ class SimpleToleranceTermination
const size_t maxIterations = 10000)
: tolerance(tolerance), maxIterations(maxIterations) {}
- template<typename MatType>
- void Initialize(MatType& V)
+ void Initialize(const MatType& V)
{
residueOld = DBL_MAX;
iteration = 1;
- normOld = 0;
residue = DBL_MIN;
- const size_t n = V.n_rows;
- const size_t m = V.n_cols;
-
- nm = n * m;
+ this->V = &V;
}
bool IsConverged()
{
- if(((residueOld - residue) / residueOld < tolerance && iteration > 4)
+ if(((residueOld - residue) / residueOld < tolerance && iteration > 4)
|| iteration > maxIterations) return true;
else return false;
}
- template<typename MatType>
- void Step(const MatType& W, const MatType& H)
+ void Step(const arma::mat& W, const arma::mat& H)
{
// Calculate norm of WH after each iteration.
arma::mat WH;
WH = W * H;
- double norm = sqrt(accu(WH % WH) / nm);
- if (iteration != 0)
+ residueOld = residue;
+ size_t n = V->n_rows;
+ size_t m = V->n_cols;
+ double sum = 0;
+ size_t count = 0;
+ for(size_t i = 0;i < n;i++)
{
- residueOld = residue;
- residue = fabs(normOld - norm);
- residue /= normOld;
+ for(size_t j = 0;j < m;j++)
+ {
+ double temp = 0;
+ if((temp = (*V)(i,j)) != 0)
+ {
+ temp = (temp - WH(i, j));
+ temp = temp * temp;
+ sum += temp;
+ count++;
+ }
+ }
}
-
- normOld = norm;
+ residue = sum / count;
+ residue = sqrt(residue);
iteration++;
}
@@ -66,15 +73,16 @@ class SimpleToleranceTermination
double tolerance;
size_t maxIterations;
+ const MatType* V;
+
size_t iteration;
double residueOld;
double residue;
double normOld;
-
- size_t nm;
}; // class SimpleToleranceTermination
}; // namespace amf
}; // namespace mlpack
#endif // _MLPACK_METHODS_AMF_SIMPLE_TOLERANCE_TERMINATION_HPP_INCLUDED
+
diff --git a/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp b/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp
new file mode 100644
index 0000000..297a35b
--- /dev/null
+++ b/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp
@@ -0,0 +1,119 @@
+#ifndef VALIDATION_RMSE_TERMINATION_HPP_INCLUDED
+#define VALIDATION_RMSE_TERMINATION_HPP_INCLUDED
+
+#include <mlpack/core.hpp>
+
+namespace mlpack
+{
+namespace amf
+{
+template <class MatType>
+class ValidationRMSETermination
+{
+ public:
+ ValidationRMSETermination(MatType& V,
+ size_t num_test_points,
+ double tolerance = 1e-5,
+ size_t maxIterations = 10000)
+ : tolerance(tolerance),
+ maxIterations(maxIterations),
+ num_test_points(num_test_points)
+ {
+ size_t n = V.n_rows;
+ size_t m = V.n_cols;
+
+ test_points.zeros(num_test_points, 3);
+
+ for(size_t i = 0; i < num_test_points; i++)
+ {
+ double t_val;
+ size_t t_row;
+ size_t t_col;
+ do
+ {
+ t_row = rand() % n;
+ t_col = rand() % m;
+ } while((t_val = V(t_row, t_col)) == 0);
+
+ test_points(i, 0) = t_row;
+ test_points(i, 1) = t_col;
+ test_points(i, 2) = t_val;
+ V(t_row, t_col) = 0;
+ }
+ }
+
+ void Initialize(const MatType& V)
+ {
+ iteration = 1;
+
+ rmse = DBL_MAX;
+ rmseOld = DBL_MAX;
+ t_count = 0;
+ }
+
+ bool IsConverged()
+ {
+ if((rmseOld - rmse) / rmseOld < tolerance && iteration > 4) t_count++;
+ else t_count = 0;
+
+ if(t_count == 3 || iteration > maxIterations) return true;
+ else return false;
+ }
+
+ void Step(const arma::mat& W, const arma::mat& H)
+ {
+ // Calculate norm of WH after each iteration.
+ arma::mat WH;
+
+ WH = W * H;
+
+ if (iteration != 0)
+ {
+ rmseOld = rmse;
+ rmse = 0;
+ for(size_t i = 0; i < num_test_points; i++)
+ {
+ size_t t_row = test_points(i, 0);
+ size_t t_col = test_points(i, 1);
+ double t_val = test_points(i, 2);
+ double temp = (t_val - WH(t_row, t_col));
+ temp *= temp;
+ rmse += temp;
+ }
+ rmse /= num_test_points;
+ rmse = sqrt(rmse);
+ }
+
+ iteration++;
+ }
+
+ const double& Index()
+ {
+ return rmse;
+ }
+
+ const size_t& Iteration()
+ {
+ return iteration;
+ }
+
+ private:
+ double tolerance;
+ size_t maxIterations;
+ size_t num_test_points;
+ size_t iteration;
+
+ arma::Mat<double> test_points;
+
+ double rmseOld;
+ double rmse;
+
+ size_t t_count;
+};
+
+} // namespace amf
+} // namespace mlpack
+
+
+#endif // VALIDATION_RMSE_TERMINATION_HPP_INCLUDED
+
diff --git a/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp b/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp
index c452895..9658d83 100644
--- a/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp
+++ b/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp
@@ -1,3 +1,7 @@
+/**
+ * @file simple_residue_termination.hpp
+ * @author Sumedh Ghaisas
+ */
#ifndef __MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCHLEARNING_HPP
#define __MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCHLEARNING_HPP
@@ -9,113 +13,123 @@ namespace amf
{
class SVDBatchLearning
{
-public:
- SVDBatchLearning(double u = 0.000001,
- double kw = 0,
- double kh = 0,
- double momentum = 0.2,
- double min = -DBL_MIN,
- double max = DBL_MAX)
+ public:
+ SVDBatchLearning(double u = 0.0002,
+ double kw = 0,
+ double kh = 0,
+ double momentum = 0.5,
+ double min = -DBL_MIN,
+ double max = DBL_MAX)
: u(u), kw(kw), kh(kh), min(min), max(max), momentum(momentum)
{}
- template<typename MatType>
- void Initialize(const MatType& dataset, const size_t rank)
+ template<typename MatType>
+ void Initialize(const MatType& dataset, const size_t rank)
+ {
+ const size_t n = dataset.n_rows;
+ const size_t m = dataset.n_cols;
+
+ mW.zeros(n, rank);
+ mH.zeros(rank, m);
+ }
+
+ /**
+ * The update rule for the basis matrix W.
+ * The function takes in all the matrices and only changes the
+ * value of the W matrix.
+ *
+ * @param V Input matrix to be factorized.
+ * @param W Basis matrix to be updated.
+ * @param H Encoding matrix.
+ */
+ template<typename MatType>
+ inline void WUpdate(const MatType& V,
+ arma::mat& W,
+ const arma::mat& H)
+ {
+ size_t n = V.n_rows;
+
+ size_t r = W.n_cols;
+
+ mW = momentum * mW;
+
+ arma::mat deltaW(n, r);
+ deltaW.zeros();
+
+ for(typename MatType::const_iterator it = V.begin();it != V.end();it++)
{
- const size_t n = dataset.n_rows;
- const size_t m = dataset.n_cols;
-
- mW.zeros(n, rank);
- mH.zeros(rank, m);
+ size_t row = it.row();
+ size_t col = it.col();
+ deltaW.row(it.row()) += (*it - arma::dot(W.row(row), H.col(col))) *
+ arma::trans(H.col(col));
}
- /**
- * The update rule for the basis matrix W.
- * The function takes in all the matrices and only changes the
- * value of the W matrix.
- *
- * @param V Input matrix to be factorized.
- * @param W Basis matrix to be updated.
- * @param H Encoding matrix.
- */
- template<typename MatType>
- inline void WUpdate(const MatType& V,
- arma::mat& W,
- const arma::mat& H)
+ if(kw != 0) for(size_t i = 0; i < n; i++)
{
- size_t n = V.n_rows;
- size_t m = V.n_cols;
-
- size_t r = W.n_cols;
-
- mW = momentum * mW;
-
- arma::mat deltaW(n, r);
- deltaW.zeros();
-
- for(size_t i = 0; i < n; i++)
- {
- for(size_t j = 0; j < m; j++)
- if(V(i,j) != 0) deltaW.row(i) += (V(i,j) - Predict(W.row(i), H.col(j))) * arma::trans(H.col(j));
- deltaW.row(i) -= kw * W.row(i);
- }
-
- mW += u * deltaW;
- W += mW;
+ deltaW.row(i) -= kw * W.row(i);
}
- /**
- * The update rule for the encoding matrix H.
- * The function takes in all the matrices and only changes the
- * value of the H matrix.
- *
- * @param V Input matrix to be factorized.
- * @param W Basis matrix.
- * @param H Encoding matrix to be updated.
- */
- template<typename MatType>
- inline void HUpdate(const MatType& V,
- const arma::mat& W,
- arma::mat& H)
+ mW += u * deltaW;
+ W += mW;
+ }
+
+ /**
+ * The update rule for the encoding matrix H.
+ * The function takes in all the matrices and only changes the
+ * value of the H matrix.
+ *
+ * @param V Input matrix to be factorized.
+ * @param W Basis matrix.
+ * @param H Encoding matrix to be updated.
+ */
+ template<typename MatType>
+ inline void HUpdate(const MatType& V,
+ const arma::mat& W,
+ arma::mat& H)
+ {
+ size_t m = V.n_cols;
+
+ size_t r = W.n_cols;
+
+ mH = momentum * mH;
+
+ arma::mat deltaH(r, m);
+ deltaH.zeros();
+
+ for(typename MatType::const_iterator it = V.begin();it != V.end();it++)
{
- size_t n = V.n_rows;
- size_t m = V.n_cols;
-
- size_t r = W.n_cols;
-
- mH = momentum * mH;
-
- arma::mat deltaH(r, m);
- deltaH.zeros();
-
- for(size_t j = 0; j < m; j++)
- {
- for(size_t i = 0; i < n; i++)
- if(V(i,j) != 0) deltaH.col(j) += (V(i,j) - Predict(W.row(i), H.col(j))) * arma::trans(W.row(i));
- deltaH.col(j) -= kh * H.col(j);
- }
-
- mH += u*deltaH;
- H += mH;
+ size_t row = it.row();
+ size_t col = it.col();
+ deltaH.col(col) += (*it - arma::dot(W.row(row), H.col(col))) *
+ arma::trans(W.row(row));
}
-private:
- double Predict(const arma::mat& wi, const arma::mat& hj) const
+ if(kh != 0) for(size_t j = 0; j < m; j++)
{
- arma::mat temp = (wi * hj);
- double out = temp(0,0);
- return out;
+ deltaH.col(j) -= kh * H.col(j);
}
- double u;
- double kw;
- double kh;
- double min;
- double max;
- double momentum;
-
- arma::mat mW;
- arma::mat mH;
+ mH += u*deltaH;
+ H += mH;
+ }
+
+ private:
+ double Predict(const arma::mat& wi, const arma::mat& hj) const
+ {
+ arma::mat temp = (wi * hj);
+ double out = temp(0,0);
+ return out;
+ }
+
+ double u;
+ double kw;
+ double kh;
+ double min;
+ double max;
+ double momentum;
+
+ arma::mat mW;
+ arma::mat mH;
};
} // namespace amf
} // namespace mlpack
@@ -123,3 +137,4 @@ private:
#endif
+
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list