[mlpack] 71/324: Perceptron Added
Barak A. Pearlmutter
barak+git at cs.nuim.ie
Sun Aug 17 08:21:57 UTC 2014
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch svn-trunk
in repository mlpack.
commit 689208101f3719d847b78ccc4c297d86576b468b
Author: saxena.udit <saxena.udit at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date: Tue Jun 24 06:58:46 2014 +0000
Perceptron Added
git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@16702 9d5b8971-822b-0410-80eb-d18c1038ef23
---
src/mlpack/methods/CMakeLists.txt | 1 +
src/mlpack/methods/perceptron/CMakeLists.txt | 29 ++++
.../InitializationMethods/CMakeLists.txt | 15 +++
.../InitializationMethods/random_init.hpp | 31 +++++
.../perceptron/InitializationMethods/zero_init.hpp | 34 +++++
.../methods/perceptron/LearnPolicy/CMakeLists.txt | 14 ++
.../perceptron/LearnPolicy/SimpleWeightUpdate.hpp | 53 ++++++++
src/mlpack/methods/perceptron/perceptron.hpp | 86 ++++++++++++
src/mlpack/methods/perceptron/perceptron_impl.cpp | 118 ++++++++++++++++
src/mlpack/methods/perceptron/perceptron_main.cpp | 81 +++++++++++
src/mlpack/tests/CMakeLists.txt | 1 +
src/mlpack/tests/perceptron_test.cpp | 150 +++++++++++++++++++++
12 files changed, 613 insertions(+)
diff --git a/src/mlpack/methods/CMakeLists.txt b/src/mlpack/methods/CMakeLists.txt
index b930aaa..d9eea39 100644
--- a/src/mlpack/methods/CMakeLists.txt
+++ b/src/mlpack/methods/CMakeLists.txt
@@ -22,6 +22,7 @@ set(DIRS
nmf
# lmf
pca
+ perceptron
radical
range_search
rann
diff --git a/src/mlpack/methods/perceptron/CMakeLists.txt b/src/mlpack/methods/perceptron/CMakeLists.txt
new file mode 100644
index 0000000..c25c549
--- /dev/null
+++ b/src/mlpack/methods/perceptron/CMakeLists.txt
@@ -0,0 +1,29 @@
+cmake_minimum_required(VERSION 2.8)
+
+# Define the files we need to compile.
+# Anything not in this list will not be compiled into MLPACK.
+set(SOURCES
+ perceptron.hpp
+ perceptron_impl.cpp
+)
+
+# Add directory name to sources.
+set(DIR_SRCS)
+foreach(file ${SOURCES})
+ set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file})
+endforeach()
+# Append sources (with directory name) to list of all MLPACK sources (used at
+# the parent scope).
+set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)
+
+add_subdirectory(InitializationMethods)
+add_subdirectory(LearnPolicy)
+
+add_executable(percep
+ perceptron_main.cpp
+)
+target_link_libraries(percep
+ mlpack
+)
+
+install(TARGETS percep RUNTIME DESTINATION bin)
diff --git a/src/mlpack/methods/perceptron/InitializationMethods/CMakeLists.txt b/src/mlpack/methods/perceptron/InitializationMethods/CMakeLists.txt
new file mode 100644
index 0000000..d5d9c31
--- /dev/null
+++ b/src/mlpack/methods/perceptron/InitializationMethods/CMakeLists.txt
@@ -0,0 +1,15 @@
+# Define the files we need to compile
+# Anything not in this list will not be compiled into MLPACK.
+set(SOURCES
+ random_init.hpp
+ zero_init.hpp
+)
+
+# Add directory name to sources.
+set(DIR_SRCS)
+foreach(file ${SOURCES})
+ set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file})
+endforeach()
+# Append sources (with directory name) to list of all MLPACK sources (used at
+# the parent scope).
+set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)
diff --git a/src/mlpack/methods/perceptron/InitializationMethods/random_init.hpp b/src/mlpack/methods/perceptron/InitializationMethods/random_init.hpp
new file mode 100644
index 0000000..7cdeb19
--- /dev/null
+++ b/src/mlpack/methods/perceptron/InitializationMethods/random_init.hpp
@@ -0,0 +1,31 @@
+/*
+ * @file: randominit.hpp
+ * @author: Udit Saxena
+ *
+ */
+
+#ifndef _MLPACK_METHOS_PERCEPTRON_RANDOMINIT
+#define _MLPACK_METHOS_PERCEPTRON_RANDOMINIT
+
+#include <mlpack/core.hpp>
+/*
+This class is used to initialize weights for the
+weightVectors matrix in a random manner.
+*/
+namespace mlpack {
+namespace perceptron {
+ class RandomInitialization
+ {
+ public:
+ RandomInitialization()
+ { }
+
+ inline static void initialize(arma::mat& W, size_t row, size_t col)
+ {
+ W = arma::randu<arma::mat>(row,col);
+ }
+ }; // class RandomInitialization
+}; // namespace perceptron
+}; // namespace mlpack
+
+#endif
\ No newline at end of file
diff --git a/src/mlpack/methods/perceptron/InitializationMethods/zero_init.hpp b/src/mlpack/methods/perceptron/InitializationMethods/zero_init.hpp
new file mode 100644
index 0000000..7115c81
--- /dev/null
+++ b/src/mlpack/methods/perceptron/InitializationMethods/zero_init.hpp
@@ -0,0 +1,34 @@
+/*
+ * @file: zeroinit.hpp
+ * @author: Udit Saxena
+ *
+ */
+
+#ifndef _MLPACK_METHOS_PERCEPTRON_ZEROINIT
+#define _MLPACK_METHOS_PERCEPTRON_ZEROINIT
+
+#include <mlpack/core.hpp>
+/*
+This class is used to initialize the matrix
+weightVectors to zero.
+*/
+namespace mlpack {
+namespace perceptron {
+ class ZeroInitialization
+ {
+ public:
+ ZeroInitialization()
+ { }
+
+ inline static void initialize(arma::mat& W, size_t row, size_t col)
+ {
+ arma::mat tempWeights(row, col);
+ tempWeights.fill(0.0);
+
+ W = tempWeights;
+ }
+ }; // class ZeroInitialization
+}; // namespace perceptron
+}; // namespace mlpack
+
+#endif
\ No newline at end of file
diff --git a/src/mlpack/methods/perceptron/LearnPolicy/CMakeLists.txt b/src/mlpack/methods/perceptron/LearnPolicy/CMakeLists.txt
new file mode 100644
index 0000000..a07bc01
--- /dev/null
+++ b/src/mlpack/methods/perceptron/LearnPolicy/CMakeLists.txt
@@ -0,0 +1,14 @@
+# Define the files we need to compile
+# Anything not in this list will not be compiled into MLPACK.
+set(SOURCES
+ SimpleWeightUpdate.hpp
+)
+
+# Add directory name to sources.
+set(DIR_SRCS)
+foreach(file ${SOURCES})
+ set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file})
+endforeach()
+# Append sources (with directory name) to list of all MLPACK sources (used at
+# the parent scope).
+set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)
diff --git a/src/mlpack/methods/perceptron/LearnPolicy/SimpleWeightUpdate.hpp b/src/mlpack/methods/perceptron/LearnPolicy/SimpleWeightUpdate.hpp
new file mode 100644
index 0000000..893ed0d
--- /dev/null
+++ b/src/mlpack/methods/perceptron/LearnPolicy/SimpleWeightUpdate.hpp
@@ -0,0 +1,53 @@
+/*
+ * @file: SimpleWeightUpdate.hpp
+ * @author: Udit Saxena
+ *
+ */
+
+#ifndef _MLPACK_METHOD_PERCEPTRON_LEARN_SIMPLEWEIGHTUPDATE
+#define _MLPACK_METHOD_PERCEPTRON_LEARN_SIMPLEWEIGHTUPDATE
+
+#include <mlpack/core.hpp>
+/*
+This class is used to update the weightVectors matrix according to
+the simple update rule as discussed by Rosenblatt:
+ if a vector x has been incorrectly classified by a weight w,
+ then w = w - x
+ and w'= w'+ x
+ where w' is the weight vector which correctly classifies x.
+*/
+namespace mlpack {
+namespace perceptron {
+
+class SimpleWeightUpdate
+{
+public:
+ SimpleWeightUpdate()
+ { }
+ /*
+ This function is called to update the weightVectors matrix.
+ It decreases the weights of the incorrectly classified class while
+ increasing the weight of the correct class it should have been classified to.
+
+ @param: trainData - the training dataset.
+ @param: weightVectors - matrix of weight vectors.
+ @param: rowIndex - index of the row which has been incorrectly predicted.
+ @param: labelIndex - index of the vector in trainData.
+ @param: vectorIndex - index of the class which should have been predicted.
+ */
+ void UpdateWeights(const arma::mat& trainData, arma::mat& weightVectors,
+ size_t labelIndex, size_t vectorIndex, size_t rowIndex )
+ {
+ arma::mat instance = trainData.col(labelIndex);
+
+ weightVectors.row(rowIndex) = weightVectors.row(rowIndex) -
+ instance.t();
+
+ weightVectors.row(vectorIndex) = weightVectors.row(vectorIndex) +
+ instance.t();
+ }
+};
+}; // namespace perceptron
+}; // namespace mlpack
+
+#endif
\ No newline at end of file
diff --git a/src/mlpack/methods/perceptron/perceptron.hpp b/src/mlpack/methods/perceptron/perceptron.hpp
new file mode 100644
index 0000000..7d26875
--- /dev/null
+++ b/src/mlpack/methods/perceptron/perceptron.hpp
@@ -0,0 +1,86 @@
+/*
+ * @file: perceptron.hpp
+ * @author: Udit Saxena
+ *
+ *
+ * Definition of Perceptron
+ */
+
+#ifndef _MLPACK_METHODS_PERCEPTRON_HPP
+#define _MLPACK_METHODS_PERCEPTRON_HPP
+
+#include <mlpack/core.hpp>
+#include "InitializationMethods/zero_init.hpp"
+#include "InitializationMethods/random_init.hpp"
+#include "LearnPolicy/SimpleWeightUpdate.hpp"
+
+
+namespace mlpack {
+namespace perceptron {
+
+template <typename LearnPolicy = SimpleWeightUpdate,
+ typename WeightInitializationPolicy = ZeroInitialization,
+ typename MatType = arma::mat>
+class Perceptron
+{
+ /*
+ This class implements a simple perceptron i.e. a single layer
+ neural network. It converges if the supplied training dataset is
+ linearly separable.
+
+ LearnPolicy: Options of SimpleWeightUpdate and GradientDescent.
+ WeightInitializationPolicy: Option of ZeroInitialization and
+ RandomInitialization.
+ */
+public:
+ /*
+ Constructor - Constructs the perceptron. Or rather, builds the weightVectors
+ matrix, which is later used in Classification.
+ It adds a bias input vector of 1 to the input data to take care of the bias
+ weights.
+
+ @param: data - Input, training data.
+ @param: labels - Labels of dataset.
+ @param: iterations - maximum number of iterations the perceptron
+ learn algorithm is to be run.
+ */
+ Perceptron(const MatType& data, const arma::Row<size_t>& labels, int iterations);
+
+ /*
+ Classification function. After training, use the weightVectors matrix to
+ classify test, and put the predicted classes in predictedLabels.
+
+ @param: test - testing data or data to classify.
+ @param: predictedLabels - vector to store the predicted classes after
+ classifying test
+ */
+ void Classify(const MatType& test, arma::Row<size_t>& predictedLabels);
+
+private:
+
+ /* Stores the class labels for the input data*/
+ arma::Row<size_t> classLabels;
+
+ /* Stores the weight vectors for each of the input class labels. */
+ arma::mat weightVectors;
+
+ /* Stores the training data to be used later on in UpdateWeights.*/
+ arma::mat trainData;
+
+ /*
+ This function is called by the constructor to update the weightVectors
+ matrix. It decreases the weights of the incorrectly classified class while
+ increasing the weight of the correct class it should have been classified to.
+
+ @param: rowIndex - index of the row which has been incorrectly predicted.
+ @param: labelIndex - index of the vector in trainData.
+ @param: vectorIndex - index of the class which should have been predicted.
+ */
+ // void UpdateWeights(size_t rowIndex, size_t labelIndex, size_t vectorIndex);
+};
+} // namespace perceptron
+} // namespace mlpack
+
+#include "perceptron_impl.cpp"
+
+#endif
\ No newline at end of file
diff --git a/src/mlpack/methods/perceptron/perceptron_impl.cpp b/src/mlpack/methods/perceptron/perceptron_impl.cpp
new file mode 100644
index 0000000..b29c722
--- /dev/null
+++ b/src/mlpack/methods/perceptron/perceptron_impl.cpp
@@ -0,0 +1,118 @@
+/*
+ * @file: perceptron_impl.hpp
+ * @author: Udit Saxena
+ *
+ */
+
+#ifndef _MLPACK_METHODS_PERCEPTRON_IMPL_CPP
+#define _MLPACK_METHODS_PERCEPTRON_IMPL_CPP
+
+#include "perceptron.hpp"
+
+namespace mlpack {
+namespace perceptron {
+
+/*
+ Constructor - Constructs the perceptron. Or rather, builds the weightVectors
+ matrix, which is later used in Classification.
+ It adds a bias input vector of 1 to the input data to take care of the bias
+ weights.
+
+ @param: data - Input, training data.
+ @param: labels - Labels of dataset.
+ @param: iterations - maximum number of iterations the perceptron
+ learn algorithm is to be run.
+*/
+template <typename LearnPolicy, typename WeightInitializationPolicy, typename MatType>
+Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Perceptron(const MatType& data,
+ const arma::Row<size_t>& labels, int iterations)
+{
+ arma::Row<size_t> uniqueLabels = arma::unique(labels);
+
+ WeightInitializationPolicy WIP;
+ WIP.initialize(weightVectors, uniqueLabels.n_elem, data.n_rows + 1);
+
+ // Start training.
+ classLabels = labels;
+
+ trainData = data;
+ // inserting a row of 1's at the top of the training data set.
+ MatType zOnes(1, data.n_cols);
+ zOnes.fill(1);
+ trainData.insert_rows(0, zOnes);
+
+ int j, i = 0, converged = 0;
+ size_t tempLabel;
+ arma::uword maxIndexRow, maxIndexCol;
+ double maxVal;
+ arma::mat tempLabelMat;
+
+ LearnPolicy LP;
+
+ while ((i < iterations) && (!converged))
+ {
+ // This outer loop is for each iteration,
+ // and we use the 'converged' variable for noting whether or not
+ // convergence has been reached.
+ i++;
+ converged = 1;
+
+ // Now this inner loop is for going through the dataset in each iteration
+ for (j = 0; j < data.n_cols; j++)
+ {
+ // Multiplying for each variable and checking
+ // whether the current weight vector correctly classifies this.
+ tempLabelMat = weightVectors * trainData.col(j);
+
+ maxVal = tempLabelMat.max(maxIndexRow, maxIndexCol);
+ maxVal *= 2;
+ //checking whether prediction is correct.
+ if(maxIndexRow != classLabels(0,j))
+ {
+ // due to incorrect prediction, convergence set to 0
+ converged = 0;
+ tempLabel = labels(0,j);
+ // send maxIndexRow for knowing which weight to update,
+ // send j to know the value of the vector to update it with.
+ // send tempLabel to know the correct class
+ LP.UpdateWeights(trainData, weightVectors, j, tempLabel, maxIndexRow);
+ }
+ }
+ }
+}
+
+/*
+ Classification function. After training, use the weightVectors matrix to
+ classify test, and put the predicted classes in predictedLabels.
+
+ @param: test - testing data or data to classify.
+ @param: predictedLabels - vector to store the predicted classes after
+ classifying test
+ */
+template <typename LearnPolicy, typename WeightInitializationPolicy, typename MatType>
+void Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Classify(
+ const MatType& test, arma::Row<size_t>& predictedLabels)
+{
+ int i;
+ arma::mat tempLabelMat;
+ arma::uword maxIndexRow, maxIndexCol;
+ double maxVal;
+ MatType testData = test;
+
+ MatType zOnes(1, test.n_cols);
+ zOnes.fill(1);
+ testData.insert_rows(0, zOnes);
+
+ for (i = 0; i < test.n_cols; i++)
+ {
+ tempLabelMat = weightVectors * testData.col(i);
+ maxVal = tempLabelMat.max(maxIndexRow, maxIndexCol);
+ maxVal *= 2;
+ predictedLabels(0,i) = maxIndexRow;
+ }
+}
+
+}; // namespace perceptron
+}; // namespace mlpack
+
+#endif
\ No newline at end of file
diff --git a/src/mlpack/methods/perceptron/perceptron_main.cpp b/src/mlpack/methods/perceptron/perceptron_main.cpp
new file mode 100644
index 0000000..a6082d7
--- /dev/null
+++ b/src/mlpack/methods/perceptron/perceptron_main.cpp
@@ -0,0 +1,81 @@
+/*
+ * @file: perceptron_main.cpp
+ * @author: Udit Saxena
+ *
+ *
+ */
+
+#include <mlpack/core.hpp>
+#include "perceptron.hpp"
+
+using namespace mlpack;
+using namespace mlpack::perceptron;
+using namespace std;
+using namespace arma;
+
+PROGRAM_INFO("","");
+
+//necessary parameters
+PARAM_STRING_REQ("train_file", "A file containing the training set.", "tr");
+PARAM_STRING_REQ("labels_file", "A file containing labels for the training set.",
+ "l");
+PARAM_STRING_REQ("test_file", "A file containing the test set.", "te");
+
+//optional parameters.
+PARAM_STRING("output", "The file in which the predicted labels for the test set"
+ " will be written.", "o", "output.csv");
+PARAM_INT("iterations","The maximum number of iterations the perceptron is "
+ "to be run", "i", 1000)
+
+int main(int argc, char *argv[])
+{
+ CLI::ParseCommandLine(argc, argv);
+
+ const string trainingDataFilename = CLI::GetParam<string>("train_file");
+ mat trainingData;
+ data::Load(trainingDataFilename, trainingData, true);
+
+ const string labelsFilename = CLI::GetParam<string>("labels_file");
+ // Load labels.
+ mat labelsIn;
+ data::Load(labelsFilename, labelsIn, true);
+
+ // helpers for normalizing the labels
+ Col<size_t> labels;
+ vec mappings;
+
+ // Do the labels need to be transposed?
+ if (labelsIn.n_rows == 1)
+ labelsIn = labelsIn.t();
+
+ // normalize the labels
+ data::NormalizeLabels(labelsIn.unsafe_col(0), labels, mappings);
+
+ const string testingDataFilename = CLI::GetParam<string>("test_file");
+ mat testingData;
+ data::Load(testingDataFilename, testingData, true);
+
+ if (testingData.n_rows != trainingData.n_rows)
+ Log::Fatal << "Test data dimensionality (" << testingData.n_rows << ") "
+ << "must be the same as training data (" << trainingData.n_rows - 1
+ << ")!" << std::endl;
+ int iterations = CLI::GetParam<int>("iterations");
+
+ Timer::Start("Training");
+ Perceptron<> p(trainingData, labels, iterations);
+ Timer::Stop("Training");
+
+ Row<size_t> predictedLabels(testingData.n_cols);
+ Timer::Start("Testing");
+ p.Classify(testingData, predictedLabels);
+ Timer::Stop("Testing");
+
+ vec results;
+ data::RevertLabels(predictedLabels, mappings, results);
+
+ const string outputFilename = CLI::GetParam<string>("output");
+ data::Save(outputFilename, results, true, true);
+ // saving the predictedLabels in the transposed manner in output
+
+ return 0;
+}
\ No newline at end of file
diff --git a/src/mlpack/tests/CMakeLists.txt b/src/mlpack/tests/CMakeLists.txt
index 2aebb62..d35779b 100644
--- a/src/mlpack/tests/CMakeLists.txt
+++ b/src/mlpack/tests/CMakeLists.txt
@@ -34,6 +34,7 @@ add_executable(mlpack_test
nca_test.cpp
nmf_test.cpp
pca_test.cpp
+ perceptron_test.cpp
radical_test.cpp
range_search_test.cpp
save_restore_utility_test.cpp
diff --git a/src/mlpack/tests/perceptron_test.cpp b/src/mlpack/tests/perceptron_test.cpp
new file mode 100644
index 0000000..70b368e
--- /dev/null
+++ b/src/mlpack/tests/perceptron_test.cpp
@@ -0,0 +1,150 @@
+/*
+ * @file: perceptron_test.cpp
+ * @author: Udit Saxena
+ *
+ * Tests for perceptron.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/methods/perceptron/perceptron.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace mlpack;
+using namespace arma;
+using namespace mlpack::perceptron;
+
+BOOST_AUTO_TEST_SUITE(PERCEPTRONTEST);
+/*
+This test tests whether the perceptron converges for the
+AND gate classifier.
+*/
+BOOST_AUTO_TEST_CASE(AND)
+{
+ mat trainData;
+ trainData << 0 << 1 << 1 << 0 << endr
+ << 1 << 0 << 1 << 0 << endr;
+ Mat<size_t> labels;
+ labels << 0 << 0 << 1 << 0;
+
+ Perceptron<> p(trainData, labels.row(0), 1000);
+
+ mat testData;
+ testData << 0 << 1 << 1 << 0 << endr
+ << 1 << 0 << 1 << 0 << endr;
+ Row<size_t> predictedLabels(testData.n_cols);
+ p.Classify(testData, predictedLabels);
+
+ BOOST_CHECK_EQUAL(predictedLabels(0,0),0);
+ BOOST_CHECK_EQUAL(predictedLabels(0,1),0);
+ BOOST_CHECK_EQUAL(predictedLabels(0,2),1);
+ BOOST_CHECK_EQUAL(predictedLabels(0,3),0);
+
+}
+
+/*
+This test tests whether the perceptron converges for the
+OR gate classifier.
+*/
+BOOST_AUTO_TEST_CASE(OR)
+{
+ mat trainData;
+ trainData << 0 << 1 << 1 << 0 << endr
+ << 1 << 0 << 1 << 0 << endr;
+
+ Mat<size_t> labels;
+ labels << 1 << 1 << 1 << 0;
+
+ Perceptron<> p(trainData, labels.row(0), 1000);
+
+ mat testData;
+ testData << 0 << 1 << 1 << 0 << endr
+ << 1 << 0 << 1 << 0 << endr;
+ Row<size_t> predictedLabels(testData.n_cols);
+ p.Classify(testData, predictedLabels);
+
+ BOOST_CHECK_EQUAL(predictedLabels(0,0),1);
+ BOOST_CHECK_EQUAL(predictedLabels(0,1),1);
+ BOOST_CHECK_EQUAL(predictedLabels(0,2),1);
+ BOOST_CHECK_EQUAL(predictedLabels(0,3),0);
+}
+
+/*
+This tests the convergence on a set of linearly
+separable data with 3 classes.
+*/
+BOOST_AUTO_TEST_CASE(RANDOM3)
+{
+ mat trainData;
+ trainData << 0 << 1 << 1 << 4 << 5 << 4 << 1 << 2 << 1 << endr
+ << 1 << 0 << 1 << 1 << 1 << 2 << 4 << 5 << 4 << endr;
+
+ Mat<size_t> labels;
+ labels << 0 << 0 << 0 << 1 << 1 << 1 << 2 << 2 << 2;
+
+ Perceptron<> p(trainData, labels.row(0), 1000);
+
+ mat testData;
+ testData << 0 << 1 << 1 << endr
+ << 1 << 0 << 1 << endr;
+ Row<size_t> predictedLabels(testData.n_cols);
+ p.Classify(testData, predictedLabels);
+
+ for (size_t i = 0; i<predictedLabels.n_cols; i++)
+ BOOST_CHECK_EQUAL(predictedLabels(0,i),0);
+
+}
+
+/*
+This tests the convergence of the perceptron on a dataset
+which has only TWO points which belong to different classes.
+*/
+BOOST_AUTO_TEST_CASE(TWOPOINTS)
+{
+ mat trainData;
+ trainData << 0 << 1 << endr
+ << 1 << 0 << endr;
+
+ Mat<size_t> labels;
+ labels << 0 << 1 ;
+
+ Perceptron<> p(trainData, labels.row(0), 1000);
+
+ mat testData;
+ testData << 0 << 1 << endr
+ << 1 << 0 << endr;
+ Row<size_t> predictedLabels(testData.n_cols);
+ p.Classify(testData, predictedLabels);
+
+ BOOST_CHECK_EQUAL(predictedLabels(0,0),0);
+ BOOST_CHECK_EQUAL(predictedLabels(0,1),1);
+}
+/*
+This tests the convergence of the perceptron on a dataset
+which has a non-linearly separable dataset.
+*/
+BOOST_AUTO_TEST_CASE(NONLINSEPDS)
+{
+ mat trainData;
+ trainData << 1 << 2 << 3 << 4 << 5 << 6 << 7 << 8
+ << 1 << 2 << 3 << 4 << 5 << 6 << 7 << 8 << endr
+ << 1 << 1 << 1 << 1 << 1 << 1 << 1 << 1
+ << 2 << 2 << 2 << 2 << 2 << 2 << 2 << 2 << endr;
+
+ Mat<size_t> labels;
+ labels << 0 << 0 << 0 << 1 << 0 << 1 << 1 << 1
+ << 0 << 0 << 0 << 1 << 0 << 1 << 1 << 1;
+ Perceptron<> p(trainData, labels.row(0), 1000);
+
+ mat testData;
+ testData << 3 << 4 << 5 << 6 << endr
+ << 3 << 2.3 << 1.7 << 1.5 << endr;
+ Row<size_t> predictedLabels(testData.n_cols);
+ p.Classify(testData, predictedLabels);
+
+ BOOST_CHECK_EQUAL(predictedLabels(0,0),0);
+ BOOST_CHECK_EQUAL(predictedLabels(0,1),0);
+ BOOST_CHECK_EQUAL(predictedLabels(0,2),1);
+ BOOST_CHECK_EQUAL(predictedLabels(0,3),1);
+}
+BOOST_AUTO_TEST_SUITE_END();
\ No newline at end of file
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list