[python-dtcwt] 07/497: added inverse transform

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Tue Jul 21 18:05:42 UTC 2015


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch debian/sid
in repository python-dtcwt.

commit e43f3d68dfa4760f86a70686838d5da007da0ab9
Author: Rich Wareham <rjw57 at cam.ac.uk>
Date:   Tue Aug 6 16:52:41 2013 +0100

    added inverse transform
---
 dtcwt/__init__.py    |   2 +-
 dtcwt/transform2d.py | 155 ++++++++++++++++++++++++++++++++++++++++++++++++---
 tests/testifm2.py    |  30 ++++++++++
 3 files changed, 177 insertions(+), 10 deletions(-)

diff --git a/dtcwt/__init__.py b/dtcwt/__init__.py
index 4dd830d..34d3cb0 100644
--- a/dtcwt/__init__.py
+++ b/dtcwt/__init__.py
@@ -1,3 +1,3 @@
 from coeffs import biort, qshift
 from lowlevel import colfilter, coldfilt, colifilt, reflect
-from transform2d import dtwavexfm2
+from transform2d import dtwavexfm2, dtwaveifm2
diff --git a/dtcwt/transform2d.py b/dtcwt/transform2d.py
index 78dc14f..fdc5673 100644
--- a/dtcwt/transform2d.py
+++ b/dtcwt/transform2d.py
@@ -1,9 +1,12 @@
 import numpy as np
 import logging
 
-from dtcwt import biort as _biort, qshift as _qshift, colfilter, coldfilt
+from dtcwt import biort as _biort, qshift as _qshift, colfilter, coldfilt, colifilt
 
-def dtwavexfm2(X, nlevels=3, biort='near_sym_a', qshift='qshift_a', include_scale=False):
+BIORT_DEFAULT='near_sym_a'
+QSHIFT_DEFAULT='qshift_a'
+
+def dtwavexfm2(X, nlevels=3, biort=BIORT_DEFAULT, qshift=QSHIFT_DEFAULT, include_scale=False):
     """Function to perform a n-level DTCWT-2D decompostion on a 2D matrix X
 
     Yl, Yh = dtwavexfm2(X, nlevels, biort, qshift)
@@ -62,7 +65,7 @@ def dtwavexfm2(X, nlevels=3, biort='near_sym_a', qshift='qshift_a', include_scal
         h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = qshift
 
     original_size = X.shape
-    
+
     if len(X.shape) >= 3:
         raise ValueError('The entered image is {0}, please enter each image slice separately.'.
                 format('x'.join(list(str(s) for s in X.shape))))
@@ -74,13 +77,13 @@ def dtwavexfm2(X, nlevels=3, biort='near_sym_a', qshift='qshift_a', include_scal
     if original_size[0] % 2 != 0:
         # if X.shape[0] is not divisable by 2 then we need to extend X by adding a row at the bottom
         X = np.vstack((X, X[-1,:]))  # Any further extension will be done in due course.
-        initial_row_extend = 1;
+        initial_row_extend = 1
 
     if original_size[1] % 2 != 0:
         # if X.shape[1] is not divisable by 2 then we need to extend X by adding a col to the left
         X = np.hstack((X, np.atleast_2d(X[:,-1]).T))
         initial_col_extend = 1
-    
+
     extended_size = X.shape
 
     if nlevels == 0:
@@ -121,11 +124,11 @@ def dtwavexfm2(X, nlevels=3, biort='near_sym_a', qshift='qshift_a', include_scal
             if col_size % 4 != 0:
                 # Extend by 2 cols if no. of cols of LoLo are not divisable by 4
                 LoLo = np.hstack((np.atleast_2d(LoLo[:,0]).T, LoLo, np.atleast_2d(LoLo[:,-1]).T))
-         
+
             # Do even Qshift filters on rows.
             Lo = coldfilt(LoLo,h0b,h0a).T
             Hi = coldfilt(LoLo,h1b,h1a).T
-          
+
             # Do even Qshift filters on columns.
             LoLo = coldfilt(Lo,h0b,h0a).T
 
@@ -138,7 +141,7 @@ def dtwavexfm2(X, nlevels=3, biort='near_sym_a', qshift='qshift_a', include_scal
                 Yscale[0] = LoLo
 
     Yl = LoLo
-    
+
     if initial_row_extend == 1 and initial_col_extend == 1:
         logging.warn('The image entered is now a {0} NOT a {1}.'.format(
             'x'.join(list(str(s) for s in extended_size)),
@@ -165,8 +168,113 @@ def dtwavexfm2(X, nlevels=3, biort='near_sym_a', qshift='qshift_a', include_scal
     else:
         return Yl, tuple(Yh)
 
+
+def dtwaveifm2(Yl,Yh,biort=BIORT_DEFAULT,qshift=QSHIFT_DEFAULT,gain_mask=None):
+    """
+    Function to perform an n-level dual-tree complex wavelet (DTCWT) 2-D
+    reconstruction.
+
+    Z = dtwaveifm2(Yl,Yh,biort,qshift,gain_mask)
+
+        Yl -> The real lowpass image from the final level
+        Yh -> A tuple containing the 6 complex highpass subimages for each level.
+
+        biort ->  'antonini'   => Antonini 9,7 tap filters.
+                  'legall'     => LeGall 5,3 tap filters.
+                  'near_sym_a' => Near-Symmetric 5,7 tap filters.
+                  'near_sym_b' => Near-Symmetric 13,19 tap filters.
+
+        qshift -> 'qshift_06' => Quarter Sample Shift Orthogonal (Q-Shift) 10,10 tap filters, 
+                                 (only 6,6 non-zero taps).
+                  'qshift_a' =>  Q-shift 10,10 tap filters,
+                                 (with 10,10 non-zero taps, unlike qshift_06).
+                  'qshift_b' => Q-Shift 14,14 tap filters.
+                  'qshift_c' => Q-Shift 16,16 tap filters.
+                  'qshift_d' => Q-Shift 18,18 tap filters.
+
+        gain_mask -> Gain to be applied to each subband. 
+                     gain_mask(d,l) is gain for subband with direction d at level l.
+                     If gain_mask(d,l) == 0, no computation is performed for band (d,l).
+                     Default gain_mask = ones(6,length(Yh)). Note that both d and l
+                     are zero-indexed in this case.
+
+        Z -> Reconstructed real image matrix
+
+    If biort or qshift are not strings, there are interpreted as tuples of
+    vectors giving filter coefficients. In the biort case, this shold be (h0o,
+    g0o, h1o, g1o). In the qshift case, this should be (h0a, h0b, g0a, g0b,
+    h1a, h1b, g1a, g1b).
+
+    For example:  Z = dtwaveifm2(Yl,Yh,'near_sym_b','qshift_b')
+    performs a 3-level reconstruction from Yl,Yh using the 13,19-tap filters 
+    for level 1 and the Q-shift 14-tap filters for levels >= 2.
+
+    Nick Kingsbury and Cian Shaffrey
+    Cambridge University, May 2002
+
+    """
+
+    a = len(Yh) # No of levels.
+
+    if gain_mask is None:
+        gain_mask = np.ones((6,a)) # Default gain_mask.
+
+    gain_mask = np.array(gain_mask)
+
+    # Try to load coefficients if biort is a string parameter
+    if isinstance(biort, basestring):
+        h0o, g0o, h1o, g1o = _biort(biort)
+    else:
+        h0o, g0o, h1o, g1o = biort
+
+    # Try to load coefficients if qshift is a string parameter
+    if isinstance(qshift, basestring):
+        h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = _qshift(qshift)
+    else:
+        h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = qshift
+
+    current_level = a
+    Z = Yl
+
+    while current_level >= 2: # this ensures that for level -1 we never do the following
+        lh = c2q(Yh[current_level-1][:,:,[0, 5]],gain_mask[[0, 5],current_level-1])
+        hl = c2q(Yh[current_level-1][:,:,[2, 3]],gain_mask[[2, 3],current_level-1])
+        hh = c2q(Yh[current_level-1][:,:,[1, 4]],gain_mask[[1, 4],current_level-1])
+
+        # Do even Qshift filters on columns.
+        y1 = colifilt(Z,g0b,g0a) + colifilt(lh,g1b,g1a)
+        y2 = colifilt(hl,g0b,g0a) + colifilt(hh,g1b,g1a)
+        # Do even Qshift filters on rows.
+        Z = (colifilt(y1.T,g0b,g0a) + colifilt(y2.T,g1b,g1a)).T
+
+        # Check size of Z and crop as required
+        [row_size, col_size] = Z.shape
+        S = 2*np.array(Yh[current_level-2].shape)
+        if row_size != S[0]:    # check to see if this result needs to be cropped for the rows
+            Z = Z[1:-1,:]
+        if col_size != S[1]:    # check to see if this result needs to be cropped for the cols
+            Z = Z[:,1:-1]
+
+        if np.any(np.array(Z.shape) != S[:2]):
+            raise ValueError('Sizes of subbands are not valid for DTWAVEIFM2')
+        
+        current_level = current_level - 1
+
+    if current_level == 1:
+        lh = c2q(Yh[current_level-1][:,:,[0, 5]],gain_mask[[0, 5],current_level-1])
+        hl = c2q(Yh[current_level-1][:,:,[2, 3]],gain_mask[[2, 3],current_level-1])
+        hh = c2q(Yh[current_level-1][:,:,[1, 4]],gain_mask[[1, 4],current_level-1])
+
+        # Do odd top-level filters on columns.
+        y1 = colfilter(Z,g0o) + colfilter(lh,g1o)
+        y2 = colfilter(hl,g0o) + colfilter(hh,g1o)
+        # Do odd top-level filters on rows.
+        Z = (colfilter(y1.T,g0o) + colfilter(y2.T,g1o)).T
+
+    return Z
+
 #==========================================================================================
-#                       **********    INTERNAL FUNCTION    **********
+#                       **********    INTERNAL FUNCTIONS    **********
 #==========================================================================================
 
 def q2c(y):
@@ -191,4 +299,33 @@ def q2c(y):
 
     return z
 
+def c2q(w,gain):
+    """Scale by gain and convert from complex w(:,:,1:2) to real quad-numbers
+    in z.
+
+    Arrange pixels from the real and imag parts of the 2 subbands
+    into 4 separate subimages .
+     A----B     Re   Im of w(:,:,1)
+     |    |
+     |    |
+     C----D     Re   Im of w(:,:,2)
+
+    """
+
+    sw = w.shape
+    x = np.zeros((w.shape[0]*2, w.shape[1]*2))
+
+    if np.any(w[:] != 0) and np.any(gain[:] != 0):
+        sc = np.sqrt(0.5) * gain
+        P = w[:,:,0]*sc[0] + w[:,:,1]*sc[1]
+        Q = w[:,:,0]*sc[0] - w[:,:,1]*sc[1]
+
+        # Recover each of the 4 corners of the quads.
+        x[0::2, 0::2] = np.real(P)  # a = (A+C)*sc
+        x[0::2, 1::2] = np.imag(P)  # b = (B+D)*sc
+        x[1::2, 0::2] = np.imag(Q)  # c = (B-D)*sc
+        x[1::2, 1::2] = -np.real(Q) # d = (C-A)*sc
+
+    return x
+
 # vim:sw=4:sts=4:et
diff --git a/tests/testifm2.py b/tests/testifm2.py
new file mode 100644
index 0000000..7d0409c
--- /dev/null
+++ b/tests/testifm2.py
@@ -0,0 +1,30 @@
+import os
+from nose.tools import raises
+
+import numpy as np
+from dtcwt import dtwavexfm2, dtwaveifm2
+
+def setup():
+    global lena, lena_crop, Yl, Yh, Yl_crop, Yh_crop
+    lena = np.load(os.path.join(os.path.dirname(__file__), 'lena.npz'))['lena']
+    lena_crop = lena[:233, :301]
+    Yl, Yh = dtwavexfm2(lena)
+    Yl_crop, Yh_crop = dtwavexfm2(lena_crop)
+
+def test_lena_loaded():
+    assert lena.shape == (512, 512)
+    assert lena.min() >= 0
+    assert lena.max() <= 1
+    assert lena.dtype == np.float32
+
+def test_reconstruct():
+    # Reconstruction up to tolerance
+    lena_recon = dtwaveifm2(Yl, Yh)
+    assert np.all(np.abs(lena_recon - lena) < 1e-3)
+
+def test_reconstruct_ctop():
+    # Reconstruction up to tolerance
+    lena_recon = dtwaveifm2(Yl_crop, Yh_crop)[:lena_crop.shape[0], :lena_crop.shape[1]]
+    assert np.all(np.abs(lena_recon - lena_crop) < 1e-3)
+
+# vim:sw=4:sts=4:et

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/python-dtcwt.git



More information about the debian-science-commits mailing list