[python-arrayfire] 45/250: FEAT/TEST: Adding support for getting data back to the host
Ghislain Vaillant
ghisvail-guest at moszumanska.debian.org
Mon Mar 28 22:59:29 UTC 2016
This is an automated email from the git hooks/post-receive script.
ghisvail-guest pushed a commit to branch debian/master
in repository python-arrayfire.
commit 71c38cae3db2c0baba51c41487281cfcacd67d06
Author: Pavan Yalamanchili <pavan at arrayfire.com>
Date: Mon Jul 20 07:00:00 2015 -0400
FEAT/TEST: Adding support for getting data back to the host
---
arrayfire/__init__.py | 3 +++
arrayfire/array.py | 44 ++++++++++++++++++++++++++++++++++++++++++++
arrayfire/blas.py | 8 --------
arrayfire/util.py | 18 ++++++++++++++++++
tests/simple_array.py | 20 ++++++++++++++++++++
tests/simple_blas.py | 5 -----
6 files changed, 85 insertions(+), 13 deletions(-)
diff --git a/arrayfire/__init__.py b/arrayfire/__init__.py
index 326234a..7062420 100644
--- a/arrayfire/__init__.py
+++ b/arrayfire/__init__.py
@@ -53,3 +53,6 @@ del safe_call
del get_indices
del get_assign_dims
del slice_to_length
+del ctype_to_lists
+del to_dtype
+del to_c_type
diff --git a/arrayfire/array.py b/arrayfire/array.py
index c2450bb..98dfff8 100644
--- a/arrayfire/array.py
+++ b/arrayfire/array.py
@@ -83,6 +83,14 @@ def binary_funcr(lhs, rhs, c_func):
return out
+def transpose(a, conj=False):
+ out = array()
+ safe_call(clib.af_transpose(ct.pointer(out.arr), a.arr, conj))
+ return out
+
+def transpose_inplace(a, conj=False):
+ safe_call(clib.af_transpose_inplace(a.arr, conj))
+
class seq(ct.Structure):
_fields_ = [("begin", ct.c_double),
("end" , ct.c_double),
@@ -163,6 +171,17 @@ def slice_to_length(key, dim):
return int(((tkey[1] - tkey[0] - 1) / tkey[2]) + 1)
+def ctype_to_lists(ctype_arr, dim, shape, offset=0):
+ if (dim == 0):
+ return list(ctype_arr[offset : offset + shape[0]])
+ else:
+ dim_len = shape[dim]
+ res = [[]] * dim_len
+ for n in range(dim_len):
+ res[n] = ctype_to_lists(ctype_arr, dim - 1, shape, offset)
+ offset += shape[0]
+ return res
+
def get_assign_dims(key, idims):
dims = [1]*4
@@ -518,6 +537,31 @@ class array(object):
except RuntimeError as e:
raise IndexError(str(e))
+ def to_ctype(self, row_major=False, return_shape=False):
+ tmp = transpose(self) if row_major else self
+ ctype_type = to_c_type[self.type()] * self.elements()
+ res = ctype_type()
+ safe_call(clib.af_get_data_ptr(ct.pointer(res), self.arr))
+ if (return_shape):
+ return res, self.dims()
+ else:
+ return res
+
+ def to_array(self, row_major=False, return_shape=False):
+ res = self.to_ctype(row_major, return_shape)
+
+ host = __import__("array")
+ h_type = to_typecode[self.type()]
+
+ if (return_shape):
+ return host.array(h_type, res[0]), res[1]
+ else:
+ return host.array(h_type, res)
+
+ def to_list(self, row_major=False):
+ ct_array, shape = self.to_ctype(row_major, True)
+ return ctype_to_lists(ct_array, len(shape) - 1, shape)
+
def display(a):
expr = inspect.stack()[1][-2]
if (expr is not None):
diff --git a/arrayfire/blas.py b/arrayfire/blas.py
index 312a6d0..da78919 100644
--- a/arrayfire/blas.py
+++ b/arrayfire/blas.py
@@ -39,11 +39,3 @@ def dot(lhs, rhs, lhs_opts=AF_MAT_NONE, rhs_opts=AF_MAT_NONE):
safe_call(clib.af_dot(ct.pointer(out.arr), lhs.arr, rhs.arr,\
lhs_opts, rhs_opts))
return out
-
-def transpose(a, conj=False):
- out = array()
- safe_call(clib.af_transpose(ct.pointer(out.arr), a.arr, conj))
- return out
-
-def transpose_inplace(a, conj=False):
- safe_call(clib.af_transpose_inplace(a.arr, conj))
diff --git a/arrayfire/util.py b/arrayfire/util.py
index 7ef5801..93dac2e 100644
--- a/arrayfire/util.py
+++ b/arrayfire/util.py
@@ -60,3 +60,21 @@ to_dtype = {'f' : f32,
'I' : u32,
'l' : s64,
'L' : u64}
+
+to_typecode = {f32.value : 'f',
+ f64.value : 'd',
+ b8.value : 'b',
+ u8.value : 'B',
+ s32.value : 'i',
+ u32.value : 'I',
+ s64.value : 'l',
+ u64.value : 'L'}
+
+to_c_type = {f32.value : ct.c_float,
+ f64.value : ct.c_double,
+ b8.value : ct.c_char,
+ u8.value : ct.c_ubyte,
+ s32.value : ct.c_int,
+ u32.value : ct.c_uint,
+ s64.value : ct.c_longlong,
+ u64.value : ct.c_ulonglong}
diff --git a/tests/simple_array.py b/tests/simple_array.py
index 141b741..95426a5 100755
--- a/tests/simple_array.py
+++ b/tests/simple_array.py
@@ -66,3 +66,23 @@ a[0:5:2] = af.randu(3, 5)
af.display(a)
a[idx, idx] = af.randu(3,3)
af.display(a)
+
+af.display(af.transpose(a))
+
+af.transpose_inplace(a)
+af.display(a)
+
+c = a.to_ctype()
+for n in range(a.elements()):
+ print(c[n])
+
+c,s = a.to_ctype(True, True)
+for n in range(a.elements()):
+ print(c[n])
+print(s)
+
+arr = a.to_array()
+lst = a.to_list(True)
+
+print(arr)
+print(lst)
diff --git a/tests/simple_blas.py b/tests/simple_blas.py
index 41dbf00..1fc4afa 100755
--- a/tests/simple_blas.py
+++ b/tests/simple_blas.py
@@ -19,8 +19,3 @@ af.display(af.matmul(a,b,af.AF_MAT_NONE, af.AF_MAT_TRANS))
b = af.randu(5,1)
af.display(af.dot(b,b))
-
-af.display(af.transpose(a))
-
-af.transpose_inplace(a)
-af.display(a)
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/python-arrayfire.git
More information about the debian-science-commits
mailing list