[python-dtcwt] 257/497: registration: add example scripts
Ghislain Vaillant
ghisvail-guest at moszumanska.debian.org
Tue Jul 21 18:06:12 UTC 2015
This is an automated email from the git hooks/post-receive script.
ghisvail-guest pushed a commit to branch debian/sid
in repository python-dtcwt.
commit d10b84e0a03ed6a2769e4d2a19f2865b355c4913
Author: Rich Wareham <rjw57 at cam.ac.uk>
Date: Tue Jan 28 16:00:59 2014 +0000
registration: add example scripts
Add a couple of example tools which make use of the registration API.
---
examples/image-registration-1.py | 2 +-
examples/image-registration-2.py | 8 +-
examples/register_images.py | 64 ++++++++++++++
examples/register_video.py | 180 +++++++++++++++++++++++++++++++++++++++
4 files changed, 250 insertions(+), 4 deletions(-)
diff --git a/examples/image-registration-1.py b/examples/image-registration-1.py
index 2d57609..1c7da49 100755
--- a/examples/image-registration-1.py
+++ b/examples/image-registration-1.py
@@ -87,7 +87,7 @@ for iteration in xrange(4*(len(Yh1)-5)):
Yh3[level] = affinewarphighpass(Yh1[level], a, method='bilinear')
Qt_mats = qtildematrices(Yh3, Yh2, levels)
- Qt = np.sum(list(x.sum() for x in Qt_mats), axis=0)
+ Qt = np.sum(list(np.sum(np.sum(x, axis=0), axis=0) for x in Qt_mats), axis=0)
a += solvetransform(Qt)
logging.info('Computing velocity field')
diff --git a/examples/image-registration-2.py b/examples/image-registration-2.py
index 5575fd5..42d9d45 100755
--- a/examples/image-registration-2.py
+++ b/examples/image-registration-2.py
@@ -16,6 +16,7 @@ from matplotlib.pyplot import *
import numpy as np
import dtcwt
+from dtcwt.backend.backend_numpy import Transform2d
import dtcwt.sampling
from dtcwt.registration import *
@@ -31,12 +32,13 @@ def register_frames(filename):
# Take the DTCWT of both frames.
logging.info('Taking DTCWT')
nlevels = 6
- Yl1, Yh1 = dtcwt.dtwavexfm2(f1, nlevels=nlevels)
- Yl2, Yh2 = dtcwt.dtwavexfm2(f2, nlevels=nlevels)
+ trans = Transform2d()
+ t1 = trans.forward(f1, nlevels=nlevels)
+ t2 = trans.forward(f2, nlevels=nlevels)
# Solve for transform
logging.info('Finding flow')
- avecs = estimateflow(Yh1, Yh2)
+ avecs = estimatereg(t1, t2)
logging.info('Computing warped image')
warped_f1 = warp(f1, avecs, method='bilinear')
diff --git a/examples/register_images.py b/examples/register_images.py
new file mode 100755
index 0000000..0017940
--- /dev/null
+++ b/examples/register_images.py
@@ -0,0 +1,64 @@
+#!/usr/bin/env python
+
+"""
+Register neighbouring frames of video and save inter-frame transform
+parameters to a file.
+
+Usage:
+ register_images.py [options] <prevframe> <nextframe> <output>
+ register_images.py (-h | --help)
+
+Options:
+
+ --cl Attempt to use OpenCL where possible.
+
+"""
+
+import logging
+
+from docopt import docopt
+import dtcwt
+from dtcwt.backend.backend_opencl import Transform2d as CLTransform2d
+from dtcwt.backend.backend_numpy import Transform2d as NumPyTransform2d
+import dtcwt.registration as reg
+import dtcwt.sampling
+from PIL import Image # Use 'Pillow', the PIL fork
+import numpy as np
+import tables
+
+# Parse command line options
+OPTS = docopt(__doc__)
+
+# Set logging options
+logging.basicConfig(level=logging.INFO)
+
+def avecs_for_pair(prev, next_):
+ trans = CLTransform2d() if OPTS['--cl'] else NumPyTransform2d()
+ t1 = trans.forward(prev, nlevels=5)
+ t2 = trans.forward(next_, nlevels=5)
+ return reg.estimatereg(t1, t2)
+
+class Metadata(tables.IsDescription):
+ previmpath = tables.StringCol(512)
+ nextimpath = tables.StringCol(512)
+
+def main():
+ logging.info('Launched')
+
+ logging.info('Loading "prev" image from "{0}"'.format(OPTS['<prevframe>']))
+ pim = np.array(Image.open(OPTS['<prevframe>']).convert('L')) / 255.0
+
+ logging.info('Loading "next" image from "{0}"'.format(OPTS['<nextframe>']))
+ nim = np.array(Image.open(OPTS['<nextframe>']).convert('L')) / 255.0
+
+ logging.info('Estimating registration')
+ avecs = avecs_for_pair(pim, nim)
+
+ logging.info('Calculating velocity field')
+ vxs, vys = dtcwt.registration.velocityfield(avecs, avecs.shape[:2], method='bilinear')
+
+ logging.info('Saving result to {0}'.format(OPTS['<output>']))
+ np.savez_compressed(OPTS['<output>'], avecs=avecs, vxs=vxs, vys=vys)
+
+if __name__ == '__main__':
+ main()
diff --git a/examples/register_video.py b/examples/register_video.py
new file mode 100755
index 0000000..4bc2cd0
--- /dev/null
+++ b/examples/register_video.py
@@ -0,0 +1,180 @@
+#!/usr/bin/env python
+
+"""
+Register neighbouring frames of video and save inter-frame transform
+parameters to a file.
+
+Usage:
+ register_video.py [options] <inputvideo> <outputh5>
+ register_video.py (-h | --help)
+
+Options:
+ -g GROUPSIZE, --group=GROUPSIZE Frame pairs per work-group [default: 32]
+
+"""
+
+import logging
+
+import cv2
+from docopt import docopt
+import dtcwt
+from dtcwt.backend.backend_numpy import Transform2d
+import dtcwt.registration as reg
+import dtcwt.sampling
+import numpy as np
+from mpi4py import MPI
+import tables
+
+# Parse command line options
+OPTS = docopt(__doc__)
+
+class VideoReader(object):
+ def __init__(self, filename, groupsize=None):
+ self._vc = cv2.VideoCapture(filename)
+ self._last_frame = None
+ self._last_frame_idx = -1
+
+ self.groupsize = groupsize or int(OPTS['--group'])
+
+ def read_next_gof(self):
+ """
+ Read the next group of frames from *videoreader*.
+
+ """
+
+ frames = []
+ if self._last_frame is not None:
+ frames.append((self._last_frame_idx, self._last_frame))
+
+ for it in xrange(self.groupsize):
+ success, f = self._vc.read()
+ if not success:
+ break
+
+ self._last_frame = cv2.cvtColor(f, cv2.COLOR_RGB2GRAY) / 255.0
+ self._last_frame_idx += 1
+
+ frames.append((self._last_frame_idx, self._last_frame))
+
+ return frames
+
+def avecs_for_frames(frames):
+ if len(frames) <= 1:
+ return [],[]
+
+ avecs = []
+ trans = Transform2d()
+
+ idx, frame = frames[0]
+ t = trans.forward(frame, nlevels=5)
+
+ h_pair = (None, t)
+ idx_pair = (None, idx)
+ idx_pairs = []
+
+ for idx, frame in frames[1:]:
+ t = trans.forward(frame, nlevels=5)
+
+ h_pair = (h_pair[1], t)
+ idx_pair = (idx_pair[1], idx)
+
+ idx_pairs.append(idx_pair)
+ avecs.append(reg.estimatereg(h_pair[0], h_pair[1]))
+ logging.info('Finished frame pair {0}'.format(idx_pair))
+
+ return idx_pairs, avecs
+
+class Metadata(tables.IsDescription):
+ videopath = tables.StringCol(512)
+
+def main():
+ comm = MPI.COMM_WORLD
+ rank = comm.Get_rank()
+ size = comm.Get_size()
+
+ logging.basicConfig(level=logging.INFO, format='Rank ' + str(comm.Get_rank()) + ': %(message)s')
+
+ logging.info('Launched')
+
+ if rank == 0:
+ # I'm rank 0. Open the video file for reading
+ logging.info('Loading video from "{0}"'.format(OPTS['<inputvideo>']))
+ v = VideoReader(OPTS['<inputvideo>'])
+ outfile = tables.openFile(OPTS['<outputh5>'], mode='w',
+ title='Frame to frame registration from "{0}"'.format(OPTS['<inputvideo>']))
+
+ # Create a frame pair e-array
+ frame_pairs = outfile.createEArray('/', 'frame_idx_pairs',
+ atom=tables.Int64Atom(),
+ shape=(0,2),
+ title='(before, after) frame indices for corresponding affine parameters')
+
+ metadata_table = outfile.createTable('/', 'metadata', Metadata, 'Metadata')
+ metadata = metadata_table.row
+ metadata['videopath'] = OPTS['<inputvideo>']
+ metadata.append()
+
+ affine_params = None # Will be created after first bit of data
+
+ # Commit to disk
+ outfile.flush()
+
+ is_last_iteration = False
+ last_frame = None
+ group_idx = 0
+ while not is_last_iteration:
+ groups = None
+
+ if rank == 0:
+ logging.info('Reading next set of work group frames')
+
+ # Read frames for next work group
+ groups = []
+ for group in xrange(size):
+ gof = v.read_next_gof()
+ groups.append((group_idx, gof))
+ group_idx += 1
+
+ # If we run out of frames, this is the last iteration
+ is_last_iteration = len(gof) <= 1
+
+ logging.info('Sending work group frames')
+
+ # Broadcast iteration flag to nodes
+ is_last_iteration = comm.bcast(is_last_iteration, root=0)
+
+ # Scatter to nodes
+ group_id, frames = comm.scatter(groups, root=0)
+
+ logging.info('received work group id={0} of {1} frame(s)'.format(group_id, len(frames)))
+
+ # Calculate result
+ idxs_and_avecs = avecs_for_frames(frames)
+
+ # Send result back to rank 0
+ logging.info('finished. Sending {1} results'.format(rank, len(idxs_and_avecs)))
+ gathered = comm.gather(idxs_and_avecs, root=0)
+
+ if rank == 0:
+ for idxs, av in gathered:
+ assert len(idxs) == len(av)
+
+ if affine_params is None and len(av) > 0:
+ affine_params = outfile.createEArray('/', 'affine_parameters',
+ atom=tables.Float64Atom(),
+ shape=(0,) + av[0].shape,
+ title='affine parameters for corresponding frame index pairs')
+
+ if len(av) > 0:
+ affine_params.append(av)
+ frame_pairs.append(idxs)
+
+ outfile.flush()
+
+ logging.info('Frame pairs processed: {0}'.format(affine_params.shape[0]))
+
+ if rank == 0:
+ outfile.close()
+
+if __name__ == '__main__':
+ main()
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/python-dtcwt.git
More information about the debian-science-commits
mailing list