[python-arrayfire] 38/58: FEAT: Adding clamp function and relevant tests
Ghislain Vaillant
ghisvail-guest at moszumanska.debian.org
Wed Sep 28 13:57:07 UTC 2016
This is an automated email from the git hooks/post-receive script.
ghisvail-guest pushed a commit to branch master
in repository python-arrayfire.
commit 596951d9ecc16c74c446f1b21ebe3644f3c39067
Author: Pavan Yalamanchili <contact at pavanky.com>
Date: Thu Sep 22 14:28:34 2016 -0700
FEAT: Adding clamp function and relevant tests
---
arrayfire/arith.py | 38 ++++++++++++++++++++++++++++++++++++++
arrayfire/tests/simple/arith.py | 5 +++++
2 files changed, 43 insertions(+)
diff --git a/arrayfire/arith.py b/arrayfire/arith.py
index 4c73e59..3b397c4 100644
--- a/arrayfire/arith.py
+++ b/arrayfire/arith.py
@@ -126,6 +126,44 @@ def maxof(lhs, rhs):
"""
return _arith_binary_func(lhs, rhs, backend.get().af_maxof)
+def clamp(val, low, high):
+ """
+ Clamp the input value between low and high
+
+
+ Parameters
+ ----------
+ val : af.Array
+ Multi dimensional arrayfire array to be clamped.
+
+ low : af.Array or scalar
+ Multi dimensional arrayfire array or a scalar number denoting the lower value(s).
+
+ high : af.Array or scalar
+ Multi dimensional arrayfire array or a scalar number denoting the higher value(s).
+ """
+ out = Array()
+
+ is_low_array = isinstance(low, Array)
+ is_high_array = isinstance(high, Array)
+
+ vdims = dim4_to_tuple(val.dims())
+ vty = val.type()
+
+ if not is_low_array:
+ low_arr = constant_array(low, vdims[0], vdims[1], vdims[2], vdims[3], vty)
+ else:
+ low_arr = low.arr
+
+ if not is_high_array:
+ high_arr = constant_array(high, vdims[0], vdims[1], vdims[2], vdims[3], vty)
+ else:
+ high_arr = high.arr
+
+ safe_call(backend.get().af_clamp(ct.pointer(out.arr), val.arr, low_arr, high_arr, _bcast_var.get()))
+
+ return out
+
def rem(lhs, rhs):
"""
Find the remainder.
diff --git a/arrayfire/tests/simple/arith.py b/arrayfire/tests/simple/arith.py
index f8407c1..84c291a 100644
--- a/arrayfire/tests/simple/arith.py
+++ b/arrayfire/tests/simple/arith.py
@@ -134,6 +134,11 @@ def simple_arith(verbose = False):
display_func(af.cast(a, af.Dtype.c32))
display_func(af.maxof(a,b))
display_func(af.minof(a,b))
+
+ display_func(af.clamp(a, 0, 1))
+ display_func(af.clamp(a, 0, b))
+ display_func(af.clamp(a, b, 1))
+
display_func(af.rem(a,b))
a = af.randu(3,3) - 0.5
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/python-arrayfire.git
More information about the debian-science-commits
mailing list