Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/mdio/builder/template_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
from mdio.builder.templates.seismic_3d_prestack_cdp import Seismic3DPreStackCDPTemplate
from mdio.builder.templates.seismic_3d_prestack_coca import Seismic3DPreStackCocaTemplate
from mdio.builder.templates.seismic_3d_prestack_shot import Seismic3DPreStackShotTemplate
from mdio.builder.templates.seismic_prestack import SeismicPreStackTemplate
from mdio.builder.templates.seismic_3d_prestack_streamer_field_records import (
Seismic3DPreStackStreamerFieldRecordsTemplate,
)

if TYPE_CHECKING:
from mdio.builder.templates.base import AbstractDatasetTemplate
Expand Down Expand Up @@ -134,7 +136,7 @@ def _register_default_templates(self) -> None:
self.register(Seismic3DPreStackCocaTemplate("depth"))

# Field (shot) data
self.register(SeismicPreStackTemplate("time"))
self.register(Seismic3DPreStackStreamerFieldRecordsTemplate("time"))
self.register(Seismic2DPreStackShotTemplate("time"))
self.register(Seismic3DPreStackShotTemplate("time"))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@

from mdio.builder.schemas.dtype import ScalarType
from mdio.builder.schemas.v1.variable import CoordinateMetadata
from mdio.builder.templates.abstract_dataset_template import AbstractDatasetTemplate
from mdio.builder.templates.base import AbstractDatasetTemplate
from mdio.builder.templates.types import SeismicDataDomain


class SeismicPreStackTemplate(AbstractDatasetTemplate):
class Seismic3DPreStackStreamerFieldRecordsTemplate(AbstractDatasetTemplate):
"""Seismic pre-stack time Dataset template.

This should be used for both 2D and 3D datasets. Common-shot or common-channel datasets
A generalized template for pre-stack field records in either 2D or 3D.
- Common-shot dataset
- Common-channel dataset

Args:
data_domain: The domain of the dataset.
Expand All @@ -24,17 +26,14 @@ def __init__(self, data_domain: SeismicDataDomain):
self._dim_names = (*self._spatial_dim_names, self._data_domain)
self._physical_coord_names = ("source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y")
self._logical_coord_names = ("orig_field_record_num",)
# TODO(Dmitriy Repin): Allow specifying full-dimension-extent chunk size in templates.
# https://github.com/TGSAI/mdio-python/issues/720
# When implemented, the following will be requesting the chunk size of the last dimension
# to be equal to the size of the dimension.
# TODO(Anyone): Disable chunking in time domain when support is merged.
# https://github.com/TGSAI/mdio-python/pull/723
# self._var_chunk_shape = (1, 1, 16, 1, 32, -1)
# For now, we are hardcoding the chunk size to 1024.
self._var_chunk_shape = (1, 1, 16, 1, 32, 1024)

@property
def _name(self) -> str:
return f"PreStackGathers3D{self._data_domain.capitalize()}"
return f"PreStackStreamerFieldRecords3D{self._data_domain.capitalize()}"

def _load_dataset_attributes(self) -> dict[str, Any]:
return {
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_import_streamer_grid_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def test_import_6d_segy( # noqa: PLR0913

segy_to_mdio(
segy_spec=segy_spec,
mdio_template=TemplateRegistry().get("PreStackGathers3DTime"), # Placeholder for the template
mdio_template=TemplateRegistry().get("PreStackStreamerFieldRecords3DTime"),
input_path=segy_path,
output_path=zarr_tmp,
overwrite=True,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
"""Unit tests for Seismic3DPreStackStreamerFieldRecordsTemplate."""

import pytest
from tests.unit.v1.helpers import validate_variable

from mdio.builder.schemas.chunk_grid import RegularChunkGrid
from mdio.builder.schemas.compressors import Blosc
from mdio.builder.schemas.compressors import BloscCname
from mdio.builder.schemas.dtype import ScalarType
from mdio.builder.schemas.dtype import StructuredType
from mdio.builder.schemas.v1.dataset import Dataset
from mdio.builder.schemas.v1.units import LengthUnitEnum
from mdio.builder.schemas.v1.units import LengthUnitModel
from mdio.builder.schemas.v1.units import TimeUnitEnum
from mdio.builder.schemas.v1.units import TimeUnitModel
from mdio.builder.templates.seismic_3d_prestack_streamer_field_records import (
Seismic3DPreStackStreamerFieldRecordsTemplate,
)

UNITS_METER = LengthUnitModel(length=LengthUnitEnum.METER)
UNITS_SECOND = TimeUnitModel(time=TimeUnitEnum.SECOND)


def _validate_coordinates_headers_trace_mask(dataset: Dataset, headers: StructuredType, domain: str) -> None:
"""Validate the coordinate, headers, trace_mask variables in the dataset."""
# Verify variables
# 6 dim coords + 5 non-dim coords + 1 data + 1 trace mask + 1 headers = 14 variables
assert len(dataset.variables) == 14

# Verify trace headers
validate_variable(
dataset,
name="headers",
dims=[("shot_line", 1), ("gun", 3), ("shot_point", 256), ("cable", 512), ("channel", 24)],
coords=["orig_field_record_num", "source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y"],
dtype=headers,
)

validate_variable(
dataset,
name="trace_mask",
dims=[("shot_line", 1), ("gun", 3), ("shot_point", 256), ("cable", 512), ("channel", 24)],
coords=["orig_field_record_num", "source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y"],
dtype=ScalarType.BOOL,
)

# Verify dimension coordinate variables
for dim_name in ["shot_line", "gun", "shot_point", "cable", "channel", domain]:
validate_variable(
dataset,
name=dim_name,
dims=[
(
dim_name,
{"shot_line": 1, "gun": 3, "shot_point": 256, "cable": 512, "channel": 24, domain: 2048}[dim_name],
)
],
coords=[dim_name],
dtype=ScalarType.INT32,
)

# Verify non-dimension coordinate variables
validate_variable(
dataset,
name="orig_field_record_num",
dims=[("shot_line", 1), ("gun", 3), ("shot_point", 256)],
coords=["orig_field_record_num"],
dtype=ScalarType.INT32,
)

# Verify coordinate variables with units
for coord_name in ["source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y"]:
coord = validate_variable(
dataset,
name=coord_name,
dims=[("shot_line", 1), ("gun", 3), ("shot_point", 256)]
+ ([("cable", 512), ("channel", 24)] if "group" in coord_name else []),
coords=[coord_name],
dtype=ScalarType.FLOAT64,
)
assert coord.metadata.units_v1.length == LengthUnitEnum.METER


class TestSeismic3DPreStackStreamerFieldRecordsTemplate:
"""Unit tests for Seismic3DPreStackStreamerFieldRecordsTemplate."""

def test_configuration(self) -> None:
"""Unit tests for Seismic3DPreStackStreamerFieldRecordsTemplate."""
t = Seismic3DPreStackStreamerFieldRecordsTemplate(data_domain="time")

# Template attributes
assert t.name == "PreStackStreamerFieldRecords3DTime"
assert t._dim_names == ("shot_line", "gun", "shot_point", "cable", "channel", "time")
assert t._physical_coord_names == ("source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y")
# TODO(Anyone): Disable chunking in time domain when support is merged.
# https://github.com/TGSAI/mdio-python/pull/723
assert t.full_chunk_shape == (1, 1, 16, 1, 32, 1024)

# Variables instantiated when build_dataset() is called
assert t._builder is None
assert t._dim_sizes == ()

# Verify dataset attributes
attrs = t._load_dataset_attributes()
assert attrs == {"surveyDimensionality": "3D", "ensembleType": "shot_point", "processingStage": "pre-stack"}
assert t.default_variable_name == "amplitude"

def test_build_dataset(self, structured_headers: StructuredType) -> None:
"""Unit tests for Seismic3DPreStackStreamerFieldRecordsTemplate build."""
t = Seismic3DPreStackStreamerFieldRecordsTemplate(data_domain="time")
t.add_units({"source_coord_x": UNITS_METER, "source_coord_y": UNITS_METER}) # spatial domain units
t.add_units({"group_coord_x": UNITS_METER, "group_coord_y": UNITS_METER}) # spatial domain units
t.add_units({"time": UNITS_SECOND}) # data domain units

dataset = t.build_dataset(
"North Sea 3D Streamer Field Records", sizes=(1, 3, 256, 512, 24, 2048), header_dtype=structured_headers
)

assert dataset.metadata.name == "North Sea 3D Streamer Field Records"
assert dataset.metadata.attributes["surveyDimensionality"] == "3D"
assert dataset.metadata.attributes["ensembleType"] == "shot_point"
assert dataset.metadata.attributes["processingStage"] == "pre-stack"

_validate_coordinates_headers_trace_mask(dataset, structured_headers, "time")

# Verify seismic variable
seismic = validate_variable(
dataset,
name="amplitude",
dims=[("shot_line", 1), ("gun", 3), ("shot_point", 256), ("cable", 512), ("channel", 24), ("time", 2048)],
coords=["orig_field_record_num", "source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y"],
dtype=ScalarType.FLOAT32,
)
assert isinstance(seismic.compressor, Blosc)
assert seismic.compressor.cname == BloscCname.zstd
assert isinstance(seismic.metadata.chunk_grid, RegularChunkGrid)
assert seismic.metadata.chunk_grid.configuration.chunk_shape == (1, 1, 16, 1, 32, 1024)
assert seismic.metadata.stats_v1 is None


@pytest.mark.parametrize("data_domain", ["Time", "TiME"])
def test_domain_case_handling(data_domain: str) -> None:
"""Test that domain parameter handles different cases correctly."""
template = Seismic3DPreStackStreamerFieldRecordsTemplate(data_domain=data_domain)
assert template._data_domain == data_domain.lower()
assert template.name.endswith(data_domain.capitalize())
Loading