[python-dtcwt] 66/497: make single level 3D transforms work
Ghislain Vaillant
ghisvail-guest at moszumanska.debian.org
Tue Jul 21 18:05:50 UTC 2015
This is an automated email from the git hooks/post-receive script.
ghisvail-guest pushed a commit to branch debian/sid
in repository python-dtcwt.
commit 2e9bcf8629f1fcd1ae81536d973f94fc0c2a1ff9
Author: Rich Wareham <rjw57 at cam.ac.uk>
Date: Thu Aug 8 14:27:23 2013 +0100
make single level 3D transforms work
---
dtcwt/lowlevel.py | 6 ++
dtcwt/transform3d.py | 172 +++++++++++++++++++++++++++++++++++++++++++++++----
tests/testxfm3.py | 33 +++++++---
3 files changed, 191 insertions(+), 20 deletions(-)
diff --git a/dtcwt/lowlevel.py b/dtcwt/lowlevel.py
index f7e8816..d2dbb47 100644
--- a/dtcwt/lowlevel.py
+++ b/dtcwt/lowlevel.py
@@ -29,6 +29,12 @@ def _column_convolve(X, h):
i.e. those values unaffected by zero padding.
"""
+
+ # This function should give the same result as:
+ #
+ from scipy.signal import convolve2d
+ return convolve2d(X, as_column_vector(h), 'valid')
+
h = h.flatten()
h_size = h.shape[0]
full_size = X.shape[0] + h_size - 1
diff --git a/dtcwt/transform3d.py b/dtcwt/transform3d.py
index 7d446f6..97760f6 100644
--- a/dtcwt/transform3d.py
+++ b/dtcwt/transform3d.py
@@ -7,17 +7,17 @@ from dtcwt import biort as _biort, qshift as _qshift
from dtcwt.defaults import DEFAULT_BIORT, DEFAULT_QSHIFT
from dtcwt.lowlevel import colfilter, coldfilt, colifilt
-def dtwavexfm3(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, include_scale=False):
+def dtwavexfm3(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, ext_mode=4):
"""Perform a *n*-level DTCWT-3D decompostion on a 3D matrix *X*.
:param X: 3D real matrix/Image of shape (N, M)
:param nlevels: Number of levels of wavelet decomposition
:param biort: Level 1 wavelets to use. See :py:func:`biort`.
:param qshift: Level >= 2 wavelets to use. See :py:func:`qshift`.
+ :param ext_mode: Extension mode. See below.
:returns Yl: The real lowpass image from the final level
- :returns Yh: A tuple containing the (N, M, 6) shape complex highpass subimages for each level.
- :returns Yscale: If *include_scale* is True, a tuple containing real lowpass coefficients for every scale.
+ :returns Yh: A tuple containing the (N, M, 7) shape complex highpass subimages for each level.
If *biort* or *qshift* are strings, they are used as an argument to the
:py:func:`biort` or :py:func:`qshift` functions. Otherwise, they are
@@ -25,6 +25,15 @@ def dtwavexfm3(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, include
case, this should be (h0o, g0o, h1o, g1o). In the *qshift* case, this should
be (h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b).
+ There are two values for *ext_mode*, either 4 or 8. If *ext_mode* = 4,
+ check whether 1st level is divisible by 2 (if not we raise a
+ ``ValueError``). Also check whether from 2nd level onwards, the coefs can
+ be divided by 4. If any dimension size is not a multiple of 4, append extra
+ coefs by repeating the edges. If *ext_mode* = 8, check whether 1st level is
+ divisible by 4 (if not we raise a ``ValueError``). Also check whether from
+ 2nd level onwards, the coeffs can be divided by 8. If any dimension size is
+ not a multiple of 8, append extra coeffs by repeating the edges twice.
+
Example::
# Performs a 3-level transform on the real image X using the 13,19-tap
@@ -34,9 +43,36 @@ def dtwavexfm3(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, include
.. codeauthor:: Rich Wareham <rjw57 at cantab.net>, Aug 2013
"""
- return None
+ X = np.atleast_3d(X)
+
+ # Try to load coefficients if biort is a string parameter
+ try:
+ h0o, g0o, h1o, g1o = _biort(biort)
+ except TypeError:
+ h0o, g0o, h1o, g1o = biort
+
+ # Try to load coefficients if qshift is a string parameter
+ try:
+ h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = _qshift(qshift)
+ except TypeError:
+ h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = qshift
+
+ # Check value of ext_mode. TODO: this should really be an enum :S
+ if ext_mode != 4 and ext_mode != 8:
+ raise ValueError('ext_mode must be one of 4 or 8')
+
+ Yl = X
+ Yh = [None,] * nlevels
-def dtwaveifm3(Yl,Yh,biort=DEFAULT_BIORT,qshift=DEFAULT_QSHIFT,gain_mask=None):
+ # level is 0-indexed
+ for level in xrange(nlevels):
+ # Transform
+ if level == 0:
+ Yl, Yh[level] = _level1_xfm(Yl, h0o, h1o, ext_mode)
+
+ return Yl, tuple(Yh)
+
+def dtwaveifm3(Yl, Yh, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT):
"""Perform an *n*-level dual-tree complex wavelet (DTCWT) 3D
reconstruction.
@@ -44,15 +80,9 @@ def dtwaveifm3(Yl,Yh,biort=DEFAULT_BIORT,qshift=DEFAULT_QSHIFT,gain_mask=None):
:param Yh: A sequence containing the complex highpass subband for each level.
:param biort: Level 1 wavelets to use. See :py:func:`biort`.
:param qshift: Level >= 2 wavelets to use. See :py:func:`qshift`.
- :param gain_mask: Gain to be applied to each subband.
:returns Z: Reconstructed real image matrix.
- The (*d*, *l*)-th element of *gain_mask* is gain for subband with direction
- *d* at level *l*. If gain_mask[d,l] == 0, no computation is performed for
- band (d,l). Default *gain_mask* is all ones. Note that both *d* and *l* are
- zero-indexed.
-
If *biort* or *qshift* are strings, they are used as an argument to the
:py:func:`biort` or :py:func:`qshift` functions. Otherwise, they are
interpreted as tuples of vectors giving filter coefficients. In the *biort*
@@ -68,6 +98,124 @@ def dtwaveifm3(Yl,Yh,biort=DEFAULT_BIORT,qshift=DEFAULT_QSHIFT,gain_mask=None):
.. codeauthor:: Rich Wareham <rjw57 at cantab.net>, Aug 2013
"""
- return None
+ # Try to load coefficients if biort is a string parameter
+ try:
+ h0o, g0o, h1o, g1o = _biort(biort)
+ except TypeError:
+ h0o, g0o, h1o, g1o = biort
+
+ # Try to load coefficients if qshift is a string parameter
+ try:
+ h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = _qshift(qshift)
+ except TypeError:
+ h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = qshift
+
+ X = Yl
+
+ # level is 0-indexed
+ nlevels = len(Yh)
+ for level in xrange(nlevels):
+ # Transform
+ if level == 0:
+ X = _level1_ifm(Yl, Yh[level], g0o, g1o)
+
+ return X
+
+def _level1_xfm(X, h0o, h1o, ext_mode):
+ """Perform level 1 of the 3d transform.
+
+ """
+ # Check shape of input according to ext_mode. Note that shape of X is
+ # double original input in each direction.
+ if ext_mode == 4 and np.any(np.fmod(X.shape, 2) != 0):
+ raise ValueError('Input shape should be a multiple of 2 in each direction when ext_mode == 4')
+ elif ext_mode == 8 and np.any(np.fmod(X.shape, 4) != 0):
+ raise ValueError('Input shape should be a multiple of 4 in each direction when ext_mode == 8')
+
+ # Create work area
+ work = np.zeros(np.asarray(X.shape) * 2, dtype=X.dtype)
+
+ # Form some useful slices
+ s0a = slice(None, work.shape[0] >> 1)
+ s1a = slice(None, work.shape[1] >> 1)
+ s2a = slice(None, work.shape[2] >> 1)
+ s0b = slice(work.shape[0] >> 1, None)
+ s1b = slice(work.shape[1] >> 1, None)
+ s2b = slice(work.shape[2] >> 1, None)
+
+ # Assign input
+ work[s0a, s1a, s2a] = X
+
+ # Loop over 2nd dimension extracting 2D slice from first and 3rd dimensions
+ for f in xrange(work.shape[1] >> 1):
+ # extract slice
+ y = work[s0a, f, s2a].T
+
+ # Do odd top-level filters on 3rd dim. The order here is important
+ # since the second filtering will modify the elements of y as well
+ # since y is merely a view onto work.
+ work[s0a, f, s2b] = colfilter(y, h1o).T
+ work[s0a, f, s2a] = colfilter(y, h0o).T
+
+ # Loop over 3rd dimension extracting 2D slice from first and 2nd dimensions
+ for f in xrange(work.shape[2]):
+ # Do odd top-level filters on rows.
+ y1 = work[s0a, s1a, f].T
+ y2 = np.vstack((colfilter(y1, h0o), colfilter(y1, h1o))).T
+
+ # Do odd top-level filters on columns.
+ work[s0a, :, f] = colfilter(y2, h0o)
+ work[s0b, :, f] = colfilter(y2, h1o)
+
+ # Return appropriate slices of output
+ return (work[s0a, s1a, s2a], # LLL
+ np.concatenate((
+ work[s0a, s1b, s2a, np.newaxis], # HLL
+ work[s0b, s1a, s2a, np.newaxis], # LHL
+ work[s0b, s1b, s2a, np.newaxis], # HHL
+ work[s0a, s1a, s2b, np.newaxis], # LLH
+ work[s0a, s1b, s2b, np.newaxis], # HLH
+ work[s0b, s1a, s2b, np.newaxis], # LHH
+ work[s0b, s1b, s2b, np.newaxis], # HLH
+ ), axis=3))
+
+def _level1_ifm(Yl, Yh, g0o, g1o):
+ """Perform level 1 of the inverse 3d transform.
+
+ """
+ # Create work area
+ work = np.zeros(np.asarray(Yl.shape) * 2, dtype=Yl.dtype)
+
+ # Form some useful slices
+ s0a = slice(None, work.shape[0] >> 1)
+ s1a = slice(None, work.shape[1] >> 1)
+ s2a = slice(None, work.shape[2] >> 1)
+ s0b = slice(work.shape[0] >> 1, None)
+ s1b = slice(work.shape[1] >> 1, None)
+ s2b = slice(work.shape[2] >> 1, None)
+
+ # Assign regions of work area
+ work[s0a, s1a, s2a] = Yl
+ work[s0a, s1b, s2a] = Yh[:,:,:,0]
+ work[s0b, s1a, s2a] = Yh[:,:,:,1]
+ work[s0b, s1b, s2a] = Yh[:,:,:,2]
+ work[s0a, s1a, s2b] = Yh[:,:,:,3]
+ work[s0a, s1b, s2b] = Yh[:,:,:,4]
+ work[s0b, s1a, s2b] = Yh[:,:,:,5]
+ work[s0b, s1b, s2b] = Yh[:,:,:,6]
+
+ for f in xrange(work.shape[2]):
+ # Do odd top-level filters on rows.
+ y = colfilter(work[:, s1a, f].T, g0o) + colfilter(work[:, s1b, f].T, g1o)
+
+ # Do odd top-level filters on columns.
+ work[s0a, s1a, f] = colfilter(y[:, s0a].T, g0o) + colfilter(y[:, s0b].T, g1o)
+
+ for f in xrange(work.shape[1]>>1):
+ # Do odd top-level filters on 3rd dim.
+ y = work[s0a, f, :].T
+ work[s0a, f, s2a] = (colfilter(y[s2a, :], g0o) + colfilter(y[s2b, :], g1o)).T
+
+ return work[s0a, s1a, s2a]
# vim:sw=4:sts=4:et
diff --git a/tests/testxfm3.py b/tests/testxfm3.py
index ac1c78d..9dc8453 100644
--- a/tests/testxfm3.py
+++ b/tests/testxfm3.py
@@ -5,21 +5,38 @@ from nose.plugins.attrib import attr
import numpy as np
from dtcwt import dtwavexfm3, dtwaveifm3
+GRID_SIZE=32
+SPHERE_RAD=25
+
def setup():
global sphere
- X, Y, Z = np.meshgrid(np.arange(-63,64), np.arange(-63,64), np.arange(-63,64))
+
+ grid = np.arange(-(GRID_SIZE>>1), (GRID_SIZE>>1))
+ X, Y, Z = np.meshgrid(grid, grid, grid)
+
+ Y *= 1.2
+ Z /= 1.2
r = np.sqrt(X*X + Y*Y + Z*Z)
- sphere = np.where(r <= 55, 1.0, 0.0)
+ sphere = np.where(r <= SPHERE_RAD, 1.0, 0.0)
def test_sphere():
# Check general aspects of sphere are OK
- assert sphere.shape == (127,127,127)
- assert sphere.min() >= 0
- assert sphere.max() <= 1
+ assert sphere.shape == (GRID_SIZE,GRID_SIZE,GRID_SIZE)
+ assert sphere.min() == 0
+ assert sphere.max() == 1
+
+def test_simple_level_1_xfm():
+ # Just tests that the transform broadly works and gives expected size output
+ Yl, Yh = dtwavexfm3(sphere, 1)
+ assert Yl.shape == (GRID_SIZE,GRID_SIZE,GRID_SIZE)
+ assert len(Yh) == 1
- # Check volume of sphere is ok to within 5%
- sphere_vol = (4.0/3.0) * np.pi * 55*55*55
- assert np.abs(np.sum(sphere.flatten()) - sphere_vol) < 5e-2*sphere_vol
+def test_simple_level_1_recon():
+ # Test for perfect reconstruction with 1 level
+ Yl, Yh = dtwavexfm3(sphere, 1)
+ sphere_recon = dtwaveifm3(Yl, Yh)
+ assert sphere.size == sphere_recon.size
+ assert np.max(np.abs(sphere - sphere_recon)) < 1e-11
# vim:sw=4:sts=4:et
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/python-dtcwt.git
More information about the debian-science-commits
mailing list