diff --git a/packages/fixer-cmip7/fixer_cmip7/fixes.py b/packages/fixer-cmip7/fixer_cmip7/fixes.py index 1a0fa54..66c7fc8 100644 --- a/packages/fixer-cmip7/fixer_cmip7/fixes.py +++ b/packages/fixer-cmip7/fixer_cmip7/fixes.py @@ -142,12 +142,14 @@ def from_cmor_table( ) -def reformat( +def reformat( # noqa: PLR0913 ds: xr.Dataset, realm: str, branded_variable: str, dim_map: dict[str, str] | None = None, variable_map: dict[str, str] | None = None, + *, + keep_global_attrs: bool = False, ) -> xr.Dataset: """Reformat a dataset using the definition from the CMIP7 CMOR tables. @@ -163,6 +165,8 @@ def reformat( A mapping of dimension names to rename. variable_map: A mapping of variable names to rename. + keep_global_attrs: + Whether to keep the global attributes. Returns ------- @@ -172,4 +176,9 @@ def reformat( return CMIP7Variable.from_cmor_table( table_id=realm, entry=branded_variable, - ).to_dataset(ds, dim_map=dim_map, variable_map=variable_map) + ).to_dataset( + ds, + dim_map=dim_map, + variable_map=variable_map, + keep_global_attrs=keep_global_attrs, + ) diff --git a/packages/fixer-esa-cci/fixer_esa_cci/fixes.yaml b/packages/fixer-esa-cci/fixer_esa_cci/fixes.yaml index 6e44c23..4ad9e1a 100644 --- a/packages/fixer-esa-cci/fixer_esa_cci/fixes.yaml +++ b/packages/fixer-esa-cci/fixer_esa_cci/fixes.yaml @@ -6,6 +6,7 @@ ESACCI-WATERVAPOUR-L3C-TCWV-meris-005deg-2002-2017-fv3.2.zarr: bnds: nv variable_map: prw: tcwv + keep_global_attrs: true - function: fixer.fixes.flip_coordinate name: lat ESACCI-WATERVAPOUR-L3C-TCWV-meris-005deg-200207-201712-fv3.2-kr1.1: @@ -16,6 +17,7 @@ ESACCI-WATERVAPOUR-L3C-TCWV-meris-005deg-200207-201712-fv3.2-kr1.1: bnds: nv variable_map: prw: tcwv + keep_global_attrs: true - function: fixer.fixes.flip_coordinate name: lat ESACCI-WATERVAPOUR-L3C-TCWV-meris-005deg-20020701-20171231-fv3.2-kr1.1: @@ -26,6 +28,7 @@ ESACCI-WATERVAPOUR-L3C-TCWV-meris-005deg-20020701-20171231-fv3.2-kr1.1: bnds: nv variable_map: prw: tcwv + keep_global_attrs: true - function: fixer.fixes.flip_coordinate name: lat ESACCI-WATERVAPOUR-L3C-TCWV-meris-05deg-200207-201712-fv3.2-kr1.1: @@ -36,6 +39,7 @@ ESACCI-WATERVAPOUR-L3C-TCWV-meris-05deg-200207-201712-fv3.2-kr1.1: bnds: nv variable_map: prw: tcwv + keep_global_attrs: true - function: fixer.fixes.flip_coordinate name: lat ESACCI-WATERVAPOUR-L3C-TCWV-meris-05deg-20020701-20171231-fv3.2-kr1.1: @@ -46,6 +50,7 @@ ESACCI-WATERVAPOUR-L3C-TCWV-meris-05deg-20020701-20171231-fv3.2-kr1.1: bnds: nv variable_map: prw: tcwv + keep_global_attrs: true - function: fixer.fixes.flip_coordinate name: lat esacci.WATERVAPOUR.day.L3S.TCWV.multi-sensor.multi-platform.TCWV_land_005deg.3-2.r1: @@ -56,6 +61,7 @@ esacci.WATERVAPOUR.day.L3S.TCWV.multi-sensor.multi-platform.TCWV_land_005deg.3-2 branded_variable: prw_tavg-u-hxy-u variable_map: prw: tcwv + keep_global_attrs: true - function: fixer.fixes.flip_coordinate name: lat esacci.WATERVAPOUR.day.L3S.TCWV.multi-sensor.multi-platform.TCWV_land_05deg.3-2.r1: @@ -66,6 +72,7 @@ esacci.WATERVAPOUR.day.L3S.TCWV.multi-sensor.multi-platform.TCWV_land_05deg.3-2. branded_variable: prw_tavg-u-hxy-u variable_map: prw: tcwv + keep_global_attrs: true - function: fixer.fixes.flip_coordinate name: lat esacci.WATERVAPOUR.mon.L3S.TCWV.multi-sensor.multi-platform.TCWV_land_005deg.3-2.r1: @@ -76,6 +83,7 @@ esacci.WATERVAPOUR.mon.L3S.TCWV.multi-sensor.multi-platform.TCWV_land_005deg.3-2 branded_variable: prw_tavg-u-hxy-u variable_map: prw: tcwv + keep_global_attrs: true - function: fixer.fixes.flip_coordinate name: lat esacci.WATERVAPOUR.mon.L3S.TCWV.multi-sensor.multi-platform.TCWV_land_05deg.3-2.r1: @@ -86,5 +94,6 @@ esacci.WATERVAPOUR.mon.L3S.TCWV.multi-sensor.multi-platform.TCWV_land_05deg.3-2. branded_variable: prw_tavg-u-hxy-u variable_map: prw: tcwv + keep_global_attrs: true - function: fixer.fixes.flip_coordinate name: lat diff --git a/packages/fixer/fixer/fixes.py b/packages/fixer/fixer/fixes.py index 41cda3b..0e6e490 100644 --- a/packages/fixer/fixer/fixes.py +++ b/packages/fixer/fixer/fixes.py @@ -228,6 +228,8 @@ def to_dataset( ds: xr.Dataset, dim_map: dict[str, str] | None = None, variable_map: dict[str, str] | None = None, + *, + keep_global_attrs: bool = False, ) -> xr.Dataset: """Create a standardized dataset. @@ -241,6 +243,8 @@ def to_dataset( variable_map: An optional mapping from the variable names in the definition to the variable names in the resulting dataset. + keep_global_attrs: + Whether to keep the global attributes. Returns ------- @@ -280,7 +284,10 @@ def to_dataset( for c in self.coords if c.bounds is not None } - return xr.Dataset({self.name: var} | bounds, coords=coords) + result = xr.Dataset({self.name: var} | bounds, coords=coords) + if keep_global_attrs: + result.attrs = dict(ds.attrs) + return result def set_global_attrs( diff --git a/packages/fixer/fixer/tests/test_fixes.py b/packages/fixer/fixer/tests/test_fixes.py index bfd9135..8c33eb6 100644 --- a/packages/fixer/fixer/tests/test_fixes.py +++ b/packages/fixer/fixer/tests/test_fixes.py @@ -19,7 +19,8 @@ from fixer.protocol import FixFunction -def test_reformat() -> None: +@pytest.mark.parametrize("keep_global_attrs", [False, True]) +def test_reformat(*, keep_global_attrs: bool) -> None: """Test the reformat function on a small synthetic dataset.""" assert isinstance(set_units, FixFunction) ds = xr.Dataset.from_dict( @@ -72,7 +73,7 @@ def test_reformat() -> None: "data": 2.0, }, }, - "attrs": {}, + "attrs": {"test_attr": "test_value"}, "dims": {"time": 1, "bnds": 2, "y": 2, "x": 3}, "data_vars": { "time_bounds": { @@ -222,6 +223,7 @@ def test_reformat() -> None: "lon_bnds": "lon_bounds", "time_bnds": "time_bounds", }, + keep_global_attrs=keep_global_attrs, ) assert result is not ds print("Result:\n", result) @@ -240,6 +242,11 @@ def test_reformat() -> None: assert result.tas.dtype == np.dtype("float32") assert result.lat.dtype == np.dtype("float64") assert result.lat_bnds.dtype == np.dtype("float64") + if keep_global_attrs: + assert "test_attr" in result.attrs + assert result.attrs["test_attr"] == "test_value" + else: + assert not result.attrs @pytest.mark.parametrize(