diff --git a/src/mdio/builder/template_registry.py b/src/mdio/builder/template_registry.py index f5a974f8..fa690c56 100644 --- a/src/mdio/builder/template_registry.py +++ b/src/mdio/builder/template_registry.py @@ -27,7 +27,9 @@ from mdio.builder.templates.seismic_3d_prestack_cdp import Seismic3DPreStackCDPTemplate from mdio.builder.templates.seismic_3d_prestack_coca import Seismic3DPreStackCocaTemplate from mdio.builder.templates.seismic_3d_prestack_shot import Seismic3DPreStackShotTemplate -from mdio.builder.templates.seismic_prestack import SeismicPreStackTemplate +from mdio.builder.templates.seismic_3d_prestack_streamer_field_records import ( + Seismic3DPreStackStreamerFieldRecordsTemplate, +) if TYPE_CHECKING: from mdio.builder.templates.base import AbstractDatasetTemplate @@ -134,7 +136,7 @@ def _register_default_templates(self) -> None: self.register(Seismic3DPreStackCocaTemplate("depth")) # Field (shot) data - self.register(SeismicPreStackTemplate("time")) + self.register(Seismic3DPreStackStreamerFieldRecordsTemplate("time")) self.register(Seismic2DPreStackShotTemplate("time")) self.register(Seismic3DPreStackShotTemplate("time")) diff --git a/src/mdio/builder/templates/seismic_prestack.py b/src/mdio/builder/templates/seismic_3d_prestack_streamer_field_records.py similarity index 80% rename from src/mdio/builder/templates/seismic_prestack.py rename to src/mdio/builder/templates/seismic_3d_prestack_streamer_field_records.py index 7eb23e9d..b3b42512 100644 --- a/src/mdio/builder/templates/seismic_prestack.py +++ b/src/mdio/builder/templates/seismic_3d_prestack_streamer_field_records.py @@ -4,14 +4,16 @@ from mdio.builder.schemas.dtype import ScalarType from mdio.builder.schemas.v1.variable import CoordinateMetadata -from mdio.builder.templates.abstract_dataset_template import AbstractDatasetTemplate +from mdio.builder.templates.base import AbstractDatasetTemplate from mdio.builder.templates.types import SeismicDataDomain -class SeismicPreStackTemplate(AbstractDatasetTemplate): +class Seismic3DPreStackStreamerFieldRecordsTemplate(AbstractDatasetTemplate): """Seismic pre-stack time Dataset template. - This should be used for both 2D and 3D datasets. Common-shot or common-channel datasets + A generalized template for pre-stack field records in either 2D or 3D. + - Common-shot dataset + - Common-channel dataset Args: data_domain: The domain of the dataset. @@ -24,17 +26,14 @@ def __init__(self, data_domain: SeismicDataDomain): self._dim_names = (*self._spatial_dim_names, self._data_domain) self._physical_coord_names = ("source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y") self._logical_coord_names = ("orig_field_record_num",) - # TODO(Dmitriy Repin): Allow specifying full-dimension-extent chunk size in templates. - # https://github.com/TGSAI/mdio-python/issues/720 - # When implemented, the following will be requesting the chunk size of the last dimension - # to be equal to the size of the dimension. + # TODO(Anyone): Disable chunking in time domain when support is merged. + # https://github.com/TGSAI/mdio-python/pull/723 # self._var_chunk_shape = (1, 1, 16, 1, 32, -1) - # For now, we are hardcoding the chunk size to 1024. self._var_chunk_shape = (1, 1, 16, 1, 32, 1024) @property def _name(self) -> str: - return f"PreStackGathers3D{self._data_domain.capitalize()}" + return f"PreStackStreamerFieldRecords3D{self._data_domain.capitalize()}" def _load_dataset_attributes(self) -> dict[str, Any]: return { diff --git a/tests/integration/test_import_streamer_grid_overrides.py b/tests/integration/test_import_streamer_grid_overrides.py index 5149eac9..c9aad726 100644 --- a/tests/integration/test_import_streamer_grid_overrides.py +++ b/tests/integration/test_import_streamer_grid_overrides.py @@ -178,7 +178,7 @@ def test_import_6d_segy( # noqa: PLR0913 segy_to_mdio( segy_spec=segy_spec, - mdio_template=TemplateRegistry().get("PreStackGathers3DTime"), # Placeholder for the template + mdio_template=TemplateRegistry().get("PreStackStreamerFieldRecords3DTime"), input_path=segy_path, output_path=zarr_tmp, overwrite=True, diff --git a/tests/unit/v1/templates/test_seismic_3d_prestack_streamer_field_records.py b/tests/unit/v1/templates/test_seismic_3d_prestack_streamer_field_records.py new file mode 100644 index 00000000..d07bd03d --- /dev/null +++ b/tests/unit/v1/templates/test_seismic_3d_prestack_streamer_field_records.py @@ -0,0 +1,146 @@ +"""Unit tests for Seismic3DPreStackStreamerFieldRecordsTemplate.""" + +import pytest +from tests.unit.v1.helpers import validate_variable + +from mdio.builder.schemas.chunk_grid import RegularChunkGrid +from mdio.builder.schemas.compressors import Blosc +from mdio.builder.schemas.compressors import BloscCname +from mdio.builder.schemas.dtype import ScalarType +from mdio.builder.schemas.dtype import StructuredType +from mdio.builder.schemas.v1.dataset import Dataset +from mdio.builder.schemas.v1.units import LengthUnitEnum +from mdio.builder.schemas.v1.units import LengthUnitModel +from mdio.builder.schemas.v1.units import TimeUnitEnum +from mdio.builder.schemas.v1.units import TimeUnitModel +from mdio.builder.templates.seismic_3d_prestack_streamer_field_records import ( + Seismic3DPreStackStreamerFieldRecordsTemplate, +) + +UNITS_METER = LengthUnitModel(length=LengthUnitEnum.METER) +UNITS_SECOND = TimeUnitModel(time=TimeUnitEnum.SECOND) + + +def _validate_coordinates_headers_trace_mask(dataset: Dataset, headers: StructuredType, domain: str) -> None: + """Validate the coordinate, headers, trace_mask variables in the dataset.""" + # Verify variables + # 6 dim coords + 5 non-dim coords + 1 data + 1 trace mask + 1 headers = 14 variables + assert len(dataset.variables) == 14 + + # Verify trace headers + validate_variable( + dataset, + name="headers", + dims=[("shot_line", 1), ("gun", 3), ("shot_point", 256), ("cable", 512), ("channel", 24)], + coords=["orig_field_record_num", "source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y"], + dtype=headers, + ) + + validate_variable( + dataset, + name="trace_mask", + dims=[("shot_line", 1), ("gun", 3), ("shot_point", 256), ("cable", 512), ("channel", 24)], + coords=["orig_field_record_num", "source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y"], + dtype=ScalarType.BOOL, + ) + + # Verify dimension coordinate variables + for dim_name in ["shot_line", "gun", "shot_point", "cable", "channel", domain]: + validate_variable( + dataset, + name=dim_name, + dims=[ + ( + dim_name, + {"shot_line": 1, "gun": 3, "shot_point": 256, "cable": 512, "channel": 24, domain: 2048}[dim_name], + ) + ], + coords=[dim_name], + dtype=ScalarType.INT32, + ) + + # Verify non-dimension coordinate variables + validate_variable( + dataset, + name="orig_field_record_num", + dims=[("shot_line", 1), ("gun", 3), ("shot_point", 256)], + coords=["orig_field_record_num"], + dtype=ScalarType.INT32, + ) + + # Verify coordinate variables with units + for coord_name in ["source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y"]: + coord = validate_variable( + dataset, + name=coord_name, + dims=[("shot_line", 1), ("gun", 3), ("shot_point", 256)] + + ([("cable", 512), ("channel", 24)] if "group" in coord_name else []), + coords=[coord_name], + dtype=ScalarType.FLOAT64, + ) + assert coord.metadata.units_v1.length == LengthUnitEnum.METER + + +class TestSeismic3DPreStackStreamerFieldRecordsTemplate: + """Unit tests for Seismic3DPreStackStreamerFieldRecordsTemplate.""" + + def test_configuration(self) -> None: + """Unit tests for Seismic3DPreStackStreamerFieldRecordsTemplate.""" + t = Seismic3DPreStackStreamerFieldRecordsTemplate(data_domain="time") + + # Template attributes + assert t.name == "PreStackStreamerFieldRecords3DTime" + assert t._dim_names == ("shot_line", "gun", "shot_point", "cable", "channel", "time") + assert t._physical_coord_names == ("source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y") + # TODO(Anyone): Disable chunking in time domain when support is merged. + # https://github.com/TGSAI/mdio-python/pull/723 + assert t.full_chunk_shape == (1, 1, 16, 1, 32, 1024) + + # Variables instantiated when build_dataset() is called + assert t._builder is None + assert t._dim_sizes == () + + # Verify dataset attributes + attrs = t._load_dataset_attributes() + assert attrs == {"surveyDimensionality": "3D", "ensembleType": "shot_point", "processingStage": "pre-stack"} + assert t.default_variable_name == "amplitude" + + def test_build_dataset(self, structured_headers: StructuredType) -> None: + """Unit tests for Seismic3DPreStackStreamerFieldRecordsTemplate build.""" + t = Seismic3DPreStackStreamerFieldRecordsTemplate(data_domain="time") + t.add_units({"source_coord_x": UNITS_METER, "source_coord_y": UNITS_METER}) # spatial domain units + t.add_units({"group_coord_x": UNITS_METER, "group_coord_y": UNITS_METER}) # spatial domain units + t.add_units({"time": UNITS_SECOND}) # data domain units + + dataset = t.build_dataset( + "North Sea 3D Streamer Field Records", sizes=(1, 3, 256, 512, 24, 2048), header_dtype=structured_headers + ) + + assert dataset.metadata.name == "North Sea 3D Streamer Field Records" + assert dataset.metadata.attributes["surveyDimensionality"] == "3D" + assert dataset.metadata.attributes["ensembleType"] == "shot_point" + assert dataset.metadata.attributes["processingStage"] == "pre-stack" + + _validate_coordinates_headers_trace_mask(dataset, structured_headers, "time") + + # Verify seismic variable + seismic = validate_variable( + dataset, + name="amplitude", + dims=[("shot_line", 1), ("gun", 3), ("shot_point", 256), ("cable", 512), ("channel", 24), ("time", 2048)], + coords=["orig_field_record_num", "source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y"], + dtype=ScalarType.FLOAT32, + ) + assert isinstance(seismic.compressor, Blosc) + assert seismic.compressor.cname == BloscCname.zstd + assert isinstance(seismic.metadata.chunk_grid, RegularChunkGrid) + assert seismic.metadata.chunk_grid.configuration.chunk_shape == (1, 1, 16, 1, 32, 1024) + assert seismic.metadata.stats_v1 is None + + +@pytest.mark.parametrize("data_domain", ["Time", "TiME"]) +def test_domain_case_handling(data_domain: str) -> None: + """Test that domain parameter handles different cases correctly.""" + template = Seismic3DPreStackStreamerFieldRecordsTemplate(data_domain=data_domain) + assert template._data_domain == data_domain.lower() + assert template.name.endswith(data_domain.capitalize()) diff --git a/tests/unit/v1/templates/test_seismic_prestack.py b/tests/unit/v1/templates/test_seismic_prestack.py deleted file mode 100644 index 35188fa3..00000000 --- a/tests/unit/v1/templates/test_seismic_prestack.py +++ /dev/null @@ -1,220 +0,0 @@ -"""Unit tests for SeismicPreStackTemplate.""" - -import pytest -from tests.unit.v1.helpers import validate_variable - -from mdio.builder.schemas.chunk_grid import RegularChunkGrid -from mdio.builder.schemas.compressors import Blosc -from mdio.builder.schemas.compressors import BloscCname -from mdio.builder.schemas.dtype import ScalarType -from mdio.builder.schemas.dtype import StructuredType -from mdio.builder.schemas.v1.dataset import Dataset -from mdio.builder.schemas.v1.units import LengthUnitEnum -from mdio.builder.schemas.v1.units import LengthUnitModel -from mdio.builder.schemas.v1.units import TimeUnitEnum -from mdio.builder.schemas.v1.units import TimeUnitModel -from mdio.builder.templates.seismic_prestack import SeismicPreStackTemplate - -UNITS_METER = LengthUnitModel(length=LengthUnitEnum.METER) -UNITS_SECOND = TimeUnitModel(time=TimeUnitEnum.SECOND) - - -def _validate_coordinates_headers_trace_mask(dataset: Dataset, headers: StructuredType, domain: str) -> None: - """Validate the coordinate, headers, trace_mask variables in the dataset.""" - # Verify variables - # 6 dim coords + 5 non-dim coords + 1 data + 1 trace mask + 1 headers = 14 variables - assert len(dataset.variables) == 14 - - # Verify trace headers - validate_variable( - dataset, - name="headers", - dims=[("shot_line", 1), ("gun", 3), ("shot_point", 256), ("cable", 512), ("channel", 24)], - coords=["orig_field_record_num", "source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y"], - dtype=headers, - ) - - validate_variable( - dataset, - name="trace_mask", - dims=[("shot_line", 1), ("gun", 3), ("shot_point", 256), ("cable", 512), ("channel", 24)], - coords=["orig_field_record_num", "source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y"], - dtype=ScalarType.BOOL, - ) - - # Verify dimension coordinate variables - shot_line = validate_variable( - dataset, - name="shot_line", - dims=[("shot_line", 1)], - coords=["shot_line"], - dtype=ScalarType.INT32, - ) - assert shot_line.metadata is None - - gun = validate_variable( - dataset, - name="gun", - dims=[("gun", 3)], - coords=["gun"], - dtype=ScalarType.INT32, - ) - assert gun.metadata is None - - shot_point = validate_variable( - dataset, - name="shot_point", - dims=[("shot_point", 256)], - coords=["shot_point"], - dtype=ScalarType.INT32, - ) - assert shot_point.metadata is None - - cable = validate_variable( - dataset, - name="cable", - dims=[("cable", 512)], - coords=["cable"], - dtype=ScalarType.INT32, - ) - assert cable.metadata is None - - channel = validate_variable( - dataset, - name="channel", - dims=[("channel", 24)], - coords=["channel"], - dtype=ScalarType.INT32, - ) - assert channel.metadata is None - - domain_var = validate_variable( - dataset, - name=domain, - dims=[(domain, 2048)], - coords=[domain], - dtype=ScalarType.INT32, - ) - assert domain_var.metadata is None - - # Verify non-dimension coordinate variables - validate_variable( - dataset, - name="orig_field_record_num", - dims=[("shot_line", 1), ("gun", 3), ("shot_point", 256)], - coords=["orig_field_record_num"], - dtype=ScalarType.INT32, - ) - - source_coord_x = validate_variable( - dataset, - name="source_coord_x", - dims=[("shot_line", 1), ("gun", 3), ("shot_point", 256)], - coords=["source_coord_x"], - dtype=ScalarType.FLOAT64, - ) - assert source_coord_x.metadata.units_v1.length == LengthUnitEnum.METER - - source_coord_y = validate_variable( - dataset, - name="source_coord_y", - dims=[("shot_line", 1), ("gun", 3), ("shot_point", 256)], - coords=["source_coord_y"], - dtype=ScalarType.FLOAT64, - ) - assert source_coord_y.metadata.units_v1.length == LengthUnitEnum.METER - - group_coord_x = validate_variable( - dataset, - name="group_coord_x", - dims=[("shot_line", 1), ("gun", 3), ("shot_point", 256), ("cable", 512), ("channel", 24)], - coords=["group_coord_x"], - dtype=ScalarType.FLOAT64, - ) - assert group_coord_x.metadata.units_v1.length == LengthUnitEnum.METER - - group_coord_y = validate_variable( - dataset, - name="group_coord_y", - dims=[("shot_line", 1), ("gun", 3), ("shot_point", 256), ("cable", 512), ("channel", 24)], - coords=["group_coord_y"], - dtype=ScalarType.FLOAT64, - ) - assert group_coord_y.metadata.units_v1.length == LengthUnitEnum.METER - - -class TestSeismic3DPreStackShotTemplate: - """Unit tests for SeismicPreStackTemplate.""" - - def test_configuration(self) -> None: - """Unit tests for SeismicPreStackTemplate in time domain.""" - t = SeismicPreStackTemplate(data_domain="time") - - # Template attributes for prestack shot - assert t.name == "PreStackGathers3DTime" - assert t.default_variable_name == "amplitude" - assert t.trace_domain == "time" - assert t.spatial_dimension_names == ("shot_line", "gun", "shot_point", "cable", "channel") - assert t.dimension_names == ("shot_line", "gun", "shot_point", "cable", "channel", "time") - assert t.physical_coordinate_names == ("source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y") - assert t.logical_coordinate_names == ("orig_field_record_num",) - assert t.coordinate_names == ( - "source_coord_x", - "source_coord_y", - "group_coord_x", - "group_coord_y", - "orig_field_record_num", - ) - assert t.full_chunk_size == (1, 1, 16, 1, 32, 1024) - - # Variables instantiated when build_dataset() is called - assert t._builder is None - assert t._dim_sizes == () - assert t._units == {} - - # Verify prestack shot attributes - attrs = t._load_dataset_attributes() - assert attrs == {"surveyDimensionality": "3D", "ensembleType": "shot_point", "processingStage": "pre-stack"} - assert t.default_variable_name == "amplitude" - - assert t.name == "PreStackGathers3DTime" - - def test_build_dataset(self, structured_headers: StructuredType) -> None: - """Unit tests for SeismicPreStackTemplate build in time domain.""" - t = SeismicPreStackTemplate(data_domain="time") - t.add_units({"source_coord_x": UNITS_METER, "source_coord_y": UNITS_METER}) # spatial domain units - t.add_units({"group_coord_x": UNITS_METER, "group_coord_y": UNITS_METER}) # spatial domain units - t.add_units({"time": UNITS_SECOND}) # data domain units - - dataset = t.build_dataset( - "North Sea 3D Shot Time", sizes=(1, 3, 256, 512, 24, 2048), header_dtype=structured_headers - ) - - assert dataset.metadata.name == "North Sea 3D Shot Time" - assert dataset.metadata.attributes["surveyDimensionality"] == "3D" - assert dataset.metadata.attributes["ensembleType"] == "shot_point" - assert dataset.metadata.attributes["processingStage"] == "pre-stack" - - _validate_coordinates_headers_trace_mask(dataset, structured_headers, "time") - - # Verify seismic variable (prestack shot time data) - seismic = validate_variable( - dataset, - name="amplitude", - dims=[("shot_line", 1), ("gun", 3), ("shot_point", 256), ("cable", 512), ("channel", 24), ("time", 2048)], - coords=["orig_field_record_num", "source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y"], - dtype=ScalarType.FLOAT32, - ) - assert isinstance(seismic.compressor, Blosc) - assert seismic.compressor.cname == BloscCname.zstd - assert isinstance(seismic.metadata.chunk_grid, RegularChunkGrid) - assert seismic.metadata.chunk_grid.configuration.chunk_shape == (1, 1, 16, 1, 32, 1024) - assert seismic.metadata.stats_v1 is None - - -@pytest.mark.parametrize("data_domain", ["Time", "TiME"]) -def test_domain_case_handling(data_domain: str) -> None: - """Test that domain parameter handles different cases correctly.""" - template = SeismicPreStackTemplate(data_domain=data_domain) - assert template._data_domain == data_domain.lower() - assert template.name.endswith(data_domain.capitalize()) diff --git a/tests/unit/v1/templates/test_template_registry.py b/tests/unit/v1/templates/test_template_registry.py index a1320580..04d49a7a 100644 --- a/tests/unit/v1/templates/test_template_registry.py +++ b/tests/unit/v1/templates/test_template_registry.py @@ -31,7 +31,7 @@ "PreStackCdpAngleGathers3DDepth", "PreStackCocaGathers3DTime", "PreStackCocaGathers3DDepth", - "PreStackGathers3DTime", + "PreStackStreamerFieldRecords3DTime", "PreStackShotGathers2DTime", "PreStackShotGathers3DTime", ]