[mlpack] 128/207: Improved Load to Load col and row arma vectors

Barak A. Pearlmutter barak+git at pearlmutter.net
Thu Mar 23 17:53:47 UTC 2017


This is an automated email from the git hooks/post-receive script.

bap pushed a commit to branch master
in repository mlpack.

commit 5bd09899d82715909c85138a0da32f049b1270f9
Author: Lakshya Agrawal <zeeshan.lakshya at gmail.com>
Date:   Wed Jan 18 14:43:57 2017 +0530

    Improved Load to Load col and row arma vectors
---
 src/mlpack/core/data/load.hpp       |  5 +++
 src/mlpack/core/data/load_impl.hpp  | 31 +++++++--------
 src/mlpack/tests/load_save_test.cpp | 78 +++++++++++++++++++++++++++++++++++++
 3 files changed, 97 insertions(+), 17 deletions(-)

diff --git a/src/mlpack/core/data/load.hpp b/src/mlpack/core/data/load.hpp
index 193d43c..7b4bdf7 100644
--- a/src/mlpack/core/data/load.hpp
+++ b/src/mlpack/core/data/load.hpp
@@ -102,6 +102,11 @@ bool Load(const std::string& filename,
           arma::Col<eT>& vec,
           const bool fatal = false);
 
+template<typename eT>
+bool Load(const std::string& filename,
+          arma::Row<eT>& rowvec,
+          const bool fatal = false);
+
 template<typename eT, typename PolicyType>
 bool Load(const std::string& filename,
           arma::Mat<eT>& matrix,
diff --git a/src/mlpack/core/data/load_impl.hpp b/src/mlpack/core/data/load_impl.hpp
index ce12008..874f32f 100644
--- a/src/mlpack/core/data/load_impl.hpp
+++ b/src/mlpack/core/data/load_impl.hpp
@@ -16,6 +16,7 @@
 #include "load.hpp"
 #include "extension.hpp"
 
+#include<exception>
 #include <algorithm>
 #include <mlpack/core/util/timers.hpp>
 
@@ -94,22 +95,16 @@ bool Load(const std::string& filename,
           arma::Col<eT>& vec,
           const bool fatal)
 {
-  arma::mat matrix(vec);
-  Load(filename, matrix, fatal, false);
-  //Check if the returned matrix is vector if yes then 
-  //copy this to our vec else return error
-  if( matrix.is_vec() )
-    vec = matrix;
-  else
-    if( fatal)
-      Log::Fatal << "Loading '" << filename << "'failed "<<
-          ", the data you are tying to load is not a column vector "<< 
-          std::endl; 
-    else
-      Log::Warn << "Loading '" << filename << "'failed " <<
-          ", the data you are trying to load is not a column vector " <<
-          std::endl;
-    return false;
+  Load(filename, vec, fatal, false);
+  return true;
+}
+
+template<typename eT>
+bool Load(const std::string& filename,
+          arma::Row<eT>& rowvec,
+          const bool fatal)
+{
+  Load(filename, rowvec, fatal, false);
   return true;
 }
 
@@ -325,7 +320,9 @@ bool Load(const std::string& filename,
   // We can't use the stream if the type is HDF5.
   bool success;
   if (loadType != arma::hdf5_binary)
-    success = matrix.load(stream, loadType);
+  {
+      success = matrix.load(stream, loadType);
+  }
   else
     success = matrix.load(filename, loadType);
 
diff --git a/src/mlpack/tests/load_save_test.cpp b/src/mlpack/tests/load_save_test.cpp
index eef18e6..7ff0ce2 100644
--- a/src/mlpack/tests/load_save_test.cpp
+++ b/src/mlpack/tests/load_save_test.cpp
@@ -178,6 +178,84 @@ BOOST_AUTO_TEST_CASE(LoadTransposedCSVTest)
   // Remove the file.
   remove("test_file.csv");
 }
+/**
+  *Make sure ColVec can be loaded
+**/
+BOOST_AUTO_TEST_CASE(LoadColVecCSVTest)
+{
+  fstream f;
+  f.open("test_file.csv", fstream::out);
+
+  for( int i = 0; i < 8; ++i)
+    f << i << endl;
+
+  f.close();
+
+  arma::vec test;
+  BOOST_REQUIRE(data::Load("test_file.csv", test, false) == true);
+
+  BOOST_REQUIRE_EQUAL(test.n_cols, 1);
+  BOOST_REQUIRE_EQUAL(test.n_rows, 8);
+
+  for(size_t i = 0; i < 8; ++i)
+    BOOST_REQUIRE_CLOSE(test[i], (double) (i), 1e-5);
+
+  //Remove the file
+  remove("test_file.csv");
+}
+
+/**
+  *Make Sure Load throws Exception when trying
+  *to load a Matrix in ColVec 
+  *and RowVec
+**/
+BOOST_AUTO_TEST_CASE(LoadMatinColVec)
+{
+  fstream f;
+  f.open("test_file.csv", fstream::out);
+
+  f << "1,2" << endl;
+  f << "3,4" << endl;
+
+  f.close();
+
+  arma::colvec coltest;
+  BOOST_WARN_THROW(data::Load("test_file.csv", coltest, false),
+      std::exception);
+
+  arma::rowvec rowtest;
+  BOOST_WARN_THROW(data::Load("test_file.csv", rowtest, false),
+      std::exception);
+
+  remove("test_file.csv");
+}
+
+/**
+  *Make sure RowVec can be loaded
+**/
+BOOST_AUTO_TEST_CASE(LoadRowVecCSVTest)
+{
+  fstream f;
+  f.open("test_file.csv", fstream::out);
+
+  for( int i = 0 ; i < 7;++i)
+    f << i << ",";
+  f << "7";
+  f << endl;
+
+  f.close();
+
+  arma::rowvec test;
+  BOOST_REQUIRE(data::Load("test_file.csv", test, false) == true);
+
+  BOOST_REQUIRE_EQUAL(test.n_cols, 8);
+  BOOST_REQUIRE_EQUAL(test.n_rows, 1);
+
+  for( size_t i = 0; i < 8 ; ++i)
+    BOOST_REQUIRE_CLOSE(test[i], (double) (i) , 1e-5);
+
+  remove("test_file.csv");
+}
 
 /**
  * Make sure TSVs can be loaded in transposed form.

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git



More information about the debian-science-commits mailing list