Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/src/api/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@ Sample is build from assemblies.

sample

Project
=======
Project provides a higher-level interface for managing models, experiments, and ORSO import.

.. toctree::
:maxdepth: 1

project

Assemblies
==========
Assemblies are collections of layers that are used to represent a specific physical setup.
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ classifiers = [
requires-python = ">=3.11,<3.13"

dependencies = [
#"easyscience @ git+https://github.com/easyscience/corelib.git@dict_size_changed_bug",
"easyscience",
"easyscience @ git+https://github.com/easyscience/corelib.git@develop",
#"easyscience",
"scipp",
"refnx",
"refl1d>=1.0.0rc0",
"refl1d>=1.0.0",
"orsopy",
"svglib<1.6 ; platform_system=='Linux'",
"xhtml2pdf",
Expand Down
2 changes: 1 addition & 1 deletion src/easyreflectometry/calculators/calculator_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from easyscience.fitting.calculators.interface_factory import ItemContainer
from easyscience.io import SerializerComponent

#if TYPE_CHECKING:
# if TYPE_CHECKING:
from easyreflectometry.model import Model
from easyreflectometry.sample import BaseAssembly
from easyreflectometry.sample import Layer
Expand Down
4 changes: 2 additions & 2 deletions src/easyreflectometry/data/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(
y: Optional[Union[np.ndarray, list]] = None,
ye: Optional[Union[np.ndarray, list]] = None,
xe: Optional[Union[np.ndarray, list]] = None,
model: Optional['Model'] = None, # delay type checking until runtime (quotes)
model: Optional['Model'] = None, # delay type checking until runtime (quotes)
x_label: str = 'x',
y_label: str = 'y',
):
Expand Down Expand Up @@ -117,7 +117,7 @@ def __init__(
self._color = None

@property
def model(self) -> 'Model': # delay type checking until runtime (quotes)
def model(self) -> 'Model': # delay type checking until runtime (quotes)
return self._model

@model.setter
Expand Down
6 changes: 3 additions & 3 deletions src/easyreflectometry/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ def fit(self, data: sc.DataGroup, id: int = 0) -> sc.DataGroup:
variances = data['data'][f'R_{i}'].variances

# Find points with non-zero variance
zero_variance_mask = (variances == 0.0)
zero_variance_mask = variances == 0.0
num_zero_variance = np.sum(zero_variance_mask)

if num_zero_variance > 0:
warnings.warn(
f"Masked {num_zero_variance} data point(s) in reflectivity {i} due to zero variance during fitting.",
UserWarning
f'Masked {num_zero_variance} data point(s) in reflectivity {i} due to zero variance during fitting.',
UserWarning,
)

# Keep only points with non-zero variances
Expand Down
16 changes: 8 additions & 8 deletions src/easyreflectometry/orso_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,16 @@

return sample, data


def load_data_from_orso_file(fname: str) -> sc.DataGroup:
"""Load data from an ORSO file."""
try:
orso_data = orso.load_orso(fname)
except Exception as e:
raise ValueError(f"Error loading ORSO file: {e}")
raise ValueError(f'Error loading ORSO file: {e}')
return load_orso_data(orso_data)


def load_orso_model(orso_str: str) -> Sample:
"""
Load a model from an ORSO file and return a Sample object.
Expand Down Expand Up @@ -64,9 +66,9 @@

# Handle case where layers are not resolved correctly
if not orso_layers:
raise ValueError("Could not resolve ORSO layers.")
raise ValueError('Could not resolve ORSO layers.')

Check warning on line 69 in src/easyreflectometry/orso_utils.py

View check run for this annotation

Codecov / codecov/patch

src/easyreflectometry/orso_utils.py#L69

Added line #L69 was not covered by tests

logger.debug(f"Resolved layers: {orso_layers}")
logger.debug(f'Resolved layers: {orso_layers}')

# Convert ORSO layers to EasyReflectometry layers
erl_layers = []
Expand Down Expand Up @@ -98,7 +100,7 @@
material=Material(sld=m_sld, isld=m_isld, name=m_name),
thickness=layer.thickness.magnitude if layer.thickness is not None else 0.0,
roughness=layer.roughness.magnitude if layer.roughness is not None else 0.0,
name=layer.original_name if layer.original_name is not None else m_name
name=layer.original_name if layer.original_name is not None else m_name,
)


Expand All @@ -107,10 +109,7 @@
if material.sld is None and material.mass_density is not None:
# Calculate SLD from mass density
m_density = material.mass_density.magnitude
density = MaterialDensity(
chemical_structure=material_name,
density=m_density
)
density = MaterialDensity(chemical_structure=material_name, density=m_density)
m_sld = density.sld.value
m_isld = density.isld.value
else:
Expand All @@ -123,6 +122,7 @@

return m_sld, m_isld


def load_orso_data(orso_str: str) -> DataSet1D:
data = {}
coords = {}
Expand Down
51 changes: 50 additions & 1 deletion src/easyreflectometry/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from easyreflectometry.data import DataSet1D
from easyreflectometry.data import load_as_dataset
from easyreflectometry.fitting import MultiFitter

# from easyreflectometry.model import LinearSpline
from easyreflectometry.model import Model
from easyreflectometry.model import ModelCollection
from easyreflectometry.model import PercentageFwhm
Expand Down Expand Up @@ -268,10 +270,53 @@ def load_orso_file(self, path: Union[Path, str]) -> None:
self._with_experiments = True
pass

def set_sample_from_orso(self, sample) -> None:
def set_sample_from_orso(self, sample: Sample) -> None:
"""Replace the current project model collection with a single model built from an ORSO-parsed sample.

This is a convenience helper for the ORSO import pipeline where a complete
:class:`~easyreflectometry.sample.Sample` is constructed elsewhere.

:param sample: Sample to set as the project's (single) model.
:type sample: easyreflectometry.sample.Sample
:return: ``None``.
:rtype: None
"""
model = Model(sample=sample)
self.models = ModelCollection([model])

def add_sample_from_orso(self, sample: Sample) -> None:
"""Add a new model with the given sample to the existing model collection.

The created model is appended to :attr:`models`, its calculator interface is
set to the project's current calculator, and any materials referenced in the
sample are added to the project's material collection.

After adding the model, :attr:`current_model_index` is updated to point to
the newly added model.

:param sample: Sample to add as a new model.
:type sample: easyreflectometry.sample.Sample
:return: ``None``.
:rtype: None
"""
model = Model(sample=sample)
self.models.add_model(model)
# Set interface after adding to collection
model.interface = self._calculator
# Extract materials from the new model and add to project materials
self._materials.extend(self._get_materials_from_model(model))
# Switch to the newly added model so its data is visible in the UI
self.current_model_index = len(self._models) - 1

def _get_materials_from_model(self, model: Model) -> 'MaterialCollection':
"""Get all materials from a single model's sample."""
materials_in_model = MaterialCollection(populate_if_none=False)
for assembly in model.sample:
for layer in assembly.layers:
if layer.material not in materials_in_model:
materials_in_model.append(layer.material)
return materials_in_model

def load_new_experiment(self, path: Union[Path, str]) -> None:
new_experiment = load_as_dataset(str(path))
new_index = len(self._experiments)
Expand All @@ -291,6 +336,10 @@ def load_new_experiment(self, path: Union[Path, str]) -> None:
q_error = new_experiment.xe
# TODO: set resolution function based on value of control in GUI
resolution_function = Pointwise(q_data_points=[q, reflectivity, q_error])
# resolution_function = LinearSpline(
# q_data_points=self._experiments[new_index].y,
# fwhm_values=np.sqrt(self._experiments[new_index].ye),
# )
self.models[model_index].resolution_function = resolution_function

def load_experiment_for_model_at_index(self, path: Union[Path, str], index: Optional[int] = 0) -> None:
Expand Down
4 changes: 2 additions & 2 deletions tests/calculators/refl1d/test_refl1d_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_reflectity_profile(self):
5.7605e-07,
2.3775e-07,
1.3093e-07,
1.0520e-07
1.0520e-07,
]
assert_almost_equal(p.reflectity_profile(q, 'MyModel'), expected, decimal=4)

Expand Down Expand Up @@ -106,7 +106,7 @@ def test_calculate2(self):
1.0968e-06,
4.5635e-07,
3.4120e-07,
2.7505e-07
2.7505e-07,
]
assert_almost_equal(actual, expected, decimal=4)

Expand Down
6 changes: 3 additions & 3 deletions tests/calculators/refl1d/test_refl1d_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def test_calculate(self):
5.7605e-07,
2.3775e-07,
1.3093e-07,
1.0520e-07
1.0520e-07,
]
assert_almost_equal(p.calculate(q, 'MyModel'), expected, decimal=4)

Expand Down Expand Up @@ -276,7 +276,7 @@ def test_calculate_three_items(self):
1.0968e-06,
4.5635e-07,
3.4120e-07,
2.7505e-07
2.7505e-07,
]
assert_almost_equal(p.calculate(q, 'MyModel'), expected, decimal=4)

Expand Down Expand Up @@ -396,7 +396,7 @@ def test_get_polarized_probe_oversampling():
probe = _get_polarized_probe(q_array=q, dq_array=dq, model_name=model_name, storage=storage, oversampling_factor=2)

# Then
assert len(probe.xs[0].calc_Qo) == 2*len(q)
assert len(probe.xs[0].calc_Qo) == 2 * len(q)


def test_get_polarized_probe_polarization():
Expand Down
19 changes: 4 additions & 15 deletions tests/data/test_data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,7 @@ def test_constructor_default_values(self):
def test_constructor_with_values(self):
# When
data = DataSet1D(
x=[1, 2, 3],
y=[4, 5, 6],
ye=[7, 8, 9],
xe=[10, 11, 12],
x_label='label_x',
y_label='label_y',
name='MyDataSet1D'
x=[1, 2, 3], y=[4, 5, 6], ye=[7, 8, 9], xe=[10, 11, 12], x_label='label_x', y_label='label_y', name='MyDataSet1D'
)

# Then
Expand Down Expand Up @@ -116,19 +110,15 @@ def test_is_simulation_property(self):

def test_data_points(self):
# When
data = DataSet1D(
x=[1, 2, 3], y=[4, 5, 6], ye=[7, 8, 9], xe=[10, 11, 12]
)
data = DataSet1D(x=[1, 2, 3], y=[4, 5, 6], ye=[7, 8, 9], xe=[10, 11, 12])

# Then
points = list(data.data_points())
assert points == [(1, 4, 7, 10), (2, 5, 8, 11), (3, 6, 9, 12)]

def test_repr(self):
# When
data = DataSet1D(
x=[1, 2, 3], y=[4, 5, 6], x_label='Q', y_label='R'
)
data = DataSet1D(x=[1, 2, 3], y=[4, 5, 6], x_label='Q', y_label='R')

# Then
expected = "1D DataStore of 'Q' Vs 'R' with 3 data points"
Expand Down Expand Up @@ -194,7 +184,7 @@ def test_setitem(self):
item1 = DataSet1D(name='item1')
item2 = DataSet1D(name='item2')
store = DataStore(item1)

# When
store[0] = item2

Expand Down Expand Up @@ -314,4 +304,3 @@ def test_constructor_with_custom_datastores(self):
assert project.sim_data == sim_store
assert project.exp_data.name == 'CustomExp'
assert project.sim_data.name == 'CustomSim'

2 changes: 1 addition & 1 deletion tests/summary/test_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def test_save_sld_plot(self, project: Project, tmp_path) -> None:
# Expect
assert os.path.exists(file_path)

@pytest.mark.skip(reason="Matplotlib issue with headless CI environments")
@pytest.mark.skip(reason='Matplotlib issue with headless CI environments')
def test_save_fit_experiment_plot(self, project: Project, tmp_path) -> None:
# When
summary = Summary(project)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,7 @@ def test_load_txt_three_columns(self):
assert coords_name in er_data['coords']

# xe should be zeros for 3-column file
assert_almost_equal(er_data['coords'][coords_name].variances,
np.zeros_like(er_data['coords'][coords_name].values))
assert_almost_equal(er_data['coords'][coords_name].variances, np.zeros_like(er_data['coords'][coords_name].values))

def test_load_txt_with_zero_errors(self):
fpath = os.path.join(PATH_STATIC, 'ref_zero_var.txt')
Expand All @@ -246,6 +245,7 @@ def test_load_txt_file_not_found(self):
def test_load_txt_insufficient_columns(self):
# Create a temporary file with insufficient columns
import tempfile

with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
f.write('1.0 2.0\n') # Only 2 columns
temp_path = f.name
Expand All @@ -272,7 +272,7 @@ def test_load_orso_multiple_datasets(self):
if data_key.replace('R_', '') in coord_key:
coord_key_found = True
break
assert coord_key_found, f"No corresponding coord found for {data_key}"
assert coord_key_found, f'No corresponding coord found for {data_key}'

def test_load_orso_with_attrs(self):
fpath = os.path.join(PATH_STATIC, 'test_example1.ort')
Expand Down
Loading
Loading