[arrayfire] 341/408: Fixing various typos and bug fixes for SVD in CUDA and OpenCL
Ghislain Vaillant
ghisvail-guest at moszumanska.debian.org
Mon Sep 21 19:12:26 UTC 2015
This is an automated email from the git hooks/post-receive script.
ghisvail-guest pushed a commit to branch debian/sid
in repository arrayfire.
commit 4118733e94e72315ba8ed2b2050bc158ae2b83d8
Author: Pavan Yalamanchili <pavan at arrayfire.com>
Date: Tue Aug 25 06:02:55 2015 -0400
Fixing various typos and bug fixes for SVD in CUDA and OpenCL
---
src/backend/cuda/svd.cu | 8 +++++---
src/backend/opencl/magma/gebrd.cpp | 3 ++-
src/backend/opencl/magma/labrd.cpp | 2 +-
src/backend/opencl/magma/magma_cpu_lapack.h | 7 +++++++
src/backend/opencl/svd.cpp | 10 ++++++----
5 files changed, 21 insertions(+), 9 deletions(-)
diff --git a/src/backend/cuda/svd.cu b/src/backend/cuda/svd.cu
index fed35cf..37ffa78 100644
--- a/src/backend/cuda/svd.cu
+++ b/src/backend/cuda/svd.cu
@@ -114,12 +114,14 @@ SVD_SPECIALIZE(cdouble, double, Z);
int M = iDims[0];
int N = iDims[1];
- if (M <= N) {
+ if (M >= N) {
Array<T> in_copy = copyArray(in);
- return svdInPlace(s, u, vt, in_copy);
+ svdInPlace(s, u, vt, in_copy);
} else {
Array<T> in_trans = transpose(in, true);
- return svdInPlace(s, vt, u, in_trans);
+ svdInPlace(s, vt, u, in_trans);
+ transpose_inplace(vt, true);
+ transpose_inplace(u, true);
}
}
diff --git a/src/backend/opencl/magma/gebrd.cpp b/src/backend/opencl/magma/gebrd.cpp
index f1d6817..c83efb4 100644
--- a/src/backend/opencl/magma/gebrd.cpp
+++ b/src/backend/opencl/magma/gebrd.cpp
@@ -284,7 +284,8 @@ magma_gebrd_hybrid(
}
magma_labrd_gpu<Ty>(nrow, ncol, nb,
- A(i, i), lda, dA(i, i), ldda,
+ A(i, i), lda,
+ dA(i, i), ldda,
d+i, e+i, tauq+i, taup+i,
work, ldwrkx, dwork, dwork_offset, ldwrkx, // x, dx
work+(ldwrkx*nb), ldwrky, dwork, dwork_offset+(ldwrkx*nb), ldwrky, // y, dy
diff --git a/src/backend/opencl/magma/labrd.cpp b/src/backend/opencl/magma/labrd.cpp
index 3b0e186..0284ee8 100644
--- a/src/backend/opencl/magma/labrd.cpp
+++ b/src/backend/opencl/magma/labrd.cpp
@@ -249,7 +249,7 @@ magma_labrd_gpu(
cpu_blas_scal_func<Ty> cpu_blas_scal;
cpu_blas_axpy_func<Ty> cpu_blas_axpy;
cpu_lapack_larfg_work_func<Ty> cpu_lapack_larfg;
- cpu_lapack_lacgv_func<Ty> cpu_lapack_lacgv;
+ cpu_lapack_lacgv_work_func<Ty> cpu_lapack_lacgv;
CBLAS_TRANSPOSE CblasTransParam = is_cplx ? CblasConjTrans : CblasTrans;
diff --git a/src/backend/opencl/magma/magma_cpu_lapack.h b/src/backend/opencl/magma/magma_cpu_lapack.h
index 02fc644..445defd 100644
--- a/src/backend/opencl/magma/magma_cpu_lapack.h
+++ b/src/backend/opencl/magma/magma_cpu_lapack.h
@@ -26,6 +26,12 @@ int LAPACKE_slacgv(Args... args) { return 0; }
template<typename... Args>
int LAPACKE_dlacgv(Args... args) { return 0; }
+template<typename... Args>
+int LAPACKE_slacgv_work(Args... args) { return 0; }
+
+template<typename... Args>
+int LAPACKE_dlacgv_work(Args... args) { return 0; }
+
#define lapack_complex_float magmaFloatComplex
#define lapack_complex_double magmaDoubleComplex
#define LAPACK_PREFIX LAPACKE_
@@ -129,6 +135,7 @@ CPU_LAPACK_DECL1(laswp)
CPU_LAPACK_DECL1(laset)
CPU_LAPACK_DECL2(lacgv)
+CPU_LAPACK_DECL2(lacgv_work)
CPU_LAPACK_DECL2(larfg)
CPU_LAPACK_DECL2(larfg_work)
CPU_LAPACK_DECL1(lacpy)
diff --git a/src/backend/opencl/svd.cpp b/src/backend/opencl/svd.cpp
index 3d90444..1a9bb43dd 100644
--- a/src/backend/opencl/svd.cpp
+++ b/src/backend/opencl/svd.cpp
@@ -67,7 +67,7 @@ void svd(Array<T > &arrU,
{
dim4 idims = arrA.dims();
- dim4 istrides = arrA.dims();
+ dim4 istrides = arrA.strides();
const int m = (int)idims[0];
const int n = (int)idims[1];
@@ -206,12 +206,14 @@ void svd(Array<Tr> &s, Array<T> &u, Array<T> &vt, const Array<T> &in)
int M = iDims[0];
int N = iDims[1];
- if (M <= N) {
+ if (M >= N) {
Array<T> in_copy = copyArray(in);
- return svdInPlace(s, u, vt, in_copy);
+ svdInPlace(s, u, vt, in_copy);
} else {
Array<T> in_trans = transpose(in, true);
- return svdInPlace(s, vt, u, in_trans);
+ svdInPlace(s, vt, u, in_trans);
+ transpose_inplace(u, true);
+ transpose_inplace(vt, true);
}
}
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/arrayfire.git
More information about the debian-science-commits
mailing list