[python-arrayfire] 38/58: FEAT: Adding clamp function and relevant tests

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Wed Sep 28 13:57:07 UTC 2016


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch master
in repository python-arrayfire.

commit 596951d9ecc16c74c446f1b21ebe3644f3c39067
Author: Pavan Yalamanchili <contact at pavanky.com>
Date:   Thu Sep 22 14:28:34 2016 -0700

    FEAT: Adding clamp function and relevant tests
---
 arrayfire/arith.py              | 38 ++++++++++++++++++++++++++++++++++++++
 arrayfire/tests/simple/arith.py |  5 +++++
 2 files changed, 43 insertions(+)

diff --git a/arrayfire/arith.py b/arrayfire/arith.py
index 4c73e59..3b397c4 100644
--- a/arrayfire/arith.py
+++ b/arrayfire/arith.py
@@ -126,6 +126,44 @@ def maxof(lhs, rhs):
     """
     return _arith_binary_func(lhs, rhs, backend.get().af_maxof)
 
+def clamp(val, low, high):
+    """
+    Clamp the input value between low and high
+
+
+    Parameters
+    ----------
+    val  : af.Array
+          Multi dimensional arrayfire array to be clamped.
+
+    low  : af.Array or scalar
+          Multi dimensional arrayfire array or a scalar number denoting the lower value(s).
+
+    high : af.Array or scalar
+          Multi dimensional arrayfire array or a scalar number denoting the higher value(s).
+    """
+    out = Array()
+
+    is_low_array = isinstance(low, Array)
+    is_high_array = isinstance(high, Array)
+
+    vdims = dim4_to_tuple(val.dims())
+    vty = val.type()
+
+    if not is_low_array:
+        low_arr = constant_array(low, vdims[0], vdims[1], vdims[2], vdims[3], vty)
+    else:
+        low_arr = low.arr
+
+    if not is_high_array:
+        high_arr = constant_array(high, vdims[0], vdims[1], vdims[2], vdims[3], vty)
+    else:
+        high_arr = high.arr
+
+    safe_call(backend.get().af_clamp(ct.pointer(out.arr), val.arr, low_arr, high_arr, _bcast_var.get()))
+
+    return out
+
 def rem(lhs, rhs):
     """
     Find the remainder.
diff --git a/arrayfire/tests/simple/arith.py b/arrayfire/tests/simple/arith.py
index f8407c1..84c291a 100644
--- a/arrayfire/tests/simple/arith.py
+++ b/arrayfire/tests/simple/arith.py
@@ -134,6 +134,11 @@ def simple_arith(verbose = False):
     display_func(af.cast(a, af.Dtype.c32))
     display_func(af.maxof(a,b))
     display_func(af.minof(a,b))
+
+    display_func(af.clamp(a, 0, 1))
+    display_func(af.clamp(a, 0, b))
+    display_func(af.clamp(a, b, 1))
+
     display_func(af.rem(a,b))
 
     a = af.randu(3,3) - 0.5

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/python-arrayfire.git



More information about the debian-science-commits mailing list