[python-dtcwt] 188/497: fold dtwave{i,x}fm2b into main transform
Ghislain Vaillant
ghisvail-guest at moszumanska.debian.org
Tue Jul 21 18:06:03 UTC 2015
This is an automated email from the git hooks/post-receive script.
ghisvail-guest pushed a commit to branch debian/sid
in repository python-dtcwt.
commit cece801deecef7437917a81288fecce55329abf4
Author: Rich Wareham <rjw57 at cam.ac.uk>
Date: Mon Nov 11 16:08:21 2013 +0000
fold dtwave{i,x}fm2b into main transform
---
docs/variant.rst | 2 +
dtcwt/backend/numpy/transform2d.py | 185 +++++++++++++++-
dtcwt/coeffs.py | 6 +
dtcwt/transform2d.py | 418 +------------------------------------
tests/testxfm2.py | 5 +
5 files changed, 203 insertions(+), 413 deletions(-)
diff --git a/docs/variant.rst b/docs/variant.rst
index 1a90e2e..e639510 100644
--- a/docs/variant.rst
+++ b/docs/variant.rst
@@ -4,6 +4,8 @@ Variant transforms
In addition to the basic 1, 2 and 3 dimensional DT-CWT, this library also
supports a selection of variant transforms.
+.. _rot-symm-wavelets:
+
Rotational symmetry modified wavelet transform
----------------------------------------------
diff --git a/dtcwt/backend/numpy/transform2d.py b/dtcwt/backend/numpy/transform2d.py
index 431d51f..f18153c 100644
--- a/dtcwt/backend/numpy/transform2d.py
+++ b/dtcwt/backend/numpy/transform2d.py
@@ -16,6 +16,13 @@ class ForwardTransformResultNumPy(object):
self.highpass_coeffs = Yh
self.scales = Yscale
+class InverseTransformResultNumPy(object):
+ def __init__(self, X):
+ self._X = X
+
+ def to_array(self):
+ return self._X
+
class Transform2dNumPy(object):
def __init__(self, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT):
# Load bi-orthogonal wavelets
@@ -59,8 +66,25 @@ class Transform2dNumPy(object):
.. codeauthor:: Cian Shaffrey, Cambridge University, Sept 2001
"""
- h0o, g0o, h1o, g1o = self.biort
- h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = self.qshift
+ # If biort has 6 elements instead of 4, then it's a modified
+ # rotationally symmetric wavelet
+ # FIXME: there's probably a nicer way to do this
+ if len(self.biort) == 4:
+ h0o, g0o, h1o, g1o = self.biort
+ elif len(self.biort) == 6:
+ h0o, g0o, h1o, g1o, h2o, g2o = self.biort
+ else:
+ raise ValueError('Biort wavelet must have 6 or 4 components.')
+
+ # If qshift has 10 elements instead of 8, then it's a modified
+ # rotationally symmetric wavelet
+ # FIXME: there's probably a nicer way to do this
+ if len(self.qshift) == 8:
+ h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = self.qshift
+ elif len(self.qshift) == 10:
+ h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b, h2a, h2b = self.qshift
+ else:
+ raise ValueError('Qshift wavelet must have 10 or 8 components.')
X = np.atleast_2d(asfarray(X))
original_size = X.shape
@@ -103,13 +127,18 @@ class Transform2dNumPy(object):
# Do odd top-level filters on cols.
Lo = colfilter(X,h0o).T
Hi = colfilter(X,h1o).T
+ if len(self.biort) >= 6:
+ Ba = colfilter(X,h2o).T
# Do odd top-level filters on rows.
LoLo = colfilter(Lo,h0o).T
Yh[0] = np.zeros((LoLo.shape[0] >> 1, LoLo.shape[1] >> 1, 6), dtype=complex_dtype)
Yh[0][:,:,0:6:5] = q2c(colfilter(Hi,h0o).T) # Horizontal pair
Yh[0][:,:,2:4:1] = q2c(colfilter(Lo,h1o).T) # Vertical pair
- Yh[0][:,:,1:5:3] = q2c(colfilter(Hi,h1o).T) # Diagonal pair
+ if len(self.biort) >= 6:
+ Yh[0][:,:,1:5:3] = q2c(colfilter(Ba,h2o).T) # Diagonal pair
+ else:
+ Yh[0][:,:,1:5:3] = q2c(colfilter(Hi,h1o).T) # Diagonal pair
if include_scale:
Yscale[0] = LoLo
@@ -127,6 +156,8 @@ class Transform2dNumPy(object):
# Do even Qshift filters on rows.
Lo = coldfilt(LoLo,h0b,h0a).T
Hi = coldfilt(LoLo,h1b,h1a).T
+ if len(self.qshift) >= 10:
+ Ba = coldfilt(LoLo,h2b,h2a).T
# Do even Qshift filters on columns.
LoLo = coldfilt(Lo,h0b,h0a).T
@@ -134,7 +165,10 @@ class Transform2dNumPy(object):
Yh[level] = np.zeros((LoLo.shape[0]>>1, LoLo.shape[1]>>1, 6), dtype=complex_dtype)
Yh[level][:,:,0:6:5] = q2c(coldfilt(Hi,h0b,h0a).T) # Horizontal
Yh[level][:,:,2:4:1] = q2c(coldfilt(Lo,h1b,h1a).T) # Vertical
- Yh[level][:,:,1:5:3] = q2c(coldfilt(Hi,h1b,h1a).T) # Diagonal
+ if len(self.qshift) >= 10:
+ Yh[level][:,:,1:5:3] = q2c(coldfilt(Ba,h2b,h2a).T) # Diagonal
+ else:
+ Yh[level][:,:,1:5:3] = q2c(coldfilt(Hi,h1b,h1a).T) # Diagonal
if include_scale:
Yscale[level] = LoLo
@@ -167,6 +201,121 @@ class Transform2dNumPy(object):
else:
return ForwardTransformResultNumPy(Yl, tuple(Yh))
+ def inverse(self, Yl, Yh, gain_mask=None):
+ """Perform an *n*-level dual-tree complex wavelet (DTCWT) 2D
+ reconstruction.
+
+ :param Yl: The real lowpass subband from the final level
+ :param Yh: A sequence containing the complex highpass subband for each level.
+ :param biort: Level 1 wavelets to use. See :py:func:`biort`.
+ :param qshift: Level >= 2 wavelets to use. See :py:func:`qshift`.
+ :param gain_mask: Gain to be applied to each subband.
+
+ :returns Z: Reconstructed real array
+
+ The (*d*, *l*)-th element of *gain_mask* is gain for subband with direction
+ *d* at level *l*. If gain_mask[d,l] == 0, no computation is performed for
+ band (d,l). Default *gain_mask* is all ones. Note that both *d* and *l* are
+ zero-indexed.
+
+ If *biort* or *qshift* are strings, they are used as an argument to the
+ :py:func:`biort` or :py:func:`qshift` functions. Otherwise, they are
+ interpreted as tuples of vectors giving filter coefficients. In the *biort*
+ case, this should be (h0o, g0o, h1o, g1o). In the *qshift* case, this should
+ be (h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b).
+
+ Example::
+
+ # Performs a 3-level reconstruction from Yl,Yh using the 13,19-tap
+ # filters for level 1 and the Q-shift 14-tap filters for levels >= 2.
+ Z = dtwaveifm2(Yl, Yh, 'near_sym_b', 'qshift_b')
+
+ .. codeauthor:: Rich Wareham <rjw57 at cantab.net>, Aug 2013
+ .. codeauthor:: Nick Kingsbury, Cambridge University, May 2002
+ .. codeauthor:: Cian Shaffrey, Cambridge University, May 2002
+
+ """
+ a = len(Yh) # No of levels.
+
+ if gain_mask is None:
+ gain_mask = np.ones((6,a)) # Default gain_mask.
+
+ gain_mask = np.array(gain_mask)
+
+ # If biort has 6 elements instead of 4, then it's a modified
+ # rotationally symmetric wavelet
+ # FIXME: there's probably a nicer way to do this
+ if len(self.biort) == 4:
+ h0o, g0o, h1o, g1o = self.biort
+ elif len(self.biort) == 6:
+ h0o, g0o, h1o, g1o, h2o, g2o = self.biort
+ else:
+ raise ValueError('Biort wavelet must have 6 or 4 components.')
+
+ # If qshift has 10 elements instead of 8, then it's a modified
+ # rotationally symmetric wavelet
+ # FIXME: there's probably a nicer way to do this
+ if len(self.qshift) == 8:
+ h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = self.qshift
+ elif len(self.qshift) == 10:
+ h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b, h2a, h2b = self.qshift
+ else:
+ raise ValueError('Qshift wavelet must have 10 or 8 components.')
+
+ current_level = a
+ Z = Yl
+
+ while current_level >= 2: # this ensures that for level 1 we never do the following
+ lh = c2q(Yh[current_level-1][:,:,[0, 5]], gain_mask[[0, 5], current_level-1])
+ hl = c2q(Yh[current_level-1][:,:,[2, 3]], gain_mask[[2, 3], current_level-1])
+ hh = c2q(Yh[current_level-1][:,:,[1, 4]], gain_mask[[1, 4], current_level-1])
+
+ # Do even Qshift filters on columns.
+ y1 = colifilt(Z,g0b,g0a) + colifilt(lh,g1b,g1a)
+ y2 = colifilt(hl,g0b,g0a) + colifilt(hh,g1b,g1a)
+
+ # Do even Qshift filters on rows.
+ Z = (colifilt(y1.T,g0b,g0a) + colifilt(y2.T,g1b,g1a)).T
+
+ # Check size of Z and crop as required
+ [row_size, col_size] = Z.shape
+ S = 2*np.array(Yh[current_level-2].shape)
+ if row_size != S[0]: # check to see if this result needs to be cropped for the rows
+ Z = Z[1:-1,:]
+ if col_size != S[1]: # check to see if this result needs to be cropped for the cols
+ Z = Z[:,1:-1]
+
+ if np.any(np.array(Z.shape) != S[:2]):
+ raise ValueError('Sizes of subbands are not valid for DTWAVEIFM2')
+
+ current_level = current_level - 1
+
+ if current_level == 1:
+ lh = c2q(Yh[current_level-1][:,:,[0, 5]],gain_mask[[0, 5],current_level-1])
+ hl = c2q(Yh[current_level-1][:,:,[2, 3]],gain_mask[[2, 3],current_level-1])
+ hh = c2q(Yh[current_level-1][:,:,[1, 4]],gain_mask[[1, 4],current_level-1])
+
+ # Do odd top-level filters on columns.
+ y1 = colfilter(Z,g0o) + colfilter(lh,g1o)
+
+ if len(self.qshift) >= 10:
+ y2 = colfilter(hl,g0o)
+ y2bp = colfilter(hh,g2o)
+
+ # Do odd top-level filters on rows.
+ Z = (colfilter(y1.T,g0o) + colfilter(y2.T,g1o) + colfilter(y2bp.T, g2o)).T
+ else:
+ y2 = colfilter(hl,g0o) + colfilter(hh,g1o)
+
+ # Do odd top-level filters on rows.
+ Z = (colfilter(y1.T,g0o) + colfilter(y2.T,g1o)).T
+
+ return InverseTransformResultNumPy(Z)
+
+#==========================================================================================
+# ********** INTERNAL FUNCTIONS **********
+#==========================================================================================
+
def q2c(y):
"""Convert from quads in y to complex numbers in z.
@@ -188,3 +337,31 @@ def q2c(y):
z = np.dstack((p-q,p+q))
return z
+
+def c2q(w,gain):
+ """Scale by gain and convert from complex w(:,:,1:2) to real quad-numbers
+ in z.
+
+ Arrange pixels from the real and imag parts of the 2 subbands
+ into 4 separate subimages .
+ A----B Re Im of w(:,:,1)
+ | |
+ | |
+ C----D Re Im of w(:,:,2)
+
+ """
+
+ x = np.zeros((w.shape[0] << 1, w.shape[1] << 1), dtype=w.real.dtype)
+
+ sc = np.sqrt(0.5) * gain
+ P = w[:,:,0]*sc[0] + w[:,:,1]*sc[1]
+ Q = w[:,:,0]*sc[0] - w[:,:,1]*sc[1]
+
+ # Recover each of the 4 corners of the quads.
+ x[0::2, 0::2] = P.real # a = (A+C)*sc
+ x[0::2, 1::2] = P.imag # b = (B+D)*sc
+ x[1::2, 0::2] = Q.imag # c = (B-D)*sc
+ x[1::2, 1::2] = -Q.real # d = (C-A)*sc
+
+ return x
+
diff --git a/dtcwt/coeffs.py b/dtcwt/coeffs.py
index 821d14f..4b70545 100644
--- a/dtcwt/coeffs.py
+++ b/dtcwt/coeffs.py
@@ -38,6 +38,9 @@ def biort(name):
Return a tuple whose elements are a vector specifying the h0o, g0o, h1o and
g1o coefficients.
+ See :ref:`rot-symm-wavelets` for an explanation of the ``near_sym_b_bp``
+ wavelet filters.
+
:raises IOError: if name does not correspond to a set of wavelets known to the library.
:raises ValueError: if name specifies a :py:func:`qshift` wavelet.
@@ -69,6 +72,9 @@ def qshift(name):
Return a tuple whose elements are a vector specifying the h0a, h0b, g0a,
g0b, h1a, h1b, g1a and g1b coefficients.
+ See :ref:`rot-symm-wavelets` for an explanation of the ``qshift_b_bp``
+ wavelet filters.
+
:raises IOError: if name does not correspond to a set of wavelets known to the library.
:raises ValueError: if name specifies a :py:func:`biort` wavelet.
diff --git a/dtcwt/transform2d.py b/dtcwt/transform2d.py
index bbb9479..ce56154 100644
--- a/dtcwt/transform2d.py
+++ b/dtcwt/transform2d.py
@@ -48,187 +48,6 @@ def dtwavexfm2(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, include
else:
return res.lowpass, res.highpass_coeffs
-def dtwavexfm2b(X, nlevels=3, biort='near_sym_b_bp', qshift='qshift_b_bp', include_scale=False):
- """Perform a *n*-level DTCWT-2D decompostion on a 2D matrix *X*.
-
- :param X: 2D real array
- :param nlevels: Number of levels of wavelet decomposition
- :param biort: Level 1 wavelets to use. See :py:func:`biort`.
- :param qshift: Level >= 2 wavelets to use. See :py:func:`qshift`.
-
- :returns Yl: The real lowpass image from the final level
- :returns Yh: A tuple containing the complex highpass subimages for each level.
- :returns Yscale: If *include_scale* is True, a tuple containing real lowpass coefficients for every scale.
-
- If *biort* or *qshift* are strings, they are used as an argument to the
- :py:func:`biort` or :py:func:`qshift` functions. Otherwise, they are
- interpreted as tuples of vectors giving filter coefficients. In the *biort*
- case, this should be (h0o, g0o, h1o, g1o). In the *qshift* case, this should
- be (h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b).
-
- Example::
-
- # Performs a 3-level transform on the real image X using the 13,19-tap
- # filters for level 1 and the Q-shift 14-tap filters for levels >= 2.
- Yl, Yh = dtwavexfm2b(X, 3, 'near_sym_b', 'qshift_b')
-
- .. codeauthor:: Rich Wareham <rjw57 at cantab.net>, Aug 2013
- .. codeauthor:: Nick Kingsbury, Cambridge University, Sept 2001
- .. codeauthor:: Cian Shaffrey, Cambridge University, Sept 2001
-
- """
- X = np.atleast_2d(asfarray(X))
-
- # Try to load coefficients if biort is a string parameter
- try:
- h0o, g0o, h1o, g1o, h2o, g2o = _biort(biort)
- except TypeError:
- h0o, g0o, h1o, g1o, h2o, g2o = biort
-
- # Try to load coefficients if qshift is a string parameter
- try:
- h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b, h2a, h2b = _qshift(qshift)
- except TypeError:
- h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b, h2a, h2b = qshift
-
- # Probably not the best way to do this, but I'm a noob.
- try:
- h2a, h2b
- bp_qsh = 1
- except:
- bp_qsh = 0
-
- try:
- h2o
- bp_lev1 = 1
- except:
- bp_lev1 = 0
-
- try:
- h1o
- hi_lev1 = 1
- except:
- hi_lev1 = 0
-
-
- original_size = X.shape
-
- if len(X.shape) >= 3:
- raise ValueError('The entered image is {0}, please enter each image slice separately.'.
- format('x'.join(list(str(s) for s in X.shape))))
-
- # The next few lines of code check to see if the image is odd in size, if so an extra ...
- # row/column will be added to the bottom/right of the image
- initial_row_extend = 0 #initialise
- initial_col_extend = 0
- if original_size[0] % 2 != 0:
- # if X.shape[0] is not divisable by 2 then we need to extend X by adding a row at the bottom
- X = np.vstack((X, X[[-1],:])) # Any further extension will be done in due course.
- initial_row_extend = 1
-
- if original_size[1] % 2 != 0:
- # if X.shape[1] is not divisable by 2 then we need to extend X by adding a col to the left
- X = np.hstack((X, X[:,[-1]]))
- initial_col_extend = 1
-
- extended_size = X.shape
-
- if nlevels == 0:
- if include_scale:
- return X, (), ()
- else:
- return X, ()
-
- # initialise
- Yh = [None,] * nlevels
- if include_scale:
- # this is only required if the user specifies a third output component.
- Yscale = [None,] * nlevels
-
- complex_dtype = appropriate_complex_type_for(X)
-
- if nlevels >= 1:
- # Do odd top-level filters on cols.
- Lo = colfilter(X,h0o).T
- if hi_lev1 == 1:
- Hi = colfilter(X,h1o).T
- if bp_lev1 == 1:
- Ba = colfilter(X,h2o).T
-
-
- # Do odd top-level filters on rows.
- LoLo = colfilter(Lo,h0o).T
- Yh[0] = np.zeros((LoLo.shape[0] >> 1, LoLo.shape[1] >> 1, 6), dtype=complex_dtype)
- Yh[0][:,:,[0, 5]] = q2c(colfilter(Hi,h0o).T) # Horizontal pair
- Yh[0][:,:,[2, 3]] = q2c(colfilter(Lo,h1o).T) # Vertical pair
- if bp_lev1 == 1: # Diagonal pair
- Yh[0][:,:,[1, 4]] = q2c(colfilter(Ba,h2o).T)
- else:
- Yh[0][:,:,[1, 4]] = q2c(colfilter(Hi,h1o).T)
-
- if include_scale:
- Yscale[0] = LoLo
-
- for level in xrange(1, nlevels):
- row_size, col_size = LoLo.shape
- if row_size % 4 != 0:
- # Extend by 2 rows if no. of rows of LoLo are not divisable by 4
- LoLo = np.vstack((LoLo[[0],:], LoLo, LoLo[[-1],:]))
-
- if col_size % 4 != 0:
- # Extend by 2 cols if no. of cols of LoLo are not divisable by 4
- LoLo = np.hstack((LoLo[:,[0]], LoLo, LoLo[:,[-1]]))
-
- # Do even Qshift filters on rows.
- Lo = coldfilt(LoLo,h0b,h0a).T
- Hi = coldfilt(LoLo,h1b,h1a).T
- if bp_qsh == 1:
- Ba = coldfilt(LoLo,h2b,h2a).T
-
- # Do even Qshift filters on columns.
- LoLo = coldfilt(Lo,h0b,h0a).T
-
- Yh[level] = np.zeros((LoLo.shape[0]>>1, LoLo.shape[1]>>1, 6), dtype=complex_dtype)
- Yh[level][:,:,[0, 5]] = q2c(coldfilt(Hi,h0b,h0a).T) # Horizontal
- Yh[level][:,:,[2, 3]] = q2c(coldfilt(Lo,h1b,h1a).T) # Vertical
-
- if bp_qsh == 1:
- Yh[level][:,:,[1, 4]] = q2c(coldfilt(Ba,h2b,h2a).T) # Diagonal bandpass
- else:
- Yh[level][:,:,[1, 4]] = q2c(coldfilt(Hi,h1b,h1a).T) # Diagonal highpass
-
- if include_scale:
- Yscale[level] = LoLo
-
- Yl = LoLo
-
- if initial_row_extend == 1 and initial_col_extend == 1:
- logging.warn('The image entered is now a {0} NOT a {1}.'.format(
- 'x'.join(list(str(s) for s in extended_size)),
- 'x'.join(list(str(s) for s in original_size))))
- logging.warn(
- 'The bottom row and rightmost column have been duplicated, prior to decomposition.')
-
- if initial_row_extend == 1 and initial_col_extend == 0:
- logging.warn('The image entered is now a {0} NOT a {1}.'.format(
- 'x'.join(list(str(s) for s in extended_size)),
- 'x'.join(list(str(s) for s in original_size))))
- logging.warn(
- 'The bottom row has been duplicated, prior to decomposition.')
-
- if initial_row_extend == 0 and initial_col_extend == 1:
- logging.warn('The image entered is now a {0} NOT a {1}.'.format(
- 'x'.join(list(str(s) for s in extended_size)),
- 'x'.join(list(str(s) for s in original_size))))
- logging.warn(
- 'The rightmost column has been duplicated, prior to decomposition.')
-
- if include_scale:
- return Yl, tuple(Yh), tuple(Yscale)
- else:
- return Yl, tuple(Yh)
-
-
def dtwaveifm2(Yl,Yh,biort=DEFAULT_BIORT,qshift=DEFAULT_QSHIFT,gain_mask=None):
"""Perform an *n*-level dual-tree complex wavelet (DTCWT) 2D
reconstruction.
@@ -263,231 +82,12 @@ def dtwaveifm2(Yl,Yh,biort=DEFAULT_BIORT,qshift=DEFAULT_QSHIFT,gain_mask=None):
.. codeauthor:: Cian Shaffrey, Cambridge University, May 2002
"""
- a = len(Yh) # No of levels.
-
- if gain_mask is None:
- gain_mask = np.ones((6,a)) # Default gain_mask.
-
- gain_mask = np.array(gain_mask)
-
- # Try to load coefficients if biort is a string parameter
- try:
- h0o, g0o, h1o, g1o = _biort(biort)
- except TypeError:
- h0o, g0o, h1o, g1o = biort
-
- # Try to load coefficients if qshift is a string parameter
- try:
- h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = _qshift(qshift)
- except TypeError:
- h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = qshift
-
- current_level = a
- Z = Yl
-
- while current_level >= 2: # this ensures that for level 1 we never do the following
- lh = c2q(Yh[current_level-1][:,:,[0, 5]], gain_mask[[0, 5], current_level-1])
- hl = c2q(Yh[current_level-1][:,:,[2, 3]], gain_mask[[2, 3], current_level-1])
- hh = c2q(Yh[current_level-1][:,:,[1, 4]], gain_mask[[1, 4], current_level-1])
-
- # Do even Qshift filters on columns.
- y1 = colifilt(Z,g0b,g0a) + colifilt(lh,g1b,g1a)
- y2 = colifilt(hl,g0b,g0a) + colifilt(hh,g1b,g1a)
-
- # Do even Qshift filters on rows.
- Z = (colifilt(y1.T,g0b,g0a) + colifilt(y2.T,g1b,g1a)).T
-
- # Check size of Z and crop as required
- [row_size, col_size] = Z.shape
- S = 2*np.array(Yh[current_level-2].shape)
- if row_size != S[0]: # check to see if this result needs to be cropped for the rows
- Z = Z[1:-1,:]
- if col_size != S[1]: # check to see if this result needs to be cropped for the cols
- Z = Z[:,1:-1]
-
- if np.any(np.array(Z.shape) != S[:2]):
- raise ValueError('Sizes of subbands are not valid for DTWAVEIFM2')
-
- current_level = current_level - 1
-
- if current_level == 1:
- lh = c2q(Yh[current_level-1][:,:,[0, 5]],gain_mask[[0, 5],current_level-1])
- hl = c2q(Yh[current_level-1][:,:,[2, 3]],gain_mask[[2, 3],current_level-1])
- hh = c2q(Yh[current_level-1][:,:,[1, 4]],gain_mask[[1, 4],current_level-1])
-
- # Do odd top-level filters on columns.
- y1 = colfilter(Z,g0o) + colfilter(lh,g1o)
- y2 = colfilter(hl,g0o) + colfilter(hh,g1o)
-
- # Do odd top-level filters on rows.
- Z = (colfilter(y1.T,g0o) + colfilter(y2.T,g1o)).T
-
- return Z
-
-def dtwaveifm2b(Yl,Yh,biort='near_sym_b_bp',qshift='qshift_b_bp',gain_mask=None):
- """Perform an *n*-level dual-tree complex wavelet (DTCWT) 2D
- reconstruction, for use with symmetry-modified DTCWT subbands.
-
- :param Yl: The real lowpass subband from the final level
- :param Yh: A sequence containing the complex highpass subband for each level.
- :param biort: Level 1 wavelets to use. See :py:func:`biort`.
- :param qshift: Level >= 2 wavelets to use. See :py:func:`qshift`.
- :param gain_mask: Gain to be applied to each subband.
-
- :returns Z: Reconstructed real array
-
- The (*d*, *l*)-th element of *gain_mask* is gain for subband with direction
- *d* at level *l*. If gain_mask[d,l] == 0, no computation is performed for
- band (d,l). Default *gain_mask* is all ones. Note that both *d* and *l* are
- zero-indexed.
-
- If *biort* or *qshift* are strings, they are used as an argument to the
- :py:func:`biort` or :py:func:`qshift` functions. Otherwise, they are
- interpreted as tuples of vectors giving filter coefficients. In the *biort*
- case, this should be (h0o, g0o, h1o, g1o). In the *qshift* case, this should
- be (h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b).
-
- Example::
-
- # Performs a 3-level reconstruction from Yl,Yh using the 13,19-tap
- # filters for level 1 and the Q-shift 14-tap filters for levels >= 2.
- Z = dtwaveifm2(Yl, Yh, 'near_sym_b', 'qshift_b')
-
- .. codeauthor:: Rich Wareham <rjw57 at cantab.net>, Aug 2013
- .. codeauthor:: Nick Kingsbury, Cambridge University, May 2002
- .. codeauthor:: Cian Shaffrey, Cambridge University, May 2002
-
- """
- a = len(Yh) # No of levels.
-
- if gain_mask is None:
- gain_mask = np.ones((6,a)) # Default gain_mask.
-
- gain_mask = np.array(gain_mask)
-
- # Try to load coefficients if biort is a string parameter
- try:
- h0o, g0o, h1o, g1o, h2o, g2o = _biort(biort)
- except TypeError:
- h0o, g0o, h1o, g1o, h2o, g2o = biort
-
- # Try to load coefficients if qshift is a string parameter
- try:
- h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b, h2a, h2b = _qshift(qshift)
- except TypeError:
- h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b, h2a, h2b = qshift
-
- # Probably not the best way to do this, but I'm a noob.
- try:
- h2a, h2b
- bp_qsh = 1
- except:
- bp_qsh = 0
-
- try:
- h2o
- bp_lev1 = 1
- except:
- bp_lev1 = 0
-
- current_level = a
- Z = Yl
-
- while current_level >= 2: # this ensures that for level 1 we never do the following
- lh = c2q(Yh[current_level-1][:,:,[0, 5]], gain_mask[[0, 5], current_level-1])
- hl = c2q(Yh[current_level-1][:,:,[2, 3]], gain_mask[[2, 3], current_level-1])
- hh = c2q(Yh[current_level-1][:,:,[1, 4]], gain_mask[[1, 4], current_level-1])
-
- # Do even Qshift filters on columns.
- y1 = colifilt(Z,g0b,g0a) + colifilt(lh,g1b,g1a)
- y2 = colifilt(hl,g0b,g0a) + colifilt(hh,g1b,g1a)
-
- # Do even Qshift filters on rows.
- Z = (colifilt(y1.T,g0b,g0a) + colifilt(y2.T,g1b,g1a)).T
-
- # Check size of Z and crop as required
- [row_size, col_size] = Z.shape
- S = 2*np.array(Yh[current_level-2].shape)
- if row_size != S[0]: # check to see if this result needs to be cropped for the rows
- Z = Z[1:-1,:]
- if col_size != S[1]: # check to see if this result needs to be cropped for the cols
- Z = Z[:,1:-1]
-
- if np.any(np.array(Z.shape) != S[:2]):
- raise ValueError('Sizes of subbands are not valid for DTWAVEIFM2')
-
- current_level = current_level - 1
-
- if current_level == 1:
- lh = c2q(Yh[current_level-1][:,:,[0, 5]],gain_mask[[0, 5],current_level-1])
- hl = c2q(Yh[current_level-1][:,:,[2, 3]],gain_mask[[2, 3],current_level-1])
- hh = c2q(Yh[current_level-1][:,:,[1, 4]],gain_mask[[1, 4],current_level-1])
-
- # Do odd top-level filters on columns.
- y1 = colfilter(Z,g0o) + colfilter(lh,g1o)
- if bp_lev1 == 1:
- y2 = colfilter(hl,g0o)
- y2bp = colfilter(hh,g2o)
- # Do odd top-level filters on rows.
- Z = (colfilter(y1.T,g0o) + colfilter(y2.T,g1o)).T + colfilter(y2bp.T,g2o).T
- else:
- y2 = colfilter(hl,g0o) + colfilter(hh,g1o)
- # Do odd top-level filters on rows.
- Z = (colfilter(y1.T,g0o) + colfilter(y2.T,g1o)).T
-
- return Z
-
-#==========================================================================================
-# ********** INTERNAL FUNCTIONS **********
-#==========================================================================================
-
-def q2c(y):
- """Convert from quads in y to complex numbers in z.
-
- """
- j2 = (np.sqrt(0.5) * np.array([1, 1j])).astype(appropriate_complex_type_for(y))
-
- # Arrange pixels from the corners of the quads into
- # 2 subimages of alternate real and imag pixels.
- # a----b
- # | |
- # | |
- # c----d
-
- # Combine (a,b) and (d,c) to form two complex subimages.
- p = y[0::2, 0::2]*j2[0] + y[0::2, 1::2]*j2[1] # p = (a + jb) / sqrt(2)
- q = y[1::2, 1::2]*j2[0] - y[1::2, 0::2]*j2[1] # q = (d - jc) / sqrt(2)
-
- # Form the 2 subbands in z.
- z = np.dstack((p-q,p+q))
-
- return z
-
-def c2q(w,gain):
- """Scale by gain and convert from complex w(:,:,1:2) to real quad-numbers
- in z.
-
- Arrange pixels from the real and imag parts of the 2 subbands
- into 4 separate subimages .
- A----B Re Im of w(:,:,1)
- | |
- | |
- C----D Re Im of w(:,:,2)
-
- """
-
- x = np.zeros((w.shape[0] << 1, w.shape[1] << 1), dtype=w.real.dtype)
-
- sc = np.sqrt(0.5) * gain
- P = w[:,:,0]*sc[0] + w[:,:,1]*sc[1]
- Q = w[:,:,0]*sc[0] - w[:,:,1]*sc[1]
-
- # Recover each of the 4 corners of the quads.
- x[0::2, 0::2] = P.real # a = (A+C)*sc
- x[0::2, 1::2] = P.imag # b = (B+D)*sc
- x[1::2, 0::2] = Q.imag # c = (B-D)*sc
- x[1::2, 1::2] = -Q.real # d = (C-A)*sc
-
- return x
-
-# vim:sw=4:sts=4:et
+ trans = Transform2dNumPy(biort, qshift)
+ res = trans.inverse(Yl, Yh, gain_mask=gain_mask)
+ return res.to_array()
+
+# BACKWARDS COMPATIBILITY: add a dtwave{i,x}fm2b function which is a copy of
+# dtwave{i,x}fm2b. The functionality of the ...b variant is rolled into the
+# original.
+dtwavexfm2b = dtwavexfm2
+dtwaveifm2b = dtwaveifm2
diff --git a/tests/testxfm2.py b/tests/testxfm2.py
index 168cc0a..b568271 100644
--- a/tests/testxfm2.py
+++ b/tests/testxfm2.py
@@ -57,6 +57,11 @@ def test_odd_rows_and_cols():
def test_odd_rows_and_cols_w_scale():
Yl, Yh, Yscale = dtwavexfm2(lena[:509,:509], include_scale=True)
+def test_rot_symm_modified():
+ # This test only checks there is no error running these functions, not that they work
+ Yl, Yh, Yscale = dtwavexfm2(lena, biort='near_sym_b_bp', qshift='qshift_b_bp', include_scale=True)
+ Z = dtwaveifm2(Yl, Yh, biort='near_sym_b_bp', qshift='qshift_b_bp')
+
def test_0_levels():
Yl, Yh = dtwavexfm2(lena, nlevels=0)
assert np.all(np.abs(Yl - lena) < TOLERANCE)
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/python-dtcwt.git
More information about the debian-science-commits
mailing list