Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 115 additions & 84 deletions shapefile.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from collections.abc import Collection
from datetime import date
from struct import Struct, calcsize, error, pack, unpack
from typing import Any, Iterable, Iterator, Optional, Reversible, TypedDict, Union
from typing import IO, Any, Iterable, Iterator, Optional, Reversible, TypedDict, Union
from urllib.error import HTTPError
from urllib.parse import urlparse, urlunparse
from urllib.request import Request, urlopen
Expand Down Expand Up @@ -912,6 +912,16 @@ class ShapefileException(Exception):
pass


class _NoShpSentinel(object):
"""For use as a default value for shp to preserve the
behaviour (from when all keyword args were gathered
in the **kwargs dict) in case someone explictly
called Reader(shp=None) to load self.shx.
"""

pass


class Reader:
"""Reads the three files of a shapefile as a unit or
separately. If one of the three files (.shp, .shx,
Expand All @@ -933,24 +943,40 @@ class Reader:
but they can be.
"""

def __init__(self, *args, encoding="utf-8", encodingErrors="strict", **kwargs):
CONSTITUENT_FILE_EXTS = ["shp", "shx", "dbf"]
assert all(ext.islower() for ext in CONSTITUENT_FILE_EXTS)

def _assert_ext_is_supported(self, ext: str):
assert ext in self.CONSTITUENT_FILE_EXTS

def __init__(
self,
shapefile_path: str = "",
*,
encoding="utf-8",
encodingErrors="strict",
shp=_NoShpSentinel,
shx=None,
dbf=None,
**kwargs,
):
self.shp = None
self.shx = None
self.dbf = None
self._files_to_close = []
self.shapeName = "Not specified"
self._offsets = []
self._offsets: list[int] = []
self.shpLength = None
self.numRecords = None
self.numShapes = None
self.fields = []
self.fields: list[list[str]] = []
self.__dbfHdrLength = 0
self.__fieldLookup = {}
self.__fieldLookup: dict[str, int] = {}
self.encoding = encoding
self.encodingErrors = encodingErrors
# See if a shapefile name was passed as the first argument
if len(args) > 0:
path = pathlike_obj(args[0])
if shapefile_path:
path = pathlike_obj(shapefile_path)
if is_string(path):
if ".zip" in path:
# Shapefile is inside a zipfile
Expand All @@ -967,6 +993,8 @@ def __init__(self, *args, encoding="utf-8", encodingErrors="strict", **kwargs):
else:
zpath = path[: path.find(".zip") + 4]
shapefile = path[path.find(".zip") + 4 + 1 :]

zipfileobj: Union[tempfile._TemporaryFileWrapper, io.BufferedReader]
# Create a zip file handle
if zpath.startswith("http"):
# Zipfile is from a url
Expand Down Expand Up @@ -1014,19 +1042,20 @@ def __init__(self, *args, encoding="utf-8", encodingErrors="strict", **kwargs):
shapefile = os.path.splitext(shapefile)[
0
] # root shapefile name
for ext in ["SHP", "SHX", "DBF", "shp", "shx", "dbf"]:
try:
member = archive.open(shapefile + "." + ext)
# write zipfile member data to a read+write tempfile and use as source, gets deleted on close()
fileobj = tempfile.NamedTemporaryFile(
mode="w+b", delete=True
)
fileobj.write(member.read())
fileobj.seek(0)
setattr(self, ext.lower(), fileobj)
self._files_to_close.append(fileobj)
except:
pass
for lower_ext in self.CONSTITUENT_FILE_EXTS:
for cased_ext in [lower_ext, lower_ext.upper()]:
try:
member = archive.open(f"{shapefile}.{cased_ext}")
# write zipfile member data to a read+write tempfile and use as source, gets deleted on close()
fileobj = tempfile.NamedTemporaryFile(
mode="w+b", delete=True
)
fileobj.write(member.read())
fileobj.seek(0)
setattr(self, lower_ext, fileobj)
self._files_to_close.append(fileobj)
except:
pass
# Close and delete the temporary zipfile
try:
zipfileobj.close()
Expand Down Expand Up @@ -1086,46 +1115,44 @@ def __init__(self, *args, encoding="utf-8", encodingErrors="strict", **kwargs):
self.load(path)
return

# Otherwise, load from separate shp/shx/dbf args (must be path or file-like)
if "shp" in kwargs:
if hasattr(kwargs["shp"], "read"):
self.shp = kwargs["shp"]
# Copy if required
try:
self.shp.seek(0)
except (NameError, io.UnsupportedOperation):
self.shp = io.BytesIO(self.shp.read())
else:
(baseName, ext) = os.path.splitext(kwargs["shp"])
self.load_shp(baseName)

if "shx" in kwargs:
if hasattr(kwargs["shx"], "read"):
self.shx = kwargs["shx"]
# Copy if required
try:
self.shx.seek(0)
except (NameError, io.UnsupportedOperation):
self.shx = io.BytesIO(self.shx.read())
else:
(baseName, ext) = os.path.splitext(kwargs["shx"])
self.load_shx(baseName)
if shp is not _NoShpSentinel:
self.shp = self._seek_0_on_file_obj_wrap_or_open_from_name("shp", shp)
self.shx = self._seek_0_on_file_obj_wrap_or_open_from_name("shx", shx)

if "dbf" in kwargs:
if hasattr(kwargs["dbf"], "read"):
self.dbf = kwargs["dbf"]
# Copy if required
try:
self.dbf.seek(0)
except (NameError, io.UnsupportedOperation):
self.dbf = io.BytesIO(self.dbf.read())
else:
(baseName, ext) = os.path.splitext(kwargs["dbf"])
self.load_dbf(baseName)
self.dbf = self._seek_0_on_file_obj_wrap_or_open_from_name("dbf", dbf)

# Load the files
if self.shp or self.dbf:
self.load()
self._try_to_set_constituent_file_headers()

def _seek_0_on_file_obj_wrap_or_open_from_name(
self,
ext: str,
# File name, file object or anything with a read() method that returns bytes.
# TODO: Create simple Protocol with a read() method
file_: Optional[Union[str, IO[bytes]]],
) -> Union[None, io.BytesIO, IO[bytes]]:
# assert ext in {'shp', 'dbf', 'shx'}
self._assert_ext_is_supported(ext)

if file_ is None:
return None

if isinstance(file_, str):
baseName, __ = os.path.splitext(file_)
return self._load_constituent_file(baseName, ext)

if hasattr(file_, "read"):
# Copy if required
try:
file_.seek(0) # type: ignore
return file_
except (NameError, io.UnsupportedOperation):
return io.BytesIO(file_.read())

raise ShapefileException(
f"Could not load shapefile constituent file from: {file_}"
)

def __str__(self):
"""
Expand Down Expand Up @@ -1232,57 +1259,61 @@ def load(self, shapefile=None):
raise ShapefileException(
f"Unable to open {shapeName}.dbf or {shapeName}.shp."
)
self._try_to_set_constituent_file_headers()

def _try_to_set_constituent_file_headers(self):
if self.shp:
self.__shpHeader()
if self.dbf:
self.__dbfHeader()
if self.shx:
self.__shxHeader()

def load_shp(self, shapefile_name):
def _try_get_open_constituent_file(self, shapefile_name: str, ext: str):
"""
Attempts to load file with .shp extension as both lower and upper case
Attempts to open a .shp, .dbf or .shx file,
with both lower case and upper case file extensions,
and return it. If it was not possible to open the file, None is returned.
"""
shp_ext = "shp"
# typing.LiteralString is only available from PYthon 3.11 onwards.
# https://docs.python.org/3/library/typing.html#typing.LiteralString
self._assert_ext_is_supported(ext)
try:
self.shp = open(f"{shapefile_name}.{shp_ext}", "rb")
self._files_to_close.append(self.shp)
return open(f"{shapefile_name}.{ext}", "rb")
except OSError:
try:
self.shp = open(f"{shapefile_name}.{shp_ext.upper()}", "rb")
self._files_to_close.append(self.shp)
return open(f"{shapefile_name}.{ext.upper()}", "rb")
except OSError:
pass
return None

def _load_constituent_file(self, shapefile_name: str, ext: str):
"""
Attempts to open a .shp, .dbf or .shx file, with the extension
as both lower and upper case, and if successful append it to
self._files_to_close.
"""
shp_dbf_or_dhx_file = self._try_get_open_constituent_file(shapefile_name, ext)
if shp_dbf_or_dhx_file is not None:
self._files_to_close.append(shp_dbf_or_dhx_file)
return shp_dbf_or_dhx_file

def load_shp(self, shapefile_name):
"""
Attempts to load file with .shp extension as both lower and upper case
"""
self.shp = self._load_constituent_file(shapefile_name, "shp")

def load_shx(self, shapefile_name):
"""
Attempts to load file with .shx extension as both lower and upper case
"""
shx_ext = "shx"
try:
self.shx = open(f"{shapefile_name}.{shx_ext}", "rb")
self._files_to_close.append(self.shx)
except OSError:
try:
self.shx = open(f"{shapefile_name}.{shx_ext.upper()}", "rb")
self._files_to_close.append(self.shx)
except OSError:
pass
self.shx = self._load_constituent_file(shapefile_name, "shx")

def load_dbf(self, shapefile_name):
"""
Attempts to load file with .dbf extension as both lower and upper case
"""
dbf_ext = "dbf"
try:
self.dbf = open(f"{shapefile_name}.{dbf_ext}", "rb")
self._files_to_close.append(self.dbf)
except OSError:
try:
self.dbf = open(f"{shapefile_name}.{dbf_ext.upper()}", "rb")
self._files_to_close.append(self.dbf)
except OSError:
pass
self.dbf = self._load_constituent_file(shapefile_name, "dbf")

def __del__(self):
self.close()
Expand Down
Loading