[python-dtcwt] 247/497: added verification phantom for the unit-testing 3d transform against the matlab implementation
Ghislain Vaillant
ghisvail-guest at moszumanska.debian.org
Tue Jul 21 18:06:11 UTC 2015
This is an automated email from the git hooks/post-receive script.
ghisvail-guest pushed a commit to branch debian/sid
in repository python-dtcwt.
commit af34ebc60d22f917fc17988a20f3235ee0226290
Author: tim <tim at tim.(none)>
Date: Mon Jan 27 20:18:47 2014 +0000
added verification phantom for the unit-testing 3d transform against the matlab implementation
---
dtcwt/backend/backend_numpy/transform2d.py | 8 ++--
dtcwt/backend/backend_numpy/transform3d.py | 58 ++++++++++++++++++++++-------
dtcwt/transform3d.py | 2 +-
matlab/gen_verif.m | 15 +++++++-
matlab/qbgn/gaussian.m | 27 ++++++++++++++
matlab/qbgn/gen_qbgn.m | 31 +++++++++++++++
matlab/qbgn/softquant.m | 30 +++++++++++++++
matlab/verif_m_to_npz.py | 30 +++++++++++++--
tests/qbgn.mat | Bin 0 -> 1193947 bytes
tests/testagainstmatlab.py | 31 +++++++++++++--
tests/util.py | 12 ++++++
tests/verification.npz | Bin 279070 -> 4364786 bytes
12 files changed, 219 insertions(+), 25 deletions(-)
diff --git a/dtcwt/backend/backend_numpy/transform2d.py b/dtcwt/backend/backend_numpy/transform2d.py
index cd93092..c0ff392 100644
--- a/dtcwt/backend/backend_numpy/transform2d.py
+++ b/dtcwt/backend/backend_numpy/transform2d.py
@@ -293,9 +293,10 @@ class Transform2d(Transform2dBase):
#==========================================================================================
def q2c(y):
- """Convert from quads in y to complex numbers in z.
-
"""
+ Convert from quads in y to complex numbers in z.
+ """
+
j2 = (np.sqrt(0.5) * np.array([1, 1j])).astype(appropriate_complex_type_for(y))
# Arrange pixels from the corners of the quads into
@@ -315,7 +316,8 @@ def q2c(y):
return z
def c2q(w,gain):
- """Scale by gain and convert from complex w(:,:,1:2) to real quad-numbers
+ """
+ Scale by gain and convert from complex w(:,:,1:2) to real quad-numbers
in z.
Arrange pixels from the real and imag parts of the 2 subbands
diff --git a/dtcwt/backend/backend_numpy/transform3d.py b/dtcwt/backend/backend_numpy/transform3d.py
index 7e784d3..2e5da6c 100644
--- a/dtcwt/backend/backend_numpy/transform3d.py
+++ b/dtcwt/backend/backend_numpy/transform3d.py
@@ -37,7 +37,7 @@ class Transform3d(Transform3dBase):
self.ext_mode = ext_mode
- def forward(self, X, nlevels=3, discard_level_1=False):
+ def forward(self, X, nlevels=3, include_scale=False, discard_level_1=False):
"""Perform a *n*-level DTCWT-3D decompostion on a 3D matrix *X*.
:param X: 3D real array-like object
@@ -88,6 +88,9 @@ class Transform3d(Transform3dBase):
"""
X = np.atleast_3d(asfarray(X))
+ # If biort has 6 elements instead of 4, then it's a modified
+ # rotationally symmetric wavelet
+ # FIXME: there's probably a nicer way to do this
if len(self.biort) == 4:
h0o, g0o, h1o, g1o = self.biort
elif len(self.biort) == 6:
@@ -95,6 +98,9 @@ class Transform3d(Transform3dBase):
else:
raise ValueError('Biort wavelet must have 6 or 4 components.')
+ # If qshift has 12 elements instead of 8, then it's a modified
+ # rotationally symmetric wavelet
+ # FIXME: there's probably a nicer way to do this
if len(self.qshift) == 8:
h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = self.qshift
elif len(self.qshift) == 12:
@@ -108,18 +114,31 @@ class Transform3d(Transform3dBase):
Yl = X
Yh = [None,] * nlevels
+
+ if include_scale:
+ # this is only required if the user specifies a third output component.
+ Yscale = [None,] * nlevels
+
#pdb.set_trace()
# level is 0-indexed
for level in xrange(nlevels):
# Transform
if level == 0 and discard_level_1:
Yl = _level1_xfm_no_highpass(Yl, h0o, h1o, self.ext_mode)
+ if include_scale:
+ Yscale[0] = Yl
elif level == 0 and not discard_level_1:
Yl, Yh[level] = _level1_xfm(Yl, h0o, h1o, self.ext_mode)
+ if include_scale:
+ Yscale[0] = Yl
else:
Yl, Yh[level] = _level2_xfm(Yl, h0a, h0b, h1a, h1b, self.ext_mode)
- #FIXME: need some way to separate the Yscale component to include the scale when necessary.
- return TransformDomainSignal(Yl, tuple(Yh))
+ if include_scale:
+ Yscale[level] = Yl
+ if include_scale:
+ return TransformDomainSignal(Yl, tuple(Yh), tuple(Yscale))
+ else:
+ return TransformDomainSignal(Yl, tuple(Yh))
def inverse(self, td_signal):
"""Perform an *n*-level dual-tree complex wavelet (DTCWT) 3D
@@ -374,13 +393,13 @@ def _level2_xfm(X, h0a, h0b, h1a, h1b, ext_mode):
cube2c(work[s0a, s1a, s2b]), # LLH
cube2c(work[s0a, s1b, s2b]), # HLH
cube2c(work[s0b, s1a, s2b]), # LHH
- cube2c(work[s0b, s1b, s2b]), # HLH
+ cube2c(work[s0b, s1b, s2b]), # HHH
), axis=3)
)
def _level1_ifm(Yl, Yh, g0o, g1o):
- """Perform level 1 of the inverse 3d transform.
-
+ """
+ Perform level 1 of the inverse 3d transform.
"""
# Create work area
work = np.zeros(np.asanyarray(Yl.shape) * 2, dtype=Yl.dtype)
@@ -559,14 +578,25 @@ def cube2c(y):
# TODO: check if the above should be the below and, if so, fix c2cube
#
- # A = y[0::2, 0::2, 0::2]
- # B = y[0::2, 0::2, 1::2]
- # C = y[0::2, 1::2, 0::2]
- # D = y[0::2, 1::2, 1::2]
- # E = y[1::2, 0::2, 0::2]
- # F = y[1::2, 0::2, 1::2]
- # G = y[1::2, 1::2, 0::2]
- # H = y[1::2, 1::2, 1::2]
+ #A = y[0::2, 0::2, 0::2]
+ #B = y[0::2, 0::2, 1::2]
+ #C = y[0::2, 1::2, 0::2]
+ #D = y[0::2, 1::2, 1::2]
+ #E = y[1::2, 0::2, 0::2]
+ #F = y[1::2, 0::2, 1::2]
+ #G = y[1::2, 1::2, 0::2]
+ #H = y[1::2, 1::2, 1::2]
+
+ # TODO: check if the above should be the below and, if so, fix c2cube
+ #
+ #A = y[0::2, 0::2, 0::2]
+ #B = y[0::2, 1::2, 0::2]
+ #C = y[1::2, 0::2, 0::2]
+ #D = y[1::2, 1::2, 0::2]
+ #E = y[0::2, 0::2, 1::2]
+ #F = y[0::2, 1::2, 1::2]
+ #G = y[1::2, 0::2, 1::2]
+ #H = y[1::2, 1::2, 1::2]
# Combine to form subbands
p = ( A-G-D-F) * j2[0] + ( B-H+C+E) * j2[1]
diff --git a/dtcwt/transform3d.py b/dtcwt/transform3d.py
index 1199812..63f6651 100644
--- a/dtcwt/transform3d.py
+++ b/dtcwt/transform3d.py
@@ -158,7 +158,7 @@ def dtwaveifm3(Yl, Yh, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, ext_mode=4):
else:
prev_shape = np.array(Yh[-level-1].shape) * 2
- Yl = _level2_ifm(Yl, Yh[-level-1], g0a, g0b, g1a, g1b, self.ext_mode, prev_shape)
+ Yl = _level2_ifm(Yl, Yh[-level-1], g0a, g0b, g1a, g1b, ext_mode, prev_shape)
return Yl
diff --git a/matlab/gen_verif.m b/matlab/gen_verif.m
index 45a27b2..8f2de14 100644
--- a/matlab/gen_verif.m
+++ b/matlab/gen_verif.m
@@ -46,4 +46,17 @@ for sb=1:6
lena_upsample(:,:,sb) = tmp;
end
-save('verification.mat', 'lena_coldfilt', 'lena_colifilt', 'lena_Yl', 'lena_Yh', 'lena_Yscale', 'lena_Ylb', 'lena_Yhb', 'lena_Yscaleb', 'lena_upsample');
+%% Generate quantized bandlimited gaussian noise (gbgn) phantom
+strFilePath=[fileparts(which(mfilename('fullpath'))) '/'];
+addpath([strFilePath 'qbgn/']);
+%generate quantized band-limited gaussian noise, and case to 8-bit to save space
+qbgn = uint8(gen_qbgn(128,128));
+%take the 3D wavelet transform, which defaults to near_sym_a (5,7) and qshift_b (14 taps)
+[qbgn_Yl, qbgn_Yh, qbgn_Yscale] = dtwavexfm3(double(qbgn), 3);
+%now re-arrange the coefficients to form complex-valued high-pass subbands instead of alternating real/imag parts
+qbgn_Yh = ri2c(qbgn_Yh);
+
+save('../tests/qbgn.mat','qbgn');
+
+save('verification.mat', 'lena_coldfilt', 'lena_colifilt', 'lena_Yl', 'lena_Yh', 'lena_Yscale', ...
+ 'lena_Ylb', 'lena_Yhb', 'lena_Yscaleb', 'lena_upsample', 'qbgn_Yl', 'qbgn_Yh', 'qbgn_Yscale');
diff --git a/matlab/qbgn/gaussian.m b/matlab/qbgn/gaussian.m
new file mode 100644
index 0000000..bd78b1b
--- /dev/null
+++ b/matlab/qbgn/gaussian.m
@@ -0,0 +1,27 @@
+function h = gaussian(sigma,thresh)
+
+% function h = gaussian(sigma,thresh)
+%
+% Generate a gaussian vector / impulse response, h, with
+% a standard deviation of sigma samples and a total value of unity.
+% The length of h is odd and is truncated at the points where
+% exp(-x^2/2) < thresh. By default, thresh = 0.01.
+%
+% Nick Kingsbury, Cambridge University, Nov 2005.
+
+if nargin < 2, thresh = 0.01; end
+
+% Solve for when exp(-x^2 / 2*sigma^2) = thresh
+xmax = sigma * sqrt(max(-2*log(thresh),1e-6));
+
+% Calculate h over the range when exp(-x^2/2) >= thresh.
+n = floor(xmax);
+x = [-n:n]';
+h = exp(x.*x/(-2*sigma*sigma));
+
+% Normalise h so it sums to unity.
+h = h / sum(h);
+
+return;
+
+
diff --git a/matlab/qbgn/gen_qbgn.m b/matlab/qbgn/gen_qbgn.m
new file mode 100644
index 0000000..8fa9d8d
--- /dev/null
+++ b/matlab/qbgn/gen_qbgn.m
@@ -0,0 +1,31 @@
+% gen_qbgn.m
+% Generate and display a 3-D fruitcake-like volume.
+%
+% Nick Kingsbury, Cambridge University, Nov 2011.
+function bg=gen_qbgn(N,F,seed)
+ if nargin<3
+ seed=0;
+ end
+ Q=4;
+ h = gaussian(3,0.01);
+ h = gaussian(1.5,0.01);
+ h = gaussian(8,0.01);
+ y = zeros(N,N,F);
+ RandStream.setGlobalStream(RandStream('mt19937ar', 'seed', seed));
+ for f=1:F,
+ x=randn(N,N);
+
+ y(:,:,f) = colfilter(colfilter(x,h).',h).';
+ end
+ for k=1:N,
+ y(k,:,:) = colfilter(squeeze(y(k,:,:)).',h).';
+ end
+ yvar = var(y(:));
+ y = y*(16/sqrt(yvar)) + 128;
+ t1=ceil(min(y(:)));
+ t2=floor(max(y(:)));
+
+ step=(t2-t1)/50;
+ z = double(softquant(y,[t1 t2 step]));
+ bg=z;
+end
\ No newline at end of file
diff --git a/matlab/qbgn/softquant.m b/matlab/qbgn/softquant.m
new file mode 100644
index 0000000..350472d
--- /dev/null
+++ b/matlab/qbgn/softquant.m
@@ -0,0 +1,30 @@
+function y = softquant(x,q)
+% function y = softquant(x,q)
+%
+% Soft quantise x, using quantiser thresholds specified by q as follows:
+% q = [qmin qmax stepsize];
+% As x increases linearly, y will smoothly transition
+% between adjacent quantiser steps around the nominal transition point.
+%
+% Nick Kingsbury, Cambridge University, Nov 2011.
+
+% Quantise x.
+
+% Apply an offset and a scale factor to x so that y goes from zero to 2*pi*ymax
+% and the stepsize is 2*pi.
+ymax = round((q(2) - q(1))/q(3)); % no. of steps.
+y = (x - q(1)) * (1/q(3));
+y = (2*pi) * max(min((x - q(1)) * (1/q(3)),ymax),0);
+
+% Apply the non-linear periodic function that is periodic over 2*pi, such
+% that it generates smooth steps and has approx zero gradient in between
+% the transition regions.
+y = y - 1.4*sin(y) + 0.2*sin(2*y);
+
+% Rescale y back to the original range of x and reinsert the offset.
+y = y * (q(3)*0.5/pi) + q(1);
+return;
+
+
+
+
diff --git a/matlab/verif_m_to_npz.py b/matlab/verif_m_to_npz.py
index 249d0ee..a6af710 100755
--- a/matlab/verif_m_to_npz.py
+++ b/matlab/verif_m_to_npz.py
@@ -14,6 +14,11 @@ def _mean(a, axis=None, *args, **kwargs):
return rv
+def centre_indices(ndim=2,apron=8):
+ """Returns the centre indices for the correct number of dimension
+ """
+ return tuple([slice(apron,-apron) for i in xrange(ndim)])
+
def summarise_mat(M, apron=8):
"""HACK to provide a 'summary' matrix consisting of the corners of the
matrix and summed versions of the sub matrices.
@@ -30,8 +35,17 @@ def summarise_mat(M, apron=8):
np.hstack((M[-apron:,:apron,...], _mean(M[-apron:,apron:-apron,...], axis=1), M[-apron:,-apron:,...])),
))
-verif = loadmat('verification.mat')
-verif = dict((k,v) for k, v in verif.iteritems() if not k.startswith('_'))
+def summarise_cube(M, apron=4):
+ """Provide a summary cube, extending summarise_mat to 3D
+ """
+ return np.dstack(
+ [summarise_mat(M[:,:,i,...], apron) for i in xrange(M.shape[-2])]
+ )
+
+verif_temp = loadmat('verification.mat')
+verif = dict((k,v) for k, v in verif_temp.iteritems() if (not k.startswith('_') and not k.startswith('qbgn')))
+verif_cube = dict((k,v) for k, v in verif_temp.iteritems() if (not k.startswith('_') and k.startswith('qbgn')))
+del verif_temp
for idx, v in enumerate(verif['lena_Yh']):
verif['lena_Yh_{0}'.format(idx)] = v[0]
@@ -49,6 +63,16 @@ for idx, v in enumerate(verif['lena_Yscaleb']):
verif['lena_Yscaleb_{0}'.format(idx)] = v[0]
del verif['lena_Yscaleb']
-summaries = dict((k, summarise_mat(v)) for k, v in verif.iteritems())
+for idx, v in enumerate(verif_cube['qbgn_Yh']):
+ verif_cube['qbgn_Yh_{0}'.format(idx)] = v[0]
+del verif_cube['qbgn_Yh']
+
+for idx, v in enumerate(verif_cube['qbgn_Yscale']):
+ verif_cube['qbgn_Yscale_{0}'.format(idx)] = v[0]
+del verif_cube['qbgn_Yscale']
+summaries = dict((k, summarise_mat(v)) for k, v in verif.iteritems())
+for k,v in verif_cube.iteritems():
+ summaries[k] = summarise_cube(v)
+
savez('../tests/verification.npz', **summaries)
diff --git a/tests/qbgn.mat b/tests/qbgn.mat
new file mode 100644
index 0000000..1d9445d
Binary files /dev/null and b/tests/qbgn.mat differ
diff --git a/tests/testagainstmatlab.py b/tests/testagainstmatlab.py
index c7ad694..b37ab1c 100644
--- a/tests/testagainstmatlab.py
+++ b/tests/testagainstmatlab.py
@@ -3,11 +3,15 @@ from nose.tools import raises
from nose.plugins.attrib import attr
import numpy as np
-from dtcwt import dtwavexfm2, dtwaveifm2, dtwavexfm2b, dtwaveifm2b, biort, qshift
+
+from scipy.io import loadmat
+from dtcwt import dtwavexfm2, dtwavexfm3, dtwaveifm2, dtwavexfm2b, dtwaveifm2b, biort, qshift
from dtcwt.lowlevel import coldfilt, colifilt
from dtcwt.sampling import rescale_highpass
-from .util import assert_almost_equal, summarise_mat, assert_percentile_almost_equal
+from dtcwt.backend.base import TransformDomainSignal, ReconstructedSignal
+
+from .util import assert_almost_equal, summarise_mat, summarise_cube, assert_percentile_almost_equal
## IMPORTANT NOTE ##
@@ -21,7 +25,7 @@ from .util import assert_almost_equal, summarise_mat, assert_percentile_almost_e
# G | H | I
#
# Where A, C, G and I are NxN and N is some agreed 'apron' size. E is replaced
-# my it's element-wise mean and thus becomes 1x1. The remaining matrices are
+# by it's element-wise mean and thus becomes 1x1. The remaining matrices are
# replaced by the element-wise mean along the apropriate axis to result in a
# (2N+1) x (2N+1) matrix. These matrices are compared.
#
@@ -41,10 +45,19 @@ def assert_almost_equal_to_summary(a, summary, *args, **kwargs):
def assert_percentile_almost_equal_to_summary(a, summary, *args, **kwargs):
assert_percentile_almost_equal(summarise_mat(a), summary, *args, **kwargs)
+def assert_almost_equal_to_summary_cube(a, summary, *args, **kwargs):
+ assert_almost_equal(summarise_cube(a), summary, *args, **kwargs)
+
+def assert_percentile_almost_equal_to_summary_cube(a, summary, *args, **kwargs):
+ assert_percentile_almost_equal(summarise_cube(a), summary, *args, **kwargs)
+
def setup():
global lena
lena = np.load(os.path.join(os.path.dirname(__file__), 'lena.npz'))['lena']
+ global qbgn
+ qbgn = loadmat(os.path.join(os.path.dirname(__file__), 'qbgn.mat'))['qbgn']
+
global verif
verif = np.load(os.path.join(os.path.dirname(__file__), 'verification.npz'))
@@ -101,4 +114,16 @@ def test_rescale_highpass():
# quite an amount. Use a percentile approach to look at the bigger picture.
assert_percentile_almost_equal_to_summary(Xrescale, verif['lena_upsample'], 60, tolerance=TOLERANCE)
+def test_transform3d_numpy():
+ from dtcwt.backend.backend_numpy import Transform3d
+ transform = Transform3d(biort='near_sym_a',qshift='qshift_b')
+ td_signal = transform.forward(qbgn, nlevels=3, include_scale=True, discard_level_1=False)
+ Yl, Yh, Yscale = td_signal.lowpass, td_signal.subbands, td_signal.scales
+ #assert_almost_equal_to_summary_cube(Yl, verif['qbgn_Yl'], tolerance=TOLERANCE)
+ for idx, a in enumerate(Yh):
+ assert_almost_equal_to_summary_cube(a, verif['qbgn_Yh_{0}'.format(idx)], tolerance=TOLERANCE)
+
+ for idx, a in enumerate(Yscale):
+ assert_almost_equal_to_summary_cube(a, verif['qbgn_Yscale_{0}'.format(idx)], tolerance=TOLERANCE)
+
# vim:sw=4:sts=4:et
diff --git a/tests/util.py b/tests/util.py
index 0c86b49..567cbc9 100644
--- a/tests/util.py
+++ b/tests/util.py
@@ -34,6 +34,11 @@ def _mean(a, axis=None, *args, **kwargs):
return rv
+def centre_indices(ndim=2,apron=8):
+ """Returns the centre indices for the correct number of dimension
+ """
+ return tuple([slice(apron,-apron) for i in xrange(ndim)])
+
def summarise_mat(M, apron=8):
"""HACK to provide a 'summary' matrix consisting of the corners of the
matrix and summed versions of the sub matrices.
@@ -50,6 +55,13 @@ def summarise_mat(M, apron=8):
np.hstack((M[-apron:,:apron,...], _mean(M[-apron:,apron:-apron,...], axis=1), M[-apron:,-apron:,...])),
))
+def summarise_cube(M, apron=4):
+ """Provide a summary cube, extending summarise_mat to 3D
+ """
+ return np.dstack(
+ [summarise_mat(M[:,:,i,...], apron) for i in xrange(M.shape[-2])]
+ )
+
def skip_if_no_cl(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
diff --git a/tests/verification.npz b/tests/verification.npz
index 5c8c5de..fa586ce 100644
Binary files a/tests/verification.npz and b/tests/verification.npz differ
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/python-dtcwt.git
More information about the debian-science-commits
mailing list