[mlpack] 165/324: * added local minima storing functionality to termination policies
Barak A. Pearlmutter
barak+git at cs.nuim.ie
Sun Aug 17 08:22:06 UTC 2014
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch svn-trunk
in repository mlpack.
commit 41e0e9d1b4c47613feefa228fcbb0c31bab40f77
Author: sumedhghaisas <sumedhghaisas at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date: Wed Jul 9 22:14:52 2014 +0000
* added local minima storing functionality to termination policies
git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@16804 9d5b8971-822b-0410-80eb-d18c1038ef23
---
src/mlpack/methods/amf/amf_impl.hpp | 2 +-
.../incomplete_incremental_termination.hpp | 4 +--
.../simple_residue_termination.hpp | 4 ++-
.../simple_tolerance_termination.hpp | 35 +++++++++++++++++--
.../validation_RMSE_termination.hpp | 39 +++++++++++++++++++---
src/mlpack/methods/amf/update_rules/CMakeLists.txt | 2 +-
...vd_batchlearning.hpp => svd_batch_learning.hpp} | 0
src/mlpack/tests/CMakeLists.txt | 2 +-
.../tests/{svd_test.cpp => svd_batch_test.cpp} | 4 +--
9 files changed, 77 insertions(+), 15 deletions(-)
diff --git a/src/mlpack/methods/amf/amf_impl.hpp b/src/mlpack/methods/amf/amf_impl.hpp
index ce9a2aa..a887931 100644
--- a/src/mlpack/methods/amf/amf_impl.hpp
+++ b/src/mlpack/methods/amf/amf_impl.hpp
@@ -50,7 +50,7 @@ Apply(const MatType& V,
update.Initialize(V, r);
terminationPolicy.Initialize(V);
- while (!terminationPolicy.IsConverged())
+ while (!terminationPolicy.IsConverged(W, H))
{
// Update the values of W and H based on the update rules provided.
update.WUpdate(V, W, H);
diff --git a/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp b/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp
index d24d571..91424db 100644
--- a/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp
+++ b/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp
@@ -26,9 +26,9 @@ class IncompleteIncrementalTermination
iteration = 0;
}
- bool IsConverged()
+ bool IsConverged(arma::mat& W, arma::mat& H)
{
- return t_policy.IsConverged();
+ return t_policy.IsConverged(W, H);
}
void Step(const arma::mat& W, const arma::mat& H)
diff --git a/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp b/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
index a47ce23..cbae4ba 100644
--- a/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
+++ b/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
@@ -30,8 +30,10 @@ class SimpleResidueTermination
nm = n * m;
}
- bool IsConverged()
+ bool IsConverged(arma::mat& W, arma::mat& H)
{
+ (void)W;
+ (void)H;
if(residue < minResidue || iteration > maxIterations) return true;
else return false;
}
diff --git a/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp b/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp
index 777e38a..8518b54 100644
--- a/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp
+++ b/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp
@@ -30,14 +30,39 @@ class SimpleToleranceTermination
this->V = &V;
}
- bool IsConverged()
+ bool IsConverged(arma::mat& W, arma::mat& H)
{
if((residueOld - residue) / residueOld < tolerance && iteration > 4)
+ {
+ if(reverseStepCount == 0 && isCopy == false)
+ {
+ isCopy = true;
+ this->W = W;
+ this->H = H;
+ c_index = residue;
+ c_indexOld = residueOld;
+ }
reverseStepCount++;
- else reverseStepCount = 0;
+ }
+ else
+ {
+ reverseStepCount = 0;
+ if(residue <= c_indexOld && isCopy == true)
+ {
+ isCopy = false;
+ }
+ }
if(reverseStepCount == reverseStepTolerance || iteration > maxIterations)
+ {
+ if(isCopy)
+ {
+ W = this->W;
+ H = this->H;
+ residue = c_index;
+ }
return true;
+ }
else return false;
}
@@ -89,6 +114,12 @@ class SimpleToleranceTermination
size_t reverseStepTolerance;
size_t reverseStepCount;
+
+ bool isCopy;
+ arma::mat W;
+ arma::mat H;
+ double c_indexOld;
+ double c_index;
}; // class SimpleToleranceTermination
}; // namespace amf
diff --git a/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp b/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp
index 49b0509..ceb7b0c 100644
--- a/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp
+++ b/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp
@@ -54,14 +54,39 @@ class ValidationRMSETermination
reverseStepCount = 0;
}
- bool IsConverged()
+ bool IsConverged(arma::mat& W, arma::mat& H)
{
- if((rmseOld - rmse) / rmseOld < tolerance && iteration > 4)
+ if((rmseOld - rmse) / rmseOld < tolerance && iteration > 4)
+ {
+ if(reverseStepCount == 0 && isCopy == false)
+ {
+ isCopy = true;
+ this->W = W;
+ this->H = H;
+ c_indexOld = rmseOld;
+ c_index = rmse;
+ }
reverseStepCount++;
- else reverseStepCount = 0;
+ }
+ else
+ {
+ reverseStepCount = 0;
+ if(rmse <= c_indexOld && isCopy == true)
+ {
+ isCopy = false;
+ }
+ }
- if(reverseStepCount == reverseStepTolerance || iteration > maxIterations)
+ if(reverseStepCount == reverseStepTolerance || iteration > maxIterations)
+ {
+ if(isCopy)
+ {
+ W = this->W;
+ H = this->H;
+ rmse = c_index;
+ }
return true;
+ }
else return false;
}
@@ -115,6 +140,12 @@ class ValidationRMSETermination
size_t reverseStepTolerance;
size_t reverseStepCount;
+
+ bool isCopy;
+ arma::mat W;
+ arma::mat H;
+ double c_indexOld;
+ double c_index;
};
} // namespace amf
diff --git a/src/mlpack/methods/amf/update_rules/CMakeLists.txt b/src/mlpack/methods/amf/update_rules/CMakeLists.txt
index b7bde1c..baa942f 100644
--- a/src/mlpack/methods/amf/update_rules/CMakeLists.txt
+++ b/src/mlpack/methods/amf/update_rules/CMakeLists.txt
@@ -4,7 +4,7 @@ set(SOURCES
nmf_als.hpp
nmf_mult_dist.hpp
nmf_mult_div.hpp
- svd_batchlearning.hpp
+ svd_batch_learning.hpp
svd_incremental_learning.hpp
)
diff --git a/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp b/src/mlpack/methods/amf/update_rules/svd_batch_learning.hpp
similarity index 100%
rename from src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp
rename to src/mlpack/methods/amf/update_rules/svd_batch_learning.hpp
diff --git a/src/mlpack/tests/CMakeLists.txt b/src/mlpack/tests/CMakeLists.txt
index 3be664e..5a8794c 100644
--- a/src/mlpack/tests/CMakeLists.txt
+++ b/src/mlpack/tests/CMakeLists.txt
@@ -50,7 +50,7 @@ add_executable(mlpack_test
tree_test.cpp
tree_traits_test.cpp
union_find_test.cpp
- svd_test.cpp
+ svd_batch_test.cpp
)
# Link dependencies of test executable.
target_link_libraries(mlpack_test
diff --git a/src/mlpack/tests/svd_test.cpp b/src/mlpack/tests/svd_batch_test.cpp
similarity index 97%
rename from src/mlpack/tests/svd_test.cpp
rename to src/mlpack/tests/svd_batch_test.cpp
index 7160645..af14247 100644
--- a/src/mlpack/tests/svd_test.cpp
+++ b/src/mlpack/tests/svd_batch_test.cpp
@@ -1,6 +1,6 @@
#include <mlpack/core.hpp>
#include <mlpack/methods/amf/amf.hpp>
-#include <mlpack/methods/amf/update_rules/svd_batchlearning.hpp>
+#include <mlpack/methods/amf/update_rules/svd_batch_learning.hpp>
#include <mlpack/methods/amf/init_rules/random_init.hpp>
#include <mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp>
#include <mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp>
@@ -145,8 +145,6 @@ BOOST_AUTO_TEST_CASE(SVDNegativeElementTest)
arma::mat result = m1 * m2;
- std::cout << result << std::endl;
-
for(size_t i = 0;i < 3;i++)
{
for(size_t j = 0;j < 3;j++)
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list