[python-dtcwt] 06/497: update test suite
Ghislain Vaillant
ghisvail-guest at moszumanska.debian.org
Tue Jul 21 18:05:42 UTC 2015
This is an automated email from the git hooks/post-receive script.
ghisvail-guest pushed a commit to branch debian/sid
in repository python-dtcwt.
commit 53f406cb3a644064d6cb61c741a0a3e0fac4b41f
Author: Rich Wareham <rjw57 at cam.ac.uk>
Date: Tue Aug 6 16:23:58 2013 +0100
update test suite
---
dtcwt/transform2d.py | 19 ++++++++++++-------
tests/testxfm2.py | 26 ++++++++++++++++++++++++++
2 files changed, 38 insertions(+), 7 deletions(-)
diff --git a/dtcwt/transform2d.py b/dtcwt/transform2d.py
index 82bb2d3..78dc14f 100644
--- a/dtcwt/transform2d.py
+++ b/dtcwt/transform2d.py
@@ -47,6 +47,8 @@ def dtwavexfm2(X, nlevels=3, biort='near_sym_a', qshift='qshift_a', include_scal
"""
+ X = np.atleast_2d(X)
+
# Try to load coefficients if biort is a string parameter
if isinstance(biort, basestring):
h0o, g0o, h1o, g1o = _biort(biort)
@@ -63,7 +65,7 @@ def dtwavexfm2(X, nlevels=3, biort='near_sym_a', qshift='qshift_a', include_scal
if len(X.shape) >= 3:
raise ValueError('The entered image is {0}, please enter each image slice separately.'.
- format('x'.join(X.shape)))
+ format('x'.join(list(str(s) for s in X.shape))))
# The next few lines of code check to see if the image is odd in size, if so an extra ...
# row/column will be added to the bottom/right of the image
@@ -71,12 +73,12 @@ def dtwavexfm2(X, nlevels=3, biort='near_sym_a', qshift='qshift_a', include_scal
initial_col_extend = 0
if original_size[0] % 2 != 0:
# if X.shape[0] is not divisable by 2 then we need to extend X by adding a row at the bottom
- X = np.vstack(X, X[-1,:]) # Any further extension will be done in due course.
+ X = np.vstack((X, X[-1,:])) # Any further extension will be done in due course.
initial_row_extend = 1;
if original_size[1] % 2 != 0:
# if X.shape[1] is not divisable by 2 then we need to extend X by adding a col to the left
- X = np.hstack(X, X[:,-1])
+ X = np.hstack((X, np.atleast_2d(X[:,-1]).T))
initial_col_extend = 1
extended_size = X.shape
@@ -118,7 +120,7 @@ def dtwavexfm2(X, nlevels=3, biort='near_sym_a', qshift='qshift_a', include_scal
if col_size % 4 != 0:
# Extend by 2 cols if no. of cols of LoLo are not divisable by 4
- LoLo = np.hstack((LoLo[:,0], LoLo, LoLo[:,-1]))
+ LoLo = np.hstack((np.atleast_2d(LoLo[:,0]).T, LoLo, np.atleast_2d(LoLo[:,-1]).T))
# Do even Qshift filters on rows.
Lo = coldfilt(LoLo,h0b,h0a).T
@@ -139,19 +141,22 @@ def dtwavexfm2(X, nlevels=3, biort='near_sym_a', qshift='qshift_a', include_scal
if initial_row_extend == 1 and initial_col_extend == 1:
logging.warn('The image entered is now a {0} NOT a {1}.'.format(
- 'x'.join(extended_size), 'x'.join(orginal_size)))
+ 'x'.join(list(str(s) for s in extended_size)),
+ 'x'.join(list(str(s) for s in original_size))))
logging.warn(
'The bottom row and rightmost column have been duplicated, prior to decomposition.')
if initial_row_extend == 1 and initial_col_extend == 0:
logging.warn('The image entered is now a {0} NOT a {1}.'.format(
- 'x'.join(extended_size), 'x'.join(orginal_size)))
+ 'x'.join(list(str(s) for s in extended_size)),
+ 'x'.join(list(str(s) for s in original_size))))
logging.warn(
'The bottom row has been duplicated, prior to decomposition.')
if initial_row_extend == 0 and initial_col_extend == 1:
logging.warn('The image entered is now a {0} NOT a {1}.'.format(
- 'x'.join(extended_size), 'x'.join(orginal_size)))
+ 'x'.join(list(str(s) for s in extended_size)),
+ 'x'.join(list(str(s) for s in original_size))))
logging.warn(
'The rightmost column has been duplicated, prior to decomposition.')
diff --git a/tests/testxfm2.py b/tests/testxfm2.py
index 4747f70..37f8eb2 100644
--- a/tests/testxfm2.py
+++ b/tests/testxfm2.py
@@ -1,4 +1,5 @@
import os
+from nose.tools import raises
import numpy as np
from dtcwt import dtwavexfm2
@@ -16,9 +17,34 @@ def test_lena_loaded():
def test_simple():
Yl, Yh = dtwavexfm2(lena)
+def test_1d():
+ Yl, Yh = dtwavexfm2(lena[0,:])
+
+ at raises(ValueError)
+def test_3d():
+ Yl, Yh = dtwavexfm2(np.dstack((lena, lena)))
+
def test_simple_w_scale():
Yl, Yh, Yscale = dtwavexfm2(lena, include_scale=True)
+def test_odd_rows():
+ Yl, Yh = dtwavexfm2(lena[:509,:])
+
+def test_odd_rows_w_scale():
+ Yl, Yh, Yscale = dtwavexfm2(lena[:509,:], include_scale=True)
+
+def test_odd_cols():
+ Yl, Yh = dtwavexfm2(lena[:,:509])
+
+def test_odd_cols_w_scale():
+ Yl, Yh, Yscale = dtwavexfm2(lena[:509,:509], include_scale=True)
+
+def test_odd_rows_and_cols():
+ Yl, Yh = dtwavexfm2(lena[:,:509])
+
+def test_odd_rows_and_cols_w_scale():
+ Yl, Yh, Yscale = dtwavexfm2(lena[:509,:509], include_scale=True)
+
def test_0_levels():
Yl, Yh = dtwavexfm2(lena, nlevels=0)
assert np.all(np.abs(Yl - lena) < 1e-5)
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/python-dtcwt.git
More information about the debian-science-commits
mailing list