Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions xrspatial/geotiff/_crs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"""
from __future__ import annotations

import numbers
import warnings

from ._runtime import GeoTIFFFallbackWarning, _geotiff_strict_mode
Expand Down Expand Up @@ -44,6 +45,69 @@ def _looks_like_wkt(s: str) -> bool:
return s.lstrip().upper().startswith(_WKT_ROOT_KEYWORDS)


def _validate_crs_arg(crs) -> None:
"""Reject malformed ``crs=`` arguments before they reach the writer.

Closes two gaps in the writer entry points (issue #1971):

* ``bool`` is an ``int`` subclass, so ``crs=True`` and ``crs=False``
would otherwise slip through ``isinstance(crs, int)`` and write
``EPSG=1`` / ``EPSG=0`` to the file. No CRS database resolves
those, so the result is silent metadata corruption.
* An ``int`` EPSG code that pyproj cannot resolve gets written
verbatim into ``ProjectedCSType`` / ``GeographicType``. The
file then round-trips with ``attrs['crs']`` set to the bad
value and only a ``GeoTIFFFallbackWarning`` to tell the caller
something is wrong.

Validates ``crs`` is one of ``None`` (no-op), ``int`` (a valid
EPSG code), or ``str`` (WKT/PROJ -- left for ``_wkt_to_epsg``
downstream). Pyproj is optional; the EPSG-resolves check is
skipped when pyproj is not installed, matching the rest of the
module's pyproj-optional posture. Under
``XRSPATIAL_GEOTIFF_STRICT=1`` the pyproj error is re-raised
instead of being wrapped.
"""
if crs is None:
return
if isinstance(crs, bool):
raise ValueError(
f"crs must be an int (EPSG code), str (WKT/PROJ), or None; "
f"got bool ({crs!r}). bool is an int subclass in Python, so "
f"passing True/False would otherwise be written as EPSG=1 / "
f"EPSG=0 -- neither resolves with any CRS database."
)
# ``numbers.Integral`` covers plain ``int`` and numpy integer scalars
# (``np.int32``, ``np.int64``, ...). Without this branch the type
# check below rejects numpy-typed CRS values that callers previously
# got away with (pre-PR they silently fell through to "no EPSG
# written"; the post-PR ``isinstance(crs, int)`` check would raise
# ``TypeError`` on the same input).
if isinstance(crs, numbers.Integral):
crs_int = int(crs)
try:
from pyproj import CRS
except ImportError:
return
try:
CRS.from_epsg(crs_int)
except Exception as e:
if _geotiff_strict_mode():
raise
raise ValueError(
f"crs={crs!r} is not a valid EPSG code "
f"(pyproj: {type(e).__name__}: {e}). Pass a valid "
f"EPSG integer, a WKT string, or None."
) from e
return
if isinstance(crs, str):
return
raise TypeError(
f"crs must be int (EPSG code), str (WKT/PROJ), or None; "
f"got {type(crs).__name__}."
)


def _wkt_to_epsg(wkt_or_proj: str) -> int | None:
"""Try to extract an EPSG code from a WKT or PROJ string.

Expand Down Expand Up @@ -149,8 +213,14 @@ def _resolve_crs_to_wkt(crs) -> str | None:
other than a string. (A string is passed through verbatim so the
WKT-only path keeps working without pyproj.)
"""
_validate_crs_arg(crs)
if crs is None:
return None
# ``_validate_crs_arg`` already accepts ``numbers.Integral`` (incl.
# numpy integer scalars); coerce here so the pyproj path and the
# str-only branch below see a plain ``int``.
if isinstance(crs, numbers.Integral) and not isinstance(crs, bool):
crs = int(crs)
if not isinstance(crs, (int, str)):
raise TypeError(
f"crs must be int (EPSG code), str (WKT or PROJ), or None; "
Expand Down
14 changes: 13 additions & 1 deletion xrspatial/geotiff/_writers/eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
require_transform_for_georeferenced as _require_transform_for_georeferenced,
transform_from_attr as _transform_from_attr,
)
from .._crs import _validate_crs_fallback, _wkt_to_epsg
from .._crs import _validate_crs_arg, _validate_crs_fallback, _wkt_to_epsg
from .._geotags import GeoTransform, RASTER_PIXEL_IS_AREA
from .._runtime import (
GeoTIFFFallbackWarning,
Expand Down Expand Up @@ -499,6 +499,7 @@ def to_geotiff(data: xr.DataArray | np.ndarray,
extra_tags_list = None

# Resolve crs argument: can be int (EPSG) or str (WKT/PROJ)
_validate_crs_arg(crs)
if isinstance(crs, int):
epsg = crs
elif isinstance(crs, str):
Expand Down Expand Up @@ -530,6 +531,11 @@ def to_geotiff(data: xr.DataArray | np.ndarray,
if epsg is None and wkt_fallback is None:
wkt_fallback = crs_attr
elif crs_attr is not None:
# Same gate as the kwarg path: reject bool / non-int
# types and confirm the EPSG resolves before writing it
# to disk. Without this, ``attrs={'crs': True}`` round-
# trips as EPSG=1 (issue #1971 follow-up).
_validate_crs_arg(crs_attr)
epsg = int(crs_attr)
if epsg is None:
wkt = data.attrs.get('crs_wkt')
Expand Down Expand Up @@ -798,6 +804,7 @@ def _write_vrt_tiled(data, vrt_path, *, crs=None, nodata=None,
os.makedirs(tiles_dir, exist_ok=True)

# Resolve CRS
_validate_crs_arg(crs)
epsg = None
wkt_fallback = None
if isinstance(crs, int):
Expand All @@ -824,6 +831,11 @@ def _write_vrt_tiled(data, vrt_path, *, crs=None, nodata=None,
if epsg is None and wkt_fallback is None:
wkt_fallback = crs_attr
elif crs_attr is not None:
# Same gate as the kwarg path: reject bool / non-int
# types and confirm the EPSG resolves before writing it
# to disk. Without this, ``attrs={'crs': True}`` round-
# trips as EPSG=1 (issue #1971 follow-up).
_validate_crs_arg(crs_attr)
epsg = int(crs_attr)
if epsg is None:
wkt = data.attrs.get('crs_wkt')
Expand Down
8 changes: 7 additions & 1 deletion xrspatial/geotiff/_writers/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
require_transform_for_georeferenced as _require_transform_for_georeferenced,
transform_from_attr as _transform_from_attr,
)
from .._crs import _validate_crs_fallback, _wkt_to_epsg
from .._crs import _validate_crs_arg, _validate_crs_fallback, _wkt_to_epsg
from .._runtime import GeoTIFFFallbackWarning
from .._validation import (
_validate_3d_writer_dims,
Expand Down Expand Up @@ -310,6 +310,7 @@ def write_geotiff_gpu(data: xr.DataArray | cupy.ndarray | np.ndarray,
y_res = None
res_unit = None

_validate_crs_arg(crs)
if isinstance(crs, int):
epsg = crs
elif isinstance(crs, str):
Expand Down Expand Up @@ -366,6 +367,11 @@ def write_geotiff_gpu(data: xr.DataArray | cupy.ndarray | np.ndarray,
if epsg is None and wkt_fallback is None:
wkt_fallback = crs_attr
elif crs_attr is not None:
# Same gate as the kwarg path: reject bool / non-int
# types and confirm the EPSG resolves before writing it
# to disk. Without this, ``attrs={'crs': True}`` round-
# trips as EPSG=1 (issue #1971 follow-up).
_validate_crs_arg(crs_attr)
epsg = int(crs_attr)
if epsg is None:
wkt = data.attrs.get('crs_wkt')
Expand Down
160 changes: 160 additions & 0 deletions xrspatial/geotiff/tests/test_crs_arg_validation_1971.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
"""Validate the writer entry points reject bool / unresolvable EPSG (#1971).

``bool`` is an int subclass, so ``crs=True`` used to slip through
``isinstance(crs, int)`` and write EPSG=1 to the file (with EPSG=0 for
``crs=False``). Integer EPSG codes were also written without a pyproj
round-trip, so any int that does not resolve as a CRS produced a file
with garbage in ``ProjectedCSType`` / ``GeographicType`` and only a
``GeoTIFFFallbackWarning`` to flag it.

Locks down the rejection at all three writer entry points: ``to_geotiff``
(eager), ``write_geotiff_gpu`` (GPU), and ``to_geotiff`` with
``vrt_tiled=True`` (the deprecated VRT-tiled path).
"""
from __future__ import annotations

import io

import numpy as np
import pytest
import xarray as xr

from xrspatial.geotiff import to_geotiff
from xrspatial.geotiff._crs import _validate_crs_arg

pyproj = pytest.importorskip("pyproj")


def _square(dtype=np.float32):
return xr.DataArray(
np.zeros((4, 4), dtype=dtype),
coords={'y': np.arange(4.0), 'x': np.arange(4.0)},
dims=('y', 'x'),
)


@pytest.mark.parametrize("bad_crs", [True, False])
def test_validate_crs_arg_rejects_bool(bad_crs):
with pytest.raises(ValueError, match="bool"):
_validate_crs_arg(bad_crs)


def test_validate_crs_arg_rejects_unresolvable_epsg():
# EPSG:1 does not exist in any CRS database.
with pytest.raises(ValueError, match="EPSG"):
_validate_crs_arg(1)


def test_validate_crs_arg_accepts_valid_epsg():
_validate_crs_arg(4326) # WGS84


def test_validate_crs_arg_accepts_none():
_validate_crs_arg(None)


def test_validate_crs_arg_accepts_str():
# Strings are deferred to ``_wkt_to_epsg`` and the WKT-fallback
# path; the entry-point validator only catches bool and bogus int.
_validate_crs_arg("EPSG:4326")
_validate_crs_arg('PROJCS["foo",GEOGCS["bar"]]')


def test_validate_crs_arg_rejects_non_int_non_str():
# ValueError vs TypeError split is intentional: bool and unresolvable
# EPSG are semantically wrong (right type, bad value), while float /
# other objects are the wrong type entirely. Tests downstream pin
# both exception classes so a regression in either direction trips.
with pytest.raises(TypeError, match="crs must be int"):
_validate_crs_arg(4326.0)


@pytest.mark.parametrize("np_int", [np.int64(4326), np.int32(4326)])
def test_validate_crs_arg_accepts_numpy_integer(np_int):
# ``isinstance(np.int64(4326), int)`` is False on most platforms, so
# the pre-fix validator rejected numpy integer CRS values with
# ``TypeError``. ``numbers.Integral`` covers them.
_validate_crs_arg(np_int)


def test_validate_crs_arg_rejects_numpy_bool():
# ``np.bool_`` is not a ``bool`` subclass but is ``numbers.Integral``
# in some numpy versions. Make sure the bool guard still catches it
# before the Integral branch coerces ``True`` -> ``1`` -> EPSG=1.
# If numpy bool is not an Integral instance, it falls through to the
# TypeError branch, which is also acceptable.
val = np.bool_(True)
with pytest.raises((ValueError, TypeError)):
_validate_crs_arg(val)


@pytest.mark.parametrize("bad_crs", [True, False])
def test_to_geotiff_rejects_bool_crs(bad_crs):
buf = io.BytesIO()
with pytest.raises(ValueError, match="bool"):
to_geotiff(_square(), buf, crs=bad_crs)


def test_to_geotiff_rejects_unresolvable_epsg():
buf = io.BytesIO()
with pytest.raises(ValueError, match="EPSG"):
to_geotiff(_square(), buf, crs=1)


def test_to_geotiff_accepts_valid_epsg():
buf = io.BytesIO()
to_geotiff(_square(), buf, crs=4326)
assert buf.getbuffer().nbytes > 0


def test_to_geotiff_accepts_numpy_int_epsg():
# End-to-end check that the validator's ``numbers.Integral`` branch
# actually lets a numpy integer through ``to_geotiff``.
buf = io.BytesIO()
to_geotiff(_square(), buf, crs=np.int64(4326))
assert buf.getbuffer().nbytes > 0


def test_to_geotiff_attrs_crs_bool_bypass(tmp_path):
# Regression for the self-review blocker: ``attrs={'crs': True}``
# with no explicit ``crs=`` kwarg used to bypass ``_validate_crs_arg``
# and write EPSG=1 to the file. The fix calls the validator on the
# ``attrs['crs']`` value in the writer's int-EPSG branch.
da = _square()
da.attrs['crs'] = True
buf = io.BytesIO()
with pytest.raises(ValueError, match="bool"):
to_geotiff(da, buf)


def test_to_geotiff_attrs_crs_unresolvable_epsg_bypass():
da = _square()
da.attrs['crs'] = 1 # int that pyproj cannot resolve
buf = io.BytesIO()
with pytest.raises(ValueError, match="EPSG"):
to_geotiff(da, buf)


def test_to_geotiff_vrt_path_rejects_bool_crs(tmp_path):
# ``to_geotiff(da, '*.vrt')`` dispatches to ``_write_vrt_tiled``,
# which has its own crs resolution block. The validator runs in
# that branch too.
vrt_path = str(tmp_path / "tmp_1971_vrt_tiled.vrt")
with pytest.raises(ValueError, match="bool"):
to_geotiff(_square(), vrt_path, crs=True)


def test_to_geotiff_vrt_path_rejects_unresolvable_epsg(tmp_path):
vrt_path = str(tmp_path / "tmp_1971_vrt_bad_epsg.vrt")
with pytest.raises(ValueError, match="EPSG"):
to_geotiff(_square(), vrt_path, crs=1)


def test_to_geotiff_vrt_path_attrs_crs_bool_bypass(tmp_path):
# Same blocker as ``test_to_geotiff_attrs_crs_bool_bypass`` but for
# the ``_write_vrt_tiled`` branch (``to_geotiff(da, '*.vrt')``).
da = _square()
da.attrs['crs'] = True
vrt_path = str(tmp_path / "tmp_1971_vrt_attrs_bool.vrt")
with pytest.raises(ValueError, match="bool"):
to_geotiff(da, vrt_path)
Loading