[mlpack] 186/324: * added svd incomplete incremental learning tests * combined functions IsConverged and Step of termination policies into IsConverged

Barak A. Pearlmutter barak+git at cs.nuim.ie
Sun Aug 17 08:22:09 UTC 2014


This is an automated email from the git hooks/post-receive script.

bap pushed a commit to branch svn-trunk
in repository mlpack.

commit 4c049ec1bf5d152f51ea22d99f67df1c22b03bef
Author: sumedhghaisas <sumedhghaisas at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date:   Wed Jul 16 11:14:27 2014 +0000

    * added svd incomplete incremental learning tests
    * combined functions IsConverged and Step of termination policies into IsConverged
    
    
    git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@16825 9d5b8971-822b-0410-80eb-d18c1038ef23
---
 src/mlpack/methods/amf/amf_impl.hpp                |  4 +-
 .../incomplete_incremental_termination.hpp         |  9 ++--
 .../simple_residue_termination.hpp                 | 11 ++--
 .../simple_tolerance_termination.hpp               | 61 ++++++++++------------
 .../validation_RMSE_termination.hpp                | 53 +++++++++----------
 .../amf/update_rules/svd_incremental_learning.hpp  |  4 +-
 src/mlpack/tests/svd_incremental_test.cpp          | 11 ++--
 7 files changed, 68 insertions(+), 85 deletions(-)

diff --git a/src/mlpack/methods/amf/amf_impl.hpp b/src/mlpack/methods/amf/amf_impl.hpp
index a887931..d99cf57 100644
--- a/src/mlpack/methods/amf/amf_impl.hpp
+++ b/src/mlpack/methods/amf/amf_impl.hpp
@@ -55,11 +55,9 @@ Apply(const MatType& V,
     // Update the values of W and H based on the update rules provided.
     update.WUpdate(V, W, H);
     update.HUpdate(V, W, H);
-
-    terminationPolicy.Step(W, H);
   }
 
-  const double residue = sqrt(terminationPolicy.Index());
+  const double residue = terminationPolicy.Index();
   const size_t iteration = terminationPolicy.Iteration();
 
   Log::Info << "AMF converged to residue of " << residue << " in "
diff --git a/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp b/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp
index cfa499e..d53b8b7 100644
--- a/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp
+++ b/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp
@@ -28,13 +28,10 @@ class IncompleteIncrementalTermination
 
   bool IsConverged(arma::mat& W, arma::mat& H)
   {
-    return t_policy.IsConverged(W, H);
-  }
-
-  void Step(const arma::mat& W, const arma::mat& H)
-  {
-    if(iteration % incrementalIndex == 0) t_policy.Step(W, H);
     iteration++;
+    if(iteration % incrementalIndex == 0)  
+      return t_policy.IsConverged(W, H);
+    else return false;
   }
 
   const double& Index()
diff --git a/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp b/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
index b5c4fb5..3e5f7b8 100644
--- a/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
+++ b/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
@@ -32,14 +32,6 @@ class SimpleResidueTermination
 
   bool IsConverged(arma::mat& W, arma::mat& H)
   {
-    (void)W;
-    (void)H;
-    if(residue < minResidue || iteration > maxIterations) return true;
-    else return false;
-  }
-
-  void Step(const arma::mat& W, const arma::mat& H)
-  {
     // Calculate norm of WH after each iteration.
     arma::mat WH;
 
@@ -55,6 +47,9 @@ class SimpleResidueTermination
     normOld = norm;
 
     iteration++;
+    
+    if(residue < minResidue || iteration > maxIterations) return true;
+    else return false;
   }
 
   const double& Index() { return residue; }
diff --git a/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp b/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp
index dfc78bd..8976c14 100644
--- a/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp
+++ b/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp
@@ -32,6 +32,35 @@ class SimpleToleranceTermination
 
   bool IsConverged(arma::mat& W, arma::mat& H)
   {
+    // Calculate norm of WH after each iteration.
+    arma::mat WH;
+
+    WH = W * H;
+
+    residueOld = residue;
+    size_t n = V->n_rows;
+    size_t m = V->n_cols;
+    double sum = 0;
+    size_t count = 0;
+    for(size_t i = 0;i < n;i++)
+    {
+        for(size_t j = 0;j < m;j++)
+        {
+            double temp = 0;
+            if((temp = (*V)(i,j)) != 0)
+            {
+                temp = (temp - WH(i, j));
+                temp = temp * temp;
+                sum += temp;
+                count++;
+            }
+        }
+    }
+    residue = sum / count;
+    residue = sqrt(residue);
+
+    iteration++;  
+  
     if((residueOld - residue) / residueOld < tolerance && iteration > 4)
     {
       if(reverseStepCount == 0 && isCopy == false)
@@ -66,38 +95,6 @@ class SimpleToleranceTermination
     else return false;
   }
 
-  void Step(const arma::mat& W, const arma::mat& H)
-  {
-    // Calculate norm of WH after each iteration.
-    arma::mat WH;
-
-    WH = W * H;
-
-    residueOld = residue;
-    size_t n = V->n_rows;
-    size_t m = V->n_cols;
-    double sum = 0;
-    size_t count = 0;
-    for(size_t i = 0;i < n;i++)
-    {
-        for(size_t j = 0;j < m;j++)
-        {
-            double temp = 0;
-            if((temp = (*V)(i,j)) != 0)
-            {
-                temp = (temp - WH(i, j));
-                temp = temp * temp;
-                sum += temp;
-                count++;
-            }
-        }
-    }
-    residue = sum / count;
-    residue = sqrt(residue);
-
-    iteration++;
-  }
-
   const double& Index() { return residue; }
   const size_t& Iteration() { return iteration; }
   const size_t& MaxIterations() { return maxIterations; }
diff --git a/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp b/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp
index 7f03954..a437ce5 100644
--- a/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp
+++ b/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp
@@ -56,6 +56,30 @@ class ValidationRMSETermination
 
   bool IsConverged(arma::mat& W, arma::mat& H)
   {
+    // Calculate norm of WH after each iteration.
+    arma::mat WH;
+
+    WH = W * H;
+
+    if (iteration != 0)
+    {
+      rmseOld = rmse;
+      rmse = 0;
+      for(size_t i = 0; i < num_test_points; i++)
+      {
+        size_t t_row = test_points(i, 0);
+        size_t t_col = test_points(i, 1);
+        double t_val = test_points(i, 2);
+        double temp = (t_val - WH(t_row, t_col));
+        temp *= temp;
+        rmse += temp;
+      }
+      rmse /= num_test_points;
+      rmse = sqrt(rmse);
+    }
+
+    iteration++;
+  
     if((rmseOld - rmse) / rmseOld < tolerance && iteration > 4)
     {
       if(reverseStepCount == 0 && isCopy == false)
@@ -89,34 +113,7 @@ class ValidationRMSETermination
     }
     else return false;
   }
-
-  void Step(const arma::mat& W, const arma::mat& H)
-  {
-    // Calculate norm of WH after each iteration.
-    arma::mat WH;
-
-    WH = W * H;
-
-    if (iteration != 0)
-    {
-      rmseOld = rmse;
-      rmse = 0;
-      for(size_t i = 0; i < num_test_points; i++)
-      {
-        size_t t_row = test_points(i, 0);
-        size_t t_col = test_points(i, 1);
-        double t_val = test_points(i, 2);
-        double temp = (t_val - WH(t_row, t_col));
-        temp *= temp;
-        rmse += temp;
-      }
-      rmse /= num_test_points;
-      rmse = sqrt(rmse);
-    }
-
-    iteration++;
-  }
-
+  
   const double& Index() { return rmse; }
 
   const size_t& Iteration() { return iteration; }
diff --git a/src/mlpack/methods/amf/update_rules/svd_incremental_learning.hpp b/src/mlpack/methods/amf/update_rules/svd_incremental_learning.hpp
index a0388c1..4cc7053 100644
--- a/src/mlpack/methods/amf/update_rules/svd_incremental_learning.hpp
+++ b/src/mlpack/methods/amf/update_rules/svd_incremental_learning.hpp
@@ -49,7 +49,7 @@ class SVDIncrementalLearning
       if((val = V(i, currentUserIndex)) != 0)
         deltaW.row(i) += (val - arma::dot(W.row(i), H.col(currentUserIndex))) *
                                          arma::trans(H.col(currentUserIndex));
-      if(kw != 0) deltaW -= kw * W.row(i);
+      if(kw != 0) deltaW.row(i) -= kw * W.row(i);
     }
 
     W += u*deltaW;
@@ -112,7 +112,7 @@ inline void SVDIncrementalLearning::WUpdate<arma::sp_mat>(const arma::sp_mat& V,
     size_t i = it.row();
     deltaW.row(i) += (val - arma::dot(W.row(i), H.col(currentUserIndex))) *
                                          arma::trans(H.col(currentUserIndex));
-    if(kw != 0) deltaW -= kw * W.row(i);
+    if(kw != 0) deltaW.row(i) -= kw * W.row(i);
   }
 
   W += u*deltaW;
diff --git a/src/mlpack/tests/svd_incremental_test.cpp b/src/mlpack/tests/svd_incremental_test.cpp
index 00b55d3..682039c 100644
--- a/src/mlpack/tests/svd_incremental_test.cpp
+++ b/src/mlpack/tests/svd_incremental_test.cpp
@@ -4,6 +4,7 @@
 #include <mlpack/methods/amf/init_rules/random_init.hpp>
 #include <mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp>
 #include <mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp>
+#include <mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp>
 
 #include <boost/test/unit_test.hpp>
 #include "old_boost_test_definitions.hpp"
@@ -35,7 +36,7 @@ BOOST_AUTO_TEST_CASE(SVDIncrementalConvergenceTest)
                     amf.TerminationPolicy().MaxIterations());
 }
 
-/*
+
 BOOST_AUTO_TEST_CASE(SVDIncrementalRegularizationTest)
 {
   mat dataset;
@@ -78,14 +79,12 @@ BOOST_AUTO_TEST_CASE(SVDIncrementalRegularizationTest)
       RandomInitialization,
       SVDIncrementalLearning> amf_2(vrt2,
                               RandomInitialization(),
-                              SVDIncrementalLearning(0.001, 1e-5, 2e-5));
+                              SVDIncrementalLearning(0.001, 0.01, 0.01));
 
   mat m3, m4;
   double RMSE_2 = amf_2.Apply(cleanedData2, 2, m3, m4);
-
-  // RMSE_2 should be less than RMSE_1
-  std::cout << RMSE_1 << " " << RMSE_2 << std::endl;
+  
+  BOOST_REQUIRE_LT(RMSE_2, RMSE_1);
 }
-*/
 
 BOOST_AUTO_TEST_SUITE_END();

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git



More information about the debian-science-commits mailing list