[python-dtcwt] 199/497: do not copy results back from OpenCL by default

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Tue Jul 21 18:06:05 UTC 2015


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch debian/sid
in repository python-dtcwt.

commit 28f5b1d1bf5a0fd36a7d819eeafce34e0cdb7409
Author: Rich Wareham <rjw57 at cam.ac.uk>
Date:   Wed Nov 13 11:31:29 2013 +0000

    do not copy results back from OpenCL by default
    
    The new TransformDomainSignal representation for the output class has
    allowed one to abstract the copying from device to host until the last
    possible moment by introducing a compatible TransformDomainSignalOpenCl
    class. The copy is performed at attribute access time. It may be
    desirable to memoize the output which will be implemented in a separate
    commit.
---
 dtcwt/backend/backend_opencl/transform2d.py | 78 ++++++++++++++++++++++++-----
 1 file changed, 66 insertions(+), 12 deletions(-)

diff --git a/dtcwt/backend/backend_opencl/transform2d.py b/dtcwt/backend/backend_opencl/transform2d.py
index 50becdb..e19744f 100644
--- a/dtcwt/backend/backend_opencl/transform2d.py
+++ b/dtcwt/backend/backend_opencl/transform2d.py
@@ -15,7 +15,7 @@ from dtcwt.backend import TransformDomainSignal, ReconstructedSignal
 from dtcwt.backend.backend_numpy.transform2d import Transform2dNumPy
 
 try:
-    from pyopencl.array import concatenate
+    from pyopencl.array import concatenate, Array as CLArray
 except ImportError:
     # The lack of OpenCL will be caught by the low-level routines.
     pass
@@ -28,6 +28,47 @@ def dtwavexfm2(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, include
     else:
         return r.lowpass, r.subbands
 
+class TransformDomainSignalOpenCL(object):
+    """
+    An interface-compatible version of
+    :py:class:`dtcwt.backend.TransformDomainSignal` where the initialiser
+    arguments are assumed to by :py:class:`pyopencl.array.Array` instances.
+
+    The attributes defined in :py:class:`dtcwt.backend.TransformDomainSignal`
+    are implemented via properties. The original OpenCL arrays may be accessed
+    via the ``cl_...`` attributes.
+
+    .. py:attribute:: cl_lowpass
+
+        The CL array containing the lowpass image.
+
+    .. py:attribute:: cl_subbands
+
+        A tuple of CL arrays containing the subband images.
+
+    .. py:attribute:: cl_scales
+
+        *(optional)* Either ``None`` or a tuple of lowpass images for each
+        scale.
+
+    """
+    def __init__(self, lowpass, subbands, scales=None):
+        self.cl_lowpass = lowpass
+        self.cl_subbands = subbands
+        self.cl_scales = scales
+
+    @property
+    def lowpass(self):
+        return to_array(self.cl_lowpass) if self.cl_lowpass is not None else None
+
+    @property
+    def subbands(self):
+        return tuple(to_array(x) for x in self.cl_subbands) if self.cl_subbands is not None else None
+
+    @property
+    def scales(self):
+        return tuple(to_array(x) for x in self.cl_scales) if self.cl_scales is not None else None
+
 class Transform2dOpenCL(Transform2dNumPy):
     """
     An implementation of the 2D DT-CWT via OpenCL. *biort* and *qshift* are the
@@ -57,13 +98,25 @@ class Transform2dOpenCL(Transform2dNumPy):
 
         :returns: A :py:class:`dtcwt.backend.TransformDomainSignal` compatible object representing the transform-domain signal
 
+        .. note::
+
+            *X* may be a :py:class:`pyopencl.array.Array` instance which has
+            already been copied to the device. In which case, it must be 2D.
+            (I.e. a vector will not be auto-promoted.)
+
         .. codeauthor:: Rich Wareham <rjw57 at cantab.net>, Aug 2013
         .. codeauthor:: Nick Kingsbury, Cambridge University, Sept 2001
         .. codeauthor:: Cian Shaffrey, Cambridge University, Sept 2001
 
         """
         queue = self.queue
-        X = np.atleast_2d(asfarray(X))
+
+        if isinstance(X, CLArray):
+            if len(X.shape) != 2:
+                raise ValueError('Input array must be two-dimensional')
+        else:
+            # If not an array, copy to device
+            X = np.atleast_2d(asfarray(X))
 
         # If biort has 6 elements instead of 4, then it's a modified
         # rotationally symmetric wavelet
@@ -97,21 +150,26 @@ class Transform2dOpenCL(Transform2dNumPy):
         initial_col_extend = 0
         if original_size[0] % 2 != 0:
             # if X.shape[0] is not divisible by 2 then we need to extend X by adding a row at the bottom
+            X = to_array(X)
             X = np.vstack((X, X[[-1],:]))  # Any further extension will be done in due course.
             initial_row_extend = 1
 
         if original_size[1] % 2 != 0:
             # if X.shape[1] is not divisible by 2 then we need to extend X by adding a col to the left
+            X = to_array(X)
             X = np.hstack((X, X[:,[-1]]))
             initial_col_extend = 1
 
         extended_size = X.shape
 
+        # Copy X to the device if necessary
+        X = to_device(X, queue=queue)
+
         if nlevels == 0:
             if include_scale:
-                return TransformDomainSignal(X, (), ())
+                return TransformDomainSignalOpenCL(X, (), ())
             else:
-                return TransformDomainSignal(X, ())
+                return TransformDomainSignalOpenCL(X, ())
 
         # initialise
         Yh = [None,] * nlevels
@@ -119,7 +177,7 @@ class Transform2dOpenCL(Transform2dNumPy):
             # this is only required if the user specifies a third output component.
             Yscale = [None,] * nlevels
 
-        complex_dtype = appropriate_complex_type_for(X)
+        complex_dtype = np.complex64
 
         if nlevels >= 1:
             # Do odd top-level filters on cols.
@@ -181,10 +239,7 @@ class Transform2dOpenCL(Transform2dNumPy):
             if include_scale:
                 Yscale[level] = LoLo
 
-        Yl = to_array(LoLo,queue=queue)
-        Yh = list(to_array(x) for x in Yh)
-        if include_scale:
-            Yscale = list(to_array(x) for x in Yscale)
+        Yl = LoLo
 
         if initial_row_extend == 1 and initial_col_extend == 1:
             logging.warn('The image entered is now a {0} NOT a {1}.'.format(
@@ -207,8 +262,7 @@ class Transform2dOpenCL(Transform2dNumPy):
             logging.warn(
                 'The rightmost column has been duplicated, prior to decomposition.')
 
-
         if include_scale:
-            return TransformDomainSignal(Yl, tuple(Yh), tuple(Yscale))
+            return TransformDomainSignalOpenCL(Yl, tuple(Yh), tuple(Yscale))
         else:
-            return TransformDomainSignal(Yl, tuple(Yh))
+            return TransformDomainSignalOpenCL(Yl, tuple(Yh))

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/python-dtcwt.git



More information about the debian-science-commits mailing list