[python-dtcwt] 127/497: add rotational symmetry mod, dtwavexfm2b and dtwaveifm2b

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Tue Jul 21 18:05:56 UTC 2015


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch debian/sid
in repository python-dtcwt.

commit 595cb33b9e3b54546678a316775fc24bd92a6937
Author: Rich Wareham <rjw57 at cam.ac.uk>
Date:   Thu Nov 7 11:28:58 2013 +0000

    add rotational symmetry mod, dtwavexfm2b and dtwaveifm2b
    
    Originally based on PR#16 by @scf32. Pretty much a line-by-line copy of
    dtwavexfm2b() as it appears in the MATLAB toolbox. You may prefer to
    merge dtwavexfm2() and dtwavexfm2b() into one function which can handle
    the extra coefficients returned by qshift and biort. As it is, the use
    of 'qshift_b_bp' and 'near_sym_b_bp' (which I have included in .npz
    format) is pretty much hardcoded in dtwavexfm2b.
---
 docs/index.rst               |   1 +
 docs/modified_wavelets.png   | Bin 0 -> 132182 bytes
 docs/reference.rst           |   6 +
 docs/variant.rst             |  46 +++++++
 dtcwt/__init__.py            |   4 +-
 dtcwt/coeffs.py              |  28 +++--
 dtcwt/data/near_sym_b_bp.npz | Bin 0 -> 2491 bytes
 dtcwt/data/qshift_b_bp.npz   | Bin 0 -> 4039 bytes
 dtcwt/sampling.py            |   2 +-
 dtcwt/transform2d.py         | 293 +++++++++++++++++++++++++++++++++++++++++++
 dtcwt/utils.py               |  59 +++++++++
 matlab/README.md             |  13 +-
 matlab/gen_verif.m           |   7 +-
 matlab/near_sym_b_bp.mat     | Bin 0 -> 1280 bytes
 matlab/qshift_b_bp.mat       | Bin 0 -> 2144 bytes
 matlab/regen_verification.sh |  29 +++++
 matlab/verif_m_to_npz.py     |   8 ++
 tests/testagainstmatlab.py   |  16 ++-
 tests/verification.npz       | Bin 128072 -> 251136 bytes
 19 files changed, 490 insertions(+), 22 deletions(-)

diff --git a/docs/index.rst b/docs/index.rst
index 0bb8574..917d17b 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -33,6 +33,7 @@ Contents
 
     gettingstarted
     examples
+    variant
     reference
 
 
diff --git a/docs/modified_wavelets.png b/docs/modified_wavelets.png
new file mode 100644
index 0000000..45dc8e3
Binary files /dev/null and b/docs/modified_wavelets.png differ
diff --git a/docs/reference.rst b/docs/reference.rst
index 9022605..a2fc241 100644
--- a/docs/reference.rst
+++ b/docs/reference.rst
@@ -19,6 +19,12 @@ Image sampling
 .. automodule:: dtcwt.sampling
     :members:
 
+Miscellaneous utility functions
+```````````````````````````````
+
+.. automodule:: dtcwt.utils
+    :members:
+
 Low-level support functions
 ```````````````````````````
 
diff --git a/docs/variant.rst b/docs/variant.rst
new file mode 100644
index 0000000..8413b8c
--- /dev/null
+++ b/docs/variant.rst
@@ -0,0 +1,46 @@
+Variant transforms
+==================
+
+In addition to the basic 1, 2 and 3 dimensional DT-CWT, this library also
+supports a selection of variant transforms.
+
+Rotational symmetry modified wavelet transform
+----------------------------------------------
+
+For some applications, one may prefer the subband responses to be more rotationally similar. 
+
+In the original 2-D DTCWT, the 45 and 135 degree subbands have passbands whose centre frequencies 
+are somewhat further from the origin than those of the other four subbands. This results from 
+the combination of two highpass 1-D wavelet filters to produce 2-D wavelets. The remaining 
+subbands combine highpass and lowpass 1-D filters, and hence their centre frequencies are a 
+factor of approximately sqrt(1.8) closer to the origin of the frequency plane.
+
+The dtwavexfm2b() function employs an alternative bandpass 1-D filter in place of the highpass 
+filter for the appropriate subbands. The image below illustrates the relevant differences in impulse 
+and frequency responses[1].
+
+.. figure:: modified_wavelets.png
+
+
+Usage is very similar to the standard 2-D transform function, but the only supported parameters are 
+'near_sym_b_bp', 'qshift_b_bp'. These arguments are optional, but it is best practice to include them
+so that your intentions are clear (and becuase it is easier for others to spot than the difference 
+between 2() and 2b().
+
+.. code-block:: console
+
+    Yl, Yh = dtcwt.dtwavexfm2b(image, tfmlevel, 'near_sym_b_bp', 'qshift_b_bp')
+
+
+While the Hilbert transform property of the DTCWT is preserved, perfect reconstruction is lost.
+However, in applications such as machine vision, where all subsequent operations on the image
+take place in the transform domain, this is of relatively minor importance.
+
+For full details, refer to:
+
+[1] N. G. Kingsbury. Rotation-invariant local feature matching with complex
+wavelets. *In Proc. European Conference on Signal Processing (EUSIPCO)*,
+pages 901–904, 2006. 2, 18, 21
+
+
+
diff --git a/dtcwt/__init__.py b/dtcwt/__init__.py
index 96aca00..64e822e 100644
--- a/dtcwt/__init__.py
+++ b/dtcwt/__init__.py
@@ -1,6 +1,6 @@
 from .coeffs import biort, qshift
 from .transform1d import dtwavexfm, dtwaveifm
-from .transform2d import dtwavexfm2, dtwaveifm2
+from .transform2d import dtwavexfm2, dtwaveifm2, dtwavexfm2b, dtwaveifm2b
 from .transform3d import dtwavexfm3, dtwaveifm3
 
 __all__ = [
@@ -9,6 +9,8 @@ __all__ = [
 
     'dtwavexfm2',
     'dtwaveifm2',
+    'dtwavexfm2b',
+    'dtwaveifm2b',
 
     'dtwavexfm3',
     'dtwaveifm3',
diff --git a/dtcwt/coeffs.py b/dtcwt/coeffs.py
index 2b5432a..821d14f 100644
--- a/dtcwt/coeffs.py
+++ b/dtcwt/coeffs.py
@@ -25,14 +25,15 @@ def biort(name):
     :param name: a string specifying the wavelet family name
     :returns: a tuple of vectors giving filter coefficients
 
-    ============ ============================================
-    Name         Wavelet
-    ============ ============================================
-    antonini     Antonini 9,7 tap filters.
-    legall       LeGall 5,3 tap filters.
-    near_sym_a   Near-Symmetric 5,7 tap filters.
-    near_sym_b   Near-Symmetric 13,19 tap filters.
-    ============ ============================================
+    =============  ============================================
+    Name           Wavelet
+    =============  ============================================
+    antonini       Antonini 9,7 tap filters.
+    legall         LeGall 5,3 tap filters.
+    near_sym_a     Near-Symmetric 5,7 tap filters.
+    near_sym_b     Near-Symmetric 13,19 tap filters.
+    near_sym_b_bp  Near-Symmetric 13,19 tap filters + BP filter
+    =============  ============================================
 
     Return a tuple whose elements are a vector specifying the h0o, g0o, h1o and
     g1o coefficients.
@@ -41,7 +42,10 @@ def biort(name):
     :raises ValueError: if name specifies a :py:func:`qshift` wavelet.
 
     """
-    return _load_from_file(name, ('h0o', 'g0o', 'h1o', 'g1o'))
+    if name == 'near_sym_b_bp':
+        return _load_from_file(name, ('h0o', 'g0o', 'h1o', 'g1o', 'h2o', 'g2o'))
+    else:
+        return _load_from_file(name, ('h0o', 'g0o', 'h1o', 'g1o'))
 
 def qshift(name):
     """Load level >=2 wavelet by name,
@@ -59,6 +63,7 @@ def qshift(name):
     qshift_b     Q-Shift 14,14 tap filters.
     qshift_c     Q-Shift 16,16 tap filters.
     qshift_d     Q-Shift 18,18 tap filters.
+    qshift_b_bp  Q-Shift 18,18 tap filters + BP
     ============ ============================================
 
     Return a tuple whose elements are a vector specifying the h0a, h0b, g0a,
@@ -68,6 +73,9 @@ def qshift(name):
     :raises ValueError: if name specifies a :py:func:`biort` wavelet.
 
     """
-    return _load_from_file(name, ('h0a', 'h0b', 'g0a', 'g0b', 'h1a', 'h1b', 'g1a', 'g1b'))
+    if name == 'qshift_b_bp':
+        return _load_from_file(name, ('h0a', 'h0b', 'g0a', 'g0b', 'h1a', 'h1b', 'g1a', 'g1b','h2a', 'h2b'))
+    else:
+        return _load_from_file(name, ('h0a', 'h0b', 'g0a', 'g0b', 'h1a', 'h1b', 'g1a', 'g1b'))
 
 # vim:sw=4:sts=4:et
diff --git a/dtcwt/data/near_sym_b_bp.npz b/dtcwt/data/near_sym_b_bp.npz
new file mode 100644
index 0000000..0c09f1c
Binary files /dev/null and b/dtcwt/data/near_sym_b_bp.npz differ
diff --git a/dtcwt/data/qshift_b_bp.npz b/dtcwt/data/qshift_b_bp.npz
new file mode 100644
index 0000000..9c55205
Binary files /dev/null and b/dtcwt/data/qshift_b_bp.npz differ
diff --git a/dtcwt/sampling.py b/dtcwt/sampling.py
index 7d21dae..16fd8f8 100644
--- a/dtcwt/sampling.py
+++ b/dtcwt/sampling.py
@@ -22,7 +22,7 @@ DTHETA_DX_2D = np.array((_W1, _W0, _W0, _W0, _W0, _W1))
 DTHETA_DY_2D = np.array((_W0, _W0, _W1, -_W1, -_W0, -_W0))
 
 def _sample_clipped(im, xs, ys):
-    """Truncated and symmatric sampling."""
+    """Truncated and symmetric sampling."""
     sym_xs = reflect(xs, -0.5, im.shape[1]-0.5).astype(np.int)
     sym_ys = reflect(ys, -0.5, im.shape[0]-0.5).astype(np.int)
     return im[sym_ys, sym_xs, ...]
diff --git a/dtcwt/transform2d.py b/dtcwt/transform2d.py
index 1b238b7..b7bd02d 100644
--- a/dtcwt/transform2d.py
+++ b/dtcwt/transform2d.py
@@ -154,6 +154,186 @@ def dtwavexfm2(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, include
     else:
         return Yl, tuple(Yh)
 
+def dtwavexfm2b(X, nlevels=3, biort='near_sym_b_bp', qshift='qshift_b_bp', include_scale=False):
+    """Perform a *n*-level DTCWT-2D decompostion on a 2D matrix *X*.
+
+    :param X: 2D real array
+    :param nlevels: Number of levels of wavelet decomposition
+    :param biort: Level 1 wavelets to use. See :py:func:`biort`.
+    :param qshift: Level >= 2 wavelets to use. See :py:func:`qshift`.
+
+    :returns Yl: The real lowpass image from the final level
+    :returns Yh: A tuple containing the complex highpass subimages for each level.
+    :returns Yscale: If *include_scale* is True, a tuple containing real lowpass coefficients for every scale.
+
+    If *biort* or *qshift* are strings, they are used as an argument to the
+    :py:func:`biort` or :py:func:`qshift` functions. Otherwise, they are
+    interpreted as tuples of vectors giving filter coefficients. In the *biort*
+    case, this should be (h0o, g0o, h1o, g1o). In the *qshift* case, this should
+    be (h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b).
+
+    Example::
+
+        # Performs a 3-level transform on the real image X using the 13,19-tap
+        # filters for level 1 and the Q-shift 14-tap filters for levels >= 2.
+        Yl, Yh = dtwavexfm2b(X, 3, 'near_sym_b', 'qshift_b')
+
+    .. codeauthor:: Rich Wareham <rjw57 at cantab.net>, Aug 2013
+    .. codeauthor:: Nick Kingsbury, Cambridge University, Sept 2001
+    .. codeauthor:: Cian Shaffrey, Cambridge University, Sept 2001
+
+    """
+    X = np.atleast_2d(asfarray(X))
+
+    # Try to load coefficients if biort is a string parameter
+    try:
+        h0o, g0o, h1o, g1o, h2o, g2o = _biort(biort)
+    except TypeError:
+        h0o, g0o, h1o, g1o, h2o, g2o = biort
+
+    # Try to load coefficients if qshift is a string parameter
+    try:
+        h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b, h2a, h2b = _qshift(qshift)
+    except TypeError:
+        h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b, h2a, h2b = qshift
+
+    # Probably not the best way to do this, but I'm a noob.
+    try:
+        h2a, h2b
+        bp_qsh = 1
+    except:
+        bp_qsh = 0
+        
+    try:
+        h2o
+        bp_lev1 = 1
+    except:
+        bp_lev1 = 0
+          
+    try:
+        h1o
+        hi_lev1 = 1
+    except:
+        hi_lev1 = 0
+
+
+    original_size = X.shape
+
+    if len(X.shape) >= 3:
+        raise ValueError('The entered image is {0}, please enter each image slice separately.'.
+                format('x'.join(list(str(s) for s in X.shape))))
+
+    # The next few lines of code check to see if the image is odd in size, if so an extra ...
+    # row/column will be added to the bottom/right of the image
+    initial_row_extend = 0  #initialise
+    initial_col_extend = 0
+    if original_size[0] % 2 != 0:
+        # if X.shape[0] is not divisable by 2 then we need to extend X by adding a row at the bottom
+        X = np.vstack((X, X[[-1],:]))  # Any further extension will be done in due course.
+        initial_row_extend = 1
+
+    if original_size[1] % 2 != 0:
+        # if X.shape[1] is not divisable by 2 then we need to extend X by adding a col to the left
+        X = np.hstack((X, X[:,[-1]]))
+        initial_col_extend = 1
+
+    extended_size = X.shape
+
+    if nlevels == 0:
+        if include_scale:
+            return X, (), ()
+        else:
+            return X, ()
+
+    # initialise
+    Yh = [None,] * nlevels
+    if include_scale:
+        # this is only required if the user specifies a third output component.
+        Yscale = [None,] * nlevels
+
+    complex_dtype = appropriate_complex_type_for(X)
+
+    if nlevels >= 1:
+        # Do odd top-level filters on cols.
+        Lo = colfilter(X,h0o).T
+        if hi_lev1 == 1:
+             Hi = colfilter(X,h1o).T
+        if bp_lev1 == 1:
+             Ba = colfilter(X,h2o).T
+
+
+        # Do odd top-level filters on rows.
+        LoLo = colfilter(Lo,h0o).T
+        Yh[0] = np.zeros((LoLo.shape[0] >> 1, LoLo.shape[1] >> 1, 6), dtype=complex_dtype)
+        Yh[0][:,:,[0, 5]] = q2c(colfilter(Hi,h0o).T)     # Horizontal pair
+        Yh[0][:,:,[2, 3]] = q2c(colfilter(Lo,h1o).T)     # Vertical pair
+        if bp_lev1 == 1:                                      # Diagonal pair
+            Yh[0][:,:,[1, 4]] = q2c(colfilter(Ba,h2o).T)
+        else:
+            Yh[0][:,:,[1, 4]] = q2c(colfilter(Hi,h1o).T)     
+
+        if include_scale:
+            Yscale[0] = LoLo
+
+    for level in xrange(1, nlevels):
+        row_size, col_size = LoLo.shape
+        if row_size % 4 != 0:
+            # Extend by 2 rows if no. of rows of LoLo are not divisable by 4
+            LoLo = np.vstack((LoLo[[0],:], LoLo, LoLo[[-1],:]))
+
+        if col_size % 4 != 0:
+            # Extend by 2 cols if no. of cols of LoLo are not divisable by 4
+            LoLo = np.hstack((LoLo[:,[0]], LoLo, LoLo[:,[-1]]))
+
+        # Do even Qshift filters on rows.
+        Lo = coldfilt(LoLo,h0b,h0a).T
+        Hi = coldfilt(LoLo,h1b,h1a).T
+        if bp_qsh == 1:
+            Ba = coldfilt(LoLo,h2b,h2a).T
+            
+        # Do even Qshift filters on columns.
+        LoLo = coldfilt(Lo,h0b,h0a).T
+
+        Yh[level] = np.zeros((LoLo.shape[0]>>1, LoLo.shape[1]>>1, 6), dtype=complex_dtype)
+        Yh[level][:,:,[0, 5]] = q2c(coldfilt(Hi,h0b,h0a).T)  # Horizontal
+        Yh[level][:,:,[2, 3]] = q2c(coldfilt(Lo,h1b,h1a).T)  # Vertical
+
+        if bp_qsh == 1:
+            Yh[level][:,:,[1, 4]] = q2c(coldfilt(Ba,h2b,h2a).T) # Diagonal bandpass
+        else:
+            Yh[level][:,:,[1, 4]] = q2c(coldfilt(Hi,h1b,h1a).T)  # Diagonal highpass
+
+        if include_scale:
+            Yscale[level] = LoLo
+
+    Yl = LoLo
+
+    if initial_row_extend == 1 and initial_col_extend == 1:
+        logging.warn('The image entered is now a {0} NOT a {1}.'.format(
+            'x'.join(list(str(s) for s in extended_size)),
+            'x'.join(list(str(s) for s in original_size))))
+        logging.warn(
+            'The bottom row and rightmost column have been duplicated, prior to decomposition.')
+
+    if initial_row_extend == 1 and initial_col_extend == 0:
+        logging.warn('The image entered is now a {0} NOT a {1}.'.format(
+            'x'.join(list(str(s) for s in extended_size)),
+            'x'.join(list(str(s) for s in original_size))))
+        logging.warn(
+            'The bottom row has been duplicated, prior to decomposition.')
+
+    if initial_row_extend == 0 and initial_col_extend == 1:
+        logging.warn('The image entered is now a {0} NOT a {1}.'.format(
+            'x'.join(list(str(s) for s in extended_size)),
+            'x'.join(list(str(s) for s in original_size))))
+        logging.warn(
+            'The rightmost column has been duplicated, prior to decomposition.')
+
+    if include_scale:
+        return Yl, tuple(Yh), tuple(Yscale)
+    else:
+        return Yl, tuple(Yh)
+
 
 def dtwaveifm2(Yl,Yh,biort=DEFAULT_BIORT,qshift=DEFAULT_QSHIFT,gain_mask=None):
     """Perform an *n*-level dual-tree complex wavelet (DTCWT) 2D
@@ -250,6 +430,119 @@ def dtwaveifm2(Yl,Yh,biort=DEFAULT_BIORT,qshift=DEFAULT_QSHIFT,gain_mask=None):
 
     return Z
 
+def dtwaveifm2b(Yl,Yh,biort='near_sym_b_bp',qshift='qshift_b_bp',gain_mask=None):
+    """Perform an *n*-level dual-tree complex wavelet (DTCWT) 2D
+    reconstruction, for use with symmetry-modified DTCWT subbands.
+
+    :param Yl: The real lowpass subband from the final level
+    :param Yh: A sequence containing the complex highpass subband for each level.
+    :param biort: Level 1 wavelets to use. See :py:func:`biort`.
+    :param qshift: Level >= 2 wavelets to use. See :py:func:`qshift`.
+    :param gain_mask: Gain to be applied to each subband.
+
+    :returns Z: Reconstructed real array
+
+    The (*d*, *l*)-th element of *gain_mask* is gain for subband with direction
+    *d* at level *l*. If gain_mask[d,l] == 0, no computation is performed for
+    band (d,l). Default *gain_mask* is all ones. Note that both *d* and *l* are
+    zero-indexed.
+
+    If *biort* or *qshift* are strings, they are used as an argument to the
+    :py:func:`biort` or :py:func:`qshift` functions. Otherwise, they are
+    interpreted as tuples of vectors giving filter coefficients. In the *biort*
+    case, this should be (h0o, g0o, h1o, g1o). In the *qshift* case, this should
+    be (h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b).
+
+    Example::
+
+        # Performs a 3-level reconstruction from Yl,Yh using the 13,19-tap
+        # filters for level 1 and the Q-shift 14-tap filters for levels >= 2.
+        Z = dtwaveifm2(Yl, Yh, 'near_sym_b', 'qshift_b')
+
+    .. codeauthor:: Rich Wareham <rjw57 at cantab.net>, Aug 2013
+    .. codeauthor:: Nick Kingsbury, Cambridge University, May 2002
+    .. codeauthor:: Cian Shaffrey, Cambridge University, May 2002
+
+    """
+    a = len(Yh) # No of levels.
+
+    if gain_mask is None:
+        gain_mask = np.ones((6,a)) # Default gain_mask.
+
+    gain_mask = np.array(gain_mask)
+
+    # Try to load coefficients if biort is a string parameter
+    try:
+        h0o, g0o, h1o, g1o, h2o, g2o = _biort(biort)
+    except TypeError:
+        h0o, g0o, h1o, g1o, h2o, g2o = biort
+
+    # Try to load coefficients if qshift is a string parameter
+    try:
+        h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b, h2a, h2b = _qshift(qshift)
+    except TypeError:
+        h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b, h2a, h2b = qshift
+
+    # Probably not the best way to do this, but I'm a noob.
+    try:
+        h2a, h2b
+        bp_qsh = 1
+    except:
+        bp_qsh = 0
+        
+    try:
+        h2o
+        bp_lev1 = 1
+    except:
+        bp_lev1 = 0
+
+    current_level = a
+    Z = Yl
+
+    while current_level >= 2: # this ensures that for level 1 we never do the following
+        lh = c2q(Yh[current_level-1][:,:,[0, 5]], gain_mask[[0, 5], current_level-1])
+        hl = c2q(Yh[current_level-1][:,:,[2, 3]], gain_mask[[2, 3], current_level-1])
+        hh = c2q(Yh[current_level-1][:,:,[1, 4]], gain_mask[[1, 4], current_level-1])
+
+        # Do even Qshift filters on columns.
+        y1 = colifilt(Z,g0b,g0a) + colifilt(lh,g1b,g1a)
+        y2 = colifilt(hl,g0b,g0a) + colifilt(hh,g1b,g1a)
+
+        # Do even Qshift filters on rows.
+        Z = (colifilt(y1.T,g0b,g0a) + colifilt(y2.T,g1b,g1a)).T
+
+        # Check size of Z and crop as required
+        [row_size, col_size] = Z.shape
+        S = 2*np.array(Yh[current_level-2].shape)
+        if row_size != S[0]:    # check to see if this result needs to be cropped for the rows
+            Z = Z[1:-1,:]
+        if col_size != S[1]:    # check to see if this result needs to be cropped for the cols
+            Z = Z[:,1:-1]
+
+        if np.any(np.array(Z.shape) != S[:2]):
+            raise ValueError('Sizes of subbands are not valid for DTWAVEIFM2')
+        
+        current_level = current_level - 1
+
+    if current_level == 1:
+        lh = c2q(Yh[current_level-1][:,:,[0, 5]],gain_mask[[0, 5],current_level-1])
+        hl = c2q(Yh[current_level-1][:,:,[2, 3]],gain_mask[[2, 3],current_level-1])
+        hh = c2q(Yh[current_level-1][:,:,[1, 4]],gain_mask[[1, 4],current_level-1])
+
+        # Do odd top-level filters on columns.
+        y1 = colfilter(Z,g0o) + colfilter(lh,g1o)
+        if bp_lev1 == 1:
+            y2 = colfilter(hl,g0o)
+            y2bp = colfilter(hh,g2o)
+            # Do odd top-level filters on rows.
+            Z = (colfilter(y1.T,g0o) + colfilter(y2.T,g1o)).T + colfilter(y2bp.T,g2o).T
+        else:
+            y2 = colfilter(hl,g0o) + colfilter(hh,g1o)
+            # Do odd top-level filters on rows.
+            Z = (colfilter(y1.T,g0o) + colfilter(y2.T,g1o)).T
+
+    return Z
+
 #==========================================================================================
 #                       **********    INTERNAL FUNCTIONS    **********
 #==========================================================================================
diff --git a/dtcwt/utils.py b/dtcwt/utils.py
new file mode 100644
index 0000000..8b716f1
--- /dev/null
+++ b/dtcwt/utils.py
@@ -0,0 +1,59 @@
+""" Useful utilities for testing the 2-D DTCWT with synthetic images"""
+
+__all__ = ( 'drawedge', 'drawcirc',)
+
+import numpy as np
+
+
+def drawedge(theta,r,w,N):
+    """Generate an image of size N * N pels, of an edge going from 0 to 1
+    in height at theta degrees to the horizontal (top of image = 1 if angle = 0).
+    r is a two-element vector, it is a coordinate in ij coords through
+    which the step should pass.
+    The shape of the intensity step is half a raised cosine w pels wide (w>=1).
+    
+    T. E . Gale's enhancement to drawedge() for MATLAB, transliterated 
+    to Python by S. C. Forshaw, Nov. 2013. """
+    
+    # convert theta from degrees to radians
+    thetar = np.array(theta * np.pi / 180)  
+    
+    # Calculate image centre from given width
+    imCentre = (np.array([N,N]).T - 1) / 2 + 1 
+    
+    # Calculate values to subtract from the plane
+    r = np.array([np.cos(thetar), np.sin(thetar)])*(-1) * (r - imCentre) 
+
+    # check width of raised cosine section
+    w = np.maximum(1,w)
+    
+    
+    ramp = np.arange(0,N) - (N+1)/2
+    hgrad = np.sin(thetar)*(-1) * np.ones([N,1])
+    vgrad = np.cos(thetar)*(-1) * np.ones([1,N])
+    plane = ((hgrad * ramp) - r[0]) + ((ramp * vgrad).T - r[1])
+    x = 0.5 + 0.5 * np.sin(np.minimum(np.maximum(plane*(np.pi/w), np.pi/(-2)), np.pi/2))
+    
+    return x
+
+def drawcirc(r,w,du,dv,N):
+    
+    """Generate an image of size N*N pels, containing a circle 
+    radius r pels and centred at du,dv relative
+    to the centre of the image.  The edge of the circle is a cosine shaped 
+    edge of width w (from 10 to 90% points).
+    
+    Python implementation by S. C. Forshaw, November 2013."""
+    
+    # check value of w to avoid dividing by zero
+    w = np.maximum(w,1)
+    
+    #x plane
+    x = np.ones([N,1]) * ((np.arange(0,N,1, dtype='float') - (N+1) / 2 - dv) / r)
+    
+    # y vector
+    y = (((np.arange(0,N,1, dtype='float') - (N+1) / 2 - du) / r) * np.ones([1,N])).T
+
+    # Final circle image plane
+    p = 0.5 + 0.5 * np.sin(np.minimum(np.maximum((np.exp(np.array([-0.5]) * (x**2 + y**2)).T - np.exp((-0.5))) * (r * 3 / w), np.pi/(-2)), np.pi/2))
+    return p
diff --git a/matlab/README.md b/matlab/README.md
index 7e38978..9191edb 100644
--- a/matlab/README.md
+++ b/matlab/README.md
@@ -5,13 +5,14 @@ data. Obviously you must have MATLAB installed but you also require the DT-CWT
 toolbox available on [Nick Kingsbury's](http://www-sigproc.eng.cam.ac.uk/~ngk/)
 home page.
 
-The ``gen_verif.m`` script is not sophisticated; they simply exercise a number
-of the DT-CWT toolbox routines and saves the result to ``verification.mat``.
-Run it with a command like the following:
+The ``regen_verification.sh`` script will run MATLAB and Python to re-generate
+the verification data. It uses the scripts ``gen_verif.m`` and
+``verif_m_to_npz.py``. You should use this script but you may need to configure
+it slightly to set the location of MATLAB on your system and the DTCWT
+toolboxes.
 
-```console
-$ MATLABPATH=/path/to/dtcwt_toolbox4_3 /path/to/matlab -nosplash -nodesktop -r "run /path/to/gen_verif; quit"
-```
+The ``gen_verif.m`` script is not sophisticated; it simply exercises a number
+of the DT-CWT toolbox routines and saves the result to ``verification.mat``.
 
 The ``verif_m_to_npz.py`` script uses SciPy to load the MATLAB output and
 convert it into NumPy's native ``.npz`` format. This file is used by the test
diff --git a/matlab/gen_verif.m b/matlab/gen_verif.m
index df161c4..be2d4d8 100644
--- a/matlab/gen_verif.m
+++ b/matlab/gen_verif.m
@@ -25,6 +25,9 @@ lena_colifilt = colifilt(lena, g0b, g0a);
 
 [lena_Yl, lena_Yh, lena_Yscale] = dtwavexfm2(lena, 4, 'near_sym_a', 'qshift_a');
 
-save('verification.mat', 'lena_coldfilt', 'lena_colifilt', 'lena_Yl', 'lena_Yh', 'lena_Yscale');
+near_sym_b_bp = load('near_sym_b_bp');
+qshift_b_bp = load('qshift_b_bp');
 
-%% 
+[lena_Ylb, lena_Yhb, lena_Yscaleb] = dtwavexfm2b(lena, 4, 'near_sym_b_bp', 'qshift_b_bp');
+
+save('verification.mat', 'lena_coldfilt', 'lena_colifilt', 'lena_Yl', 'lena_Yh', 'lena_Yscale', 'lena_Ylb', 'lena_Yhb', 'lena_Yscaleb');
diff --git a/matlab/near_sym_b_bp.mat b/matlab/near_sym_b_bp.mat
new file mode 100644
index 0000000..2be1601
Binary files /dev/null and b/matlab/near_sym_b_bp.mat differ
diff --git a/matlab/qshift_b_bp.mat b/matlab/qshift_b_bp.mat
new file mode 100644
index 0000000..b9bfa0e
Binary files /dev/null and b/matlab/qshift_b_bp.mat differ
diff --git a/matlab/regen_verification.sh b/matlab/regen_verification.sh
new file mode 100755
index 0000000..a13414b
--- /dev/null
+++ b/matlab/regen_verification.sh
@@ -0,0 +1,29 @@
+#!/bin/bash
+
+## CONFIGURATION
+
+# Path to dtcwt toolbox
+DTCWT_TOOLBOX=$HOME/Downloads/dtcwt_toolbox4_3
+
+# Path to dtcwt keypoints toolbox
+DTCWT_KEYPOINTS=$HOME/Downloads/DTCWTkeypoints
+
+# Path to MATLAB
+MATLAB=/opt/MATLAB/R2013b/bin/matlab
+
+## END OF CONFIGURATION
+
+# Update MATLAB path
+export MATLABPATH="$MATLABPATH:$DTCWT_TOOLBOX:$DTCWT_KEYPOINTS"
+
+# Change to this directory
+cd "`dirname "${BASH_SOURCE[0]}"`"
+
+echo "Generating verification data in MATLAB..."
+"$MATLAB" -nosplash -nodesktop -r "gen_verif; quit"
+
+echo "Converting to NumPy format..."
+python verif_m_to_npz.py
+
+echo "Done"
+
diff --git a/matlab/verif_m_to_npz.py b/matlab/verif_m_to_npz.py
index c291584..9ad94df 100755
--- a/matlab/verif_m_to_npz.py
+++ b/matlab/verif_m_to_npz.py
@@ -28,6 +28,14 @@ for idx, v in enumerate(verif['lena_Yscale']):
     verif['lena_Yscale_{0}'.format(idx)] = v[0]
 del verif['lena_Yscale']
 
+for idx, v in enumerate(verif['lena_Yhb']):
+    verif['lena_Yhb_{0}'.format(idx)] = v[0]
+del verif['lena_Yhb']
+
+for idx, v in enumerate(verif['lena_Yscaleb']):
+    verif['lena_Yscaleb_{0}'.format(idx)] = v[0]
+del verif['lena_Yscaleb']
+
 summaries = dict((k, summarise_mat(v)) for k, v in verif.iteritems())
 
 savez('../tests/verification.npz', **summaries)
diff --git a/tests/testagainstmatlab.py b/tests/testagainstmatlab.py
index 0170195..da05d62 100644
--- a/tests/testagainstmatlab.py
+++ b/tests/testagainstmatlab.py
@@ -3,7 +3,7 @@ from nose.tools import raises
 from nose.plugins.attrib import attr
 
 import numpy as np
-from dtcwt import dtwavexfm2, dtwaveifm2, biort, qshift
+from dtcwt import dtwavexfm2, dtwaveifm2, dtwavexfm2b, dtwaveifm2b, biort, qshift
 from dtcwt.lowlevel import coldfilt, colifilt
 
 from .util import assert_almost_equal, summarise_mat
@@ -30,6 +30,8 @@ from .util import assert_almost_equal, summarise_mat
 # matrix greatly reduces the amount of storage required.
 
 # Summary matching requires greater tolerance
+
+# We allow a little more tolerance for comparison with MATLAB
 TOLERANCE = 1e-5
 
 def assert_almost_equal_to_summary(a, summary, *args, **kwargs):
@@ -41,7 +43,7 @@ def setup():
 
     global verif
     verif = np.load(os.path.join(os.path.dirname(__file__), 'verification.npz'))
-
+    
 def test_lena_loaded():
     assert lena.shape == (512, 512)
     assert lena.min() >= 0
@@ -74,4 +76,14 @@ def test_dtwavexfm2():
     for idx, a in enumerate(Yscale):
         assert_almost_equal_to_summary(a, verif['lena_Yscale_{0}'.format(idx)], tolerance=TOLERANCE)
 
+def test_dtwavexfm2b():
+    Yl, Yh, Yscale = dtwavexfm2b(lena, 4, 'near_sym_b_bp', 'qshift_b_bp', include_scale=True)
+    assert_almost_equal_to_summary(Yl, verif['lena_Ylb'], tolerance=TOLERANCE)
+
+    for idx, a in enumerate(Yh):
+        assert_almost_equal_to_summary(a, verif['lena_Yhb_{0}'.format(idx)], tolerance=TOLERANCE)
+
+    for idx, a in enumerate(Yscale):
+        assert_almost_equal_to_summary(a, verif['lena_Yscaleb_{0}'.format(idx)], tolerance=TOLERANCE)
+
 # vim:sw=4:sts=4:et
diff --git a/tests/verification.npz b/tests/verification.npz
index 664902d..6843631 100644
Binary files a/tests/verification.npz and b/tests/verification.npz differ

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/python-dtcwt.git



More information about the debian-science-commits mailing list