[mlpack] 166/324: Changes are part of perceptron code review, as discussed with Ryan
Barak A. Pearlmutter
barak+git at cs.nuim.ie
Sun Aug 17 08:22:07 UTC 2014
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch svn-trunk
in repository mlpack.
commit 1d0265caf31227a4ab6176eefba98e8dd0d12fd1
Author: saxena.udit <saxena.udit at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date: Thu Jul 10 11:21:42 2014 +0000
Changes are part of perceptron code review, as discussed with Ryan
git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@16805 9d5b8971-822b-0410-80eb-d18c1038ef23
---
.../perceptron/learning_policies/simple_weight_update.hpp | 7 +++----
src/mlpack/methods/perceptron/perceptron_impl.hpp | 9 +++------
src/mlpack/tests/perceptron_test.cpp | 2 ++
3 files changed, 8 insertions(+), 10 deletions(-)
diff --git a/src/mlpack/methods/perceptron/learning_policies/simple_weight_update.hpp b/src/mlpack/methods/perceptron/learning_policies/simple_weight_update.hpp
index 0ab6086..0f6a3c1 100644
--- a/src/mlpack/methods/perceptron/learning_policies/simple_weight_update.hpp
+++ b/src/mlpack/methods/perceptron/learning_policies/simple_weight_update.hpp
@@ -42,12 +42,11 @@ class SimpleWeightUpdate
const size_t vectorIndex,
const size_t rowIndex)
{
- arma::mat instance = trainData.col(labelIndex);
-
- weightVectors.row(rowIndex) = weightVectors.row(rowIndex) - instance.t();
+ weightVectors.row(rowIndex) = weightVectors.row(rowIndex) -
+ trainData.col(labelIndex).t();
weightVectors.row(vectorIndex) = weightVectors.row(vectorIndex) +
- instance.t();
+ trainData.col(labelIndex).t();
}
};
diff --git a/src/mlpack/methods/perceptron/perceptron_impl.hpp b/src/mlpack/methods/perceptron/perceptron_impl.hpp
index 6284cb9..7244b43 100644
--- a/src/mlpack/methods/perceptron/perceptron_impl.hpp
+++ b/src/mlpack/methods/perceptron/perceptron_impl.hpp
@@ -101,15 +101,12 @@ void Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Classify(
arma::mat tempLabelMat;
arma::uword maxIndexRow, maxIndexCol;
double maxVal;
- MatType testData = test;
-
- MatType zOnes(1, test.n_cols);
- zOnes.fill(1);
- testData.insert_rows(0, zOnes);
for (int i = 0; i < test.n_cols; i++)
{
- tempLabelMat = weightVectors * testData.col(i);
+ tempLabelMat = weightVectors.submat(0,1,weightVectors.n_rows-1,
+ weightVectors.n_cols-1) *
+ test.col(i) + weightVectors.col(0);
maxVal = tempLabelMat.max(maxIndexRow, maxIndexCol);
maxVal *= 2;
predictedLabels(0, i) = maxIndexRow;
diff --git a/src/mlpack/tests/perceptron_test.cpp b/src/mlpack/tests/perceptron_test.cpp
index d9265c7..91f826c 100644
--- a/src/mlpack/tests/perceptron_test.cpp
+++ b/src/mlpack/tests/perceptron_test.cpp
@@ -13,6 +13,7 @@
using namespace mlpack;
using namespace arma;
using namespace mlpack::perceptron;
+using namespace mlpack::distribution;
BOOST_AUTO_TEST_SUITE(PerceptronTest);
@@ -133,6 +134,7 @@ BOOST_AUTO_TEST_CASE(NonLinearlySeparableDataset)
Mat<size_t> labels;
labels << 0 << 0 << 0 << 1 << 0 << 1 << 1 << 1
<< 0 << 0 << 0 << 1 << 0 << 1 << 1 << 1;
+ // labels.print("Here too.");
Perceptron<> p(trainData, labels.row(0), 1000);
mat testData;
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list