[python-dtcwt] 189/497: refactor opencl backend to support modified rotational symmetry wavelets

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Tue Jul 21 18:06:03 UTC 2015


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch debian/sid
in repository python-dtcwt.

commit d42a78369c6ad037d7b8b6e1641b91f847506fb7
Author: Rich Wareham <rjw57 at cam.ac.uk>
Date:   Mon Nov 11 17:17:30 2013 +0000

    refactor opencl backend to support modified rotational symmetry wavelets
---
 docs/reference.rst                                 |  12 +
 dtcwt/backend/__init__.py                          |  95 +++++++
 dtcwt/backend/{numpy => backend_numpy}/__init__.py |   0
 dtcwt/backend/{numpy => backend_numpy}/lowlevel.py |   0
 .../{numpy => backend_numpy}/transform2d.py        |  79 +-----
 dtcwt/lowlevel.py                                  |   2 +-
 dtcwt/opencl/transform2d.py                        | 315 +++++++++++----------
 dtcwt/transform2d.py                               |  11 +-
 tests/testopenclxfm2.py                            |   7 +
 9 files changed, 296 insertions(+), 225 deletions(-)

diff --git a/docs/reference.rst b/docs/reference.rst
index a2fc241..a962fce 100644
--- a/docs/reference.rst
+++ b/docs/reference.rst
@@ -33,3 +33,15 @@ here just in case you do.
 
 .. automodule:: dtcwt.lowlevel
     :members:
+
+Backends
+````````
+
+.. automodule:: dtcwt.backend
+    :members:
+
+NumPy
+'''''
+
+.. automodule:: dtcwt.backend.backend_numpy
+    :members:
diff --git a/dtcwt/backend/__init__.py b/dtcwt/backend/__init__.py
index e69de29..34d08bf 100644
--- a/dtcwt/backend/__init__.py
+++ b/dtcwt/backend/__init__.py
@@ -0,0 +1,95 @@
+from dtcwt.utils import asfarray
+from dtcwt.defaults import DEFAULT_BIORT, DEFAULT_QSHIFT
+
+class TransformDomainSignal(object):
+    """A representation of a transform domain signal.
+
+    Backends are free to implement any class which respects this interface for
+    storing transform-domain signals. The inverse transform may accept a
+    backend-specific version of this class but should always accept any class
+    which corresponds to this interface.
+
+    .. py:attribute:: lowpass
+        
+        A NumPy-compatible array containing the coarsest scale lowpass signal.
+
+    .. py:attribute:: subbands
+        
+        A tuple where each element is the complex subband coefficients for
+        corresponding scales finest to coarsest.
+
+    .. py:attribute:: scales
+        
+        *(optional)* A tuple where each element is a NumPy-compatible array
+        containing the lowpass signal for corresponding scales finest to
+        coarsest. This is not required for the inverse and may be *None*.
+
+    """
+    def __init__(self, lowpass, subbands, scales=None):
+        self.lowpass = asfarray(lowpass)
+        self.subbands = tuple(asfarray(x) for x in subbands)
+        self.scales = tuple(asfarray(x) for x in scales) if scales is not None else None
+
+class ReconstructedSignal(object):
+    """
+    A representation of the reconstructed signal from the inverse transform. A
+    backend is free to implement their own version of this class providing it
+    corresponds to the interface documented.
+
+    .. py:attribute:: value
+
+        A NumPy-compatible array containing the reconstructed signal.
+
+    """
+    def __init__(self, value):
+        self.value = asfarray(value)
+
+class Transform2d(object):
+    """
+    An implementation of a 2D DT-CWT transformation. Backends must provide a
+    transform class which provides an interface compatible with this base
+    class.
+
+    :param biort: Level 1 wavelets to use. See :py:func:`biort`.
+    :param qshift: Level >= 2 wavelets to use. See :py:func:`qshift`.
+
+    If *biort* or *qshift* are strings, they are used as an argument to the
+    :py:func:`biort` or :py:func:`qshift` functions. Otherwise, they are
+    interpreted as tuples of vectors giving filter coefficients. In the *biort*
+    case, this should be (h0o, g0o, h1o, g1o). In the *qshift* case, this should
+    be (h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b).
+
+    In some cases the tuples may have more elements. This is used to represent
+    the :ref:`rot-symm-wavelets`.
+    
+    """
+    def __init__(self, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT):
+        raise NotImplementedError()
+
+    def forward(self, X, nlevels=3, include_scale=False):
+        """Perform a *n*-level DTCWT-2D decompostion on a 2D matrix *X*.
+
+        :param X: 2D real array
+        :param nlevels: Number of levels of wavelet decomposition
+
+        :returns td: A :pyclass:`dtcwt.backend.TransformDomainSignal` compatible object representing the transform-domain signal
+
+        """
+        raise NotImplementedError()
+
+    def inverse(self, td_signal, gain_mask=None):
+        """Perform an *n*-level dual-tree complex wavelet (DTCWT) 2D
+        reconstruction.
+
+        :param td_signal: A :pyclass:`dtcwt.backend.TransformDomainSignal`-like class holding the transform domain representation to invert.
+        :param gain_mask: Gain to be applied to each subband.
+
+        :returns Z: A :pyclass:`dtcwt.backend.ReconstructedSignal` compatible instance with the reconstruction.
+
+        The (*d*, *l*)-th element of *gain_mask* is gain for subband with direction
+        *d* at level *l*. If gain_mask[d,l] == 0, no computation is performed for
+        band (d,l). Default *gain_mask* is all ones. Note that both *d* and *l* are
+        zero-indexed.
+
+        """
+        raise NotImplementedError()
diff --git a/dtcwt/backend/numpy/__init__.py b/dtcwt/backend/backend_numpy/__init__.py
similarity index 100%
rename from dtcwt/backend/numpy/__init__.py
rename to dtcwt/backend/backend_numpy/__init__.py
diff --git a/dtcwt/backend/numpy/lowlevel.py b/dtcwt/backend/backend_numpy/lowlevel.py
similarity index 100%
rename from dtcwt/backend/numpy/lowlevel.py
rename to dtcwt/backend/backend_numpy/lowlevel.py
diff --git a/dtcwt/backend/numpy/transform2d.py b/dtcwt/backend/backend_numpy/transform2d.py
similarity index 78%
rename from dtcwt/backend/numpy/transform2d.py
rename to dtcwt/backend/backend_numpy/transform2d.py
index f18153c..ff1b8f6 100644
--- a/dtcwt/backend/numpy/transform2d.py
+++ b/dtcwt/backend/backend_numpy/transform2d.py
@@ -3,26 +3,12 @@ import logging
 
 from six.moves import xrange
 
-from dtcwt import biort as _biort, qshift as _qshift
+from dtcwt.backend import TransformDomainSignal, ReconstructedSignal
+from dtcwt.coeffs import biort as _biort, qshift as _qshift
 from dtcwt.defaults import DEFAULT_BIORT, DEFAULT_QSHIFT
 from dtcwt.lowlevel import colfilter, coldfilt, colifilt
 from dtcwt.utils import appropriate_complex_type_for, asfarray
 
-from dtcwt import biort as _biort, qshift as _qshift
-
-class ForwardTransformResultNumPy(object):
-    def __init__(self, Yl, Yh, Yscale=None):
-        self.lowpass = Yl
-        self.highpass_coeffs = Yh
-        self.scales = Yscale
-
-class InverseTransformResultNumPy(object):
-    def __init__(self, X):
-        self._X = X
-
-    def to_array(self):
-        return self._X
-
 class Transform2dNumPy(object):
     def __init__(self, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT):
         # Load bi-orthogonal wavelets
@@ -40,27 +26,6 @@ class Transform2dNumPy(object):
     def forward(self, X, nlevels=3, include_scale=False):
         """Perform a *n*-level DTCWT-2D decompostion on a 2D matrix *X*.
 
-        :param X: 2D real array
-        :param nlevels: Number of levels of wavelet decomposition
-        :param biort: Level 1 wavelets to use. See :py:func:`biort`.
-        :param qshift: Level >= 2 wavelets to use. See :py:func:`qshift`.
-
-        :returns Yl: The real lowpass image from the final level
-        :returns Yh: A tuple containing the complex highpass subimages for each level.
-        :returns Yscale: If *include_scale* is True, a tuple containing real lowpass coefficients for every scale.
-
-        If *biort* or *qshift* are strings, they are used as an argument to the
-        :py:func:`biort` or :py:func:`qshift` functions. Otherwise, they are
-        interpreted as tuples of vectors giving filter coefficients. In the *biort*
-        case, this should be (h0o, g0o, h1o, g1o). In the *qshift* case, this should
-        be (h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b).
-
-        Example::
-
-            # Performs a 3-level transform on the real image X using the 13,19-tap
-            # filters for level 1 and the Q-shift 14-tap filters for levels >= 2.
-            Yl, Yh = dtwavexfm2(X, 3, 'near_sym_b', 'qshift_b')
-
         .. codeauthor:: Rich Wareham <rjw57 at cantab.net>, Aug 2013
         .. codeauthor:: Nick Kingsbury, Cambridge University, Sept 2001
         .. codeauthor:: Cian Shaffrey, Cambridge University, Sept 2001
@@ -111,9 +76,9 @@ class Transform2dNumPy(object):
 
         if nlevels == 0:
             if include_scale:
-                return ForwardTransformResultNumPy(X, (), ())
+                return TransformDomainSignal(X, (), ())
             else:
-                return ForwardTransformResultNumPy(X, ())
+                return TransformDomainSignal(X, ())
 
         # initialise
         Yh = [None,] * nlevels
@@ -197,44 +162,22 @@ class Transform2dNumPy(object):
                 'The rightmost column has been duplicated, prior to decomposition.')
 
         if include_scale:
-            return ForwardTransformResultNumPy(Yl, tuple(Yh), tuple(Yscale))
+            return TransformDomainSignal(Yl, tuple(Yh), tuple(Yscale))
         else:
-            return ForwardTransformResultNumPy(Yl, tuple(Yh))
+            return TransformDomainSignal(Yl, tuple(Yh))
 
-    def inverse(self, Yl, Yh, gain_mask=None):
+    def inverse(self, td_signal, gain_mask=None):
         """Perform an *n*-level dual-tree complex wavelet (DTCWT) 2D
         reconstruction.
 
-        :param Yl: The real lowpass subband from the final level
-        :param Yh: A sequence containing the complex highpass subband for each level.
-        :param biort: Level 1 wavelets to use. See :py:func:`biort`.
-        :param qshift: Level >= 2 wavelets to use. See :py:func:`qshift`.
-        :param gain_mask: Gain to be applied to each subband.
-
-        :returns Z: Reconstructed real array
-
-        The (*d*, *l*)-th element of *gain_mask* is gain for subband with direction
-        *d* at level *l*. If gain_mask[d,l] == 0, no computation is performed for
-        band (d,l). Default *gain_mask* is all ones. Note that both *d* and *l* are
-        zero-indexed.
-
-        If *biort* or *qshift* are strings, they are used as an argument to the
-        :py:func:`biort` or :py:func:`qshift` functions. Otherwise, they are
-        interpreted as tuples of vectors giving filter coefficients. In the *biort*
-        case, this should be (h0o, g0o, h1o, g1o). In the *qshift* case, this should
-        be (h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b).
-
-        Example::
-
-            # Performs a 3-level reconstruction from Yl,Yh using the 13,19-tap
-            # filters for level 1 and the Q-shift 14-tap filters for levels >= 2.
-            Z = dtwaveifm2(Yl, Yh, 'near_sym_b', 'qshift_b')
-
         .. codeauthor:: Rich Wareham <rjw57 at cantab.net>, Aug 2013
         .. codeauthor:: Nick Kingsbury, Cambridge University, May 2002
         .. codeauthor:: Cian Shaffrey, Cambridge University, May 2002
 
         """
+        Yl = td_signal.lowpass
+        Yh = td_signal.subbands
+
         a = len(Yh) # No of levels.
 
         if gain_mask is None:
@@ -310,7 +253,7 @@ class Transform2dNumPy(object):
                 # Do odd top-level filters on rows.
                 Z = (colfilter(y1.T,g0o) + colfilter(y2.T,g1o)).T
 
-        return InverseTransformResultNumPy(Z)
+        return ReconstructedSignal(Z)
 
 #==========================================================================================
 #                       **********    INTERNAL FUNCTIONS    **********
diff --git a/dtcwt/lowlevel.py b/dtcwt/lowlevel.py
index 77307e5..355f1cf 100644
--- a/dtcwt/lowlevel.py
+++ b/dtcwt/lowlevel.py
@@ -1,4 +1,4 @@
-from dtcwt.backend.numpy.lowlevel import LowLevelBackendNumPy
+from dtcwt.backend.backend_numpy.lowlevel import LowLevelBackendNumPy
 
 _BACKEND = LowLevelBackendNumPy()
 
diff --git a/dtcwt/opencl/transform2d.py b/dtcwt/opencl/transform2d.py
index 81f665a..fcc0596 100644
--- a/dtcwt/opencl/transform2d.py
+++ b/dtcwt/opencl/transform2d.py
@@ -11,6 +11,9 @@ from dtcwt.opencl.lowlevel import colfilter, coldfilt, colifilt
 from dtcwt.opencl.lowlevel import axis_convolve, axis_convolve_dfilter, q2c
 from dtcwt.opencl.lowlevel import to_device, to_queue, to_array, empty
 
+from dtcwt.backend import TransformDomainSignal, ReconstructedSignal
+from dtcwt.backend.backend_numpy.transform2d import Transform2dNumPy
+
 try:
     from pyopencl.array import concatenate
 except ImportError:
@@ -18,159 +21,169 @@ except ImportError:
     pass
 
 def dtwavexfm2(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, include_scale=False, queue=None):
-    """Perform a *n*-level DTCWT-2D decompostion on a 2D matrix *X*.
-
-    :param X: 2D real array
-    :param nlevels: Number of levels of wavelet decomposition
-    :param biort: Level 1 wavelets to use. See :py:func:`biort`.
-    :param qshift: Level >= 2 wavelets to use. See :py:func:`qshift`.
-
-    :returns Yl: The real lowpass image from the final level
-    :returns Yh: A tuple containing the complex highpass subimages for each level.
-    :returns Yscale: If *include_scale* is True, a tuple containing real lowpass coefficients for every scale.
-
-    If *biort* or *qshift* are strings, they are used as an argument to the
-    :py:func:`biort` or :py:func:`qshift` functions. Otherwise, they are
-    interpreted as tuples of vectors giving filter coefficients. In the *biort*
-    case, this should be (h0o, g0o, h1o, g1o). In the *qshift* case, this should
-    be (h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b).
-
-    Example::
-
-        # Performs a 3-level transform on the real image X using the 13,19-tap
-        # filters for level 1 and the Q-shift 14-tap filters for levels >= 2.
-        Yl, Yh = dtwavexfm2(X, 3, 'near_sym_b', 'qshift_b')
-
-    .. codeauthor:: Rich Wareham <rjw57 at cantab.net>, Aug 2013
-    .. codeauthor:: Nick Kingsbury, Cambridge University, Sept 2001
-    .. codeauthor:: Cian Shaffrey, Cambridge University, Sept 2001
-
-    """
-    queue = to_queue(queue)
-    X = np.atleast_2d(asfarray(X))
-
-    # Try to load coefficients if biort is a string parameter
-    try:
-        h0o, g0o, h1o, g1o = tuple(to_device(x) for x in _biort(biort))
-    except TypeError:
-        h0o, g0o, h1o, g1o = tuple(to_device(x) for x in biort)
-
-    # Try to load coefficients if qshift is a string parameter
-    try:
-        h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = tuple(to_device(x) for x in _qshift(qshift))
-    except TypeError:
-        h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = tuple(to_device(x) for x in qshift)
-
-    original_size = X.shape
-
-    if len(X.shape) >= 3:
-        raise ValueError('The entered image is {0}, please enter each image slice separately.'.
-                format('x'.join(list(str(s) for s in X.shape))))
-
-    # The next few lines of code check to see if the image is odd in size, if so an extra ...
-    # row/column will be added to the bottom/right of the image
-    initial_row_extend = 0  #initialise
-    initial_col_extend = 0
-    if original_size[0] % 2 != 0:
-        # if X.shape[0] is not divisible by 2 then we need to extend X by adding a row at the bottom
-        X = np.vstack((X, X[[-1],:]))  # Any further extension will be done in due course.
-        initial_row_extend = 1
-
-    if original_size[1] % 2 != 0:
-        # if X.shape[1] is not divisible by 2 then we need to extend X by adding a col to the left
-        X = np.hstack((X, X[:,[-1]]))
-        initial_col_extend = 1
-
-    extended_size = X.shape
-
-    if nlevels == 0:
-        if include_scale:
-            return X, (), ()
-        else:
-            return X, ()
-
-    # initialise
-    Yh = [None,] * nlevels
+    t = Transform2dOpenCL(biort=biort, qshift=qshift, queue=queue)
+    r = t.forward(X, nlevels=nlevels, include_scale=include_scale)
     if include_scale:
-        # this is only required if the user specifies a third output component.
-        Yscale = [None,] * nlevels
-
-    complex_dtype = appropriate_complex_type_for(X)
-
-    if nlevels >= 1:
-        # Do odd top-level filters on cols.
-        Lo = axis_convolve(X,h0o,axis=0,queue=queue)
-        Hi = axis_convolve(X,h1o,axis=0,queue=queue)
-
-        # Do odd top-level filters on rows.
-        LoLo = axis_convolve(Lo,h0o,axis=1)
-
-        Yh[0] = q2c(
-            axis_convolve(Hi,h0o,axis=1,queue=queue),
-            axis_convolve(Lo,h1o,axis=1,queue=queue),
-            axis_convolve(Hi,h1o,axis=1,queue=queue),
-        )
-
+        return r.lowpass, r.subbands, r.scales
+    else:
+        return r.lowpass, r.subbands
+
+class Transform2dOpenCL(Transform2dNumPy):
+    def __init__(self, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, queue=None):
+        super(Transform2dOpenCL, self).__init__(biort=biort, qshift=qshift)
+        self.queue = to_queue(queue)
+
+    def forward(self, X, nlevels=3, include_scale=False):
+        """Perform a *n*-level DTCWT-2D decompostion on a 2D matrix *X*.
+        
+        """
+        queue = self.queue
+        X = np.atleast_2d(asfarray(X))
+
+        # If biort has 6 elements instead of 4, then it's a modified
+        # rotationally symmetric wavelet
+        # FIXME: there's probably a nicer way to do this
+        if len(self.biort) == 4:
+            h0o, g0o, h1o, g1o = self.biort
+        elif len(self.biort) == 6:
+            h0o, g0o, h1o, g1o, h2o, g2o = self.biort
+        else:
+            raise ValueError('Biort wavelet must have 6 or 4 components.')
+
+        # If qshift has 10 elements instead of 8, then it's a modified
+        # rotationally symmetric wavelet
+        # FIXME: there's probably a nicer way to do this
+        if len(self.qshift) == 8:
+            h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = self.qshift
+        elif len(self.qshift) == 10:
+            h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b, h2a, h2b = self.qshift
+        else:
+            raise ValueError('Qshift wavelet must have 10 or 8 components.')
+
+        original_size = X.shape
+
+        if len(X.shape) >= 3:
+            raise ValueError('The entered image is {0}, please enter each image slice separately.'.
+                    format('x'.join(list(str(s) for s in X.shape))))
+
+        # The next few lines of code check to see if the image is odd in size, if so an extra ...
+        # row/column will be added to the bottom/right of the image
+        initial_row_extend = 0  #initialise
+        initial_col_extend = 0
+        if original_size[0] % 2 != 0:
+            # if X.shape[0] is not divisible by 2 then we need to extend X by adding a row at the bottom
+            X = np.vstack((X, X[[-1],:]))  # Any further extension will be done in due course.
+            initial_row_extend = 1
+
+        if original_size[1] % 2 != 0:
+            # if X.shape[1] is not divisible by 2 then we need to extend X by adding a col to the left
+            X = np.hstack((X, X[:,[-1]]))
+            initial_col_extend = 1
+
+        extended_size = X.shape
+
+        if nlevels == 0:
+            if include_scale:
+                return TransformDomainSignal(X, (), ())
+            else:
+                return TransformDomainSignal(X, ())
+
+        # initialise
+        Yh = [None,] * nlevels
         if include_scale:
-            Yscale[0] = LoLo
-
-    for level in xrange(1, nlevels):
-        row_size, col_size = LoLo.shape
-
-        if row_size % 4 != 0:
-            # Extend by 2 rows if no. of rows of LoLo are not divisible by 4
-            LoLo = to_array(LoLo)
-            LoLo = np.vstack((LoLo[:1,:], LoLo, LoLo[-1:,:]))
-
-        if col_size % 4 != 0:
-            # Extend by 2 cols if no. of cols of LoLo are not divisible by 4
-            LoLo = to_array(LoLo)
-            LoLo = np.hstack((LoLo[:,:1], LoLo, LoLo[:,-1:]))
-
-        # Do even Qshift filters on rows.
-        Lo = axis_convolve_dfilter(LoLo,h0b,axis=0,queue=queue)
-        Hi = axis_convolve_dfilter(LoLo,h1b,axis=0,queue=queue)
-
-        # Do even Qshift filters on columns.
-        LoLo = axis_convolve_dfilter(Lo,h0b,axis=1,queue=queue)
-
-        Yh[level] = q2c(
-            axis_convolve_dfilter(Hi,h0b,axis=1,queue=queue),
-            axis_convolve_dfilter(Lo,h1b,axis=1,queue=queue),
-            axis_convolve_dfilter(Hi,h1b,axis=1,queue=queue),
-        )
-
+            # this is only required if the user specifies a third output component.
+            Yscale = [None,] * nlevels
+
+        complex_dtype = appropriate_complex_type_for(X)
+
+        if nlevels >= 1:
+            # Do odd top-level filters on cols.
+            Lo = axis_convolve(X,h0o,axis=0,queue=queue)
+            Hi = axis_convolve(X,h1o,axis=0,queue=queue)
+            if len(self.biort) >= 6:
+                Ba = axis_convolve(X,h2o,axis=0,queue=queue)
+
+            # Do odd top-level filters on rows.
+            LoLo = axis_convolve(Lo,h0o,axis=1)
+
+            if len(self.biort) >= 6:
+                diag = axis_convolve(Ba,h2o,axis=1,queue=queue)
+            else:
+                diag = axis_convolve(Hi,h1o,axis=1,queue=queue)
+
+            Yh[0] = q2c(
+                axis_convolve(Hi,h0o,axis=1,queue=queue),
+                axis_convolve(Lo,h1o,axis=1,queue=queue),
+                diag,
+            )
+
+            if include_scale:
+                Yscale[0] = LoLo
+
+        for level in xrange(1, nlevels):
+            row_size, col_size = LoLo.shape
+
+            if row_size % 4 != 0:
+                # Extend by 2 rows if no. of rows of LoLo are not divisible by 4
+                LoLo = to_array(LoLo)
+                LoLo = np.vstack((LoLo[:1,:], LoLo, LoLo[-1:,:]))
+
+            if col_size % 4 != 0:
+                # Extend by 2 cols if no. of cols of LoLo are not divisible by 4
+                LoLo = to_array(LoLo)
+                LoLo = np.hstack((LoLo[:,:1], LoLo, LoLo[:,-1:]))
+
+            # Do even Qshift filters on rows.
+            Lo = axis_convolve_dfilter(LoLo,h0b,axis=0,queue=queue)
+            Hi = axis_convolve_dfilter(LoLo,h1b,axis=0,queue=queue)
+            if len(self.qshift) >= 10:
+                Ba = axis_convolve_dfilter(LoLo,h2b,axis=0,queue=queue)
+
+            # Do even Qshift filters on columns.
+            LoLo = axis_convolve_dfilter(Lo,h0b,axis=1,queue=queue)
+
+            if len(self.qshift) >= 10:
+                diag = axis_convolve_dfilter(Ba,h2b,axis=1,queue=queue)
+            else:
+                diag = axis_convolve_dfilter(Hi,h1b,axis=1,queue=queue)
+
+            Yh[level] = q2c(
+                axis_convolve_dfilter(Hi,h0b,axis=1,queue=queue),
+                axis_convolve_dfilter(Lo,h1b,axis=1,queue=queue),
+                diag,
+            )
+
+            if include_scale:
+                Yscale[level] = LoLo
+
+        Yl = to_array(LoLo,queue=queue)
+        Yh = list(to_array(x) for x in Yh)
         if include_scale:
-            Yscale[level] = LoLo
+            Yscale = list(to_array(x) for x in Yscale)
+
+        if initial_row_extend == 1 and initial_col_extend == 1:
+            logging.warn('The image entered is now a {0} NOT a {1}.'.format(
+                'x'.join(list(str(s) for s in extended_size)),
+                'x'.join(list(str(s) for s in original_size))))
+            logging.warn(
+                'The bottom row and rightmost column have been duplicated, prior to decomposition.')
+
+        if initial_row_extend == 1 and initial_col_extend == 0:
+            logging.warn('The image entered is now a {0} NOT a {1}.'.format(
+                'x'.join(list(str(s) for s in extended_size)),
+                'x'.join(list(str(s) for s in original_size))))
+            logging.warn(
+                'The bottom row has been duplicated, prior to decomposition.')
+
+        if initial_row_extend == 0 and initial_col_extend == 1:
+            logging.warn('The image entered is now a {0} NOT a {1}.'.format(
+                'x'.join(list(str(s) for s in extended_size)),
+                'x'.join(list(str(s) for s in original_size))))
+            logging.warn(
+                'The rightmost column has been duplicated, prior to decomposition.')
 
-    Yl = to_array(LoLo,queue=queue)
-    Yh = list(to_array(x) for x in Yh)
-    if include_scale:
-        Yscale = list(to_array(x) for x in Yscale)
-
-    if initial_row_extend == 1 and initial_col_extend == 1:
-        logging.warn('The image entered is now a {0} NOT a {1}.'.format(
-            'x'.join(list(str(s) for s in extended_size)),
-            'x'.join(list(str(s) for s in original_size))))
-        logging.warn(
-            'The bottom row and rightmost column have been duplicated, prior to decomposition.')
-
-    if initial_row_extend == 1 and initial_col_extend == 0:
-        logging.warn('The image entered is now a {0} NOT a {1}.'.format(
-            'x'.join(list(str(s) for s in extended_size)),
-            'x'.join(list(str(s) for s in original_size))))
-        logging.warn(
-            'The bottom row has been duplicated, prior to decomposition.')
-
-    if initial_row_extend == 0 and initial_col_extend == 1:
-        logging.warn('The image entered is now a {0} NOT a {1}.'.format(
-            'x'.join(list(str(s) for s in extended_size)),
-            'x'.join(list(str(s) for s in original_size))))
-        logging.warn(
-            'The rightmost column has been duplicated, prior to decomposition.')
-
-    if include_scale:
-        return Yl, tuple(Yh), tuple(Yscale)
-    else:
-        return Yl, tuple(Yh)
 
+        if include_scale:
+            return TransformDomainSignal(Yl, tuple(Yh), tuple(Yscale))
+        else:
+            return TransformDomainSignal(Yl, tuple(Yh))
diff --git a/dtcwt/transform2d.py b/dtcwt/transform2d.py
index ce56154..936cf70 100644
--- a/dtcwt/transform2d.py
+++ b/dtcwt/transform2d.py
@@ -8,7 +8,8 @@ from dtcwt.defaults import DEFAULT_BIORT, DEFAULT_QSHIFT
 from dtcwt.lowlevel import colfilter, coldfilt, colifilt
 from dtcwt.utils import appropriate_complex_type_for, asfarray
 
-from dtcwt.backend.numpy.transform2d import Transform2dNumPy
+from dtcwt.backend import TransformDomainSignal
+from dtcwt.backend.backend_numpy.transform2d import Transform2dNumPy
 
 def dtwavexfm2(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, include_scale=False):
     """Perform a *n*-level DTCWT-2D decompostion on a 2D matrix *X*.
@@ -44,9 +45,9 @@ def dtwavexfm2(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, include
     res = trans.forward(X, nlevels, include_scale)
 
     if include_scale:
-        return res.lowpass, res.highpass_coeffs, res.scales
+        return res.lowpass, res.subbands, res.scales
     else:
-        return res.lowpass, res.highpass_coeffs
+        return res.lowpass, res.subbands
 
 def dtwaveifm2(Yl,Yh,biort=DEFAULT_BIORT,qshift=DEFAULT_QSHIFT,gain_mask=None):
     """Perform an *n*-level dual-tree complex wavelet (DTCWT) 2D
@@ -83,8 +84,8 @@ def dtwaveifm2(Yl,Yh,biort=DEFAULT_BIORT,qshift=DEFAULT_QSHIFT,gain_mask=None):
 
     """
     trans = Transform2dNumPy(biort, qshift)
-    res = trans.inverse(Yl, Yh, gain_mask=gain_mask)
-    return res.to_array()
+    res = trans.inverse(TransformDomainSignal(Yl, Yh), gain_mask=gain_mask)
+    return res.value
 
 # BACKWARDS COMPATIBILITY: add a dtwave{i,x}fm2b function which is a copy of
 # dtwave{i,x}fm2b. The functionality of the ...b variant is rolled into the
diff --git a/tests/testopenclxfm2.py b/tests/testopenclxfm2.py
index a112f24..7d71563 100644
--- a/tests/testopenclxfm2.py
+++ b/tests/testopenclxfm2.py
@@ -85,4 +85,11 @@ def test_0_levels():
     b = dtwavexfm2_cl(lena, nlevels=0)
     _compare_transforms(a, b)
 
+ at skip_if_no_cl
+ at attr('transform')
+def test_modified():
+    a = dtwavexfm2_np(lena, biort=biort('near_sym_b_bp'), qshift=qshift('qshift_b_bp'))
+    b = dtwavexfm2_cl(lena, biort=biort('near_sym_b_bp'), qshift=qshift('qshift_b_bp'))
+    _compare_transforms(a, b)
+
 # vim:sw=4:sts=4:et

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/python-dtcwt.git



More information about the debian-science-commits mailing list