[python-dtcwt] 96/497: handle float32 input correctly
Ghislain Vaillant
ghisvail-guest at moszumanska.debian.org
Tue Jul 21 18:05:53 UTC 2015
This is an automated email from the git hooks/post-receive script.
ghisvail-guest pushed a commit to branch debian/sid
in repository python-dtcwt.
commit 4a1d822bf53ae76cef130e0fe95921fd3ba4571f
Author: Rich Wareham <rjw57 at cam.ac.uk>
Date: Fri Aug 9 16:45:16 2013 +0100
handle float32 input correctly
When presented with a float32 input, the library should perform a 32-bit
single precision transform and not a full double-precision one. Add
test cases for float32 and make the transform library pass.
---
dtcwt/lowlevel.py | 45 ++++++++++++++++++++++++++++-----------------
dtcwt/transform1d.py | 10 +++++-----
dtcwt/transform2d.py | 6 +++---
dtcwt/transform3d.py | 16 ++++++++--------
tests/testifm1.py | 9 +++++++++
tests/testifm2.py | 9 +++++++++
tests/testxfm1.py | 6 ++++++
tests/testxfm2.py | 6 ++++++
tests/testxfm3.py | 16 +++++++++++++++-
9 files changed, 89 insertions(+), 34 deletions(-)
diff --git a/dtcwt/lowlevel.py b/dtcwt/lowlevel.py
index 588a690..878e767 100644
--- a/dtcwt/lowlevel.py
+++ b/dtcwt/lowlevel.py
@@ -1,6 +1,15 @@
import numpy as np
from six.moves import xrange
+def asfarray(X):
+ """Similar to :py:func:`numpy.asfarray` except that this function tries to
+ preserve the original datatype of X if it is already a floating point type
+ and will pass floating point arrays through directly without copying.
+
+ """
+ X = np.asanyarray(X)
+ return np.asfarray(X, dtype=X.dtype)
+
def appropriate_complex_type_for(X):
"""Return an appropriate complex data type depending on the type of X. If X
is already complex, return that, if it is floating point return a complex
@@ -8,13 +17,13 @@ def appropriate_complex_type_for(X):
floating point type depending on the result of :py:func:`numpy.asfarray`.
"""
- X = np.asfarray(X)
+ X = asfarray(X)
- if X.dtype is np.complex64 or X.dtype is np.complex128:
+ if np.issubsctype(X.dtype, np.complex64) or np.issubsctype(X.dtype, np.complex128):
return X.dtype
- elif X.dtype is np.float32:
+ elif np.issubsctype(X.dtype, np.float32):
return np.complex64
- elif X.dtype is np.float64:
+ elif np.issubsctype(X.dtype, np.float64):
return np.complex128
# God knows, err on the side of caution
@@ -33,7 +42,7 @@ def as_column_vector(v):
def _centered(arr, newsize):
# Return the center newsize portion of the array.
# (Shamelessly cribbed from scipy.)
- newsize = np.asarray(newsize)
+ newsize = np.asanyarray(newsize)
currsize = np.array(arr.shape)
startind = (currsize - newsize) // 2
endind = startind + newsize
@@ -46,13 +55,15 @@ _irfft = np.fft.irfft
def _column_convolve(X, h):
"""Convolve the columns of *X* with *h* returning only the 'valid' section,
- i.e. those values unaffected by zero padding.
+ i.e. those values unaffected by zero padding. Irrespective of the ftype of
+ *h*, the output will have the dtype of *X* appropriately expanded to a
+ floating point type if necessary.
We assume that h is small and so direct convolution is the most efficient.
"""
- Xshape = np.asarray(X.shape)
- h = h.flatten()
+ Xshape = np.asanyarray(X.shape)
+ h = h.flatten().astype(X.dtype)
h_size = h.shape[0]
full_size = X.shape[0] + h_size - 1
@@ -114,7 +125,7 @@ def colfilter(X, h):
"""
# Interpret all inputs as arrays
- X = np.array(X)
+ X = asfarray(X)
h = as_column_vector(h)
r, c = X.shape
@@ -162,9 +173,9 @@ def coldfilt(X, ha, hb):
"""
# Make sure all inputs are arrays
- X = np.array(X)
- ha = np.array(ha)
- hb = np.array(hb)
+ X = asfarray(X)
+ ha = asfarray(ha)
+ hb = asfarray(hb)
r, c = X.shape
if r % 4 != 0:
@@ -190,7 +201,7 @@ def coldfilt(X, ha, hb):
hbe = as_column_vector(hb[1:m:2])
t = np.arange(5, r+2*m-2, 4)
r2 = r/2;
- Y = np.zeros((r2,c))
+ Y = np.zeros((r2,c), dtype=X.dtype)
if np.sum(ha*hb) > 0:
s1 = slice(0, r2, 2)
@@ -233,9 +244,9 @@ def colifilt(X, ha, hb):
"""
# Make sure all inputs are arrays
- X = np.array(X)
- ha = np.array(ha)
- hb = np.array(hb)
+ X = asfarray(X)
+ ha = asfarray(ha)
+ hb = asfarray(hb)
r, c = X.shape
if r % 2 != 0:
@@ -250,7 +261,7 @@ def colifilt(X, ha, hb):
m = ha.shape[0]
m2 = np.fix(m*0.5)
- Y = np.zeros((r*2,c))
+ Y = np.zeros((r*2,c), dtype=X.dtype)
if not np.any(np.nonzero(X[:])[0]):
return Y
diff --git a/dtcwt/transform1d.py b/dtcwt/transform1d.py
index c12cb60..b6c54e0 100644
--- a/dtcwt/transform1d.py
+++ b/dtcwt/transform1d.py
@@ -5,7 +5,7 @@ from six.moves import xrange
from dtcwt import biort as _biort, qshift as _qshift
from dtcwt.defaults import DEFAULT_BIORT, DEFAULT_QSHIFT
-from dtcwt.lowlevel import colfilter, coldfilt, colifilt, as_column_vector
+from dtcwt.lowlevel import colfilter, coldfilt, colifilt, as_column_vector, asfarray
def dtwavexfm(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, include_scale=False):
"""Perform a *n*-level DTCWT decompostion on a 1D column vector *X* (or on
@@ -38,7 +38,7 @@ def dtwavexfm(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, include_
"""
# Need this because colfilter and friends assumes input is 2d
- X = np.asfarray(X)
+ X = asfarray(X)
if len(X.shape) == 1:
X = np.atleast_2d(X).T
@@ -54,7 +54,7 @@ def dtwavexfm(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, include_
except TypeError:
h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = qshift
- L = np.asarray(X.shape)
+ L = np.asanyarray(X.shape)
# ensure that X is an even length, thus enabling it to be extended if needs be.
if X.shape[0] % 2 != 0:
@@ -162,7 +162,7 @@ def dtwaveifm(Yl, Yh, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, gain_mask=None
if Lo.shape[0] != 2*Yh[level-1].shape[0]: # If Lo is not the same length as the next Yh => t1 was extended.
Lo = Lo[1:-1,...] # Therefore we have to clip Lo so it is the same height as the next Yh.
- if np.any(np.asarray(Lo.shape) != np.asarray(Yh[level-1].shape * np.array((2,1)))):
+ if np.any(np.asanyarray(Lo.shape) != np.asanyarray(Yh[level-1].shape * np.array((2,1)))):
raise ValueError('Yh sizes are not valid for DTWAVEIFM')
level -= 1
@@ -187,7 +187,7 @@ def c2q1d(x):
"""
a, b = x.shape
- z = np.zeros((a*2, b))
+ z = np.zeros((a*2, b), dtype=x.real.dtype)
z[::2, :] = np.real(x)
z[1::2, :] = np.imag(x)
diff --git a/dtcwt/transform2d.py b/dtcwt/transform2d.py
index f500e06..57b7bb3 100644
--- a/dtcwt/transform2d.py
+++ b/dtcwt/transform2d.py
@@ -5,7 +5,7 @@ from six.moves import xrange
from dtcwt import biort as _biort, qshift as _qshift
from dtcwt.defaults import DEFAULT_BIORT, DEFAULT_QSHIFT
-from dtcwt.lowlevel import colfilter, coldfilt, colifilt, appropriate_complex_type_for
+from dtcwt.lowlevel import colfilter, coldfilt, colifilt, appropriate_complex_type_for, asfarray
def dtwavexfm2(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, include_scale=False):
"""Perform a *n*-level DTCWT-2D decompostion on a 2D matrix *X*.
@@ -36,7 +36,7 @@ def dtwavexfm2(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, include
.. codeauthor:: Cian Shaffrey, Cambridge University, Sept 2001
"""
- X = np.atleast_2d(np.asfarray(X))
+ X = np.atleast_2d(asfarray(X))
# Try to load coefficients if biort is a string parameter
try:
@@ -289,7 +289,7 @@ def c2q(w,gain):
"""
- x = np.zeros((w.shape[0] << 1, w.shape[1] << 1))
+ x = np.zeros((w.shape[0] << 1, w.shape[1] << 1), dtype=w.real.dtype)
sc = np.sqrt(0.5) * gain
P = w[:,:,0]*sc[0] + w[:,:,1]*sc[1]
diff --git a/dtcwt/transform3d.py b/dtcwt/transform3d.py
index 788910d..5ce71ca 100644
--- a/dtcwt/transform3d.py
+++ b/dtcwt/transform3d.py
@@ -5,7 +5,7 @@ from six.moves import xrange
from dtcwt import biort as _biort, qshift as _qshift
from dtcwt.defaults import DEFAULT_BIORT, DEFAULT_QSHIFT
-from dtcwt.lowlevel import colfilter, coldfilt, colifilt
+from dtcwt.lowlevel import colfilter, coldfilt, colifilt, asfarray
def dtwavexfm3(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, ext_mode=4):
"""Perform a *n*-level DTCWT-3D decompostion on a 3D matrix *X*.
@@ -49,7 +49,7 @@ def dtwavexfm3(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, ext_mod
.. codeauthor:: Nick Kingsbury, Cambridge University, July 1999.
"""
- X = np.atleast_3d(np.asfarray(X))
+ X = np.atleast_3d(asfarray(X))
# Try to load coefficients if biort is a string parameter
try:
@@ -155,7 +155,7 @@ def _level1_xfm(X, h0o, h1o, ext_mode):
raise ValueError('Input shape should be a multiple of 4 in each direction when ext_mode == 8')
# Create work area
- work_shape = np.asarray(X.shape) * 2
+ work_shape = np.asanyarray(X.shape) * 2
# We need one extra row per octant if filter length is even
if h0o.shape[0] % 2 == 0:
@@ -246,7 +246,7 @@ def _level2_xfm(X, h0a, h0b, h1a, h1b, ext_mode):
X = np.concatenate((X[:,:,(0,0)], X, X[:,:,(-1,-1)]), 2)
# Create work area
- work_shape = np.asarray(X.shape)
+ work_shape = np.asanyarray(X.shape)
work = np.zeros(work_shape, dtype=X.dtype)
# Form some useful slices
@@ -298,10 +298,10 @@ def _level1_ifm(Yl, Yh, g0o, g1o):
"""
# Create work area
- work = np.zeros(np.asarray(Yl.shape) * 2, dtype=Yl.dtype)
+ work = np.zeros(np.asanyarray(Yl.shape) * 2, dtype=Yl.dtype)
# Work out shape of output
- Xshape = np.asarray(work.shape) >> 1
+ Xshape = np.asanyarray(work.shape) >> 1
if g0o.shape[0] % 2 == 0:
# if we have an even length filter, we need to shrink the output by 1
# to compensate for the addition of an extra row/column/slice in
@@ -355,7 +355,7 @@ def _level2_ifm(Yl, Yh, g0a, g0b, g1a, g1b, ext_mode, prev_level_size):
"""
# Create work area
- work = np.zeros(np.asarray(Yl.shape)*2, dtype=Yl.dtype)
+ work = np.zeros(np.asanyarray(Yl.shape)*2, dtype=Yl.dtype)
# Form some useful slices
s0a = slice(None, work.shape[0] >> 1)
@@ -507,7 +507,7 @@ def c2cube(z):
rr, ri = r.real, r.imag
sr, si = s.real, s.imag
- y = np.zeros(np.asarray(z.shape[:3])*2, dtype=z.real.dtype)
+ y = np.zeros(np.asanyarray(z.shape[:3])*2, dtype=z.real.dtype)
y[1::2, 1::2, 1::2] = ( pr+qr+rr+sr)
y[0::2, 0::2, 1::2] = (-pr-qr+rr+sr)
diff --git a/tests/testifm1.py b/tests/testifm1.py
index 8b906fa..d208e98 100644
--- a/tests/testifm1.py
+++ b/tests/testifm1.py
@@ -23,4 +23,13 @@ def test_reconstruct_2d():
vec_recon = dtwaveifm(Yl, Yh)
assert np.all(np.abs(vec_recon - vec) < TOLERANCE)
+def test_float32_input():
+ # Check that an float32 input is correctly output as float32
+ Yl, Yh = dtwavexfm(np.array([1, 2, 3, 4]).astype(np.float32))
+ assert np.issubsctype(Yl.dtype, np.float32)
+ assert np.all(list(np.issubsctype(x.dtype, np.complex64) for x in Yh))
+
+ recon = dtwaveifm(Yl, Yh)
+ assert np.issubsctype(recon.dtype, np.float32)
+
# vim:sw=4:sts=4:et
diff --git a/tests/testifm2.py b/tests/testifm2.py
index 02f370f..24095d1 100644
--- a/tests/testifm2.py
+++ b/tests/testifm2.py
@@ -39,5 +39,14 @@ def test_reconstruct_custom_filter():
lena_recon = dtwaveifm2(Yl, Yh, biort('legall'), qshift('qshift_06'))
assert np.all(np.abs(lena_recon - lena) < TOLERANCE)
+def test_float32_input():
+ # Check that an float32 input is correctly output as float32
+ Yl, Yh = dtwavexfm2(lena.astype(np.float32))
+ assert np.issubsctype(Yl.dtype, np.float32)
+ assert np.all(list(np.issubsctype(x.dtype, np.complex64) for x in Yh))
+
+ lena_recon = dtwaveifm2(Yl, Yh)
+ assert np.issubsctype(lena_recon.dtype, np.float32)
+
# vim:sw=4:sts=4:et
diff --git a/tests/testxfm1.py b/tests/testxfm1.py
index 408dca6..15265ed 100644
--- a/tests/testxfm1.py
+++ b/tests/testxfm1.py
@@ -74,4 +74,10 @@ def test_integer_perfect_recon():
B = dtwaveifm(Yl, Yh)
assert np.max(np.abs(A-B)) < 1e-12
+def test_float32_input():
+ # Check that an float32 input is correctly output as float32
+ Yl, Yh = dtwavexfm(np.array([1,2,3,4]).astype(np.float32))
+ assert np.issubsctype(Yl.dtype, np.float32)
+ assert np.all(list(np.issubsctype(x.dtype, np.complex64) for x in Yh))
+
# vim:sw=4:sts=4:et
diff --git a/tests/testxfm2.py b/tests/testxfm2.py
index 94775fb..97595a4 100644
--- a/tests/testxfm2.py
+++ b/tests/testxfm2.py
@@ -77,4 +77,10 @@ def test_integer_perfect_recon():
B = dtwaveifm2(Yl, Yh)
assert np.max(np.abs(A-B)) < 1e-5
+def test_float32_input():
+ # Check that an float32 input is correctly output as float32
+ Yl, Yh = dtwavexfm2(lena.astype(np.float32))
+ assert np.issubsctype(Yl.dtype, np.float32)
+ assert np.all(list(np.issubsctype(x.dtype, np.complex64) for x in Yh))
+
# vim:sw=4:sts=4:et
diff --git a/tests/testxfm3.py b/tests/testxfm3.py
index cca6070..99f727f 100644
--- a/tests/testxfm3.py
+++ b/tests/testxfm3.py
@@ -56,7 +56,6 @@ def test_simple_level_1_recon_haar():
Yl, Yh = dtwavexfm3(ellipsoid, 1, biort=haar)
ellipsoid_recon = dtwaveifm3(Yl, Yh, biort=haar)
assert ellipsoid.size == ellipsoid_recon.size
- print(np.max(np.abs(ellipsoid - ellipsoid_recon)))
assert np.max(np.abs(ellipsoid - ellipsoid_recon)) < TOLERANCE
def test_simple_level_2_xfm():
@@ -136,4 +135,19 @@ def test_integer_perfect_recon():
B = dtwaveifm3(Yl, Yh)
assert np.max(np.abs(A-B)) < 1e-12
+def test_float32_input():
+ # Check that an float32 input is correctly output as float32
+ Yl, Yh = dtwavexfm3(ellipsoid.astype(np.float32))
+ assert np.issubsctype(Yl.dtype, np.float32)
+ assert np.all(list(np.issubsctype(x.dtype, np.complex64) for x in Yh))
+
+def test_float32_recon():
+ # Check that an float32 input is correctly output as float32
+ Yl, Yh = dtwavexfm3(ellipsoid.astype(np.float32))
+ assert np.issubsctype(Yl.dtype, np.float32)
+ assert np.all(list(np.issubsctype(x.dtype, np.complex64) for x in Yh))
+
+ recon = dtwaveifm3(Yl, Yh)
+ assert np.issubsctype(recon.dtype, np.float32)
+
# vim:sw=4:sts=4:et
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/python-dtcwt.git
More information about the debian-science-commits
mailing list