[python-dtcwt] 94/497: fix perfect reconstruction for 2D
Ghislain Vaillant
ghisvail-guest at moszumanska.debian.org
Tue Jul 21 18:05:53 UTC 2015
This is an automated email from the git hooks/post-receive script.
ghisvail-guest pushed a commit to branch debian/sid
in repository python-dtcwt.
commit 7f015ffb62e5a45dd8f697acff1e993cf0bcd740
Author: Rich Wareham <rjw57 at cam.ac.uk>
Date: Fri Aug 9 16:09:37 2013 +0100
fix perfect reconstruction for 2D
The 2D transform was hard-coded to only ever return 64-bit complex
(i.e. two 32-bit floating point part) results. Modify it to select a
complex data type depending on the input. Fixes #9.
---
dtcwt/lowlevel.py | 19 +++++++++++++++++++
dtcwt/transform2d.py | 10 ++++++----
tests/testifm2.py | 4 ++--
3 files changed, 27 insertions(+), 6 deletions(-)
diff --git a/dtcwt/lowlevel.py b/dtcwt/lowlevel.py
index 4e94c43..588a690 100644
--- a/dtcwt/lowlevel.py
+++ b/dtcwt/lowlevel.py
@@ -1,6 +1,25 @@
import numpy as np
from six.moves import xrange
+def appropriate_complex_type_for(X):
+ """Return an appropriate complex data type depending on the type of X. If X
+ is already complex, return that, if it is floating point return a complex
+ type of the appropriate size and if it is integer, choose an complex
+ floating point type depending on the result of :py:func:`numpy.asfarray`.
+
+ """
+ X = np.asfarray(X)
+
+ if X.dtype is np.complex64 or X.dtype is np.complex128:
+ return X.dtype
+ elif X.dtype is np.float32:
+ return np.complex64
+ elif X.dtype is np.float64:
+ return np.complex128
+
+ # God knows, err on the side of caution
+ return np.complex128
+
def as_column_vector(v):
"""Return *v* as a column vector with shape (N,1).
diff --git a/dtcwt/transform2d.py b/dtcwt/transform2d.py
index deb7565..f500e06 100644
--- a/dtcwt/transform2d.py
+++ b/dtcwt/transform2d.py
@@ -5,7 +5,7 @@ from six.moves import xrange
from dtcwt import biort as _biort, qshift as _qshift
from dtcwt.defaults import DEFAULT_BIORT, DEFAULT_QSHIFT
-from dtcwt.lowlevel import colfilter, coldfilt, colifilt
+from dtcwt.lowlevel import colfilter, coldfilt, colifilt, appropriate_complex_type_for
def dtwavexfm2(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, include_scale=False):
"""Perform a *n*-level DTCWT-2D decompostion on a 2D matrix *X*.
@@ -84,6 +84,8 @@ def dtwavexfm2(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, include
# this is only required if the user specifies a third output component.
Yscale = [None,] * nlevels
+ complex_dtype = appropriate_complex_type_for(X)
+
if nlevels >= 1:
# Do odd top-level filters on cols.
Lo = colfilter(X,h0o).T
@@ -91,7 +93,7 @@ def dtwavexfm2(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, include
# Do odd top-level filters on rows.
LoLo = colfilter(Lo,h0o).T
- Yh[0] = np.zeros((LoLo.shape[0] >> 1, LoLo.shape[1] >> 1, 6), dtype=np.complex64)
+ Yh[0] = np.zeros((LoLo.shape[0] >> 1, LoLo.shape[1] >> 1, 6), dtype=complex_dtype)
Yh[0][:,:,[0, 5]] = q2c(colfilter(Hi,h0o).T) # Horizontal pair
Yh[0][:,:,[2, 3]] = q2c(colfilter(Lo,h1o).T) # Vertical pair
Yh[0][:,:,[1, 4]] = q2c(colfilter(Hi,h1o).T) # Diagonal pair
@@ -116,7 +118,7 @@ def dtwavexfm2(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, include
# Do even Qshift filters on columns.
LoLo = coldfilt(Lo,h0b,h0a).T
- Yh[level] = np.zeros((LoLo.shape[0]>>1, LoLo.shape[1]>>1, 6), dtype=np.complex64)
+ Yh[level] = np.zeros((LoLo.shape[0]>>1, LoLo.shape[1]>>1, 6), dtype=complex_dtype)
Yh[level][:,:,[0, 5]] = q2c(coldfilt(Hi,h0b,h0a).T) # Horizontal
Yh[level][:,:,[2, 3]] = q2c(coldfilt(Lo,h1b,h1a).T) # Vertical
Yh[level][:,:,[1, 4]] = q2c(coldfilt(Hi,h1b,h1a).T) # Diagonal
@@ -256,7 +258,7 @@ def q2c(y):
"""Convert from quads in y to complex numbers in z.
"""
- j2 = np.sqrt(0.5) * np.array([1, 1j])
+ j2 = (np.sqrt(0.5) * np.array([1, 1j])).astype(appropriate_complex_type_for(y))
# Arrange pixels from the corners of the quads into
# 2 subimages of alternate real and imag pixels.
diff --git a/tests/testifm2.py b/tests/testifm2.py
index b8906f0..02f370f 100644
--- a/tests/testifm2.py
+++ b/tests/testifm2.py
@@ -9,14 +9,14 @@ TOLERANCE = 1e-12
def setup():
global lena, lena_crop
- lena = np.load(os.path.join(os.path.dirname(__file__), 'lena.npz'))['lena']
+ lena = np.load(os.path.join(os.path.dirname(__file__), 'lena.npz'))['lena'].astype(np.float64)
lena_crop = lena[:233, :301]
def test_lena_loaded():
assert lena.shape == (512, 512)
assert lena.min() >= 0
assert lena.max() <= 1
- assert lena.dtype == np.float32
+ assert lena.dtype == np.float64
@attr('transform')
def test_reconstruct():
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/python-dtcwt.git
More information about the debian-science-commits
mailing list