[mlpack] 165/324: * added local minima storing functionality to termination policies

Barak A. Pearlmutter barak+git at cs.nuim.ie
Sun Aug 17 08:22:06 UTC 2014


This is an automated email from the git hooks/post-receive script.

bap pushed a commit to branch svn-trunk
in repository mlpack.

commit 41e0e9d1b4c47613feefa228fcbb0c31bab40f77
Author: sumedhghaisas <sumedhghaisas at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date:   Wed Jul 9 22:14:52 2014 +0000

    * added local minima storing functionality to termination policies
    
    
    git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@16804 9d5b8971-822b-0410-80eb-d18c1038ef23
---
 src/mlpack/methods/amf/amf_impl.hpp                |  2 +-
 .../incomplete_incremental_termination.hpp         |  4 +--
 .../simple_residue_termination.hpp                 |  4 ++-
 .../simple_tolerance_termination.hpp               | 35 +++++++++++++++++--
 .../validation_RMSE_termination.hpp                | 39 +++++++++++++++++++---
 src/mlpack/methods/amf/update_rules/CMakeLists.txt |  2 +-
 ...vd_batchlearning.hpp => svd_batch_learning.hpp} |  0
 src/mlpack/tests/CMakeLists.txt                    |  2 +-
 .../tests/{svd_test.cpp => svd_batch_test.cpp}     |  4 +--
 9 files changed, 77 insertions(+), 15 deletions(-)

diff --git a/src/mlpack/methods/amf/amf_impl.hpp b/src/mlpack/methods/amf/amf_impl.hpp
index ce9a2aa..a887931 100644
--- a/src/mlpack/methods/amf/amf_impl.hpp
+++ b/src/mlpack/methods/amf/amf_impl.hpp
@@ -50,7 +50,7 @@ Apply(const MatType& V,
   update.Initialize(V, r);
   terminationPolicy.Initialize(V);
 
-  while (!terminationPolicy.IsConverged())
+  while (!terminationPolicy.IsConverged(W, H))
   {
     // Update the values of W and H based on the update rules provided.
     update.WUpdate(V, W, H);
diff --git a/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp b/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp
index d24d571..91424db 100644
--- a/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp
+++ b/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp
@@ -26,9 +26,9 @@ class IncompleteIncrementalTermination
     iteration = 0;
   }
 
-  bool IsConverged()
+  bool IsConverged(arma::mat& W, arma::mat& H)
   {
-    return t_policy.IsConverged();
+    return t_policy.IsConverged(W, H);
   }
 
   void Step(const arma::mat& W, const arma::mat& H)
diff --git a/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp b/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
index a47ce23..cbae4ba 100644
--- a/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
+++ b/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
@@ -30,8 +30,10 @@ class SimpleResidueTermination
     nm = n * m;
   }
 
-  bool IsConverged()
+  bool IsConverged(arma::mat& W, arma::mat& H)
   {
+    (void)W;
+    (void)H;
     if(residue < minResidue || iteration > maxIterations) return true;
     else return false;
   }
diff --git a/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp b/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp
index 777e38a..8518b54 100644
--- a/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp
+++ b/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp
@@ -30,14 +30,39 @@ class SimpleToleranceTermination
     this->V = &V;
   }
 
-  bool IsConverged()
+  bool IsConverged(arma::mat& W, arma::mat& H)
   {
     if((residueOld - residue) / residueOld < tolerance && iteration > 4)
+    {
+      if(reverseStepCount == 0 && isCopy == false)
+      {
+        isCopy = true;
+        this->W = W;
+        this->H = H;
+        c_index = residue;
+        c_indexOld = residueOld;
+      }
       reverseStepCount++;
-    else reverseStepCount = 0;
+    }
+    else
+    {
+      reverseStepCount = 0;
+      if(residue <= c_indexOld && isCopy == true)
+      {
+        isCopy = false;
+      }
+    }
 
     if(reverseStepCount == reverseStepTolerance || iteration > maxIterations)
+    {
+      if(isCopy)
+      {
+        W = this->W;
+        H = this->H;
+        residue = c_index;
+      }
       return true;
+    }
     else return false;
   }
 
@@ -89,6 +114,12 @@ class SimpleToleranceTermination
 
   size_t reverseStepTolerance;
   size_t reverseStepCount;
+  
+  bool isCopy;
+  arma::mat W;
+  arma::mat H;
+  double c_indexOld;
+  double c_index;
 }; // class SimpleToleranceTermination
 
 }; // namespace amf
diff --git a/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp b/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp
index 49b0509..ceb7b0c 100644
--- a/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp
+++ b/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp
@@ -54,14 +54,39 @@ class ValidationRMSETermination
     reverseStepCount = 0;
   }
 
-  bool IsConverged()
+  bool IsConverged(arma::mat& W, arma::mat& H)
   {
-    if((rmseOld - rmse) / rmseOld < tolerance && iteration > 4) 
+    if((rmseOld - rmse) / rmseOld < tolerance && iteration > 4)
+    {
+      if(reverseStepCount == 0 && isCopy == false)
+      {
+        isCopy = true;
+        this->W = W;
+        this->H = H;
+        c_indexOld = rmseOld;
+        c_index = rmse;
+      }
       reverseStepCount++;
-    else reverseStepCount = 0;
+    }
+    else
+    {
+      reverseStepCount = 0;
+      if(rmse <= c_indexOld && isCopy == true)
+      {
+        isCopy = false;
+      }
+    }
 
-    if(reverseStepCount == reverseStepTolerance || iteration > maxIterations) 
+    if(reverseStepCount == reverseStepTolerance || iteration > maxIterations)
+    {
+      if(isCopy)
+      {
+        W = this->W;
+        H = this->H;
+        rmse = c_index;
+      }
       return true;
+    }
     else return false;
   }
 
@@ -115,6 +140,12 @@ class ValidationRMSETermination
 
   size_t reverseStepTolerance;
   size_t reverseStepCount;
+  
+  bool isCopy;
+  arma::mat W;
+  arma::mat H;
+  double c_indexOld;
+  double c_index;
 };
 
 } // namespace amf
diff --git a/src/mlpack/methods/amf/update_rules/CMakeLists.txt b/src/mlpack/methods/amf/update_rules/CMakeLists.txt
index b7bde1c..baa942f 100644
--- a/src/mlpack/methods/amf/update_rules/CMakeLists.txt
+++ b/src/mlpack/methods/amf/update_rules/CMakeLists.txt
@@ -4,7 +4,7 @@ set(SOURCES
   nmf_als.hpp
   nmf_mult_dist.hpp
   nmf_mult_div.hpp
-  svd_batchlearning.hpp
+  svd_batch_learning.hpp
   svd_incremental_learning.hpp
 )
 
diff --git a/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp b/src/mlpack/methods/amf/update_rules/svd_batch_learning.hpp
similarity index 100%
rename from src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp
rename to src/mlpack/methods/amf/update_rules/svd_batch_learning.hpp
diff --git a/src/mlpack/tests/CMakeLists.txt b/src/mlpack/tests/CMakeLists.txt
index 3be664e..5a8794c 100644
--- a/src/mlpack/tests/CMakeLists.txt
+++ b/src/mlpack/tests/CMakeLists.txt
@@ -50,7 +50,7 @@ add_executable(mlpack_test
   tree_test.cpp
   tree_traits_test.cpp
   union_find_test.cpp
-  svd_test.cpp
+  svd_batch_test.cpp
 )
 # Link dependencies of test executable.
 target_link_libraries(mlpack_test
diff --git a/src/mlpack/tests/svd_test.cpp b/src/mlpack/tests/svd_batch_test.cpp
similarity index 97%
rename from src/mlpack/tests/svd_test.cpp
rename to src/mlpack/tests/svd_batch_test.cpp
index 7160645..af14247 100644
--- a/src/mlpack/tests/svd_test.cpp
+++ b/src/mlpack/tests/svd_batch_test.cpp
@@ -1,6 +1,6 @@
 #include <mlpack/core.hpp>
 #include <mlpack/methods/amf/amf.hpp>
-#include <mlpack/methods/amf/update_rules/svd_batchlearning.hpp>
+#include <mlpack/methods/amf/update_rules/svd_batch_learning.hpp>
 #include <mlpack/methods/amf/init_rules/random_init.hpp>
 #include <mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp>
 #include <mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp>
@@ -145,8 +145,6 @@ BOOST_AUTO_TEST_CASE(SVDNegativeElementTest)
 
   arma::mat result = m1 * m2;
 
-  std::cout << result << std::endl;
-
   for(size_t i = 0;i < 3;i++)
   {
     for(size_t j = 0;j < 3;j++)

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git



More information about the debian-science-commits mailing list