[python-dtcwt] 166/497: add a first pass opencl transform function

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Tue Jul 21 18:06:01 UTC 2015


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch debian/sid
in repository python-dtcwt.

commit c54767d88637bdaecb18b0f8ec7e1bfa814cf9ee
Author: Rich Wareham <rjw57 at cam.ac.uk>
Date:   Fri Nov 8 16:25:21 2013 +0000

    add a first pass opencl transform function
    
    Implement a 2d transform which simply uses the OpenCL col{i,d}filt
    implementations.
---
 dtcwt/opencl/transform2d.py | 159 ++++++++++++++++++++++++++++++++++++++++++++
 scripts/benchmark_opencl.py |   9 +++
 tests/testopenclxfm2.py     |  88 ++++++++++++++++++++++++
 3 files changed, 256 insertions(+)

diff --git a/dtcwt/opencl/transform2d.py b/dtcwt/opencl/transform2d.py
new file mode 100644
index 0000000..0a42392
--- /dev/null
+++ b/dtcwt/opencl/transform2d.py
@@ -0,0 +1,159 @@
+from __future__ import division
+
+import logging
+import numpy as np
+from six.moves import xrange
+
+from dtcwt import biort as _biort, qshift as _qshift
+from dtcwt.defaults import DEFAULT_BIORT, DEFAULT_QSHIFT
+from dtcwt.lowlevel import appropriate_complex_type_for, asfarray
+from dtcwt.opencl.lowlevel import colfilter, coldfilt, colifilt
+from dtcwt.transform2d import q2c
+
+def dtwavexfm2(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, include_scale=False):
+    """Perform a *n*-level DTCWT-2D decompostion on a 2D matrix *X*.
+
+    :param X: 2D real array
+    :param nlevels: Number of levels of wavelet decomposition
+    :param biort: Level 1 wavelets to use. See :py:func:`biort`.
+    :param qshift: Level >= 2 wavelets to use. See :py:func:`qshift`.
+
+    :returns Yl: The real lowpass image from the final level
+    :returns Yh: A tuple containing the complex highpass subimages for each level.
+    :returns Yscale: If *include_scale* is True, a tuple containing real lowpass coefficients for every scale.
+
+    If *biort* or *qshift* are strings, they are used as an argument to the
+    :py:func:`biort` or :py:func:`qshift` functions. Otherwise, they are
+    interpreted as tuples of vectors giving filter coefficients. In the *biort*
+    case, this should be (h0o, g0o, h1o, g1o). In the *qshift* case, this should
+    be (h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b).
+
+    Example::
+
+        # Performs a 3-level transform on the real image X using the 13,19-tap
+        # filters for level 1 and the Q-shift 14-tap filters for levels >= 2.
+        Yl, Yh = dtwavexfm2(X, 3, 'near_sym_b', 'qshift_b')
+
+    .. codeauthor:: Rich Wareham <rjw57 at cantab.net>, Aug 2013
+    .. codeauthor:: Nick Kingsbury, Cambridge University, Sept 2001
+    .. codeauthor:: Cian Shaffrey, Cambridge University, Sept 2001
+
+    """
+    X = np.atleast_2d(asfarray(X))
+
+    # Try to load coefficients if biort is a string parameter
+    try:
+        h0o, g0o, h1o, g1o = _biort(biort)
+    except TypeError:
+        h0o, g0o, h1o, g1o = biort
+
+    # Try to load coefficients if qshift is a string parameter
+    try:
+        h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = _qshift(qshift)
+    except TypeError:
+        h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = qshift
+
+    original_size = X.shape
+
+    if len(X.shape) >= 3:
+        raise ValueError('The entered image is {0}, please enter each image slice separately.'.
+                format('x'.join(list(str(s) for s in X.shape))))
+
+    # The next few lines of code check to see if the image is odd in size, if so an extra ...
+    # row/column will be added to the bottom/right of the image
+    initial_row_extend = 0  #initialise
+    initial_col_extend = 0
+    if original_size[0] % 2 != 0:
+        # if X.shape[0] is not divisable by 2 then we need to extend X by adding a row at the bottom
+        X = np.vstack((X, X[[-1],:]))  # Any further extension will be done in due course.
+        initial_row_extend = 1
+
+    if original_size[1] % 2 != 0:
+        # if X.shape[1] is not divisable by 2 then we need to extend X by adding a col to the left
+        X = np.hstack((X, X[:,[-1]]))
+        initial_col_extend = 1
+
+    extended_size = X.shape
+
+    if nlevels == 0:
+        if include_scale:
+            return X, (), ()
+        else:
+            return X, ()
+
+    # initialise
+    Yh = [None,] * nlevels
+    if include_scale:
+        # this is only required if the user specifies a third output component.
+        Yscale = [None,] * nlevels
+
+    complex_dtype = appropriate_complex_type_for(X)
+
+    if nlevels >= 1:
+        # Do odd top-level filters on cols.
+        Lo = colfilter(X,h0o).T
+        Hi = colfilter(X,h1o).T
+
+        # Do odd top-level filters on rows.
+        LoLo = colfilter(Lo,h0o).T
+        Yh[0] = np.zeros((LoLo.shape[0] >> 1, LoLo.shape[1] >> 1, 6), dtype=complex_dtype)
+        Yh[0][:,:,[0, 5]] = q2c(colfilter(Hi,h0o).T)     # Horizontal pair
+        Yh[0][:,:,[2, 3]] = q2c(colfilter(Lo,h1o).T)     # Vertical pair
+        Yh[0][:,:,[1, 4]] = q2c(colfilter(Hi,h1o).T)     # Diagonal pair
+
+        if include_scale:
+            Yscale[0] = LoLo
+
+    for level in xrange(1, nlevels):
+        row_size, col_size = LoLo.shape
+        if row_size % 4 != 0:
+            # Extend by 2 rows if no. of rows of LoLo are not divisable by 4
+            LoLo = np.vstack((LoLo[[0],:], LoLo, LoLo[[-1],:]))
+
+        if col_size % 4 != 0:
+            # Extend by 2 cols if no. of cols of LoLo are not divisable by 4
+            LoLo = np.hstack((LoLo[:,[0]], LoLo, LoLo[:,[-1]]))
+
+        # Do even Qshift filters on rows.
+        Lo = coldfilt(LoLo,h0b,h0a).T
+        Hi = coldfilt(LoLo,h1b,h1a).T
+
+        # Do even Qshift filters on columns.
+        LoLo = coldfilt(Lo,h0b,h0a).T
+
+        Yh[level] = np.zeros((LoLo.shape[0]>>1, LoLo.shape[1]>>1, 6), dtype=complex_dtype)
+        Yh[level][:,:,[0, 5]] = q2c(coldfilt(Hi,h0b,h0a).T)  # Horizontal
+        Yh[level][:,:,[2, 3]] = q2c(coldfilt(Lo,h1b,h1a).T)  # Vertical
+        Yh[level][:,:,[1, 4]] = q2c(coldfilt(Hi,h1b,h1a).T)  # Diagonal   
+
+        if include_scale:
+            Yscale[level] = LoLo
+
+    Yl = LoLo
+
+    if initial_row_extend == 1 and initial_col_extend == 1:
+        logging.warn('The image entered is now a {0} NOT a {1}.'.format(
+            'x'.join(list(str(s) for s in extended_size)),
+            'x'.join(list(str(s) for s in original_size))))
+        logging.warn(
+            'The bottom row and rightmost column have been duplicated, prior to decomposition.')
+
+    if initial_row_extend == 1 and initial_col_extend == 0:
+        logging.warn('The image entered is now a {0} NOT a {1}.'.format(
+            'x'.join(list(str(s) for s in extended_size)),
+            'x'.join(list(str(s) for s in original_size))))
+        logging.warn(
+            'The bottom row has been duplicated, prior to decomposition.')
+
+    if initial_row_extend == 0 and initial_col_extend == 1:
+        logging.warn('The image entered is now a {0} NOT a {1}.'.format(
+            'x'.join(list(str(s) for s in extended_size)),
+            'x'.join(list(str(s) for s in original_size))))
+        logging.warn(
+            'The rightmost column has been duplicated, prior to decomposition.')
+
+    if include_scale:
+        return Yl, tuple(Yh), tuple(Yscale)
+    else:
+        return Yl, tuple(Yh)
+
diff --git a/scripts/benchmark_opencl.py b/scripts/benchmark_opencl.py
index 23d5d79..88253de 100644
--- a/scripts/benchmark_opencl.py
+++ b/scripts/benchmark_opencl.py
@@ -75,5 +75,14 @@ def main():
     print('Percentage speed up: {0:.0f}%'.format(1e2*a/b))
     print('=====')
 
+    print('Running NumPy dtwavexfm2...')
+    a = benchmark('dtwavexfm2(lena)',
+            'from dtcwt import dtwavexfm2; from __main__ import lena')
+    print('Running OpenCL dtwavexfm2...')
+    b = benchmark('dtwavexfm2(lena)',
+            'from dtcwt.opencl.transform2d import dtwavexfm2; from __main__ import lena')
+    print('Percentage speed up: {0:.0f}%'.format(1e2*a/b))
+    print('=====')
+
 if __name__ == '__main__':
     main()
diff --git a/tests/testopenclxfm2.py b/tests/testopenclxfm2.py
new file mode 100644
index 0000000..a112f24
--- /dev/null
+++ b/tests/testopenclxfm2.py
@@ -0,0 +1,88 @@
+import os
+from nose.tools import raises
+from nose.plugins.attrib import attr
+
+import numpy as np
+from dtcwt import biort, qshift
+from dtcwt import dtwavexfm2 as dtwavexfm2_np, dtwaveifm2
+from dtcwt.opencl.transform2d import dtwavexfm2 as dtwavexfm2_cl
+
+from .util import assert_almost_equal, skip_if_no_cl
+
+TOLERANCE = 1e-12
+GOLD_TOLERANCE = 1e-5
+
+def setup():
+    global lena
+    lena = np.load(os.path.join(os.path.dirname(__file__), 'lena.npz'))['lena']
+
+def test_lena_loaded():
+    assert lena.shape == (512, 512)
+    assert lena.min() >= 0
+    assert lena.max() <= 1
+    assert lena.dtype == np.float32
+
+def _compare_transforms(A, B):
+    Yl_A, Yh_A = A
+    Yl_B, Yh_B = B
+    assert_almost_equal(Yl_A, Yl_B, tolerance=GOLD_TOLERANCE)
+    for x, y in zip(Yh_A, Yh_B):
+        assert_almost_equal(x, y, tolerance=GOLD_TOLERANCE)
+
+ at skip_if_no_cl
+ at attr('transform')
+def test_simple():
+    _compare_transforms(dtwavexfm2_np(lena), dtwavexfm2_cl(lena))
+
+ at skip_if_no_cl
+ at attr('transform')
+def test_specific_wavelet():
+    a = dtwavexfm2_np(lena, biort=biort('antonini'), qshift=qshift('qshift_06'))
+    b = dtwavexfm2_cl(lena, biort=biort('antonini'), qshift=qshift('qshift_06'))
+    _compare_transforms(a, b)
+
+ at skip_if_no_cl
+def test_1d():
+    a = dtwavexfm2_np(lena[0,:])
+    b = dtwavexfm2_cl(lena[0,:])
+    _compare_transforms(a, b)
+
+ at skip_if_no_cl
+ at raises(ValueError)
+def test_3d():
+    Yl, Yh = dtwavexfm2_cl(np.dstack((lena, lena)))
+
+ at skip_if_no_cl
+def test_simple_w_scale():
+    Yl, Yh, Yscale = dtwavexfm2_cl(lena, include_scale=True)
+
+    assert len(Yscale) > 0
+    for x in Yscale:
+        assert x is not None
+
+ at skip_if_no_cl
+ at skip_if_no_cl
+def test_odd_rows():
+    a = dtwavexfm2_np(lena[:509,:])
+    b = dtwavexfm2_cl(lena[:509,:])
+    _compare_transforms(a, b)
+
+ at skip_if_no_cl
+def test_odd_cols():
+    a = dtwavexfm2_np(lena[:,:509])
+    b = dtwavexfm2_cl(lena[:,:509])
+    _compare_transforms(a, b)
+
+ at skip_if_no_cl
+def test_odd_rows_and_cols():
+    a = dtwavexfm2_np(lena[:509,:509])
+    b = dtwavexfm2_cl(lena[:509,:509])
+    _compare_transforms(a, b)
+
+ at skip_if_no_cl
+def test_0_levels():
+    a = dtwavexfm2_np(lena, nlevels=0)
+    b = dtwavexfm2_cl(lena, nlevels=0)
+    _compare_transforms(a, b)
+
+# vim:sw=4:sts=4:et

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/python-dtcwt.git



More information about the debian-science-commits mailing list