[python-arrayfire] 130/250: FEAT: adding select and replace

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Mon Mar 28 22:59:40 UTC 2016


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch debian/master
in repository python-arrayfire.

commit faaf22386f9685acb0eedabfffd892fbb7d09f85
Author: Pavan Yalamanchili <pavan at arrayfire.com>
Date:   Tue Nov 10 09:47:41 2015 -0500

    FEAT: adding select and replace
---
 arrayfire/data.py    | 116 +++++++++++++++++++++++++++++++++++++++++++++++++++
 tests/simple/data.py |   5 +++
 2 files changed, 121 insertions(+)

diff --git a/arrayfire/data.py b/arrayfire/data.py
index c202aac..2e6c9fc 100644
--- a/arrayfire/data.py
+++ b/arrayfire/data.py
@@ -15,6 +15,7 @@ from sys import version_info
 from .library import *
 from .array import *
 from .util import *
+from .util import _is_number
 
 def constant(val, d0, d1=None, d2=None, d3=None, dtype=Dtype.f32):
     """
@@ -781,3 +782,118 @@ def upper(a, is_unit_diag=False):
     out = Array()
     safe_call(backend.get().af_upper(ct.pointer(out.arr), a.arr, is_unit_diag))
     return out
+
+def select(cond, lhs, rhs):
+    """
+    Select elements from one of two arrays based on condition.
+
+    Parameters
+    ----------
+
+    cond : af.Array
+           Conditional array
+
+    lhs  : af.Array or scalar
+           numerical array whose elements are picked when conditional element is True
+
+    rhs  : af.Array or scalar
+           numerical array whose elements are picked when conditional element is False
+
+    Returns
+    --------
+
+    out: af.Array
+         An array containing elements from `lhs` when `cond` is True and `rhs` when False.
+
+    Examples
+    ---------
+
+    >>> import arrayfire as af
+    >>> a = af.randu(3,3)
+    >>> b = af.randu(3,3)
+    >>> cond = a > b
+    >>> res = af.select(cond, a, b)
+
+    >>> af.display(a)
+    [3 3 1 1]
+        0.4107     0.1794     0.3775
+        0.8224     0.4198     0.3027
+        0.9518     0.0081     0.6456
+
+    >>> af.display(b)
+    [3 3 1 1]
+        0.7269     0.3569     0.3341
+        0.7104     0.1437     0.0899
+        0.5201     0.4563     0.5363
+
+    >>> af.display(res)
+    [3 3 1 1]
+        0.7269     0.3569     0.3775
+        0.8224     0.4198     0.3027
+        0.9518     0.4563     0.6456
+    """
+    out = Array()
+
+    is_left_array = isinstance(lhs, Array)
+    is_right_array = isinstance(rhs, Array)
+
+    if not (is_left_array or is_right_array):
+        raise TypeError("Atleast one input needs to be of type arrayfire.array")
+
+    elif (is_left_array and is_right_array):
+        safe_call(backend.get().af_select(ct.pointer(out.arr), cond.arr, lhs.arr, rhs.arr))
+
+    elif (_is_number(rhs)):
+        safe_call(backend.get().af_select_scalar_r(ct.pointer(out.arr), cond.arr, lhs.arr, ct.c_double(rhs)))
+    else:
+        safe_call(backend.get().af_select_scalar_l(ct.pointer(out.arr), cond.arr, ct.c_double(lhs), rhs.arr))
+
+    return out
+
+def replace(lhs, cond, rhs):
+    """
+    Select elements from one of two arrays based on condition.
+
+    Parameters
+    ----------
+
+    lhs  : af.Array or scalar
+           numerical array whose elements are replaced with `rhs` when conditional element is False
+
+    cond : af.Array
+           Conditional array
+
+    rhs  : af.Array or scalar
+           numerical array whose elements are picked when conditional element is False
+
+    Examples
+    ---------
+    >>> import arrayfire as af
+    >>> a = af.randu(3,3)
+    >>> af.display(a)
+    [3 3 1 1]
+        0.4107     0.1794     0.3775
+        0.8224     0.4198     0.3027
+        0.9518     0.0081     0.6456
+
+    >>> cond = (a >= 0.25) & (a <= 0.75)
+    >>> af.display(cond)
+    [3 3 1 1]
+             1          0          1
+             0          1          1
+             0          0          1
+
+    >>> af.replace(a, cond, 0.3333)
+    >>> af.display(a)
+    [3 3 1 1]
+        0.3333     0.1794     0.3333
+        0.8224     0.3333     0.3333
+        0.9518     0.0081     0.3333
+
+    """
+    is_right_array = isinstance(rhs, Array)
+
+    if (is_right_array):
+        safe_call(backend.get().af_replace(lhs.arr, cond.arr, rhs.arr))
+    else:
+        safe_call(backend.get().af_replace_scalar(lhs.arr, cond.arr, ct.c_double(rhs)))
diff --git a/tests/simple/data.py b/tests/simple/data.py
index b19b388..7c95a81 100644
--- a/tests/simple/data.py
+++ b/tests/simple/data.py
@@ -76,4 +76,9 @@ def simple_data(verbose=False):
     af.transpose_inplace(a)
     display_func(a)
 
+    display_func(af.select(a > 0.3, a, -0.3))
+
+    af.replace(a, a > 0.3, -0.3)
+    display_func(a)
+
 _util.tests['data'] = simple_data

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/python-arrayfire.git



More information about the debian-science-commits mailing list