[mlpack] 128/207: Improved Load to Load col and row arma vectors
Barak A. Pearlmutter
barak+git at pearlmutter.net
Thu Mar 23 17:53:47 UTC 2017
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch master
in repository mlpack.
commit 5bd09899d82715909c85138a0da32f049b1270f9
Author: Lakshya Agrawal <zeeshan.lakshya at gmail.com>
Date: Wed Jan 18 14:43:57 2017 +0530
Improved Load to Load col and row arma vectors
---
src/mlpack/core/data/load.hpp | 5 +++
src/mlpack/core/data/load_impl.hpp | 31 +++++++--------
src/mlpack/tests/load_save_test.cpp | 78 +++++++++++++++++++++++++++++++++++++
3 files changed, 97 insertions(+), 17 deletions(-)
diff --git a/src/mlpack/core/data/load.hpp b/src/mlpack/core/data/load.hpp
index 193d43c..7b4bdf7 100644
--- a/src/mlpack/core/data/load.hpp
+++ b/src/mlpack/core/data/load.hpp
@@ -102,6 +102,11 @@ bool Load(const std::string& filename,
arma::Col<eT>& vec,
const bool fatal = false);
+template<typename eT>
+bool Load(const std::string& filename,
+ arma::Row<eT>& rowvec,
+ const bool fatal = false);
+
template<typename eT, typename PolicyType>
bool Load(const std::string& filename,
arma::Mat<eT>& matrix,
diff --git a/src/mlpack/core/data/load_impl.hpp b/src/mlpack/core/data/load_impl.hpp
index ce12008..874f32f 100644
--- a/src/mlpack/core/data/load_impl.hpp
+++ b/src/mlpack/core/data/load_impl.hpp
@@ -16,6 +16,7 @@
#include "load.hpp"
#include "extension.hpp"
+#include<exception>
#include <algorithm>
#include <mlpack/core/util/timers.hpp>
@@ -94,22 +95,16 @@ bool Load(const std::string& filename,
arma::Col<eT>& vec,
const bool fatal)
{
- arma::mat matrix(vec);
- Load(filename, matrix, fatal, false);
- //Check if the returned matrix is vector if yes then
- //copy this to our vec else return error
- if( matrix.is_vec() )
- vec = matrix;
- else
- if( fatal)
- Log::Fatal << "Loading '" << filename << "'failed "<<
- ", the data you are tying to load is not a column vector "<<
- std::endl;
- else
- Log::Warn << "Loading '" << filename << "'failed " <<
- ", the data you are trying to load is not a column vector " <<
- std::endl;
- return false;
+ Load(filename, vec, fatal, false);
+ return true;
+}
+
+template<typename eT>
+bool Load(const std::string& filename,
+ arma::Row<eT>& rowvec,
+ const bool fatal)
+{
+ Load(filename, rowvec, fatal, false);
return true;
}
@@ -325,7 +320,9 @@ bool Load(const std::string& filename,
// We can't use the stream if the type is HDF5.
bool success;
if (loadType != arma::hdf5_binary)
- success = matrix.load(stream, loadType);
+ {
+ success = matrix.load(stream, loadType);
+ }
else
success = matrix.load(filename, loadType);
diff --git a/src/mlpack/tests/load_save_test.cpp b/src/mlpack/tests/load_save_test.cpp
index eef18e6..7ff0ce2 100644
--- a/src/mlpack/tests/load_save_test.cpp
+++ b/src/mlpack/tests/load_save_test.cpp
@@ -178,6 +178,84 @@ BOOST_AUTO_TEST_CASE(LoadTransposedCSVTest)
// Remove the file.
remove("test_file.csv");
}
+/**
+ *Make sure ColVec can be loaded
+**/
+BOOST_AUTO_TEST_CASE(LoadColVecCSVTest)
+{
+ fstream f;
+ f.open("test_file.csv", fstream::out);
+
+ for( int i = 0; i < 8; ++i)
+ f << i << endl;
+
+ f.close();
+
+ arma::vec test;
+ BOOST_REQUIRE(data::Load("test_file.csv", test, false) == true);
+
+ BOOST_REQUIRE_EQUAL(test.n_cols, 1);
+ BOOST_REQUIRE_EQUAL(test.n_rows, 8);
+
+ for(size_t i = 0; i < 8; ++i)
+ BOOST_REQUIRE_CLOSE(test[i], (double) (i), 1e-5);
+
+ //Remove the file
+ remove("test_file.csv");
+}
+
+/**
+ *Make Sure Load throws Exception when trying
+ *to load a Matrix in ColVec
+ *and RowVec
+**/
+BOOST_AUTO_TEST_CASE(LoadMatinColVec)
+{
+ fstream f;
+ f.open("test_file.csv", fstream::out);
+
+ f << "1,2" << endl;
+ f << "3,4" << endl;
+
+ f.close();
+
+ arma::colvec coltest;
+ BOOST_WARN_THROW(data::Load("test_file.csv", coltest, false),
+ std::exception);
+
+ arma::rowvec rowtest;
+ BOOST_WARN_THROW(data::Load("test_file.csv", rowtest, false),
+ std::exception);
+
+ remove("test_file.csv");
+}
+
+/**
+ *Make sure RowVec can be loaded
+**/
+BOOST_AUTO_TEST_CASE(LoadRowVecCSVTest)
+{
+ fstream f;
+ f.open("test_file.csv", fstream::out);
+
+ for( int i = 0 ; i < 7;++i)
+ f << i << ",";
+ f << "7";
+ f << endl;
+
+ f.close();
+
+ arma::rowvec test;
+ BOOST_REQUIRE(data::Load("test_file.csv", test, false) == true);
+
+ BOOST_REQUIRE_EQUAL(test.n_cols, 8);
+ BOOST_REQUIRE_EQUAL(test.n_rows, 1);
+
+ for( size_t i = 0; i < 8 ; ++i)
+ BOOST_REQUIRE_CLOSE(test[i], (double) (i) , 1e-5);
+
+ remove("test_file.csv");
+}
/**
* Make sure TSVs can be loaded in transposed form.
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list