[python-arrayfire] 41/58: Adding functions from array.h and device.h
Ghislain Vaillant
ghisvail-guest at moszumanska.debian.org
Wed Sep 28 13:57:07 UTC 2016
This is an automated email from the git hooks/post-receive script.
ghisvail-guest pushed a commit to branch master
in repository python-arrayfire.
commit 7594769fd905e1aff0a7f1518cf2283bc1c27710
Author: Pavan Yalamanchili <contact at pavanky.com>
Date: Thu Sep 22 15:26:09 2016 -0700
Adding functions from array.h and device.h
- is_sparse
- is_locked_array
- modified eval to use af_eval_multiple
- set_manual_eval_flag
- get_manual_eval_flag
- Added necessary tests
---
arrayfire/array.py | 8 +++
arrayfire/device.py | 100 +++++++++++++++++++++++++++++++----
arrayfire/tests/simple/array_test.py | 2 +
arrayfire/tests/simple/device.py | 19 +++++++
4 files changed, 119 insertions(+), 10 deletions(-)
diff --git a/arrayfire/array.py b/arrayfire/array.py
index 2c2eda4..4e27af6 100644
--- a/arrayfire/array.py
+++ b/arrayfire/array.py
@@ -667,6 +667,14 @@ class Array(BaseArray):
safe_call(backend.get().af_is_vector(ct.pointer(res), self.arr))
return res.value
+ def is_sparse(self):
+ """
+ Check if the array is a sparse matrix.
+ """
+ res = ct.c_bool(False)
+ safe_call(backend.get().af_is_sparse(ct.pointer(res), self.arr))
+ return res.value
+
def is_complex(self):
"""
Check if the array is of complex type.
diff --git a/arrayfire/device.py b/arrayfire/device.py
index 0a30385..b13989b 100644
--- a/arrayfire/device.py
+++ b/arrayfire/device.py
@@ -163,24 +163,87 @@ def sync(device=None):
safe_call(backend.get().af_sync(dev))
def __eval(*args):
- for A in args:
- if isinstance(A, tuple):
- __eval(*A)
- if isinstance(A, list):
- __eval(*A)
- if isinstance(A, Array):
- safe_call(backend.get().af_eval(A.arr))
+ nargs = len(args)
+ if (nargs == 1):
+ safe_call(backend.get().af_eval(args[0].arr))
+ else:
+ c_void_p_n = ct.c_void_p * nargs
+ arrs = c_void_p_n()
+ for n in range(nargs):
+ arrs[n] = args[n].arr
+ safe_call(backend.get().af_eval_multiple(ct.c_int(nargs), ct.pointer(arrs)))
+ return
def eval(*args):
"""
- Evaluate the input
+ Evaluate one or more inputs together
Parameters
-----------
args : arguments to be evaluated
+
+ Note
+ -----
+
+ All the input arrays to this function should be of the same size.
+
+ Examples
+ --------
+
+ >>> a = af.constant(1, 3, 3)
+ >>> b = af.constant(2, 3, 3)
+ >>> c = a + b
+ >>> d = a - b
+ >>> af.eval(c, d) # A single kernel is launched here
+ >>> c
+ arrayfire.Array()
+ Type: float
+ [3 3 1 1]
+ 3.0000 3.0000 3.0000
+ 3.0000 3.0000 3.0000
+ 3.0000 3.0000 3.0000
+
+ >>> d
+ arrayfire.Array()
+ Type: float
+ [3 3 1 1]
+ -1.0000 -1.0000 -1.0000
+ -1.0000 -1.0000 -1.0000
+ -1.0000 -1.0000 -1.0000
+ """
+ for arg in args:
+ if not isinstance(arg, Array):
+ raise RuntimeError("All inputs to eval must be of type arrayfire.Array")
+
+ __eval(*args)
+
+def set_manual_eval_flag(flag):
+ """
+ Tells the backend JIT engine to disable heuristics for determining when to evaluate a JIT tree.
+
+ Parameters
+ ----------
+
+ flag : optional: bool.
+ - Specifies if the heuristic evaluation of the JIT tree needs to be disabled.
+
+ Note
+ ----
+ This does not affect the evaluation that occurs when a non JIT function forces the evaluation.
"""
+ safe_call(backend.get().af_set_manual_eval_flag(flag))
- __eval(args)
+def get_manual_eval_flag():
+ """
+ Query the backend JIT engine to see if the user disabled heuristic evaluation of the JIT tree.
+
+ Note
+ ----
+ This does not affect the evaluation that occurs when a non JIT function forces the evaluation.
+ """
+ res = ct.c_bool(False)
+ safe_call(backend.get().af_get_manual_eval_flag(ct.pointer(res)))
+ return res.value
def device_mem_info():
"""
@@ -258,10 +321,27 @@ def lock_array(a):
Note
-----
- - The device pointer of `a` is not freed by memory manager until `unlock_device_ptr()` is called.
+ - The device pointer of `a` is not freed by memory manager until `unlock_array()` is called.
"""
safe_call(backend.get().af_lock_array(a.arr))
+def is_locked_array(a):
+ """
+ Check if the input array is locked by the user.
+
+ Parameters
+ ----------
+ a: af.Array
+ - A multi dimensional arrayfire array.
+
+ Returns
+ -----------
+ A bool specifying if the input array is locked.
+ """
+ res = ct.c_bool(False)
+ safe_call(backend.get().af_is_locked_array(ct.pointer(res), a.arr))
+ return res.value
+
def unlock_device_ptr(a):
"""
This functions is deprecated. Please use unlock_array instead.
diff --git a/arrayfire/tests/simple/array_test.py b/arrayfire/tests/simple/array_test.py
index 1aec934..0c6ab52 100644
--- a/arrayfire/tests/simple/array_test.py
+++ b/arrayfire/tests/simple/array_test.py
@@ -60,4 +60,6 @@ def simple_array(verbose=False):
print_func(arr)
print_func(lst)
+ print_func(a.is_sparse())
+
_util.tests['array'] = simple_array
diff --git a/arrayfire/tests/simple/device.py b/arrayfire/tests/simple/device.py
index 925add2..279fa31 100644
--- a/arrayfire/tests/simple/device.py
+++ b/arrayfire/tests/simple/device.py
@@ -51,4 +51,23 @@ def simple_device(verbose=False):
af.lock_array(c)
af.unlock_array(c)
+ a = af.constant(1, 3, 3)
+ b = af.constant(2, 3, 3)
+ af.eval(a)
+ af.eval(b)
+ print_func(a)
+ print_func(b)
+ c = a + b
+ d = a - b
+ af.eval(c, d)
+ print_func(c)
+ print_func(d)
+
+ print_func(af.set_manual_eval_flag(True))
+ assert(af.get_manual_eval_flag() == True)
+ print_func(af.set_manual_eval_flag(False))
+ assert(af.get_manual_eval_flag() == False)
+
+ display_func(af.is_locked_array(a))
+
_util.tests['device'] = simple_device
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/python-arrayfire.git
More information about the debian-science-commits
mailing list