[python-arrayfire] 39/250: FEAT/TEST: Added indexing and assignment support

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Mon Mar 28 22:59:28 UTC 2016


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch debian/master
in repository python-arrayfire.

commit f4a17385722d976c24b8c756c6da05ab25baa6b3
Author: Pavan Yalamanchili <pavan at arrayfire.com>
Date:   Thu Jul 16 16:32:24 2015 -0400

    FEAT/TEST: Added indexing and assignment support
    
    - Added simple tests in simple_array for verification
---
 arrayfire/__init__.py |  24 ++++++++
 arrayfire/arith.py    |   2 +-
 arrayfire/array.py    | 161 ++++++++++++++++++++++++++++++++++++++++++++++++--
 arrayfire/signal.py   |   3 +-
 arrayfire/util.py     |   8 +--
 tests/simple_array.py |  30 +++++++++-
 6 files changed, 216 insertions(+), 12 deletions(-)

diff --git a/arrayfire/__init__.py b/arrayfire/__init__.py
index 14376d3..28c2ffc 100644
--- a/arrayfire/__init__.py
+++ b/arrayfire/__init__.py
@@ -27,3 +27,27 @@ del ct
 del inspect
 del numbers
 del os
+
+#do not export internal classes
+del uidx
+del seq
+del index
+
+#do not export internal functions
+del binary_func
+del binary_funcr
+del create_array
+del constant_array
+del parallel_dim
+del reduce_all
+del arith_unary_func
+del arith_binary_func
+del brange
+del load_backend
+del dim4_tuple
+del is_number
+del to_str
+del safe_call
+del get_indices
+del get_assign_dims
+del slice_to_length
diff --git a/arrayfire/arith.py b/arrayfire/arith.py
index a57234f..24add97 100644
--- a/arrayfire/arith.py
+++ b/arrayfire/arith.py
@@ -22,7 +22,7 @@ def arith_binary_func(lhs, rhs, c_func):
     elif (is_left_array and is_right_array):
         safe_call(c_func(ct.pointer(out.arr), lhs.arr, rhs.arr, False))
 
-    elif (is_valid_scalar(rhs)):
+    elif (is_number(rhs)):
         ldims = dim4_tuple(lhs.dims())
         lty = lhs.type()
         other = array()
diff --git a/arrayfire/array.py b/arrayfire/array.py
index 91f3d7e..5ec2677 100644
--- a/arrayfire/array.py
+++ b/arrayfire/array.py
@@ -55,13 +55,13 @@ def binary_func(lhs, rhs, c_func):
     out = array()
     other = rhs
 
-    if (is_valid_scalar(rhs)):
+    if (is_number(rhs)):
         ldims = dim4_tuple(lhs.dims())
         lty = lhs.type()
         other = array()
         other.arr = constant_array(rhs, ldims[0], ldims[1], ldims[2], ldims[3], lty)
     elif not isinstance(rhs, array):
-        TypeError("Invalid parameter to binary function")
+        raise TypeError("Invalid parameter to binary function")
 
     safe_call(c_func(ct.pointer(out.arr), lhs.arr, other.arr, False))
 
@@ -71,18 +71,133 @@ def binary_funcr(lhs, rhs, c_func):
     out = array()
     other = lhs
 
-    if (is_valid_scalar(lhs)):
+    if (is_number(lhs)):
         rdims = dim4_tuple(rhs.dims())
         rty = rhs.type()
         other = array()
         other.arr = constant_array(lhs, rdims[0], rdims[1], rdims[2], rdims[3], rty)
     elif not isinstance(lhs, array):
-        TypeError("Invalid parameter to binary function")
+        raise TypeError("Invalid parameter to binary function")
 
     c_func(ct.pointer(out.arr), other.arr, rhs.arr, False)
 
     return out
 
+class seq(ct.Structure):
+    _fields_ = [("begin", ct.c_double),
+                ("end"  , ct.c_double),
+                ("step" , ct.c_double)]
+
+    def __init__ (self, S):
+        num = __import__("numbers")
+
+        self.begin = ct.c_double( 0)
+        self.end   = ct.c_double(-1)
+        self.step  = ct.c_double( 1)
+
+        if is_number(S):
+            self.begin = ct.c_double(S)
+            self.end   = ct.c_double(S)
+        elif isinstance(S, slice):
+            if (S.start is not None):
+                self.begin = ct.c_double(S.start)
+            if (S.stop is not None):
+                self.end   = ct.c_double(S.stop - 1) if S.stop >= 0 else ct.c_double(S.stop)
+            if (S.step is not None):
+                self.step  = ct.c_double(S.step)
+        else:
+            raise IndexError("Invalid type while indexing arrayfire.array")
+
+class uidx(ct.Union):
+    _fields_ = [("arr", ct.c_longlong),
+                ("seq", seq)]
+
+class index(ct.Structure):
+    _fields_ = [("idx", uidx),
+                ("isSeq", ct.c_bool),
+                ("isBatch", ct.c_bool)]
+
+    def __init__ (self, idx):
+
+        self.idx     = uidx()
+        self.isBatch = False
+        self.isSeq   = True
+
+        if isinstance(idx, array):
+            self.idx.arr = idx.arr
+            self.isSeq   = False
+        else:
+            self.idx.seq = seq(idx)
+
+def get_indices(key, n_dims):
+    index_vec = index * n_dims
+    inds = index_vec()
+
+    for n in range(n_dims):
+        inds[n] = index(slice(0, -1))
+
+    if isinstance(key, tuple):
+        num_idx = len(key)
+        for n in range(n_dims):
+            inds[n] = index(key[n]) if (n < num_idx) else index(slice(0, -1))
+    else:
+        inds[0] = index(key)
+
+    return inds
+
+def slice_to_length(key, dim):
+    tkey = [key.start, key.stop, key.step]
+
+    if tkey[0] is None:
+        tkey[0] = 0
+    elif tkey[0] < 0:
+        tkey[0] = dim - tkey[0]
+
+    if tkey[1] is None:
+        tkey[1] = dim
+    elif tkey[1] < 0:
+        tkey[1] = dim - tkey[1]
+
+    if tkey[2] is None:
+        tkey[2] = 1
+
+    return int(((tkey[1] - tkey[0] - 1) / tkey[2]) + 1)
+
+def get_assign_dims(key, idims):
+    dims = [1]*4
+
+    for n in range(len(idims)):
+        dims[n] = idims[n]
+
+    if is_number(key):
+        dims[0] = 1
+        return dims
+    elif isinstance(key, slice):
+        dims[0] = slice_to_length(key, idims[0])
+        return dims
+    elif isinstance(key, array):
+        dims[0] = key.elements()
+        return dims
+    elif isinstance(key, tuple):
+        n_inds = len(key)
+
+        if (n_inds > len(idims)):
+            raise IndexError("Number of indices greater than array dimensions")
+
+        for n in range(n_inds):
+            if (is_number(key[n])):
+                dims[n] = 1
+            elif (isinstance(key[n], array)):
+                dims[n] = key[n].elements()
+            elif (isinstance(key[n], slice)):
+                dims[n] = slice_to_length(key[n], idims[n])
+            else:
+                raise IndexError("Invalid type while assigning to arrayfire.array")
+
+        return dims
+    else:
+        raise IndexError("Invalid type while assigning to arrayfire.array")
+
 class array(object):
 
     def __init__(self, src=None, dims=(0,)):
@@ -152,7 +267,8 @@ class array(object):
         d1 = ct.c_longlong(0)
         d2 = ct.c_longlong(0)
         d3 = ct.c_longlong(0)
-        safe_call(clib.af_get_dims(ct.pointer(d0), ct.pointer(d1), ct.pointer(d2), ct.pointer(d3), self.arr))
+        safe_call(clib.af_get_dims(ct.pointer(d0), ct.pointer(d1),\
+                                   ct.pointer(d2), ct.pointer(d3), self.arr))
         dims = (d0.value,d1.value,d2.value,d3.value)
         return dims[:self.numdims()]
 
@@ -367,6 +483,41 @@ class array(object):
     # def __abs__(self):
     #     return self
 
+    def __getitem__(self, key):
+        try:
+            out = array()
+            n_dims = self.numdims()
+            inds = get_indices(key, n_dims)
+
+            safe_call(clib.af_index_gen(ct.pointer(out.arr),\
+                                        self.arr, ct.c_longlong(n_dims), ct.pointer(inds)))
+            return out
+        except RuntimeError as e:
+            raise IndexError(str(e))
+
+
+    def __setitem__(self, key, val):
+        try:
+            n_dims = self.numdims()
+
+            if (is_number(val)):
+                tdims = get_assign_dims(key, self.dims())
+                other_arr = constant_array(val, tdims[0], tdims[1], tdims[2], tdims[3])
+            else:
+                other_arr = val.arr
+
+            out_arr = ct.c_longlong(0)
+            inds  = get_indices(key, n_dims)
+
+            safe_call(clib.af_assign_gen(ct.pointer(out_arr),\
+                                         self.arr, ct.c_longlong(n_dims), ct.pointer(inds),\
+                                         other_arr))
+            safe_call(clib.af_release_array(self.arr))
+            self.arr = out_arr
+
+        except RuntimeError as e:
+            raise IndexError(str(e))
+
 def print_array(a):
     expr = inspect.stack()[1][-2]
     if (expr is not None):
diff --git a/arrayfire/signal.py b/arrayfire/signal.py
index aea2275..fd0d50e 100644
--- a/arrayfire/signal.py
+++ b/arrayfire/signal.py
@@ -12,7 +12,8 @@ from .array import *
 
 def approx1(signal, pos0, method=AF_INTERP_LINEAR, off_grid=0.0):
     output = array()
-    safe_call(clib.af_approx1(ct.pointer(output.arr), signal.arr, pos0.arr, method, ct.c_double(off_grid)))
+    safe_call(clib.af_approx1(ct.pointer(output.arr), signal.arr, pos0.arr,\
+                              method, ct.c_double(off_grid)))
     return output
 
 def approx2(signal, pos0, pos1, method=AF_INTERP_LINEAR, off_grid=0.0):
diff --git a/arrayfire/util.py b/arrayfire/util.py
index f4b3ab1..7ef5801 100644
--- a/arrayfire/util.py
+++ b/arrayfire/util.py
@@ -19,11 +19,14 @@ def dim4(d0=1, d1=1, d2=1, d3=1):
 
     return out
 
+def is_number(a):
+    return isinstance(a, numbers.Number)
+
 def dim4_tuple(dims, default=1):
     assert(isinstance(dims, tuple))
 
     if (default is not None):
-        assert(isinstance(default, numbers.Number))
+        assert(is_number(default))
 
     out = [default]*4
 
@@ -32,9 +35,6 @@ def dim4_tuple(dims, default=1):
 
     return tuple(out)
 
-def is_valid_scalar(a):
-    return isinstance(a, float) or isinstance(a, int) or isinstance(a, complex)
-
 def to_str(c_str):
     return str(c_str.value.decode('utf-8'))
 
diff --git a/tests/simple_array.py b/tests/simple_array.py
index c51998f..943ac93 100755
--- a/tests/simple_array.py
+++ b/tests/simple_array.py
@@ -19,7 +19,7 @@ print(a.is_complex(), a.is_real(), a.is_double(), a.is_single())
 print(a.is_real_floating(), a.is_floating(), a.is_integer(), a.is_bool())
 
 
-a = af.array(host.array('d', [4, 5, 6]))
+a = af.array(host.array('i', [4, 5, 6]))
 af.print_array(a)
 print(a.elements(), a.type(), a.dims(), a.numdims())
 print(a.is_empty(), a.is_scalar(), a.is_column(), a.is_row())
@@ -33,8 +33,36 @@ print(a.is_empty(), a.is_scalar(), a.is_column(), a.is_row())
 print(a.is_complex(), a.is_real(), a.is_double(), a.is_single())
 print(a.is_real_floating(), a.is_floating(), a.is_integer(), a.is_bool())
 
+a = af.randu(5, 5)
+af.print_array(a)
 b = af.array(a)
 af.print_array(b)
 
 c = a.copy()
 af.print_array(c)
+af.print_array(a[0,0])
+af.print_array(a[0])
+af.print_array(a[:])
+af.print_array(a[:,:])
+af.print_array(a[0:3,])
+af.print_array(a[-2:-1,-1])
+af.print_array(a[0:5])
+af.print_array(a[0:5:2])
+
+idx = af.array(host.array('i', [0, 3, 2]))
+af.print_array(idx)
+aa = a[idx]
+af.print_array(aa)
+
+a[0] = 1
+af.print_array(a)
+a[0] = af.randu(1, 5)
+af.print_array(a)
+a[:] = af.randu(5,5)
+af.print_array(a)
+a[:,-1] = af.randu(5,1)
+af.print_array(a)
+a[0:5:2] = af.randu(3, 5)
+af.print_array(a)
+a[idx, idx] = af.randu(3,3)
+af.print_array(a)

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/python-arrayfire.git



More information about the debian-science-commits mailing list