[python-arrayfire] 152/250: FEAT: Changes to use the unified backend from arrayfire when available

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Mon Mar 28 22:59:43 UTC 2016


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch debian/master
in repository python-arrayfire.

commit 9f9b2e05d5b5d0e8c2186a20968e7d28e455cc38
Author: Pavan Yalamanchili <pavan at arrayfire.com>
Date:   Tue Dec 8 19:35:27 2015 -0500

    FEAT: Changes to use the unified backend from arrayfire when available
---
 arrayfire/array.py   |   2 -
 arrayfire/library.py | 224 +++++++++++++++++++++++++++++++++++++--------------
 2 files changed, 165 insertions(+), 61 deletions(-)

diff --git a/arrayfire/array.py b/arrayfire/array.py
index 82266ae..f1e7dcb 100644
--- a/arrayfire/array.py
+++ b/arrayfire/array.py
@@ -369,8 +369,6 @@ class Array(BaseArray):
 
         _type_char='f'
 
-        backend.lock()
-
         if src is not None:
 
             if (isinstance(src, Array)):
diff --git a/arrayfire/library.py b/arrayfire/library.py
index 5701eac..0b49930 100644
--- a/arrayfire/library.py
+++ b/arrayfire/library.py
@@ -26,65 +26,6 @@ except:
         def __init__(self, v):
             self.value = v
 
-class _clibrary(object):
-
-    def __libname(self, name):
-        platform_name = platform.system()
-        assert(len(platform_name) >= 3)
-
-        libname = 'libaf' + name
-        if platform_name == 'Linux':
-            libname += '.so'
-        elif platform_name == 'Darwin':
-            libname += '.dylib'
-        elif platform_name == "Windows" or platform_name[:3] == "CYG":
-            libname += '.dll'
-            libname = libname[3:] # remove 'lib'
-            if platform_name == "Windows":
-                '''
-                Supressing crashes caused by missing dlls
-                http://stackoverflow.com/questions/8347266/missing-dll-print-message-instead-of-launching-a-popup
-                https://msdn.microsoft.com/en-us/library/windows/desktop/ms680621.aspx
-                '''
-                ct.windll.kernel32.SetErrorMode(0x0001 | 0x0002);
-        else:
-            raise OSError(platform_name + ' not supported')
-
-        return libname
-
-    def set(self, name, unsafe=False):
-        if (not unsafe and self.__lock):
-            raise RuntimeError("Can not change backend after creating an Array")
-        if (self.clibs[name] is None):
-            raise RuntimeError("Could not load any ArrayFire %s backend" % name)
-        self.name = name
-        return
-
-    def __init__(self):
-        self.clibs = {}
-        self.name = None
-        self.__lock = False
-        # Iterate in reverse order of preference
-        for name in ('cpu', 'opencl', 'cuda'):
-            try:
-                libname = self.__libname(name)
-                ct.cdll.LoadLibrary(libname)
-                self.clibs[name] = ct.CDLL(libname)
-                self.name = name
-            except:
-                self.clibs[name] = None
-
-        if (self.name is None):
-            raise RuntimeError("Could not load any ArrayFire libraries")
-
-    def get(self):
-        return self.clibs[self.name]
-
-    def lock(self):
-        self.__lock = True
-
-backend = _clibrary()
-
 class ERR(_Enum):
     """
     Error values. For internal use only.
@@ -373,3 +314,168 @@ class BACKEND(_Enum):
     CPU     = _Enum_Type(1)
     CUDA    = _Enum_Type(2)
     OPENCL  = _Enum_Type(4)
+
+class _clibrary(object):
+
+    def __libname(self, name):
+        platform_name = platform.system()
+        assert(len(platform_name) >= 3)
+
+        libname = 'libaf' + name
+        if platform_name == 'Linux':
+            libname += '.so'
+        elif platform_name == 'Darwin':
+            libname += '.dylib'
+        elif platform_name == "Windows" or platform_name[:3] == "CYG":
+            libname += '.dll'
+            libname = libname[3:] # remove 'lib'
+            if platform_name == "Windows":
+                '''
+                Supressing crashes caused by missing dlls
+                http://stackoverflow.com/questions/8347266/missing-dll-print-message-instead-of-launching-a-popup
+                https://msdn.microsoft.com/en-us/library/windows/desktop/ms680621.aspx
+                '''
+                ct.windll.kernel32.SetErrorMode(0x0001 | 0x0002);
+        else:
+            raise OSError(platform_name + ' not supported')
+
+        return libname
+
+    def set_unsafe(self, name):
+        lib = self.__clibs[name]
+        if (lib is None):
+            raise RuntimeError("Backend not found")
+        self.__name = name
+
+    def __init__(self):
+        self.__name = None
+
+        self.__clibs = {'cuda'   : None,
+                        'opencl' : None,
+                        'cpu'    : None,
+                        ''       : None}
+
+        self.__backend_map = {0 : 'default',
+                              1 : 'cpu'    ,
+                              2 : 'cuda'   ,
+                              4 : 'opencl' }
+
+        self.__backend_name_map = {'default' : 0,
+                                   'cpu'     : 1,
+                                   'cuda'    : 2,
+                                   'opencl'  : 4}
+
+        # Iterate in reverse order of preference
+        for name in ('cpu', 'opencl', 'cuda', ''):
+            try:
+                libname = self.__libname(name)
+                ct.cdll.LoadLibrary(libname)
+                self.__clibs[name] = ct.CDLL(libname)
+                self.__name = name
+            except:
+                pass
+
+        if (self.__name is None):
+            raise RuntimeError("Could not load any ArrayFire libraries")
+
+    def get_id(self, name):
+        return self.__backend_name_map[name]
+
+    def get_name(self, bk_id):
+        return self.__backend_map[bk_id]
+
+    def get(self):
+        return self.__clibs[self.__name]
+
+    def name(self):
+        return self.__name
+
+    def is_unified(self):
+        return self.__name == ''
+
+    def parse(self, res):
+        lst = []
+        for key,value in self.__backend_name_map.items():
+            if (value & res):
+                lst.append(key)
+        return tuple(lst)
+
+backend = _clibrary()
+
+def set_backend(name, unsafe=False):
+    """
+    Set a specific backend by name
+
+    Parameters
+    ----------
+
+    name : str.
+
+    unsafe : optional: bool. Default: False.
+           If False, does not switch backend if current backend is not unified backend.
+    """
+    if (backend.is_unified() == False and unsanfe == False):
+        raise RuntimeError("Can not change backend after loading %s" % name)
+
+    if (backend.is_unified()):
+        safe_call(backend.get().af_set_backend(backend.get_id(name)))
+    else:
+        backend.set_unsafe(name)
+    return
+
+def get_backend_id(A):
+    """
+    Get backend name of an array
+
+    Parameters
+    ----------
+    A    : af.Array
+
+    Returns
+    ----------
+
+    name : str.
+         Backend name
+    """
+    if (backend.is_unified()):
+        backend_id = ct.c_int(BACKEND.DEFAULT.value)
+        safe_call(backend.get().af_get_backend_id(ct.pointer(backend_id), A.arr))
+        return backend.get_name(backend_id.value)
+    else:
+        return backend.name()
+
+def get_backend_count():
+    """
+    Get number of available backends
+
+    Returns
+    ----------
+
+    count : int
+          Number of available backends
+    """
+    if (backend.is_unified()):
+        count = ct.c_int(0)
+        safe_call(backend.get().af_get_backend_count(ct.pointer(count)))
+        return count.value
+    else:
+        return 1
+
+def get_available_backends():
+    """
+    Get names of available backends
+
+    Returns
+    ----------
+
+    names : tuple of strings
+          Names of available backends
+    """
+    if (backend.is_unified()):
+        available = ct.c_int(0)
+        safe_call(backend.get().af_get_available_backends(ct.pointer(available)))
+        return backend.parse(int(available.value))
+    else:
+        return (backend.name(),)
+
+from .util import safe_call

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/python-arrayfire.git



More information about the debian-science-commits mailing list