diff --git a/shapefile.py b/shapefile.py index e56b3b2..c1a2d76 100644 --- a/shapefile.py +++ b/shapefile.py @@ -20,7 +20,7 @@ from collections.abc import Collection from datetime import date from struct import Struct, calcsize, error, pack, unpack -from typing import Any, Iterable, Iterator, Optional, Reversible, TypedDict, Union +from typing import IO, Any, Iterable, Iterator, Optional, Reversible, TypedDict, Union from urllib.error import HTTPError from urllib.parse import urlparse, urlunparse from urllib.request import Request, urlopen @@ -912,6 +912,16 @@ class ShapefileException(Exception): pass +class _NoShpSentinel(object): + """For use as a default value for shp to preserve the + behaviour (from when all keyword args were gathered + in the **kwargs dict) in case someone explictly + called Reader(shp=None) to load self.shx. + """ + + pass + + class Reader: """Reads the three files of a shapefile as a unit or separately. If one of the three files (.shp, .shx, @@ -933,24 +943,40 @@ class Reader: but they can be. """ - def __init__(self, *args, encoding="utf-8", encodingErrors="strict", **kwargs): + CONSTITUENT_FILE_EXTS = ["shp", "shx", "dbf"] + assert all(ext.islower() for ext in CONSTITUENT_FILE_EXTS) + + def _assert_ext_is_supported(self, ext: str): + assert ext in self.CONSTITUENT_FILE_EXTS + + def __init__( + self, + shapefile_path: str = "", + *, + encoding="utf-8", + encodingErrors="strict", + shp=_NoShpSentinel, + shx=None, + dbf=None, + **kwargs, + ): self.shp = None self.shx = None self.dbf = None self._files_to_close = [] self.shapeName = "Not specified" - self._offsets = [] + self._offsets: list[int] = [] self.shpLength = None self.numRecords = None self.numShapes = None - self.fields = [] + self.fields: list[list[str]] = [] self.__dbfHdrLength = 0 - self.__fieldLookup = {} + self.__fieldLookup: dict[str, int] = {} self.encoding = encoding self.encodingErrors = encodingErrors # See if a shapefile name was passed as the first argument - if len(args) > 0: - path = pathlike_obj(args[0]) + if shapefile_path: + path = pathlike_obj(shapefile_path) if is_string(path): if ".zip" in path: # Shapefile is inside a zipfile @@ -967,6 +993,8 @@ def __init__(self, *args, encoding="utf-8", encodingErrors="strict", **kwargs): else: zpath = path[: path.find(".zip") + 4] shapefile = path[path.find(".zip") + 4 + 1 :] + + zipfileobj: Union[tempfile._TemporaryFileWrapper, io.BufferedReader] # Create a zip file handle if zpath.startswith("http"): # Zipfile is from a url @@ -1014,19 +1042,20 @@ def __init__(self, *args, encoding="utf-8", encodingErrors="strict", **kwargs): shapefile = os.path.splitext(shapefile)[ 0 ] # root shapefile name - for ext in ["SHP", "SHX", "DBF", "shp", "shx", "dbf"]: - try: - member = archive.open(shapefile + "." + ext) - # write zipfile member data to a read+write tempfile and use as source, gets deleted on close() - fileobj = tempfile.NamedTemporaryFile( - mode="w+b", delete=True - ) - fileobj.write(member.read()) - fileobj.seek(0) - setattr(self, ext.lower(), fileobj) - self._files_to_close.append(fileobj) - except: - pass + for lower_ext in self.CONSTITUENT_FILE_EXTS: + for cased_ext in [lower_ext, lower_ext.upper()]: + try: + member = archive.open(f"{shapefile}.{cased_ext}") + # write zipfile member data to a read+write tempfile and use as source, gets deleted on close() + fileobj = tempfile.NamedTemporaryFile( + mode="w+b", delete=True + ) + fileobj.write(member.read()) + fileobj.seek(0) + setattr(self, lower_ext, fileobj) + self._files_to_close.append(fileobj) + except: + pass # Close and delete the temporary zipfile try: zipfileobj.close() @@ -1086,46 +1115,44 @@ def __init__(self, *args, encoding="utf-8", encodingErrors="strict", **kwargs): self.load(path) return - # Otherwise, load from separate shp/shx/dbf args (must be path or file-like) - if "shp" in kwargs: - if hasattr(kwargs["shp"], "read"): - self.shp = kwargs["shp"] - # Copy if required - try: - self.shp.seek(0) - except (NameError, io.UnsupportedOperation): - self.shp = io.BytesIO(self.shp.read()) - else: - (baseName, ext) = os.path.splitext(kwargs["shp"]) - self.load_shp(baseName) - - if "shx" in kwargs: - if hasattr(kwargs["shx"], "read"): - self.shx = kwargs["shx"] - # Copy if required - try: - self.shx.seek(0) - except (NameError, io.UnsupportedOperation): - self.shx = io.BytesIO(self.shx.read()) - else: - (baseName, ext) = os.path.splitext(kwargs["shx"]) - self.load_shx(baseName) + if shp is not _NoShpSentinel: + self.shp = self._seek_0_on_file_obj_wrap_or_open_from_name("shp", shp) + self.shx = self._seek_0_on_file_obj_wrap_or_open_from_name("shx", shx) - if "dbf" in kwargs: - if hasattr(kwargs["dbf"], "read"): - self.dbf = kwargs["dbf"] - # Copy if required - try: - self.dbf.seek(0) - except (NameError, io.UnsupportedOperation): - self.dbf = io.BytesIO(self.dbf.read()) - else: - (baseName, ext) = os.path.splitext(kwargs["dbf"]) - self.load_dbf(baseName) + self.dbf = self._seek_0_on_file_obj_wrap_or_open_from_name("dbf", dbf) # Load the files if self.shp or self.dbf: - self.load() + self._try_to_set_constituent_file_headers() + + def _seek_0_on_file_obj_wrap_or_open_from_name( + self, + ext: str, + # File name, file object or anything with a read() method that returns bytes. + # TODO: Create simple Protocol with a read() method + file_: Optional[Union[str, IO[bytes]]], + ) -> Union[None, io.BytesIO, IO[bytes]]: + # assert ext in {'shp', 'dbf', 'shx'} + self._assert_ext_is_supported(ext) + + if file_ is None: + return None + + if isinstance(file_, str): + baseName, __ = os.path.splitext(file_) + return self._load_constituent_file(baseName, ext) + + if hasattr(file_, "read"): + # Copy if required + try: + file_.seek(0) # type: ignore + return file_ + except (NameError, io.UnsupportedOperation): + return io.BytesIO(file_.read()) + + raise ShapefileException( + f"Could not load shapefile constituent file from: {file_}" + ) def __str__(self): """ @@ -1232,6 +1259,9 @@ def load(self, shapefile=None): raise ShapefileException( f"Unable to open {shapeName}.dbf or {shapeName}.shp." ) + self._try_to_set_constituent_file_headers() + + def _try_to_set_constituent_file_headers(self): if self.shp: self.__shpHeader() if self.dbf: @@ -1239,50 +1269,51 @@ def load(self, shapefile=None): if self.shx: self.__shxHeader() - def load_shp(self, shapefile_name): + def _try_get_open_constituent_file(self, shapefile_name: str, ext: str): """ - Attempts to load file with .shp extension as both lower and upper case + Attempts to open a .shp, .dbf or .shx file, + with both lower case and upper case file extensions, + and return it. If it was not possible to open the file, None is returned. """ - shp_ext = "shp" + # typing.LiteralString is only available from PYthon 3.11 onwards. + # https://docs.python.org/3/library/typing.html#typing.LiteralString + self._assert_ext_is_supported(ext) try: - self.shp = open(f"{shapefile_name}.{shp_ext}", "rb") - self._files_to_close.append(self.shp) + return open(f"{shapefile_name}.{ext}", "rb") except OSError: try: - self.shp = open(f"{shapefile_name}.{shp_ext.upper()}", "rb") - self._files_to_close.append(self.shp) + return open(f"{shapefile_name}.{ext.upper()}", "rb") except OSError: - pass + return None + + def _load_constituent_file(self, shapefile_name: str, ext: str): + """ + Attempts to open a .shp, .dbf or .shx file, with the extension + as both lower and upper case, and if successful append it to + self._files_to_close. + """ + shp_dbf_or_dhx_file = self._try_get_open_constituent_file(shapefile_name, ext) + if shp_dbf_or_dhx_file is not None: + self._files_to_close.append(shp_dbf_or_dhx_file) + return shp_dbf_or_dhx_file + + def load_shp(self, shapefile_name): + """ + Attempts to load file with .shp extension as both lower and upper case + """ + self.shp = self._load_constituent_file(shapefile_name, "shp") def load_shx(self, shapefile_name): """ Attempts to load file with .shx extension as both lower and upper case """ - shx_ext = "shx" - try: - self.shx = open(f"{shapefile_name}.{shx_ext}", "rb") - self._files_to_close.append(self.shx) - except OSError: - try: - self.shx = open(f"{shapefile_name}.{shx_ext.upper()}", "rb") - self._files_to_close.append(self.shx) - except OSError: - pass + self.shx = self._load_constituent_file(shapefile_name, "shx") def load_dbf(self, shapefile_name): """ Attempts to load file with .dbf extension as both lower and upper case """ - dbf_ext = "dbf" - try: - self.dbf = open(f"{shapefile_name}.{dbf_ext}", "rb") - self._files_to_close.append(self.dbf) - except OSError: - try: - self.dbf = open(f"{shapefile_name}.{dbf_ext.upper()}", "rb") - self._files_to_close.append(self.dbf) - except OSError: - pass + self.dbf = self._load_constituent_file(shapefile_name, "dbf") def __del__(self): self.close()