[python-dtcwt] 217/497: fix modified inverse transform

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Tue Jul 21 18:06:07 UTC 2015


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch debian/sid
in repository python-dtcwt.

commit 622464fa35c6c5a5e03dce3c65c39404bc2a87f5
Author: Rich Wareham <rjw57 at cam.ac.uk>
Date:   Fri Nov 22 12:53:29 2013 +0000

    fix modified inverse transform
    
    The modified 'bp' variant of the inverse transform actually completely
    ignored the shifted wavelets up until layer 1. This is because the
    wavelets g2a and g2b weren't actually loaded or used.
---
 dtcwt/backend/backend_numpy/transform2d.py  | 36 ++++++++++++++++++-----------
 dtcwt/backend/backend_opencl/transform2d.py | 12 +++++-----
 dtcwt/coeffs.py                             |  2 +-
 3 files changed, 29 insertions(+), 21 deletions(-)

diff --git a/dtcwt/backend/backend_numpy/transform2d.py b/dtcwt/backend/backend_numpy/transform2d.py
index 4b821c4..cd93092 100644
--- a/dtcwt/backend/backend_numpy/transform2d.py
+++ b/dtcwt/backend/backend_numpy/transform2d.py
@@ -56,15 +56,15 @@ class Transform2d(Transform2dBase):
         else:
             raise ValueError('Biort wavelet must have 6 or 4 components.')
 
-        # If qshift has 10 elements instead of 8, then it's a modified
+        # If qshift has 12 elements instead of 8, then it's a modified
         # rotationally symmetric wavelet
         # FIXME: there's probably a nicer way to do this
         if len(self.qshift) == 8:
             h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = self.qshift
-        elif len(self.qshift) == 10:
-            h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b, h2a, h2b = self.qshift
+        elif len(self.qshift) == 12:
+            h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b, h2a, h2b = self.qshift[:10]
         else:
-            raise ValueError('Qshift wavelet must have 10 or 8 components.')
+            raise ValueError('Qshift wavelet must have 12 or 8 components.')
 
         X = np.atleast_2d(asfarray(X))
         original_size = X.shape
@@ -136,7 +136,7 @@ class Transform2d(Transform2dBase):
             # Do even Qshift filters on rows.
             Lo = _BACKEND.coldfilt(LoLo,h0b,h0a).T
             Hi = _BACKEND.coldfilt(LoLo,h1b,h1a).T
-            if len(self.qshift) >= 10:
+            if len(self.qshift) >= 12:
                 Ba = _BACKEND.coldfilt(LoLo,h2b,h2a).T
 
             # Do even Qshift filters on columns.
@@ -145,7 +145,7 @@ class Transform2d(Transform2dBase):
             Yh[level] = np.zeros((LoLo.shape[0]>>1, LoLo.shape[1]>>1, 6), dtype=complex_dtype)
             Yh[level][:,:,0:6:5] = q2c(_BACKEND.coldfilt(Hi,h0b,h0a).T)  # Horizontal
             Yh[level][:,:,2:4:1] = q2c(_BACKEND.coldfilt(Lo,h1b,h1a).T)  # Vertical
-            if len(self.qshift) >= 10:
+            if len(self.qshift) >= 12:
                 Yh[level][:,:,1:5:3] = q2c(_BACKEND.coldfilt(Ba,h2b,h2a).T)  # Diagonal   
             else:
                 Yh[level][:,:,1:5:3] = q2c(_BACKEND.coldfilt(Hi,h1b,h1a).T)  # Diagonal   
@@ -220,15 +220,15 @@ class Transform2d(Transform2dBase):
         else:
             raise ValueError('Biort wavelet must have 6 or 4 components.')
 
-        # If qshift has 10 elements instead of 8, then it's a modified
+        # If qshift has 12 elements instead of 8, then it's a modified
         # rotationally symmetric wavelet
         # FIXME: there's probably a nicer way to do this
         if len(self.qshift) == 8:
             h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = self.qshift
-        elif len(self.qshift) == 10:
-            h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b, h2a, h2b = self.qshift
+        elif len(self.qshift) == 12:
+            h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b, h2a, h2b, g2a, g2b = self.qshift
         else:
-            raise ValueError('Qshift wavelet must have 10 or 8 components.')
+            raise ValueError('Qshift wavelet must have 12 or 8 components.')
 
         current_level = a
         Z = Yl
@@ -240,10 +240,18 @@ class Transform2d(Transform2dBase):
 
             # Do even Qshift filters on columns.
             y1 = _BACKEND.colifilt(Z,g0b,g0a) + _BACKEND.colifilt(lh,g1b,g1a)
-            y2 = _BACKEND.colifilt(hl,g0b,g0a) + _BACKEND.colifilt(hh,g1b,g1a)
 
-            # Do even Qshift filters on rows.
-            Z = (_BACKEND.colifilt(y1.T,g0b,g0a) + _BACKEND.colifilt(y2.T,g1b,g1a)).T
+            if len(self.qshift) >= 12:
+                y2 = _BACKEND.colifilt(hl,g0b,g0a)
+                y2bp = _BACKEND.colifilt(hh,g2b,g2a)
+
+                # Do even Qshift filters on rows.
+                Z = (_BACKEND.colifilt(y1.T,g0b,g0a) + _BACKEND.colifilt(y2.T,g1b,g1a) + _BACKEND.colifilt(y2bp.T, g2b, g2a)).T
+            else:
+                y2 = _BACKEND.colifilt(hl,g0b,g0a) + _BACKEND.colifilt(hh,g1b,g1a)
+
+                # Do even Qshift filters on rows.
+                Z = (_BACKEND.colifilt(y1.T,g0b,g0a) + _BACKEND.colifilt(y2.T,g1b,g1a)).T
 
             # Check size of Z and crop as required
             [row_size, col_size] = Z.shape
@@ -266,7 +274,7 @@ class Transform2d(Transform2dBase):
             # Do odd top-level filters on columns.
             y1 = _BACKEND.colfilter(Z,g0o) + _BACKEND.colfilter(lh,g1o)
 
-            if len(self.qshift) >= 10:
+            if len(self.biort) >= 6:
                 y2 = _BACKEND.colfilter(hl,g0o)
                 y2bp = _BACKEND.colfilter(hh,g2o)
 
diff --git a/dtcwt/backend/backend_opencl/transform2d.py b/dtcwt/backend/backend_opencl/transform2d.py
index 7d7b729..d771034 100644
--- a/dtcwt/backend/backend_opencl/transform2d.py
+++ b/dtcwt/backend/backend_opencl/transform2d.py
@@ -143,15 +143,15 @@ class Transform2d(Transform2dNumPy):
         else:
             raise ValueError('Biort wavelet must have 6 or 4 components.')
 
-        # If qshift has 10 elements instead of 8, then it's a modified
+        # If qshift has 12 elements instead of 8, then it's a modified
         # rotationally symmetric wavelet
         # FIXME: there's probably a nicer way to do this
         if len(self.qshift) == 8:
             h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = self.qshift
-        elif len(self.qshift) == 10:
-            h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b, h2a, h2b = self.qshift
+        elif len(self.qshift) == 12:
+            h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b, h2a, h2b = self.qshift[:10]
         else:
-            raise ValueError('Qshift wavelet must have 10 or 8 components.')
+            raise ValueError('Qshift wavelet must have 12 or 8 components.')
 
         original_size = X.shape
 
@@ -234,13 +234,13 @@ class Transform2d(Transform2dNumPy):
             # Do even Qshift filters on rows.
             Lo = axis_convolve_dfilter(LoLo,h0b,axis=0,queue=queue)
             Hi = axis_convolve_dfilter(LoLo,h1b,axis=0,queue=queue)
-            if len(self.qshift) >= 10:
+            if len(self.qshift) >= 12:
                 Ba = axis_convolve_dfilter(LoLo,h2b,axis=0,queue=queue)
 
             # Do even Qshift filters on columns.
             LoLo = axis_convolve_dfilter(Lo,h0b,axis=1,queue=queue)
 
-            if len(self.qshift) >= 10:
+            if len(self.qshift) >= 12:
                 diag = axis_convolve_dfilter(Ba,h2b,axis=1,queue=queue)
             else:
                 diag = axis_convolve_dfilter(Hi,h1b,axis=1,queue=queue)
diff --git a/dtcwt/coeffs.py b/dtcwt/coeffs.py
index 4b70545..163e4de 100644
--- a/dtcwt/coeffs.py
+++ b/dtcwt/coeffs.py
@@ -80,7 +80,7 @@ def qshift(name):
 
     """
     if name == 'qshift_b_bp':
-        return _load_from_file(name, ('h0a', 'h0b', 'g0a', 'g0b', 'h1a', 'h1b', 'g1a', 'g1b','h2a', 'h2b'))
+        return _load_from_file(name, ('h0a', 'h0b', 'g0a', 'g0b', 'h1a', 'h1b', 'g1a', 'g1b', 'h2a', 'h2b', 'g2a','g2b'))
     else:
         return _load_from_file(name, ('h0a', 'h0b', 'g0a', 'g0b', 'h1a', 'h1b', 'g1a', 'g1b'))
 

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/python-dtcwt.git



More information about the debian-science-commits mailing list