[python-dtcwt] 66/497: make single level 3D transforms work

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Tue Jul 21 18:05:50 UTC 2015


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch debian/sid
in repository python-dtcwt.

commit 2e9bcf8629f1fcd1ae81536d973f94fc0c2a1ff9
Author: Rich Wareham <rjw57 at cam.ac.uk>
Date:   Thu Aug 8 14:27:23 2013 +0100

    make single level 3D transforms work
---
 dtcwt/lowlevel.py    |   6 ++
 dtcwt/transform3d.py | 172 +++++++++++++++++++++++++++++++++++++++++++++++----
 tests/testxfm3.py    |  33 +++++++---
 3 files changed, 191 insertions(+), 20 deletions(-)

diff --git a/dtcwt/lowlevel.py b/dtcwt/lowlevel.py
index f7e8816..d2dbb47 100644
--- a/dtcwt/lowlevel.py
+++ b/dtcwt/lowlevel.py
@@ -29,6 +29,12 @@ def _column_convolve(X, h):
     i.e. those values unaffected by zero padding.
 
     """
+
+    # This function should give the same result as:
+    #
+    from scipy.signal import convolve2d
+    return convolve2d(X, as_column_vector(h), 'valid')
+
     h = h.flatten()
     h_size = h.shape[0]
     full_size = X.shape[0] + h_size - 1
diff --git a/dtcwt/transform3d.py b/dtcwt/transform3d.py
index 7d446f6..97760f6 100644
--- a/dtcwt/transform3d.py
+++ b/dtcwt/transform3d.py
@@ -7,17 +7,17 @@ from dtcwt import biort as _biort, qshift as _qshift
 from dtcwt.defaults import DEFAULT_BIORT, DEFAULT_QSHIFT
 from dtcwt.lowlevel import colfilter, coldfilt, colifilt
 
-def dtwavexfm3(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, include_scale=False):
+def dtwavexfm3(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, ext_mode=4):
     """Perform a *n*-level DTCWT-3D decompostion on a 3D matrix *X*.
 
     :param X: 3D real matrix/Image of shape (N, M)
     :param nlevels: Number of levels of wavelet decomposition
     :param biort: Level 1 wavelets to use. See :py:func:`biort`.
     :param qshift: Level >= 2 wavelets to use. See :py:func:`qshift`.
+    :param ext_mode: Extension mode. See below.
 
     :returns Yl: The real lowpass image from the final level
-    :returns Yh: A tuple containing the (N, M, 6) shape complex highpass subimages for each level.
-    :returns Yscale: If *include_scale* is True, a tuple containing real lowpass coefficients for every scale.
+    :returns Yh: A tuple containing the (N, M, 7) shape complex highpass subimages for each level.
 
     If *biort* or *qshift* are strings, they are used as an argument to the
     :py:func:`biort` or :py:func:`qshift` functions. Otherwise, they are
@@ -25,6 +25,15 @@ def dtwavexfm3(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, include
     case, this should be (h0o, g0o, h1o, g1o). In the *qshift* case, this should
     be (h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b).
 
+    There are two values for *ext_mode*, either 4 or 8. If *ext_mode* = 4,
+    check whether 1st level is divisible by 2 (if not we raise a
+    ``ValueError``). Also check whether from 2nd level onwards, the coefs can
+    be divided by 4. If any dimension size is not a multiple of 4, append extra
+    coefs by repeating the edges. If *ext_mode* = 8, check whether 1st level is
+    divisible by 4 (if not we raise a ``ValueError``). Also check whether from
+    2nd level onwards, the coeffs can be divided by 8. If any dimension size is
+    not a multiple of 8, append extra coeffs by repeating the edges twice.
+
     Example::
 
         # Performs a 3-level transform on the real image X using the 13,19-tap
@@ -34,9 +43,36 @@ def dtwavexfm3(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, include
     .. codeauthor:: Rich Wareham <rjw57 at cantab.net>, Aug 2013
 
     """
-    return None
+    X = np.atleast_3d(X)
+
+    # Try to load coefficients if biort is a string parameter
+    try:
+        h0o, g0o, h1o, g1o = _biort(biort)
+    except TypeError:
+        h0o, g0o, h1o, g1o = biort
+
+    # Try to load coefficients if qshift is a string parameter
+    try:
+        h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = _qshift(qshift)
+    except TypeError:
+        h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = qshift
+
+    # Check value of ext_mode. TODO: this should really be an enum :S
+    if ext_mode != 4 and ext_mode != 8:
+        raise ValueError('ext_mode must be one of 4 or 8')
+
+    Yl = X
+    Yh = [None,] * nlevels
 
-def dtwaveifm3(Yl,Yh,biort=DEFAULT_BIORT,qshift=DEFAULT_QSHIFT,gain_mask=None):
+    # level is 0-indexed
+    for level in xrange(nlevels):
+        # Transform
+        if level == 0:
+            Yl, Yh[level] = _level1_xfm(Yl, h0o, h1o, ext_mode)
+
+    return Yl, tuple(Yh)
+
+def dtwaveifm3(Yl, Yh, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT):
     """Perform an *n*-level dual-tree complex wavelet (DTCWT) 3D
     reconstruction.
 
@@ -44,15 +80,9 @@ def dtwaveifm3(Yl,Yh,biort=DEFAULT_BIORT,qshift=DEFAULT_QSHIFT,gain_mask=None):
     :param Yh: A sequence containing the complex highpass subband for each level.
     :param biort: Level 1 wavelets to use. See :py:func:`biort`.
     :param qshift: Level >= 2 wavelets to use. See :py:func:`qshift`.
-    :param gain_mask: Gain to be applied to each subband.
 
     :returns Z: Reconstructed real image matrix.
 
-    The (*d*, *l*)-th element of *gain_mask* is gain for subband with direction
-    *d* at level *l*. If gain_mask[d,l] == 0, no computation is performed for
-    band (d,l). Default *gain_mask* is all ones. Note that both *d* and *l* are
-    zero-indexed.
-
     If *biort* or *qshift* are strings, they are used as an argument to the
     :py:func:`biort` or :py:func:`qshift` functions. Otherwise, they are
     interpreted as tuples of vectors giving filter coefficients. In the *biort*
@@ -68,6 +98,124 @@ def dtwaveifm3(Yl,Yh,biort=DEFAULT_BIORT,qshift=DEFAULT_QSHIFT,gain_mask=None):
     .. codeauthor:: Rich Wareham <rjw57 at cantab.net>, Aug 2013
 
     """
-    return None
+    # Try to load coefficients if biort is a string parameter
+    try:
+        h0o, g0o, h1o, g1o = _biort(biort)
+    except TypeError:
+        h0o, g0o, h1o, g1o = biort
+
+    # Try to load coefficients if qshift is a string parameter
+    try:
+        h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = _qshift(qshift)
+    except TypeError:
+        h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = qshift
+
+    X = Yl
+
+    # level is 0-indexed
+    nlevels = len(Yh)
+    for level in xrange(nlevels):
+        # Transform
+        if level == 0:
+            X = _level1_ifm(Yl, Yh[level], g0o, g1o)
+
+    return X
+
+def _level1_xfm(X, h0o, h1o, ext_mode):
+    """Perform level 1 of the 3d transform.
+
+    """
+    # Check shape of input according to ext_mode. Note that shape of X is
+    # double original input in each direction.
+    if ext_mode == 4 and np.any(np.fmod(X.shape, 2) != 0):
+        raise ValueError('Input shape should be a multiple of 2 in each direction when ext_mode == 4')
+    elif ext_mode == 8 and np.any(np.fmod(X.shape, 4) != 0):
+        raise ValueError('Input shape should be a multiple of 4 in each direction when ext_mode == 8')
+
+    # Create work area
+    work = np.zeros(np.asarray(X.shape) * 2, dtype=X.dtype)
+
+    # Form some useful slices
+    s0a = slice(None, work.shape[0] >> 1)
+    s1a = slice(None, work.shape[1] >> 1)
+    s2a = slice(None, work.shape[2] >> 1)
+    s0b = slice(work.shape[0] >> 1, None)
+    s1b = slice(work.shape[1] >> 1, None)
+    s2b = slice(work.shape[2] >> 1, None)
+
+    # Assign input
+    work[s0a, s1a, s2a] = X
+
+    # Loop over 2nd dimension extracting 2D slice from first and 3rd dimensions
+    for f in xrange(work.shape[1] >> 1):
+        # extract slice
+        y = work[s0a, f, s2a].T
+
+        # Do odd top-level filters on 3rd dim. The order here is important
+        # since the second filtering will modify the elements of y as well
+        # since y is merely a view onto work.
+        work[s0a, f, s2b] = colfilter(y, h1o).T
+        work[s0a, f, s2a] = colfilter(y, h0o).T
+
+    # Loop over 3rd dimension extracting 2D slice from first and 2nd dimensions
+    for f in xrange(work.shape[2]):
+        # Do odd top-level filters on rows.
+        y1 = work[s0a, s1a, f].T
+        y2 = np.vstack((colfilter(y1, h0o), colfilter(y1, h1o))).T
+
+        # Do odd top-level filters on columns.
+        work[s0a, :, f] = colfilter(y2, h0o)
+        work[s0b, :, f] = colfilter(y2, h1o)
+
+    # Return appropriate slices of output
+    return (work[s0a, s1a, s2a],                # LLL
+        np.concatenate((
+            work[s0a, s1b, s2a, np.newaxis],    # HLL
+            work[s0b, s1a, s2a, np.newaxis],    # LHL
+            work[s0b, s1b, s2a, np.newaxis],    # HHL
+            work[s0a, s1a, s2b, np.newaxis],    # LLH
+            work[s0a, s1b, s2b, np.newaxis],    # HLH
+            work[s0b, s1a, s2b, np.newaxis],    # LHH
+            work[s0b, s1b, s2b, np.newaxis],    # HLH
+        ), axis=3))
+
+def _level1_ifm(Yl, Yh, g0o, g1o):
+    """Perform level 1 of the inverse 3d transform.
+
+    """
+    # Create work area
+    work = np.zeros(np.asarray(Yl.shape) * 2, dtype=Yl.dtype)
+
+    # Form some useful slices
+    s0a = slice(None, work.shape[0] >> 1)
+    s1a = slice(None, work.shape[1] >> 1)
+    s2a = slice(None, work.shape[2] >> 1)
+    s0b = slice(work.shape[0] >> 1, None)
+    s1b = slice(work.shape[1] >> 1, None)
+    s2b = slice(work.shape[2] >> 1, None)
+
+    # Assign regions of work area
+    work[s0a, s1a, s2a] = Yl
+    work[s0a, s1b, s2a] = Yh[:,:,:,0]
+    work[s0b, s1a, s2a] = Yh[:,:,:,1]
+    work[s0b, s1b, s2a] = Yh[:,:,:,2]
+    work[s0a, s1a, s2b] = Yh[:,:,:,3]
+    work[s0a, s1b, s2b] = Yh[:,:,:,4]
+    work[s0b, s1a, s2b] = Yh[:,:,:,5]
+    work[s0b, s1b, s2b] = Yh[:,:,:,6]
+
+    for f in xrange(work.shape[2]):
+        # Do odd top-level filters on rows.
+        y = colfilter(work[:, s1a, f].T, g0o) + colfilter(work[:, s1b, f].T, g1o)
+
+        # Do odd top-level filters on columns.
+        work[s0a, s1a, f] = colfilter(y[:, s0a].T, g0o) + colfilter(y[:, s0b].T, g1o)
+
+    for f in xrange(work.shape[1]>>1):
+        # Do odd top-level filters on 3rd dim.
+        y = work[s0a, f, :].T
+        work[s0a, f, s2a] = (colfilter(y[s2a, :], g0o) + colfilter(y[s2b, :], g1o)).T
+
+    return work[s0a, s1a, s2a]
 
 # vim:sw=4:sts=4:et
diff --git a/tests/testxfm3.py b/tests/testxfm3.py
index ac1c78d..9dc8453 100644
--- a/tests/testxfm3.py
+++ b/tests/testxfm3.py
@@ -5,21 +5,38 @@ from nose.plugins.attrib import attr
 import numpy as np
 from dtcwt import dtwavexfm3, dtwaveifm3
 
+GRID_SIZE=32
+SPHERE_RAD=25
+
 def setup():
     global sphere
-    X, Y, Z = np.meshgrid(np.arange(-63,64), np.arange(-63,64), np.arange(-63,64))
+
+    grid = np.arange(-(GRID_SIZE>>1), (GRID_SIZE>>1))
+    X, Y, Z = np.meshgrid(grid, grid, grid)
+
+    Y *= 1.2
+    Z /= 1.2
 
     r = np.sqrt(X*X + Y*Y + Z*Z)
-    sphere = np.where(r <= 55, 1.0, 0.0)
+    sphere = np.where(r <= SPHERE_RAD, 1.0, 0.0)
 
 def test_sphere():
     # Check general aspects of sphere are OK
-    assert sphere.shape == (127,127,127)
-    assert sphere.min() >= 0
-    assert sphere.max() <= 1
+    assert sphere.shape == (GRID_SIZE,GRID_SIZE,GRID_SIZE)
+    assert sphere.min() == 0
+    assert sphere.max() == 1
+
+def test_simple_level_1_xfm():
+    # Just tests that the transform broadly works and gives expected size output
+    Yl, Yh = dtwavexfm3(sphere, 1)
+    assert Yl.shape == (GRID_SIZE,GRID_SIZE,GRID_SIZE)
+    assert len(Yh) == 1
 
-    # Check volume of sphere is ok to within 5%
-    sphere_vol = (4.0/3.0) * np.pi * 55*55*55
-    assert np.abs(np.sum(sphere.flatten()) - sphere_vol) < 5e-2*sphere_vol
+def test_simple_level_1_recon():
+    # Test for perfect reconstruction with 1 level
+    Yl, Yh = dtwavexfm3(sphere, 1)
+    sphere_recon = dtwaveifm3(Yl, Yh)
+    assert sphere.size == sphere_recon.size
+    assert np.max(np.abs(sphere - sphere_recon)) < 1e-11
 
 # vim:sw=4:sts=4:et

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/python-dtcwt.git



More information about the debian-science-commits mailing list