Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions xrspatial/geotiff/_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
"""
from __future__ import annotations

import warnings

import numpy as np

from ._coords import (
Expand Down Expand Up @@ -206,6 +208,10 @@ def _resolve_nodata_attr(attrs: dict):
"""
nodata = attrs.get('nodata')
if nodata is not None:
try:
float(nodata)
except (TypeError, ValueError) as e:
raise ValueError(_nodata_attr_non_numeric_msg('nodata', nodata)) from e
return nodata

vals = attrs.get('nodatavals')
Expand All @@ -214,30 +220,59 @@ def _resolve_nodata_attr(attrs: dict):
seq = list(vals)
except TypeError:
seq = [vals]
saw_non_numeric = False
for v in seq:
if v is None:
continue
try:
fv = float(v)
except (TypeError, ValueError):
saw_non_numeric = True
continue
if np.isnan(fv):
continue
return v
# A tuple where every entry is non-numeric is almost certainly a
# user error (typo, stringified sentinel) rather than a legitimate
# "no sentinel" signal. Warn so the caller sees it, but still fall
# through to the rest of the resolution chain rather than raising:
# the rest of the function's contract is "skip non-numeric entries".
if saw_non_numeric:
warnings.warn(
f"attrs['nodatavals']={vals!r} contained only non-numeric "
f"entries; no usable sentinel could be resolved from it. "
f"Pass ``nodata=`` explicitly or fix the attr.",
UserWarning,
stacklevel=2,
)

fill = attrs.get('_FillValue')
if fill is not None:
try:
ffv = float(fill)
except (TypeError, ValueError):
return fill # non-numeric -- pass through verbatim
except (TypeError, ValueError) as e:
raise ValueError(_nodata_attr_non_numeric_msg('_FillValue', fill)) from e
if np.isnan(ffv):
return None
return fill

return None


def _nodata_attr_non_numeric_msg(attr_name: str, value) -> str:
"""Error string shared by the ``attrs['nodata']`` and ``attrs['_FillValue']``
non-numeric branches in ``_resolve_nodata_attr`` (#1973)."""
return (
f"attrs[{attr_name!r}]={value!r} is not numeric "
f"({type(value).__name__}). The writer needs a numeric "
f"sentinel to compare against pixel values; passing a "
f"non-numeric value would otherwise crash inside "
f"``np.isnan`` with an opaque ufunc error. Drop the "
f"attr, replace it with a numeric sentinel, or pass "
f"``nodata=`` explicitly (issue #1973)."
)


def _merge_friendly_extra_tags(extra_tags_list, attrs: dict) -> list | None:
"""Combine ``attrs['extra_tags']`` with friendly tag attrs.

Expand Down
33 changes: 33 additions & 0 deletions xrspatial/geotiff/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,3 +230,36 @@ def _validate_predictor_sample_format(predictor, sample_format) -> None:
f"pair, e.g. `gdal_translate -co PREDICTOR=2` for integers or "
f"`-co PREDICTOR=1` to drop the predictor."
)


def _validate_nodata_arg(nodata) -> None:
"""Reject non-numeric ``nodata=`` at the writer entry point (#1973).

``None`` (no sentinel) passes through. ``bool`` is rejected with
``TypeError`` so all three writer entry points (eager, GPU, VRT)
refuse ``nodata=True`` / ``nodata=False`` the same way the eager
path already does for issue #1911 -- ``float(True) == 1.0`` would
otherwise slip a bool past the numeric branch on the GPU/VRT paths
that do not have their own bool guard. Anything else is run
through ``float()``: success means the writer's downstream
``np.isnan(nodata)`` and integer-cast paths will not blow up.
Failure raises ``ValueError`` with the offending repr, so users
see ``nodata='missing'`` flagged at the boundary instead of an
opaque ``ufunc 'isnan' not supported`` TypeError from inside the
writer.
"""
if nodata is None:
return
if isinstance(nodata, (bool, np.bool_)):
raise TypeError(
f"nodata must be numeric (int or float), got {nodata!r}")
try:
float(nodata)
except (TypeError, ValueError) as e:
raise ValueError(
f"nodata must be numeric or None, got {nodata!r} "
f"(type {type(nodata).__name__}). The writer compares it "
f"against pixel values via ``np.isnan`` and casts it to "
f"the array dtype; a non-numeric value would otherwise "
f"crash inside NumPy with a ufunc TypeError."
) from e
4 changes: 4 additions & 0 deletions xrspatial/geotiff/_writers/eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
)
from .._validation import (
_validate_3d_writer_dims,
_validate_nodata_arg,
_validate_tile_size_arg,
)
from .._writer import write
Expand Down Expand Up @@ -267,6 +268,8 @@ def to_geotiff(data: xr.DataArray | np.ndarray,
if tiled:
_validate_tile_size_arg(tile_size)

_validate_nodata_arg(nodata)

# Up-front validation: catch bad compression names before they reach
# any of the deeper write paths (streaming, GPU, VRT, COG) where the
# error surfaces from _compression_tag with a less obvious traceback.
Expand Down Expand Up @@ -775,6 +778,7 @@ def _write_vrt_tiled(data, vrt_path, *, crs=None, nodata=None,
This enables streaming dask arrays to disk without materializing the
full array in RAM.
"""
_validate_nodata_arg(nodata)
# Validate compression_level against codec-specific range
if compression_level is not None:
level_range = _LEVEL_RANGES.get(compression.lower())
Expand Down
2 changes: 2 additions & 0 deletions xrspatial/geotiff/_writers/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .._runtime import GeoTIFFFallbackWarning
from .._validation import (
_validate_3d_writer_dims,
_validate_nodata_arg,
_validate_tile_size_arg,
)

Expand Down Expand Up @@ -259,6 +260,7 @@ def write_geotiff_gpu(data: xr.DataArray | cupy.ndarray | np.ndarray,
# write_geotiff_gpu is always tiled, so validate tile_size here and
# keep parity with the public to_geotiff entry point.
_validate_tile_size_arg(tile_size)
_validate_nodata_arg(nodata)
if max_z_error < 0:
raise ValueError(
f"max_z_error must be >= 0, got {max_z_error}")
Expand Down
183 changes: 183 additions & 0 deletions xrspatial/geotiff/tests/test_nodata_validation_1973.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
"""Refuse non-numeric ``nodata=`` / ``attrs['_FillValue']`` (#1973).

The writer compares the resolved nodata against pixel values via
``np.isnan`` and casts it to the array dtype. A non-numeric value
used to fall through ``_resolve_nodata_attr`` (returned verbatim) or
the ``nodata=`` kwarg path and then crash inside NumPy with
``ufunc 'isnan' not supported``. Both the entry point and the attr
resolution path now refuse non-numeric values up front with a clear
error.
"""
from __future__ import annotations

import importlib.util
import io

import numpy as np
import pytest
import xarray as xr

from xrspatial.geotiff import to_geotiff
from xrspatial.geotiff._attrs import _resolve_nodata_attr
from xrspatial.geotiff._validation import _validate_nodata_arg


def _gpu_available() -> bool:
if importlib.util.find_spec("cupy") is None:
return False
try:
import cupy
return bool(cupy.cuda.is_available())
except Exception:
return False


_HAS_GPU = _gpu_available()
_gpu_only = pytest.mark.skipif(not _HAS_GPU, reason="cupy + CUDA required")


def _nan_square():
return xr.DataArray(
np.full((4, 4), np.nan, dtype=np.float32),
coords={'y': np.arange(4.0), 'x': np.arange(4.0)},
dims=('y', 'x'),
)


@pytest.mark.parametrize("bad", ['missing', object(), [1, 2]])
def test_validate_nodata_arg_rejects_non_numeric(bad):
with pytest.raises(ValueError, match="nodata must be numeric"):
_validate_nodata_arg(bad)


@pytest.mark.parametrize("ok", [None, 0, -9999, 1.5, np.float32(-1), np.int64(0)])
def test_validate_nodata_arg_accepts_numeric_and_none(ok):
_validate_nodata_arg(ok)


def test_resolve_nodata_attr_rejects_non_numeric_fillvalue():
with pytest.raises(ValueError, match="_FillValue"):
_resolve_nodata_attr({'_FillValue': 'missing'})


def test_resolve_nodata_attr_rejects_non_numeric_nodata_attr():
with pytest.raises(ValueError, match=r"attrs\['nodata'\]"):
_resolve_nodata_attr({'nodata': 'missing'})


def test_resolve_nodata_attr_skips_non_numeric_in_nodatavals():
# nodatavals (rioxarray's per-band tuple) keeps its skip-on-non-numeric
# behaviour: those values often come from arbitrary upstream pipelines
# and a single bad entry should not block writing.
assert _resolve_nodata_attr({'nodatavals': ('NaN ', -9999.0)}) == -9999.0


def test_resolve_nodata_attr_still_accepts_numeric_fillvalue():
assert _resolve_nodata_attr({'_FillValue': -9999}) == -9999


def test_resolve_nodata_attr_returns_none_for_nan_fillvalue():
assert _resolve_nodata_attr({'_FillValue': float('nan')}) is None


def test_to_geotiff_rejects_non_numeric_nodata_kwarg():
buf = io.BytesIO()
with pytest.raises(ValueError, match="nodata must be numeric"):
to_geotiff(_nan_square(), buf, nodata='missing')


def test_to_geotiff_rejects_non_numeric_fillvalue_attr():
da = _nan_square()
da.attrs['_FillValue'] = 'missing'
buf = io.BytesIO()
with pytest.raises(ValueError, match="_FillValue"):
to_geotiff(da, buf)


def test_to_geotiff_vrt_path_rejects_non_numeric_nodata(tmp_path):
vrt_path = str(tmp_path / "tmp_1973_vrt.vrt")
with pytest.raises(ValueError, match="nodata must be numeric"):
to_geotiff(_nan_square(), vrt_path, nodata='missing')


def test_to_geotiff_accepts_numeric_nodata_kwarg():
buf = io.BytesIO()
to_geotiff(_nan_square(), buf, nodata=-9999)
assert buf.getbuffer().nbytes > 0


# ---------------------------------------------------------------------------
# Bool rejection: ``nodata=True`` / ``nodata=False`` must raise TypeError at
# all three writer entry points (eager, GPU, VRT). The eager path already
# rejected bools for #1911 but the GPU/VRT validators previously routed bool
# through ``float(True) == 1.0`` and silently coerced. The shared validator
# now refuses bools so all three paths behave the same.
# ---------------------------------------------------------------------------


@pytest.mark.parametrize("bad", [True, False])
def test_validate_nodata_arg_rejects_bool(bad):
with pytest.raises(TypeError, match="nodata must be numeric"):
_validate_nodata_arg(bad)


def test_validate_nodata_arg_rejects_numpy_bool():
with pytest.raises(TypeError, match="nodata must be numeric"):
_validate_nodata_arg(np.bool_(True))


def test_to_geotiff_eager_rejects_bool_nodata():
buf = io.BytesIO()
with pytest.raises(TypeError, match="nodata must be numeric"):
to_geotiff(_nan_square(), buf, nodata=True)


def test_to_geotiff_vrt_rejects_bool_nodata(tmp_path):
vrt_path = str(tmp_path / "tmp_1973_bool_vrt.vrt")
with pytest.raises(TypeError, match="nodata must be numeric"):
to_geotiff(_nan_square(), vrt_path, nodata=True)


@_gpu_only
def test_write_geotiff_gpu_rejects_bool_nodata(tmp_path):
import cupy

from xrspatial.geotiff import write_geotiff_gpu

da_cpu = _nan_square()
da_gpu = da_cpu.copy(data=cupy.asarray(da_cpu.values))
out = str(tmp_path / "tmp_1973_bool_gpu.tif")
with pytest.raises(TypeError, match="nodata must be numeric"):
write_geotiff_gpu(da_gpu, out, nodata=True)


# ---------------------------------------------------------------------------
# All-non-numeric ``attrs['nodatavals']``: warn but still return None and
# fall through. A tuple where every entry is non-numeric is almost certainly
# a user error rather than a legitimate "no sentinel" signal.
# ---------------------------------------------------------------------------


def test_resolve_nodata_attr_warns_when_nodatavals_all_non_numeric():
with pytest.warns(UserWarning, match="nodatavals"):
result = _resolve_nodata_attr({'nodatavals': ('foo', 'bar')})
assert result is None


def test_resolve_nodata_attr_no_warning_when_nodatavals_has_usable_entry():
# First entry is non-numeric, second is a real sentinel. The loop
# returns -9999.0 before reaching the warn site, so no warning fires.
import warnings as _warnings
with _warnings.catch_warnings():
_warnings.simplefilter("error")
assert _resolve_nodata_attr({'nodatavals': ('foo', -9999.0)}) == -9999.0


def test_resolve_nodata_attr_no_warning_when_nodatavals_all_nan():
# NaN entries are skipped (they signal "the float NaN is the sentinel",
# which doesn't need a GDAL_NODATA tag) but they ARE numeric, so the
# all-non-numeric warning must not fire for an all-NaN tuple.
import warnings as _warnings
with _warnings.catch_warnings():
_warnings.simplefilter("error")
assert _resolve_nodata_attr({'nodatavals': (float('nan'),)}) is None
Loading