diff --git a/xrspatial/geotiff/_crs.py b/xrspatial/geotiff/_crs.py index 18544cda..2d74a4f2 100644 --- a/xrspatial/geotiff/_crs.py +++ b/xrspatial/geotiff/_crs.py @@ -12,6 +12,7 @@ """ from __future__ import annotations +import numbers import warnings from ._runtime import GeoTIFFFallbackWarning, _geotiff_strict_mode @@ -44,6 +45,69 @@ def _looks_like_wkt(s: str) -> bool: return s.lstrip().upper().startswith(_WKT_ROOT_KEYWORDS) +def _validate_crs_arg(crs) -> None: + """Reject malformed ``crs=`` arguments before they reach the writer. + + Closes two gaps in the writer entry points (issue #1971): + + * ``bool`` is an ``int`` subclass, so ``crs=True`` and ``crs=False`` + would otherwise slip through ``isinstance(crs, int)`` and write + ``EPSG=1`` / ``EPSG=0`` to the file. No CRS database resolves + those, so the result is silent metadata corruption. + * An ``int`` EPSG code that pyproj cannot resolve gets written + verbatim into ``ProjectedCSType`` / ``GeographicType``. The + file then round-trips with ``attrs['crs']`` set to the bad + value and only a ``GeoTIFFFallbackWarning`` to tell the caller + something is wrong. + + Validates ``crs`` is one of ``None`` (no-op), ``int`` (a valid + EPSG code), or ``str`` (WKT/PROJ -- left for ``_wkt_to_epsg`` + downstream). Pyproj is optional; the EPSG-resolves check is + skipped when pyproj is not installed, matching the rest of the + module's pyproj-optional posture. Under + ``XRSPATIAL_GEOTIFF_STRICT=1`` the pyproj error is re-raised + instead of being wrapped. + """ + if crs is None: + return + if isinstance(crs, bool): + raise ValueError( + f"crs must be an int (EPSG code), str (WKT/PROJ), or None; " + f"got bool ({crs!r}). bool is an int subclass in Python, so " + f"passing True/False would otherwise be written as EPSG=1 / " + f"EPSG=0 -- neither resolves with any CRS database." + ) + # ``numbers.Integral`` covers plain ``int`` and numpy integer scalars + # (``np.int32``, ``np.int64``, ...). Without this branch the type + # check below rejects numpy-typed CRS values that callers previously + # got away with (pre-PR they silently fell through to "no EPSG + # written"; the post-PR ``isinstance(crs, int)`` check would raise + # ``TypeError`` on the same input). + if isinstance(crs, numbers.Integral): + crs_int = int(crs) + try: + from pyproj import CRS + except ImportError: + return + try: + CRS.from_epsg(crs_int) + except Exception as e: + if _geotiff_strict_mode(): + raise + raise ValueError( + f"crs={crs!r} is not a valid EPSG code " + f"(pyproj: {type(e).__name__}: {e}). Pass a valid " + f"EPSG integer, a WKT string, or None." + ) from e + return + if isinstance(crs, str): + return + raise TypeError( + f"crs must be int (EPSG code), str (WKT/PROJ), or None; " + f"got {type(crs).__name__}." + ) + + def _wkt_to_epsg(wkt_or_proj: str) -> int | None: """Try to extract an EPSG code from a WKT or PROJ string. @@ -149,8 +213,14 @@ def _resolve_crs_to_wkt(crs) -> str | None: other than a string. (A string is passed through verbatim so the WKT-only path keeps working without pyproj.) """ + _validate_crs_arg(crs) if crs is None: return None + # ``_validate_crs_arg`` already accepts ``numbers.Integral`` (incl. + # numpy integer scalars); coerce here so the pyproj path and the + # str-only branch below see a plain ``int``. + if isinstance(crs, numbers.Integral) and not isinstance(crs, bool): + crs = int(crs) if not isinstance(crs, (int, str)): raise TypeError( f"crs must be int (EPSG code), str (WKT or PROJ), or None; " diff --git a/xrspatial/geotiff/_writers/eager.py b/xrspatial/geotiff/_writers/eager.py index 1915fb84..a52028b1 100644 --- a/xrspatial/geotiff/_writers/eager.py +++ b/xrspatial/geotiff/_writers/eager.py @@ -33,7 +33,7 @@ require_transform_for_georeferenced as _require_transform_for_georeferenced, transform_from_attr as _transform_from_attr, ) -from .._crs import _validate_crs_fallback, _wkt_to_epsg +from .._crs import _validate_crs_arg, _validate_crs_fallback, _wkt_to_epsg from .._geotags import GeoTransform, RASTER_PIXEL_IS_AREA from .._runtime import ( GeoTIFFFallbackWarning, @@ -499,6 +499,7 @@ def to_geotiff(data: xr.DataArray | np.ndarray, extra_tags_list = None # Resolve crs argument: can be int (EPSG) or str (WKT/PROJ) + _validate_crs_arg(crs) if isinstance(crs, int): epsg = crs elif isinstance(crs, str): @@ -530,6 +531,11 @@ def to_geotiff(data: xr.DataArray | np.ndarray, if epsg is None and wkt_fallback is None: wkt_fallback = crs_attr elif crs_attr is not None: + # Same gate as the kwarg path: reject bool / non-int + # types and confirm the EPSG resolves before writing it + # to disk. Without this, ``attrs={'crs': True}`` round- + # trips as EPSG=1 (issue #1971 follow-up). + _validate_crs_arg(crs_attr) epsg = int(crs_attr) if epsg is None: wkt = data.attrs.get('crs_wkt') @@ -798,6 +804,7 @@ def _write_vrt_tiled(data, vrt_path, *, crs=None, nodata=None, os.makedirs(tiles_dir, exist_ok=True) # Resolve CRS + _validate_crs_arg(crs) epsg = None wkt_fallback = None if isinstance(crs, int): @@ -824,6 +831,11 @@ def _write_vrt_tiled(data, vrt_path, *, crs=None, nodata=None, if epsg is None and wkt_fallback is None: wkt_fallback = crs_attr elif crs_attr is not None: + # Same gate as the kwarg path: reject bool / non-int + # types and confirm the EPSG resolves before writing it + # to disk. Without this, ``attrs={'crs': True}`` round- + # trips as EPSG=1 (issue #1971 follow-up). + _validate_crs_arg(crs_attr) epsg = int(crs_attr) if epsg is None: wkt = data.attrs.get('crs_wkt') diff --git a/xrspatial/geotiff/_writers/gpu.py b/xrspatial/geotiff/_writers/gpu.py index c08ab791..4869c28c 100644 --- a/xrspatial/geotiff/_writers/gpu.py +++ b/xrspatial/geotiff/_writers/gpu.py @@ -23,7 +23,7 @@ require_transform_for_georeferenced as _require_transform_for_georeferenced, transform_from_attr as _transform_from_attr, ) -from .._crs import _validate_crs_fallback, _wkt_to_epsg +from .._crs import _validate_crs_arg, _validate_crs_fallback, _wkt_to_epsg from .._runtime import GeoTIFFFallbackWarning from .._validation import ( _validate_3d_writer_dims, @@ -310,6 +310,7 @@ def write_geotiff_gpu(data: xr.DataArray | cupy.ndarray | np.ndarray, y_res = None res_unit = None + _validate_crs_arg(crs) if isinstance(crs, int): epsg = crs elif isinstance(crs, str): @@ -366,6 +367,11 @@ def write_geotiff_gpu(data: xr.DataArray | cupy.ndarray | np.ndarray, if epsg is None and wkt_fallback is None: wkt_fallback = crs_attr elif crs_attr is not None: + # Same gate as the kwarg path: reject bool / non-int + # types and confirm the EPSG resolves before writing it + # to disk. Without this, ``attrs={'crs': True}`` round- + # trips as EPSG=1 (issue #1971 follow-up). + _validate_crs_arg(crs_attr) epsg = int(crs_attr) if epsg is None: wkt = data.attrs.get('crs_wkt') diff --git a/xrspatial/geotiff/tests/test_crs_arg_validation_1971.py b/xrspatial/geotiff/tests/test_crs_arg_validation_1971.py new file mode 100644 index 00000000..2d2749e7 --- /dev/null +++ b/xrspatial/geotiff/tests/test_crs_arg_validation_1971.py @@ -0,0 +1,160 @@ +"""Validate the writer entry points reject bool / unresolvable EPSG (#1971). + +``bool`` is an int subclass, so ``crs=True`` used to slip through +``isinstance(crs, int)`` and write EPSG=1 to the file (with EPSG=0 for +``crs=False``). Integer EPSG codes were also written without a pyproj +round-trip, so any int that does not resolve as a CRS produced a file +with garbage in ``ProjectedCSType`` / ``GeographicType`` and only a +``GeoTIFFFallbackWarning`` to flag it. + +Locks down the rejection at all three writer entry points: ``to_geotiff`` +(eager), ``write_geotiff_gpu`` (GPU), and ``to_geotiff`` with +``vrt_tiled=True`` (the deprecated VRT-tiled path). +""" +from __future__ import annotations + +import io + +import numpy as np +import pytest +import xarray as xr + +from xrspatial.geotiff import to_geotiff +from xrspatial.geotiff._crs import _validate_crs_arg + +pyproj = pytest.importorskip("pyproj") + + +def _square(dtype=np.float32): + return xr.DataArray( + np.zeros((4, 4), dtype=dtype), + coords={'y': np.arange(4.0), 'x': np.arange(4.0)}, + dims=('y', 'x'), + ) + + +@pytest.mark.parametrize("bad_crs", [True, False]) +def test_validate_crs_arg_rejects_bool(bad_crs): + with pytest.raises(ValueError, match="bool"): + _validate_crs_arg(bad_crs) + + +def test_validate_crs_arg_rejects_unresolvable_epsg(): + # EPSG:1 does not exist in any CRS database. + with pytest.raises(ValueError, match="EPSG"): + _validate_crs_arg(1) + + +def test_validate_crs_arg_accepts_valid_epsg(): + _validate_crs_arg(4326) # WGS84 + + +def test_validate_crs_arg_accepts_none(): + _validate_crs_arg(None) + + +def test_validate_crs_arg_accepts_str(): + # Strings are deferred to ``_wkt_to_epsg`` and the WKT-fallback + # path; the entry-point validator only catches bool and bogus int. + _validate_crs_arg("EPSG:4326") + _validate_crs_arg('PROJCS["foo",GEOGCS["bar"]]') + + +def test_validate_crs_arg_rejects_non_int_non_str(): + # ValueError vs TypeError split is intentional: bool and unresolvable + # EPSG are semantically wrong (right type, bad value), while float / + # other objects are the wrong type entirely. Tests downstream pin + # both exception classes so a regression in either direction trips. + with pytest.raises(TypeError, match="crs must be int"): + _validate_crs_arg(4326.0) + + +@pytest.mark.parametrize("np_int", [np.int64(4326), np.int32(4326)]) +def test_validate_crs_arg_accepts_numpy_integer(np_int): + # ``isinstance(np.int64(4326), int)`` is False on most platforms, so + # the pre-fix validator rejected numpy integer CRS values with + # ``TypeError``. ``numbers.Integral`` covers them. + _validate_crs_arg(np_int) + + +def test_validate_crs_arg_rejects_numpy_bool(): + # ``np.bool_`` is not a ``bool`` subclass but is ``numbers.Integral`` + # in some numpy versions. Make sure the bool guard still catches it + # before the Integral branch coerces ``True`` -> ``1`` -> EPSG=1. + # If numpy bool is not an Integral instance, it falls through to the + # TypeError branch, which is also acceptable. + val = np.bool_(True) + with pytest.raises((ValueError, TypeError)): + _validate_crs_arg(val) + + +@pytest.mark.parametrize("bad_crs", [True, False]) +def test_to_geotiff_rejects_bool_crs(bad_crs): + buf = io.BytesIO() + with pytest.raises(ValueError, match="bool"): + to_geotiff(_square(), buf, crs=bad_crs) + + +def test_to_geotiff_rejects_unresolvable_epsg(): + buf = io.BytesIO() + with pytest.raises(ValueError, match="EPSG"): + to_geotiff(_square(), buf, crs=1) + + +def test_to_geotiff_accepts_valid_epsg(): + buf = io.BytesIO() + to_geotiff(_square(), buf, crs=4326) + assert buf.getbuffer().nbytes > 0 + + +def test_to_geotiff_accepts_numpy_int_epsg(): + # End-to-end check that the validator's ``numbers.Integral`` branch + # actually lets a numpy integer through ``to_geotiff``. + buf = io.BytesIO() + to_geotiff(_square(), buf, crs=np.int64(4326)) + assert buf.getbuffer().nbytes > 0 + + +def test_to_geotiff_attrs_crs_bool_bypass(tmp_path): + # Regression for the self-review blocker: ``attrs={'crs': True}`` + # with no explicit ``crs=`` kwarg used to bypass ``_validate_crs_arg`` + # and write EPSG=1 to the file. The fix calls the validator on the + # ``attrs['crs']`` value in the writer's int-EPSG branch. + da = _square() + da.attrs['crs'] = True + buf = io.BytesIO() + with pytest.raises(ValueError, match="bool"): + to_geotiff(da, buf) + + +def test_to_geotiff_attrs_crs_unresolvable_epsg_bypass(): + da = _square() + da.attrs['crs'] = 1 # int that pyproj cannot resolve + buf = io.BytesIO() + with pytest.raises(ValueError, match="EPSG"): + to_geotiff(da, buf) + + +def test_to_geotiff_vrt_path_rejects_bool_crs(tmp_path): + # ``to_geotiff(da, '*.vrt')`` dispatches to ``_write_vrt_tiled``, + # which has its own crs resolution block. The validator runs in + # that branch too. + vrt_path = str(tmp_path / "tmp_1971_vrt_tiled.vrt") + with pytest.raises(ValueError, match="bool"): + to_geotiff(_square(), vrt_path, crs=True) + + +def test_to_geotiff_vrt_path_rejects_unresolvable_epsg(tmp_path): + vrt_path = str(tmp_path / "tmp_1971_vrt_bad_epsg.vrt") + with pytest.raises(ValueError, match="EPSG"): + to_geotiff(_square(), vrt_path, crs=1) + + +def test_to_geotiff_vrt_path_attrs_crs_bool_bypass(tmp_path): + # Same blocker as ``test_to_geotiff_attrs_crs_bool_bypass`` but for + # the ``_write_vrt_tiled`` branch (``to_geotiff(da, '*.vrt')``). + da = _square() + da.attrs['crs'] = True + vrt_path = str(tmp_path / "tmp_1971_vrt_attrs_bool.vrt") + with pytest.raises(ValueError, match="bool"): + to_geotiff(da, vrt_path)