[mlpack] 57/58: Backport another sparse matrix constructor for softmax regression.
Barak A. Pearlmutter
barak+git at cs.nuim.ie
Tue Sep 9 13:19:43 UTC 2014
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch svn-trunk
in repository mlpack.
commit 2bb218f93a706967d1ff8b74ac7a757688ab73fc
Author: rcurtin <rcurtin at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date: Sat Aug 30 04:06:08 2014 +0000
Backport another sparse matrix constructor for softmax regression.
git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@17140 9d5b8971-822b-0410-80eb-d18c1038ef23
---
src/mlpack/core/arma_extend/SpMat_extra_bones.hpp | 16 +++++-
src/mlpack/core/arma_extend/SpMat_extra_meat.hpp | 60 +++++++++++++++++++++++
2 files changed, 74 insertions(+), 2 deletions(-)
diff --git a/src/mlpack/core/arma_extend/SpMat_extra_bones.hpp b/src/mlpack/core/arma_extend/SpMat_extra_bones.hpp
index 1c84221..51b724c 100644
--- a/src/mlpack/core/arma_extend/SpMat_extra_bones.hpp
+++ b/src/mlpack/core/arma_extend/SpMat_extra_bones.hpp
@@ -5,12 +5,14 @@
* Add a batch constructor for SpMat, if the version is older than 3.810.0.
*/
#if ARMA_VERSION_MAJOR == 3 && ARMA_VERSION_MINOR < 810
-template<typename T1, typename T2> inline SpMat(
+template<typename T1, typename T2>
+inline SpMat(
const Base<uword, T1>& locations,
const Base<eT, T2>& values,
const bool sort_locations = true);
-template<typename T1, typename T2> inline SpMat(
+template<typename T1, typename T2>
+inline SpMat(
const Base<uword, T1>& locations,
const Base<eT, T2>& values,
const uword n_rows,
@@ -18,6 +20,16 @@ template<typename T1, typename T2> inline SpMat(
const bool sort_locations = true);
#endif
+#if ARMA_VERSION_MAJOR == 3 && ARMA_VERSION_MINOR < 920
+template<typename T1, typename T2, typename T3>
+inline SpMat(
+ const Base<uword, T1>& rowind,
+ const Base<uword, T2>& colptr,
+ const Base<eT, T3>& values,
+ const uword n_rows,
+ const uword n_cols);
+#endif
+
/*
* Extra functions for SpMat<eT>
* Adding definition of row_col_iterator to generalize with Mat<eT>::row_col_iterator
diff --git a/src/mlpack/core/arma_extend/SpMat_extra_meat.hpp b/src/mlpack/core/arma_extend/SpMat_extra_meat.hpp
index 2cf980b..d2ad10e 100644
--- a/src/mlpack/core/arma_extend/SpMat_extra_meat.hpp
+++ b/src/mlpack/core/arma_extend/SpMat_extra_meat.hpp
@@ -250,6 +250,66 @@ SpMat<eT>::SpMat(const Base<uword,T1>& locations_expr, const Base<eT,T2>& vals_e
#endif
+#if ARMA_VERSION_MAJOR == 3 && ARMA_VERSION_MINOR < 920
+//! Insert a large number of values at once.
+//! Per CSC format, rowind_expr should be row indices,~
+//! colptr_expr should column ptr indices locations,
+//! and values should be the corresponding values.
+//! In this constructor the size is explicitly given.
+//! Values are assumed to be sorted, and the size~
+//! information is trusted
+template<typename eT>
+template<typename T1, typename T2, typename T3>
+inline
+SpMat<eT>::SpMat
+ (
+ const Base<uword,T1>& rowind_expr,
+ const Base<uword,T2>& colptr_expr,
+ const Base<eT, T3>& values_expr,
+ const uword in_n_rows,
+ const uword in_n_cols
+ )
+ : n_rows(0)
+ , n_cols(0)
+ , n_elem(0)
+ , n_nonzero(0)
+ , vec_state(0)
+ , values(NULL)
+ , row_indices(NULL)
+ , col_ptrs(NULL)
+ {
+ arma_extra_debug_sigprint_this(this);
+
+ init(in_n_rows, in_n_cols);
+
+ const unwrap<T1> rowind_tmp( rowind_expr.get_ref() );
+ const unwrap<T2> colptr_tmp( colptr_expr.get_ref() );
+ const unwrap<T3> vals_tmp( values_expr.get_ref() );
+
+ const Mat<uword>& rowind = rowind_tmp.M;
+ const Mat<uword>& colptr = colptr_tmp.M;
+ const Mat<eT>& vals = vals_tmp.M;
+
+ arma_debug_check( (rowind.is_vec() == false), "SpMat::SpMat(): given 'rowind' object is not a vector" );
+ arma_debug_check( (colptr.is_vec() == false), "SpMat::SpMat(): given 'colptr' object is not a vector" );
+ arma_debug_check( (vals.is_vec() == false), "SpMat::SpMat(): given 'values' object is not a vector" );
+
+ arma_debug_check( (rowind.n_elem != vals.n_elem), "SpMat::SpMat(): number of row indices is not equal to number of values" );
+ arma_debug_check( (colptr.n_elem != (n_cols+1) ), "SpMat::SpMat(): number of column pointers is not equal to n_cols+1" );
+
+ // Resize to correct number of elements (this also sets n_nonzero)
+ mem_resize(vals.n_elem);
+
+ // copy supplied values into sparse matrix -- not checked for consistency
+ arrayops::copy(access::rwp(row_indices), rowind.memptr(), rowind.n_elem );
+ arrayops::copy(access::rwp(col_ptrs), colptr.memptr(), colptr.n_elem );
+ arrayops::copy(access::rwp(values), vals.memptr(), vals.n_elem );
+
+ // important: set the sentinel as well
+ access::rw(col_ptrs[n_cols + 1]) = std::numeric_limits<uword>::max();
+ }
+#endif
+
#if ARMA_VERSION_MAJOR < 4 || \
(ARMA_VERSION_MAJOR == 4 && ARMA_VERSION_MINOR < 349)
template<typename eT>
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list