diff --git a/xrspatial/geotiff/_attrs.py b/xrspatial/geotiff/_attrs.py index df416b5d..37f8615c 100644 --- a/xrspatial/geotiff/_attrs.py +++ b/xrspatial/geotiff/_attrs.py @@ -17,6 +17,8 @@ """ from __future__ import annotations +import warnings + import numpy as np from ._coords import ( @@ -206,6 +208,10 @@ def _resolve_nodata_attr(attrs: dict): """ nodata = attrs.get('nodata') if nodata is not None: + try: + float(nodata) + except (TypeError, ValueError) as e: + raise ValueError(_nodata_attr_non_numeric_msg('nodata', nodata)) from e return nodata vals = attrs.get('nodatavals') @@ -214,23 +220,38 @@ def _resolve_nodata_attr(attrs: dict): seq = list(vals) except TypeError: seq = [vals] + saw_non_numeric = False for v in seq: if v is None: continue try: fv = float(v) except (TypeError, ValueError): + saw_non_numeric = True continue if np.isnan(fv): continue return v + # A tuple where every entry is non-numeric is almost certainly a + # user error (typo, stringified sentinel) rather than a legitimate + # "no sentinel" signal. Warn so the caller sees it, but still fall + # through to the rest of the resolution chain rather than raising: + # the rest of the function's contract is "skip non-numeric entries". + if saw_non_numeric: + warnings.warn( + f"attrs['nodatavals']={vals!r} contained only non-numeric " + f"entries; no usable sentinel could be resolved from it. " + f"Pass ``nodata=`` explicitly or fix the attr.", + UserWarning, + stacklevel=2, + ) fill = attrs.get('_FillValue') if fill is not None: try: ffv = float(fill) - except (TypeError, ValueError): - return fill # non-numeric -- pass through verbatim + except (TypeError, ValueError) as e: + raise ValueError(_nodata_attr_non_numeric_msg('_FillValue', fill)) from e if np.isnan(ffv): return None return fill @@ -238,6 +259,20 @@ def _resolve_nodata_attr(attrs: dict): return None +def _nodata_attr_non_numeric_msg(attr_name: str, value) -> str: + """Error string shared by the ``attrs['nodata']`` and ``attrs['_FillValue']`` + non-numeric branches in ``_resolve_nodata_attr`` (#1973).""" + return ( + f"attrs[{attr_name!r}]={value!r} is not numeric " + f"({type(value).__name__}). The writer needs a numeric " + f"sentinel to compare against pixel values; passing a " + f"non-numeric value would otherwise crash inside " + f"``np.isnan`` with an opaque ufunc error. Drop the " + f"attr, replace it with a numeric sentinel, or pass " + f"``nodata=`` explicitly (issue #1973)." + ) + + def _merge_friendly_extra_tags(extra_tags_list, attrs: dict) -> list | None: """Combine ``attrs['extra_tags']`` with friendly tag attrs. diff --git a/xrspatial/geotiff/_validation.py b/xrspatial/geotiff/_validation.py index 790a3ea3..3ed94cba 100644 --- a/xrspatial/geotiff/_validation.py +++ b/xrspatial/geotiff/_validation.py @@ -230,3 +230,36 @@ def _validate_predictor_sample_format(predictor, sample_format) -> None: f"pair, e.g. `gdal_translate -co PREDICTOR=2` for integers or " f"`-co PREDICTOR=1` to drop the predictor." ) + + +def _validate_nodata_arg(nodata) -> None: + """Reject non-numeric ``nodata=`` at the writer entry point (#1973). + + ``None`` (no sentinel) passes through. ``bool`` is rejected with + ``TypeError`` so all three writer entry points (eager, GPU, VRT) + refuse ``nodata=True`` / ``nodata=False`` the same way the eager + path already does for issue #1911 -- ``float(True) == 1.0`` would + otherwise slip a bool past the numeric branch on the GPU/VRT paths + that do not have their own bool guard. Anything else is run + through ``float()``: success means the writer's downstream + ``np.isnan(nodata)`` and integer-cast paths will not blow up. + Failure raises ``ValueError`` with the offending repr, so users + see ``nodata='missing'`` flagged at the boundary instead of an + opaque ``ufunc 'isnan' not supported`` TypeError from inside the + writer. + """ + if nodata is None: + return + if isinstance(nodata, (bool, np.bool_)): + raise TypeError( + f"nodata must be numeric (int or float), got {nodata!r}") + try: + float(nodata) + except (TypeError, ValueError) as e: + raise ValueError( + f"nodata must be numeric or None, got {nodata!r} " + f"(type {type(nodata).__name__}). The writer compares it " + f"against pixel values via ``np.isnan`` and casts it to " + f"the array dtype; a non-numeric value would otherwise " + f"crash inside NumPy with a ufunc TypeError." + ) from e diff --git a/xrspatial/geotiff/_writers/eager.py b/xrspatial/geotiff/_writers/eager.py index 1915fb84..03fcada0 100644 --- a/xrspatial/geotiff/_writers/eager.py +++ b/xrspatial/geotiff/_writers/eager.py @@ -42,6 +42,7 @@ ) from .._validation import ( _validate_3d_writer_dims, + _validate_nodata_arg, _validate_tile_size_arg, ) from .._writer import write @@ -267,6 +268,8 @@ def to_geotiff(data: xr.DataArray | np.ndarray, if tiled: _validate_tile_size_arg(tile_size) + _validate_nodata_arg(nodata) + # Up-front validation: catch bad compression names before they reach # any of the deeper write paths (streaming, GPU, VRT, COG) where the # error surfaces from _compression_tag with a less obvious traceback. @@ -775,6 +778,7 @@ def _write_vrt_tiled(data, vrt_path, *, crs=None, nodata=None, This enables streaming dask arrays to disk without materializing the full array in RAM. """ + _validate_nodata_arg(nodata) # Validate compression_level against codec-specific range if compression_level is not None: level_range = _LEVEL_RANGES.get(compression.lower()) diff --git a/xrspatial/geotiff/_writers/gpu.py b/xrspatial/geotiff/_writers/gpu.py index c08ab791..a5de64e0 100644 --- a/xrspatial/geotiff/_writers/gpu.py +++ b/xrspatial/geotiff/_writers/gpu.py @@ -27,6 +27,7 @@ from .._runtime import GeoTIFFFallbackWarning from .._validation import ( _validate_3d_writer_dims, + _validate_nodata_arg, _validate_tile_size_arg, ) @@ -259,6 +260,7 @@ def write_geotiff_gpu(data: xr.DataArray | cupy.ndarray | np.ndarray, # write_geotiff_gpu is always tiled, so validate tile_size here and # keep parity with the public to_geotiff entry point. _validate_tile_size_arg(tile_size) + _validate_nodata_arg(nodata) if max_z_error < 0: raise ValueError( f"max_z_error must be >= 0, got {max_z_error}") diff --git a/xrspatial/geotiff/tests/test_nodata_validation_1973.py b/xrspatial/geotiff/tests/test_nodata_validation_1973.py new file mode 100644 index 00000000..154604e0 --- /dev/null +++ b/xrspatial/geotiff/tests/test_nodata_validation_1973.py @@ -0,0 +1,183 @@ +"""Refuse non-numeric ``nodata=`` / ``attrs['_FillValue']`` (#1973). + +The writer compares the resolved nodata against pixel values via +``np.isnan`` and casts it to the array dtype. A non-numeric value +used to fall through ``_resolve_nodata_attr`` (returned verbatim) or +the ``nodata=`` kwarg path and then crash inside NumPy with +``ufunc 'isnan' not supported``. Both the entry point and the attr +resolution path now refuse non-numeric values up front with a clear +error. +""" +from __future__ import annotations + +import importlib.util +import io + +import numpy as np +import pytest +import xarray as xr + +from xrspatial.geotiff import to_geotiff +from xrspatial.geotiff._attrs import _resolve_nodata_attr +from xrspatial.geotiff._validation import _validate_nodata_arg + + +def _gpu_available() -> bool: + if importlib.util.find_spec("cupy") is None: + return False + try: + import cupy + return bool(cupy.cuda.is_available()) + except Exception: + return False + + +_HAS_GPU = _gpu_available() +_gpu_only = pytest.mark.skipif(not _HAS_GPU, reason="cupy + CUDA required") + + +def _nan_square(): + return xr.DataArray( + np.full((4, 4), np.nan, dtype=np.float32), + coords={'y': np.arange(4.0), 'x': np.arange(4.0)}, + dims=('y', 'x'), + ) + + +@pytest.mark.parametrize("bad", ['missing', object(), [1, 2]]) +def test_validate_nodata_arg_rejects_non_numeric(bad): + with pytest.raises(ValueError, match="nodata must be numeric"): + _validate_nodata_arg(bad) + + +@pytest.mark.parametrize("ok", [None, 0, -9999, 1.5, np.float32(-1), np.int64(0)]) +def test_validate_nodata_arg_accepts_numeric_and_none(ok): + _validate_nodata_arg(ok) + + +def test_resolve_nodata_attr_rejects_non_numeric_fillvalue(): + with pytest.raises(ValueError, match="_FillValue"): + _resolve_nodata_attr({'_FillValue': 'missing'}) + + +def test_resolve_nodata_attr_rejects_non_numeric_nodata_attr(): + with pytest.raises(ValueError, match=r"attrs\['nodata'\]"): + _resolve_nodata_attr({'nodata': 'missing'}) + + +def test_resolve_nodata_attr_skips_non_numeric_in_nodatavals(): + # nodatavals (rioxarray's per-band tuple) keeps its skip-on-non-numeric + # behaviour: those values often come from arbitrary upstream pipelines + # and a single bad entry should not block writing. + assert _resolve_nodata_attr({'nodatavals': ('NaN ', -9999.0)}) == -9999.0 + + +def test_resolve_nodata_attr_still_accepts_numeric_fillvalue(): + assert _resolve_nodata_attr({'_FillValue': -9999}) == -9999 + + +def test_resolve_nodata_attr_returns_none_for_nan_fillvalue(): + assert _resolve_nodata_attr({'_FillValue': float('nan')}) is None + + +def test_to_geotiff_rejects_non_numeric_nodata_kwarg(): + buf = io.BytesIO() + with pytest.raises(ValueError, match="nodata must be numeric"): + to_geotiff(_nan_square(), buf, nodata='missing') + + +def test_to_geotiff_rejects_non_numeric_fillvalue_attr(): + da = _nan_square() + da.attrs['_FillValue'] = 'missing' + buf = io.BytesIO() + with pytest.raises(ValueError, match="_FillValue"): + to_geotiff(da, buf) + + +def test_to_geotiff_vrt_path_rejects_non_numeric_nodata(tmp_path): + vrt_path = str(tmp_path / "tmp_1973_vrt.vrt") + with pytest.raises(ValueError, match="nodata must be numeric"): + to_geotiff(_nan_square(), vrt_path, nodata='missing') + + +def test_to_geotiff_accepts_numeric_nodata_kwarg(): + buf = io.BytesIO() + to_geotiff(_nan_square(), buf, nodata=-9999) + assert buf.getbuffer().nbytes > 0 + + +# --------------------------------------------------------------------------- +# Bool rejection: ``nodata=True`` / ``nodata=False`` must raise TypeError at +# all three writer entry points (eager, GPU, VRT). The eager path already +# rejected bools for #1911 but the GPU/VRT validators previously routed bool +# through ``float(True) == 1.0`` and silently coerced. The shared validator +# now refuses bools so all three paths behave the same. +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("bad", [True, False]) +def test_validate_nodata_arg_rejects_bool(bad): + with pytest.raises(TypeError, match="nodata must be numeric"): + _validate_nodata_arg(bad) + + +def test_validate_nodata_arg_rejects_numpy_bool(): + with pytest.raises(TypeError, match="nodata must be numeric"): + _validate_nodata_arg(np.bool_(True)) + + +def test_to_geotiff_eager_rejects_bool_nodata(): + buf = io.BytesIO() + with pytest.raises(TypeError, match="nodata must be numeric"): + to_geotiff(_nan_square(), buf, nodata=True) + + +def test_to_geotiff_vrt_rejects_bool_nodata(tmp_path): + vrt_path = str(tmp_path / "tmp_1973_bool_vrt.vrt") + with pytest.raises(TypeError, match="nodata must be numeric"): + to_geotiff(_nan_square(), vrt_path, nodata=True) + + +@_gpu_only +def test_write_geotiff_gpu_rejects_bool_nodata(tmp_path): + import cupy + + from xrspatial.geotiff import write_geotiff_gpu + + da_cpu = _nan_square() + da_gpu = da_cpu.copy(data=cupy.asarray(da_cpu.values)) + out = str(tmp_path / "tmp_1973_bool_gpu.tif") + with pytest.raises(TypeError, match="nodata must be numeric"): + write_geotiff_gpu(da_gpu, out, nodata=True) + + +# --------------------------------------------------------------------------- +# All-non-numeric ``attrs['nodatavals']``: warn but still return None and +# fall through. A tuple where every entry is non-numeric is almost certainly +# a user error rather than a legitimate "no sentinel" signal. +# --------------------------------------------------------------------------- + + +def test_resolve_nodata_attr_warns_when_nodatavals_all_non_numeric(): + with pytest.warns(UserWarning, match="nodatavals"): + result = _resolve_nodata_attr({'nodatavals': ('foo', 'bar')}) + assert result is None + + +def test_resolve_nodata_attr_no_warning_when_nodatavals_has_usable_entry(): + # First entry is non-numeric, second is a real sentinel. The loop + # returns -9999.0 before reaching the warn site, so no warning fires. + import warnings as _warnings + with _warnings.catch_warnings(): + _warnings.simplefilter("error") + assert _resolve_nodata_attr({'nodatavals': ('foo', -9999.0)}) == -9999.0 + + +def test_resolve_nodata_attr_no_warning_when_nodatavals_all_nan(): + # NaN entries are skipped (they signal "the float NaN is the sentinel", + # which doesn't need a GDAL_NODATA tag) but they ARE numeric, so the + # all-non-numeric warning must not fire for an all-NaN tuple. + import warnings as _warnings + with _warnings.catch_warnings(): + _warnings.simplefilter("error") + assert _resolve_nodata_attr({'nodatavals': (float('nan'),)}) is None