[python-dtcwt] 373/497: Merge remote-tracking branch 'timseries/master' into timseries-test-merge

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Tue Jul 21 18:06:29 UTC 2015


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch debian/sid
in repository python-dtcwt.

commit 2f308f084c82b97bc5b42e49170c7ae052fe0f21
Merge: ae6f664 99ffbc1
Author: Rich Wareham <rjw57 at cam.ac.uk>
Date:   Fri Feb 7 16:36:40 2014 +0000

    Merge remote-tracking branch 'timseries/master' into timseries-test-merge
    
    Conflicts:
    	dtcwt/backend/base.py
    	dtcwt/numpy/__init__.py
    	dtcwt/opencl/__init__.py
    	tests/testagainstmatlab.py

 dtcwt/numpy/__init__.py             |   2 +
 dtcwt/numpy/transform2d.py          |   8 +-
 dtcwt/{ => numpy}/transform3d.py    | 439 ++++++++++++++++++++----------------
 dtcwt/opencl/__init__.py            |   2 +
 dtcwt/{ => opencl}/transform3d.py   | 382 ++++++++++++++++++-------------
 dtcwt/transform3d.py                |  44 ++--
 examples/3d_dtcwt_directionality.py | 147 ++++++------
 matlab/gen_verif.m                  |  15 +-
 matlab/qbgn/gaussian.m              |  27 +++
 matlab/qbgn/gen_qbgn.m              |  31 +++
 matlab/qbgn/softquant.m             |  30 +++
 matlab/verif_m_to_npz.py            |  30 ++-
 tests/qbgn.mat                      | Bin 0 -> 826597 bytes
 tests/testagainstmatlab.py          |  27 ++-
 tests/util.py                       |  12 +
 tests/verification.npz              | Bin 279070 -> 4510508 bytes
 16 files changed, 726 insertions(+), 470 deletions(-)

diff --cc dtcwt/numpy/__init__.py
index 4cdda5e,df98a81..56300bd
--- a/dtcwt/numpy/__init__.py
+++ b/dtcwt/numpy/__init__.py
@@@ -4,12 -4,11 +4,14 @@@ be available
  
  """
  
 -from .transform2d import TransformDomainSignal, Transform2d
 +from .common import Pyramid
 +from .transform1d import Transform1d
 +from .transform2d import Transform2d
+ from .transform3d import Transform3d
  
  __all__ = [
 -    'TransformDomainSignal',
 +    'Pyramid',
 +    'Transform1d',
      'Transform2d',
+     'Transform3d',
  ]
diff --cc dtcwt/numpy/transform2d.py
index f824bf5,c0ff392..a3da89f
--- a/dtcwt/numpy/transform2d.py
+++ b/dtcwt/numpy/transform2d.py
@@@ -319,10 -316,11 +320,11 @@@ def q2c(y)
      return z
  
  def c2q(w,gain):
-     """Scale by gain and convert from complex w(:,:,1:2) to real quad-numbers
+     """
+     Scale by gain and convert from complex w(:,:,1:2) to real quad-numbers
      in z.
  
 -    Arrange pixels from the real and imag parts of the 2 subbands
 +    Arrange pixels from the real and imag parts of the 2 highpasses
      into 4 separate subimages .
       A----B     Re   Im of w(:,:,1)
       |    |
diff --cc dtcwt/numpy/transform3d.py
index 7274c16,adf0454..f3d7412
--- a/dtcwt/numpy/transform3d.py
+++ b/dtcwt/numpy/transform3d.py
@@@ -5,165 -3,220 +3,217 @@@ import loggin
  
  from six.moves import xrange
  
 -from dtcwt.backend.base import TransformDomainSignal, ReconstructedSignal, Transform3d as Transform3dBase
++from dtcwt.numpy.common import Pyramid
  from dtcwt.coeffs import biort as _biort, qshift as _qshift
  from dtcwt.defaults import DEFAULT_BIORT, DEFAULT_QSHIFT
- from dtcwt.numpy.lowlevel import colfilter, coldfilt, colifilt
- from dtcwt.utils import asfarray
- 
- def dtwavexfm3(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, ext_mode=4, discard_level_1=False):
-     """Perform a *n*-level DTCWT-3D decompostion on a 3D matrix *X*.
- 
-     :param X: 3D real array-like object
-     :param nlevels: Number of levels of wavelet decomposition
-     :param biort: Level 1 wavelets to use. See :py:func:`biort`.
-     :param qshift: Level >= 2 wavelets to use. See :py:func:`qshift`.
-     :param ext_mode: Extension mode. See below.
-     :param discard_level_1: True if level 1 high-pass bands are to be discarded.
- 
-     :returns Yl: The real lowpass image from the final level
-     :returns Yh: A tuple containing the complex highpass subimages for each level.
- 
-     Each element of *Yh* is a 4D complex array with the 4th dimension having
-     size 28. The 3D slice ``Yh[l][:,:,:,d]`` corresponds to the complex higpass
-     coefficients for direction d at level l where d and l are both 0-indexed.
- 
-     If *biort* or *qshift* are strings, they are used as an argument to the
-     :py:func:`biort` or :py:func:`qshift` functions. Otherwise, they are
-     interpreted as tuples of vectors giving filter coefficients. In the *biort*
-     case, this should be (h0o, g0o, h1o, g1o). In the *qshift* case, this should
-     be (h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b).
- 
-     There are two values for *ext_mode*, either 4 or 8. If *ext_mode* = 4,
-     check whether 1st level is divisible by 2 (if not we raise a
-     ``ValueError``). Also check whether from 2nd level onwards, the coefs can
-     be divided by 4. If any dimension size is not a multiple of 4, append extra
-     coefs by repeating the edges. If *ext_mode* = 8, check whether 1st level is
-     divisible by 4 (if not we raise a ``ValueError``). Also check whether from
-     2nd level onwards, the coeffs can be divided by 8. If any dimension size is
-     not a multiple of 8, append extra coeffs by repeating the edges twice.
- 
-     If *discard_level_1* is True the highpass coefficients at level 1 will be
-     discarded. (And, in fact, will never be calculated.) This turns the
-     transform from being 8:1 redundant to being 1:1 redundant at the cost of
-     no-longer allowing perfect reconstruction. If this option is selected then
-     `Yh[0]` will be `None`. Note that :py:func:`dtwaveifm3` will accepts
-     `Yh[0]` being `None` and will treat it as being zero.
- 
-     Example::
+ from dtcwt.utils import appropriate_complex_type_for, asfarray
  
-         # Performs a 3-level transform on the real 3D array X using the 13,19-tap
-         # filters for level 1 and the Q-shift 14-tap filters for levels >= 2.
-         Yl, Yh = dtwavexfm3(X, 3, 'near_sym_b', 'qshift_b')
 -from dtcwt.backend.backend_numpy.lowlevel import LowLevelBackendNumPy
++from dtcwt.numpy.lowlevel import *
  
-     .. codeauthor:: Rich Wareham <rjw57 at cantab.net>, Aug 2013
-     .. codeauthor:: Huizhong Chen, Jan 2009
-     .. codeauthor:: Nick Kingsbury, Cambridge University, July 1999.
+ import pdb
  
 -# Use the NumPy low-level backend
 -_BACKEND = LowLevelBackendNumPy()
 -
 -class Transform3d(Transform3dBase):
++class Transform3d(object):
      """
-     X = np.atleast_3d(asfarray(X))
- 
-     # Try to load coefficients if biort is a string parameter
-     try:
-         h0o, g0o, h1o, g1o = _biort(biort)
-     except TypeError:
-         h0o, g0o, h1o, g1o = biort
- 
-     # Try to load coefficients if qshift is a string parameter
-     try:
-         h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = _qshift(qshift)
-     except TypeError:
-         h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = qshift
- 
-     # Check value of ext_mode. TODO: this should really be an enum :S
-     if ext_mode != 4 and ext_mode != 8:
-         raise ValueError('ext_mode must be one of 4 or 8')
- 
-     Yl = X
-     Yh = [None,] * nlevels
- 
-     # level is 0-indexed
-     for level in xrange(nlevels):
-         # Transform
-         if level == 0 and discard_level_1:
-             Yl = _level1_xfm_no_highpass(Yl, h0o, h1o, ext_mode)
-         elif level == 0 and not discard_level_1:
-             Yl, Yh[level] = _level1_xfm(Yl, h0o, h1o, ext_mode)
-         else:
-             Yl, Yh[level] = _level2_xfm(Yl, h0a, h0b, h1a, h1b, ext_mode)
- 
-     return Yl, tuple(Yh)
- 
- def dtwaveifm3(Yl, Yh, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, ext_mode=4):
-     """Perform an *n*-level dual-tree complex wavelet (DTCWT) 3D
-     reconstruction.
- 
-     :param Yl: The real lowpass subband from the final level
-     :param Yh: A sequence containing the complex highpass subband for each level.
-     :param biort: Level 1 wavelets to use. See :py:func:`biort`.
-     :param qshift: Level >= 2 wavelets to use. See :py:func:`qshift`.
-     :param ext_mode: Extension mode. See below.
+     An implementation of the 3D DT-CWT via NumPy. *biort* and *qshift* are the
+     wavelets which parameterise the transform. Valid values are documented in
+     :py:func:`dtcwt.dtwavexfm3`.
  
-     :returns Z: Reconstructed real image matrix.
- 
-     If *biort* or *qshift* are strings, they are used as an argument to the
-     :py:func:`biort` or :py:func:`qshift` functions. Otherwise, they are
-     interpreted as tuples of vectors giving filter coefficients. In the *biort*
-     case, this should be (h0o, g0o, h1o, g1o). In the *qshift* case, this should
-     be (h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b).
- 
-     There are two values for *ext_mode*, either 4 or 8. If *ext_mode* = 4,
-     check whether 1st level is divisible by 2 (if not we raise a
-     ``ValueError``). Also check whether from 2nd level onwards, the coefs can
-     be divided by 4. If any dimension size is not a multiple of 4, append extra
-     coefs by repeating the edges. If *ext_mode* = 8, check whether 1st level is
-     divisible by 4 (if not we raise a ``ValueError``). Also check whether from
-     2nd level onwards, the coeffs can be divided by 8. If any dimension size is
-     not a multiple of 8, append extra coeffs by repeating the edges twice.
+     """
+     def __init__(self, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, ext_mode=4):
+         # Load bi-orthogonal wavelets
+         try:
+             self.biort = _biort(biort)
+         except TypeError:
+             self.biort = biort
+ 
+         # Load quarter sample shift wavelets
+         try:
+             self.qshift = _qshift(qshift)
+         except TypeError:
+             self.qshift = qshift
+ 
+         self.ext_mode = ext_mode
+             
+     def forward(self, X, nlevels=3, include_scale=False, discard_level_1=False):
+         """Perform a *n*-level DTCWT-3D decompostion on a 3D matrix *X*.
+         
+         :param X: 3D real array-like object
+         :param nlevels: Number of levels of wavelet decomposition
+         :param biort: Level 1 wavelets to use. See :py:func:`biort`.
+         :param qshift: Level >= 2 wavelets to use. See :py:func:`qshift`.
+         :param discard_level_1: True if level 1 high-pass bands are to be discarded.
+         
+         :returns Yl: The real lowpass image from the final level
+         :returns Yh: A tuple containing the complex highpass subimages for each level.
+         
+         Each element of *Yh* is a 4D complex array with the 4th dimension having
+         size 28. The 3D slice ``Yh[l][:,:,:,d]`` corresponds to the complex higpass
+         coefficients for direction d at level l where d and l are both 0-indexed.
+         
+         If *biort* or *qshift* are strings, they are used as an argument to the
+         :py:func:`biort` or :py:func:`qshift` functions. Otherwise, they are
+         interpreted as tuples of vectors giving filter coefficients. In the *biort*
+         case, this should be (h0o, g0o, h1o, g1o). In the *qshift* case, this should
+         be (h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b).
+         
+         There are two values for *ext_mode*, either 4 or 8. If *ext_mode* = 4,
+         check whether 1st level is divisible by 2 (if not we raise a
+         ``ValueError``). Also check whether from 2nd level onwards, the coefs can
+         be divided by 4. If any dimension size is not a multiple of 4, append extra
+         coefs by repeating the edges. If *ext_mode* = 8, check whether 1st level is
+         divisible by 4 (if not we raise a ``ValueError``). Also check whether from
+         2nd level onwards, the coeffs can be divided by 8. If any dimension size is
+         not a multiple of 8, append extra coeffs by repeating the edges twice.
+         
+         If *discard_level_1* is True the highpass coefficients at level 1 will not be
+         discarded. (And, in fact, will never be calculated.) This turns the
+         transform from being 8:1 redundant to being 1:1 redundant at the cost of
+         no-longer allowing perfect reconstruction. If this option is selected then
+         `Yh[0]` will be `None`. Note that :py:func:`dtwaveifm3` will accepts
+         `Yh[0]` being `None` and will treat it as being zero.
+         
+         Example::
+         
+         # Performs a 3-level transform on the real 3D array X using the 13,19-tap
+         # filters for level 1 and the Q-shift 14-tap filters for levels >= 2.
+         Yl, Yh = dtwavexfm3(X, 3, 'near_sym_b', 'qshift_b')
+         
+         .. codeauthor:: Rich Wareham <rjw57 at cantab.net>, Aug 2013
+         .. codeauthor:: Huizhong Chen, Jan 2009
+         .. codeauthor:: Nick Kingsbury, Cambridge University, July 1999.
+         
+         """
+         X = np.atleast_3d(asfarray(X))
+ 
+         # If biort has 6 elements instead of 4, then it's a modified
+         # rotationally symmetric wavelet
+         # FIXME: there's probably a nicer way to do this
+         if len(self.biort) == 4:
+             h0o, g0o, h1o, g1o = self.biort
+         elif len(self.biort) == 6:
+             h0o, g0o, h1o, g1o, h2o, g2o = self.biort
+         else:
+             raise ValueError('Biort wavelet must have 6 or 4 components.')
+         # If qshift has 12 elements instead of 8, then it's a modified
+         # rotationally symmetric wavelet
+         # FIXME: there's probably a nicer way to do this
+         if len(self.qshift) == 8:
+             h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = self.qshift
+         elif len(self.qshift) == 12:
+             h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b, h2a, h2b = self.qshift[:10]
+         else:
+             raise ValueError('Qshift wavelet must have 12 or 8 components.')
  
-     Example::
+         # Check value of ext_mode. TODO: this should really be an enum :S
+         if self.ext_mode != 4 and self.ext_mode != 8:
+             raise ValueError('ext_mode must be one of 4 or 8')
  
+         Yl = X
+         Yh = [None,] * nlevels
+         
+         if include_scale:
+             # this is only required if the user specifies a third output component.
+             Yscale = [None,] * nlevels
+ 
+         #pdb.set_trace()
+         # level is 0-indexed
+         for level in xrange(nlevels):
+             # Transform
+             if level == 0 and discard_level_1:
+                 Yl = _level1_xfm_no_highpass(Yl, h0o, h1o, self.ext_mode)
+             elif level == 0 and not discard_level_1:
+                 Yl, Yh[level] = _level1_xfm(Yl, h0o, h1o, self.ext_mode)
+             else:
+                 Yl, Yh[level] = _level2_xfm(Yl, h0a, h0b, h1a, h1b, self.ext_mode)
+             if include_scale:
+                 Yscale[level] = Yl.copy()
+         
+                 #Yh[nlevels+1]=1 #to throw an error for debugging in nose
+         if include_scale:
 -            return TransformDomainSignal(Yl, tuple(Yh), tuple(Yscale))
++            return Pyramid(Yl, tuple(Yh), tuple(Yscale))
+         else: 
 -            return TransformDomainSignal(Yl, tuple(Yh))
++            return Pyramid(Yl, tuple(Yh))
+ 
+     def inverse(self, td_signal):
+         """Perform an *n*-level dual-tree complex wavelet (DTCWT) 3D
+         reconstruction.
+         
+         :param Yl: The real lowpass subband from the final level
+         :param Yh: A sequence containing the complex highpass subband for each level.
+         :param biort: Level 1 wavelets to use. See :py:func:`biort`.
+         :param qshift: Level >= 2 wavelets to use. See :py:func:`qshift`.
+         :param ext_mode: Extension mode. See below.
+         
+         :returns Z: Reconstructed real image matrix.
+         
+         If *biort* or *qshift* are strings, they are used as an argument to the
+         :py:func:`biort` or :py:func:`qshift` functions. Otherwise, they are
+         interpreted as tuples of vectors giving filter coefficients. In the *biort*
+         case, this should be (h0o, g0o, h1o, g1o). In the *qshift* case, this should
+         be (h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b).
+         
+         There are two values for *ext_mode*, either 4 or 8. If *ext_mode* = 4,
+         check whether 1st level is divisible by 2 (if not we raise a
+         ``ValueError``). Also check whether from 2nd level onwards, the coefs can
+         be divided by 4. If any dimension size is not a multiple of 4, append extra
+         coefs by repeating the edges. If *ext_mode* = 8, check whether 1st level is
+         divisible by 4 (if not we raise a ``ValueError``). Also check whether from
+         2nd level onwards, the coeffs can be divided by 8. If any dimension size is
+         not a multiple of 8, append extra coeffs by repeating the edges twice.
+         
+         Example::
+         
          # Performs a 3-level reconstruction from Yl,Yh using the 13,19-tap
          # filters for level 1 and the Q-shift 14-tap filters for levels >= 2.
          Z = dtwaveifm3(Yl, Yh, 'near_sym_b', 'qshift_b')
- 
-     .. codeauthor:: Rich Wareham <rjw57 at cantab.net>, Aug 2013
-     .. codeauthor:: Huizhong Chen, Jan 2009
-     .. codeauthor:: Nick Kingsbury, Cambridge University, July 1999.
- 
-     """
-     # Try to load coefficients if biort is a string parameter
-     try:
-         h0o, g0o, h1o, g1o = _biort(biort)
-     except TypeError:
-         h0o, g0o, h1o, g1o = biort
- 
-     # Try to load coefficients if qshift is a string parameter
-     try:
-         h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = _qshift(qshift)
-     except TypeError:
-         h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = qshift
- 
-     X = Yl
- 
-     nlevels = len(Yh)
-     # level is 0-indexed but interpreted starting from the *last* level
-     for level in xrange(nlevels):
-         # Transform
-         if level == nlevels-1: # non-obviously this is the 'first' level
-             if Yh[-level-1] is None:
-                 Yl = _level1_ifm_no_highpass(Yl, g0o, g1o)
-             else:
-                 Yl = _level1_ifm(Yl, Yh[-level-1], g0o, g1o)
+         
+         .. codeauthor:: Rich Wareham <rjw57 at cantab.net>, Aug 2013
+         .. codeauthor:: Huizhong Chen, Jan 2009
+         .. codeauthor:: Nick Kingsbury, Cambridge University, July 1999.
+         
+         """
+         Yl = td_signal.lowpass
 -        Yh = td_signal.subbands
++        Yh = td_signal.highpasses
+ 
+         # Try to load coefficients if biort is a string parameter
+         if len(self.biort) == 4:
+             h0o, g0o, h1o, g1o = self.biort
+         elif len(self.biort) == 6:
+             h0o, g0o, h1o, g1o, h2o, g2o = self.biort
          else:
-             # Gracefully handle the Yh[0] is None case.
-             if Yh[-level-2] is not None:
-                 prev_shape = Yh[-level-2].shape
+             raise ValueError('Biort wavelet must have 6 or 4 components.')
+         
+         # If qshift has 12 elements instead of 8, then it's a modified
+         # rotationally symmetric wavelet
+         # FIXME: there's probably a nicer way to do this
+         if len(self.qshift) == 8:
+             h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = self.qshift
+         elif len(self.qshift) == 12:
+             h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b, h2a, h2b, g2a, g2b = self.qshift
+         else:
+             raise ValueError('Qshift wavelet must have 12 or 8 components.')
+ 
+         X = Yl
+ 
+         nlevels = len(Yh)
+         # level is 0-indexed but interpreted starting from the *last* level
+         for level in xrange(nlevels):
+             # Transform
+             if level == nlevels-1: # non-obviously this is the 'first' level
+                 if Yh[-level-1] is None:
+                     Yl = _level1_ifm_no_highpass(Yl, g0o, g1o)
+                 else:
+                     Yl = _level1_ifm(Yl, Yh[-level-1], g0o, g1o)
              else:
-                 prev_shape = np.array(Yh[-level-1].shape) * 2
- 
-             Yl = _level2_ifm(Yl, Yh[-level-1], g0a, g0b, g1a, g1b, ext_mode, prev_shape)
+                 # Gracefully handle the Yh[0] is None case.
+                 if Yh[-level-2] is not None:
+                     prev_shape = Yh[-level-2].shape
+                 else:
+                     prev_shape = np.array(Yh[-level-1].shape) * 2
+                     
+                 Yl = _level2_ifm(Yl, Yh[-level-1], g0a, g0b, g1a, g1b, self.ext_mode, prev_shape)
  
-     return Yl
 -        return ReconstructedSignal(Yl)
++        return Yl
  
  def _level1_xfm(X, h0o, h1o, ext_mode):
      """Perform level 1 of the 3d transform.
@@@ -227,11 -279,13 +276,13 @@@
      for f in xrange(work.shape[2]):
          # Do odd top-level filters on rows.
          y1 = work[x0a, x1a, f].T
 -        y2 = np.vstack((_BACKEND.colfilter(y1, h0o), _BACKEND.colfilter(y1, h1o))).T
 +        y2 = np.vstack((colfilter(y1, h0o), colfilter(y1, h1o))).T
  
          # Do odd top-level filters on columns.
 -        work[s0a, :, f] = _BACKEND.colfilter(y2, h0o)
 -        work[s0b, :, f] = _BACKEND.colfilter(y2, h1o)
 +        work[s0a, :, f] = colfilter(y2, h0o)
 +        work[s0b, :, f] = colfilter(y2, h1o)
+         #if f==2:
+         #work[:,:,work.shape[2]+1]=1 #to throw an error so we can inspect y in the unit test
  
      # Return appropriate slices of output
      return (
@@@ -264,12 -318,12 +315,12 @@@ def _level1_xfm_no_highpass(X, h0o, h1o
      for f in xrange(X.shape[1]):
          # extract slice
          y = X[:, f, :].T
 -        out[:, f, :] = _BACKEND.colfilter(y, h0o).T
 +        out[:, f, :] = colfilter(y, h0o).T
- 
-     # Loop over 3rd dimension extracting 2D slice from first and 2nd dimensions
+         
+   # Loop over 3rd dimension extracting 2D slice from first and 2nd dimensions
      for f in xrange(X.shape[2]):
 -        y = _BACKEND.colfilter(out[:, :, f].T, h0o).T
 -        out[:, :, f] = _BACKEND.colfilter(y, h0o)
 +        y = colfilter(out[:, :, f].T, h0o).T
 +        out[:, :, f] = colfilter(y, h0o)
  
      return out
  
@@@ -321,11 -375,11 +372,11 @@@ def _level2_xfm(X, h0a, h0b, h1a, h1b, 
      for f in xrange(work.shape[2]):
          # Do even Qshift filters on rows.
          y1 = work[:, :, f].T
 -        y2 = np.vstack((_BACKEND.coldfilt(y1, h0b, h0a), _BACKEND.coldfilt(y1, h1b, h1a))).T
 +        y2 = np.vstack((coldfilt(y1, h0b, h0a), coldfilt(y1, h1b, h1a))).T
- 
+         
          # Do even Qshift filters on columns.
 -        work[s0a, :, f] = _BACKEND.coldfilt(y2, h0b, h0a)
 -        work[s0b, :, f] = _BACKEND.coldfilt(y2, h1b, h1a)
 +        work[s0a, :, f] = coldfilt(y2, h0b, h0a)
 +        work[s0b, :, f] = coldfilt(y2, h1b, h1a)
  
      # Return appropriate slices of output
      return (
@@@ -443,10 -497,10 +494,10 @@@ def _level2_ifm(Yl, Yh, g0a, g0b, g1a, 
  
      for f in xrange(work.shape[2]):
          # Do even Qshift filters on rows.
 -        y = _BACKEND.colifilt(work[:, s1a, f].T, g0b, g0a) + _BACKEND.colifilt(work[:, s1b, f].T, g1b, g1a)
 +        y = colifilt(work[:, s1a, f].T, g0b, g0a) + colifilt(work[:, s1b, f].T, g1b, g1a)
  
-         # Do even Qshift filters on columns.
+           # Do even Qshift filters on columns.
 -        work[:, :, f] = _BACKEND.colifilt(y[:, s0a].T, g0b, g0a) + _BACKEND.colifilt(y[:,s0b].T, g1b, g1a)
 +        work[:, :, f] = colifilt(y[:, s0a].T, g0b, g0a) + colifilt(y[:,s0b].T, g1b, g1a)
  
      for f in xrange(work.shape[1]):
          # Do even Qshift filters on 3rd dim.
diff --cc dtcwt/opencl/__init__.py
index 80e799f,4f65c11..8ad5e6d
--- a/dtcwt/opencl/__init__.py
+++ b/dtcwt/opencl/__init__.py
@@@ -4,9 -4,11 +4,11 @@@ PyOpenCL be installed
  
  """
  
 -from .transform2d import TransformDomainSignal, Transform2d
 +from .transform2d import Pyramid, Transform2d
+ from .transform3d import Transform3d
  
  __all__ = [
 -    'TransformDomainSignal',
 +    'Pyramid',
      'Transform2d',
+     'Transform3d',
  ]
diff --cc dtcwt/opencl/transform3d.py
index 7274c16,35b56ff..dc5af80
--- a/dtcwt/opencl/transform3d.py
+++ b/dtcwt/opencl/transform3d.py
@@@ -1,169 -1,225 +1,225 @@@
- from __future__ import absolute_import
+ from __future__ import division
  
- import numpy as np
  import logging
- 
+ import numpy as np
  from six.moves import xrange
  
 -from dtcwt import biort as _biort, qshift as _qshift
 +from dtcwt.coeffs import biort as _biort, qshift as _qshift
  from dtcwt.defaults import DEFAULT_BIORT, DEFAULT_QSHIFT
- from dtcwt.numpy.lowlevel import colfilter, coldfilt, colifilt
- from dtcwt.utils import asfarray
- 
- def dtwavexfm3(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, ext_mode=4, discard_level_1=False):
-     """Perform a *n*-level DTCWT-3D decompostion on a 3D matrix *X*.
- 
-     :param X: 3D real array-like object
-     :param nlevels: Number of levels of wavelet decomposition
-     :param biort: Level 1 wavelets to use. See :py:func:`biort`.
-     :param qshift: Level >= 2 wavelets to use. See :py:func:`qshift`.
-     :param ext_mode: Extension mode. See below.
-     :param discard_level_1: True if level 1 high-pass bands are to be discarded.
- 
-     :returns Yl: The real lowpass image from the final level
-     :returns Yh: A tuple containing the complex highpass subimages for each level.
- 
-     Each element of *Yh* is a 4D complex array with the 4th dimension having
-     size 28. The 3D slice ``Yh[l][:,:,:,d]`` corresponds to the complex higpass
-     coefficients for direction d at level l where d and l are both 0-indexed.
- 
-     If *biort* or *qshift* are strings, they are used as an argument to the
-     :py:func:`biort` or :py:func:`qshift` functions. Otherwise, they are
-     interpreted as tuples of vectors giving filter coefficients. In the *biort*
-     case, this should be (h0o, g0o, h1o, g1o). In the *qshift* case, this should
-     be (h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b).
- 
-     There are two values for *ext_mode*, either 4 or 8. If *ext_mode* = 4,
-     check whether 1st level is divisible by 2 (if not we raise a
-     ``ValueError``). Also check whether from 2nd level onwards, the coefs can
-     be divided by 4. If any dimension size is not a multiple of 4, append extra
-     coefs by repeating the edges. If *ext_mode* = 8, check whether 1st level is
-     divisible by 4 (if not we raise a ``ValueError``). Also check whether from
-     2nd level onwards, the coeffs can be divided by 8. If any dimension size is
-     not a multiple of 8, append extra coeffs by repeating the edges twice.
- 
-     If *discard_level_1* is True the highpass coefficients at level 1 will be
-     discarded. (And, in fact, will never be calculated.) This turns the
-     transform from being 8:1 redundant to being 1:1 redundant at the cost of
-     no-longer allowing perfect reconstruction. If this option is selected then
-     `Yh[0]` will be `None`. Note that :py:func:`dtwaveifm3` will accepts
-     `Yh[0]` being `None` and will treat it as being zero.
- 
-     Example::
+ from dtcwt.utils import appropriate_complex_type_for, asfarray, memoize
 -from dtcwt.backend.backend_opencl.lowlevel import colfilter, coldfilt, colifilt
 -from dtcwt.backend.backend_opencl.lowlevel import axis_convolve, axis_convolve_dfilter, q2c
 -from dtcwt.backend.backend_opencl.lowlevel import to_device, to_queue, to_array, empty
++from dtcwt.opencl.lowlevel import colfilter, coldfilt, colifilt
++from dtcwt.opencl.lowlevel import axis_convolve, axis_convolve_dfilter, q2c
++from dtcwt.opencl.lowlevel import to_device, to_queue, to_array, empty
  
-         # Performs a 3-level transform on the real 3D array X using the 13,19-tap
-         # filters for level 1 and the Q-shift 14-tap filters for levels >= 2.
-         Yl, Yh = dtwavexfm3(X, 3, 'near_sym_b', 'qshift_b')
 -from dtcwt.backend.base import TransformDomainSignal, ReconstructedSignal
 -from dtcwt.backend.backend_numpy import Transform3d as Transform3dNumPy
++from dtcwt.numpy import Pyramid
++from dtcwt.numpy import Transform3d as Transform3dNumPy
  
-     .. codeauthor:: Rich Wareham <rjw57 at cantab.net>, Aug 2013
-     .. codeauthor:: Huizhong Chen, Jan 2009
-     .. codeauthor:: Nick Kingsbury, Cambridge University, July 1999.
+ try:
+     from pyopencl.array import concatenate, Array as CLArray
+ except ImportError:
+     # The lack of OpenCL will be caught by the low-level routines.
+     pass
  
+ class Transform3d(Transform3dNumPy):
      """
-     X = np.atleast_3d(asfarray(X))
- 
-     # Try to load coefficients if biort is a string parameter
-     try:
-         h0o, g0o, h1o, g1o = _biort(biort)
-     except TypeError:
-         h0o, g0o, h1o, g1o = biort
- 
-     # Try to load coefficients if qshift is a string parameter
-     try:
-         h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = _qshift(qshift)
-     except TypeError:
-         h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = qshift
- 
-     # Check value of ext_mode. TODO: this should really be an enum :S
-     if ext_mode != 4 and ext_mode != 8:
-         raise ValueError('ext_mode must be one of 4 or 8')
- 
-     Yl = X
-     Yh = [None,] * nlevels
- 
-     # level is 0-indexed
-     for level in xrange(nlevels):
-         # Transform
-         if level == 0 and discard_level_1:
-             Yl = _level1_xfm_no_highpass(Yl, h0o, h1o, ext_mode)
-         elif level == 0 and not discard_level_1:
-             Yl, Yh[level] = _level1_xfm(Yl, h0o, h1o, ext_mode)
-         else:
-             Yl, Yh[level] = _level2_xfm(Yl, h0a, h0b, h1a, h1b, ext_mode)
- 
-     return Yl, tuple(Yh)
+     An implementation of the 3D DT-CWT via OpenCL. *biort* and *qshift* are the
+     wavelets which parameterise the transform. Valid values are documented in
+     :py:func:`dtcwt.dtwavexfm2`.
  
- def dtwaveifm3(Yl, Yh, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, ext_mode=4):
-     """Perform an *n*-level dual-tree complex wavelet (DTCWT) 3D
-     reconstruction.
+     If *queue* is non-*None* it is an instance of
+     :py:class:`pyopencl.CommandQueue` which is used to compile and execute the
+     OpenCL kernels which implement the transform. If it is *None*, the first
+     available compute device is used.
  
-     :param Yl: The real lowpass subband from the final level
-     :param Yh: A sequence containing the complex highpass subband for each level.
-     :param biort: Level 1 wavelets to use. See :py:func:`biort`.
-     :param qshift: Level >= 2 wavelets to use. See :py:func:`qshift`.
-     :param ext_mode: Extension mode. See below.
- 
-     :returns Z: Reconstructed real image matrix.
- 
-     If *biort* or *qshift* are strings, they are used as an argument to the
-     :py:func:`biort` or :py:func:`qshift` functions. Otherwise, they are
-     interpreted as tuples of vectors giving filter coefficients. In the *biort*
-     case, this should be (h0o, g0o, h1o, g1o). In the *qshift* case, this should
-     be (h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b).
+     .. note::
+         
+         At the moment *only* the **forward** transform is accelerated. The
+         inverse transform uses the NumPy backend.
  
-     There are two values for *ext_mode*, either 4 or 8. If *ext_mode* = 4,
-     check whether 1st level is divisible by 2 (if not we raise a
-     ``ValueError``). Also check whether from 2nd level onwards, the coefs can
-     be divided by 4. If any dimension size is not a multiple of 4, append extra
-     coefs by repeating the edges. If *ext_mode* = 8, check whether 1st level is
-     divisible by 4 (if not we raise a ``ValueError``). Also check whether from
-     2nd level onwards, the coeffs can be divided by 8. If any dimension size is
-     not a multiple of 8, append extra coeffs by repeating the edges twice.
+     """
+     def __init__(self, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT,ext_mode=4, queue=None):
+         super(Transform3d, self).__init__(biort=biort, qshift=qshift, ext_mode=ext_mode)
+         self.queue = to_queue(queue)
  
-     Example::
+     def forward(self, X, nlevels=3, include_scale=False, discard_level_1=False):
+         """Perform a *n*-level DTCWT-3D decompostion on a 3D matrix *X*.
+         
+         :param X: 3D real array-like object
+         :param nlevels: Number of levels of wavelet decomposition
+         :param biort: Level 1 wavelets to use. See :py:func:`biort`.
+         :param qshift: Level >= 2 wavelets to use. See :py:func:`qshift`.
+         :param ext_mode: Extension mode. See below.
+         :param include_scale: True if level 1 high-pass bands are to be discarded.
+         
+         :returns Yl: The real lowpass image from the final level
+         :returns Yh: A tuple containing the complex highpass subimages for each level.
+         
+         Each element of *Yh* is a 4D complex array with the 4th dimension having
+         size 28. The 3D slice ``Yh[l][:,:,:,d]`` corresponds to the complex higpass
+         coefficients for direction d at level l where d and l are both 0-indexed.
+         
+         If *biort* or *qshift* are strings, they are used as an argument to the
+         :py:func:`biort` or :py:func:`qshift` functions. Otherwise, they are
+         interpreted as tuples of vectors giving filter coefficients. In the *biort*
+         case, this should be (h0o, g0o, h1o, g1o). In the *qshift* case, this should
+         be (h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b).
+         
+         There are two values for *ext_mode*, either 4 or 8. If *ext_mode* = 4,
+         check whether 1st level is divisible by 2 (if not we raise a
+         ``ValueError``). Also check whether from 2nd level onwards, the coefs can
+         be divided by 4. If any dimension size is not a multiple of 4, append extra
+         coefs by repeating the edges. If *ext_mode* = 8, check whether 1st level is
+         divisible by 4 (if not we raise a ``ValueError``). Also check whether from
+         2nd level onwards, the coeffs can be divided by 8. If any dimension size is
+         not a multiple of 8, append extra coeffs by repeating the edges twice.
+         
+         If *include_scale* is True the highpass coefficients at level 1 will not be
+         discarded. (And, in fact, will never be calculated.) This turns the
+         transform from being 8:1 redundant to being 1:1 redundant at the cost of
+         no-longer allowing perfect reconstruction. If this option is selected then
+         `Yh[0]` will be `None`. Note that :py:func:`dtwaveifm3` will accepts
+         `Yh[0]` being `None` and will treat it as being zero.
+         
+         Example::
+         
+         # Performs a 3-level transform on the real 3D array X using the 13,19-tap
+         # filters for level 1 and the Q-shift 14-tap filters for levels >= 2.
+         Yl, Yh = dtwavexfm3(X, 3, 'near_sym_b', 'qshift_b')
+         
+         .. codeauthor:: Rich Wareham <rjw57 at cantab.net>, Aug 2013
+         .. codeauthor:: Huizhong Chen, Jan 2009
+         .. codeauthor:: Nick Kingsbury, Cambridge University, July 1999.
+         
+         """
+         X = np.atleast_3d(asfarray(X))
  
+         if len(self.biort) == 4:
+             h0o, g0o, h1o, g1o = self.biort
+         elif len(self.biort) == 6:
+             h0o, g0o, h1o, g1o, h2o, g2o = self.biort
+         else:
+             raise ValueError('Biort wavelet must have 6 or 4 components.')
+         
+         if len(self.qshift) == 8:
+             h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = self.qshift
+         elif len(self.qshift) == 12:
+             h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b, h2a, h2b = self.qshift[:10]
+         else:
+             raise ValueError('Qshift wavelet must have 12 or 8 components.')
+ 
+         # Check value of ext_mode. TODO: this should really be an enum :S
+         if self.ext_mode != 4 and self.ext_mode != 8:
+             raise ValueError('ext_mode must be one of 4 or 8')
+ 
+         Yl = X
+         Yh = [None,] * nlevels
+ 
+         if include_scale:
+             # this is only required if the user specifies a third output component.
+             Yscale = [None,] * nlevels
+ 
+         # level is 0-indexed
+         for level in xrange(nlevels):
+             # Transform
+             if level == 0 and discard_level_1:
+                 Yl = _level1_xfm_no_highpass(Yl, h0o, h1o, self.ext_mode)
+                 if include_scale:
+                     Yscale[0] = Yl
+             elif level == 0 and not discard_level_1:
+                 Yl, Yh[level] = _level1_xfm(Yl, h0o, h1o, self.ext_mode)
+                 if include_scale:
+                     Yscale[0] = Yl
+             else:
+                 Yl, Yh[level] = _level2_xfm(Yl, h0a, h0b, h1a, h1b, self.ext_mode)
+                 if include_scale:
+                     Yscale[level] = Yl
+         #FIXME: need some way to separate the Yscale component to include the scale when necessary.
+         if include_scale:
 -            return TransformDomainSignal(Yl, tuple(Yh), tuple(Yscale))
++            return Pyramid(Yl, tuple(Yh), tuple(Yscale))
+         else: 
 -            return TransformDomainSignal(Yl, tuple(Yh))
 -        return TransformDomainSignal(Yl, tuple(Yh))
++            return Pyramid(Yl, tuple(Yh))
++        return Pyramid(Yl, tuple(Yh))
+ 
+     def inverse(self, td_signal):
+         """Perform an *n*-level dual-tree complex wavelet (DTCWT) 3D
+         reconstruction.
+         
+         :param Yl: The real lowpass subband from the final level
+         :param Yh: A sequence containing the complex highpass subband for each level.
+         :param biort: Level 1 wavelets to use. See :py:func:`biort`.
+         :param qshift: Level >= 2 wavelets to use. See :py:func:`qshift`.
+         :param ext_mode: Extension mode. See below.
+         
+         :returns Z: Reconstructed real image matrix.
+         
+         If *biort* or *qshift* are strings, they are used as an argument to the
+         :py:func:`biort` or :py:func:`qshift` functions. Otherwise, they are
+         interpreted as tuples of vectors giving filter coefficients. In the *biort*
+         case, this should be (h0o, g0o, h1o, g1o). In the *qshift* case, this should
+         be (h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b).
+         
+         There are two values for *ext_mode*, either 4 or 8. If *ext_mode* = 4,
+         check whether 1st level is divisible by 2 (if not we raise a
+         ``ValueError``). Also check whether from 2nd level onwards, the coefs can
+         be divided by 4. If any dimension size is not a multiple of 4, append extra
+         coefs by repeating the edges. If *ext_mode* = 8, check whether 1st level is
+         divisible by 4 (if not we raise a ``ValueError``). Also check whether from
+         2nd level onwards, the coeffs can be divided by 8. If any dimension size is
+         not a multiple of 8, append extra coeffs by repeating the edges twice.
+         
+         Example::
+         
          # Performs a 3-level reconstruction from Yl,Yh using the 13,19-tap
          # filters for level 1 and the Q-shift 14-tap filters for levels >= 2.
          Z = dtwaveifm3(Yl, Yh, 'near_sym_b', 'qshift_b')
- 
-     .. codeauthor:: Rich Wareham <rjw57 at cantab.net>, Aug 2013
-     .. codeauthor:: Huizhong Chen, Jan 2009
-     .. codeauthor:: Nick Kingsbury, Cambridge University, July 1999.
- 
-     """
-     # Try to load coefficients if biort is a string parameter
-     try:
-         h0o, g0o, h1o, g1o = _biort(biort)
-     except TypeError:
-         h0o, g0o, h1o, g1o = biort
- 
-     # Try to load coefficients if qshift is a string parameter
-     try:
-         h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = _qshift(qshift)
-     except TypeError:
-         h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = qshift
- 
-     X = Yl
- 
-     nlevels = len(Yh)
-     # level is 0-indexed but interpreted starting from the *last* level
-     for level in xrange(nlevels):
-         # Transform
-         if level == nlevels-1: # non-obviously this is the 'first' level
-             if Yh[-level-1] is None:
-                 Yl = _level1_ifm_no_highpass(Yl, g0o, g1o)
-             else:
-                 Yl = _level1_ifm(Yl, Yh[-level-1], g0o, g1o)
+         
+         .. codeauthor:: Rich Wareham <rjw57 at cantab.net>, Aug 2013
+         .. codeauthor:: Huizhong Chen, Jan 2009
+         .. codeauthor:: Nick Kingsbury, Cambridge University, July 1999.
+         
+         """
+         Yl = td_signal.lowpass
 -        Yh = td_signal.subbands
++        Yh = td_signal.highpasses
+ 
+         # Try to load coefficients if biort is a string parameter
+         if len(self.biort) == 4:
+             h0o, g0o, h1o, g1o = self.biort
+         elif len(self.biort) == 6:
+             h0o, g0o, h1o, g1o, h2o, g2o = self.biort
+         else:
+             raise ValueError('Biort wavelet must have 6 or 4 components.')
+         
+         # If qshift has 12 elements instead of 8, then it's a modified
+         # rotationally symmetric wavelet
+         # FIXME: there's probably a nicer way to do this
+         if len(self.qshift) == 8:
+             h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = self.qshift
+         elif len(self.qshift) == 12:
+             h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b, h2a, h2b, g2a, g2b = self.qshift
          else:
-             # Gracefully handle the Yh[0] is None case.
-             if Yh[-level-2] is not None:
-                 prev_shape = Yh[-level-2].shape
+             raise ValueError('Qshift wavelet must have 12 or 8 components.')
+ 
+         X = Yl
+ 
+         nlevels = len(Yh)
+         # level is 0-indexed but interpreted starting from the *last* level
+         for level in xrange(nlevels):
+             # Transform
+             if level == nlevels-1: # non-obviously this is the 'first' level
+                 if Yh[-level-1] is None:
+                     Yl = _level1_ifm_no_highpass(Yl, g0o, g1o)
+                 else:
+                     Yl = _level1_ifm(Yl, Yh[-level-1], g0o, g1o)
              else:
-                 prev_shape = np.array(Yh[-level-1].shape) * 2
+                 # Gracefully handle the Yh[0] is None case.
+                 if Yh[-level-2] is not None:
+                     prev_shape = Yh[-level-2].shape
+                 else:
+                     prev_shape = np.array(Yh[-level-1].shape) * 2
+                     
+                 Yl = _level2_ifm(Yl, Yh[-level-1], g0a, g0b, g1a, g1b, self.ext_mode, prev_shape)
  
-             Yl = _level2_ifm(Yl, Yh[-level-1], g0a, g0b, g1a, g1b, ext_mode, prev_shape)
- 
-     return Yl
 -        return ReconstructedSignal(Yl)
++        return Yl
  
  def _level1_xfm(X, h0o, h1o, ext_mode):
      """Perform level 1 of the 3d transform.
diff --cc examples/3d_dtcwt_directionality.py
index ca45be2,bd3cebf..78fc2f1
--- a/examples/3d_dtcwt_directionality.py
+++ b/examples/3d_dtcwt_directionality.py
@@@ -16,38 -16,83 +16,84 @@@ from matplotlib.pyplot import 
  import numpy as np
  from mpl_toolkits.mplot3d import Axes3D
  from mpl_toolkits.mplot3d.art3d import Poly3DCollection
 -from dtcwt import dtwavexfm3, dtwaveifm3, biort, qshift
 +from dtcwt.compat import dtwavexfm3, dtwaveifm3
 +from dtcwt.coeffs import biort, qshift
  
  # Specify details about sphere and grid size
- GRID_SIZE = 128
- SPHERE_RAD = int(0.45 * GRID_SIZE) + 0.5
- 
- # Compute an image of the sphere
- grid = np.arange(-(GRID_SIZE>>1), GRID_SIZE>>1)
- X, Y, Z = np.meshgrid(grid, grid, grid)
- r = np.sqrt(X*X + Y*Y + Z*Z)
- sphere = (0.5 + np.clip(SPHERE_RAD-r, -0.5, 0.5)).astype(np.float32)
- 
- # Specify number of levels and wavelet family to use
- nlevels = 2
- b = biort('near_sym_a')
- q = qshift('qshift_a')
- 
- # Form the DT-CWT of the sphere. We use discard_level_1 since we're
- # uninterested in the inverse transform and this saves us some memory.
- Yl, Yh = dtwavexfm3(sphere, nlevels, b, q, discard_level_1=True)
- 
- # Plot maxima
- figure(figsize=(8,8))
- 
- ax = gcf().add_subplot(1,1,1, projection='3d')
- ax.set_aspect('equal')
- ax.view_init(35, 75)
- 
- # Plot unit sphere +ve octant
- thetas = np.linspace(0, np.pi/2, 10)
- phis = np.linspace(0, np.pi/2, 10)
+ def main():
+     GRID_SIZE = 128
+     SPHERE_RAD = int(0.45 * GRID_SIZE) + 0.5
+     
+     # Compute an image of the sphere
+     grid = np.arange(-(GRID_SIZE>>1), GRID_SIZE>>1)
+     X, Y, Z = np.meshgrid(grid, grid, grid)
+     r = np.sqrt(X*X + Y*Y + Z*Z)
+     sphere = (0.5 + np.clip(SPHERE_RAD-r, -0.5, 0.5)).astype(np.float32)
+     
+     # Specify number of levels and wavelet family to use
+     nlevels = 2
+     b = biort('near_sym_a')
+     q = qshift('qshift_a')
+     
+     # Form the DT-CWT of the sphere. We use discard_level_1 since we're
+     # uninterested in the inverse transform and this saves us some memory.
+     Yl, Yh = dtwavexfm3(sphere, nlevels, b, q, discard_level_1=False)
+     
+     # Plot maxima
+     figure(figsize=(8,8))
+     
+     ax = gcf().add_subplot(1,1,1, projection='3d')
+     ax.set_aspect('equal')
+     ax.view_init(35, 75)
+     
+     # Plot unit sphere +ve octant
+     thetas = np.linspace(0, np.pi/2, 10)
+     phis = np.linspace(0, np.pi/2, 10)
+   
+     
+     tris = []
+     rad = 0.99 # so that points plotted latter are not z-clipped
+     for t1, t2 in zip(thetas[:-1], thetas[1:]):
+         for p1, p2 in zip(phis[:-1], phis[1:]):
+             tris.append([
+                 sphere_to_xyz(rad, t1, p1),
+                 sphere_to_xyz(rad, t1, p2),
+                 sphere_to_xyz(rad, t2, p2),
+                 sphere_to_xyz(rad, t2, p1),
+                 ])
+             
+     sphere = Poly3DCollection(tris, facecolor='w', edgecolor=(0.6,0.6,0.6))
+     ax.add_collection3d(sphere)
+             
+     locs = []
+     scale = 1.1
+     for idx in xrange(Yh[-1].shape[3]):
+         Z = Yh[-1][:,:,:,idx]
+         C = np.abs(Z)
+         max_loc = np.asarray(np.unravel_index(np.argmax(C), C.shape)) - np.asarray(C.shape)*0.5
+         max_loc /= np.sqrt(np.sum(max_loc * max_loc))
+         
+         # Only record directions in the +ve octant (or those from the -ve quadrant
+         # which can be flipped).
+         if np.all(np.sign(max_loc) == 1):
+             locs.append(max_loc)
+             ax.text(max_loc[0] * scale, max_loc[1] * scale, max_loc[2] * scale, str(idx+1))
+         elif np.all(np.sign(max_loc) == -1):
+             locs.append(-max_loc)
+             ax.text(-max_loc[0] * scale, -max_loc[1] * scale, -max_loc[2] * scale, str(idx+1))
+             
+             # Plot all directions as a scatter plot
+     locs = np.asarray(locs)
+     ax.scatter(locs[:,0], locs[:,1], locs[:,2], c=np.arange(locs.shape[0]))
+     
+     w = 1.1
+     ax.auto_scale_xyz([0, w], [0, w], [0, w])
+     
+     legend()
+     title('3D DT-CWT subband directions for +ve hemisphere quadrant')
+     tight_layout()
+     
+     show()
  
  def sphere_to_xyz(r, theta, phi):
      st, ct = np.sin(theta), np.cos(theta)
diff --cc tests/testagainstmatlab.py
index 8b61690,287b9a8..8676cb9
--- a/tests/testagainstmatlab.py
+++ b/tests/testagainstmatlab.py
@@@ -3,13 -3,16 +3,16 @@@ from nose.tools import raise
  from nose.plugins.attrib import attr
  
  import numpy as np
+ 
+ from scipy.io import loadmat
 -from dtcwt import dtwavexfm2, dtwavexfm3, dtwaveifm2, dtwavexfm2b, dtwaveifm2b, biort, qshift
 -from dtcwt.lowlevel import coldfilt, colifilt
 +from dtcwt.compat import dtwavexfm2, dtwaveifm2, dtwavexfm2b, dtwaveifm2b
 +from dtcwt.coeffs import biort, qshift
 +from dtcwt.numpy.lowlevel import coldfilt, colifilt
++from dtcwt.numpy import Transform3d, Pyramid
  from dtcwt.sampling import rescale_highpass
  
- from .util import assert_almost_equal, summarise_mat, assert_percentile_almost_equal
 -from dtcwt.backend.base import TransformDomainSignal, ReconstructedSignal
 -from dtcwt.backend.backend_numpy import Transform3d
 -
+ from .util import assert_almost_equal, summarise_mat, summarise_cube, assert_percentile_almost_equal
 +import tests.datasets as datasets
  
  ## IMPORTANT NOTE ##
  
@@@ -43,10 -46,19 +46,19 @@@ def assert_almost_equal_to_summary(a, s
  def assert_percentile_almost_equal_to_summary(a, summary, *args, **kwargs):
      assert_percentile_almost_equal(summarise_mat(a), summary, *args, **kwargs)
  
+ def assert_almost_equal_to_summary_cube(a, summary, *args, **kwargs):
+     assert_almost_equal(summarise_cube(a), summary, *args, **kwargs)
+ 
+ def assert_percentile_almost_equal_to_summary_cube(a, summary, *args, **kwargs):
+     assert_percentile_almost_equal(summarise_cube(a), summary, *args, **kwargs)
+ 
  def setup():
      global lena
 -    lena = np.load(os.path.join(os.path.dirname(__file__), 'lena.npz'))['lena']
 +    lena = datasets.lena()
  
+     global qbgn
+     qbgn = loadmat(os.path.join(os.path.dirname(__file__), 'qbgn.mat'))['qbgn']
+ 
      global verif
      verif = np.load(os.path.join(os.path.dirname(__file__), 'verification.npz'))
      
@@@ -103,4 -115,15 +115,15 @@@ def test_rescale_highpass()
      # quite an amount. Use a percentile approach to look at the bigger picture.
      assert_percentile_almost_equal_to_summary(Xrescale, verif['lena_upsample'], 60, tolerance=TOLERANCE)
  
+ def test_transform3d_numpy():
+     transform = Transform3d(biort='near_sym_b',qshift='qshift_b')
+     td_signal = transform.forward(qbgn, nlevels=3, include_scale=True, discard_level_1=False)
 -    Yl, Yh, Yscale = td_signal.lowpass, td_signal.subbands, td_signal.scales
++    Yl, Yh, Yscale = td_signal.lowpass, td_signal.highpasses, td_signal.scales
+     assert_almost_equal_to_summary_cube(Yl, verif['qbgn_Yl'], tolerance=TOLERANCE)
+     for idx, a in enumerate(Yh):
+         assert_almost_equal_to_summary_cube(a, verif['qbgn_Yh_{0}'.format(idx)], tolerance=TOLERANCE)
+ 
+     for idx, a in enumerate(Yscale):
+         assert_almost_equal_to_summary_cube(a, verif['qbgn_Yscale_{0}'.format(idx)], tolerance=TOLERANCE)
+ 
  # vim:sw=4:sts=4:et

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/python-dtcwt.git



More information about the debian-science-commits mailing list