[python-dtcwt] 96/497: handle float32 input correctly

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Tue Jul 21 18:05:53 UTC 2015


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch debian/sid
in repository python-dtcwt.

commit 4a1d822bf53ae76cef130e0fe95921fd3ba4571f
Author: Rich Wareham <rjw57 at cam.ac.uk>
Date:   Fri Aug 9 16:45:16 2013 +0100

    handle float32 input correctly
    
    When presented with a float32 input, the library should perform a 32-bit
    single precision transform and not a full double-precision one.  Add
    test cases for float32 and make the transform library pass.
---
 dtcwt/lowlevel.py    | 45 ++++++++++++++++++++++++++++-----------------
 dtcwt/transform1d.py | 10 +++++-----
 dtcwt/transform2d.py |  6 +++---
 dtcwt/transform3d.py | 16 ++++++++--------
 tests/testifm1.py    |  9 +++++++++
 tests/testifm2.py    |  9 +++++++++
 tests/testxfm1.py    |  6 ++++++
 tests/testxfm2.py    |  6 ++++++
 tests/testxfm3.py    | 16 +++++++++++++++-
 9 files changed, 89 insertions(+), 34 deletions(-)

diff --git a/dtcwt/lowlevel.py b/dtcwt/lowlevel.py
index 588a690..878e767 100644
--- a/dtcwt/lowlevel.py
+++ b/dtcwt/lowlevel.py
@@ -1,6 +1,15 @@
 import numpy as np
 from six.moves import xrange
 
+def asfarray(X):
+    """Similar to :py:func:`numpy.asfarray` except that this function tries to
+    preserve the original datatype of X if it is already a floating point type
+    and will pass floating point arrays through directly without copying.
+
+    """
+    X = np.asanyarray(X)
+    return np.asfarray(X, dtype=X.dtype)
+
 def appropriate_complex_type_for(X):
     """Return an appropriate complex data type depending on the type of X. If X
     is already complex, return that, if it is floating point return a complex
@@ -8,13 +17,13 @@ def appropriate_complex_type_for(X):
     floating point type depending on the result of :py:func:`numpy.asfarray`.
 
     """
-    X = np.asfarray(X)
+    X = asfarray(X)
     
-    if X.dtype is np.complex64 or X.dtype is np.complex128:
+    if np.issubsctype(X.dtype, np.complex64) or np.issubsctype(X.dtype, np.complex128):
         return X.dtype
-    elif X.dtype is np.float32:
+    elif np.issubsctype(X.dtype, np.float32):
         return np.complex64
-    elif X.dtype is np.float64:
+    elif np.issubsctype(X.dtype, np.float64):
         return np.complex128
 
     # God knows, err on the side of caution
@@ -33,7 +42,7 @@ def as_column_vector(v):
 def _centered(arr, newsize):
     # Return the center newsize portion of the array.
     # (Shamelessly cribbed from scipy.)
-    newsize = np.asarray(newsize)
+    newsize = np.asanyarray(newsize)
     currsize = np.array(arr.shape)
     startind = (currsize - newsize) // 2
     endind = startind + newsize
@@ -46,13 +55,15 @@ _irfft = np.fft.irfft
 
 def _column_convolve(X, h):
     """Convolve the columns of *X* with *h* returning only the 'valid' section,
-    i.e. those values unaffected by zero padding.
+    i.e. those values unaffected by zero padding. Irrespective of the ftype of
+    *h*, the output will have the dtype of *X* appropriately expanded to a
+    floating point type if necessary.
 
     We assume that h is small and so direct convolution is the most efficient.
 
     """
-    Xshape = np.asarray(X.shape)
-    h = h.flatten()
+    Xshape = np.asanyarray(X.shape)
+    h = h.flatten().astype(X.dtype)
     h_size = h.shape[0]
 
     full_size = X.shape[0] + h_size - 1
@@ -114,7 +125,7 @@ def colfilter(X, h):
     """
     
     # Interpret all inputs as arrays
-    X = np.array(X)
+    X = asfarray(X)
     h = as_column_vector(h)
 
     r, c = X.shape
@@ -162,9 +173,9 @@ def coldfilt(X, ha, hb):
 
     """
     # Make sure all inputs are arrays
-    X = np.array(X)
-    ha = np.array(ha)
-    hb = np.array(hb)
+    X = asfarray(X)
+    ha = asfarray(ha)
+    hb = asfarray(hb)
 
     r, c = X.shape
     if r % 4 != 0:
@@ -190,7 +201,7 @@ def coldfilt(X, ha, hb):
     hbe = as_column_vector(hb[1:m:2])
     t = np.arange(5, r+2*m-2, 4)
     r2 = r/2;
-    Y = np.zeros((r2,c))
+    Y = np.zeros((r2,c), dtype=X.dtype)
 
     if np.sum(ha*hb) > 0:
        s1 = slice(0, r2, 2)
@@ -233,9 +244,9 @@ def colifilt(X, ha, hb):
 
     """
     # Make sure all inputs are arrays
-    X = np.array(X)
-    ha = np.array(ha)
-    hb = np.array(hb)
+    X = asfarray(X)
+    ha = asfarray(ha)
+    hb = asfarray(hb)
 
     r, c = X.shape
     if r % 2 != 0:
@@ -250,7 +261,7 @@ def colifilt(X, ha, hb):
     m = ha.shape[0]
     m2 = np.fix(m*0.5)
 
-    Y = np.zeros((r*2,c))
+    Y = np.zeros((r*2,c), dtype=X.dtype)
     if not np.any(np.nonzero(X[:])[0]):
         return Y
 
diff --git a/dtcwt/transform1d.py b/dtcwt/transform1d.py
index c12cb60..b6c54e0 100644
--- a/dtcwt/transform1d.py
+++ b/dtcwt/transform1d.py
@@ -5,7 +5,7 @@ from six.moves import xrange
 
 from dtcwt import biort as _biort, qshift as _qshift
 from dtcwt.defaults import DEFAULT_BIORT, DEFAULT_QSHIFT
-from dtcwt.lowlevel import colfilter, coldfilt, colifilt, as_column_vector
+from dtcwt.lowlevel import colfilter, coldfilt, colifilt, as_column_vector, asfarray
 
 def dtwavexfm(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, include_scale=False):
     """Perform a *n*-level DTCWT decompostion on a 1D column vector *X* (or on
@@ -38,7 +38,7 @@ def dtwavexfm(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, include_
 
     """
     # Need this because colfilter and friends assumes input is 2d
-    X = np.asfarray(X)
+    X = asfarray(X)
     if len(X.shape) == 1:
        X = np.atleast_2d(X).T
 
@@ -54,7 +54,7 @@ def dtwavexfm(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, include_
     except TypeError:
         h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = qshift
 
-    L = np.asarray(X.shape)
+    L = np.asanyarray(X.shape)
 
     # ensure that X is an even length, thus enabling it to be extended if needs be.
     if X.shape[0] % 2 != 0:
@@ -162,7 +162,7 @@ def dtwaveifm(Yl, Yh, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, gain_mask=None
        if Lo.shape[0] != 2*Yh[level-1].shape[0]:  # If Lo is not the same length as the next Yh => t1 was extended.
           Lo = Lo[1:-1,...]                       # Therefore we have to clip Lo so it is the same height as the next Yh.
 
-       if np.any(np.asarray(Lo.shape) != np.asarray(Yh[level-1].shape * np.array((2,1)))):
+       if np.any(np.asanyarray(Lo.shape) != np.asanyarray(Yh[level-1].shape * np.array((2,1)))):
           raise ValueError('Yh sizes are not valid for DTWAVEIFM')
        
        level -= 1
@@ -187,7 +187,7 @@ def c2q1d(x):
 
     """
     a, b = x.shape
-    z = np.zeros((a*2, b))
+    z = np.zeros((a*2, b), dtype=x.real.dtype)
     z[::2, :] = np.real(x)
     z[1::2, :] = np.imag(x)
 
diff --git a/dtcwt/transform2d.py b/dtcwt/transform2d.py
index f500e06..57b7bb3 100644
--- a/dtcwt/transform2d.py
+++ b/dtcwt/transform2d.py
@@ -5,7 +5,7 @@ from six.moves import xrange
 
 from dtcwt import biort as _biort, qshift as _qshift
 from dtcwt.defaults import DEFAULT_BIORT, DEFAULT_QSHIFT
-from dtcwt.lowlevel import colfilter, coldfilt, colifilt, appropriate_complex_type_for
+from dtcwt.lowlevel import colfilter, coldfilt, colifilt, appropriate_complex_type_for, asfarray
 
 def dtwavexfm2(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, include_scale=False):
     """Perform a *n*-level DTCWT-2D decompostion on a 2D matrix *X*.
@@ -36,7 +36,7 @@ def dtwavexfm2(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, include
     .. codeauthor:: Cian Shaffrey, Cambridge University, Sept 2001
 
     """
-    X = np.atleast_2d(np.asfarray(X))
+    X = np.atleast_2d(asfarray(X))
 
     # Try to load coefficients if biort is a string parameter
     try:
@@ -289,7 +289,7 @@ def c2q(w,gain):
 
     """
 
-    x = np.zeros((w.shape[0] << 1, w.shape[1] << 1))
+    x = np.zeros((w.shape[0] << 1, w.shape[1] << 1), dtype=w.real.dtype)
 
     sc = np.sqrt(0.5) * gain
     P = w[:,:,0]*sc[0] + w[:,:,1]*sc[1]
diff --git a/dtcwt/transform3d.py b/dtcwt/transform3d.py
index 788910d..5ce71ca 100644
--- a/dtcwt/transform3d.py
+++ b/dtcwt/transform3d.py
@@ -5,7 +5,7 @@ from six.moves import xrange
 
 from dtcwt import biort as _biort, qshift as _qshift
 from dtcwt.defaults import DEFAULT_BIORT, DEFAULT_QSHIFT
-from dtcwt.lowlevel import colfilter, coldfilt, colifilt
+from dtcwt.lowlevel import colfilter, coldfilt, colifilt, asfarray
 
 def dtwavexfm3(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, ext_mode=4):
     """Perform a *n*-level DTCWT-3D decompostion on a 3D matrix *X*.
@@ -49,7 +49,7 @@ def dtwavexfm3(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, ext_mod
     .. codeauthor:: Nick Kingsbury, Cambridge University, July 1999.
 
     """
-    X = np.atleast_3d(np.asfarray(X))
+    X = np.atleast_3d(asfarray(X))
 
     # Try to load coefficients if biort is a string parameter
     try:
@@ -155,7 +155,7 @@ def _level1_xfm(X, h0o, h1o, ext_mode):
         raise ValueError('Input shape should be a multiple of 4 in each direction when ext_mode == 8')
 
     # Create work area
-    work_shape = np.asarray(X.shape) * 2
+    work_shape = np.asanyarray(X.shape) * 2
 
     # We need one extra row per octant if filter length is even
     if h0o.shape[0] % 2 == 0:
@@ -246,7 +246,7 @@ def _level2_xfm(X, h0a, h0b, h1a, h1b, ext_mode):
             X = np.concatenate((X[:,:,(0,0)], X, X[:,:,(-1,-1)]), 2)
 
     # Create work area
-    work_shape = np.asarray(X.shape)
+    work_shape = np.asanyarray(X.shape)
     work = np.zeros(work_shape, dtype=X.dtype)
 
     # Form some useful slices
@@ -298,10 +298,10 @@ def _level1_ifm(Yl, Yh, g0o, g1o):
 
     """
     # Create work area
-    work = np.zeros(np.asarray(Yl.shape) * 2, dtype=Yl.dtype)
+    work = np.zeros(np.asanyarray(Yl.shape) * 2, dtype=Yl.dtype)
 
     # Work out shape of output
-    Xshape = np.asarray(work.shape) >> 1
+    Xshape = np.asanyarray(work.shape) >> 1
     if g0o.shape[0] % 2 == 0:
         # if we have an even length filter, we need to shrink the output by 1
         # to compensate for the addition of an extra row/column/slice in 
@@ -355,7 +355,7 @@ def _level2_ifm(Yl, Yh, g0a, g0b, g1a, g1b, ext_mode, prev_level_size):
 
     """
     # Create work area
-    work = np.zeros(np.asarray(Yl.shape)*2, dtype=Yl.dtype)
+    work = np.zeros(np.asanyarray(Yl.shape)*2, dtype=Yl.dtype)
 
     # Form some useful slices
     s0a = slice(None, work.shape[0] >> 1)
@@ -507,7 +507,7 @@ def c2cube(z):
     rr, ri = r.real, r.imag
     sr, si = s.real, s.imag
 
-    y = np.zeros(np.asarray(z.shape[:3])*2, dtype=z.real.dtype)
+    y = np.zeros(np.asanyarray(z.shape[:3])*2, dtype=z.real.dtype)
 
     y[1::2, 1::2, 1::2] = ( pr+qr+rr+sr)
     y[0::2, 0::2, 1::2] = (-pr-qr+rr+sr)
diff --git a/tests/testifm1.py b/tests/testifm1.py
index 8b906fa..d208e98 100644
--- a/tests/testifm1.py
+++ b/tests/testifm1.py
@@ -23,4 +23,13 @@ def test_reconstruct_2d():
     vec_recon = dtwaveifm(Yl, Yh)
     assert np.all(np.abs(vec_recon - vec) < TOLERANCE)
 
+def test_float32_input():
+    # Check that an float32 input is correctly output as float32
+    Yl, Yh = dtwavexfm(np.array([1, 2, 3, 4]).astype(np.float32))
+    assert np.issubsctype(Yl.dtype, np.float32)
+    assert np.all(list(np.issubsctype(x.dtype, np.complex64) for x in Yh))
+
+    recon = dtwaveifm(Yl, Yh)
+    assert np.issubsctype(recon.dtype, np.float32)
+
 # vim:sw=4:sts=4:et
diff --git a/tests/testifm2.py b/tests/testifm2.py
index 02f370f..24095d1 100644
--- a/tests/testifm2.py
+++ b/tests/testifm2.py
@@ -39,5 +39,14 @@ def test_reconstruct_custom_filter():
     lena_recon = dtwaveifm2(Yl, Yh, biort('legall'), qshift('qshift_06'))
     assert np.all(np.abs(lena_recon - lena) < TOLERANCE)
 
+def test_float32_input():
+    # Check that an float32 input is correctly output as float32
+    Yl, Yh = dtwavexfm2(lena.astype(np.float32))
+    assert np.issubsctype(Yl.dtype, np.float32)
+    assert np.all(list(np.issubsctype(x.dtype, np.complex64) for x in Yh))
+
+    lena_recon = dtwaveifm2(Yl, Yh)
+    assert np.issubsctype(lena_recon.dtype, np.float32)
+
 
 # vim:sw=4:sts=4:et
diff --git a/tests/testxfm1.py b/tests/testxfm1.py
index 408dca6..15265ed 100644
--- a/tests/testxfm1.py
+++ b/tests/testxfm1.py
@@ -74,4 +74,10 @@ def test_integer_perfect_recon():
     B = dtwaveifm(Yl, Yh)
     assert np.max(np.abs(A-B)) < 1e-12
 
+def test_float32_input():
+    # Check that an float32 input is correctly output as float32
+    Yl, Yh = dtwavexfm(np.array([1,2,3,4]).astype(np.float32))
+    assert np.issubsctype(Yl.dtype, np.float32)
+    assert np.all(list(np.issubsctype(x.dtype, np.complex64) for x in Yh))
+
 # vim:sw=4:sts=4:et
diff --git a/tests/testxfm2.py b/tests/testxfm2.py
index 94775fb..97595a4 100644
--- a/tests/testxfm2.py
+++ b/tests/testxfm2.py
@@ -77,4 +77,10 @@ def test_integer_perfect_recon():
     B = dtwaveifm2(Yl, Yh)
     assert np.max(np.abs(A-B)) < 1e-5
 
+def test_float32_input():
+    # Check that an float32 input is correctly output as float32
+    Yl, Yh = dtwavexfm2(lena.astype(np.float32))
+    assert np.issubsctype(Yl.dtype, np.float32)
+    assert np.all(list(np.issubsctype(x.dtype, np.complex64) for x in Yh))
+
 # vim:sw=4:sts=4:et
diff --git a/tests/testxfm3.py b/tests/testxfm3.py
index cca6070..99f727f 100644
--- a/tests/testxfm3.py
+++ b/tests/testxfm3.py
@@ -56,7 +56,6 @@ def test_simple_level_1_recon_haar():
     Yl, Yh = dtwavexfm3(ellipsoid, 1, biort=haar)
     ellipsoid_recon = dtwaveifm3(Yl, Yh, biort=haar)
     assert ellipsoid.size == ellipsoid_recon.size
-    print(np.max(np.abs(ellipsoid - ellipsoid_recon)))
     assert np.max(np.abs(ellipsoid - ellipsoid_recon)) < TOLERANCE
 
 def test_simple_level_2_xfm():
@@ -136,4 +135,19 @@ def test_integer_perfect_recon():
     B = dtwaveifm3(Yl, Yh)
     assert np.max(np.abs(A-B)) < 1e-12
 
+def test_float32_input():
+    # Check that an float32 input is correctly output as float32
+    Yl, Yh = dtwavexfm3(ellipsoid.astype(np.float32))
+    assert np.issubsctype(Yl.dtype, np.float32)
+    assert np.all(list(np.issubsctype(x.dtype, np.complex64) for x in Yh))
+
+def test_float32_recon():
+    # Check that an float32 input is correctly output as float32
+    Yl, Yh = dtwavexfm3(ellipsoid.astype(np.float32))
+    assert np.issubsctype(Yl.dtype, np.float32)
+    assert np.all(list(np.issubsctype(x.dtype, np.complex64) for x in Yh))
+
+    recon = dtwaveifm3(Yl, Yh)
+    assert np.issubsctype(recon.dtype, np.float32)
+
 # vim:sw=4:sts=4:et

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/python-dtcwt.git



More information about the debian-science-commits mailing list