[python-arrayfire] 39/250: FEAT/TEST: Added indexing and assignment support
Ghislain Vaillant
ghisvail-guest at moszumanska.debian.org
Mon Mar 28 22:59:28 UTC 2016
This is an automated email from the git hooks/post-receive script.
ghisvail-guest pushed a commit to branch debian/master
in repository python-arrayfire.
commit f4a17385722d976c24b8c756c6da05ab25baa6b3
Author: Pavan Yalamanchili <pavan at arrayfire.com>
Date: Thu Jul 16 16:32:24 2015 -0400
FEAT/TEST: Added indexing and assignment support
- Added simple tests in simple_array for verification
---
arrayfire/__init__.py | 24 ++++++++
arrayfire/arith.py | 2 +-
arrayfire/array.py | 161 ++++++++++++++++++++++++++++++++++++++++++++++++--
arrayfire/signal.py | 3 +-
arrayfire/util.py | 8 +--
tests/simple_array.py | 30 +++++++++-
6 files changed, 216 insertions(+), 12 deletions(-)
diff --git a/arrayfire/__init__.py b/arrayfire/__init__.py
index 14376d3..28c2ffc 100644
--- a/arrayfire/__init__.py
+++ b/arrayfire/__init__.py
@@ -27,3 +27,27 @@ del ct
del inspect
del numbers
del os
+
+#do not export internal classes
+del uidx
+del seq
+del index
+
+#do not export internal functions
+del binary_func
+del binary_funcr
+del create_array
+del constant_array
+del parallel_dim
+del reduce_all
+del arith_unary_func
+del arith_binary_func
+del brange
+del load_backend
+del dim4_tuple
+del is_number
+del to_str
+del safe_call
+del get_indices
+del get_assign_dims
+del slice_to_length
diff --git a/arrayfire/arith.py b/arrayfire/arith.py
index a57234f..24add97 100644
--- a/arrayfire/arith.py
+++ b/arrayfire/arith.py
@@ -22,7 +22,7 @@ def arith_binary_func(lhs, rhs, c_func):
elif (is_left_array and is_right_array):
safe_call(c_func(ct.pointer(out.arr), lhs.arr, rhs.arr, False))
- elif (is_valid_scalar(rhs)):
+ elif (is_number(rhs)):
ldims = dim4_tuple(lhs.dims())
lty = lhs.type()
other = array()
diff --git a/arrayfire/array.py b/arrayfire/array.py
index 91f3d7e..5ec2677 100644
--- a/arrayfire/array.py
+++ b/arrayfire/array.py
@@ -55,13 +55,13 @@ def binary_func(lhs, rhs, c_func):
out = array()
other = rhs
- if (is_valid_scalar(rhs)):
+ if (is_number(rhs)):
ldims = dim4_tuple(lhs.dims())
lty = lhs.type()
other = array()
other.arr = constant_array(rhs, ldims[0], ldims[1], ldims[2], ldims[3], lty)
elif not isinstance(rhs, array):
- TypeError("Invalid parameter to binary function")
+ raise TypeError("Invalid parameter to binary function")
safe_call(c_func(ct.pointer(out.arr), lhs.arr, other.arr, False))
@@ -71,18 +71,133 @@ def binary_funcr(lhs, rhs, c_func):
out = array()
other = lhs
- if (is_valid_scalar(lhs)):
+ if (is_number(lhs)):
rdims = dim4_tuple(rhs.dims())
rty = rhs.type()
other = array()
other.arr = constant_array(lhs, rdims[0], rdims[1], rdims[2], rdims[3], rty)
elif not isinstance(lhs, array):
- TypeError("Invalid parameter to binary function")
+ raise TypeError("Invalid parameter to binary function")
c_func(ct.pointer(out.arr), other.arr, rhs.arr, False)
return out
+class seq(ct.Structure):
+ _fields_ = [("begin", ct.c_double),
+ ("end" , ct.c_double),
+ ("step" , ct.c_double)]
+
+ def __init__ (self, S):
+ num = __import__("numbers")
+
+ self.begin = ct.c_double( 0)
+ self.end = ct.c_double(-1)
+ self.step = ct.c_double( 1)
+
+ if is_number(S):
+ self.begin = ct.c_double(S)
+ self.end = ct.c_double(S)
+ elif isinstance(S, slice):
+ if (S.start is not None):
+ self.begin = ct.c_double(S.start)
+ if (S.stop is not None):
+ self.end = ct.c_double(S.stop - 1) if S.stop >= 0 else ct.c_double(S.stop)
+ if (S.step is not None):
+ self.step = ct.c_double(S.step)
+ else:
+ raise IndexError("Invalid type while indexing arrayfire.array")
+
+class uidx(ct.Union):
+ _fields_ = [("arr", ct.c_longlong),
+ ("seq", seq)]
+
+class index(ct.Structure):
+ _fields_ = [("idx", uidx),
+ ("isSeq", ct.c_bool),
+ ("isBatch", ct.c_bool)]
+
+ def __init__ (self, idx):
+
+ self.idx = uidx()
+ self.isBatch = False
+ self.isSeq = True
+
+ if isinstance(idx, array):
+ self.idx.arr = idx.arr
+ self.isSeq = False
+ else:
+ self.idx.seq = seq(idx)
+
+def get_indices(key, n_dims):
+ index_vec = index * n_dims
+ inds = index_vec()
+
+ for n in range(n_dims):
+ inds[n] = index(slice(0, -1))
+
+ if isinstance(key, tuple):
+ num_idx = len(key)
+ for n in range(n_dims):
+ inds[n] = index(key[n]) if (n < num_idx) else index(slice(0, -1))
+ else:
+ inds[0] = index(key)
+
+ return inds
+
+def slice_to_length(key, dim):
+ tkey = [key.start, key.stop, key.step]
+
+ if tkey[0] is None:
+ tkey[0] = 0
+ elif tkey[0] < 0:
+ tkey[0] = dim - tkey[0]
+
+ if tkey[1] is None:
+ tkey[1] = dim
+ elif tkey[1] < 0:
+ tkey[1] = dim - tkey[1]
+
+ if tkey[2] is None:
+ tkey[2] = 1
+
+ return int(((tkey[1] - tkey[0] - 1) / tkey[2]) + 1)
+
+def get_assign_dims(key, idims):
+ dims = [1]*4
+
+ for n in range(len(idims)):
+ dims[n] = idims[n]
+
+ if is_number(key):
+ dims[0] = 1
+ return dims
+ elif isinstance(key, slice):
+ dims[0] = slice_to_length(key, idims[0])
+ return dims
+ elif isinstance(key, array):
+ dims[0] = key.elements()
+ return dims
+ elif isinstance(key, tuple):
+ n_inds = len(key)
+
+ if (n_inds > len(idims)):
+ raise IndexError("Number of indices greater than array dimensions")
+
+ for n in range(n_inds):
+ if (is_number(key[n])):
+ dims[n] = 1
+ elif (isinstance(key[n], array)):
+ dims[n] = key[n].elements()
+ elif (isinstance(key[n], slice)):
+ dims[n] = slice_to_length(key[n], idims[n])
+ else:
+ raise IndexError("Invalid type while assigning to arrayfire.array")
+
+ return dims
+ else:
+ raise IndexError("Invalid type while assigning to arrayfire.array")
+
class array(object):
def __init__(self, src=None, dims=(0,)):
@@ -152,7 +267,8 @@ class array(object):
d1 = ct.c_longlong(0)
d2 = ct.c_longlong(0)
d3 = ct.c_longlong(0)
- safe_call(clib.af_get_dims(ct.pointer(d0), ct.pointer(d1), ct.pointer(d2), ct.pointer(d3), self.arr))
+ safe_call(clib.af_get_dims(ct.pointer(d0), ct.pointer(d1),\
+ ct.pointer(d2), ct.pointer(d3), self.arr))
dims = (d0.value,d1.value,d2.value,d3.value)
return dims[:self.numdims()]
@@ -367,6 +483,41 @@ class array(object):
# def __abs__(self):
# return self
+ def __getitem__(self, key):
+ try:
+ out = array()
+ n_dims = self.numdims()
+ inds = get_indices(key, n_dims)
+
+ safe_call(clib.af_index_gen(ct.pointer(out.arr),\
+ self.arr, ct.c_longlong(n_dims), ct.pointer(inds)))
+ return out
+ except RuntimeError as e:
+ raise IndexError(str(e))
+
+
+ def __setitem__(self, key, val):
+ try:
+ n_dims = self.numdims()
+
+ if (is_number(val)):
+ tdims = get_assign_dims(key, self.dims())
+ other_arr = constant_array(val, tdims[0], tdims[1], tdims[2], tdims[3])
+ else:
+ other_arr = val.arr
+
+ out_arr = ct.c_longlong(0)
+ inds = get_indices(key, n_dims)
+
+ safe_call(clib.af_assign_gen(ct.pointer(out_arr),\
+ self.arr, ct.c_longlong(n_dims), ct.pointer(inds),\
+ other_arr))
+ safe_call(clib.af_release_array(self.arr))
+ self.arr = out_arr
+
+ except RuntimeError as e:
+ raise IndexError(str(e))
+
def print_array(a):
expr = inspect.stack()[1][-2]
if (expr is not None):
diff --git a/arrayfire/signal.py b/arrayfire/signal.py
index aea2275..fd0d50e 100644
--- a/arrayfire/signal.py
+++ b/arrayfire/signal.py
@@ -12,7 +12,8 @@ from .array import *
def approx1(signal, pos0, method=AF_INTERP_LINEAR, off_grid=0.0):
output = array()
- safe_call(clib.af_approx1(ct.pointer(output.arr), signal.arr, pos0.arr, method, ct.c_double(off_grid)))
+ safe_call(clib.af_approx1(ct.pointer(output.arr), signal.arr, pos0.arr,\
+ method, ct.c_double(off_grid)))
return output
def approx2(signal, pos0, pos1, method=AF_INTERP_LINEAR, off_grid=0.0):
diff --git a/arrayfire/util.py b/arrayfire/util.py
index f4b3ab1..7ef5801 100644
--- a/arrayfire/util.py
+++ b/arrayfire/util.py
@@ -19,11 +19,14 @@ def dim4(d0=1, d1=1, d2=1, d3=1):
return out
+def is_number(a):
+ return isinstance(a, numbers.Number)
+
def dim4_tuple(dims, default=1):
assert(isinstance(dims, tuple))
if (default is not None):
- assert(isinstance(default, numbers.Number))
+ assert(is_number(default))
out = [default]*4
@@ -32,9 +35,6 @@ def dim4_tuple(dims, default=1):
return tuple(out)
-def is_valid_scalar(a):
- return isinstance(a, float) or isinstance(a, int) or isinstance(a, complex)
-
def to_str(c_str):
return str(c_str.value.decode('utf-8'))
diff --git a/tests/simple_array.py b/tests/simple_array.py
index c51998f..943ac93 100755
--- a/tests/simple_array.py
+++ b/tests/simple_array.py
@@ -19,7 +19,7 @@ print(a.is_complex(), a.is_real(), a.is_double(), a.is_single())
print(a.is_real_floating(), a.is_floating(), a.is_integer(), a.is_bool())
-a = af.array(host.array('d', [4, 5, 6]))
+a = af.array(host.array('i', [4, 5, 6]))
af.print_array(a)
print(a.elements(), a.type(), a.dims(), a.numdims())
print(a.is_empty(), a.is_scalar(), a.is_column(), a.is_row())
@@ -33,8 +33,36 @@ print(a.is_empty(), a.is_scalar(), a.is_column(), a.is_row())
print(a.is_complex(), a.is_real(), a.is_double(), a.is_single())
print(a.is_real_floating(), a.is_floating(), a.is_integer(), a.is_bool())
+a = af.randu(5, 5)
+af.print_array(a)
b = af.array(a)
af.print_array(b)
c = a.copy()
af.print_array(c)
+af.print_array(a[0,0])
+af.print_array(a[0])
+af.print_array(a[:])
+af.print_array(a[:,:])
+af.print_array(a[0:3,])
+af.print_array(a[-2:-1,-1])
+af.print_array(a[0:5])
+af.print_array(a[0:5:2])
+
+idx = af.array(host.array('i', [0, 3, 2]))
+af.print_array(idx)
+aa = a[idx]
+af.print_array(aa)
+
+a[0] = 1
+af.print_array(a)
+a[0] = af.randu(1, 5)
+af.print_array(a)
+a[:] = af.randu(5,5)
+af.print_array(a)
+a[:,-1] = af.randu(5,1)
+af.print_array(a)
+a[0:5:2] = af.randu(3, 5)
+af.print_array(a)
+a[idx, idx] = af.randu(3,3)
+af.print_array(a)
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/python-arrayfire.git
More information about the debian-science-commits
mailing list