[mlpack] 197/324: Embedding nystroem method into kernel pca method.
Barak A. Pearlmutter
barak+git at cs.nuim.ie
Sun Aug 17 08:22:10 UTC 2014
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch svn-trunk
in repository mlpack.
commit 9f9eb690e25b75395b13674453e1e23fb698c67b
Author: marcus <marcus at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date: Sat Jul 19 19:29:11 2014 +0000
Embedding nystroem method into kernel pca method.
git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@16839 9d5b8971-822b-0410-80eb-d18c1038ef23
---
src/mlpack/methods/CMakeLists.txt | 1 +
src/mlpack/methods/kernel_pca/CMakeLists.txt | 2 +
src/mlpack/methods/kernel_pca/kernel_pca.hpp | 33 ++++---
src/mlpack/methods/kernel_pca/kernel_pca_impl.hpp | 100 ++++++---------------
.../kernel_pca/{ => kernel_rules}/CMakeLists.txt | 12 +--
.../kernel_pca/kernel_rules/naive_method.hpp | 89 ++++++++++++++++++
.../kernel_pca/kernel_rules/nystroem_method.hpp | 71 +++++++++++++++
.../{kernel_pca => nystroem_method}/CMakeLists.txt | 15 ++--
src/mlpack/tests/CMakeLists.txt | 1 +
src/mlpack/tests/kernel_pca_test.cpp | 70 +++++++++++++--
10 files changed, 286 insertions(+), 108 deletions(-)
diff --git a/src/mlpack/methods/CMakeLists.txt b/src/mlpack/methods/CMakeLists.txt
index 1c0e61d..041e475 100644
--- a/src/mlpack/methods/CMakeLists.txt
+++ b/src/mlpack/methods/CMakeLists.txt
@@ -31,6 +31,7 @@ set(DIRS
regularized_svd
sparse_autoencoder
sparse_coding
+ nystroem_method
)
foreach(dir ${DIRS})
diff --git a/src/mlpack/methods/kernel_pca/CMakeLists.txt b/src/mlpack/methods/kernel_pca/CMakeLists.txt
index c575af9..4b5fcb5 100644
--- a/src/mlpack/methods/kernel_pca/CMakeLists.txt
+++ b/src/mlpack/methods/kernel_pca/CMakeLists.txt
@@ -14,6 +14,8 @@ endforeach()
# the parent scope).
set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)
+add_subdirectory(kernel_rules)
+
add_executable(kernel_pca
kernel_pca_main.cpp
)
diff --git a/src/mlpack/methods/kernel_pca/kernel_pca.hpp b/src/mlpack/methods/kernel_pca/kernel_pca.hpp
index 98c23ca..fdf2ef4 100644
--- a/src/mlpack/methods/kernel_pca/kernel_pca.hpp
+++ b/src/mlpack/methods/kernel_pca/kernel_pca.hpp
@@ -1,6 +1,7 @@
/**
* @file kernel_pca.hpp
* @author Ajinkya Kale
+ * @author Marcus Edel
*
* Defines the KernelPCA class to perform Kernel Principal Components Analysis
* on the specified data set.
@@ -9,7 +10,7 @@
#define __MLPACK_METHODS_KERNEL_PCA_KERNEL_PCA_HPP
#include <mlpack/core.hpp>
-#include <mlpack/core/kernels/linear_kernel.hpp>
+#include <mlpack/methods/kernel_pca/kernel_rules/naive_method.hpp>
namespace mlpack {
namespace kpca {
@@ -27,7 +28,10 @@ namespace kpca {
* files in mlpack/core/kernels/) and it is easy to write your own; see other
* implementations for examples.
*/
-template <typename KernelType>
+template <
+ typename KernelType,
+ typename KernelRule = NaiveKernelRule<KernelType>
+>
class KernelPCA
{
public:
@@ -38,6 +42,7 @@ class KernelPCA
* much).
*
* @param kernel Kernel to be used for computation.
+ * @param centerTransformedData Center transformed data.
*/
KernelPCA(const KernelType kernel = KernelType(),
const bool centerTransformedData = false);
@@ -49,6 +54,21 @@ class KernelPCA
* @param transformedData Matrix to output results into.
* @param eigval KPCA eigenvalues will be written to this vector.
* @param eigvec KPCA eigenvectors will be written to this matrix.
+ * @param newDimension New dimension for the dataset.
+ */
+ void Apply(const arma::mat& data,
+ arma::mat& transformedData,
+ arma::vec& eigval,
+ arma::mat& eigvec,
+ const size_t newDimension);
+
+ /**
+ * Apply Kernel Principal Components Analysis to the provided data set.
+ *
+ * @param data Data matrix.
+ * @param transformedData Matrix to output results into.
+ * @param eigval KPCA eigenvalues will be written to this vector.
+ * @param eigvec KPCA eigenvectors will be written to this matrix.
*/
void Apply(const arma::mat& data,
arma::mat& transformedData,
@@ -90,7 +110,6 @@ class KernelPCA
bool CenterTransformedData() const { return centerTransformedData; }
//! Return whether or not the transformed data is centered.
bool& CenterTransformedData() { return centerTransformedData; }
-
// Returns a string representation of this object.
std::string ToString() const;
@@ -102,14 +121,6 @@ class KernelPCA
//! run.
bool centerTransformedData;
- /**
- * Construct the kernel matrix.
- *
- * @param data Input data points.
- * @param kernelMatrix Matrix to store the constructed kernel matrix in.
- */
- void GetKernelMatrix(const arma::mat& data, arma::mat& kernelMatrix);
-
}; // class KernelPCA
}; // namespace kpca
diff --git a/src/mlpack/methods/kernel_pca/kernel_pca_impl.hpp b/src/mlpack/methods/kernel_pca/kernel_pca_impl.hpp
index 9e65845..e209b1a 100644
--- a/src/mlpack/methods/kernel_pca/kernel_pca_impl.hpp
+++ b/src/mlpack/methods/kernel_pca/kernel_pca_impl.hpp
@@ -3,7 +3,7 @@
* @author Ajinkya Kale
* @author Marcus Edel
*
- * Implementation of KernelPCA class to perform Kernel Principal Components
+ * Implementation of Kernel PCA class to perform Kernel Principal Components
* Analysis on the specified data set.
*/
#ifndef __MLPACK_METHODS_KERNEL_PCA_KERNEL_PCA_IMPL_HPP
@@ -12,54 +12,26 @@
// In case it hasn't already been included.
#include "kernel_pca.hpp"
-#include <iostream>
-
namespace mlpack {
namespace kpca {
-template <typename KernelType>
-arma::mat GetKernelMatrix(KernelType kernel, arma::mat transData);
-
-template <typename KernelType>
-KernelPCA<KernelType>::KernelPCA(const KernelType kernel,
+template <typename KernelType, typename KernelRule>
+KernelPCA<KernelType, KernelRule>::KernelPCA(const KernelType kernel,
const bool centerTransformedData) :
kernel(kernel),
centerTransformedData(centerTransformedData)
{ }
//! Apply Kernel Principal Component Analysis to the provided data set.
-template <typename KernelType>
-void KernelPCA<KernelType>::Apply(const arma::mat& data,
+template <typename KernelType, typename KernelRule>
+void KernelPCA<KernelType, KernelRule>::Apply(const arma::mat& data,
arma::mat& transformedData,
arma::vec& eigval,
- arma::mat& eigvec)
+ arma::mat& eigvec,
+ const size_t newDimension)
{
- // Construct the kernel matrix.
- arma::mat kernelMatrix;
- GetKernelMatrix(data, kernelMatrix);
-
- // For PCA the data has to be centered, even if the data is centered. But it
- // is not guaranteed that the data, when mapped to the kernel space, is also
- // centered. Since we actually never work in the feature space we cannot
- // center the data. So, we perform a "psuedo-centering" using the kernel
- // matrix.
- arma::rowvec rowMean = arma::sum(kernelMatrix, 0) / kernelMatrix.n_cols;
- kernelMatrix.each_col() -= arma::sum(kernelMatrix, 1) / kernelMatrix.n_cols;
- kernelMatrix.each_row() -= rowMean;
- kernelMatrix += arma::sum(rowMean) / kernelMatrix.n_cols;
-
- // Eigendecompose the centered kernel matrix.
- arma::eig_sym(eigval, eigvec, kernelMatrix);
-
- // Swap the eigenvalues since they are ordered backwards (we need largest to
- // smallest).
- for (size_t i = 0; i < floor(eigval.n_elem / 2.0); ++i)
- eigval.swap_rows(i, (eigval.n_elem - 1) - i);
-
- // Flip the coefficients to produce the same effect.
- eigvec = arma::fliplr(eigvec);
-
- transformedData = eigvec.t() * kernelMatrix;
+ KernelRule::ApplyKernelMatrix(data, transformedData, eigval,
+ eigvec, newDimension, kernel);
// Center the transformed data, if the user asked for it.
if (centerTransformedData)
@@ -71,58 +43,42 @@ void KernelPCA<KernelType>::Apply(const arma::mat& data,
}
//! Apply Kernel Principal Component Analysis to the provided data set.
-template <typename KernelType>
-void KernelPCA<KernelType>::Apply(const arma::mat& data,
+template <typename KernelType, typename KernelRule>
+void KernelPCA<KernelType, KernelRule>::Apply(const arma::mat& data,
+ arma::mat& transformedData,
+ arma::vec& eigval,
+ arma::mat& eigvec)
+{
+ Apply(data, transformedData, eigval, eigvec, data.n_cols);
+}
+
+//! Apply Kernel Principal Component Analysis to the provided data set.
+template <typename KernelType, typename KernelRule>
+void KernelPCA<KernelType, KernelRule>::Apply(const arma::mat& data,
arma::mat& transformedData,
arma::vec& eigVal)
{
arma::mat coeffs;
- Apply(data, transformedData, eigVal, coeffs);
+ Apply(data, transformedData, eigVal, coeffs, data.n_cols);
}
//! Use KPCA for dimensionality reduction.
-template <typename KernelType>
-void KernelPCA<KernelType>::Apply(arma::mat& data, const size_t newDimension)
+template <typename KernelType, typename KernelRule>
+void KernelPCA<KernelType, KernelRule>::Apply(arma::mat& data,
+ const size_t newDimension)
{
arma::mat coeffs;
arma::vec eigVal;
- Apply(data, data, eigVal, coeffs);
+ Apply(data, data, eigVal, coeffs, newDimension);
if (newDimension < coeffs.n_rows && newDimension > 0)
data.shed_rows(newDimension, data.n_rows - 1);
}
-//! Construct the kernel matrix.
-template <typename KernelType>
-void KernelPCA<KernelType>::GetKernelMatrix(const arma::mat& data,
- arma::mat& kernelMatrix)
-{
- // Resize the kernel matrix to the right size.
- kernelMatrix.set_size(data.n_cols, data.n_cols);
-
- // Note that we only need to calculate the upper triangular part of the kernel
- // matrix, since it is symmetric. This helps minimize the number of kernel
- // evaluations.
- for (size_t i = 0; i < data.n_cols; ++i)
- {
- for (size_t j = i; j < data.n_cols; ++j)
- {
- // Evaluate the kernel on these two points.
- kernelMatrix(i, j) = kernel.Evaluate(data.unsafe_col(i),
- data.unsafe_col(j));
- }
- }
-
- // Copy to the lower triangular part of the matrix.
- for (size_t i = 1; i < data.n_cols; ++i)
- for (size_t j = 0; j < i; ++j)
- kernelMatrix(i, j) = kernelMatrix(j, i);
-}
-
// Returns a String of the Object
-template <typename KernelType>
-std::string KernelPCA<KernelType>::ToString() const
+template <typename KernelType, typename KernelRule>
+std::string KernelPCA<KernelType, KernelRule>::ToString() const
{
std::ostringstream convert;
convert << "KernelPCA [" << this << "]" << std::endl;
diff --git a/src/mlpack/methods/kernel_pca/CMakeLists.txt b/src/mlpack/methods/kernel_pca/kernel_rules/CMakeLists.txt
similarity index 69%
copy from src/mlpack/methods/kernel_pca/CMakeLists.txt
copy to src/mlpack/methods/kernel_pca/kernel_rules/CMakeLists.txt
index c575af9..7093fbf 100644
--- a/src/mlpack/methods/kernel_pca/CMakeLists.txt
+++ b/src/mlpack/methods/kernel_pca/kernel_rules/CMakeLists.txt
@@ -1,8 +1,8 @@
# Define the files we need to compile
# Anything not in this list will not be compiled into MLPACK.
set(SOURCES
- kernel_pca.hpp
- kernel_pca_impl.hpp
+ nystroem_method.hpp
+ naive_method.hpp
)
# Add directory name to sources.
@@ -13,11 +13,3 @@ endforeach()
# Append sources (with directory name) to list of all MLPACK sources (used at
# the parent scope).
set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)
-
-add_executable(kernel_pca
- kernel_pca_main.cpp
-)
-target_link_libraries(kernel_pca
- mlpack
-)
-install(TARGETS kernel_pca RUNTIME DESTINATION bin)
diff --git a/src/mlpack/methods/kernel_pca/kernel_rules/naive_method.hpp b/src/mlpack/methods/kernel_pca/kernel_rules/naive_method.hpp
new file mode 100644
index 0000000..7a97f34
--- /dev/null
+++ b/src/mlpack/methods/kernel_pca/kernel_rules/naive_method.hpp
@@ -0,0 +1,89 @@
+/**
+ * @file naive_method.hpp
+ * @author Ajinkya Kale
+ *
+ * Use the naive method to construct the kernel matrix.
+ */
+
+#ifndef __MLPACK_METHODS_KERNEL_PCA_NAIVE_METHOD_HPP
+#define __MLPACK_METHODS_KERNEL_PCA_NAIVE_METHOD_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace kpca {
+
+template<typename KernelType>
+class NaiveKernelRule
+{
+ public:
+ public:
+ /**
+ * Construct the kernel matrix approximation using the nystroem method.
+ *
+ * @param data Input data points.
+ * @param transformedData Matrix to output results into.
+ * @param eigval KPCA eigenvalues will be written to this vector.
+ * @param eigvec KPCA eigenvectors will be written to this matrix.
+ * @param rank Rank to be used for matrix approximation.
+ * @param kernel Kernel to be used for computation.
+ */
+ static void ApplyKernelMatrix(const arma::mat& data,
+ arma::mat& transformedData,
+ arma::vec& eigval,
+ arma::mat& eigvec,
+ const size_t /* unused */,
+ KernelType kernel = KernelType())
+ {
+ // Construct the kernel matrix.
+ arma::mat kernelMatrix;
+ // Resize the kernel matrix to the right size.
+ kernelMatrix.set_size(data.n_cols, data.n_cols);
+
+ // Note that we only need to calculate the upper triangular part of the
+ // kernel matrix, since it is symmetric. This helps minimize the number of
+ // kernel evaluations.
+ for (size_t i = 0; i < data.n_cols; ++i)
+ {
+ for (size_t j = i; j < data.n_cols; ++j)
+ {
+ // Evaluate the kernel on these two points.
+ kernelMatrix(i, j) = kernel.Evaluate(data.unsafe_col(i),
+ data.unsafe_col(j));
+ }
+ }
+
+ // Copy to the lower triangular part of the matrix.
+ for (size_t i = 1; i < data.n_cols; ++i)
+ for (size_t j = 0; j < i; ++j)
+ kernelMatrix(i, j) = kernelMatrix(j, i);
+
+ // For PCA the data has to be centered, even if the data is centered. But it
+ // is not guaranteed that the data, when mapped to the kernel space, is also
+ // centered. Since we actually never work in the feature space we cannot
+ // center the data. So, we perform a "psuedo-centering" using the kernel
+ // matrix.
+ arma::rowvec rowMean = arma::sum(kernelMatrix, 0) / kernelMatrix.n_cols;
+ kernelMatrix.each_col() -= arma::sum(kernelMatrix, 1) / kernelMatrix.n_cols;
+ kernelMatrix.each_row() -= rowMean;
+ kernelMatrix += arma::sum(rowMean) / kernelMatrix.n_cols;
+
+ // Eigendecompose the centered kernel matrix.
+ arma::eig_sym(eigval, eigvec, kernelMatrix);
+
+ // Swap the eigenvalues since they are ordered backwards (we need largest to
+ // smallest).
+ for (size_t i = 0; i < floor(eigval.n_elem / 2.0); ++i)
+ eigval.swap_rows(i, (eigval.n_elem - 1) - i);
+
+ // Flip the coefficients to produce the same effect.
+ eigvec = arma::fliplr(eigvec);
+
+ transformedData = eigvec.t() * kernelMatrix;
+ }
+};
+
+}; // namespace kpca
+}; // namespace mlpack
+
+#endif
diff --git a/src/mlpack/methods/kernel_pca/kernel_rules/nystroem_method.hpp b/src/mlpack/methods/kernel_pca/kernel_rules/nystroem_method.hpp
new file mode 100644
index 0000000..34bcf62
--- /dev/null
+++ b/src/mlpack/methods/kernel_pca/kernel_rules/nystroem_method.hpp
@@ -0,0 +1,71 @@
+/**
+ * @file nystroem_method.hpp
+ * @author Marcus Edel
+ *
+ * Use the Nystroem method for approximating a kernel matrix.
+ */
+
+#ifndef __MLPACK_METHODS_KERNEL_PCA_NYSTROEM_METHOD_HPP
+#define __MLPACK_METHODS_KERNEL_PCA_NYSTROEM_METHOD_HPP
+
+#include <mlpack/core.hpp>
+#include <mlpack/methods/nystroem_method/kmeans_selection.hpp>
+#include <mlpack/methods/nystroem_method/nystroem_method.hpp>
+
+namespace mlpack {
+namespace kpca {
+
+template<
+ typename KernelType,
+ typename PointSelectionPolicy = kernel::KMeansSelection<>
+>
+class NystroemKernelRule
+{
+ public:
+ /**
+ * Construct the kernel matrix approximation using the nystroem method.
+ *
+ * @param data Input data points.
+ * @param transformedData Matrix to output results into.
+ * @param eigval KPCA eigenvalues will be written to this vector.
+ * @param eigvec KPCA eigenvectors will be written to this matrix.
+ * @param rank Rank to be used for matrix approximation.
+ * @param kernel Kernel to be used for computation.
+ */
+ static void ApplyKernelMatrix(const arma::mat& data,
+ arma::mat& transformedData,
+ arma::vec& eigval,
+ arma::mat& eigvec,
+ const size_t rank,
+ KernelType kernel = KernelType())
+ {
+ arma::mat G, v;
+ kernel::NystroemMethod<KernelType, PointSelectionPolicy> nm(data, kernel,
+ rank);
+ nm.Apply(G);
+ transformedData = G.t() * G;
+
+ // For PCA the data has to be centered, even if the data is centered. But
+ // it is not guaranteed that the data, when mapped to the kernel space, is
+ // also centered. Since we actually never work in the feature space we
+ // cannot center the data. So, we perform a "psuedo-centering" using the
+ // kernel matrix.
+ arma::rowvec rowMean = arma::sum(transformedData, 0) /
+ transformedData.n_cols;
+ transformedData.each_col() -= arma::sum(transformedData, 1) /
+ transformedData.n_cols;
+ transformedData.each_row() -= rowMean;
+ transformedData += arma::sum(rowMean) / transformedData.n_cols;
+
+ // Eigendecompose the centered kernel matrix.
+ arma::svd(eigvec, eigval, v, transformedData);
+ eigval %= eigval / (data.n_cols - 1);
+
+ transformedData = eigvec.t() * G.t();
+ }
+};
+
+}; // namespace kpca
+}; // namespace mlpack
+
+#endif
diff --git a/src/mlpack/methods/kernel_pca/CMakeLists.txt b/src/mlpack/methods/nystroem_method/CMakeLists.txt
similarity index 69%
copy from src/mlpack/methods/kernel_pca/CMakeLists.txt
copy to src/mlpack/methods/nystroem_method/CMakeLists.txt
index c575af9..5b3f5d7 100644
--- a/src/mlpack/methods/kernel_pca/CMakeLists.txt
+++ b/src/mlpack/methods/nystroem_method/CMakeLists.txt
@@ -1,8 +1,11 @@
# Define the files we need to compile
# Anything not in this list will not be compiled into MLPACK.
set(SOURCES
- kernel_pca.hpp
- kernel_pca_impl.hpp
+ nystroem_method.hpp
+ nystroem_method_impl.hpp
+ ordered_selection.hpp
+ random_selection.hpp
+ kmeans_selection.hpp
)
# Add directory name to sources.
@@ -13,11 +16,3 @@ endforeach()
# Append sources (with directory name) to list of all MLPACK sources (used at
# the parent scope).
set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)
-
-add_executable(kernel_pca
- kernel_pca_main.cpp
-)
-target_link_libraries(kernel_pca
- mlpack
-)
-install(TARGETS kernel_pca RUNTIME DESTINATION bin)
diff --git a/src/mlpack/tests/CMakeLists.txt b/src/mlpack/tests/CMakeLists.txt
index 897520b..5200932 100644
--- a/src/mlpack/tests/CMakeLists.txt
+++ b/src/mlpack/tests/CMakeLists.txt
@@ -52,6 +52,7 @@ add_executable(mlpack_test
union_find_test.cpp
svd_batch_test.cpp
svd_incremental_test.cpp
+ nystroem_method_test.cpp
)
# Link dependencies of test executable.
target_link_libraries(mlpack_test
diff --git a/src/mlpack/tests/kernel_pca_test.cpp b/src/mlpack/tests/kernel_pca_test.cpp
index 5c8f874..e1459a0 100644
--- a/src/mlpack/tests/kernel_pca_test.cpp
+++ b/src/mlpack/tests/kernel_pca_test.cpp
@@ -5,9 +5,9 @@
* Test file for Kernel PCA.
*/
#include <mlpack/core.hpp>
-#include <mlpack/core/kernels/linear_kernel.hpp>
#include <mlpack/core/kernels/gaussian_kernel.hpp>
#include <mlpack/methods/kernel_pca/kernel_pca.hpp>
+#include <mlpack/methods/kernel_pca/kernel_rules/nystroem_method.hpp>
#include <boost/test/unit_test.hpp>
#include "old_boost_test_definitions.hpp"
@@ -25,7 +25,7 @@ using namespace arma;
* If KernelPCA is working right, then it should turn a circle dataset into a
* linearly separable dataset in one dimension (which is easy to check).
*/
-BOOST_AUTO_TEST_CASE(CircleTransformationTest)
+BOOST_AUTO_TEST_CASE(CircleTransformationTestNaive)
{
// The dataset, which will have three concentric rings in three dimensions.
arma::mat dataset;
@@ -56,10 +56,8 @@ BOOST_AUTO_TEST_CASE(CircleTransformationTest)
dataset(2, i) += 5.0 * (dataset(2, i) / pointNorm);
}
- data::Save("circle.csv", dataset);
-
// Now we have a dataset; we will use the GaussianKernel to perform KernelPCA
- // to take it down to one dimension.
+ // using the naive method to take it down to one dimension.
KernelPCA<GaussianKernel> p;
p.Apply(dataset, 1);
@@ -85,4 +83,66 @@ BOOST_AUTO_TEST_CASE(CircleTransformationTest)
BOOST_REQUIRE_EQUAL(ranges[1].Contains(ranges[2]), false);
}
+/**
+ * If KernelPCA is working right, then it should turn a circle dataset into a
+ * linearly separable dataset in one dimension (which is easy to check).
+ */
+BOOST_AUTO_TEST_CASE(CircleTransformationTestNystroem)
+{
+ // The dataset, which will have three concentric rings in three dimensions.
+ arma::mat dataset;
+
+ // Now, there are 750 points centered at the origin with unit variance.
+ dataset.randn(3, 750);
+ dataset *= 0.05;
+
+ // Take the second 250 points and spread them away from the origin.
+ for (size_t i = 250; i < 500; ++i)
+ {
+ // Push the point away from the origin by 2.
+ const double pointNorm = norm(dataset.col(i), 2);
+
+ dataset(0, i) += 2.0 * (dataset(0, i) / pointNorm);
+ dataset(1, i) += 2.0 * (dataset(1, i) / pointNorm);
+ dataset(2, i) += 2.0 * (dataset(2, i) / pointNorm);
+ }
+
+ // Take the third 500 points and spread them away from the origin.
+ for (size_t i = 500; i < 750; ++i)
+ {
+ // Push the point away from the origin by 5.
+ const double pointNorm = norm(dataset.col(i), 2);
+
+ dataset(0, i) += 5.0 * (dataset(0, i) / pointNorm);
+ dataset(1, i) += 5.0 * (dataset(1, i) / pointNorm);
+ dataset(2, i) += 5.0 * (dataset(2, i) / pointNorm);
+ }
+
+ // Now we have a dataset; we will use the GaussianKernel to perform KernelPCA
+ // using the nytroem method to take it down to one dimension.
+ KernelPCA<GaussianKernel, NystroemKernelRule<GaussianKernel> > p;
+ p.Apply(dataset, 1);
+
+ // Get the ranges of each "class". These are all initialized as empty ranges
+ // containing no points.
+ Range ranges[3];
+ ranges[0] = Range();
+ ranges[1] = Range();
+ ranges[2] = Range();
+
+ // Expand the ranges to hold all of the points in the class.
+ for (size_t i = 0; i < 250; ++i)
+ ranges[0] |= dataset(0, i);
+ for (size_t i = 250; i < 500; ++i)
+ ranges[1] |= dataset(0, i);
+ for (size_t i = 500; i < 750; ++i)
+ ranges[2] |= dataset(0, i);
+
+ // None of these ranges should overlap -- the classes should be linearly
+ // separable.
+ BOOST_REQUIRE_EQUAL(ranges[0].Contains(ranges[1]), false);
+ BOOST_REQUIRE_EQUAL(ranges[0].Contains(ranges[2]), false);
+ BOOST_REQUIRE_EQUAL(ranges[1].Contains(ranges[2]), false);
+}
+
BOOST_AUTO_TEST_SUITE_END();
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list