[segyio] 246/376: Make trace.__getitem__ accept more int-like types

Jørgen Kvalsvik jokva-guest at moszumanska.debian.org
Wed Sep 20 08:04:40 UTC 2017


This is an automated email from the git hooks/post-receive script.

jokva-guest pushed a commit to branch debian
in repository segyio.

commit 94548e03de9ddc4ef97883ad63fe577a75234d5c
Author: Jørgen Kvalsvik <jokva at statoil.com>
Date:   Fri Mar 24 10:33:53 2017 +0100

    Make trace.__getitem__ accept more int-like types
    
    Under some conditions, perfectly fine int-like types would raise a type
    error when passed to getitem. Instead of checking for isinstance(int),
    a slice or an int-like type is expected, and we internally handle the
    conversion to an actual int.
---
 python/segyio/_raw_trace.py | 12 +++++++++++-
 python/segyio/_segyio.c     | 42 ++++++++++--------------------------------
 python/segyio/_trace.py     | 43 +++++++++++++++++++++----------------------
 python/test/segy.py         | 11 +++++++++++
 python/test/segyio_c.py     |  4 ++--
 5 files changed, 55 insertions(+), 57 deletions(-)

diff --git a/python/segyio/_raw_trace.py b/python/segyio/_raw_trace.py
index 4a7113e..9b0f581 100644
--- a/python/segyio/_raw_trace.py
+++ b/python/segyio/_raw_trace.py
@@ -1,6 +1,11 @@
 import numpy as np
 import segyio
 
+try: xrange
+except NameError: pass
+else: range = xrange
+
+
 class RawTrace(object):
     def __init__(self, trace):
         self.trace = trace
@@ -14,8 +19,13 @@ class RawTrace(object):
             mstart, mstop = min(start, stop), max(start, stop)
             length = max(0, (mstop - mstart + (step - (1 if step > 0 else -1))))
             buf = np.zeros(shape = (length, len(f.samples)), dtype = np.single)
+            l = len(range(start, stop, step))
+            return self.trace._readtr(start, step, l, buf)
+
+        if int(index) != index:
+            raise TypeError("Trace index must be integer or slice.")
 
-        return self.trace._readtr(index, buf)
+        return self.trace._readtr(int(index), 1, 1, buf)
 
     def __repr__(self):
         return self.trace.__repr__() + ".raw"
diff --git a/python/segyio/_segyio.c b/python/segyio/_segyio.c
index 4d475df..6a782ce 100644
--- a/python/segyio/_segyio.c
+++ b/python/segyio/_segyio.c
@@ -932,30 +932,27 @@ static PyObject *py_fread_trace0(PyObject *self, PyObject *args) {
 static PyObject *py_read_trace(PyObject *self, PyObject *args) {
     errno = 0;
     PyObject *file_capsule = NULL;
-    PyObject *trace_no;
     PyObject *buffer_out;
-    int trace_count;
+    int start, length, step;
     long trace0;
     int trace_bsize;
     int format;
     int samples;
 
-    PyArg_ParseTuple(args, "OOiOliii", &file_capsule, &trace_no, &trace_count, &buffer_out, &trace0, &trace_bsize, &format, &samples);
+    PyArg_ParseTuple(args, "OOiiiiili", &file_capsule,
+                                        &buffer_out,
+                                        &start,
+                                        &step,
+                                        &length,
+                                        &format,
+                                        &samples,
+                                        &trace0,
+                                        &trace_bsize );
 
     segy_file *p_FILE = get_FILE_pointer_from_capsule(file_capsule);
 
     if (PyErr_Occurred()) { return NULL; }
 
-    if( !trace_no || trace_no == Py_None ) {
-        PyErr_SetString(PyExc_TypeError, "Trace number must be int or slice." );
-        return NULL;
-    }
-
-    if( !integer_check( trace_no ) && !PySlice_Check( trace_no ) ) {
-        PyErr_SetString(PyExc_TypeError, "Trace number must be int or slice." );
-        return NULL;
-    }
-
     if (!PyObject_CheckBuffer(buffer_out)) {
         PyErr_SetString(PyExc_TypeError, "The destination buffer is not of the correct type.");
         return NULL;
@@ -963,25 +960,6 @@ static PyObject *py_read_trace(PyObject *self, PyObject *args) {
     Py_buffer buffer;
     PyObject_GetBuffer(buffer_out, &buffer, PyBUF_FORMAT | PyBUF_C_CONTIGUOUS | PyBUF_WRITEABLE);
 
-    Py_ssize_t start, stop, step, length;
-    if( PySlice_Check( trace_no ) ) {
-        int err = PySlice_GetIndicesEx( (PySliceObject*)trace_no,
-                                        trace_count,
-                                        &start, &stop, &step,
-                                        &length );
-        if( err != 0 ) {
-            PyBuffer_Release( &buffer );
-            return NULL;
-        }
-    }
-    else {
-        start = convert_integer( trace_no );
-        if( start < 0 ) start += trace_count;
-        step = 1;
-        stop = start + step;
-        length = 1;
-    }
-
     int error = 0;
     float* buf = buffer.buf;
     Py_ssize_t i;
diff --git a/python/segyio/_trace.py b/python/segyio/_trace.py
index 9643f1d..df5d425 100644
--- a/python/segyio/_trace.py
+++ b/python/segyio/_trace.py
@@ -15,21 +15,19 @@ class Trace:
 
         buf = self._trace_buffer(buf)
 
-        if isinstance(index, int):
-            if not 0 <= abs(index) < len(self):
-                raise IndexError("Trace %d not in range (-%d,%d)", (index, len(self), len(self)))
-
-            return self._readtr(index, buf)
-
-        elif isinstance(index, slice):
+        if isinstance(index, slice):
             def gen():
                 for i in range(*index.indices(len(self))):
-                    yield self._readtr(i, buf)
+                    yield self._readtr(i, 1, 1, buf)
 
             return gen()
 
-        else:
-            raise TypeError("Key must be int, slice, (int,np.ndarray) or (slice,np.ndarray)")
+        if not 0 <= abs(index) < len(self):
+            raise IndexError("Trace %d not in range (-%d,%d)", (index, len(self), len(self)))
+
+        # map negative a negative to the corresponding positive value
+        start = (index + len(self)) % len(self)
+        return self._readtr(start, 1, 1, buf)
 
     def __setitem__(self, index, val):
         if not 0 <= abs(index) < len(self):
@@ -46,15 +44,11 @@ class Trace:
         if val.shape[0] < shape[0]:
             raise TypeError("Array wrong shape. Expected minimum %s, was %s" % (shape, val.shape))
 
-        if isinstance(index, int):
-            self._writetr(index, val)
-
-        elif isinstance(index, slice):
-            for i, buf in range(*index.indices(len(self))), val:
-                self._writetr(i, val)
+        if not isinstance(index, slice):
+            index = slice(index, index + 1, 1)
 
-        else:
-            raise KeyError("Wrong shape of index")
+        for i in range(*index.indices(len(self))):
+            self._writetr(i, val)
 
     def __len__(self):
         return self._file.tracecount
@@ -77,18 +71,23 @@ class Trace:
 
         return buf
 
-    def _readtr(self, traceno, buf=None):
+    def _readtr(self, start, step, length, buf=None):
         buf = self._trace_buffer(buf)
 
-        tracecount = self._file.tracecount
         trace0 = self._file._tr0
         bsz = self._file._bsz
         fmt = self._file._fmt
         smp = len(self._file.samples)
-        return segyio._segyio.read_trace(self._file.xfd, traceno, tracecount, buf, trace0, bsz, fmt, smp)
+        return segyio._segyio.read_trace(self._file.xfd, buf,
+                                         start, step, length,
+                                         fmt, smp,
+                                         trace0, bsz)
 
     def _writetr(self, traceno, buf):
-        self.write_trace(traceno, buf, self._file)
+        if int(traceno) != traceno:
+            raise TypeError("Trace index must be integer type")
+
+        self.write_trace(int(traceno), buf, self._file)
 
     @classmethod
     def write_trace(cls, traceno, buf, segy):
diff --git a/python/test/segy.py b/python/test/segy.py
index 6ab1480..20048a0 100644
--- a/python/test/segy.py
+++ b/python/test/segy.py
@@ -572,6 +572,17 @@ class TestSegy(TestCase):
             self.assertListEqual(list(f.attributes(189)[:]),
                                  [(i // 5) + 1 for i in range(len(f.trace))])
 
+    def test_traceaccess_from_array(self):
+        a = np.arange(10, dtype = np.int)
+        b = np.arange(10, dtype = np.int32)
+        c = np.arange(10, dtype = np.int64)
+        d = np.arange(10, dtype = np.intc)
+        with segyio.open(self.filename) as f:
+            f.trace[a[0]]
+            f.trace[b[1]]
+            f.trace[c[2]]
+            f.trace[d[3]]
+
     def test_create_sgy(self):
         with TestContext("create_sgy") as context:
             context.copy_file(self.filename)
diff --git a/python/test/segyio_c.py b/python/test/segyio_c.py
index 7bd8188..708e1f8 100644
--- a/python/test/segyio_c.py
+++ b/python/test/segyio_c.py
@@ -383,12 +383,12 @@ class _segyioTests(TestCase):
 
             buf = numpy.zeros(25, dtype=numpy.single)
 
-            _segyio.read_trace(f, 0, 25, buf, 0, 100, 1, 25)
+            _segyio.read_trace(f, buf, 0, 1, 1, 1, 25, 0, 100)
 
             self.assertAlmostEqual(buf[10], 1.0, places=4)
             self.assertAlmostEqual(buf[11], 3.1415, places=4)
 
-            _segyio.read_trace(f, 1, 25, buf, 0, 100, 1, 25)
+            _segyio.read_trace(f, buf, 1, 1, 1, 1, 25, 0, 100)
 
             self.assertAlmostEqual(sum(buf), 42.0 * 25, places=4)
 

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/segyio.git



More information about the debian-science-commits mailing list