[python-arrayfire] 41/58: Adding functions from array.h and device.h

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Wed Sep 28 13:57:07 UTC 2016


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch master
in repository python-arrayfire.

commit 7594769fd905e1aff0a7f1518cf2283bc1c27710
Author: Pavan Yalamanchili <contact at pavanky.com>
Date:   Thu Sep 22 15:26:09 2016 -0700

    Adding functions from array.h and device.h
    
    - is_sparse
    - is_locked_array
    - modified eval to use af_eval_multiple
    - set_manual_eval_flag
    - get_manual_eval_flag
    - Added necessary tests
---
 arrayfire/array.py                   |   8 +++
 arrayfire/device.py                  | 100 +++++++++++++++++++++++++++++++----
 arrayfire/tests/simple/array_test.py |   2 +
 arrayfire/tests/simple/device.py     |  19 +++++++
 4 files changed, 119 insertions(+), 10 deletions(-)

diff --git a/arrayfire/array.py b/arrayfire/array.py
index 2c2eda4..4e27af6 100644
--- a/arrayfire/array.py
+++ b/arrayfire/array.py
@@ -667,6 +667,14 @@ class Array(BaseArray):
         safe_call(backend.get().af_is_vector(ct.pointer(res), self.arr))
         return res.value
 
+    def is_sparse(self):
+        """
+        Check if the array is a sparse matrix.
+        """
+        res = ct.c_bool(False)
+        safe_call(backend.get().af_is_sparse(ct.pointer(res), self.arr))
+        return res.value
+
     def is_complex(self):
         """
         Check if the array is of complex type.
diff --git a/arrayfire/device.py b/arrayfire/device.py
index 0a30385..b13989b 100644
--- a/arrayfire/device.py
+++ b/arrayfire/device.py
@@ -163,24 +163,87 @@ def sync(device=None):
     safe_call(backend.get().af_sync(dev))
 
 def __eval(*args):
-    for A in args:
-        if isinstance(A, tuple):
-            __eval(*A)
-        if isinstance(A, list):
-            __eval(*A)
-        if isinstance(A, Array):
-            safe_call(backend.get().af_eval(A.arr))
+    nargs = len(args)
+    if (nargs == 1):
+        safe_call(backend.get().af_eval(args[0].arr))
+    else:
+        c_void_p_n = ct.c_void_p * nargs
+        arrs = c_void_p_n()
+        for n in range(nargs):
+            arrs[n] = args[n].arr
+        safe_call(backend.get().af_eval_multiple(ct.c_int(nargs), ct.pointer(arrs)))
+    return
 
 def eval(*args):
     """
-    Evaluate the input
+    Evaluate one or more inputs together
 
     Parameters
     -----------
     args : arguments to be evaluated
+
+    Note
+    -----
+
+    All the input arrays to this function should be of the same size.
+
+    Examples
+    --------
+
+    >>> a = af.constant(1, 3, 3)
+    >>> b = af.constant(2, 3, 3)
+    >>> c = a + b
+    >>> d = a - b
+    >>> af.eval(c, d) # A single kernel is launched here
+    >>> c
+    arrayfire.Array()
+    Type: float
+    [3 3 1 1]
+    3.0000     3.0000     3.0000
+    3.0000     3.0000     3.0000
+    3.0000     3.0000     3.0000
+
+    >>> d
+    arrayfire.Array()
+    Type: float
+    [3 3 1 1]
+    -1.0000    -1.0000    -1.0000
+    -1.0000    -1.0000    -1.0000
+    -1.0000    -1.0000    -1.0000
+    """
+    for arg in args:
+        if not isinstance(arg, Array):
+            raise RuntimeError("All inputs to eval must be of type arrayfire.Array")
+
+    __eval(*args)
+
+def set_manual_eval_flag(flag):
+    """
+    Tells the backend JIT engine to disable heuristics for determining when to evaluate a JIT tree.
+
+    Parameters
+    ----------
+
+    flag : optional: bool.
+         - Specifies if the heuristic evaluation of the JIT tree needs to be disabled.
+
+    Note
+    ----
+    This does not affect the evaluation that occurs when a non JIT function forces the evaluation.
     """
+    safe_call(backend.get().af_set_manual_eval_flag(flag))
 
-    __eval(args)
+def get_manual_eval_flag():
+    """
+    Query the backend JIT engine to see if the user disabled heuristic evaluation of the JIT tree.
+
+    Note
+    ----
+    This does not affect the evaluation that occurs when a non JIT function forces the evaluation.
+    """
+    res = ct.c_bool(False)
+    safe_call(backend.get().af_get_manual_eval_flag(ct.pointer(res)))
+    return res.value
 
 def device_mem_info():
     """
@@ -258,10 +321,27 @@ def lock_array(a):
 
     Note
     -----
-        - The device pointer of `a` is not freed by memory manager until `unlock_device_ptr()` is called.
+        - The device pointer of `a` is not freed by memory manager until `unlock_array()` is called.
     """
     safe_call(backend.get().af_lock_array(a.arr))
 
+def is_locked_array(a):
+    """
+    Check if the input array is locked by the user.
+
+    Parameters
+    ----------
+    a: af.Array
+       - A multi dimensional arrayfire array.
+
+    Returns
+    -----------
+    A bool specifying if the input array is locked.
+    """
+    res = ct.c_bool(False)
+    safe_call(backend.get().af_is_locked_array(ct.pointer(res), a.arr))
+    return res.value
+
 def unlock_device_ptr(a):
     """
     This functions is deprecated. Please use unlock_array instead.
diff --git a/arrayfire/tests/simple/array_test.py b/arrayfire/tests/simple/array_test.py
index 1aec934..0c6ab52 100644
--- a/arrayfire/tests/simple/array_test.py
+++ b/arrayfire/tests/simple/array_test.py
@@ -60,4 +60,6 @@ def simple_array(verbose=False):
     print_func(arr)
     print_func(lst)
 
+    print_func(a.is_sparse())
+
 _util.tests['array'] = simple_array
diff --git a/arrayfire/tests/simple/device.py b/arrayfire/tests/simple/device.py
index 925add2..279fa31 100644
--- a/arrayfire/tests/simple/device.py
+++ b/arrayfire/tests/simple/device.py
@@ -51,4 +51,23 @@ def simple_device(verbose=False):
     af.lock_array(c)
     af.unlock_array(c)
 
+    a = af.constant(1, 3, 3)
+    b = af.constant(2, 3, 3)
+    af.eval(a)
+    af.eval(b)
+    print_func(a)
+    print_func(b)
+    c = a + b
+    d = a - b
+    af.eval(c, d)
+    print_func(c)
+    print_func(d)
+
+    print_func(af.set_manual_eval_flag(True))
+    assert(af.get_manual_eval_flag() == True)
+    print_func(af.set_manual_eval_flag(False))
+    assert(af.get_manual_eval_flag() == False)
+
+    display_func(af.is_locked_array(a))
+
 _util.tests['device'] = simple_device

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/python-arrayfire.git



More information about the debian-science-commits mailing list