Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion imap_processing/ena_maps/ena_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,7 @@ def update_az_el_points(self) -> None:
The values stored in the "hae_longitude" and "hae_latitude" variables
are used to construct the azimuth and elevation coordinates.
"""
logger.info(
logger.debug(
"Updating az/el points based on data in hae_longitude and"
"hae_latitude variables."
)
Expand Down
7 changes: 6 additions & 1 deletion imap_processing/ena_maps/utils/corrections.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,7 @@ def interpolate_map_flux_to_helio_frame(
esa_energies: xr.DataArray,
helio_energies: xr.DataArray,
vars_to_interpolate: list[str],
update_sys_err: bool = True,
) -> xr.Dataset:
"""
Interpolate flux from spacecraft frame to heliocentric frame energies.
Expand Down Expand Up @@ -806,6 +807,9 @@ def interpolate_map_flux_to_helio_frame(
dataset and will be interpolated as well. For example, if ["ena_intensity"]
is input, then the variables "ena_intensity", "ena_intensity_stat_uncert",
and "ena_intensity_sys_err" will be interpolated.
update_sys_err : bool, optional
Flag indicating whether to update the systematic error variables as part
of the flux interpolation. Defaults to True.

Returns
-------
Expand Down Expand Up @@ -912,7 +916,8 @@ def interpolate_map_flux_to_helio_frame(
# Update the dataset with interpolated values
map_ds[var_name] = flux_helio
map_ds[f"{var_name}_stat_uncert"] = stat_unc_helio
map_ds[f"{var_name}_sys_err"] = sys_err_helio
if update_sys_err:
map_ds[f"{var_name}_sys_err"] = sys_err_helio

return map_ds

Expand Down
248 changes: 205 additions & 43 deletions imap_processing/hi/hi_l2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,11 @@
"counts",
"exposure_factor",
"bg_rate",
"bg_rate_sys_err",
"obs_date",
}
HELIO_FRAME_VARS_TO_PROJECT = SC_FRAME_VARS_TO_PROJECT | {"energy_sc"}
# TODO: is an exposure time weighted average for obs_date appropriate?
FULL_EXPOSURE_TIME_AVERAGE_SET = {"bg_rate", "bg_rate_sys_err", "obs_date", "energy_sc"}
FULL_EXPOSURE_TIME_AVERAGE_SET = {"bg_rate", "obs_date", "energy_sc"}


# =============================================================================
Expand Down Expand Up @@ -90,21 +89,25 @@ def hi_l2(
)

logger.info(f"Step 1: Creating sky map from {len(psets)} pointing sets")
sky_map = create_sky_map_from_psets(
sky_maps = create_sky_map_from_psets(
psets,
l2_ancillary_path_dict,
map_descriptor,
)

logger.info("Step 2: Calculating rates and intensities")
sky_map.data_1d = calculate_all_rates_and_intensities(
sky_map.data_1d,
l2_ancillary_path_dict,
map_descriptor,
)
for sky_map in sky_maps.values():
sky_map.data_1d = calculate_all_rates_and_intensities(
sky_map.data_1d,
l2_ancillary_path_dict,
map_descriptor,
)

logger.info("Step 3: Finalizing dataset with attributes")
l2_ds = sky_map.build_cdf_dataset(
logger.info("Step 3: Combining maps (if needed)")
final_map = combine_maps(sky_maps)

logger.info("Step 4: Finalizing dataset with attributes")
l2_ds = final_map.build_cdf_dataset(
"hi",
"l2",
descriptor,
Expand All @@ -124,7 +127,7 @@ def create_sky_map_from_psets(
psets: list[str | Path],
l2_ancillary_path_dict: dict[str, Path],
descriptor: MapDescriptor,
) -> RectangularSkyMap:
) -> dict[str, RectangularSkyMap]:
"""
Project Hi PSET data into a sky map.

Expand All @@ -141,18 +144,35 @@ def create_sky_map_from_psets(

Returns
-------
sky_map : RectangularSkyMap
The sky map with all the PSET data projected into the map. Includes
sky_maps : dict[str, RectangularSkyMap]
Dictionary mapping spin phase keys ("ram", "anti", or "full")
to sky maps with all the PSET data projected. Includes
an energy coordinate and energy_delta_minus and energy_delta_plus
variables from ESA energy calibration data.
variables from ESA energy calibration data. For helioframe full-spin
maps, contains "ram" and "anti" keys; otherwise contains a single key
matching the descriptor's spin_phase.
"""
if len(psets) == 0:
raise ValueError("No PSETs provided for map creation")

output_map = descriptor.to_empty_map()

if not isinstance(output_map, RectangularSkyMap):
# If we are making a full-spin, helio-frame map, we need to make one ram
# and one anti-ram map that get combined at the final L2 step.
Comment on lines +158 to +159
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This reminds me of Lo needing to make an extra "Oxygen" map for sputtering calculations and needing to carry it through to some extent as well. I'm wondering if this dictionary-like mapping should be carried over to the Lo calculations as well to keep things consistent or if there is another way of organizing this multi-map projections work across all instruments.

Just noting this as a comment, nothing to do here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good point. I was toying with the idea of generating a dataset that had yet another dimension added for "spin_phase" which could later be combined. I think that some of these things would come easier if we did a refactor of the mapping code to leverage the xarray custom accessor functions.

if descriptor.frame_descriptor == "hf" and descriptor.spin_phase == "full":
# The spin-phase of the descriptor has no effect on the output map, so
# we can just use descriptor.to_empty_map() to generate both.
output_maps = {
"ram": descriptor.to_empty_map(),
"anti": descriptor.to_empty_map(),
}
else:
output_maps = {descriptor.spin_phase: descriptor.to_empty_map()}

if not all([isinstance(map, RectangularSkyMap) for map in output_maps.values()]):
raise NotImplementedError("Healpix map output not supported for Hi")
# Needed for mypy type narrowing
rect_maps: dict[str, RectangularSkyMap] = {
k: v for k, v in output_maps.items() if isinstance(v, RectangularSkyMap)
}

vars_to_bin = (
HELIO_FRAME_VARS_TO_PROJECT
Expand Down Expand Up @@ -187,31 +207,31 @@ def create_sky_map_from_psets(
vars_to_exposure_time_average,
)

# Project (bin) the PSET variables into the map pixels
directional_mask = get_pset_directional_mask(
pset_processed, descriptor.spin_phase
)
hi_pset = HiPointingSet(pset_processed)
output_map.project_pset_values_to_map(
hi_pset, list(vars_to_bin), pset_valid_mask=directional_mask
)

# Finish the exposure time weighted mean calculation of backgrounds
# Allow divide by zero to fill set pixels with zero exposure time to NaN
with np.errstate(divide="ignore"):
for var in vars_to_exposure_time_average:
output_map.data_1d[var] /= output_map.data_1d["exposure_factor"]

# Add ESA energy data to the map dataset for use in rate/intensity calculations
energy_delta = esa_ds["bandpass_fwhm"] / 2
output_map.data_1d["energy_delta_minus"] = energy_delta
output_map.data_1d["energy_delta_plus"] = energy_delta
# Add energy as an auxiliary coordinate (keV values indexed by esa_energy_step)
output_map.data_1d = output_map.data_1d.assign_coords(
energy=("esa_energy_step", esa_ds["nominal_central_energy"].values)
)
for spin_phase, map in rect_maps.items():
# Project (bin) the PSET variables into the map pixels
directional_mask = get_pset_directional_mask(pset_processed, spin_phase)
map.project_pset_values_to_map(
hi_pset, list(vars_to_bin), pset_valid_mask=directional_mask
)

return output_map
for map in rect_maps.values():
# Finish the exposure time weighted mean calculation of backgrounds
# Allow divide by zero to fill set pixels with zero exposure time to NaN
with np.errstate(divide="ignore"):
map.data_1d[vars_to_exposure_time_average] /= map.data_1d["exposure_factor"]

# Add ESA energy data to the map dataset for use in rate/intensity calculations
energy_delta = esa_ds["bandpass_fwhm"] / 2
map.data_1d["energy_delta_minus"] = energy_delta
map.data_1d["energy_delta_plus"] = energy_delta
# Add energy as an auxiliary coordinate (keV values indexed by esa_energy_step)
map.data_1d = map.data_1d.assign_coords(
energy=("esa_energy_step", esa_ds["nominal_central_energy"].values)
)

return rect_maps


# =============================================================================
Expand Down Expand Up @@ -323,13 +343,11 @@ def calculate_all_rates_and_intensities(
# TODO: Handle variable types correctly in RectangularSkyMap.build_cdf_dataset
obs_date = map_ds["obs_date"]
# Replace non-finite values with the int64 sentinel before casting
obs_date_filled = xr.where(
map_ds["obs_date"] = xr.where(
np.isfinite(obs_date),
obs_date,
obs_date.astype("int64"),
np.int64(-9223372036854775808),
)
map_ds["obs_date"] = obs_date_filled.astype("int64")
# TODO: Figure out how to compute obs_date_range (stddev of obs_date)
map_ds["obs_date_range"] = xr.zeros_like(map_ds["obs_date"])

# Step 4: Swap esa_energy_step dimension for energy coordinate
Expand All @@ -343,15 +361,21 @@ def calculate_all_rates_and_intensities(
logger.debug("Applying Compton-Getting interpolation for heliocentric frame")
# Convert energy coordinate from keV to eV for interpolation
esa_energy_ev = map_ds["energy"] * 1000

# Hi does not want to apply the flux correction to the systematic error.
map_ds = interpolate_map_flux_to_helio_frame(
map_ds,
esa_energy_ev, # ESA energies in eV
esa_energy_ev, # heliocentric energies (same as ESA energies)
["ena_intensity"],
update_sys_err=False, # Hi does not update the systematic error
)
# Drop any esa_energy_step_label that may have been re-added
map_ds = map_ds.drop_vars(["esa_energy_step_label"], errors="ignore")

# Step 6: Clean up intermediate variables
map_ds = cleanup_intermediate_variables(map_ds)

return map_ds


Expand Down Expand Up @@ -525,6 +549,117 @@ def combine_calibration_products(
return map_ds


def combine_maps(sky_maps: dict[str, RectangularSkyMap]) -> RectangularSkyMap:
"""
Combine ram and anti-ram sky maps using appropriate weighting.

For full-spin heliocentric frame maps, ram and anti-ram maps are processed
separately through the CG correction pipeline, then combined here using
inverse-variance weighting for intensity and appropriate methods for other
variables.

Parameters
----------
sky_maps : dict[str, RectangularSkyMap]
Dictionary of sky maps to combine. Expected to contain either 1 map
(no combination needed) or 2 maps with "ram" and "anti" keys.

Returns
-------
RectangularSkyMap
Combined sky map.
"""
if len(sky_maps) not in [1, 2]:
raise ValueError(f"Expected 1 or 2 sky maps, got {len(sky_maps)}")
if len(sky_maps) == 1:
logger.debug("Only one sky map provided, returning it unchanged")
return next(iter(sky_maps.values()))
if "ram" not in sky_maps or "anti" not in sky_maps:
raise ValueError(
f"Expected sky maps with 'ram' and 'anti' keys."
f"Instead got: {sky_maps.keys()}"
)

logger.info("Combining ram and anti-ram maps using inverse-variance weighting.")

ram_ds = sky_maps["ram"].data_1d
anti_ds = sky_maps["anti"].data_1d
Comment on lines +585 to +586
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this would be useful to have more general, but I think a lot of the math below could be generalized to something where you are working with a list of datasets input (not just ram/anti, but maybe combining multiple 7-day maps or something like that in the future...)

datasets = [ds.data_1d for ds in sky_maps.values()]

weights = [xr.where(ds["ena_intensity_stat_uncert"] > 0, 1 / ds["ena_intensity_stat_uncert"] ** 2, 0) for ds in datasets]
...

Again, not needed here and I'm just writing it down for a quick note. It should be another ticket if this would even be useful.


# Use the ram sky map as the base for the combined result
combined_map = sky_maps["ram"]
combined = ram_ds.copy()

# Additive variables: counts and exposure_factor
combined["counts"] = ram_ds["counts"] + anti_ds["counts"]
combined["exposure_factor"] = ram_ds["exposure_factor"] + anti_ds["exposure_factor"]

# Inverse-variance weighted average for ena_intensity
weight_ram = xr.where(
ram_ds["ena_intensity_stat_uncert"] > 0,
1 / ram_ds["ena_intensity_stat_uncert"] ** 2,
0,
)
weight_anti = xr.where(
anti_ds["ena_intensity_stat_uncert"] > 0,
1 / anti_ds["ena_intensity_stat_uncert"] ** 2,
0,
)
total_weight = weight_ram + weight_anti

with np.errstate(divide="ignore", invalid="ignore"):
combined["ena_intensity"] = (
ram_ds["ena_intensity"] * weight_ram
+ anti_ds["ena_intensity"] * weight_anti
) / total_weight

combined["ena_intensity_stat_uncert"] = np.sqrt(1 / total_weight)

# Exposure-weighted average for systematic error
total_exp = combined["exposure_factor"]
combined["ena_intensity_sys_err"] = (
ram_ds["ena_intensity_sys_err"] * ram_ds["exposure_factor"]
+ anti_ds["ena_intensity_sys_err"] * anti_ds["exposure_factor"]
) / total_exp

# Exposure-weighted average for obs_date
with np.errstate(divide="ignore", invalid="ignore"):
combined["obs_date"] = (
ram_ds["obs_date"] * ram_ds["exposure_factor"]
+ anti_ds["obs_date"] * anti_ds["exposure_factor"]
) / total_exp

# Combined obs_date_range accounts for within-group and between-group variance
# var_combined = (w1*var1 + w2*var2)/(w1+w2) + w1*w2*(mean1-mean2)^2/(w1+w2)^2
within_variance = (
ram_ds["exposure_factor"] * ram_ds["obs_date_range"] ** 2
+ anti_ds["exposure_factor"] * anti_ds["obs_date_range"] ** 2
) / total_exp
between_variance = (
ram_ds["exposure_factor"]
* anti_ds["exposure_factor"]
* (ram_ds["obs_date"] - anti_ds["obs_date"]) ** 2
) / (total_exp**2)
combined["obs_date_range"] = np.sqrt(within_variance + between_variance)

# Re-cast obs_date and obs_date_range back to int64 after float arithmetic.
# Replace non-finite values with the int64 sentinel value.
# TODO: Handle variable types correctly in RectangularSkyMap.build_cdf_dataset
int64_sentinel = np.int64(-9223372036854775808)
combined["obs_date"] = xr.where(
np.isfinite(combined["obs_date"]),
combined["obs_date"].astype("int64"),
int64_sentinel,
)
combined["obs_date_range"] = xr.where(
np.isfinite(combined["obs_date_range"]),
combined["obs_date_range"].astype("int64"),
int64_sentinel,
)

combined_map.data_1d = combined
return combined_map


def _calculate_improved_stat_variance(
map_ds: xr.Dataset,
geometric_factors: xr.DataArray,
Expand Down Expand Up @@ -600,6 +735,33 @@ def _calculate_improved_stat_variance(
return improved_variance


def cleanup_intermediate_variables(dataset: xr.Dataset) -> xr.Dataset:
"""
Remove intermediate variables that were only needed for calculations.

Parameters
----------
dataset : xarray.Dataset
Dataset containing intermediate calculation variables.

Returns
-------
xarray.Dataset
Cleaned dataset with intermediate variables removed.
"""
# Remove the intermediate variables from the map
potential_vars = [
"bg_rate",
"energy_sc",
"ena_signal_rates",
"ena_signal_rate_stat_unc",
]

vars_to_remove = [var for var in potential_vars if var in dataset.data_vars]

return dataset.drop_vars(vars_to_remove)


# =============================================================================
# SETUP AND INITIALIZATION HELPERS
# =============================================================================
Expand Down
Loading