[mlpack] 186/324: * added svd incomplete incremental learning tests * combined functions IsConverged and Step of termination policies into IsConverged
Barak A. Pearlmutter
barak+git at cs.nuim.ie
Sun Aug 17 08:22:09 UTC 2014
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch svn-trunk
in repository mlpack.
commit 4c049ec1bf5d152f51ea22d99f67df1c22b03bef
Author: sumedhghaisas <sumedhghaisas at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date: Wed Jul 16 11:14:27 2014 +0000
* added svd incomplete incremental learning tests
* combined functions IsConverged and Step of termination policies into IsConverged
git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@16825 9d5b8971-822b-0410-80eb-d18c1038ef23
---
src/mlpack/methods/amf/amf_impl.hpp | 4 +-
.../incomplete_incremental_termination.hpp | 9 ++--
.../simple_residue_termination.hpp | 11 ++--
.../simple_tolerance_termination.hpp | 61 ++++++++++------------
.../validation_RMSE_termination.hpp | 53 +++++++++----------
.../amf/update_rules/svd_incremental_learning.hpp | 4 +-
src/mlpack/tests/svd_incremental_test.cpp | 11 ++--
7 files changed, 68 insertions(+), 85 deletions(-)
diff --git a/src/mlpack/methods/amf/amf_impl.hpp b/src/mlpack/methods/amf/amf_impl.hpp
index a887931..d99cf57 100644
--- a/src/mlpack/methods/amf/amf_impl.hpp
+++ b/src/mlpack/methods/amf/amf_impl.hpp
@@ -55,11 +55,9 @@ Apply(const MatType& V,
// Update the values of W and H based on the update rules provided.
update.WUpdate(V, W, H);
update.HUpdate(V, W, H);
-
- terminationPolicy.Step(W, H);
}
- const double residue = sqrt(terminationPolicy.Index());
+ const double residue = terminationPolicy.Index();
const size_t iteration = terminationPolicy.Iteration();
Log::Info << "AMF converged to residue of " << residue << " in "
diff --git a/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp b/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp
index cfa499e..d53b8b7 100644
--- a/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp
+++ b/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp
@@ -28,13 +28,10 @@ class IncompleteIncrementalTermination
bool IsConverged(arma::mat& W, arma::mat& H)
{
- return t_policy.IsConverged(W, H);
- }
-
- void Step(const arma::mat& W, const arma::mat& H)
- {
- if(iteration % incrementalIndex == 0) t_policy.Step(W, H);
iteration++;
+ if(iteration % incrementalIndex == 0)
+ return t_policy.IsConverged(W, H);
+ else return false;
}
const double& Index()
diff --git a/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp b/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
index b5c4fb5..3e5f7b8 100644
--- a/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
+++ b/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
@@ -32,14 +32,6 @@ class SimpleResidueTermination
bool IsConverged(arma::mat& W, arma::mat& H)
{
- (void)W;
- (void)H;
- if(residue < minResidue || iteration > maxIterations) return true;
- else return false;
- }
-
- void Step(const arma::mat& W, const arma::mat& H)
- {
// Calculate norm of WH after each iteration.
arma::mat WH;
@@ -55,6 +47,9 @@ class SimpleResidueTermination
normOld = norm;
iteration++;
+
+ if(residue < minResidue || iteration > maxIterations) return true;
+ else return false;
}
const double& Index() { return residue; }
diff --git a/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp b/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp
index dfc78bd..8976c14 100644
--- a/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp
+++ b/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp
@@ -32,6 +32,35 @@ class SimpleToleranceTermination
bool IsConverged(arma::mat& W, arma::mat& H)
{
+ // Calculate norm of WH after each iteration.
+ arma::mat WH;
+
+ WH = W * H;
+
+ residueOld = residue;
+ size_t n = V->n_rows;
+ size_t m = V->n_cols;
+ double sum = 0;
+ size_t count = 0;
+ for(size_t i = 0;i < n;i++)
+ {
+ for(size_t j = 0;j < m;j++)
+ {
+ double temp = 0;
+ if((temp = (*V)(i,j)) != 0)
+ {
+ temp = (temp - WH(i, j));
+ temp = temp * temp;
+ sum += temp;
+ count++;
+ }
+ }
+ }
+ residue = sum / count;
+ residue = sqrt(residue);
+
+ iteration++;
+
if((residueOld - residue) / residueOld < tolerance && iteration > 4)
{
if(reverseStepCount == 0 && isCopy == false)
@@ -66,38 +95,6 @@ class SimpleToleranceTermination
else return false;
}
- void Step(const arma::mat& W, const arma::mat& H)
- {
- // Calculate norm of WH after each iteration.
- arma::mat WH;
-
- WH = W * H;
-
- residueOld = residue;
- size_t n = V->n_rows;
- size_t m = V->n_cols;
- double sum = 0;
- size_t count = 0;
- for(size_t i = 0;i < n;i++)
- {
- for(size_t j = 0;j < m;j++)
- {
- double temp = 0;
- if((temp = (*V)(i,j)) != 0)
- {
- temp = (temp - WH(i, j));
- temp = temp * temp;
- sum += temp;
- count++;
- }
- }
- }
- residue = sum / count;
- residue = sqrt(residue);
-
- iteration++;
- }
-
const double& Index() { return residue; }
const size_t& Iteration() { return iteration; }
const size_t& MaxIterations() { return maxIterations; }
diff --git a/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp b/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp
index 7f03954..a437ce5 100644
--- a/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp
+++ b/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp
@@ -56,6 +56,30 @@ class ValidationRMSETermination
bool IsConverged(arma::mat& W, arma::mat& H)
{
+ // Calculate norm of WH after each iteration.
+ arma::mat WH;
+
+ WH = W * H;
+
+ if (iteration != 0)
+ {
+ rmseOld = rmse;
+ rmse = 0;
+ for(size_t i = 0; i < num_test_points; i++)
+ {
+ size_t t_row = test_points(i, 0);
+ size_t t_col = test_points(i, 1);
+ double t_val = test_points(i, 2);
+ double temp = (t_val - WH(t_row, t_col));
+ temp *= temp;
+ rmse += temp;
+ }
+ rmse /= num_test_points;
+ rmse = sqrt(rmse);
+ }
+
+ iteration++;
+
if((rmseOld - rmse) / rmseOld < tolerance && iteration > 4)
{
if(reverseStepCount == 0 && isCopy == false)
@@ -89,34 +113,7 @@ class ValidationRMSETermination
}
else return false;
}
-
- void Step(const arma::mat& W, const arma::mat& H)
- {
- // Calculate norm of WH after each iteration.
- arma::mat WH;
-
- WH = W * H;
-
- if (iteration != 0)
- {
- rmseOld = rmse;
- rmse = 0;
- for(size_t i = 0; i < num_test_points; i++)
- {
- size_t t_row = test_points(i, 0);
- size_t t_col = test_points(i, 1);
- double t_val = test_points(i, 2);
- double temp = (t_val - WH(t_row, t_col));
- temp *= temp;
- rmse += temp;
- }
- rmse /= num_test_points;
- rmse = sqrt(rmse);
- }
-
- iteration++;
- }
-
+
const double& Index() { return rmse; }
const size_t& Iteration() { return iteration; }
diff --git a/src/mlpack/methods/amf/update_rules/svd_incremental_learning.hpp b/src/mlpack/methods/amf/update_rules/svd_incremental_learning.hpp
index a0388c1..4cc7053 100644
--- a/src/mlpack/methods/amf/update_rules/svd_incremental_learning.hpp
+++ b/src/mlpack/methods/amf/update_rules/svd_incremental_learning.hpp
@@ -49,7 +49,7 @@ class SVDIncrementalLearning
if((val = V(i, currentUserIndex)) != 0)
deltaW.row(i) += (val - arma::dot(W.row(i), H.col(currentUserIndex))) *
arma::trans(H.col(currentUserIndex));
- if(kw != 0) deltaW -= kw * W.row(i);
+ if(kw != 0) deltaW.row(i) -= kw * W.row(i);
}
W += u*deltaW;
@@ -112,7 +112,7 @@ inline void SVDIncrementalLearning::WUpdate<arma::sp_mat>(const arma::sp_mat& V,
size_t i = it.row();
deltaW.row(i) += (val - arma::dot(W.row(i), H.col(currentUserIndex))) *
arma::trans(H.col(currentUserIndex));
- if(kw != 0) deltaW -= kw * W.row(i);
+ if(kw != 0) deltaW.row(i) -= kw * W.row(i);
}
W += u*deltaW;
diff --git a/src/mlpack/tests/svd_incremental_test.cpp b/src/mlpack/tests/svd_incremental_test.cpp
index 00b55d3..682039c 100644
--- a/src/mlpack/tests/svd_incremental_test.cpp
+++ b/src/mlpack/tests/svd_incremental_test.cpp
@@ -4,6 +4,7 @@
#include <mlpack/methods/amf/init_rules/random_init.hpp>
#include <mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp>
#include <mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp>
+#include <mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp>
#include <boost/test/unit_test.hpp>
#include "old_boost_test_definitions.hpp"
@@ -35,7 +36,7 @@ BOOST_AUTO_TEST_CASE(SVDIncrementalConvergenceTest)
amf.TerminationPolicy().MaxIterations());
}
-/*
+
BOOST_AUTO_TEST_CASE(SVDIncrementalRegularizationTest)
{
mat dataset;
@@ -78,14 +79,12 @@ BOOST_AUTO_TEST_CASE(SVDIncrementalRegularizationTest)
RandomInitialization,
SVDIncrementalLearning> amf_2(vrt2,
RandomInitialization(),
- SVDIncrementalLearning(0.001, 1e-5, 2e-5));
+ SVDIncrementalLearning(0.001, 0.01, 0.01));
mat m3, m4;
double RMSE_2 = amf_2.Apply(cleanedData2, 2, m3, m4);
-
- // RMSE_2 should be less than RMSE_1
- std::cout << RMSE_1 << " " << RMSE_2 << std::endl;
+
+ BOOST_REQUIRE_LT(RMSE_2, RMSE_1);
}
-*/
BOOST_AUTO_TEST_SUITE_END();
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list