[python-debian/master 28/36] Fix up debian.arfile and debian.debfile for string handling changes in Python 3, involving adding encoding= and errors= parameters in a number of places. Loosely inspired by tarfile.

Colin Watson cjwatson at canonical.com
Mon Oct 8 07:41:25 UTC 2012


---
 lib/debian/arfile.py  |   61 ++++++++++++++++++++++++++++++++++++------------
 lib/debian/debfile.py |   62 +++++++++++++++++++++++++++++++++++++------------
 tests/test_debfile.py |   29 +++++++++++++++-------
 3 files changed, 113 insertions(+), 39 deletions(-)

diff --git a/lib/debian/arfile.py b/lib/debian/arfile.py
index 841b716..eea9007 100644
--- a/lib/debian/arfile.py
+++ b/lib/debian/arfile.py
@@ -17,11 +17,13 @@
 
 from __future__ import print_function
 
-GLOBAL_HEADER = "!<arch>\n"
+import sys
+
+GLOBAL_HEADER = b"!<arch>\n"
 GLOBAL_HEADER_LENGTH = len(GLOBAL_HEADER)
 
 FILE_HEADER_LENGTH = 60
-FILE_MAGIC = "`\n"
+FILE_MAGIC = b"`\n"
 
 class ArError(Exception):
     pass
@@ -36,14 +38,30 @@ class ArFile(object):
         - members       same as getmembers()
     """
 
-    def __init__(self, filename=None, mode='r', fileobj=None):
+    def __init__(self, filename=None, mode='r', fileobj=None,
+                 encoding=None, errors=None):
         """ Build an ar file representation starting from either a filename or
-        an existing file object. The only supported mode is 'r' """
+        an existing file object. The only supported mode is 'r'.
+
+        In Python 3, the encoding and errors parameters control how member
+        names are decoded into Unicode strings. Like tarfile, the default
+        encoding is sys.getfilesystemencoding() and the default error handling
+        scheme is 'surrogateescape' (>= 3.2) or 'strict' (< 3.2).
+        """
 
         self.__members = [] 
         self.__members_dict = {}
         self.__fname = filename
         self.__fileobj = fileobj
+        if encoding is None:
+            encoding = sys.getfilesystemencoding()
+        self.__encoding = encoding
+        if errors is None:
+            if sys.version >= '3.2':
+                errors = 'surrogateescape'
+            else:
+                errors = 'strict'
+        self.__errors = errors
         
         if mode == "r":
             self.__index_archive()
@@ -61,7 +79,9 @@ class ArFile(object):
             raise ArError("Unable to find global header")
 
         while True:
-            newmember = ArMember.from_file(fp, self.__fname)
+            newmember = ArMember.from_file(fp, self.__fname,
+                                           encoding=self.__encoding,
+                                           errors=self.__errors)
             if not newmember:
                 break
             self.__members.append(newmember)
@@ -162,7 +182,7 @@ class ArMember(object):
         self.__offset = None    # start-of-data offset
         self.__end = None       # end-of-data offset
 
-    def from_file(fp, fname):
+    def from_file(fp, fname, encoding=None, errors=None):
         """fp is an open File object positioned on a valid file header inside
         an ar archive. Return a new ArMember on success, None otherwise. """
 
@@ -178,6 +198,15 @@ class ArMember(object):
         if buf[58:60] != FILE_MAGIC:
             raise IOError("Incorrect file magic")
 
+        if sys.version >= '3':
+            if encoding is None:
+                encoding = sys.getfilesystemencoding()
+            if errors is None:
+                if sys.version >= '3.2':
+                    errors = 'surrogateescape'
+                else:
+                    errors = 'strict'
+
         # http://en.wikipedia.org/wiki/Ar_(Unix)    
         #from   to     Name                      Format
         #0      15     File name                 ASCII
@@ -190,7 +219,9 @@ class ArMember(object):
 
         # XXX struct.unpack can be used as well here
         f = ArMember()
-        f.__name = buf[0:16].split("/")[0].strip()
+        f.__name = buf[0:16].split(b"/")[0].strip()
+        if sys.version >= '3':
+            f.__name = f.__name.decode(encoding, errors)
         f.__mtime = int(buf[16:28])
         f.__owner = int(buf[28:34])
         f.__group = int(buf[34:40])
@@ -210,7 +241,7 @@ class ArMember(object):
     # XXX this is not a sequence like file objects
     def read(self, size=0):
         if self.__fp is None:
-            self.__fp = open(self.__fname, "r")
+            self.__fp = open(self.__fname, "rb")
             self.__fp.seek(self.__offset)
 
         cur = self.__fp.tell()
@@ -219,31 +250,31 @@ class ArMember(object):
             return self.__fp.read(size)
 
         if cur >= self.__end or cur < self.__offset:
-            return ''
+            return b''
 
         return self.__fp.read(self.__end - cur)
 
     def readline(self, size=None):
         if self.__fp is None:
-            self.__fp = open(self.__fname, "r")
+            self.__fp = open(self.__fname, "rb")
             self.__fp.seek(self.__offset)
 
         if size is not None: 
             buf = self.__fp.readline(size)
             if self.__fp.tell() > self.__end:
-                return ''
+                return b''
 
             return buf
 
         buf = self.__fp.readline()
         if self.__fp.tell() > self.__end:
-            return ''
+            return b''
         else:
             return buf
 
     def readlines(self, sizehint=0):
         if self.__fp is None:
-            self.__fp = open(self.__fname, "r")
+            self.__fp = open(self.__fname, "rb")
             self.__fp.seek(self.__offset)
         
         buf = None
@@ -258,7 +289,7 @@ class ArMember(object):
 
     def seek(self, offset, whence=0):
         if self.__fp is None:
-            self.__fp = open(self.__fname, "r")
+            self.__fp = open(self.__fname, "rb")
             self.__fp.seek(self.__offset)
 
         if self.__fp.tell() < self.__offset:
@@ -276,7 +307,7 @@ class ArMember(object):
 
     def tell(self):
         if self.__fp is None:
-            self.__fp = open(self.__fname, "r")
+            self.__fp = open(self.__fname, "rb")
             self.__fp.seek(self.__offset)
 
         cur = self.__fp.tell()
diff --git a/lib/debian/debfile.py b/lib/debian/debfile.py
index da8c738..a728a77 100644
--- a/lib/debian/debfile.py
+++ b/lib/debian/debfile.py
@@ -108,20 +108,42 @@ class DebPart(object):
         return (('./' + fname in names) \
                 or (fname in names)) # XXX python << 2.5 TarFile compatibility
 
-    def get_file(self, fname):
-        """Return a file object corresponding to a given file name."""
+    def get_file(self, fname, encoding=None, errors=None):
+        """Return a file object corresponding to a given file name.
+
+        If encoding is given, then the file object will return Unicode data;
+        otherwise, it will return binary data.
+        """
 
         fname = DebPart.__normalize_member(fname)
         try:
-            return (self.tgz().extractfile('./' + fname))
+            fobj = self.tgz().extractfile('./' + fname)
         except KeyError:    # XXX python << 2.5 TarFile compatibility
-            return (self.tgz().extractfile(fname))
-
-    def get_content(self, fname):
+            fobj = self.tgz().extractfile(fname)
+        if encoding is not None:
+            if sys.version >= '3':
+                import io
+                if not hasattr(fobj, 'flush'):
+                    # XXX http://bugs.python.org/issue13815
+                    fobj.flush = lambda: None
+                return io.TextIOWrapper(fobj, encoding=encoding, errors=errors)
+            else:
+                import codecs
+                if errors is None:
+                    errors = 'strict'
+                return codecs.EncodedFile(fobj, encoding, errors=errors)
+        else:
+            return fobj
+
+    def get_content(self, fname, encoding=None, errors=None):
         """Return the string content of a given file, or None (e.g. for
-        directories)."""
+        directories).
 
-        f = self.get_file(fname)
+        If encoding is given, then the content will be a Unicode object;
+        otherwise, it will contain binary data.
+        """
+
+        f = self.get_file(fname, encoding=encoding, errors=errors)
         content = None
         if f:   # can be None for non regular or link files
             content = f.read()
@@ -174,24 +196,34 @@ class DebControl(DebPart):
 
         return Deb822(self.get_content(CONTROL_FILE))
 
-    def md5sums(self):
+    def md5sums(self, encoding=None, errors=None):
         """ Return a dictionary mapping filenames (of the data part) to
         md5sums. Fails if the control part does not contain a 'md5sum' file.
 
         Keys of the returned dictionary are the left-hand side values of lines
         in the md5sums member of control.tar.gz, usually file names relative to
-        the file system root (without heading '/' or './'). """
+        the file system root (without heading '/' or './').
+
+        The returned keys are Unicode objects if an encoding is specified,
+        otherwise binary. The returned values are always Unicode."""
 
         if not self.has_file(MD5_FILE):
             raise DebError("'%s' file not found, can't list MD5 sums" %
                     MD5_FILE)
 
-        md5_file = self.get_file(MD5_FILE)
+        md5_file = self.get_file(MD5_FILE, encoding=encoding, errors=errors)
         sums = {}
+        if encoding is None:
+            newline = b'\r\n'
+        else:
+            newline = '\r\n'
         for line in md5_file.readlines():
             # we need to support spaces in filenames, .split() is not enough
-            md5, fname = line.rstrip('\r\n').split(None, 1)
-            sums[fname] = md5
+            md5, fname = line.rstrip(newline).split(None, 1)
+            if sys.version >= '3' and isinstance(md5, bytes):
+                sums[fname] = md5.decode()
+            else:
+                sums[fname] = md5
         md5_file.close()
         return sums
 
@@ -260,9 +292,9 @@ class DebFile(ArFile):
         """ See .control.scripts() """
         return self.control.scripts()
 
-    def md5sums(self):
+    def md5sums(self, encoding=None, errors=None):
         """ See .control.md5sums() """
-        return self.control.md5sums()
+        return self.control.md5sums(encoding=encoding, errors=errors)
 
     def changelog(self):
         """ Return a Changelog object for the changelog.Debian.gz of the
diff --git a/tests/test_debfile.py b/tests/test_debfile.py
index 9b8c67f..9907b58 100755
--- a/tests/test_debfile.py
+++ b/tests/test_debfile.py
@@ -27,6 +27,8 @@ import sys
 import tempfile
 import uu
 
+import six
+
 sys.path.insert(0, '../lib/')
 
 from debian import arfile
@@ -82,7 +84,7 @@ class TestArFile(unittest.TestCase):
     def test_file_read(self):
         """ test for faked read """
         for m in self.a.getmembers():
-            f = open(m.name)
+            f = open(m.name, 'rb')
         
             for i in [10, 100, 10000]:
                 self.assertEqual(m.read(i), f.read(i))
@@ -94,7 +96,7 @@ class TestArFile(unittest.TestCase):
         """ test for faked readlines """
 
         for m in self.a.getmembers():
-            f = open(m.name)
+            f = open(m.name, 'rb')
         
             self.assertEqual(m.readlines(), f.readlines())
             
@@ -105,8 +107,8 @@ class TestDebFile(unittest.TestCase):
 
     def setUp(self):
         def uudecode(infile, outfile):
-            uu_deb = open(infile, 'r')
-            bin_deb = open(outfile, 'w')
+            uu_deb = open(infile, 'rb')
+            bin_deb = open(outfile, 'wb')
             uu.decode(uu_deb, bin_deb)
             uu_deb.close()
             bin_deb.close()
@@ -119,8 +121,8 @@ class TestDebFile(unittest.TestCase):
         uudecode('test-bz2.deb.uu', self.bz2_debname)
 
         self.debname = 'test.deb'
-        uu_deb = open('test.deb.uu', 'r')
-        bin_deb = open(self.debname, 'w')
+        uu_deb = open('test.deb.uu', 'rb')
+        bin_deb = open(self.debname, 'wb')
         uu.decode(uu_deb, bin_deb)
         uu_deb.close()
         bin_deb.close()
@@ -160,14 +162,23 @@ class TestDebFile(unittest.TestCase):
         with os.popen("dpkg-deb -f %s" % self.debname) as dpkg_deb:
             filecontrol = "".join(dpkg_deb.readlines())
 
-        self.assertEqual(self.d.control.get_content("control"), filecontrol)
+        self.assertEqual(
+            self.d.control.get_content("control").decode("utf-8"), filecontrol)
+        self.assertEqual(
+            self.d.control.get_content("control", encoding="utf-8"),
+            filecontrol)
 
     def test_md5sums(self):
         """test md5 extraction from .debs"""
         md5 = self.d.md5sums()
-        self.assertEqual(md5['usr/bin/hello'],
+        self.assertEqual(md5[b'usr/bin/hello'],
+                '9c1a72a78f82216a0305b6c90ab71058')
+        self.assertEqual(md5[b'usr/share/locale/zh_TW/LC_MESSAGES/hello.mo'],
+                'a7356e05bd420872d03cd3f5369de42f')
+        md5 = self.d.md5sums(encoding='UTF-8')
+        self.assertEqual(md5[six.u('usr/bin/hello')],
                 '9c1a72a78f82216a0305b6c90ab71058')
-        self.assertEqual(md5['usr/share/locale/zh_TW/LC_MESSAGES/hello.mo'],
+        self.assertEqual(md5[six.u('usr/share/locale/zh_TW/LC_MESSAGES/hello.mo')],
                 'a7356e05bd420872d03cd3f5369de42f')
 
 if __name__ == '__main__':
-- 
1.7.2.5





More information about the pkg-python-debian-commits mailing list