[python-dtcwt] 94/497: fix perfect reconstruction for 2D

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Tue Jul 21 18:05:53 UTC 2015


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch debian/sid
in repository python-dtcwt.

commit 7f015ffb62e5a45dd8f697acff1e993cf0bcd740
Author: Rich Wareham <rjw57 at cam.ac.uk>
Date:   Fri Aug 9 16:09:37 2013 +0100

    fix perfect reconstruction for 2D
    
    The 2D transform was hard-coded to only ever return 64-bit complex
    (i.e. two 32-bit floating point part) results. Modify it to select a
    complex data type depending on the input. Fixes #9.
---
 dtcwt/lowlevel.py    | 19 +++++++++++++++++++
 dtcwt/transform2d.py | 10 ++++++----
 tests/testifm2.py    |  4 ++--
 3 files changed, 27 insertions(+), 6 deletions(-)

diff --git a/dtcwt/lowlevel.py b/dtcwt/lowlevel.py
index 4e94c43..588a690 100644
--- a/dtcwt/lowlevel.py
+++ b/dtcwt/lowlevel.py
@@ -1,6 +1,25 @@
 import numpy as np
 from six.moves import xrange
 
+def appropriate_complex_type_for(X):
+    """Return an appropriate complex data type depending on the type of X. If X
+    is already complex, return that, if it is floating point return a complex
+    type of the appropriate size and if it is integer, choose an complex
+    floating point type depending on the result of :py:func:`numpy.asfarray`.
+
+    """
+    X = np.asfarray(X)
+    
+    if X.dtype is np.complex64 or X.dtype is np.complex128:
+        return X.dtype
+    elif X.dtype is np.float32:
+        return np.complex64
+    elif X.dtype is np.float64:
+        return np.complex128
+
+    # God knows, err on the side of caution
+    return np.complex128
+
 def as_column_vector(v):
     """Return *v* as a column vector with shape (N,1).
     
diff --git a/dtcwt/transform2d.py b/dtcwt/transform2d.py
index deb7565..f500e06 100644
--- a/dtcwt/transform2d.py
+++ b/dtcwt/transform2d.py
@@ -5,7 +5,7 @@ from six.moves import xrange
 
 from dtcwt import biort as _biort, qshift as _qshift
 from dtcwt.defaults import DEFAULT_BIORT, DEFAULT_QSHIFT
-from dtcwt.lowlevel import colfilter, coldfilt, colifilt
+from dtcwt.lowlevel import colfilter, coldfilt, colifilt, appropriate_complex_type_for
 
 def dtwavexfm2(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, include_scale=False):
     """Perform a *n*-level DTCWT-2D decompostion on a 2D matrix *X*.
@@ -84,6 +84,8 @@ def dtwavexfm2(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, include
         # this is only required if the user specifies a third output component.
         Yscale = [None,] * nlevels
 
+    complex_dtype = appropriate_complex_type_for(X)
+
     if nlevels >= 1:
         # Do odd top-level filters on cols.
         Lo = colfilter(X,h0o).T
@@ -91,7 +93,7 @@ def dtwavexfm2(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, include
 
         # Do odd top-level filters on rows.
         LoLo = colfilter(Lo,h0o).T
-        Yh[0] = np.zeros((LoLo.shape[0] >> 1, LoLo.shape[1] >> 1, 6), dtype=np.complex64)
+        Yh[0] = np.zeros((LoLo.shape[0] >> 1, LoLo.shape[1] >> 1, 6), dtype=complex_dtype)
         Yh[0][:,:,[0, 5]] = q2c(colfilter(Hi,h0o).T)     # Horizontal pair
         Yh[0][:,:,[2, 3]] = q2c(colfilter(Lo,h1o).T)     # Vertical pair
         Yh[0][:,:,[1, 4]] = q2c(colfilter(Hi,h1o).T)     # Diagonal pair
@@ -116,7 +118,7 @@ def dtwavexfm2(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, include
         # Do even Qshift filters on columns.
         LoLo = coldfilt(Lo,h0b,h0a).T
 
-        Yh[level] = np.zeros((LoLo.shape[0]>>1, LoLo.shape[1]>>1, 6), dtype=np.complex64)
+        Yh[level] = np.zeros((LoLo.shape[0]>>1, LoLo.shape[1]>>1, 6), dtype=complex_dtype)
         Yh[level][:,:,[0, 5]] = q2c(coldfilt(Hi,h0b,h0a).T)  # Horizontal
         Yh[level][:,:,[2, 3]] = q2c(coldfilt(Lo,h1b,h1a).T)  # Vertical
         Yh[level][:,:,[1, 4]] = q2c(coldfilt(Hi,h1b,h1a).T)  # Diagonal   
@@ -256,7 +258,7 @@ def q2c(y):
     """Convert from quads in y to complex numbers in z.
 
     """
-    j2 = np.sqrt(0.5) * np.array([1, 1j])
+    j2 = (np.sqrt(0.5) * np.array([1, 1j])).astype(appropriate_complex_type_for(y))
 
     # Arrange pixels from the corners of the quads into
     # 2 subimages of alternate real and imag pixels.
diff --git a/tests/testifm2.py b/tests/testifm2.py
index b8906f0..02f370f 100644
--- a/tests/testifm2.py
+++ b/tests/testifm2.py
@@ -9,14 +9,14 @@ TOLERANCE = 1e-12
 
 def setup():
     global lena, lena_crop
-    lena = np.load(os.path.join(os.path.dirname(__file__), 'lena.npz'))['lena']
+    lena = np.load(os.path.join(os.path.dirname(__file__), 'lena.npz'))['lena'].astype(np.float64)
     lena_crop = lena[:233, :301]
 
 def test_lena_loaded():
     assert lena.shape == (512, 512)
     assert lena.min() >= 0
     assert lena.max() <= 1
-    assert lena.dtype == np.float32
+    assert lena.dtype == np.float64
 
 @attr('transform')
 def test_reconstruct():

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/python-dtcwt.git



More information about the debian-science-commits mailing list