[python-arrayfire] 81/250: Making sure indexing operation is not dropping dimensions
Ghislain Vaillant
ghisvail-guest at moszumanska.debian.org
Mon Mar 28 22:59:34 UTC 2016
This is an automated email from the git hooks/post-receive script.
ghisvail-guest pushed a commit to branch debian/master
in repository python-arrayfire.
commit 6b543d2aeb2b6ec78f40a700b3c5b92dfcaa1ab5
Author: Pavan Yalamanchili <pavan at arrayfire.com>
Date: Mon Aug 31 07:14:40 2015 -0400
Making sure indexing operation is not dropping dimensions
- Added relevant test
---
arrayfire/array.py | 4 ++--
arrayfire/index.py | 13 ++++---------
tests/simple_index.py | 8 ++++++++
3 files changed, 14 insertions(+), 11 deletions(-)
diff --git a/arrayfire/array.py b/arrayfire/array.py
index adcef8a..21a6eb0 100644
--- a/arrayfire/array.py
+++ b/arrayfire/array.py
@@ -426,7 +426,7 @@ class Array(BaseArray):
try:
out = Array()
n_dims = self.numdims()
- inds = get_indices(key, n_dims)
+ inds = get_indices(key)
safe_call(clib.af_index_gen(ct.pointer(out.arr),
self.arr, ct.c_longlong(n_dims), ct.pointer(inds)))
@@ -446,7 +446,7 @@ class Array(BaseArray):
other_arr = val.arr
out_arr = ct.c_void_p(0)
- inds = get_indices(key, n_dims)
+ inds = get_indices(key)
safe_call(clib.af_assign_gen(ct.pointer(out_arr),
self.arr, ct.c_longlong(n_dims), ct.pointer(inds),
diff --git a/arrayfire/index.py b/arrayfire/index.py
index 1f1d419..608b062 100644
--- a/arrayfire/index.py
+++ b/arrayfire/index.py
@@ -104,13 +104,11 @@ class Index(ct.Structure):
else:
self.idx.seq = Seq(idx)
-def get_indices(key, n_dims):
+def get_indices(key):
- index_vec = Index * n_dims
- inds = index_vec()
-
- for n in range(n_dims):
- inds[n] = Index(slice(None))
+ index_vec = Index * 4
+ S = Index(slice(None))
+ inds = index_vec(S, S, S, S)
if isinstance(key, tuple):
n_idx = len(key)
@@ -143,9 +141,6 @@ def get_assign_dims(key, idims):
elif isinstance(key, tuple):
n_inds = len(key)
- if (n_inds > len(idims)):
- raise IndexError("Number of indices greater than array dimensions")
-
for n in range(n_inds):
if (is_number(key[n])):
dims[n] = 1
diff --git a/tests/simple_index.py b/tests/simple_index.py
index 716d96a..a3c924c 100755
--- a/tests/simple_index.py
+++ b/tests/simple_index.py
@@ -58,3 +58,11 @@ af.display(a)
for ii in ParallelRange(2,5):
b[ii] = 2
af.display(b)
+
+a = af.randu(3,2)
+rows = af.constant(0, 1, dtype=af.s32)
+b = a[:,rows]
+af.display(b)
+for r in rows:
+ af.display(r)
+ af.display(b[:,r])
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/python-arrayfire.git
More information about the debian-science-commits
mailing list