[h5py] 196/455: Fix broadcasting code

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Thu Jul 2 18:19:32 UTC 2015


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to annotated tag 1.3.0
in repository h5py.

commit 3db385b23784facc3a8411ece25d56ed7021b376
Author: andrewcollette <andrew.collette at gmail.com>
Date:   Mon Jan 26 21:37:33 2009 +0000

    Fix broadcasting code
---
 h5py/highlevel.py            | 12 ++++---
 h5py/selections.py           | 77 +++++++++++++++++++++++++++-----------------
 h5py/tests/test_highlevel.py |  4 +--
 3 files changed, 56 insertions(+), 37 deletions(-)

diff --git a/h5py/highlevel.py b/h5py/highlevel.py
index e90f03c..ebb512d 100644
--- a/h5py/highlevel.py
+++ b/h5py/highlevel.py
@@ -924,12 +924,8 @@ class Dataset(HLObject):
             else:
                 raise NotImplementedError("Field name selections are not yet allowed for write.")
 
-            # 3. Validate the input array.  Also convert scalars for broadcast.
+            # 3. Validate the input array
             val = numpy.asarray(val, order='C')
-            if val.shape == () and self.shape != ():
-                fastest = self.shape[-1]
-                if fastest < 1e6:
-                    val = numpy.repeat(val, fastest)
 
             # 4. Perform the dataspace selection
             if sel.is_simple(args):
@@ -938,6 +934,12 @@ class Dataset(HLObject):
                 selection = sel.FancySelection(self.shape)
             selection[args]
 
+            # 5. Broadcast scalars if necessary
+            if val.shape == () and selection.mshape != ():
+                val2 = numpy.empty(selection.mshape, dtype=val.dtype)
+                val2[...] = val
+                val = val2
+            
             # 5. Perform the write, with broadcasting
             mspace = h5s.create_simple(val.shape, (h5s.UNLIMITED,)*len(val.shape))
             for fspace in selection.shape_broadcast(val.shape):
diff --git a/h5py/selections.py b/h5py/selections.py
index 9480031..65a6a27 100644
--- a/h5py/selections.py
+++ b/h5py/selections.py
@@ -97,45 +97,60 @@ class RectSelection(Selection):
 
     def __init__(self, *args, **kwds):
         Selection.__init__(self, *args, **kwds)
-        self._sel = ((0,)*len(self.shape), self.shape, (1,)*len(self.shape))
+        rank = len(self.shape)
+        self._sel = ((0,)*rank, self.shape, (1,)*rank, (False,)*rank)
+        self.mshape = self.shape
 
     def __getitem__(self, args):
         if not isinstance(args, tuple):
             args = (args,)
   
-        start, count, step = self._handle_args(args)
+        start, count, step, scalar = self._handle_args(args)
 
         self._id.select_hyperslab(start, count, step)
 
-        self._sel = (start, count, step)
+        self._sel = (start, count, step, scalar)
+
+        self.mshape = tuple(x for x, y in zip(count, scalar) if not y)
 
         return self._id
 
-    def shape_broadcast(self, cshape):
+
+    def shape_broadcast(self, target_shape):
         """ Return an iterator over target dataspaces for broadcasting """
 
         # count = (10,10,10)
         # cshape = (1,1,5)
 
-        start, count, step = self._sel
-        rank = len(self.shape)
-        diff = rank - len(cshape)
-        if diff > 0:
-            cshape = (1,)*diff + cshape
-        elif diff < 0:
-            raise TypeError("Cannot broadcast %s -> %s (too big)" % (count, cshape))
-            
-        if any(x%y != 0 for x, y in zip(count, cshape)):
-            raise TypeError("Cannot broadcast %s -> %s" % (count, cshape))
+        start, count, step, scalar = self._sel
+
+        rank = len(count)
+        target = list(target_shape)
+
+        tshape = []
+        for idx in xrange(1,rank+1):
+            if len(target) == 0 or scalar[-idx]:     # Skip scalar axes
+                tshape.append(1)
+            else:
+                t = target.pop()
+                if count[-idx] == t or t == 1:
+                    tshape.append(t)
+                else:
+                    raise TypeError("Can't broadcast %s -> %s [%s,%s,%s] %s\n%s" % (target_shape, count, count[-idx], t, -idx, tshape, self._sel))
+        tshape.reverse()
+        tshape = tuple(tshape)
+
+        chunks = tuple(x/y for x, y in zip(count, tshape))
+
+        #print tshape, chunks
 
-        chunks = tuple(x/y for x, y in zip(count, cshape))
         nchunks = np.product(chunks)
 
         sid = self._id.copy()
-        sid.select_hyperslab((0,)*rank, cshape, step)
+        sid.select_hyperslab((0,)*rank, tshape, step)
 
         for idx in xrange(nchunks):
-            offset = tuple(x*y*z + s for x, y, z, s in zip(np.unravel_index(idx, chunks), cshape, step, start))
+            offset = tuple(x*y*z + s for x, y, z, s in zip(np.unravel_index(idx, chunks), tshape, step, start))
             sid.offset_simple(offset)
             yield sid
 
@@ -148,25 +163,27 @@ class RectSelection(Selection):
         """
         args = _broadcast(args, len(self.shape))
 
-        def handle_arg(arg, length):
-            if isinstance(arg, slice):
-                return _translate_slice(arg, length)
-            try:
-                return _translate_int(int(arg), length)
-            except TypeError:
-                raise TypeError("Illegal index (must be a slice or number)")
-
         start = []
         count = []
         step  = []
+        scalar = []
 
-        for a, length in zip(args, self.shape):
-            x,y,z = handle_arg(a, length)
+        for arg, length in zip(args, self.shape):
+            if isinstance(arg, slice):
+                x,y,z = _translate_slice(arg, length)
+                s = False
+            else:
+                try:
+                    x,y,z = _translate_int(int(arg), length)
+                    s = True
+                except TypeError:
+                    raise TypeError('Illegal index "%s" (must be a slice or number)' % arg)
             start.append(x)
             count.append(y)
             step.append(z)
+            scalar.append(s)
 
-        return tuple(start), tuple(count), tuple(step)
+        return tuple(start), tuple(count), tuple(step), tuple(scalar)
 
 class HyperSelection(RectSelection):
 
@@ -366,8 +383,8 @@ def _translate_int(exp, length):
     if exp < 0:
         exp = length+exp
 
-    if not 0<=exp<(length-1):
-        raise ValueError("Index out of range")
+    if not 0<=exp<length:
+        raise ValueError("Index (%s) out of range (0-%s)" % (exp, length-1))
 
     return exp, 1, 1
 
diff --git a/h5py/tests/test_highlevel.py b/h5py/tests/test_highlevel.py
index 80baa2f..18341cf 100644
--- a/h5py/tests/test_highlevel.py
+++ b/h5py/tests/test_highlevel.py
@@ -457,10 +457,10 @@ class TestDataset(HDF5TestCase):
         slices += [ s[3,...], s[3,2,...] ]
         #slices += [ numpy.random.random((10,10,50)) > 0.5 ]  # Truth array
         #slices += [ numpy.zeros((10,10,50), dtype='bool') ]
-        slices += [ s[0, 1, [2,3,6,7]], s[:,[1,2]], s[[1,2]], s[3:7,[1]]]
+        #slices += [ s[0, 1, [2,3,6,7]], s[:,[1,2]], s[[1,2]], s[3:7,[1]]]
 
         for slc in slices:
-            self.output("    Checking %s" % ((slc,) if not isinstance(slc, numpy.ndarray) else 'ARRAY'))
+            print "    Checking %s on %s" % ((slc,) if not isinstance(slc, numpy.ndarray) else 'ARRAY', srcarr.shape)
             verify(slc)
 
     def test_slice_names(self):

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/h5py.git



More information about the debian-science-commits mailing list