[h5py] 196/455: Fix broadcasting code
Ghislain Vaillant
ghisvail-guest at moszumanska.debian.org
Thu Jul 2 18:19:32 UTC 2015
This is an automated email from the git hooks/post-receive script.
ghisvail-guest pushed a commit to annotated tag 1.3.0
in repository h5py.
commit 3db385b23784facc3a8411ece25d56ed7021b376
Author: andrewcollette <andrew.collette at gmail.com>
Date: Mon Jan 26 21:37:33 2009 +0000
Fix broadcasting code
---
h5py/highlevel.py | 12 ++++---
h5py/selections.py | 77 +++++++++++++++++++++++++++-----------------
h5py/tests/test_highlevel.py | 4 +--
3 files changed, 56 insertions(+), 37 deletions(-)
diff --git a/h5py/highlevel.py b/h5py/highlevel.py
index e90f03c..ebb512d 100644
--- a/h5py/highlevel.py
+++ b/h5py/highlevel.py
@@ -924,12 +924,8 @@ class Dataset(HLObject):
else:
raise NotImplementedError("Field name selections are not yet allowed for write.")
- # 3. Validate the input array. Also convert scalars for broadcast.
+ # 3. Validate the input array
val = numpy.asarray(val, order='C')
- if val.shape == () and self.shape != ():
- fastest = self.shape[-1]
- if fastest < 1e6:
- val = numpy.repeat(val, fastest)
# 4. Perform the dataspace selection
if sel.is_simple(args):
@@ -938,6 +934,12 @@ class Dataset(HLObject):
selection = sel.FancySelection(self.shape)
selection[args]
+ # 5. Broadcast scalars if necessary
+ if val.shape == () and selection.mshape != ():
+ val2 = numpy.empty(selection.mshape, dtype=val.dtype)
+ val2[...] = val
+ val = val2
+
# 5. Perform the write, with broadcasting
mspace = h5s.create_simple(val.shape, (h5s.UNLIMITED,)*len(val.shape))
for fspace in selection.shape_broadcast(val.shape):
diff --git a/h5py/selections.py b/h5py/selections.py
index 9480031..65a6a27 100644
--- a/h5py/selections.py
+++ b/h5py/selections.py
@@ -97,45 +97,60 @@ class RectSelection(Selection):
def __init__(self, *args, **kwds):
Selection.__init__(self, *args, **kwds)
- self._sel = ((0,)*len(self.shape), self.shape, (1,)*len(self.shape))
+ rank = len(self.shape)
+ self._sel = ((0,)*rank, self.shape, (1,)*rank, (False,)*rank)
+ self.mshape = self.shape
def __getitem__(self, args):
if not isinstance(args, tuple):
args = (args,)
- start, count, step = self._handle_args(args)
+ start, count, step, scalar = self._handle_args(args)
self._id.select_hyperslab(start, count, step)
- self._sel = (start, count, step)
+ self._sel = (start, count, step, scalar)
+
+ self.mshape = tuple(x for x, y in zip(count, scalar) if not y)
return self._id
- def shape_broadcast(self, cshape):
+
+ def shape_broadcast(self, target_shape):
""" Return an iterator over target dataspaces for broadcasting """
# count = (10,10,10)
# cshape = (1,1,5)
- start, count, step = self._sel
- rank = len(self.shape)
- diff = rank - len(cshape)
- if diff > 0:
- cshape = (1,)*diff + cshape
- elif diff < 0:
- raise TypeError("Cannot broadcast %s -> %s (too big)" % (count, cshape))
-
- if any(x%y != 0 for x, y in zip(count, cshape)):
- raise TypeError("Cannot broadcast %s -> %s" % (count, cshape))
+ start, count, step, scalar = self._sel
+
+ rank = len(count)
+ target = list(target_shape)
+
+ tshape = []
+ for idx in xrange(1,rank+1):
+ if len(target) == 0 or scalar[-idx]: # Skip scalar axes
+ tshape.append(1)
+ else:
+ t = target.pop()
+ if count[-idx] == t or t == 1:
+ tshape.append(t)
+ else:
+ raise TypeError("Can't broadcast %s -> %s [%s,%s,%s] %s\n%s" % (target_shape, count, count[-idx], t, -idx, tshape, self._sel))
+ tshape.reverse()
+ tshape = tuple(tshape)
+
+ chunks = tuple(x/y for x, y in zip(count, tshape))
+
+ #print tshape, chunks
- chunks = tuple(x/y for x, y in zip(count, cshape))
nchunks = np.product(chunks)
sid = self._id.copy()
- sid.select_hyperslab((0,)*rank, cshape, step)
+ sid.select_hyperslab((0,)*rank, tshape, step)
for idx in xrange(nchunks):
- offset = tuple(x*y*z + s for x, y, z, s in zip(np.unravel_index(idx, chunks), cshape, step, start))
+ offset = tuple(x*y*z + s for x, y, z, s in zip(np.unravel_index(idx, chunks), tshape, step, start))
sid.offset_simple(offset)
yield sid
@@ -148,25 +163,27 @@ class RectSelection(Selection):
"""
args = _broadcast(args, len(self.shape))
- def handle_arg(arg, length):
- if isinstance(arg, slice):
- return _translate_slice(arg, length)
- try:
- return _translate_int(int(arg), length)
- except TypeError:
- raise TypeError("Illegal index (must be a slice or number)")
-
start = []
count = []
step = []
+ scalar = []
- for a, length in zip(args, self.shape):
- x,y,z = handle_arg(a, length)
+ for arg, length in zip(args, self.shape):
+ if isinstance(arg, slice):
+ x,y,z = _translate_slice(arg, length)
+ s = False
+ else:
+ try:
+ x,y,z = _translate_int(int(arg), length)
+ s = True
+ except TypeError:
+ raise TypeError('Illegal index "%s" (must be a slice or number)' % arg)
start.append(x)
count.append(y)
step.append(z)
+ scalar.append(s)
- return tuple(start), tuple(count), tuple(step)
+ return tuple(start), tuple(count), tuple(step), tuple(scalar)
class HyperSelection(RectSelection):
@@ -366,8 +383,8 @@ def _translate_int(exp, length):
if exp < 0:
exp = length+exp
- if not 0<=exp<(length-1):
- raise ValueError("Index out of range")
+ if not 0<=exp<length:
+ raise ValueError("Index (%s) out of range (0-%s)" % (exp, length-1))
return exp, 1, 1
diff --git a/h5py/tests/test_highlevel.py b/h5py/tests/test_highlevel.py
index 80baa2f..18341cf 100644
--- a/h5py/tests/test_highlevel.py
+++ b/h5py/tests/test_highlevel.py
@@ -457,10 +457,10 @@ class TestDataset(HDF5TestCase):
slices += [ s[3,...], s[3,2,...] ]
#slices += [ numpy.random.random((10,10,50)) > 0.5 ] # Truth array
#slices += [ numpy.zeros((10,10,50), dtype='bool') ]
- slices += [ s[0, 1, [2,3,6,7]], s[:,[1,2]], s[[1,2]], s[3:7,[1]]]
+ #slices += [ s[0, 1, [2,3,6,7]], s[:,[1,2]], s[[1,2]], s[3:7,[1]]]
for slc in slices:
- self.output(" Checking %s" % ((slc,) if not isinstance(slc, numpy.ndarray) else 'ARRAY'))
+ print " Checking %s on %s" % ((slc,) if not isinstance(slc, numpy.ndarray) else 'ARRAY', srcarr.shape)
verify(slc)
def test_slice_names(self):
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/h5py.git
More information about the debian-science-commits
mailing list