Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 16 additions & 21 deletions src/mdio/builder/templates/seismic_prestack.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,17 @@ class SeismicPreStackTemplate(AbstractDatasetTemplate):
def __init__(self, data_domain: SeismicDataDomain):
super().__init__(data_domain=data_domain)

self._coord_dim_names = [
"shot_line",
"gun",
"shot_point",
"cable",
"channel",
] # Custom coordinates for shot gathers
self._dim_names = [*self._coord_dim_names, self._data_domain]
self._coord_names = [
"energy_source_point_number",
"source_coord_x",
"source_coord_y",
"group_coord_x",
"group_coord_y",
]
self._var_chunk_shape = [1, 1, 16, 1, 32, -1]
self._spatial_dim_names = ("shot_line", "gun", "shot_point", "cable", "channel")
self._dim_names = (*self._spatial_dim_names, self._data_domain)
self._physical_coord_names = ("source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y")
self._logical_coord_names = ("orig_field_record_num",)
# TODO(Dmitriy Repin): Allow specifying full-dimension-extent chunk size in templates.
# https://github.com/TGSAI/mdio-python/issues/720
# When implemented, the following will be requesting the chunk size of the last dimension
# to be equal to the size of the dimension.
# self._var_chunk_shape = (1, 1, 16, 1, 32, -1)
# For now, we are hardcoding the chunk size to 1024.
self._var_chunk_shape = (1, 1, 16, 1, 32, 1024)

@property
def _name(self) -> str:
Expand All @@ -55,31 +50,31 @@ def _add_coordinates(self) -> None:

# Add non-dimension coordinates
self._builder.add_coordinate(
"energy_source_point_number",
"orig_field_record_num",
dimensions=("shot_line", "gun", "shot_point"),
data_type=ScalarType.INT32,
)
self._builder.add_coordinate(
"source_coord_x",
dimensions=("shot_line", "gun", "shot_point"),
data_type=ScalarType.FLOAT64,
metadata=CoordinateMetadata(units_v1=self._horizontal_coord_unit),
metadata=CoordinateMetadata(units_v1=self.get_unit_by_key("source_coord_x")),
)
self._builder.add_coordinate(
"source_coord_y",
dimensions=("shot_line", "gun", "shot_point"),
data_type=ScalarType.FLOAT64,
metadata=CoordinateMetadata(units_v1=self._horizontal_coord_unit),
metadata=CoordinateMetadata(units_v1=self.get_unit_by_key("source_coord_y")),
)
self._builder.add_coordinate(
"group_coord_x",
dimensions=("shot_line", "gun", "shot_point", "cable", "channel"),
data_type=ScalarType.FLOAT64,
metadata=CoordinateMetadata(units_v1=self._horizontal_coord_unit),
metadata=CoordinateMetadata(units_v1=self.get_unit_by_key("group_coord_x")),
)
self._builder.add_coordinate(
"group_coord_y",
dimensions=("shot_line", "gun", "shot_point", "cable", "channel"),
data_type=ScalarType.FLOAT64,
metadata=CoordinateMetadata(units_v1=self._horizontal_coord_unit),
metadata=CoordinateMetadata(units_v1=self.get_unit_by_key("group_coord_y")),
)
4 changes: 2 additions & 2 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
def get_segy_mock_4d_spec() -> SegySpec:
"""Create a mock 4D SEG-Y specification."""
trace_header_fields = [
HeaderField(name="field_rec_no", byte=9, format="int32"),
HeaderField(name="orig_field_record_num", byte=9, format="int32"),
HeaderField(name="channel", byte=13, format="int32"),
HeaderField(name="shot_point", byte=17, format="int32"),
HeaderField(name="offset", byte=37, format="int32"),
Expand Down Expand Up @@ -118,7 +118,7 @@ def create_segy_mock_4d( # noqa: PLR0913
channel, gun, shot_line = 0, 0, 0

# Assign dimension coordinate fields with calculated mock data
header_fields = ["field_rec_no", "channel", "shot_point", "offset", "shot_line", "cable", "gun"]
header_fields = ["orig_field_record_num", "channel", "shot_point", "offset", "shot_line", "cable", "gun"]
headers[header_fields][trc_idx] = (shot, channel, shot, offset, shot_line, cable, gun)

# Assign coordinate fields with mock data
Expand Down
15 changes: 8 additions & 7 deletions tests/integration/test_import_streamer_grid_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
# TODO(Altay): Finish implementing these grid overrides.
# https://github.com/TGSAI/mdio-python/issues/612
@pytest.mark.skip(reason="NonBinned and HasDuplicates haven't been properly implemented yet.")
@pytest.mark.parametrize("grid_override", [{"NonBinned": True}, {"HasDuplicates": True}])
@pytest.mark.parametrize(
"grid_override", [{"NonBinned": True}, {"HasDuplicates": True}], ids=["NonBinned", "HasDuplicates"]
)
@pytest.mark.parametrize("chan_header_type", [StreamerShotGeometryType.C])
class TestImport4DNonReg: # pragma: no cover - tests is skipped
"""Test for 4D segy import with grid overrides."""
Expand Down Expand Up @@ -78,7 +80,7 @@ def test_import_4d_segy( # noqa: PLR0913
xrt.assert_duckarray_equal(ds["time"], times_expected)


@pytest.mark.parametrize("grid_override", [{"AutoChannelWrap": True}, None])
@pytest.mark.parametrize("grid_override", [{"AutoChannelWrap": True}, None], ids=["AutoChannelWrap", "None"])
@pytest.mark.parametrize("chan_header_type", [StreamerShotGeometryType.A, StreamerShotGeometryType.B])
class TestImport4D:
"""Test for 4D segy import with grid overrides."""
Expand Down Expand Up @@ -156,10 +158,9 @@ def test_import_4d_segy( # noqa: PLR0913
assert "This grid is very sparse and most likely user error with indexing." in str(execinfo.value)


# TODO(Altay): Finish implementing these grid overrides.
# https://github.com/TGSAI/mdio-python/issues/612
@pytest.mark.skip(reason="AutoShotWrap requires a template that is not implemented yet.")
@pytest.mark.parametrize("grid_override", [{"AutoChannelWrap": True}, {"AutoShotWrap": True}, None])
@pytest.mark.parametrize(
"grid_override", [{"AutoChannelWrap": True, "AutoShotWrap": True}, None], ids=["Channel&ShotWrap", "None"]
)
@pytest.mark.parametrize("chan_header_type", [StreamerShotGeometryType.A, StreamerShotGeometryType.B])
class TestImport6D: # pragma: no cover - tests is skipped
"""Test for 6D segy import with grid overrides."""
Expand All @@ -177,7 +178,7 @@ def test_import_6d_segy( # noqa: PLR0913

segy_to_mdio(
segy_spec=segy_spec,
mdio_template=TemplateRegistry().get("XYZ"), # Placeholder for the template
mdio_template=TemplateRegistry().get("PreStackGathers3DTime"), # Placeholder for the template
input_path=segy_path,
output_path=zarr_tmp,
overwrite=True,
Expand Down
43 changes: 25 additions & 18 deletions tests/unit/v1/templates/test_seismic_prestack.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@
from mdio.builder.schemas.v1.dataset import Dataset
from mdio.builder.schemas.v1.units import LengthUnitEnum
from mdio.builder.schemas.v1.units import LengthUnitModel
from mdio.builder.schemas.v1.units import TimeUnitEnum
from mdio.builder.schemas.v1.units import TimeUnitModel
from mdio.builder.templates.seismic_prestack import SeismicPreStackTemplate

UNITS_METER = LengthUnitModel(length=LengthUnitEnum.METER)
UNITS_SECOND = TimeUnitModel(time=TimeUnitEnum.SECOND)


def _validate_coordinates_headers_trace_mask(dataset: Dataset, headers: StructuredType, domain: str) -> None:
Expand All @@ -27,15 +30,15 @@ def _validate_coordinates_headers_trace_mask(dataset: Dataset, headers: Structur
dataset,
name="headers",
dims=[("shot_line", 1), ("gun", 3), ("shot_point", 256), ("cable", 512), ("channel", 24)],
coords=["energy_source_point_number", "source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y"],
coords=["orig_field_record_num", "source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y"],
dtype=headers,
)

validate_variable(
dataset,
name="trace_mask",
dims=[("shot_line", 1), ("gun", 3), ("shot_point", 256), ("cable", 512), ("channel", 24)],
coords=["energy_source_point_number", "source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y"],
coords=["orig_field_record_num", "source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y"],
dtype=ScalarType.BOOL,
)

Expand Down Expand Up @@ -97,9 +100,9 @@ def _validate_coordinates_headers_trace_mask(dataset: Dataset, headers: Structur
# Verify non-dimension coordinate variables
validate_variable(
dataset,
name="energy_source_point_number",
name="orig_field_record_num",
dims=[("shot_line", 1), ("gun", 3), ("shot_point", 256)],
coords=["energy_source_point_number"],
coords=["orig_field_record_num"],
dtype=ScalarType.INT32,
)

Expand Down Expand Up @@ -148,39 +151,43 @@ def test_configuration(self) -> None:
t = SeismicPreStackTemplate(data_domain="time")

# Template attributes for prestack shot
assert t._data_domain == "time"
assert t._coord_dim_names == ["shot_line", "gun", "shot_point", "cable", "channel"]
assert t._dim_names == ["shot_line", "gun", "shot_point", "cable", "channel", "time"]
assert t._coord_names == [
"energy_source_point_number",
assert t.name == "PreStackGathers3DTime"
assert t.default_variable_name == "amplitude"
assert t.trace_domain == "time"
assert t.spatial_dimension_names == ("shot_line", "gun", "shot_point", "cable", "channel")
assert t.dimension_names == ("shot_line", "gun", "shot_point", "cable", "channel", "time")
assert t.physical_coordinate_names == ("source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y")
assert t.logical_coordinate_names == ("orig_field_record_num",)
assert t.coordinate_names == (
"source_coord_x",
"source_coord_y",
"group_coord_x",
"group_coord_y",
]
assert t._var_chunk_shape == [1, 1, 16, 1, 32, -1]
"orig_field_record_num",
)
assert t.full_chunk_size == (1, 1, 16, 1, 32, -1)

# Variables instantiated when build_dataset() is called
assert t._builder is None
assert t._dim_sizes == ()
assert t._horizontal_coord_unit is None
assert t._units == {}

# Verify prestack shot attributes
attrs = t._load_dataset_attributes()
assert attrs == {"surveyDimensionality": "3D", "ensembleType": "shot_point", "processingStage": "pre-stack"}
assert t.default_variable_name == "amplitude"

assert t.name == "PreStackGathers3DTime"

def test_build_dataset(self, structured_headers: StructuredType) -> None:
"""Unit tests for SeismicPreStackTemplate build in time domain."""
t = SeismicPreStackTemplate(data_domain="time")
t.add_units({"source_coord_x": UNITS_METER, "source_coord_y": UNITS_METER}) # spatial domain units
t.add_units({"group_coord_x": UNITS_METER, "group_coord_y": UNITS_METER}) # spatial domain units
t.add_units({"time": UNITS_SECOND}) # data domain units

assert t.name == "PreStackGathers3DTime"
dataset = t.build_dataset(
"North Sea 3D Shot Time",
sizes=(1, 3, 256, 512, 24, 2048),
horizontal_coord_unit=UNITS_METER,
header_dtype=structured_headers,
"North Sea 3D Shot Time", sizes=(1, 3, 256, 512, 24, 2048), header_dtype=structured_headers
)

assert dataset.metadata.name == "North Sea 3D Shot Time"
Expand All @@ -195,7 +202,7 @@ def test_build_dataset(self, structured_headers: StructuredType) -> None:
dataset,
name="amplitude",
dims=[("shot_line", 1), ("gun", 3), ("shot_point", 256), ("cable", 512), ("channel", 24), ("time", 2048)],
coords=["energy_source_point_number", "source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y"],
coords=["orig_field_record_num", "source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y"],
dtype=ScalarType.FLOAT32,
)
assert isinstance(seismic.compressor, Blosc)
Expand Down
Loading