diff --git a/examples/13_mps/channel_demo.py b/examples/13_mps/channel_demo.py new file mode 100644 index 00000000..32b1ad10 --- /dev/null +++ b/examples/13_mps/channel_demo.py @@ -0,0 +1,81 @@ +""" +Direct Sampling demo — Strebelle (2002) channelized fluvial TI. +Adapted for GSTools. +""" + +import os +import urllib.request + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.colors import ListedColormap + +from gstools import mps + +# 1. Load TI +TI_URL = ( + "https://raw.githubusercontent.com/GeostatsGuy/" + "GeoDataSets/master/MPS_Training_image_and_Realizations_500.npz" +) +CACHE = "mps_strebelle.npz" +if not os.path.exists(CACHE): + print("Downloading Strebelle TI ...") + urllib.request.urlretrieve(TI_URL, CACHE) + +ti_arr = np.load(CACHE)["array1"].astype(int) # (256, 256) +ti_model = mps.TrainingImage(ti_arr, categorical=True) +print(f"TI shape: {ti_model.shape} sand={ti_arr.mean():.3f}") + +# 2. Conditioning: 100 random hard-data points from the TI +SG_SIZE = 100 +N_COND = 100 +rng = np.random.default_rng(0) +cond_row = rng.integers(0, SG_SIZE, N_COND) +cond_col = rng.integers(0, SG_SIZE, N_COND) +cond_pos = [cond_row.astype(float), cond_col.astype(float)] +cond_val = ti_arr[cond_row, cond_col].astype(float) +print(f"Conditioning: {N_COND} pts sand={cond_val.mean():.3f}") + +# 3. Simulate (n=30, f=1.0, t=0.01) +N_NEIGH = 30 +SCAN_F = 0.1 +THRESH = 0.01 + +ds = mps.DirectSampling( + ti_model, n_neighbors=N_NEIGH, scan_fraction=SCAN_F, threshold=THRESH +) +ds.set_condition(cond_pos, cond_val) + +print("Starting simulation (this may take a moment)...") +x = np.arange(SG_SIZE, dtype=float) +y = np.arange(SG_SIZE, dtype=float) +sg = ds([x, y], seed=42).astype(int) + +honored = (sg[cond_row, cond_col] == cond_val.astype(int)).sum() +print(f"Simulation complete. Conditioning: {honored}/{N_COND} honored") + +# 4. Plot +cmap = ListedColormap(["#c9a96e", "#2b6cb0"]) +fig, ax = plt.subplots(1, 2, figsize=(12, 6)) + +ax[0].imshow(ti_arr[:SG_SIZE, :SG_SIZE], cmap=cmap, origin="upper") +ax[0].set_title("Training Image (Crop)") + +im = ax[1].imshow(sg, cmap=cmap, origin="upper") +ax[1].scatter( + cond_col, + cond_row, + c=cond_val, + cmap=cmap, + edgecolors="k", + s=20, + label="Cond. Points", +) +ax[1].set_title(f"DS Realization (n={N_NEIGH}, f={SCAN_F}, t={THRESH})") +ax[1].legend() + +plt.colorbar(im, ax=ax.ravel().tolist(), ticks=[0.25, 0.75]).set_ticklabels( + ["shale", "sand"] +) +plt.savefig("channel_demo_mps.png", dpi=150, bbox_inches="tight") +print("Saved plot to channel_demo_mps.png") diff --git a/src/gstools/__init__.py b/src/gstools/__init__.py index 4d12007c..e61f7597 100644 --- a/src/gstools/__init__.py +++ b/src/gstools/__init__.py @@ -23,10 +23,21 @@ tools transform normalizer + mps Classes ======= +Multiple Point Statistics +^^^^^^^^^^^^^^^^^^^^^^^^^ +Classes for Multiple Point Statistics (MPS) simulations + +.. currentmodule:: gstools.mps + +.. autosummary:: + DirectSampling + TrainingImage + Kriging ^^^^^^^ Swiss-Army-Knife for Kriging. For short cut classes see: :any:`gstools.krige` @@ -139,6 +150,7 @@ covmodel, field, krige, + mps, normalizer, random, tools, @@ -169,6 +181,7 @@ ) from gstools.field import PGS, SRF, CondSRF from gstools.krige import Krige +from gstools.mps import DirectSampling, TrainingImage from gstools.tools import ( DEGREE_SCALE, EARTH_RADIUS, @@ -200,7 +213,7 @@ __all__ = ["__version__"] __all__ += ["covmodel", "field", "variogram", "krige", "random", "tools"] -__all__ += ["transform", "normalizer", "config"] +__all__ += ["transform", "normalizer", "config", "mps"] __all__ += [ "CovModel", "SumModel", @@ -237,6 +250,8 @@ "SRF", "CondSRF", "PGS", + "DirectSampling", + "TrainingImage", "rotated_main_axes", "generate_grid", "generate_st_grid", diff --git a/src/gstools/mps/__init__.py b/src/gstools/mps/__init__.py new file mode 100644 index 00000000..2a3be355 --- /dev/null +++ b/src/gstools/mps/__init__.py @@ -0,0 +1,18 @@ +""" +GStools subpackage for Multiple Point Statistics (MPS). + +.. currentmodule:: gstools.mps + +Multiple Point Statistics +^^^^^^^^^^^^^^^^^^^^^^^^^ +.. autosummary:: + :toctree: + + DirectSampling + TrainingImage +""" + +from gstools.mps.direct_sampling import DirectSampling +from gstools.mps.training_image import TrainingImage + +__all__ = ["DirectSampling", "TrainingImage"] diff --git a/src/gstools/mps/direct_sampling.py b/src/gstools/mps/direct_sampling.py new file mode 100644 index 00000000..3e46ddbb --- /dev/null +++ b/src/gstools/mps/direct_sampling.py @@ -0,0 +1,775 @@ +""" +GStools subpackage providing the Direct Sampling MPS simulation class. + +.. currentmodule:: gstools.mps + +The following classes and functions are provided + +.. autosummary:: + DirectSampling +""" + +import queue +from concurrent.futures import ThreadPoolExecutor + +import numpy as np + +from gstools import config +from gstools.field.base import Field +from gstools.random.rng import RNG + +__all__ = ["DirectSampling"] + +_VALID_BOUNDARY = ("strict", "partial") + +# DS-mode scan block size. Large enough that per-call NumPy overhead is +# negligible (essentially full vectorization speed), small enough that the +# greedy DS scan does not overcompute far past the first accepted match. +# This is a call-overhead-amortization constant, not a cache-tuned one. +_SCAN_BLOCK = 4096 + + +def _precompute_offsets(shape, max_offset=None): + """Neighbour offsets from the origin, sorted by Euclidean distance. + + Parameters + ---------- + shape : tuple + Simulation grid shape. + max_offset : int, optional + Maximum offset in any dimension. + Default: ``max(shape)``. + + Returns + ------- + numpy.ndarray, shape (N, dim) + """ + dim = len(shape) + if max_offset is None: + max_offset = max(shape) + rng_vals = np.arange(-max_offset, max_offset + 1) + grid = np.array(np.meshgrid(*[rng_vals] * dim, indexing="ij")) + offsets = grid.reshape(dim, -1).T + offsets = offsets[np.any(offsets != 0, axis=1)] + idx = np.argsort(np.sum(offsets**2, axis=1)) + return offsets[idx] + + +def _select_neighbors( + x_i, + offset_arr, + sim_shape_arr, + sim_shape, + path_pos_map, + curr_idx, + informed, + max_radius, + n_neighbors, +): + """Closest valid neighbours of ``x_i``, with their path indices. + + A candidate is valid when it is in bounds, has path index ``< curr_idx`` + (already-simulated in path order) or ``-1`` (conditioning data), and — if + ``informed`` is given — is marked informed. ``offset_arr`` is + distance-sorted, so slicing the first ``n_neighbors`` survivors yields the + closest ones. + + Passing ``informed=None`` treats every in-bounds lower-index/conditioning + cell as available; this is correct when building the dependency DAG, where + all earlier-path nodes are informed by definition. + + Returns + ------- + coords : numpy.ndarray, shape (m, dim) + Neighbour coordinates, ``m <= n_neighbors``. + path_idx : numpy.ndarray, shape (m,) + Path index of each neighbour (``-1`` for conditioning data). + """ + # ``offset_arr`` is distance-sorted, so the closest ``n_neighbors`` valid + # candidates are the first ``n_neighbors`` survivors and we can stop the + # moment we have them. Iterating with an early break avoids masking the + # whole offset array for every node (O(N**2) on large grids without a + # ``max_radius`` cap). Tradeoff: when fewer than ``n_neighbors`` valid + # candidates exist (sparse early-path nodes), the scan still walks the full + # ``offset_arr`` in Python; this affects only the first few nodes and is + # bounded by the ``max_radius`` ball when one is set. + dim = offset_arr.shape[1] + r_sq = max_radius * max_radius if max_radius is not None else None + found_coords = [] + found_vidx = [] + for off in offset_arr: + # Distance-sorted: the first offset beyond the radius ends the scan. + if r_sq is not None and float(off @ off) > r_sq: + break + cand = x_i + off + if np.any(cand < 0) or np.any(cand >= sim_shape_arr): + continue + vi = path_pos_map[int(np.ravel_multi_index(tuple(cand), sim_shape))] + if not (vi < curr_idx or vi == -1): + continue + if informed is not None and not informed[tuple(cand)]: + continue + found_coords.append(cand) + found_vidx.append(vi) + if len(found_coords) >= n_neighbors: + break + if found_coords: + return ( + np.array(found_coords, dtype=offset_arr.dtype), + np.array(found_vidx, dtype=path_pos_map.dtype), + ) + return ( + np.empty((0, dim), dtype=offset_arr.dtype), + np.empty(0, dtype=path_pos_map.dtype), + ) + + +def _build_dag( + path, + n_neighbors, + sim_shape, + offset_arr, + path_pos_map, + max_radius=None, +): + """Build the simulation dependency DAG. + + Edge ``j -> i`` means path node ``j`` (j < i) is among the n-closest + neighbours used when simulating node ``i``. Conditioning data carry no + edge. Uses the same vectorized neighbour selection as the simulation + (:func:`_select_neighbors`), so the resulting dependencies match the set + each node would actually pick at simulation time. + """ + N = len(path) + sim_shape_arr = np.array(sim_shape) + indegree = np.zeros(N, dtype=np.int32) + out_edges = [[] for _ in range(N)] + + for i in range(N): + _, vidx = _select_neighbors( + path[i], + offset_arr, + sim_shape_arr, + sim_shape, + path_pos_map, + i, # build-time: all path nodes with index < i are informed + None, + max_radius, + n_neighbors, + ) + # path-node neighbours (conditioning data have index -1, no edge) + for j in vidx[vidx >= 0]: + indegree[i] += 1 + out_edges[int(j)].append(i) + + return indegree, out_edges + + +def ds_simulate( + training_image, + sim_shape, + n_neighbors, + threshold, + scan_fraction, + rng, + conditions=None, + cond_weight=1.0, + boundary="strict", + max_radius=None, + num_threads=None, +): + """Direct Sampling univariate simulation (Mariethoz2010, Juda2022). + + Parameters + ---------- + training_image : TrainingImage + Training image; provides ``training_image.distance()`` and + ``training_image.adjust_value()``. + sim_shape : tuple + Simulation grid shape. + n_neighbors : int + Maximum number of neighbours in the data event (Juda2022 §2). + threshold : float + Distance threshold for early acceptance (Juda2022 §2). + ``0.0`` → DSBC mode. + scan_fraction : float + Fraction of the per-node search window to scan (Mariethoz2010 §3 ¶24). + Evaluates at most ``floor(f · |window|)`` candidates per node. + ``1.0`` → full window scan. + rng : numpy.random.RandomState + Random number generator. + conditions : dict, optional + ``{tuple_index: value}`` mapping of conditioning data. + cond_weight : float, optional + Weight δ for conditioning nodes (Mariethoz2010 §3 ¶26). + boundary : str, optional + Search-window strategy: ``"strict"`` (default) or ``"partial"``. + max_radius : float, optional + If set, SG neighbours beyond this Euclidean distance are excluded + from the data event (Mariethoz2010 §3 ¶19). + num_threads : int or None, optional + Number of threads for outer DAG parallelism. ``None`` defaults to + ``config.NUM_THREADS``. + + Returns + ------- + numpy.ndarray + """ + ti_data = training_image.data + ti_shape = np.array(ti_data.shape) + sim_shape_arr = np.array(sim_shape) + sg = np.full(sim_shape, np.nan) + is_cond = np.zeros(sim_shape, dtype=bool) + informed = np.zeros(sim_shape, dtype=bool) + + if conditions: + for idx, val in conditions.items(): + sg[idx] = val + is_cond[idx] = True + informed[idx] = True + + n_threads = ( + num_threads if num_threads is not None else (config.NUM_THREADS or 1) + ) + executor = ( + ThreadPoolExecutor(max_workers=n_threads) if n_threads > 1 else None + ) + max_off_int = int(np.ceil(max_radius)) if max_radius is not None else None + offset_arr = _precompute_offsets(sim_shape, max_off_int) + + path = np.argwhere(np.isnan(sg)) + path = path[rng.permutation(len(path))] + node_seeds = rng.randint(0, 2**32, size=len(path), dtype=np.int64) + + path_flat = np.ravel_multi_index(path.T, sim_shape) + path_pos_map = np.full(int(np.prod(sim_shape)), -1, dtype=np.intp) + path_pos_map[path_flat] = np.arange(len(path_flat)) + + def _rand_ti(node_rng): + return ti_data[tuple(node_rng.randint(0, s) for s in ti_shape)] + + def _get_neighbors(x_i, informed_in): + curr_idx = path_pos_map[ + int(np.ravel_multi_index(tuple(x_i), sim_shape)) + ] + coords, _ = _select_neighbors( + x_i, + offset_arr, + sim_shape_arr, + sim_shape, + path_pos_map, + curr_idx, + informed_in, + max_radius, + n_neighbors, + ) + return coords + + def _scan_ti(lo, win_shape, lags, de_sim, cm, ln, node_rng): + # Precondition: win_size >= 1 (callers guarantee win_lo <= win_hi on all axes). + # Also captures scan_fraction, ti_data, threshold, cond_weight from outer scope. + win_size = int(np.prod(win_shape)) + max_scan = max(1, int(scan_fraction * win_size)) + start = int(node_rng.randint(0, win_size)) + + # All scan positions in visit order — shape (max_scan,) + positions = (start + np.arange(max_scan)) % win_size + # Anchor coordinates for each position — shape (max_scan, dim) + y_all = lo + np.column_stack(np.unravel_index(positions, win_shape)) + + # lags are integer-valued float64; cast once, reuse for all candidates + int_lags = lags.astype(int) # (k, dim) + + def _de_ti(y_rows): + # TI data events for the given anchor rows — shape (len(y_rows), k) + coords = y_rows[:, None, :] + int_lags[None, :, :] + return ti_data[tuple(coords.transpose(2, 0, 1))] + + # DSBC (threshold == 0): no early exit is possible — the global minimum + # over the whole scan is required — so evaluate every candidate in a + # single vectorized call. This is the fastest path and stays exact. + if threshold <= 0: + all_de_ti = _de_ti(y_all) + all_dists = training_image.vec_distance( + de_sim, all_de_ti, cm, cond_weight, ln + ) + best_k = int(np.argmin(all_dists)) + return ti_data[tuple(y_all[best_k])], all_de_ti[best_k] + + # DS (threshold > 0): chunked vectorized scan with an early-exit + # checkpoint between blocks. Each block is a full vectorized distance + # call (so the per-element cost matches the single-call version); only + # the threshold test runs per block. Blocks advance in scan order, so + # the first under-threshold candidate found is the first one globally — + # identical to the unchunked argmax(under) result. + best_d = np.inf + best_y = None + best_de = None + for b0 in range(0, max_scan, _SCAN_BLOCK): + y_blk = y_all[b0 : b0 + _SCAN_BLOCK] + de_blk = _de_ti(y_blk) + d_blk = training_image.vec_distance( + de_sim, de_blk, cm, cond_weight, ln + ) + under = d_blk <= threshold + if np.any(under): + k = int(np.argmax(under)) + return ti_data[tuple(y_blk[k])], de_blk[k] + # No acceptable match in this block. Track the running best with a + # strict ``<`` test so that, if no candidate ever falls below the + # threshold, the returned fallback equals the global argmin with the + # same first-occurrence tie-break as the unchunked version. + k = int(np.argmin(d_blk)) + if d_blk[k] < best_d: + best_d = float(d_blk[k]) + best_y = y_blk[k] + best_de = de_blk[k] + return ti_data[tuple(best_y)], best_de + + def _simulate_node(x_i, node_rng, sg_in, informed_in): + nbrs = _get_neighbors(x_i, informed_in) + if len(nbrs) == 0: + return _rand_ti(node_rng) + + lags = (nbrs - x_i).astype(np.float64) # (k, dim) + data_event_sim = sg_in[tuple(nbrs.T)] # (k,) + cond_mask = is_cond[tuple(nbrs.T)] # (k,) + lag_norms = np.linalg.norm(lags, axis=1) # (k,) + + if boundary == "strict": + # Search window Y(L_i) — Juda2022 Eq. 5, Mariethoz2010 §3 ¶19 + win_lo = np.maximum(0, np.ceil(-lags.min(axis=0))).astype(int) + win_hi = np.minimum( + ti_shape - 1, np.floor(ti_shape - 1 - lags.max(axis=0)) + ).astype(int) + if np.any(win_lo > win_hi): + return _rand_ti(node_rng) + best_v, best_de_ti = _scan_ti( + win_lo, + tuple(win_hi - win_lo + 1), + lags, + data_event_sim, + cond_mask, + lag_norms, + node_rng, + ) + return training_image.adjust_value( + best_v, data_event_sim, best_de_ti + ) + + else: # "partial" — Mariethoz2010 §6.2: global template reduction + # Lags are distance-sorted (closest first) because offset_arr is. + # Drop farthest neighbours one at a time until the bounding box of + # the remaining data event fits inside the TI, per the paper's + # "ignore until it becomes possible to scan" directive (§6.2). + valid_count = len(lags) + while valid_count > 0: + lags_p = lags[:valid_count] + sw_lo = np.maximum(0, np.ceil(-lags_p.min(axis=0))).astype(int) + sw_hi = np.minimum( + ti_shape - 1, np.floor(ti_shape - 1 - lags_p.max(axis=0)) + ).astype(int) + if np.all(sw_lo <= sw_hi): + break + valid_count -= 1 + else: + # No subset of the data event fits inside the TI (the closest + # neighbour's lag already exceeds the TI in some dimension). + # Recover like the empty-window case in strict mode rather than + # aborting the whole simulation. + return _rand_ti(node_rng) + best_v, best_de_ti = _scan_ti( + sw_lo, + tuple(sw_hi - sw_lo + 1), + lags_p, + data_event_sim[:valid_count], + cond_mask[:valid_count], + lag_norms[:valid_count], + node_rng, + ) + # For variation distance, adjust_value uses the mean of the + # truncated data event (valid_count neighbours), not the full + # neighbourhood mean. This is intentional — the mean-shift + # must be consistent with the lags actually used in the scan. + return training_image.adjust_value( + best_v, data_event_sim[:valid_count], best_de_ti + ) + + try: + if executor is not None: + indegree, out_edges = _build_dag( + path, + n_neighbors, + sim_shape, + offset_arr, + path_pos_map, + max_radius, + ) + # Running ready-queue: a node is dispatched the instant its last + # dependency completes (no per-wave barrier). Workers read the + # live sg / informed arrays; this is safe because (1) a node is + # only submitted once all its dependencies are written, so the + # values it reads are final, and (2) all shared-state mutation + # (sg, informed, in-degree, submission) happens on this main + # thread — workers only read. Each numpy access holds the GIL for + # its duration, so element reads never tear against the writes. + # The result of every node depends only on its seed and its + # (final) neighbour values, so the output is independent of + # completion order and stays identical to the serial run. + done_q = queue.Queue() + counts = {"submitted": 0, "done": 0} + + def _run(i): + return i, _simulate_node( + path[i], + RNG(int(node_seeds[i])).random, + sg, + informed, + ) + + def _submit(i): + executor.submit(_run, i).add_done_callback(done_q.put) + counts["submitted"] += 1 + + for i in range(len(path)): + if indegree[i] == 0: + _submit(i) + + while counts["done"] < counts["submitted"]: + i, val = done_q.get().result() + counts["done"] += 1 + x_i_t = tuple(path[i]) + if np.isnan(val): + raise ValueError( + f"Simulation produced NaN at {path[i]}. Check TI data." + ) + sg[x_i_t] = val + informed[x_i_t] = True + for j in out_edges[i]: + indegree[j] -= 1 + if indegree[j] == 0: + _submit(j) + else: + for i, x_i in enumerate(path): + x_i_t = tuple(x_i) + val = _simulate_node( + x_i, + RNG(int(node_seeds[i])).random, + sg, + informed, + ) + if np.isnan(val): + raise ValueError( + f"Simulation produced NaN at {x_i}. Check TI data." + ) + sg[x_i_t] = val + informed[x_i_t] = True + finally: + if executor is not None: + executor.shutdown(wait=True) + + return sg + + +class DirectSampling(Field): + """Multiple Point Statistics simulation using Direct Sampling. + + Subclasses :class:`gstools.field.base.Field`. Takes a :class:`TrainingImage` + (analogous to :class:`CovModel`) and produces fields on structured grids. + + Parameters + ---------- + ti : TrainingImage + The training image (the MPS model). + n_neighbors : int, optional + Maximum neighbors in data event. Default: 32. + scan_fraction : float, optional + Fraction of the per-node search window to scan. Default: 1. + threshold : float, optional + Distance threshold. 0.0 -> DSBC mode. Default: 0.0. + cond_weight : float, optional + Weight for conditioning nodes in distance. Default: 1.0. + boundary : str, optional + Search-window strategy: ``"strict"`` (default) or ``"partial"``. + max_radius : float, optional + Exclude SG neighbours beyond this Euclidean distance from the + data event. Default: ``None`` (no limit). + The minimum effective value is 1.0 (the grid-cell Euclidean + distance to the nearest neighbour). Values in ``(0, 1)`` accept + no neighbours at all, making every node fall back to a random TI + sample. + seed : int or nan, optional + Master RNG seed. Default: nan. + """ + + default_field_names = ["field"] + + def __init__( + self, + ti, + n_neighbors=32, + scan_fraction=1, + threshold=0.0, + cond_weight=1.0, + boundary="strict", + max_radius=None, + num_threads=None, + seed=np.nan, + ): + if boundary not in _VALID_BOUNDARY: + raise ValueError( + f"DirectSampling: boundary must be one of {_VALID_BOUNDARY!r}, " + f"got {boundary!r}" + ) + if int(n_neighbors) < 1: + raise ValueError( + f"DirectSampling: n_neighbors must be >= 1, got {n_neighbors!r}" + ) + if not (0 < float(scan_fraction) <= 1): + raise ValueError( + f"DirectSampling: scan_fraction must be in (0, 1], " + f"got {scan_fraction!r}" + ) + if float(threshold) < 0: + raise ValueError( + f"DirectSampling: threshold must be >= 0, got {threshold!r}" + ) + if float(threshold) > 1.0: + import warnings + + warnings.warn( + "threshold > 1.0 guarantees the first candidate is always accepted.", + stacklevel=2, + ) + if max_radius is not None and float(max_radius) <= 0: + raise ValueError( + f"DirectSampling: max_radius must be a positive float, " + f"got {max_radius!r}" + ) + super().__init__(model=None, dim=ti.ndim, value_type="scalar") + self._ti = ti + self._n_neighbors = int(n_neighbors) + self._scan_fraction = float(scan_fraction) + self._threshold = float(threshold) + self._cond_weight = float(cond_weight) + self._boundary = boundary + self._max_radius = ( + float(max_radius) if max_radius is not None else None + ) + self._num_threads = num_threads + self._cond_pos = None + self._cond_val = None + self.rng = RNG(None if np.isnan(seed) else int(seed)) + + def __call__( + self, + pos=None, + seed=np.nan, + mesh_type="structured", + post_process=True, + store=True, + ): + """Generate the spatial random field via Direct Sampling. + + The field is saved as ``self.field`` and is also returned. + + Parameters + ---------- + pos : :class:`list`, optional + The position tuple, containing main direction and transversal + directions. Only structured grids are supported. + seed : :class:`int`, optional + Seed for the RNG. If ``np.nan``, the current seed is kept. + Default: ``np.nan`` + mesh_type : :class:`str`, optional + Grid type. Must be ``"structured"``. + Default: ``"structured"`` + post_process : :class:`bool`, optional + Whether to apply post-processing transformations (mean, + normalizer, trend) to the field. Default: :any:`True` + store : :class:`bool` or :class:`str`, optional + Whether to store the field (``True``), not store it (``False``), + or store it under a custom name (string). + Default: :any:`True` + + Returns + ------- + field : :class:`numpy.ndarray` + The simulated field. + """ + if mesh_type != "structured": + raise ValueError( + "DirectSampling: only structured grids are supported." + ) + name, save = self.get_store_config(store) + pos, shape = self.pre_pos(pos, mesh_type) + conditions = self._conditions_to_grid(self.pos) + if not np.isnan(seed): + self.rng.seed = int(seed) + rng = np.random.RandomState( + int(self.rng.random.randint(0, 2**32, dtype=np.int64)) + ) + field = ds_simulate( + training_image=self._ti, + sim_shape=shape, + n_neighbors=self._n_neighbors, + threshold=self._threshold, + scan_fraction=self._scan_fraction, + rng=rng, + conditions=conditions, + cond_weight=self._cond_weight, + boundary=self._boundary, + max_radius=self._max_radius, + num_threads=self._num_threads, + ) + return self.post_field(field, name, post_process, save) + + def _conditions_to_grid(self, axes): + """Smart snapping: Mariethoz 2010 collision rule.""" + if self._cond_pos is None: + return {} + candidates = {} # idx -> (val, dist_sq) + for k in range(self._cond_val.shape[0]): + idx = tuple( + int(np.argmin(np.abs(axes[d] - self._cond_pos[d][k]))) + for d in range(self.dim) + ) + dist_sq = sum( + (axes[d][idx[d]] - self._cond_pos[d][k]) ** 2 + for d in range(self.dim) + ) + if idx not in candidates or dist_sq < candidates[idx][1]: + candidates[idx] = (self._cond_val[k], dist_sq) + return {idx: val for idx, (val, _) in candidates.items()} + + def set_condition(self, cond_pos, cond_val, cond_weight=None): + """Set the conditioning data for the simulation. + + Parameters + ---------- + cond_pos : :class:`list` + The position tuple of the conditioning data ``(x, [y, z])``. + cond_val : :class:`numpy.ndarray` + The values at the conditioning positions. + cond_weight : :class:`float`, optional + Conditioning weight δ. If given, overrides the ``cond_weight`` + set at construction. Default: :any:`None` (keep existing weight) + """ + from gstools.krige.tools import set_condition as _gs_set_condition + + self._cond_pos, self._cond_val = _gs_set_condition( + cond_pos, cond_val, self.dim + ) + if cond_weight is not None: + self._cond_weight = float(cond_weight) + + @property + def ti(self): + """TrainingImage: The training image model.""" + return self._ti + + @property + def n_neighbors(self): + """:class:`int`: Maximum neighbours in the data event.""" + return self._n_neighbors + + @n_neighbors.setter + def n_neighbors(self, value): + if int(value) < 1: + raise ValueError( + f"DirectSampling: n_neighbors must be >= 1, got {value!r}" + ) + self._n_neighbors = int(value) + + @property + def scan_fraction(self): + """:class:`float`: Fraction of the per-node search window to scan.""" + return self._scan_fraction + + @scan_fraction.setter + def scan_fraction(self, value): + if not (0 < float(value) <= 1): + raise ValueError( + f"DirectSampling: scan_fraction must be in (0, 1], got {value!r}" + ) + self._scan_fraction = float(value) + + @property + def threshold(self): + """:class:`float`: Distance threshold (0.0 → DSBC mode).""" + return self._threshold + + @threshold.setter + def threshold(self, value): + if float(value) < 0: + raise ValueError( + f"DirectSampling: threshold must be >= 0, got {value!r}" + ) + if float(value) > 1.0: + import warnings + + warnings.warn( + "threshold > 1.0 guarantees the first candidate is always accepted.", + stacklevel=2, + ) + self._threshold = float(value) + + @property + def cond_weight(self): + """:class:`float`: Weight for conditioning nodes in distance.""" + return self._cond_weight + + @cond_weight.setter + def cond_weight(self, value): + self._cond_weight = float(value) + + @property + def boundary(self): + """:class:`str`: Search-window strategy (``"strict"`` or ``"partial"``).""" + return self._boundary + + @boundary.setter + def boundary(self, value): + if value not in _VALID_BOUNDARY: + raise ValueError( + f"DirectSampling: boundary must be one of {_VALID_BOUNDARY!r}, " + f"got {value!r}" + ) + self._boundary = value + + @property + def max_radius(self): + """:class:`float` or :any:`None`: Euclidean cap on SG neighbour selection. + + Values in ``(0, 1)`` disable all neighbours (nearest grid cell is + at distance 1.0), causing every node to fall back to a random TI + sample. + """ + return self._max_radius + + @max_radius.setter + def max_radius(self, value): + if value is not None and float(value) <= 0: + raise ValueError( + f"DirectSampling: max_radius must be a positive float, " + f"got {value!r}" + ) + self._max_radius = float(value) if value is not None else None + + @property + def num_threads(self): + """:class:`int` or :any:`None`: Number of threads for outer DAG parallelism.""" + return self._num_threads + + @num_threads.setter + def num_threads(self, value): + self._num_threads = None if value is None else int(value) + + def __repr__(self): + return ( + f"DirectSampling(dim={self.dim}, " + f"n_neighbors={self.n_neighbors}, " + f"scan_fraction={self.scan_fraction}, " + f"threshold={self.threshold}, " + f"boundary={self.boundary!r})" + ) diff --git a/src/gstools/mps/distance.py b/src/gstools/mps/distance.py new file mode 100644 index 00000000..32c7241c --- /dev/null +++ b/src/gstools/mps/distance.py @@ -0,0 +1,308 @@ +"""Pure distance functions for MPS pattern comparison. + +No class state — takes arrays and scalars, returns floats. +``TrainingImage.distance()`` uses these internally; other algorithms +can import them directly. +""" + +import numpy as np + +__all__ = [ + "compute_node_weights", + "categorical_dist", + "l1_dist", + "l2_dist", + "lp_dist", + "variation_dist", + "vec_categorical_dist", + "vec_l1_dist", + "vec_l2_dist", + "vec_lp_dist", + "vec_variation_dist", +] + + +def compute_node_weights( + n, lag_norms, distance_power, cond_mask=None, cond_weight=1.0 +): + """Compute normalized spatial-decay weights for a data event. + + Combines spatial decay (Mariethoz2010 Eq. 5) with conditioning data + multipliers (Mariethoz2010 §3 ¶26). + + Parameters + ---------- + n : int + Number of neighbours in the data event. + lag_norms : array-like or None, shape (n,) + Euclidean norms ``‖h_i‖`` of each lag vector. ``None`` or + ``distance_power == 0`` → uniform spatial weights. + distance_power : float + Exponent δ. ``0.0`` → uniform. + cond_mask : array-like of bool, optional + ``True`` where the neighbour is a conditioning datum. + cond_weight : float, optional + Bonus weight multiplier for conditioning nodes. + + Returns + ------- + numpy.ndarray, shape (n,) + Node weights normalized to sum to 1. + """ + if lag_norms is not None and distance_power != 0.0: + norms = np.asarray(lag_norms, dtype=np.float64) + norms = np.where(norms == 0.0, 1e-10, norms) + raw_w = norms ** (-distance_power) + else: + raw_w = np.ones(n, dtype=np.float64) + + if cond_mask is not None: + raw_w = raw_w.copy() + raw_w[np.asarray(cond_mask, dtype=bool)] *= cond_weight + + total = raw_w.sum() + if not np.isfinite(total) or total == 0.0: + # e.g. all neighbours are conditioning data with cond_weight == 0: + # fall back to uniform weights rather than emit NaNs. + return np.full(n, 1.0 / n, dtype=np.float64) + return raw_w / total + + +def categorical_dist(data_event_sim, data_event_ti, node_weights): + """Weighted categorical distance (Mariethoz2010 Eq. 3). + + Parameters + ---------- + data_event_sim : numpy.ndarray, shape (n,) + data_event_ti : numpy.ndarray, shape (n,) + node_weights : numpy.ndarray, shape (n,) + Normalized spatial and conditioning weights. + + Returns + ------- + float + Distance in [0, 1]. + """ + return float( + np.dot( + node_weights, + (data_event_sim != data_event_ti).astype(np.float64), + ) + ) + + +def l1_dist(data_event_sim, data_event_ti, node_weights, d_max): + """Weighted L1 distance / Manhattan (Mariethoz2010 Eq. 6). + + Parameters + ---------- + data_event_sim : numpy.ndarray, shape (n,) + data_event_ti : numpy.ndarray, shape (n,) + node_weights : numpy.ndarray, shape (n,) + Normalized spatial and conditioning weights. + d_max : float + Data range for normalization. + + Returns + ------- + float + Distance in [0, 1]. + """ + return float( + np.dot(node_weights, np.abs(data_event_sim - data_event_ti) / d_max) + ) + + +def l2_dist(data_event_sim, data_event_ti, node_weights, d_max): + """Weighted L2 / RMS distance (Mariethoz2010 Eq. 4–5). + + Parameters + ---------- + data_event_sim : numpy.ndarray, shape (n,) + data_event_ti : numpy.ndarray, shape (n,) + node_weights : numpy.ndarray, shape (n,) + Normalized spatial and conditioning weights. + d_max : float + Data range for normalization. + + Returns + ------- + float + Distance in [0, 1]. + """ + return float( + np.sqrt( + np.dot( + node_weights, + ((data_event_sim - data_event_ti) / d_max) ** 2, + ) + ) + ) + + +def lp_dist(data_event_sim, data_event_ti, node_weights, d_max, p): + """Weighted Lp (Minkowski) distance. + + Warning: Computationally heavier than l1_dist or l2_dist due to + the generic C-level pow() evaluation. Use only when p != 1.0 or 2.0. + + Parameters + ---------- + data_event_sim : numpy.ndarray, shape (n,) + data_event_ti : numpy.ndarray, shape (n,) + node_weights : numpy.ndarray, shape (n,) + Normalized spatial and conditioning weights. + d_max : float + Data range for normalization. + p : float + The Minkowski exponent (e.g., 1.5, 3.0, 5.0). + + Returns + ------- + float + Distance in [0, 1]. + """ + diffs = np.abs(data_event_sim - data_event_ti) / d_max + return float(np.sum(node_weights * (diffs**p)) ** (1.0 / p)) + + +def variation_dist(data_event_sim, data_event_ti, node_weights, d_max, p=2.0): + """Weighted variation distance (Mariethoz2010 Eq. 9, de-meaned). + + Parameters + ---------- + data_event_sim : numpy.ndarray, shape (n,) + data_event_ti : numpy.ndarray, shape (n,) + node_weights : numpy.ndarray, shape (n,) + Normalized spatial and conditioning weights. + d_max : float + Data range for normalization. + p : float, optional + Lp aggregation exponent. Default ``2.0`` (RMS, Mariethoz2010 Eq. 9). + + Returns + ------- + float + Distance in [0, 1]. + """ + diffs = (data_event_sim - data_event_sim.mean()) - ( + data_event_ti - data_event_ti.mean() + ) + # 2*d_max normalises the common case to [0, 1]; SG values are not bounded + # by the TI range (conditioning data / accumulated mean-shifts), so clamp. + return float( + min( + 1.0, + np.dot(node_weights, np.abs(diffs / (2 * d_max)) ** p) + ** (1.0 / p), + ) + ) + + +# --------------------------------------------------------------------------- +# Vectorized variants — same maths, operate on all TI candidates at once. +# Each accepts all_de_ti of shape (max_scan, n) and returns (max_scan,). +# np.dot(X, w) with X (max_scan, n) and w (n,) is a standard BLAS matvec. +# --------------------------------------------------------------------------- + + +def vec_categorical_dist(data_event_sim, all_de_ti, node_weights): + """Vectorized categorical distance over all TI scan candidates. + + Parameters + ---------- + data_event_sim : numpy.ndarray, shape (n,) + all_de_ti : numpy.ndarray, shape (max_scan, n) + node_weights : numpy.ndarray, shape (n,) + + Returns + ------- + numpy.ndarray, shape (max_scan,) + Distance in [0, 1] for each candidate. + """ + return np.dot( + (data_event_sim != all_de_ti).astype(np.float64), node_weights + ) + + +def vec_l1_dist(data_event_sim, all_de_ti, node_weights, d_max): + """Vectorized L1 distance over all TI scan candidates. + + Parameters + ---------- + data_event_sim : numpy.ndarray, shape (n,) + all_de_ti : numpy.ndarray, shape (max_scan, n) + node_weights : numpy.ndarray, shape (n,) + d_max : float + + Returns + ------- + numpy.ndarray, shape (max_scan,) + """ + return np.dot(np.abs(data_event_sim - all_de_ti) / d_max, node_weights) + + +def vec_l2_dist(data_event_sim, all_de_ti, node_weights, d_max): + """Vectorized L2 distance over all TI scan candidates. + + Parameters + ---------- + data_event_sim : numpy.ndarray, shape (n,) + all_de_ti : numpy.ndarray, shape (max_scan, n) + node_weights : numpy.ndarray, shape (n,) + d_max : float + + Returns + ------- + numpy.ndarray, shape (max_scan,) + """ + return np.sqrt( + np.dot(((data_event_sim - all_de_ti) / d_max) ** 2, node_weights) + ) + + +def vec_lp_dist(data_event_sim, all_de_ti, node_weights, d_max, p): + """Vectorized Lp distance over all TI scan candidates. + + Parameters + ---------- + data_event_sim : numpy.ndarray, shape (n,) + all_de_ti : numpy.ndarray, shape (max_scan, n) + node_weights : numpy.ndarray, shape (n,) + d_max : float + p : float + + Returns + ------- + numpy.ndarray, shape (max_scan,) + """ + diffs = np.abs(data_event_sim - all_de_ti) / d_max + return np.dot(diffs**p, node_weights) ** (1.0 / p) + + +def vec_variation_dist(data_event_sim, all_de_ti, node_weights, d_max, p=2.0): + """Vectorized variation distance over all TI scan candidates. + + Parameters + ---------- + data_event_sim : numpy.ndarray, shape (n,) + all_de_ti : numpy.ndarray, shape (max_scan, n) + node_weights : numpy.ndarray, shape (n,) + d_max : float + p : float, optional + Lp aggregation exponent. Default ``2.0``. + + Returns + ------- + numpy.ndarray, shape (max_scan,) + Distance in [0, 1]. + """ + de_sim_c = data_event_sim - data_event_sim.mean() + all_de_ti_c = all_de_ti - all_de_ti.mean(axis=1, keepdims=True) + diffs = de_sim_c - all_de_ti_c + # 2*d_max normalises the common case to [0, 1]; SG values are not bounded + # by the TI range (conditioning data / accumulated mean-shifts), so clamp. + return np.minimum( + 1.0, + np.dot(np.abs(diffs / (2 * d_max)) ** p, node_weights) ** (1.0 / p), + ) diff --git a/src/gstools/mps/training_image.py b/src/gstools/mps/training_image.py new file mode 100644 index 00000000..6fba49dd --- /dev/null +++ b/src/gstools/mps/training_image.py @@ -0,0 +1,292 @@ +""" +GStools subpackage providing the TrainingImage class for MPS simulations. + +.. currentmodule:: gstools.mps + +The following classes and functions are provided + +.. autosummary:: + TrainingImage +""" + +import numpy as np + +from gstools.mps.distance import ( + categorical_dist, + compute_node_weights, + l1_dist, + l2_dist, + lp_dist, + variation_dist, + vec_categorical_dist, + vec_l1_dist, + vec_l2_dist, + vec_lp_dist, + vec_variation_dist, +) + +__all__ = ["TrainingImage"] + + +class TrainingImage: + """Training image for multiple point statistics simulation. + + The MPS analogue of :class:`gstools.CovModel`: encapsulates training + data and the distance function for comparing data events. + + Parameters + ---------- + data : numpy.ndarray + Training image data (n-d array). + categorical : bool, optional + Whether the variable is categorical. Default: ``True``. + distance : str, optional + Distance metric for continuous variables: ``"l1"`` (Juda2022 + Eq. 7, default), ``"l2"`` (Mariethoz2010 Eq. 4–5), or + ``"variation"`` (Mariethoz2010 Eq. 9). Ignored when categorical. + distance_power : float, optional + Exponent δ for spatial-decay weighting of neighbours + (Mariethoz2010 Eq. 3). Applied to **all** distance types. + ``0.0`` → uniform weights (oracle-compatible default). + ``1.0`` → closer neighbours weighted more heavily. + """ + + def __init__( + self, data, categorical=True, distance="l1", distance_power=0.0 + ): + self._data = np.array(data, copy=True) + self._categorical = bool(categorical) + self._distance_power = float(distance_power) + if self._distance_power < 0: + raise ValueError("distance_power must be >= 0") + self._distance_type = distance + self._p_norm = None + self._variation_p_norm = None + if not self._categorical: + distance_lower = str(distance).lower() + if distance_lower.startswith("l"): + try: + p_val = float(distance_lower[1:]) + except ValueError: + raise ValueError( + f"TrainingImage: distance starting with 'l' must be followed by " + f"a positive number (e.g. 'l1', 'l2', 'l3.5'). Got {distance!r}" + ) + if p_val <= 0: + raise ValueError( + f"TrainingImage: Lp norm exponent must be > 0, got {p_val}." + ) + self._p_norm = p_val + elif distance_lower == "variation": + self._variation_p_norm = 2.0 + elif distance_lower.startswith("variation"): + try: + p_val = float(distance_lower[len("variation") :]) + except ValueError: + raise ValueError( + f"TrainingImage: distance starting with 'variation' must be " + f"followed by a positive number (e.g. 'variation1', 'variation1.5'). " + f"Got {distance!r}" + ) + if p_val <= 0: + raise ValueError( + f"TrainingImage: variation exponent must be > 0, got {p_val}." + ) + self._variation_p_norm = p_val + else: + raise ValueError( + f"TrainingImage: distance must be 'l

' (e.g. 'l1', 'l2'), " + f"'variation', or 'variation

' (e.g. 'variation1'). " + f"Got {distance!r}" + ) + + if not self._categorical: + dmax = float(self._data.max() - self._data.min()) + self._d_max = dmax if dmax > 0 else 1.0 + else: + self._d_max = None + + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ + + @property + def data(self): + """numpy.ndarray: Raw training image data.""" + return self._data + + @property + def ndim(self): + """int: Number of spatial dimensions.""" + return self._data.ndim + + @property + def shape(self): + """tuple: Shape of the training image.""" + return self._data.shape + + @property + def categorical(self): + """bool: Whether the variable is categorical.""" + return self._categorical + + @property + def distance_type(self): + """str: Distance metric (e.g. ``"l1"``, ``"l2"``, ``"variation"``, ``"variation1"``).""" + return self._distance_type + + @property + def distance_power(self): + """float: Spatial-decay exponent δ for node weighting.""" + return self._distance_power + + # ------------------------------------------------------------------ + # Distance + # ------------------------------------------------------------------ + + def distance( + self, + data_event_sim, + data_event_ti, + cond_mask=None, + cond_weight=1.0, + lag_norms=None, + ): + """Distance between two data events. + + Applies spatial-decay weights (Mariethoz2010 Eq. 3) to all + distance types when ``distance_power > 0``. + + Parameters + ---------- + data_event_sim : array-like, shape (n,) + Values at SG neighbourhood nodes. + data_event_ti : array-like, shape (n,) + Values at TI neighbourhood nodes. + cond_mask : array-like of bool, optional + True where the neighbour is a conditioning datum. + cond_weight : float, optional + Weight multiplier δ for conditioning nodes + (Mariethoz2010 §3 ¶26). Default: ``1.0``. + lag_norms : array-like, shape (n,), optional + Euclidean norms ``‖h_i‖`` of each lag vector. Required for + spatial-decay weighting (``distance_power > 0``). + + Returns + ------- + float + Distance in [0, 1]. + """ + data_event_sim = np.asarray(data_event_sim, dtype=np.float64) + data_event_ti = np.asarray(data_event_ti, dtype=np.float64) + n = len(data_event_sim) + if n == 0: + return 0.0 + + w = compute_node_weights( + n, lag_norms, self._distance_power, cond_mask, cond_weight + ) + + if self._categorical: + return categorical_dist(data_event_sim, data_event_ti, w) + if self._p_norm == 1.0: + return l1_dist(data_event_sim, data_event_ti, w, self._d_max) + if self._p_norm == 2.0: + return l2_dist(data_event_sim, data_event_ti, w, self._d_max) + if self._p_norm is not None: + return lp_dist( + data_event_sim, data_event_ti, w, self._d_max, self._p_norm + ) + return variation_dist( + data_event_sim, + data_event_ti, + w, + self._d_max, + self._variation_p_norm, + ) + + def vec_distance( + self, + data_event_sim, + all_de_ti, + cond_mask=None, + cond_weight=1.0, + lag_norms=None, + ): + """Vectorized distance between SG data event and all TI candidates. + + Same maths as :meth:`distance` but operates on all TI scan candidates + at once, returning a distance per candidate instead of a scalar. + + Parameters + ---------- + data_event_sim : array-like, shape (n,) + Values at SG neighbourhood nodes. + all_de_ti : array-like, shape (max_scan, n) + TI data events for every scan candidate. + cond_mask : array-like of bool, optional + True where the neighbour is a conditioning datum. + cond_weight : float, optional + Weight multiplier δ for conditioning nodes. Default: ``1.0``. + lag_norms : array-like, shape (n,), optional + Euclidean norms of each lag vector. + + Returns + ------- + numpy.ndarray, shape (max_scan,) + Distance in [0, 1] for each candidate. + """ + data_event_sim = np.asarray(data_event_sim, dtype=np.float64) + all_de_ti = np.asarray(all_de_ti, dtype=np.float64) + n = len(data_event_sim) + if n == 0: + return np.zeros(len(all_de_ti)) + w = compute_node_weights( + n, lag_norms, self._distance_power, cond_mask, cond_weight + ) + if self._categorical: + return vec_categorical_dist(data_event_sim, all_de_ti, w) + if self._p_norm == 1.0: + return vec_l1_dist(data_event_sim, all_de_ti, w, self._d_max) + if self._p_norm == 2.0: + return vec_l2_dist(data_event_sim, all_de_ti, w, self._d_max) + if self._p_norm is not None: + return vec_lp_dist( + data_event_sim, all_de_ti, w, self._d_max, self._p_norm + ) + return vec_variation_dist( + data_event_sim, all_de_ti, w, self._d_max, self._variation_p_norm + ) + + def adjust_value(self, ti_val, data_event_sim, data_event_ti): + """Adjust matched TI value before assignment to SG. + + For ``distance="variation"``, applies the mean-shift correction + (Mariethoz2010 Eq. 9): Z(x_i) = Z(y) − Z̄(y) + Z̄(x_i). + For all other metrics returns *ti_val* unchanged. + + Parameters + ---------- + ti_val : float + Raw value at the matched TI node. + data_event_sim : array-like + SG data event (used to compute Z̄(x_i)). + data_event_ti : array-like + TI data event (used to compute Z̄(y)). + + Returns + ------- + float + """ + if self._variation_p_norm is None or self._categorical: + return ti_val + data_event_sim = np.asarray(data_event_sim, dtype=np.float64) + data_event_ti = np.asarray(data_event_ti, dtype=np.float64) + return float(ti_val - data_event_ti.mean() + data_event_sim.mean()) + + def __repr__(self): + return ( + f"TrainingImage(shape={self.shape}, " + f"categorical={self._categorical}, " + f"distance={self._distance_type!r})" + ) diff --git a/tests/test_mps.py b/tests/test_mps.py new file mode 100644 index 00000000..dc50d158 --- /dev/null +++ b/tests/test_mps.py @@ -0,0 +1,648 @@ +#!/usr/bin/env python +"""Unittest for the MPS module (TrainingImage and DirectSampling).""" + +import unittest + +import numpy as np + +import gstools as gs +from gstools import config as gs_config +from gstools.mps.direct_sampling import ( + DirectSampling, + _precompute_offsets, + ds_simulate, +) +from gstools.mps.distance import ( + categorical_dist, + compute_node_weights, + l1_dist, + l2_dist, + lp_dist, + variation_dist, +) +from gstools.mps.training_image import TrainingImage + +class TestDirectSamplingParallel(unittest.TestCase): + def test_valid_values(self): + rng = np.random.default_rng(0) + data = rng.integers(0, 3, (20, 20)) + ti = TrainingImage(data) + ds = DirectSampling(ti, n_neighbors=8, scan_fraction=0.2, num_threads=2) + field = ds([np.arange(8, dtype=float)] * 2, seed=0) + self.assertEqual(field.shape, (8, 8)) + self.assertTrue(np.all(np.isin(field, [0, 1, 2]))) + + def test_reproducible(self): + # DAG parallelism is deterministic: same seed → same parallel result + rng = np.random.default_rng(0) + data = rng.integers(0, 3, (20, 20)) + ti = TrainingImage(data) + ds = DirectSampling(ti, n_neighbors=8, scan_fraction=0.2, num_threads=2) + pos = [np.arange(8, dtype=float)] * 2 + self.assertTrue(np.array_equal(ds(pos, seed=7), ds(pos, seed=7))) + + def test_conditioning_preserved(self): + rng = np.random.default_rng(0) + data = rng.integers(0, 3, (20, 20)) + ti = TrainingImage(data) + ds = DirectSampling(ti, n_neighbors=4, scan_fraction=0.2, num_threads=2) + ds.set_condition([[5.0], [5.0]], [2]) + field = ds([np.arange(10, dtype=float)] * 2, seed=0) + self.assertEqual(field[5, 5], 2) + + def test_global_config(self): + # num_threads=None reads gs_config.NUM_THREADS + rng = np.random.default_rng(0) + data = rng.integers(0, 3, (20, 20)) + ti = TrainingImage(data) + pos = [np.arange(8, dtype=float)] * 2 + old = gs_config.NUM_THREADS + try: + gs_config.NUM_THREADS = 2 + field = DirectSampling(ti, n_neighbors=8, scan_fraction=0.2)( + pos, seed=7 + ) + finally: + gs_config.NUM_THREADS = old + self.assertEqual(field.shape, (8, 8)) + self.assertTrue(np.all(np.isin(field, [0, 1, 2]))) + + def test_large_batches(self): + # n_neighbors=2 → sparse DAG → large ready batches + rng = np.random.default_rng(1) + data = rng.integers(0, 2, (30, 30)) + ti = TrainingImage(data) + pos = [np.arange(12, dtype=float)] * 2 + ds = DirectSampling(ti, n_neighbors=2, scan_fraction=0.3, num_threads=4) + field = ds(pos, seed=42) + self.assertEqual(field.shape, (12, 12)) + self.assertTrue(np.all(np.isin(field, [0.0, 1.0]))) + + def test_stress(self): + # large grid, sparse DAG, conditioning, varying thread counts + rng = np.random.default_rng(3) + data = rng.integers(0, 4, (40, 40)) + ti = TrainingImage(data) + pos = [np.arange(25, dtype=float)] * 2 + cond_pos = [ + rng.integers(0, 25, 20).astype(float), + rng.integers(0, 25, 20).astype(float), + ] + cond_val = rng.integers(0, 4, 20).astype(float) + for nt in (2, 4, 8): + ds = DirectSampling( + ti, n_neighbors=3, scan_fraction=0.4, num_threads=nt + ) + ds.set_condition(cond_pos, cond_val) + field = ds(pos, seed=11) + self.assertEqual(field.shape, (25, 25)) + self.assertTrue(np.all(np.isin(field, [0, 1, 2, 3]))) + + +class TestTrainingImage(unittest.TestCase): + def setUp(self): + arr_cat = np.array([[0, 1, 0], [1, 0, 1], [0, 1, 0]], dtype=float) + self.ti_cat = TrainingImage(arr_cat, categorical=True) + + arr_cont = np.linspace(0.0, 1.0, 20) + self.ti_cont = TrainingImage( + arr_cont, categorical=False, distance="l1" + ) + + def test_properties(self): + np.testing.assert_array_equal( + self.ti_cat.data, + np.array([[0, 1, 0], [1, 0, 1], [0, 1, 0]], dtype=float), + ) + self.assertEqual(self.ti_cat.ndim, 2) + self.assertEqual(self.ti_cat.shape, (3, 3)) + self.assertTrue(self.ti_cat.categorical) + self.assertEqual( + self.ti_cat.distance_type, "l1" + ) # default ignored for cat + self.assertIsInstance(repr(self.ti_cat), str) + self.assertIn("TrainingImage", repr(self.ti_cat)) + + self.assertEqual(self.ti_cont.ndim, 1) + self.assertEqual(self.ti_cont.shape, (20,)) + self.assertFalse(self.ti_cont.categorical) + self.assertEqual(self.ti_cont.distance_type, "l1") + + def test_raise(self): + with self.assertRaises(ValueError): + TrainingImage(np.ones(10), categorical=False, distance="l0") + with self.assertRaises(ValueError): + TrainingImage(np.ones(10), categorical=False, distance="labc") + with self.assertRaises(ValueError): + TrainingImage(np.ones(10), categorical=False, distance="invalid") + + def test_distance_categorical(self): + # Identical events → 0.0 + a = np.array([0.0, 1.0, 0.0]) + dist = self.ti_cat.distance(a, a) + self.assertAlmostEqual(dist, 0.0) + + # Completely mismatched, uniform weights → 1.0 + b = np.array([1.0, 0.0, 1.0]) + dist = self.ti_cat.distance(a, b) + self.assertAlmostEqual(dist, 1.0) + + # One of three mismatched → 1/3 + c = np.array([1.0, 1.0, 0.0]) + dist = self.ti_cat.distance(a, c) + self.assertAlmostEqual(dist, 1.0 / 3.0) + + # Two of four mismatched → 0.5 (spec-required half-mismatch case) + a4 = np.array([0.0, 1.0, 0.0, 1.0]) + c4 = np.array([1.0, 0.0, 0.0, 1.0]) + dist = self.ti_cat.distance(a4, c4) + self.assertAlmostEqual(dist, 0.5) + + def test_distance_continuous(self): + x = np.array([0.0, 0.5, 1.0]) + y = np.array([0.2, 0.3, 0.8]) + + # l1 + ti_l1 = TrainingImage( + np.linspace(0.0, 1.0, 10), categorical=False, distance="l1" + ) + self.assertAlmostEqual(ti_l1.distance(x, x), 0.0) + self.assertAlmostEqual(ti_l1.distance(x, y), 0.2, places=6) + + # l2 + ti_l2 = TrainingImage( + np.linspace(0.0, 1.0, 10), categorical=False, distance="l2" + ) + self.assertAlmostEqual(ti_l2.distance(x, x), 0.0) + self.assertAlmostEqual(ti_l2.distance(x, y), 0.2, places=6) + + # lp (p=3.5) — non-uniform diffs [0.3, 0.1, 0.3] distinguish lp from l1/l2 + y_lp = np.array([0.3, 0.4, 0.7]) + ti_lp = TrainingImage( + np.linspace(0.0, 1.0, 10), categorical=False, distance="l3.5" + ) + self.assertAlmostEqual(ti_lp.distance(x, x), 0.0) + self.assertAlmostEqual(ti_lp.distance(x, y_lp), 0.2680, places=3) + self.assertGreater(ti_lp.distance(x, y_lp), ti_l1.distance(x, y_lp)) + + # variation (default p=2) + ti_var = TrainingImage( + np.linspace(0.0, 1.0, 10), categorical=False, distance="variation" + ) + self.assertAlmostEqual(ti_var.distance(x, x), 0.0) + self.assertAlmostEqual(ti_var.distance(x, y), 0.094281, places=5) + # constant shift → distance = 0 (key behavioral property of variation distance) + self.assertAlmostEqual(ti_var.distance(x, x + 0.15), 0.0, places=10) + + # variation1 (L^1 aggregation) + ti_var1 = TrainingImage( + np.linspace(0.0, 1.0, 10), categorical=False, distance="variation1" + ) + self.assertAlmostEqual(ti_var1.distance(x, x), 0.0) + self.assertAlmostEqual(ti_var1.distance(x, y), 0.08889, places=4) + self.assertAlmostEqual(ti_var1.distance(x, x + 0.15), 0.0, places=10) + # L^1 < L^2 for non-uniform diffs + self.assertLess(ti_var1.distance(x, y), ti_var.distance(x, y)) + + # variation2 explicit matches variation (regression guard) + ti_var2 = TrainingImage( + np.linspace(0.0, 1.0, 10), categorical=False, distance="variation2" + ) + self.assertAlmostEqual( + ti_var2.distance(x, y), ti_var.distance(x, y), places=10 + ) + + def test_adjust_value(self): + # Categorical and lp: passthrough + self.assertAlmostEqual( + self.ti_cat.adjust_value( + 0.7, np.array([0.1, 0.3]), np.array([0.4, 0.6]) + ), + 0.7, + ) + self.assertAlmostEqual( + self.ti_cont.adjust_value( + 0.7, np.array([0.1, 0.3]), np.array([0.4, 0.6]) + ), + 0.7, + ) + + # variation: Z(y) - Z_bar(y) + Z_bar(x) = 0.7 - 0.6 + 0.3 = 0.4 + ti_var = TrainingImage( + np.linspace(0.0, 1.0, 20), categorical=False, distance="variation" + ) + result = ti_var.adjust_value( + 0.7, np.array([0.1, 0.3, 0.5]), np.array([0.4, 0.6, 0.8]) + ) + self.assertAlmostEqual(result, 0.4, places=6) + self.assertNotAlmostEqual(result, 0.7) # must not be passthrough + + def test_distance_weights(self): + a = np.array([0.0, 1.0, 0.0]) + b = np.array([1.0, 1.0, 0.0]) # first element differs + + # cond_weight=2 on first node → it gets weight 0.5 (double) + d1 = self.ti_cat.distance( + a, b, cond_mask=[True, False, False], cond_weight=1.0 + ) + d2 = self.ti_cat.distance( + a, b, cond_mask=[True, False, False], cond_weight=2.0 + ) + self.assertGreater(d2, d1) + + # distance_power shifts weight toward closer neighbours — use non-uniform + # differences so the weighted sums actually differ: diffs = [0, 0, 0.5] + ti_p = TrainingImage( + np.linspace(0.0, 1.0, 10), + categorical=False, + distance="l1", + distance_power=1.0, + ) + ti_flat = TrainingImage( + np.linspace(0.0, 1.0, 10), + categorical=False, + distance="l1", + distance_power=0.0, + ) + x = np.array([0.0, 0.5, 1.0]) + z = np.array([0.0, 0.5, 0.5]) # only third element differs + lags = np.array([1.0, 2.0, 3.0]) + d_power = ti_p.distance(x, z, lag_norms=lags) + d_flat = ti_flat.distance(x, z, lag_norms=lags) + # power=1 weights far neighbours less → smaller distance for far mismatch + self.assertLess(d_power, d_flat) + + def test_distance_empty_event(self): + dist = self.ti_cat.distance(np.array([]), np.array([])) + self.assertAlmostEqual(dist, 0.0) + + def test_distance_functions_directly(self): + a = np.array([0.0, 1.0, 0.0]) + b = np.array([1.0, 0.0, 1.0]) + w = np.array([1 / 3, 1 / 3, 1 / 3]) + + # weights sum to 1 + w2 = compute_node_weights(3, None, 0.0) + self.assertAlmostEqual(w2.sum(), 1.0) + + # cond_weight=2 on first node → uniform spatial weights → w[0] = 2/(2+1+1) = 0.5 + w3 = compute_node_weights( + 3, + None, + 0.0, + cond_mask=[True, False, False], + cond_weight=2.0, + ) + self.assertAlmostEqual(w3.sum(), 1.0) + self.assertAlmostEqual(w3[0], 0.5, places=6) + + # categorical: identical → 0, opposite → 1 + self.assertAlmostEqual(categorical_dist(a, a, w), 0.0) + self.assertAlmostEqual(categorical_dist(a, b, w), 1.0) + + # continuous: identical → 0 + x = np.array([0.0, 0.5, 1.0]) + d_max = 1.0 + self.assertAlmostEqual(l1_dist(x, x, w, d_max), 0.0) + self.assertAlmostEqual(l2_dist(x, x, w, d_max), 0.0) + self.assertAlmostEqual(lp_dist(x, x, w, d_max, 3.5), 0.0) + self.assertAlmostEqual(variation_dist(x, x, w, d_max), 0.0) + self.assertAlmostEqual(variation_dist(x, x, w, d_max, p=1.0), 0.0) + + # distances in [0, 1] + y = np.array([0.2, 0.3, 0.8]) + self.assertAlmostEqual(l1_dist(x, y, w, d_max), 0.2, places=6) + self.assertAlmostEqual(l2_dist(x, y, w, d_max), 0.2, places=6) + self.assertAlmostEqual( + variation_dist(x, y, w, d_max), 0.094281, places=5 + ) + self.assertAlmostEqual( + variation_dist(x, y, w, d_max, p=1.0), 0.08889, places=4 + ) + # p=2 explicit matches default + self.assertAlmostEqual( + variation_dist(x, y, w, d_max, p=2.0), + variation_dist(x, y, w, d_max), + places=10, + ) + + # lp: non-uniform diffs [0.3, 0.1, 0.3] verify the p-norm exponent is used + y_lp = np.array([0.3, 0.4, 0.7]) + self.assertAlmostEqual( + lp_dist(x, y_lp, w, d_max, 3.5), 0.2680, places=3 + ) + self.assertGreater( + lp_dist(x, y_lp, w, d_max, 3.5), l1_dist(x, y_lp, w, d_max) + ) + + def test_variation_dist_bounded(self): + """variation_dist with distance_power > 0 must stay in [0, 1].""" + # Adversarial: weight concentrated on maximally anti-correlated element + x = np.array([0.0, 1.0, 0.0]) + y = np.array([1.0, 0.0, 1.0]) + lags = np.array([10.0, 0.1, 10.0]) + w = compute_node_weights(3, lags, 1.0) + d = variation_dist(x, y, w, 1.0) + self.assertGreaterEqual(d, 0.0) + self.assertLessEqual(d, 1.0) + self.assertAlmostEqual(d, 0.661747, places=5) + # Also via TrainingImage.distance() + ti = TrainingImage( + np.linspace(0.0, 1.0, 10), + categorical=False, + distance="variation", + distance_power=1.0, + ) + self.assertLessEqual(ti.distance(x, y, lag_norms=lags), 1.0) + + def test_variation_dist_out_of_range_clamped(self): + """Out-of-range SG values (from conditioning / mean-shift) must clamp to [0, 1].""" + ti = TrainingImage( + np.linspace(0.0, 1.0, 10), # d_max == 1.0 + categorical=False, + distance="variation", + ) + de_sim = np.array([5.0, 0.0]) # 5.0 is far outside the TI range + de_ti = np.array([0.0, 1.0]) + self.assertLessEqual(ti.distance(de_sim, de_ti), 1.0) + vec = ti.vec_distance(de_sim, de_ti[np.newaxis, :]) + self.assertEqual(vec.shape, (1,)) + self.assertLessEqual(vec[0], 1.0) + + def test_variation_lp_parsing(self): + """variation

string is parsed correctly and rejects bad inputs.""" + data = np.linspace(0.0, 1.0, 10) + for spec in ("variation", "variation1", "variation1.5", "variation2"): + ti = TrainingImage(data, categorical=False, distance=spec) + self.assertEqual(ti.distance_type, spec) + # invalid suffix + with self.assertRaises(ValueError): + TrainingImage(data, categorical=False, distance="variationX") + # non-positive exponent + with self.assertRaises(ValueError): + TrainingImage(data, categorical=False, distance="variation0") + with self.assertRaises(ValueError): + TrainingImage(data, categorical=False, distance="variation-1") + + def test_variation_lp_adjust_value(self): + """adjust_value mean-shift applies for all variation

variants.""" + de_sim = np.array([0.1, 0.3, 0.5]) # mean = 0.3 + de_ti = np.array([0.4, 0.6, 0.8]) # mean = 0.6 + # expected: 0.7 - 0.6 + 0.3 = 0.4 + for spec in ("variation", "variation1", "variation1.5"): + ti = TrainingImage( + np.linspace(0.0, 1.0, 20), categorical=False, distance=spec + ) + self.assertAlmostEqual( + ti.adjust_value(0.7, de_sim, de_ti), 0.4, places=6 + ) + + def test_node_weights_zero_cond_weight(self): + """All-conditioning event with cond_weight=0 must not yield NaN weights.""" + w = compute_node_weights( + 3, + lag_norms=None, + distance_power=0.0, + cond_mask=np.array([True, True, True]), + cond_weight=0.0, + ) + self.assertTrue(np.all(np.isfinite(w))) + self.assertAlmostEqual(w.sum(), 1.0) + np.testing.assert_allclose(w, np.full(3, 1.0 / 3.0)) + + +class TestDirectSampling(unittest.TestCase): + def setUp(self): + # 1-D categorical TI: alternating 0/1, length 20 + arr1d = np.tile([0, 1], 10).astype(float) + self.ti1d = TrainingImage(arr1d, categorical=True) + + # 2-D categorical TI: 8×8 checkerboard + self.ti2d = TrainingImage( + (np.indices((8, 8)).sum(axis=0) % 2).astype(float), + categorical=True, + ) + + rng = np.random.default_rng(0) + self.ti2d_rand = TrainingImage( + rng.integers(0, 2, size=(20, 20)).astype(float), + categorical=True, + ) + + # 1-D continuous TI + self.ti1d_cont = TrainingImage( + np.linspace(0.0, 1.0, 20), categorical=False, distance="l1" + ) + + self.x1d = np.arange(10, dtype=float) + self.x2d = np.arange(6, dtype=float) + self.y2d = np.arange(6, dtype=float) + + def test_raise(self): + with self.assertRaises(ValueError): + DirectSampling(self.ti1d, boundary="bad") + with self.assertRaises(ValueError): + DirectSampling(self.ti1d, max_radius=0) + with self.assertRaises(ValueError): + DirectSampling(self.ti1d, max_radius=-1.0) + ds = DirectSampling(self.ti1d) + with self.assertRaises(ValueError): + ds([self.x1d], seed=42, mesh_type="unstructured") + + def test_repr(self): + ds = DirectSampling(self.ti1d) + r = repr(ds) + self.assertIsInstance(r, str) + self.assertIn("DirectSampling", r) + + def test_properties_and_setters(self): + ds = DirectSampling( + self.ti1d, + n_neighbors=16, + scan_fraction=0.5, + threshold=0.05, + cond_weight=2.0, + boundary="partial", + max_radius=3.0, + ) + self.assertIs(ds.ti, self.ti1d) + self.assertEqual(ds.n_neighbors, 16) + self.assertAlmostEqual(ds.scan_fraction, 0.5) + self.assertAlmostEqual(ds.threshold, 0.05) + self.assertAlmostEqual(ds.cond_weight, 2.0) + self.assertEqual(ds.boundary, "partial") + self.assertAlmostEqual(ds.max_radius, 3.0) + + ds.n_neighbors = 8 + self.assertEqual(ds.n_neighbors, 8) + ds.scan_fraction = 1.0 + self.assertAlmostEqual(ds.scan_fraction, 1.0) + ds.threshold = 0.0 + self.assertAlmostEqual(ds.threshold, 0.0) + ds.cond_weight = 1.0 + self.assertAlmostEqual(ds.cond_weight, 1.0) + + def test_offsets_shape(self): + off = _precompute_offsets((5, 5)) + # shape: (N, 2) for 2-D, no zero row + self.assertEqual(off.ndim, 2) + self.assertEqual(off.shape[1], 2) + self.assertFalse(np.any(np.all(off == 0, axis=1))) + # sorted by Euclidean norm + norms = np.linalg.norm(off, axis=1) + self.assertTrue(np.all(norms[:-1] <= norms[1:])) + + def test_offsets_1d(self): + off = _precompute_offsets((10,)) + self.assertEqual(off.shape[1], 1) + self.assertFalse(np.any(off == 0)) + + def test_offsets_max_offset(self): + off = _precompute_offsets((5, 5), max_offset=1) + self.assertLessEqual(np.abs(off).max(), 1) + # 2-D, max_offset=1: 3^2 - 1 = 8 neighbours + self.assertEqual(off.shape, (8, 2)) + + def test_shape_1d(self): + ds = DirectSampling(self.ti1d, n_neighbors=4, scan_fraction=1.0) + field = ds([self.x1d], seed=42) + self.assertEqual(field.shape, (10,)) + self.assertFalse(np.any(np.isnan(field))) + + def test_shape_2d(self): + ds = DirectSampling(self.ti2d, n_neighbors=4, scan_fraction=1.0) + field = ds([self.x2d, self.y2d], seed=42) + self.assertEqual(field.shape, (6, 6)) + self.assertFalse(np.any(np.isnan(field))) + # All output values must be in the TI value set {0, 1} + unique_vals = set(np.unique(field)) + self.assertTrue(unique_vals.issubset({0.0, 1.0})) + + def test_regression_1d(self): + ds = DirectSampling(self.ti1d, n_neighbors=4, scan_fraction=1.0) + field = ds([self.x1d], seed=42) + self.assertAlmostEqual(field[0], 1.0) + self.assertAlmostEqual(field[5], 0.0) + self.assertAlmostEqual(field[9], 0.0) + + def test_regression_2d(self): + ds = DirectSampling(self.ti2d, n_neighbors=4, scan_fraction=1.0) + field = ds([self.x2d, self.y2d], seed=42) + self.assertAlmostEqual(field[0, 0], 1.0) + self.assertAlmostEqual(field[2, 3], 0.0) + self.assertAlmostEqual(field[5, 5], 1.0) + + def test_seeded_reproducibility(self): + ds = DirectSampling(self.ti2d_rand, n_neighbors=8, scan_fraction=0.5) + pos = [self.x2d, self.y2d] + fa = ds(pos, seed=99) + fb = ds(pos, seed=99) + fc = ds(pos, seed=100) + # Same seed → identical output + self.assertTrue(np.allclose(fa, fb)) + # Different seed → different output + self.assertFalse(np.allclose(fa, fc)) + # Pin two values for seed=99; stable across NumPy versions because DS + # uses RandomState (MT19937) throughout, matching the rest of GSTools. + self.assertAlmostEqual(fa[0, 0], 0.0) + self.assertAlmostEqual(fa[3, 4], 1.0) + + def test_conditioning_honored(self): + ds = DirectSampling(self.ti1d, n_neighbors=4, scan_fraction=1.0) + # Three exact grid node positions — spec requires ≥ 3 to exercise multi-point handling + cond_pos = [np.array([2.0, 4.0, 7.0])] + cond_val = np.array([0.0, 1.0, 1.0]) + ds.set_condition(cond_pos, cond_val) + field = ds([self.x1d], seed=5) + self.assertAlmostEqual(field[2], 0.0) + self.assertAlmostEqual(field[4], 1.0) + self.assertAlmostEqual(field[7], 1.0) + + def test_boundary_partial(self): + ds = DirectSampling( + self.ti2d, n_neighbors=4, scan_fraction=1.0, boundary="partial" + ) + field = ds([self.x2d, self.y2d], seed=42) + self.assertEqual(field.shape, (6, 6)) + self.assertFalse(np.any(np.isnan(field))) + + def test_boundary_partial_collapse_recovers(self): + # TI far smaller than lag span → partial mode must recover, not raise + ti_tiny = TrainingImage( + np.random.default_rng(0).random((3, 3)), + categorical=False, + distance="l1", + ) + ds = DirectSampling( + ti_tiny, n_neighbors=32, scan_fraction=1.0, boundary="partial" + ) + field = ds([np.arange(30, dtype=float)] * 2, seed=1) + self.assertEqual(field.shape, (30, 30)) + self.assertFalse(np.any(np.isnan(field))) + + def test_threshold_above_one_warns_in_constructor(self): + with self.assertWarns(UserWarning): + DirectSampling(self.ti1d, threshold=5.0) + + def test_scan_fraction_window_semantics(self): + """scan_fraction=0.1 applies to the window, not the TI — no crash, valid output.""" + rng = np.random.default_rng(0) + ti = TrainingImage( + rng.integers(0, 2, (20, 20)).astype(float), categorical=True + ) + ds = DirectSampling(ti, n_neighbors=4, scan_fraction=0.1) + field = ds([np.arange(6, dtype=float)] * 2, seed=0) + self.assertEqual(field.shape, (6, 6)) + self.assertFalse(np.any(np.isnan(field))) + self.assertTrue(set(np.unique(field)).issubset({0.0, 1.0})) + + def test_max_radius(self): + ds = DirectSampling( + self.ti2d, n_neighbors=4, scan_fraction=1.0, max_radius=2.0 + ) + field = ds([self.x2d, self.y2d], seed=42) + self.assertEqual(field.shape, (6, 6)) + self.assertFalse(np.any(np.isnan(field))) + + def test_continuous_ti(self): + ds = DirectSampling( + self.ti1d_cont, n_neighbors=4, scan_fraction=1.0, threshold=0.05 + ) + field = ds([np.arange(8, dtype=float)], seed=42) + self.assertEqual(field.shape, (8,)) + self.assertFalse(np.any(np.isnan(field))) + self.assertTrue(np.all(field >= 0.0)) + self.assertTrue(np.all(field <= 1.0)) + + def test_ds_simulate_direct(self): + result = ds_simulate( + self.ti1d, + sim_shape=(8,), + n_neighbors=4, + threshold=0.0, + scan_fraction=1.0, + rng=np.random.RandomState(7), + ) + self.assertEqual(result.shape, (8,)) + self.assertFalse(np.any(np.isnan(result))) + # Check values — seeded values for ds_simulate(seed=7) with ti1d + self.assertTrue(set(np.unique(result)).issubset({0.0, 1.0})) + + def test_empty_search_window_recovery(self): + # n_neighbors >> TI size collapses search windows → must recover silently + ti_tiny = TrainingImage(np.array([0.0, 1.0, 0.0]), categorical=True) + ds = DirectSampling(ti_tiny, n_neighbors=10, scan_fraction=1.0) + field = ds([np.arange(5, dtype=float)], seed=1) + self.assertEqual(field.shape, (5,)) + self.assertFalse(np.any(np.isnan(field))) + self.assertTrue(set(np.unique(field)).issubset({0.0, 1.0})) + + def test_gstools_namespace(self): + self.assertIs(gs.DirectSampling, DirectSampling) + self.assertIs(gs.TrainingImage, TrainingImage) + self.assertIs(gs.mps.DirectSampling, DirectSampling) + self.assertIs(gs.mps.TrainingImage, TrainingImage) + + +if __name__ == "__main__": + unittest.main()