[python-dtcwt] 189/497: refactor opencl backend to support modified rotational symmetry wavelets
Ghislain Vaillant
ghisvail-guest at moszumanska.debian.org
Tue Jul 21 18:06:03 UTC 2015
This is an automated email from the git hooks/post-receive script.
ghisvail-guest pushed a commit to branch debian/sid
in repository python-dtcwt.
commit d42a78369c6ad037d7b8b6e1641b91f847506fb7
Author: Rich Wareham <rjw57 at cam.ac.uk>
Date: Mon Nov 11 17:17:30 2013 +0000
refactor opencl backend to support modified rotational symmetry wavelets
---
docs/reference.rst | 12 +
dtcwt/backend/__init__.py | 95 +++++++
dtcwt/backend/{numpy => backend_numpy}/__init__.py | 0
dtcwt/backend/{numpy => backend_numpy}/lowlevel.py | 0
.../{numpy => backend_numpy}/transform2d.py | 79 +-----
dtcwt/lowlevel.py | 2 +-
dtcwt/opencl/transform2d.py | 315 +++++++++++----------
dtcwt/transform2d.py | 11 +-
tests/testopenclxfm2.py | 7 +
9 files changed, 296 insertions(+), 225 deletions(-)
diff --git a/docs/reference.rst b/docs/reference.rst
index a2fc241..a962fce 100644
--- a/docs/reference.rst
+++ b/docs/reference.rst
@@ -33,3 +33,15 @@ here just in case you do.
.. automodule:: dtcwt.lowlevel
:members:
+
+Backends
+````````
+
+.. automodule:: dtcwt.backend
+ :members:
+
+NumPy
+'''''
+
+.. automodule:: dtcwt.backend.backend_numpy
+ :members:
diff --git a/dtcwt/backend/__init__.py b/dtcwt/backend/__init__.py
index e69de29..34d08bf 100644
--- a/dtcwt/backend/__init__.py
+++ b/dtcwt/backend/__init__.py
@@ -0,0 +1,95 @@
+from dtcwt.utils import asfarray
+from dtcwt.defaults import DEFAULT_BIORT, DEFAULT_QSHIFT
+
+class TransformDomainSignal(object):
+ """A representation of a transform domain signal.
+
+ Backends are free to implement any class which respects this interface for
+ storing transform-domain signals. The inverse transform may accept a
+ backend-specific version of this class but should always accept any class
+ which corresponds to this interface.
+
+ .. py:attribute:: lowpass
+
+ A NumPy-compatible array containing the coarsest scale lowpass signal.
+
+ .. py:attribute:: subbands
+
+ A tuple where each element is the complex subband coefficients for
+ corresponding scales finest to coarsest.
+
+ .. py:attribute:: scales
+
+ *(optional)* A tuple where each element is a NumPy-compatible array
+ containing the lowpass signal for corresponding scales finest to
+ coarsest. This is not required for the inverse and may be *None*.
+
+ """
+ def __init__(self, lowpass, subbands, scales=None):
+ self.lowpass = asfarray(lowpass)
+ self.subbands = tuple(asfarray(x) for x in subbands)
+ self.scales = tuple(asfarray(x) for x in scales) if scales is not None else None
+
+class ReconstructedSignal(object):
+ """
+ A representation of the reconstructed signal from the inverse transform. A
+ backend is free to implement their own version of this class providing it
+ corresponds to the interface documented.
+
+ .. py:attribute:: value
+
+ A NumPy-compatible array containing the reconstructed signal.
+
+ """
+ def __init__(self, value):
+ self.value = asfarray(value)
+
+class Transform2d(object):
+ """
+ An implementation of a 2D DT-CWT transformation. Backends must provide a
+ transform class which provides an interface compatible with this base
+ class.
+
+ :param biort: Level 1 wavelets to use. See :py:func:`biort`.
+ :param qshift: Level >= 2 wavelets to use. See :py:func:`qshift`.
+
+ If *biort* or *qshift* are strings, they are used as an argument to the
+ :py:func:`biort` or :py:func:`qshift` functions. Otherwise, they are
+ interpreted as tuples of vectors giving filter coefficients. In the *biort*
+ case, this should be (h0o, g0o, h1o, g1o). In the *qshift* case, this should
+ be (h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b).
+
+ In some cases the tuples may have more elements. This is used to represent
+ the :ref:`rot-symm-wavelets`.
+
+ """
+ def __init__(self, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT):
+ raise NotImplementedError()
+
+ def forward(self, X, nlevels=3, include_scale=False):
+ """Perform a *n*-level DTCWT-2D decompostion on a 2D matrix *X*.
+
+ :param X: 2D real array
+ :param nlevels: Number of levels of wavelet decomposition
+
+ :returns td: A :pyclass:`dtcwt.backend.TransformDomainSignal` compatible object representing the transform-domain signal
+
+ """
+ raise NotImplementedError()
+
+ def inverse(self, td_signal, gain_mask=None):
+ """Perform an *n*-level dual-tree complex wavelet (DTCWT) 2D
+ reconstruction.
+
+ :param td_signal: A :pyclass:`dtcwt.backend.TransformDomainSignal`-like class holding the transform domain representation to invert.
+ :param gain_mask: Gain to be applied to each subband.
+
+ :returns Z: A :pyclass:`dtcwt.backend.ReconstructedSignal` compatible instance with the reconstruction.
+
+ The (*d*, *l*)-th element of *gain_mask* is gain for subband with direction
+ *d* at level *l*. If gain_mask[d,l] == 0, no computation is performed for
+ band (d,l). Default *gain_mask* is all ones. Note that both *d* and *l* are
+ zero-indexed.
+
+ """
+ raise NotImplementedError()
diff --git a/dtcwt/backend/numpy/__init__.py b/dtcwt/backend/backend_numpy/__init__.py
similarity index 100%
rename from dtcwt/backend/numpy/__init__.py
rename to dtcwt/backend/backend_numpy/__init__.py
diff --git a/dtcwt/backend/numpy/lowlevel.py b/dtcwt/backend/backend_numpy/lowlevel.py
similarity index 100%
rename from dtcwt/backend/numpy/lowlevel.py
rename to dtcwt/backend/backend_numpy/lowlevel.py
diff --git a/dtcwt/backend/numpy/transform2d.py b/dtcwt/backend/backend_numpy/transform2d.py
similarity index 78%
rename from dtcwt/backend/numpy/transform2d.py
rename to dtcwt/backend/backend_numpy/transform2d.py
index f18153c..ff1b8f6 100644
--- a/dtcwt/backend/numpy/transform2d.py
+++ b/dtcwt/backend/backend_numpy/transform2d.py
@@ -3,26 +3,12 @@ import logging
from six.moves import xrange
-from dtcwt import biort as _biort, qshift as _qshift
+from dtcwt.backend import TransformDomainSignal, ReconstructedSignal
+from dtcwt.coeffs import biort as _biort, qshift as _qshift
from dtcwt.defaults import DEFAULT_BIORT, DEFAULT_QSHIFT
from dtcwt.lowlevel import colfilter, coldfilt, colifilt
from dtcwt.utils import appropriate_complex_type_for, asfarray
-from dtcwt import biort as _biort, qshift as _qshift
-
-class ForwardTransformResultNumPy(object):
- def __init__(self, Yl, Yh, Yscale=None):
- self.lowpass = Yl
- self.highpass_coeffs = Yh
- self.scales = Yscale
-
-class InverseTransformResultNumPy(object):
- def __init__(self, X):
- self._X = X
-
- def to_array(self):
- return self._X
-
class Transform2dNumPy(object):
def __init__(self, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT):
# Load bi-orthogonal wavelets
@@ -40,27 +26,6 @@ class Transform2dNumPy(object):
def forward(self, X, nlevels=3, include_scale=False):
"""Perform a *n*-level DTCWT-2D decompostion on a 2D matrix *X*.
- :param X: 2D real array
- :param nlevels: Number of levels of wavelet decomposition
- :param biort: Level 1 wavelets to use. See :py:func:`biort`.
- :param qshift: Level >= 2 wavelets to use. See :py:func:`qshift`.
-
- :returns Yl: The real lowpass image from the final level
- :returns Yh: A tuple containing the complex highpass subimages for each level.
- :returns Yscale: If *include_scale* is True, a tuple containing real lowpass coefficients for every scale.
-
- If *biort* or *qshift* are strings, they are used as an argument to the
- :py:func:`biort` or :py:func:`qshift` functions. Otherwise, they are
- interpreted as tuples of vectors giving filter coefficients. In the *biort*
- case, this should be (h0o, g0o, h1o, g1o). In the *qshift* case, this should
- be (h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b).
-
- Example::
-
- # Performs a 3-level transform on the real image X using the 13,19-tap
- # filters for level 1 and the Q-shift 14-tap filters for levels >= 2.
- Yl, Yh = dtwavexfm2(X, 3, 'near_sym_b', 'qshift_b')
-
.. codeauthor:: Rich Wareham <rjw57 at cantab.net>, Aug 2013
.. codeauthor:: Nick Kingsbury, Cambridge University, Sept 2001
.. codeauthor:: Cian Shaffrey, Cambridge University, Sept 2001
@@ -111,9 +76,9 @@ class Transform2dNumPy(object):
if nlevels == 0:
if include_scale:
- return ForwardTransformResultNumPy(X, (), ())
+ return TransformDomainSignal(X, (), ())
else:
- return ForwardTransformResultNumPy(X, ())
+ return TransformDomainSignal(X, ())
# initialise
Yh = [None,] * nlevels
@@ -197,44 +162,22 @@ class Transform2dNumPy(object):
'The rightmost column has been duplicated, prior to decomposition.')
if include_scale:
- return ForwardTransformResultNumPy(Yl, tuple(Yh), tuple(Yscale))
+ return TransformDomainSignal(Yl, tuple(Yh), tuple(Yscale))
else:
- return ForwardTransformResultNumPy(Yl, tuple(Yh))
+ return TransformDomainSignal(Yl, tuple(Yh))
- def inverse(self, Yl, Yh, gain_mask=None):
+ def inverse(self, td_signal, gain_mask=None):
"""Perform an *n*-level dual-tree complex wavelet (DTCWT) 2D
reconstruction.
- :param Yl: The real lowpass subband from the final level
- :param Yh: A sequence containing the complex highpass subband for each level.
- :param biort: Level 1 wavelets to use. See :py:func:`biort`.
- :param qshift: Level >= 2 wavelets to use. See :py:func:`qshift`.
- :param gain_mask: Gain to be applied to each subband.
-
- :returns Z: Reconstructed real array
-
- The (*d*, *l*)-th element of *gain_mask* is gain for subband with direction
- *d* at level *l*. If gain_mask[d,l] == 0, no computation is performed for
- band (d,l). Default *gain_mask* is all ones. Note that both *d* and *l* are
- zero-indexed.
-
- If *biort* or *qshift* are strings, they are used as an argument to the
- :py:func:`biort` or :py:func:`qshift` functions. Otherwise, they are
- interpreted as tuples of vectors giving filter coefficients. In the *biort*
- case, this should be (h0o, g0o, h1o, g1o). In the *qshift* case, this should
- be (h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b).
-
- Example::
-
- # Performs a 3-level reconstruction from Yl,Yh using the 13,19-tap
- # filters for level 1 and the Q-shift 14-tap filters for levels >= 2.
- Z = dtwaveifm2(Yl, Yh, 'near_sym_b', 'qshift_b')
-
.. codeauthor:: Rich Wareham <rjw57 at cantab.net>, Aug 2013
.. codeauthor:: Nick Kingsbury, Cambridge University, May 2002
.. codeauthor:: Cian Shaffrey, Cambridge University, May 2002
"""
+ Yl = td_signal.lowpass
+ Yh = td_signal.subbands
+
a = len(Yh) # No of levels.
if gain_mask is None:
@@ -310,7 +253,7 @@ class Transform2dNumPy(object):
# Do odd top-level filters on rows.
Z = (colfilter(y1.T,g0o) + colfilter(y2.T,g1o)).T
- return InverseTransformResultNumPy(Z)
+ return ReconstructedSignal(Z)
#==========================================================================================
# ********** INTERNAL FUNCTIONS **********
diff --git a/dtcwt/lowlevel.py b/dtcwt/lowlevel.py
index 77307e5..355f1cf 100644
--- a/dtcwt/lowlevel.py
+++ b/dtcwt/lowlevel.py
@@ -1,4 +1,4 @@
-from dtcwt.backend.numpy.lowlevel import LowLevelBackendNumPy
+from dtcwt.backend.backend_numpy.lowlevel import LowLevelBackendNumPy
_BACKEND = LowLevelBackendNumPy()
diff --git a/dtcwt/opencl/transform2d.py b/dtcwt/opencl/transform2d.py
index 81f665a..fcc0596 100644
--- a/dtcwt/opencl/transform2d.py
+++ b/dtcwt/opencl/transform2d.py
@@ -11,6 +11,9 @@ from dtcwt.opencl.lowlevel import colfilter, coldfilt, colifilt
from dtcwt.opencl.lowlevel import axis_convolve, axis_convolve_dfilter, q2c
from dtcwt.opencl.lowlevel import to_device, to_queue, to_array, empty
+from dtcwt.backend import TransformDomainSignal, ReconstructedSignal
+from dtcwt.backend.backend_numpy.transform2d import Transform2dNumPy
+
try:
from pyopencl.array import concatenate
except ImportError:
@@ -18,159 +21,169 @@ except ImportError:
pass
def dtwavexfm2(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, include_scale=False, queue=None):
- """Perform a *n*-level DTCWT-2D decompostion on a 2D matrix *X*.
-
- :param X: 2D real array
- :param nlevels: Number of levels of wavelet decomposition
- :param biort: Level 1 wavelets to use. See :py:func:`biort`.
- :param qshift: Level >= 2 wavelets to use. See :py:func:`qshift`.
-
- :returns Yl: The real lowpass image from the final level
- :returns Yh: A tuple containing the complex highpass subimages for each level.
- :returns Yscale: If *include_scale* is True, a tuple containing real lowpass coefficients for every scale.
-
- If *biort* or *qshift* are strings, they are used as an argument to the
- :py:func:`biort` or :py:func:`qshift` functions. Otherwise, they are
- interpreted as tuples of vectors giving filter coefficients. In the *biort*
- case, this should be (h0o, g0o, h1o, g1o). In the *qshift* case, this should
- be (h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b).
-
- Example::
-
- # Performs a 3-level transform on the real image X using the 13,19-tap
- # filters for level 1 and the Q-shift 14-tap filters for levels >= 2.
- Yl, Yh = dtwavexfm2(X, 3, 'near_sym_b', 'qshift_b')
-
- .. codeauthor:: Rich Wareham <rjw57 at cantab.net>, Aug 2013
- .. codeauthor:: Nick Kingsbury, Cambridge University, Sept 2001
- .. codeauthor:: Cian Shaffrey, Cambridge University, Sept 2001
-
- """
- queue = to_queue(queue)
- X = np.atleast_2d(asfarray(X))
-
- # Try to load coefficients if biort is a string parameter
- try:
- h0o, g0o, h1o, g1o = tuple(to_device(x) for x in _biort(biort))
- except TypeError:
- h0o, g0o, h1o, g1o = tuple(to_device(x) for x in biort)
-
- # Try to load coefficients if qshift is a string parameter
- try:
- h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = tuple(to_device(x) for x in _qshift(qshift))
- except TypeError:
- h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = tuple(to_device(x) for x in qshift)
-
- original_size = X.shape
-
- if len(X.shape) >= 3:
- raise ValueError('The entered image is {0}, please enter each image slice separately.'.
- format('x'.join(list(str(s) for s in X.shape))))
-
- # The next few lines of code check to see if the image is odd in size, if so an extra ...
- # row/column will be added to the bottom/right of the image
- initial_row_extend = 0 #initialise
- initial_col_extend = 0
- if original_size[0] % 2 != 0:
- # if X.shape[0] is not divisible by 2 then we need to extend X by adding a row at the bottom
- X = np.vstack((X, X[[-1],:])) # Any further extension will be done in due course.
- initial_row_extend = 1
-
- if original_size[1] % 2 != 0:
- # if X.shape[1] is not divisible by 2 then we need to extend X by adding a col to the left
- X = np.hstack((X, X[:,[-1]]))
- initial_col_extend = 1
-
- extended_size = X.shape
-
- if nlevels == 0:
- if include_scale:
- return X, (), ()
- else:
- return X, ()
-
- # initialise
- Yh = [None,] * nlevels
+ t = Transform2dOpenCL(biort=biort, qshift=qshift, queue=queue)
+ r = t.forward(X, nlevels=nlevels, include_scale=include_scale)
if include_scale:
- # this is only required if the user specifies a third output component.
- Yscale = [None,] * nlevels
-
- complex_dtype = appropriate_complex_type_for(X)
-
- if nlevels >= 1:
- # Do odd top-level filters on cols.
- Lo = axis_convolve(X,h0o,axis=0,queue=queue)
- Hi = axis_convolve(X,h1o,axis=0,queue=queue)
-
- # Do odd top-level filters on rows.
- LoLo = axis_convolve(Lo,h0o,axis=1)
-
- Yh[0] = q2c(
- axis_convolve(Hi,h0o,axis=1,queue=queue),
- axis_convolve(Lo,h1o,axis=1,queue=queue),
- axis_convolve(Hi,h1o,axis=1,queue=queue),
- )
-
+ return r.lowpass, r.subbands, r.scales
+ else:
+ return r.lowpass, r.subbands
+
+class Transform2dOpenCL(Transform2dNumPy):
+ def __init__(self, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, queue=None):
+ super(Transform2dOpenCL, self).__init__(biort=biort, qshift=qshift)
+ self.queue = to_queue(queue)
+
+ def forward(self, X, nlevels=3, include_scale=False):
+ """Perform a *n*-level DTCWT-2D decompostion on a 2D matrix *X*.
+
+ """
+ queue = self.queue
+ X = np.atleast_2d(asfarray(X))
+
+ # If biort has 6 elements instead of 4, then it's a modified
+ # rotationally symmetric wavelet
+ # FIXME: there's probably a nicer way to do this
+ if len(self.biort) == 4:
+ h0o, g0o, h1o, g1o = self.biort
+ elif len(self.biort) == 6:
+ h0o, g0o, h1o, g1o, h2o, g2o = self.biort
+ else:
+ raise ValueError('Biort wavelet must have 6 or 4 components.')
+
+ # If qshift has 10 elements instead of 8, then it's a modified
+ # rotationally symmetric wavelet
+ # FIXME: there's probably a nicer way to do this
+ if len(self.qshift) == 8:
+ h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = self.qshift
+ elif len(self.qshift) == 10:
+ h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b, h2a, h2b = self.qshift
+ else:
+ raise ValueError('Qshift wavelet must have 10 or 8 components.')
+
+ original_size = X.shape
+
+ if len(X.shape) >= 3:
+ raise ValueError('The entered image is {0}, please enter each image slice separately.'.
+ format('x'.join(list(str(s) for s in X.shape))))
+
+ # The next few lines of code check to see if the image is odd in size, if so an extra ...
+ # row/column will be added to the bottom/right of the image
+ initial_row_extend = 0 #initialise
+ initial_col_extend = 0
+ if original_size[0] % 2 != 0:
+ # if X.shape[0] is not divisible by 2 then we need to extend X by adding a row at the bottom
+ X = np.vstack((X, X[[-1],:])) # Any further extension will be done in due course.
+ initial_row_extend = 1
+
+ if original_size[1] % 2 != 0:
+ # if X.shape[1] is not divisible by 2 then we need to extend X by adding a col to the left
+ X = np.hstack((X, X[:,[-1]]))
+ initial_col_extend = 1
+
+ extended_size = X.shape
+
+ if nlevels == 0:
+ if include_scale:
+ return TransformDomainSignal(X, (), ())
+ else:
+ return TransformDomainSignal(X, ())
+
+ # initialise
+ Yh = [None,] * nlevels
if include_scale:
- Yscale[0] = LoLo
-
- for level in xrange(1, nlevels):
- row_size, col_size = LoLo.shape
-
- if row_size % 4 != 0:
- # Extend by 2 rows if no. of rows of LoLo are not divisible by 4
- LoLo = to_array(LoLo)
- LoLo = np.vstack((LoLo[:1,:], LoLo, LoLo[-1:,:]))
-
- if col_size % 4 != 0:
- # Extend by 2 cols if no. of cols of LoLo are not divisible by 4
- LoLo = to_array(LoLo)
- LoLo = np.hstack((LoLo[:,:1], LoLo, LoLo[:,-1:]))
-
- # Do even Qshift filters on rows.
- Lo = axis_convolve_dfilter(LoLo,h0b,axis=0,queue=queue)
- Hi = axis_convolve_dfilter(LoLo,h1b,axis=0,queue=queue)
-
- # Do even Qshift filters on columns.
- LoLo = axis_convolve_dfilter(Lo,h0b,axis=1,queue=queue)
-
- Yh[level] = q2c(
- axis_convolve_dfilter(Hi,h0b,axis=1,queue=queue),
- axis_convolve_dfilter(Lo,h1b,axis=1,queue=queue),
- axis_convolve_dfilter(Hi,h1b,axis=1,queue=queue),
- )
-
+ # this is only required if the user specifies a third output component.
+ Yscale = [None,] * nlevels
+
+ complex_dtype = appropriate_complex_type_for(X)
+
+ if nlevels >= 1:
+ # Do odd top-level filters on cols.
+ Lo = axis_convolve(X,h0o,axis=0,queue=queue)
+ Hi = axis_convolve(X,h1o,axis=0,queue=queue)
+ if len(self.biort) >= 6:
+ Ba = axis_convolve(X,h2o,axis=0,queue=queue)
+
+ # Do odd top-level filters on rows.
+ LoLo = axis_convolve(Lo,h0o,axis=1)
+
+ if len(self.biort) >= 6:
+ diag = axis_convolve(Ba,h2o,axis=1,queue=queue)
+ else:
+ diag = axis_convolve(Hi,h1o,axis=1,queue=queue)
+
+ Yh[0] = q2c(
+ axis_convolve(Hi,h0o,axis=1,queue=queue),
+ axis_convolve(Lo,h1o,axis=1,queue=queue),
+ diag,
+ )
+
+ if include_scale:
+ Yscale[0] = LoLo
+
+ for level in xrange(1, nlevels):
+ row_size, col_size = LoLo.shape
+
+ if row_size % 4 != 0:
+ # Extend by 2 rows if no. of rows of LoLo are not divisible by 4
+ LoLo = to_array(LoLo)
+ LoLo = np.vstack((LoLo[:1,:], LoLo, LoLo[-1:,:]))
+
+ if col_size % 4 != 0:
+ # Extend by 2 cols if no. of cols of LoLo are not divisible by 4
+ LoLo = to_array(LoLo)
+ LoLo = np.hstack((LoLo[:,:1], LoLo, LoLo[:,-1:]))
+
+ # Do even Qshift filters on rows.
+ Lo = axis_convolve_dfilter(LoLo,h0b,axis=0,queue=queue)
+ Hi = axis_convolve_dfilter(LoLo,h1b,axis=0,queue=queue)
+ if len(self.qshift) >= 10:
+ Ba = axis_convolve_dfilter(LoLo,h2b,axis=0,queue=queue)
+
+ # Do even Qshift filters on columns.
+ LoLo = axis_convolve_dfilter(Lo,h0b,axis=1,queue=queue)
+
+ if len(self.qshift) >= 10:
+ diag = axis_convolve_dfilter(Ba,h2b,axis=1,queue=queue)
+ else:
+ diag = axis_convolve_dfilter(Hi,h1b,axis=1,queue=queue)
+
+ Yh[level] = q2c(
+ axis_convolve_dfilter(Hi,h0b,axis=1,queue=queue),
+ axis_convolve_dfilter(Lo,h1b,axis=1,queue=queue),
+ diag,
+ )
+
+ if include_scale:
+ Yscale[level] = LoLo
+
+ Yl = to_array(LoLo,queue=queue)
+ Yh = list(to_array(x) for x in Yh)
if include_scale:
- Yscale[level] = LoLo
+ Yscale = list(to_array(x) for x in Yscale)
+
+ if initial_row_extend == 1 and initial_col_extend == 1:
+ logging.warn('The image entered is now a {0} NOT a {1}.'.format(
+ 'x'.join(list(str(s) for s in extended_size)),
+ 'x'.join(list(str(s) for s in original_size))))
+ logging.warn(
+ 'The bottom row and rightmost column have been duplicated, prior to decomposition.')
+
+ if initial_row_extend == 1 and initial_col_extend == 0:
+ logging.warn('The image entered is now a {0} NOT a {1}.'.format(
+ 'x'.join(list(str(s) for s in extended_size)),
+ 'x'.join(list(str(s) for s in original_size))))
+ logging.warn(
+ 'The bottom row has been duplicated, prior to decomposition.')
+
+ if initial_row_extend == 0 and initial_col_extend == 1:
+ logging.warn('The image entered is now a {0} NOT a {1}.'.format(
+ 'x'.join(list(str(s) for s in extended_size)),
+ 'x'.join(list(str(s) for s in original_size))))
+ logging.warn(
+ 'The rightmost column has been duplicated, prior to decomposition.')
- Yl = to_array(LoLo,queue=queue)
- Yh = list(to_array(x) for x in Yh)
- if include_scale:
- Yscale = list(to_array(x) for x in Yscale)
-
- if initial_row_extend == 1 and initial_col_extend == 1:
- logging.warn('The image entered is now a {0} NOT a {1}.'.format(
- 'x'.join(list(str(s) for s in extended_size)),
- 'x'.join(list(str(s) for s in original_size))))
- logging.warn(
- 'The bottom row and rightmost column have been duplicated, prior to decomposition.')
-
- if initial_row_extend == 1 and initial_col_extend == 0:
- logging.warn('The image entered is now a {0} NOT a {1}.'.format(
- 'x'.join(list(str(s) for s in extended_size)),
- 'x'.join(list(str(s) for s in original_size))))
- logging.warn(
- 'The bottom row has been duplicated, prior to decomposition.')
-
- if initial_row_extend == 0 and initial_col_extend == 1:
- logging.warn('The image entered is now a {0} NOT a {1}.'.format(
- 'x'.join(list(str(s) for s in extended_size)),
- 'x'.join(list(str(s) for s in original_size))))
- logging.warn(
- 'The rightmost column has been duplicated, prior to decomposition.')
-
- if include_scale:
- return Yl, tuple(Yh), tuple(Yscale)
- else:
- return Yl, tuple(Yh)
+ if include_scale:
+ return TransformDomainSignal(Yl, tuple(Yh), tuple(Yscale))
+ else:
+ return TransformDomainSignal(Yl, tuple(Yh))
diff --git a/dtcwt/transform2d.py b/dtcwt/transform2d.py
index ce56154..936cf70 100644
--- a/dtcwt/transform2d.py
+++ b/dtcwt/transform2d.py
@@ -8,7 +8,8 @@ from dtcwt.defaults import DEFAULT_BIORT, DEFAULT_QSHIFT
from dtcwt.lowlevel import colfilter, coldfilt, colifilt
from dtcwt.utils import appropriate_complex_type_for, asfarray
-from dtcwt.backend.numpy.transform2d import Transform2dNumPy
+from dtcwt.backend import TransformDomainSignal
+from dtcwt.backend.backend_numpy.transform2d import Transform2dNumPy
def dtwavexfm2(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, include_scale=False):
"""Perform a *n*-level DTCWT-2D decompostion on a 2D matrix *X*.
@@ -44,9 +45,9 @@ def dtwavexfm2(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, include
res = trans.forward(X, nlevels, include_scale)
if include_scale:
- return res.lowpass, res.highpass_coeffs, res.scales
+ return res.lowpass, res.subbands, res.scales
else:
- return res.lowpass, res.highpass_coeffs
+ return res.lowpass, res.subbands
def dtwaveifm2(Yl,Yh,biort=DEFAULT_BIORT,qshift=DEFAULT_QSHIFT,gain_mask=None):
"""Perform an *n*-level dual-tree complex wavelet (DTCWT) 2D
@@ -83,8 +84,8 @@ def dtwaveifm2(Yl,Yh,biort=DEFAULT_BIORT,qshift=DEFAULT_QSHIFT,gain_mask=None):
"""
trans = Transform2dNumPy(biort, qshift)
- res = trans.inverse(Yl, Yh, gain_mask=gain_mask)
- return res.to_array()
+ res = trans.inverse(TransformDomainSignal(Yl, Yh), gain_mask=gain_mask)
+ return res.value
# BACKWARDS COMPATIBILITY: add a dtwave{i,x}fm2b function which is a copy of
# dtwave{i,x}fm2b. The functionality of the ...b variant is rolled into the
diff --git a/tests/testopenclxfm2.py b/tests/testopenclxfm2.py
index a112f24..7d71563 100644
--- a/tests/testopenclxfm2.py
+++ b/tests/testopenclxfm2.py
@@ -85,4 +85,11 @@ def test_0_levels():
b = dtwavexfm2_cl(lena, nlevels=0)
_compare_transforms(a, b)
+ at skip_if_no_cl
+ at attr('transform')
+def test_modified():
+ a = dtwavexfm2_np(lena, biort=biort('near_sym_b_bp'), qshift=qshift('qshift_b_bp'))
+ b = dtwavexfm2_cl(lena, biort=biort('near_sym_b_bp'), qshift=qshift('qshift_b_bp'))
+ _compare_transforms(a, b)
+
# vim:sw=4:sts=4:et
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/python-dtcwt.git
More information about the debian-science-commits
mailing list