[mlpack] 137/324: * added SVD Incomplete incremental learning * added Termination Policy wrapper for SVD Incomplete Learning
Barak A. Pearlmutter
barak+git at cs.nuim.ie
Sun Aug 17 08:22:03 UTC 2014
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch svn-trunk
in repository mlpack.
commit dfc837295fe03382de1a395c24c3f1177b2173d1
Author: sumedhghaisas <sumedhghaisas at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date: Mon Jul 7 21:31:54 2014 +0000
* added SVD Incomplete incremental learning
* added Termination Policy wrapper for SVD Incomplete Learning
git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@16776 9d5b8971-822b-0410-80eb-d18c1038ef23
---
.../amf/termination_policies/CMakeLists.txt | 1 +
.../incomplete_incremental_termination.hpp | 60 +++++++++
src/mlpack/methods/amf/update_rules/CMakeLists.txt | 2 +
.../methods/amf/update_rules/svd_batchlearning.hpp | 6 +-
.../amf/update_rules/svd_incremental_learning.hpp | 147 +++++++++++++++++++++
5 files changed, 214 insertions(+), 2 deletions(-)
diff --git a/src/mlpack/methods/amf/termination_policies/CMakeLists.txt b/src/mlpack/methods/amf/termination_policies/CMakeLists.txt
index d2478ca..2517b14 100644
--- a/src/mlpack/methods/amf/termination_policies/CMakeLists.txt
+++ b/src/mlpack/methods/amf/termination_policies/CMakeLists.txt
@@ -4,6 +4,7 @@ set(SOURCES
simple_residue_termination.hpp
simple_tolerance_termination.hpp
validation_rmse_termination.hpp
+ incomplete_incremental_termination.hpp
)
# Add directory name to sources.
diff --git a/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp b/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp
new file mode 100644
index 0000000..d24d571
--- /dev/null
+++ b/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp
@@ -0,0 +1,60 @@
+/**
+ * @file incomplete_incremental_termination.hpp
+ * @author Sumedh Ghaisas
+ */
+#ifndef _INCOMPLETE_INCREMENTAL_TERMINATION_HPP_INCLUDED
+#define _INCOMPLETE_INCREMENTAL_TERMINATION_HPP_INCLUDED
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace amf {
+
+template <class TerminationPolicy>
+class IncompleteIncrementalTermination
+{
+ public:
+ IncompleteIncrementalTermination(TerminationPolicy t_policy = TerminationPolicy())
+ : t_policy(t_policy) {}
+
+ template <class MatType>
+ void Initialize(const MatType& V)
+ {
+ t_policy.Initialize(V);
+
+ incrementalIndex = V.n_rows;
+ iteration = 0;
+ }
+
+ bool IsConverged()
+ {
+ return t_policy.IsConverged();
+ }
+
+ void Step(const arma::mat& W, const arma::mat& H)
+ {
+ if(iteration % incrementalIndex == 0) t_policy.Step(W, H);
+ iteration++;
+ }
+
+ const double& Index()
+ {
+ return t_policy.Index();
+ }
+ const size_t& Iteration()
+ {
+ return iteration;
+ }
+
+ private:
+ TerminationPolicy t_policy;
+
+ size_t incrementalIndex;
+ size_t iteration;
+};
+
+}; // namespace amf
+}; // namespace mlpack
+
+#endif
+
diff --git a/src/mlpack/methods/amf/update_rules/CMakeLists.txt b/src/mlpack/methods/amf/update_rules/CMakeLists.txt
index 011ec27..b7bde1c 100644
--- a/src/mlpack/methods/amf/update_rules/CMakeLists.txt
+++ b/src/mlpack/methods/amf/update_rules/CMakeLists.txt
@@ -4,6 +4,8 @@ set(SOURCES
nmf_als.hpp
nmf_mult_dist.hpp
nmf_mult_div.hpp
+ svd_batchlearning.hpp
+ svd_incremental_learning.hpp
)
# Add directory name to sources.
diff --git a/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp b/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp
index 267d651..bb31e9e 100644
--- a/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp
+++ b/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp
@@ -143,7 +143,8 @@ inline void SVDBatchLearning::WUpdate<arma::sp_mat>(const arma::sp_mat& V,
{
size_t row = it.row();
size_t col = it.col();
- deltaW.row(it.row()) += (*it - arma::dot(W.row(row), H.col(col))) * arma::trans(H.col(col));
+ deltaW.row(it.row()) += (*it - arma::dot(W.row(row), H.col(col))) *
+ arma::trans(H.col(col));
}
if(kw != 0) for(size_t i = 0; i < n; i++)
@@ -173,7 +174,8 @@ inline void SVDBatchLearning::HUpdate<arma::sp_mat>(const arma::sp_mat& V,
{
size_t row = it.row();
size_t col = it.col();
- deltaH.col(col) += (*it - arma::dot(W.row(row), H.col(col))) * arma::trans(W.row(row));
+ deltaH.col(col) += (*it - arma::dot(W.row(row), H.col(col))) *
+ arma::trans(W.row(row));
}
if(kh != 0) for(size_t j = 0; j < m; j++)
diff --git a/src/mlpack/methods/amf/update_rules/svd_incremental_learning.hpp b/src/mlpack/methods/amf/update_rules/svd_incremental_learning.hpp
new file mode 100644
index 0000000..d1ec14d
--- /dev/null
+++ b/src/mlpack/methods/amf/update_rules/svd_incremental_learning.hpp
@@ -0,0 +1,147 @@
+#ifndef SVD_INCREMENTAL_LEARNING_HPP_INCLUDED
+#define SVD_INCREMENTAL_LEARNING_HPP_INCLUDED
+
+namespace mlpack
+{
+namespace amf
+{
+class SVDIncrementalLearning
+{
+ public:
+ SVDIncrementalLearning(double u = 0.001,
+ double kw = 0,
+ double kh = 0,
+ double min = -DBL_MIN,
+ double max = DBL_MAX)
+ : u(u), kw(kw), kh(kh), min(min), max(max)
+ {}
+
+ template<typename MatType>
+ void Initialize(const MatType& dataset, const size_t rank)
+ {
+ n = dataset.n_rows;
+ m = dataset.n_cols;
+
+ currentUserIndex = 0;
+ }
+
+ /**
+ * The update rule for the basis matrix W.
+ * The function takes in all the matrices and only changes the
+ * value of the W matrix.
+ *
+ * @param V Input matrix to be factorized.
+ * @param W Basis matrix to be updated.
+ * @param H Encoding matrix.
+ */
+ template<typename MatType>
+ inline void WUpdate(const MatType& V,
+ arma::mat& W,
+ const arma::mat& H)
+ {
+ arma::mat deltaW(n, W.n_cols);
+ deltaW.zeros();
+ for(size_t i = 0;i < n;i++)
+ {
+ double val;
+ if((val = V(i, currentUserIndex)) != 0)
+ deltaW.row(i) += (val - arma::dot(W.row(i), H.col(currentUserIndex))) *
+ arma::trans(H.col(currentUserIndex));
+ if(kw != 0) deltaW -= kw * W.row(i);
+ }
+
+ W += u*deltaW;
+ }
+
+ /**
+ * The update rule for the encoding matrix H.
+ * The function takes in all the matrices and only changes the
+ * value of the H matrix.
+ *
+ * @param V Input matrix to be factorized.
+ * @param W Basis matrix.
+ * @param H Encoding matrix to be updated.
+ */
+ template<typename MatType>
+ inline void HUpdate(const MatType& V,
+ const arma::mat& W,
+ arma::mat& H)
+ {
+ arma::mat deltaH(H.n_rows, 1);
+ deltaH.zeros();
+
+ for(size_t i = 0;i < n;i++)
+ {
+ double val;
+ if((val = V(i, currentUserIndex)) != 0)
+ deltaH += (val - arma::dot(W.row(i), H.col(currentUserIndex))) *
+ arma::trans(W.row(i));
+ }
+ if(kh != 0) deltaH -= kh * H.col(currentUserIndex);
+
+ H.col(currentUserIndex++) += u * deltaH;
+ currentUserIndex = currentUserIndex % m;
+ }
+
+ private:
+ double u;
+ double kw;
+ double kh;
+ double min;
+ double max;
+
+ size_t n;
+ size_t m;
+
+ size_t currentUserIndex;
+};
+
+template<>
+inline void SVDIncrementalLearning::WUpdate<arma::sp_mat>(const arma::sp_mat& V,
+ arma::mat& W,
+ const arma::mat& H)
+{
+ arma::mat deltaW(n, W.n_cols);
+ deltaW.zeros();
+ for(arma::sp_mat::const_iterator it = V.begin_col(currentUserIndex);
+ it != V.end_col(currentUserIndex);it++)
+ {
+ double val = *it;
+ size_t i = it.row();
+ deltaW.row(i) += (val - arma::dot(W.row(i), H.col(currentUserIndex))) *
+ arma::trans(H.col(currentUserIndex));
+ if(kw != 0) deltaW -= kw * W.row(i);
+ }
+
+ W += u*deltaW;
+}
+
+template<>
+inline void SVDIncrementalLearning::HUpdate<arma::sp_mat>(const arma::sp_mat& V,
+ const arma::mat& W,
+ arma::mat& H)
+{
+ arma::mat deltaH(H.n_rows, 1);
+ deltaH.zeros();
+
+ for(arma::sp_mat::const_iterator it = V.begin_col(currentUserIndex);
+ it != V.end_col(currentUserIndex);it++)
+ {
+ double val = *it;
+ size_t i = it.row();
+ if((val = V(i, currentUserIndex)) != 0)
+ deltaH += (val - arma::dot(W.row(i), H.col(currentUserIndex))) *
+ arma::trans(W.row(i));
+ }
+ if(kh != 0) deltaH -= kh * H.col(currentUserIndex);
+
+ H.col(currentUserIndex++) += u * deltaH;
+ currentUserIndex = currentUserIndex % m;
+}
+
+}; // namepsace amf
+}; // namespace mlpack
+
+
+#endif // SVD_INCREMENTAL_LEARNING_HPP_INCLUDED
+
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list