[python-arrayfire] 130/250: FEAT: adding select and replace
Ghislain Vaillant
ghisvail-guest at moszumanska.debian.org
Mon Mar 28 22:59:40 UTC 2016
This is an automated email from the git hooks/post-receive script.
ghisvail-guest pushed a commit to branch debian/master
in repository python-arrayfire.
commit faaf22386f9685acb0eedabfffd892fbb7d09f85
Author: Pavan Yalamanchili <pavan at arrayfire.com>
Date: Tue Nov 10 09:47:41 2015 -0500
FEAT: adding select and replace
---
arrayfire/data.py | 116 +++++++++++++++++++++++++++++++++++++++++++++++++++
tests/simple/data.py | 5 +++
2 files changed, 121 insertions(+)
diff --git a/arrayfire/data.py b/arrayfire/data.py
index c202aac..2e6c9fc 100644
--- a/arrayfire/data.py
+++ b/arrayfire/data.py
@@ -15,6 +15,7 @@ from sys import version_info
from .library import *
from .array import *
from .util import *
+from .util import _is_number
def constant(val, d0, d1=None, d2=None, d3=None, dtype=Dtype.f32):
"""
@@ -781,3 +782,118 @@ def upper(a, is_unit_diag=False):
out = Array()
safe_call(backend.get().af_upper(ct.pointer(out.arr), a.arr, is_unit_diag))
return out
+
+def select(cond, lhs, rhs):
+ """
+ Select elements from one of two arrays based on condition.
+
+ Parameters
+ ----------
+
+ cond : af.Array
+ Conditional array
+
+ lhs : af.Array or scalar
+ numerical array whose elements are picked when conditional element is True
+
+ rhs : af.Array or scalar
+ numerical array whose elements are picked when conditional element is False
+
+ Returns
+ --------
+
+ out: af.Array
+ An array containing elements from `lhs` when `cond` is True and `rhs` when False.
+
+ Examples
+ ---------
+
+ >>> import arrayfire as af
+ >>> a = af.randu(3,3)
+ >>> b = af.randu(3,3)
+ >>> cond = a > b
+ >>> res = af.select(cond, a, b)
+
+ >>> af.display(a)
+ [3 3 1 1]
+ 0.4107 0.1794 0.3775
+ 0.8224 0.4198 0.3027
+ 0.9518 0.0081 0.6456
+
+ >>> af.display(b)
+ [3 3 1 1]
+ 0.7269 0.3569 0.3341
+ 0.7104 0.1437 0.0899
+ 0.5201 0.4563 0.5363
+
+ >>> af.display(res)
+ [3 3 1 1]
+ 0.7269 0.3569 0.3775
+ 0.8224 0.4198 0.3027
+ 0.9518 0.4563 0.6456
+ """
+ out = Array()
+
+ is_left_array = isinstance(lhs, Array)
+ is_right_array = isinstance(rhs, Array)
+
+ if not (is_left_array or is_right_array):
+ raise TypeError("Atleast one input needs to be of type arrayfire.array")
+
+ elif (is_left_array and is_right_array):
+ safe_call(backend.get().af_select(ct.pointer(out.arr), cond.arr, lhs.arr, rhs.arr))
+
+ elif (_is_number(rhs)):
+ safe_call(backend.get().af_select_scalar_r(ct.pointer(out.arr), cond.arr, lhs.arr, ct.c_double(rhs)))
+ else:
+ safe_call(backend.get().af_select_scalar_l(ct.pointer(out.arr), cond.arr, ct.c_double(lhs), rhs.arr))
+
+ return out
+
+def replace(lhs, cond, rhs):
+ """
+ Select elements from one of two arrays based on condition.
+
+ Parameters
+ ----------
+
+ lhs : af.Array or scalar
+ numerical array whose elements are replaced with `rhs` when conditional element is False
+
+ cond : af.Array
+ Conditional array
+
+ rhs : af.Array or scalar
+ numerical array whose elements are picked when conditional element is False
+
+ Examples
+ ---------
+ >>> import arrayfire as af
+ >>> a = af.randu(3,3)
+ >>> af.display(a)
+ [3 3 1 1]
+ 0.4107 0.1794 0.3775
+ 0.8224 0.4198 0.3027
+ 0.9518 0.0081 0.6456
+
+ >>> cond = (a >= 0.25) & (a <= 0.75)
+ >>> af.display(cond)
+ [3 3 1 1]
+ 1 0 1
+ 0 1 1
+ 0 0 1
+
+ >>> af.replace(a, cond, 0.3333)
+ >>> af.display(a)
+ [3 3 1 1]
+ 0.3333 0.1794 0.3333
+ 0.8224 0.3333 0.3333
+ 0.9518 0.0081 0.3333
+
+ """
+ is_right_array = isinstance(rhs, Array)
+
+ if (is_right_array):
+ safe_call(backend.get().af_replace(lhs.arr, cond.arr, rhs.arr))
+ else:
+ safe_call(backend.get().af_replace_scalar(lhs.arr, cond.arr, ct.c_double(rhs)))
diff --git a/tests/simple/data.py b/tests/simple/data.py
index b19b388..7c95a81 100644
--- a/tests/simple/data.py
+++ b/tests/simple/data.py
@@ -76,4 +76,9 @@ def simple_data(verbose=False):
af.transpose_inplace(a)
display_func(a)
+ display_func(af.select(a > 0.3, a, -0.3))
+
+ af.replace(a, a > 0.3, -0.3)
+ display_func(a)
+
_util.tests['data'] = simple_data
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/python-arrayfire.git
More information about the debian-science-commits
mailing list