diff --git a/imap_processing/ena_maps/ena_maps.py b/imap_processing/ena_maps/ena_maps.py index a8c165743..44a191039 100644 --- a/imap_processing/ena_maps/ena_maps.py +++ b/imap_processing/ena_maps/ena_maps.py @@ -618,7 +618,7 @@ def update_az_el_points(self) -> None: The values stored in the "hae_longitude" and "hae_latitude" variables are used to construct the azimuth and elevation coordinates. """ - logger.info( + logger.debug( "Updating az/el points based on data in hae_longitude and" "hae_latitude variables." ) diff --git a/imap_processing/ena_maps/utils/corrections.py b/imap_processing/ena_maps/utils/corrections.py index ff7402b31..6346b8198 100644 --- a/imap_processing/ena_maps/utils/corrections.py +++ b/imap_processing/ena_maps/utils/corrections.py @@ -775,6 +775,7 @@ def interpolate_map_flux_to_helio_frame( esa_energies: xr.DataArray, helio_energies: xr.DataArray, vars_to_interpolate: list[str], + update_sys_err: bool = True, ) -> xr.Dataset: """ Interpolate flux from spacecraft frame to heliocentric frame energies. @@ -806,6 +807,9 @@ def interpolate_map_flux_to_helio_frame( dataset and will be interpolated as well. For example, if ["ena_intensity"] is input, then the variables "ena_intensity", "ena_intensity_stat_uncert", and "ena_intensity_sys_err" will be interpolated. + update_sys_err : bool, optional + Flag indicating whether to update the systematic error variables as part + of the flux interpolation. Defaults to True. Returns ------- @@ -912,7 +916,8 @@ def interpolate_map_flux_to_helio_frame( # Update the dataset with interpolated values map_ds[var_name] = flux_helio map_ds[f"{var_name}_stat_uncert"] = stat_unc_helio - map_ds[f"{var_name}_sys_err"] = sys_err_helio + if update_sys_err: + map_ds[f"{var_name}_sys_err"] = sys_err_helio return map_ds diff --git a/imap_processing/hi/hi_l2.py b/imap_processing/hi/hi_l2.py index d69d21acf..03a5e3903 100644 --- a/imap_processing/hi/hi_l2.py +++ b/imap_processing/hi/hi_l2.py @@ -29,12 +29,11 @@ "counts", "exposure_factor", "bg_rate", - "bg_rate_sys_err", "obs_date", } HELIO_FRAME_VARS_TO_PROJECT = SC_FRAME_VARS_TO_PROJECT | {"energy_sc"} # TODO: is an exposure time weighted average for obs_date appropriate? -FULL_EXPOSURE_TIME_AVERAGE_SET = {"bg_rate", "bg_rate_sys_err", "obs_date", "energy_sc"} +FULL_EXPOSURE_TIME_AVERAGE_SET = {"bg_rate", "obs_date", "energy_sc"} # ============================================================================= @@ -90,21 +89,25 @@ def hi_l2( ) logger.info(f"Step 1: Creating sky map from {len(psets)} pointing sets") - sky_map = create_sky_map_from_psets( + sky_maps = create_sky_map_from_psets( psets, l2_ancillary_path_dict, map_descriptor, ) logger.info("Step 2: Calculating rates and intensities") - sky_map.data_1d = calculate_all_rates_and_intensities( - sky_map.data_1d, - l2_ancillary_path_dict, - map_descriptor, - ) + for sky_map in sky_maps.values(): + sky_map.data_1d = calculate_all_rates_and_intensities( + sky_map.data_1d, + l2_ancillary_path_dict, + map_descriptor, + ) - logger.info("Step 3: Finalizing dataset with attributes") - l2_ds = sky_map.build_cdf_dataset( + logger.info("Step 3: Combining maps (if needed)") + final_map = combine_maps(sky_maps) + + logger.info("Step 4: Finalizing dataset with attributes") + l2_ds = final_map.build_cdf_dataset( "hi", "l2", descriptor, @@ -124,7 +127,7 @@ def create_sky_map_from_psets( psets: list[str | Path], l2_ancillary_path_dict: dict[str, Path], descriptor: MapDescriptor, -) -> RectangularSkyMap: +) -> dict[str, RectangularSkyMap]: """ Project Hi PSET data into a sky map. @@ -141,18 +144,35 @@ def create_sky_map_from_psets( Returns ------- - sky_map : RectangularSkyMap - The sky map with all the PSET data projected into the map. Includes + sky_maps : dict[str, RectangularSkyMap] + Dictionary mapping spin phase keys ("ram", "anti", or "full") + to sky maps with all the PSET data projected. Includes an energy coordinate and energy_delta_minus and energy_delta_plus - variables from ESA energy calibration data. + variables from ESA energy calibration data. For helioframe full-spin + maps, contains "ram" and "anti" keys; otherwise contains a single key + matching the descriptor's spin_phase. """ if len(psets) == 0: raise ValueError("No PSETs provided for map creation") - output_map = descriptor.to_empty_map() - - if not isinstance(output_map, RectangularSkyMap): + # If we are making a full-spin, helio-frame map, we need to make one ram + # and one anti-ram map that get combined at the final L2 step. + if descriptor.frame_descriptor == "hf" and descriptor.spin_phase == "full": + # The spin-phase of the descriptor has no effect on the output map, so + # we can just use descriptor.to_empty_map() to generate both. + output_maps = { + "ram": descriptor.to_empty_map(), + "anti": descriptor.to_empty_map(), + } + else: + output_maps = {descriptor.spin_phase: descriptor.to_empty_map()} + + if not all([isinstance(map, RectangularSkyMap) for map in output_maps.values()]): raise NotImplementedError("Healpix map output not supported for Hi") + # Needed for mypy type narrowing + rect_maps: dict[str, RectangularSkyMap] = { + k: v for k, v in output_maps.items() if isinstance(v, RectangularSkyMap) + } vars_to_bin = ( HELIO_FRAME_VARS_TO_PROJECT @@ -187,31 +207,31 @@ def create_sky_map_from_psets( vars_to_exposure_time_average, ) - # Project (bin) the PSET variables into the map pixels - directional_mask = get_pset_directional_mask( - pset_processed, descriptor.spin_phase - ) hi_pset = HiPointingSet(pset_processed) - output_map.project_pset_values_to_map( - hi_pset, list(vars_to_bin), pset_valid_mask=directional_mask - ) - # Finish the exposure time weighted mean calculation of backgrounds - # Allow divide by zero to fill set pixels with zero exposure time to NaN - with np.errstate(divide="ignore"): - for var in vars_to_exposure_time_average: - output_map.data_1d[var] /= output_map.data_1d["exposure_factor"] - - # Add ESA energy data to the map dataset for use in rate/intensity calculations - energy_delta = esa_ds["bandpass_fwhm"] / 2 - output_map.data_1d["energy_delta_minus"] = energy_delta - output_map.data_1d["energy_delta_plus"] = energy_delta - # Add energy as an auxiliary coordinate (keV values indexed by esa_energy_step) - output_map.data_1d = output_map.data_1d.assign_coords( - energy=("esa_energy_step", esa_ds["nominal_central_energy"].values) - ) + for spin_phase, map in rect_maps.items(): + # Project (bin) the PSET variables into the map pixels + directional_mask = get_pset_directional_mask(pset_processed, spin_phase) + map.project_pset_values_to_map( + hi_pset, list(vars_to_bin), pset_valid_mask=directional_mask + ) - return output_map + for map in rect_maps.values(): + # Finish the exposure time weighted mean calculation of backgrounds + # Allow divide by zero to fill set pixels with zero exposure time to NaN + with np.errstate(divide="ignore"): + map.data_1d[vars_to_exposure_time_average] /= map.data_1d["exposure_factor"] + + # Add ESA energy data to the map dataset for use in rate/intensity calculations + energy_delta = esa_ds["bandpass_fwhm"] / 2 + map.data_1d["energy_delta_minus"] = energy_delta + map.data_1d["energy_delta_plus"] = energy_delta + # Add energy as an auxiliary coordinate (keV values indexed by esa_energy_step) + map.data_1d = map.data_1d.assign_coords( + energy=("esa_energy_step", esa_ds["nominal_central_energy"].values) + ) + + return rect_maps # ============================================================================= @@ -323,13 +343,11 @@ def calculate_all_rates_and_intensities( # TODO: Handle variable types correctly in RectangularSkyMap.build_cdf_dataset obs_date = map_ds["obs_date"] # Replace non-finite values with the int64 sentinel before casting - obs_date_filled = xr.where( + map_ds["obs_date"] = xr.where( np.isfinite(obs_date), - obs_date, + obs_date.astype("int64"), np.int64(-9223372036854775808), ) - map_ds["obs_date"] = obs_date_filled.astype("int64") - # TODO: Figure out how to compute obs_date_range (stddev of obs_date) map_ds["obs_date_range"] = xr.zeros_like(map_ds["obs_date"]) # Step 4: Swap esa_energy_step dimension for energy coordinate @@ -343,15 +361,21 @@ def calculate_all_rates_and_intensities( logger.debug("Applying Compton-Getting interpolation for heliocentric frame") # Convert energy coordinate from keV to eV for interpolation esa_energy_ev = map_ds["energy"] * 1000 + + # Hi does not want to apply the flux correction to the systematic error. map_ds = interpolate_map_flux_to_helio_frame( map_ds, esa_energy_ev, # ESA energies in eV esa_energy_ev, # heliocentric energies (same as ESA energies) ["ena_intensity"], + update_sys_err=False, # Hi does not update the systematic error ) # Drop any esa_energy_step_label that may have been re-added map_ds = map_ds.drop_vars(["esa_energy_step_label"], errors="ignore") + # Step 6: Clean up intermediate variables + map_ds = cleanup_intermediate_variables(map_ds) + return map_ds @@ -525,6 +549,117 @@ def combine_calibration_products( return map_ds +def combine_maps(sky_maps: dict[str, RectangularSkyMap]) -> RectangularSkyMap: + """ + Combine ram and anti-ram sky maps using appropriate weighting. + + For full-spin heliocentric frame maps, ram and anti-ram maps are processed + separately through the CG correction pipeline, then combined here using + inverse-variance weighting for intensity and appropriate methods for other + variables. + + Parameters + ---------- + sky_maps : dict[str, RectangularSkyMap] + Dictionary of sky maps to combine. Expected to contain either 1 map + (no combination needed) or 2 maps with "ram" and "anti" keys. + + Returns + ------- + RectangularSkyMap + Combined sky map. + """ + if len(sky_maps) not in [1, 2]: + raise ValueError(f"Expected 1 or 2 sky maps, got {len(sky_maps)}") + if len(sky_maps) == 1: + logger.debug("Only one sky map provided, returning it unchanged") + return next(iter(sky_maps.values())) + if "ram" not in sky_maps or "anti" not in sky_maps: + raise ValueError( + f"Expected sky maps with 'ram' and 'anti' keys." + f"Instead got: {sky_maps.keys()}" + ) + + logger.info("Combining ram and anti-ram maps using inverse-variance weighting.") + + ram_ds = sky_maps["ram"].data_1d + anti_ds = sky_maps["anti"].data_1d + + # Use the ram sky map as the base for the combined result + combined_map = sky_maps["ram"] + combined = ram_ds.copy() + + # Additive variables: counts and exposure_factor + combined["counts"] = ram_ds["counts"] + anti_ds["counts"] + combined["exposure_factor"] = ram_ds["exposure_factor"] + anti_ds["exposure_factor"] + + # Inverse-variance weighted average for ena_intensity + weight_ram = xr.where( + ram_ds["ena_intensity_stat_uncert"] > 0, + 1 / ram_ds["ena_intensity_stat_uncert"] ** 2, + 0, + ) + weight_anti = xr.where( + anti_ds["ena_intensity_stat_uncert"] > 0, + 1 / anti_ds["ena_intensity_stat_uncert"] ** 2, + 0, + ) + total_weight = weight_ram + weight_anti + + with np.errstate(divide="ignore", invalid="ignore"): + combined["ena_intensity"] = ( + ram_ds["ena_intensity"] * weight_ram + + anti_ds["ena_intensity"] * weight_anti + ) / total_weight + + combined["ena_intensity_stat_uncert"] = np.sqrt(1 / total_weight) + + # Exposure-weighted average for systematic error + total_exp = combined["exposure_factor"] + combined["ena_intensity_sys_err"] = ( + ram_ds["ena_intensity_sys_err"] * ram_ds["exposure_factor"] + + anti_ds["ena_intensity_sys_err"] * anti_ds["exposure_factor"] + ) / total_exp + + # Exposure-weighted average for obs_date + with np.errstate(divide="ignore", invalid="ignore"): + combined["obs_date"] = ( + ram_ds["obs_date"] * ram_ds["exposure_factor"] + + anti_ds["obs_date"] * anti_ds["exposure_factor"] + ) / total_exp + + # Combined obs_date_range accounts for within-group and between-group variance + # var_combined = (w1*var1 + w2*var2)/(w1+w2) + w1*w2*(mean1-mean2)^2/(w1+w2)^2 + within_variance = ( + ram_ds["exposure_factor"] * ram_ds["obs_date_range"] ** 2 + + anti_ds["exposure_factor"] * anti_ds["obs_date_range"] ** 2 + ) / total_exp + between_variance = ( + ram_ds["exposure_factor"] + * anti_ds["exposure_factor"] + * (ram_ds["obs_date"] - anti_ds["obs_date"]) ** 2 + ) / (total_exp**2) + combined["obs_date_range"] = np.sqrt(within_variance + between_variance) + + # Re-cast obs_date and obs_date_range back to int64 after float arithmetic. + # Replace non-finite values with the int64 sentinel value. + # TODO: Handle variable types correctly in RectangularSkyMap.build_cdf_dataset + int64_sentinel = np.int64(-9223372036854775808) + combined["obs_date"] = xr.where( + np.isfinite(combined["obs_date"]), + combined["obs_date"].astype("int64"), + int64_sentinel, + ) + combined["obs_date_range"] = xr.where( + np.isfinite(combined["obs_date_range"]), + combined["obs_date_range"].astype("int64"), + int64_sentinel, + ) + + combined_map.data_1d = combined + return combined_map + + def _calculate_improved_stat_variance( map_ds: xr.Dataset, geometric_factors: xr.DataArray, @@ -600,6 +735,33 @@ def _calculate_improved_stat_variance( return improved_variance +def cleanup_intermediate_variables(dataset: xr.Dataset) -> xr.Dataset: + """ + Remove intermediate variables that were only needed for calculations. + + Parameters + ---------- + dataset : xarray.Dataset + Dataset containing intermediate calculation variables. + + Returns + ------- + xarray.Dataset + Cleaned dataset with intermediate variables removed. + """ + # Remove the intermediate variables from the map + potential_vars = [ + "bg_rate", + "energy_sc", + "ena_signal_rates", + "ena_signal_rate_stat_unc", + ] + + vars_to_remove = [var for var in potential_vars if var in dataset.data_vars] + + return dataset.drop_vars(vars_to_remove) + + # ============================================================================= # SETUP AND INITIALIZATION HELPERS # ============================================================================= diff --git a/imap_processing/tests/ena_maps/test_corrections.py b/imap_processing/tests/ena_maps/test_corrections.py index 82ab6a619..f88f6e87a 100644 --- a/imap_processing/tests/ena_maps/test_corrections.py +++ b/imap_processing/tests/ena_maps/test_corrections.py @@ -1291,6 +1291,25 @@ def test_systematic_uncertainty_propagation(self): assert np.all(rel_sys_err > 0) assert np.all(rel_sys_err < 0.5) # Should be reasonable + def test_systematic_uncertainty_update_flag(self): + """Test that systematic error is unchanged when flag is set False.""" + + map_ds, esa_energies, helio_energies = self.create_test_map_dataset( + n_energy=3, n_spatial=1 + ) + sys_err_input = map_ds["ena_intensity_sys_err"].copy() + + result_ds = interpolate_map_flux_to_helio_frame( + map_ds, + esa_energies, + helio_energies, + ["ena_intensity"], + update_sys_err=False, + ) + + # Systematic uncertainty should be positive and finite + xr.testing.assert_equal(result_ds["ena_intensity_sys_err"], sys_err_input) + def test_energy_scaling_transformation(self): """Test Liouville theorem: flux_helio = flux_sc * (E_helio / E_sc).""" diff --git a/imap_processing/tests/hi/test_hi_l2.py b/imap_processing/tests/hi/test_hi_l2.py index e28e54aca..bdf414b5d 100644 --- a/imap_processing/tests/hi/test_hi_l2.py +++ b/imap_processing/tests/hi/test_hi_l2.py @@ -8,14 +8,16 @@ import xarray as xr from imap_processing.cdf.utils import load_cdf, write_cdf -from imap_processing.ena_maps.ena_maps import RectangularSkyMap +from imap_processing.ena_maps.ena_maps import HealpixSkyMap, RectangularSkyMap from imap_processing.ena_maps.utils.naming import MapDescriptor from imap_processing.hi.hi_l2 import ( _calculate_improved_stat_variance, calculate_all_rates_and_intensities, calculate_ena_intensity, calculate_ena_signal_rates, + cleanup_intermediate_variables, combine_calibration_products, + combine_maps, create_sky_map_from_psets, esa_energy_df, hi_l2, @@ -198,8 +200,8 @@ def test_hi_l2_uses_descriptor_to_setup_map( pset_path = hi_l1_test_data_path / "imap_hi_l1c_45sensor-pset_20250415_v999.cdf" descriptor_str = "h90-ena-h-sf-nsp-full-hnu-2deg-3mo" rect_map = MapDescriptor.from_string(descriptor_str).to_empty_map() - # create_sky_map_from_psets returns just the sky_map - mock_create_sky_map_from_psets.return_value = rect_map + # create_sky_map_from_psets returns a dict with spin_phase key + mock_create_sky_map_from_psets.return_value = {"full": rect_map} # calculate_all_rates_and_intensities modifies and returns the map data mock_calculate_all_rates_and_intensities.side_effect = lambda ds, *args: ds mock_map_build_cdf_dataset.return_value = xr.Dataset() @@ -215,11 +217,12 @@ def test_hi_l2_uses_descriptor_to_setup_map( @pytest.mark.parametrize( - "descriptor_str", + "descriptor_str, expected_keys", [ - "h90-ena-h-sf-nsp-full-gcs-6deg-3mo", - "h90-ena-h-sf-nsp-ram-gcs-6deg-3mo", - "h90-ena-h-hf-nsp-ram-gcs-6deg-3mo", + ("h90-ena-h-sf-nsp-full-gcs-6deg-3mo", ["full"]), + ("h90-ena-h-sf-nsp-ram-gcs-6deg-3mo", ["ram"]), + ("h90-ena-h-hf-nsp-ram-gcs-6deg-3mo", ["ram"]), + ("h90-ena-h-hf-nsp-full-gcs-6deg-3mo", ["ram", "anti"]), ], ) @pytest.mark.external_test_data @@ -228,6 +231,7 @@ def test_create_sky_map_from_psets( anc_path_dict, furnish_kernels, descriptor_str, + expected_keys, ): """Test coverage for create_sky_map_from_psets()""" kernels = [ @@ -241,39 +245,67 @@ def test_create_sky_map_from_psets( pset_path = hi_l1_test_data_path / "imap_hi_l1c_45sensor-pset_20250415_v999.cdf" map_descriptor = MapDescriptor.from_string(descriptor_str) - sky_map = create_sky_map_from_psets( + sky_maps = create_sky_map_from_psets( [pset_path], anc_path_dict, map_descriptor, ) - assert isinstance(sky_map, RectangularSkyMap) - assert sky_map.spacing_deg == 6 - assert sky_map.spice_reference_frame == SpiceFrame.IMAP_GCS - - # Check that ESA energy data was added to the map - assert "energy_delta_minus" in sky_map.data_1d - assert "energy_delta_plus" in sky_map.data_1d - assert "energy" in sky_map.data_1d.coords - - # Test that we got some non-zero values - for var_name in ["counts", "exposure_factor", "obs_date"]: - assert var_name in sky_map.data_1d.data_vars - assert np.nanmax(sky_map.data_1d[var_name].data) > 0 + + # Check that returned dict has expected keys + assert isinstance(sky_maps, dict) + assert set(sky_maps.keys()) == set(expected_keys) + + # Check each map in the dict + for sky_map in sky_maps.values(): + assert isinstance(sky_map, RectangularSkyMap) + assert sky_map.spacing_deg == 6 + assert sky_map.spice_reference_frame == SpiceFrame.IMAP_GCS + + # Check that ESA energy data was added to the map + assert "energy_delta_minus" in sky_map.data_1d + assert "energy_delta_plus" in sky_map.data_1d + assert "energy" in sky_map.data_1d.coords + + # Test that we got some non-zero values + for var_name in ["counts", "exposure_factor", "obs_date"]: + assert var_name in sky_map.data_1d.data_vars + assert np.nanmax(sky_map.data_1d[var_name].data) > 0 + + # If the CG correction ran, check that the energy_sc variable is present + if "-hf-" in descriptor_str: + assert "energy_sc" in sky_map.data_1d.data_vars + assert np.nanmax(sky_map.data_1d["energy_sc"].data) > 0 + # With a single PSET input, the valid obs_date values should be very close # to the PSET midpoint. Convert to seconds to set reasonable comparison # tolerance. + first_map = next(iter(sky_maps.values())) pset = load_cdf(pset_path) pset_midpoint = (pset["epoch"].values[0] + pset["epoch_delta"].values[0] / 2) / 1e9 np.testing.assert_allclose( - np.nanmax(sky_map.data_1d["obs_date"].data) / 1e9, + np.nanmax(first_map.data_1d["obs_date"].data) / 1e9, pset_midpoint, atol=60, ) - # If the CG correction ran, check that the energy_sc variable is present - # in the map - if "-hf-" in descriptor_str: - assert "energy_sc" in sky_map.data_1d.data_vars - assert np.nanmax(sky_map.data_1d["energy_sc"].data) > 0 + + +def test_create_sky_map_from_psets_healpix_not_supported(): + """Test that NotImplementedError is raised when HealpixSkyMap is returned.""" + # Create a mock descriptor that returns a HealpixSkyMap + mock_descriptor = mock.Mock() + mock_descriptor.frame_descriptor = "sf" + mock_descriptor.spin_phase = "full" + + # Create a mock HealpixSkyMap + mock_healpix_map = mock.Mock(spec=HealpixSkyMap) + mock_descriptor.to_empty_map.return_value = mock_healpix_map + + with pytest.raises(NotImplementedError, match="Healpix map output not supported"): + create_sky_map_from_psets( + ["fake_pset.cdf"], # non-empty psets list + {}, # empty ancillary dict + mock_descriptor, + ) def test_calculate_ena_signal_rates(empty_rectangular_map_dataset): @@ -1005,10 +1037,6 @@ def test_calculate_all_rates_and_intensities_basic( descriptor, ) - # Check that signal rates were calculated - assert "ena_signal_rates" in result - assert "ena_signal_rate_stat_unc" in result - # Check that intensities were calculated assert "ena_intensity" in result assert "ena_intensity_stat_uncert" in result @@ -1077,12 +1105,14 @@ def test_calculate_all_rates_and_intensities_adds_obs_date_range( assert "obs_date_range" in result -@mock.patch("imap_processing.hi.hi_l2.interpolate_map_flux_to_helio_frame") +@mock.patch( + "imap_processing.hi.hi_l2.interpolate_map_flux_to_helio_frame", autospec=True +) def test_calculate_all_rates_and_intensities_cg_correction( mock_interp_flux, mock_map_dataset_for_rates, anc_path_dict ): """Test that CG interpolation is applied for heliocentric frame.""" - mock_interp_flux.side_effect = lambda ds, *args: ds + mock_interp_flux.side_effect = lambda ds, *args, **kwargs: ds descriptor = MapDescriptor.from_string("h90-ena-h-hf-nsp-full-gcs-6deg-3mo") @@ -1113,3 +1143,286 @@ def test_calculate_all_rates_and_intensities_no_cg_for_sf( # interpolate_map_flux_to_helio_frame should NOT have been called for sf frame mock_interp_flux.assert_not_called() + + +# ============================================================================= +# CLEANUP INTERMEDIATE VARIABLES TESTS +# ============================================================================= + + +def test_cleanup_intermediate_variables(): + """Test that cleanup_intermediate_variables removes expected variables.""" + # Create a dataset with intermediate variables + ds = xr.Dataset( + { + "bg_rate": xr.DataArray([1, 2, 3], dims=["x"]), + "energy_sc": xr.DataArray([4, 5, 6], dims=["x"]), + "ena_signal_rates": xr.DataArray([7, 8, 9], dims=["x"]), + "ena_signal_rate_stat_unc": xr.DataArray([0.1, 0.2, 0.3], dims=["x"]), + "ena_intensity": xr.DataArray([10, 20, 30], dims=["x"]), + "exposure_factor": xr.DataArray([100, 200, 300], dims=["x"]), + } + ) + + result = cleanup_intermediate_variables(ds) + + # Intermediate variables should be removed + assert "bg_rate" not in result + assert "energy_sc" not in result + assert "ena_signal_rates" not in result + assert "ena_signal_rate_stat_unc" not in result + + # Non-intermediate variables should remain + assert "ena_intensity" in result + assert "exposure_factor" in result + + +def test_cleanup_intermediate_variables_missing_vars(): + """Test cleanup works when some intermediate variables don't exist.""" + # Create a dataset without all intermediate variables + ds = xr.Dataset( + { + "bg_rate": xr.DataArray([1, 2, 3], dims=["x"]), + "ena_intensity": xr.DataArray([10, 20, 30], dims=["x"]), + } + ) + + # Should not raise an error + result = cleanup_intermediate_variables(ds) + + assert "bg_rate" not in result + assert "ena_intensity" in result + + +# ============================================================================= +# COMBINE MAPS TESTS +# ============================================================================= + + +@pytest.fixture +def mock_sky_map_for_combine(): + """Create a mock RectangularSkyMap for testing combine_maps.""" + + def _create_map(intensity_offset=0, exposure_offset=0): + """Helper to create a map with configurable values.""" + descriptor = MapDescriptor.from_string("h90-ena-h-hf-nsp-full-gcs-6deg-3mo") + sky_map = descriptor.to_empty_map() + + # Create simple test data + shape = (1, 3, 4, 2) # epoch, energy, lon, lat + sky_map.data_1d = xr.Dataset( + { + "counts": xr.DataArray( + np.ones(shape) * (100 + intensity_offset), + dims=["epoch", "energy", "longitude", "latitude"], + ), + "exposure_factor": xr.DataArray( + np.ones(shape) * (10 + exposure_offset), + dims=["epoch", "energy", "longitude", "latitude"], + ), + "obs_date": xr.DataArray( + np.ones(shape) * (1e18 + intensity_offset * 1e15), + dims=["epoch", "energy", "longitude", "latitude"], + ), + "obs_date_range": xr.DataArray( + np.ones(shape) * (1e14 + intensity_offset * 1e13), + dims=["epoch", "energy", "longitude", "latitude"], + ), + "ena_intensity": xr.DataArray( + np.ones(shape) * (50 + intensity_offset), + dims=["epoch", "energy", "longitude", "latitude"], + ), + "ena_intensity_stat_uncert": xr.DataArray( + np.ones(shape) * 5.0, + dims=["epoch", "energy", "longitude", "latitude"], + ), + "ena_intensity_sys_err": xr.DataArray( + np.ones(shape) * 2.0, + dims=["epoch", "energy", "longitude", "latitude"], + ), + }, + coords={ + "epoch": [0], + "energy": [0.5, 0.75, 1.1], + "longitude": np.arange(4), + "latitude": np.arange(2), + }, + ) + return sky_map + + return _create_map + + +def test_combine_maps_single_map(mock_sky_map_for_combine): + """Test combine_maps with a single map returns it unchanged.""" + sky_map = mock_sky_map_for_combine() + sky_maps = {"full": sky_map} + + result = combine_maps(sky_maps) + + # Should return the same map + assert result is sky_map + + +def test_combine_maps_two_maps(mock_sky_map_for_combine): + """Test combine_maps properly combines ram and anti-ram maps.""" + ram_map = mock_sky_map_for_combine(intensity_offset=0, exposure_offset=0) + anti_map = mock_sky_map_for_combine(intensity_offset=20, exposure_offset=5) + sky_maps = {"ram": ram_map, "anti": anti_map} + + result = combine_maps(sky_maps) + + # Check that result is a RectangularSkyMap + assert isinstance(result, RectangularSkyMap) + + # Check additive variables + expected_counts = 100 + 120 # 100 + (100 + 20) + np.testing.assert_array_almost_equal( + result.data_1d["counts"].values, + np.ones_like(result.data_1d["counts"].values) * expected_counts, + ) + + expected_exposure = 10 + 15 # 10 + (10 + 5) + np.testing.assert_array_almost_equal( + result.data_1d["exposure_factor"].values, + np.ones_like(result.data_1d["exposure_factor"].values) * expected_exposure, + ) + + +def test_combine_maps_intensity_weighting(mock_sky_map_for_combine): + """Test that ena_intensity is combined with inverse-variance weighting.""" + ram_map = mock_sky_map_for_combine() + anti_map = mock_sky_map_for_combine(intensity_offset=20) + + # Give anti map higher uncertainty (lower weight) + anti_map.data_1d["ena_intensity_stat_uncert"] = xr.full_like( + anti_map.data_1d["ena_intensity_stat_uncert"], 10.0 + ) + + sky_maps = {"ram": ram_map, "anti": anti_map} + result = combine_maps(sky_maps) + + # Ram has uncertainty 5, anti has uncertainty 10 + # Weights: ram = 1/25 = 0.04, anti = 1/100 = 0.01 + # Weighted average: (50 * 0.04 + 70 * 0.01) / (0.04 + 0.01) = (2 + 0.7) / 0.05 = 54 + expected_intensity = (50 * 0.04 + 70 * 0.01) / (0.04 + 0.01) + np.testing.assert_array_almost_equal( + result.data_1d["ena_intensity"].values.flat[0], + expected_intensity, + decimal=5, + ) + + +def test_combine_maps_sys_err_exposure_weighted(mock_sky_map_for_combine): + """Test that systematic errors are combined with exposure weighting.""" + ram_map = mock_sky_map_for_combine() + anti_map = mock_sky_map_for_combine() + + # Set specific sys_err and exposure_factor values + ram_map.data_1d["ena_intensity_sys_err"] = xr.full_like( + ram_map.data_1d["ena_intensity_sys_err"], 5.0 + ) + ram_map.data_1d["exposure_factor"] = xr.full_like( + ram_map.data_1d["exposure_factor"], 1.0 + ) + anti_map.data_1d["ena_intensity_sys_err"] = xr.full_like( + anti_map.data_1d["ena_intensity_sys_err"], 5.0 + ) + anti_map.data_1d["exposure_factor"] = xr.full_like( + anti_map.data_1d["exposure_factor"], 4.0 + ) + + sky_maps = {"ram": ram_map, "anti": anti_map} + result = combine_maps(sky_maps) + + # Exposure weighted sum: (5 * 1 + 5 * 4) / (1 + 4) + expected_sys_err = 5.0 + np.testing.assert_array_almost_equal( + result.data_1d["ena_intensity_sys_err"].values.flat[0], + expected_sys_err, + decimal=10, + ) + + +def test_combine_maps_obs_date_exposure_weighted(mock_sky_map_for_combine): + """Test that obs_date is combined with exposure weighting.""" + ram_map = mock_sky_map_for_combine() + anti_map = mock_sky_map_for_combine() + + # Ram: obs_date=1000, exposure=10 + ram_map.data_1d["obs_date"] = xr.full_like(ram_map.data_1d["obs_date"], 1000) + ram_map.data_1d["exposure_factor"] = xr.full_like( + ram_map.data_1d["exposure_factor"], 10 + ) + + # Anti: obs_date=2000, exposure=30 + anti_map.data_1d["obs_date"] = xr.full_like(anti_map.data_1d["obs_date"], 2000) + anti_map.data_1d["exposure_factor"] = xr.full_like( + anti_map.data_1d["exposure_factor"], 30 + ) + + sky_maps = {"ram": ram_map, "anti": anti_map} + result = combine_maps(sky_maps) + + # Exposure-weighted average: (1000*10 + 2000*30) / (10+30) = 70000/40 = 1750 + expected_obs_date = (1000 * 10 + 2000 * 30) // 40 # Integer division + + # obs_date is cast to int64 after combining + assert result.data_1d["obs_date"].dtype == np.int64 + np.testing.assert_array_equal( + result.data_1d["obs_date"].values.flat[0], + expected_obs_date, + ) + + +def test_combine_maps_obs_date_range(mock_sky_map_for_combine): + """Test that obs_date_range accounts for within and between-group variance.""" + ram_map = mock_sky_map_for_combine() + anti_map = mock_sky_map_for_combine() + + # Ram: obs_date=1000, obs_date_range=100, exposure=10 + ram_map.data_1d["obs_date"] = xr.full_like(ram_map.data_1d["obs_date"], 1000) + ram_map.data_1d["obs_date_range"] = xr.full_like( + ram_map.data_1d["obs_date_range"], 100 + ) + ram_map.data_1d["exposure_factor"] = xr.full_like( + ram_map.data_1d["exposure_factor"], 10 + ) + + # Anti: obs_date=2000, obs_date_range=200, exposure=30 + anti_map.data_1d["obs_date"] = xr.full_like(anti_map.data_1d["obs_date"], 2000) + anti_map.data_1d["obs_date_range"] = xr.full_like( + anti_map.data_1d["obs_date_range"], 200 + ) + anti_map.data_1d["exposure_factor"] = xr.full_like( + anti_map.data_1d["exposure_factor"], 30 + ) + + sky_maps = {"ram": ram_map, "anti": anti_map} + result = combine_maps(sky_maps) + + # Calculate expected combined variance + w1, w2 = 10, 30 + sigma1, sigma2 = 100, 200 + mu1, mu2 = 1000, 2000 + total_exp = w1 + w2 + + within_variance = (w1 * sigma1**2 + w2 * sigma2**2) / total_exp + between_variance = (w1 * w2 * (mu1 - mu2) ** 2) / (total_exp**2) + expected_range = np.sqrt(within_variance + between_variance) + + # obs_date_range is cast to int64 after combining, so compare with truncated value + assert result.data_1d["obs_date_range"].dtype == np.int64 + np.testing.assert_array_equal( + result.data_1d["obs_date_range"].values.flat[0], + int(expected_range), + ) + + +def test_combine_maps_invalid_length(): + """Test that combine_maps raises error for invalid number of maps.""" + descriptor = MapDescriptor.from_string("h90-ena-h-hf-nsp-full-gcs-6deg-3mo") + sky_map = descriptor.to_empty_map() + + with pytest.raises(ValueError, match="Expected 1 or 2 sky maps"): + combine_maps({"a": sky_map, "b": sky_map, "c": sky_map})