diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index c240ec236..f67e9fa20 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -53,6 +53,10 @@ jobs: run: | echo "API_KEY_NAME=$(echo ${{ format('MP_API_KEY_{0}_{1}', matrix.os, matrix.python-version) }} | awk '{gsub(/-|\./, "_"); print}' | tr '[:lower:]' '[:upper:]')" | Out-File -FilePath $Env:GITHUB_ENV -Encoding utf8 -Append + - name: Lint with mypy + shell: bash -l {0} + run: python -m mypy mp_api/ + - name: Test with pytest env: MP_API_KEY: ${{ secrets[env.API_KEY_NAME] }} diff --git a/mp_api/__init__.py b/mp_api/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mp_api/client/__init__.py b/mp_api/client/__init__.py index 7895061b3..1fd6d23f8 100644 --- a/mp_api/client/__init__.py +++ b/mp_api/client/__init__.py @@ -10,4 +10,4 @@ try: __version__ = version("mp_api") except PackageNotFoundError: # pragma: no cover - __version__ = os.getenv("SETUPTOOLS_SCM_PRETEND_VERSION") + __version__ = os.getenv("SETUPTOOLS_SCM_PRETEND_VERSION", "") diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 0545d78dd..33820207f 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -35,7 +35,7 @@ from tqdm.auto import tqdm from urllib3.util.retry import Retry -from mp_api.client.core.exceptions import MPRestError +from mp_api.client.core.exceptions import MPRestError, _emit_status_warning from mp_api.client.core.settings import MAPI_CLIENT_SETTINGS from mp_api.client.core.utils import ( load_json, @@ -46,8 +46,10 @@ try: import flask + + _flask_is_installed = True except ImportError: - flask = None + _flask_is_installed = False if TYPE_CHECKING: from typing import Any, Callable @@ -59,7 +61,7 @@ try: __version__ = version("mp_api") except PackageNotFoundError: # pragma: no cover - __version__ = os.getenv("SETUPTOOLS_SCM_PRETEND_VERSION") + __version__ = os.getenv("SETUPTOOLS_SCM_PRETEND_VERSION", "") class _DictLikeAccess(BaseModel): @@ -83,7 +85,7 @@ class BaseRester: """Base client class with core stubs.""" suffix: str = "" - document_model: type[BaseModel] | None = None + document_model: type[BaseModel] = _DictLikeAccess primary_key: str = "material_id" def __init__( @@ -222,7 +224,10 @@ def _get_database_version(endpoint): Returns: database version as a string """ - return requests.get(url=endpoint + "heartbeat").json()["db_version"] + if (get_resp := requests.get(url=endpoint + "heartbeat")).status_code == 403: + _emit_status_warning() + return + return get_resp.json()["db_version"] def _post_resource( self, @@ -407,7 +412,7 @@ def _query_resource( num_chunks: int | None = None, chunk_size: int | None = None, timeout: int | None = None, - ) -> dict: + ) -> dict[str, Any]: """Query the endpoint for a Resource containing a list of documents and meta information about pagination and total document count. @@ -539,12 +544,15 @@ def _query_resource( for docs, _, _ in byte_data: unzipped_data.extend(docs) - data = {"data": unzipped_data, "meta": {}} - - if self.use_document_model: - data["data"] = self._convert_to_model(data["data"]) + data: dict[str, Any] = { + "data": ( + self._convert_to_model(unzipped_data) # type: ignore[arg-type] + if self.use_document_model + else unzipped_data + ), + "meta": {"total_doc": len(unzipped_data)}, + } - data["meta"]["total_doc"] = len(data["data"]) else: data = self._submit_requests( url=url, @@ -672,7 +680,7 @@ def _submit_requests( # noqa new_limits = [chunk_size] total_num_docs = 0 - total_data: dict[str, list[Any]] = {"data": []} + total_data: dict[str, Any] = {"data": []} # Obtain first page of results and get pagination information. # Individual total document limits (subtotal) will potentially @@ -871,7 +879,7 @@ def _multi_thread( func: Callable, params_list: list[dict], progress_bar: tqdm | None = None, - ): + ) -> list[tuple[Any, int, int]]: """Handles setting up a threadpool and sending parallel requests. Arguments: @@ -962,7 +970,7 @@ def _submit_request_and_process( Tuple with data and total number of docs in matching the query in the database. """ headers = None - if flask is not None and flask.has_request_context(): + if _flask_is_installed and flask.has_request_context(): headers = flask.request.headers try: @@ -1015,7 +1023,9 @@ def _submit_request_and_process( f"on URL {response.url} with message:\n{message}" ) - def _convert_to_model(self, data: list[dict]): + def _convert_to_model( + self, data: list[dict[str, Any]] + ) -> list[BaseModel] | list[dict[str, Any]]: """Converts dictionary documents to instantiated MPDataDoc objects. Args: @@ -1028,7 +1038,7 @@ def _convert_to_model(self, data: list[dict]): if len(data) > 0: data_model, set_fields, _ = self._generate_returned_model(data[0]) - data = [ + return [ data_model( **{ field: value @@ -1043,7 +1053,7 @@ def _convert_to_model(self, data: list[dict]): def _generate_returned_model( self, doc: dict[str, Any] - ) -> tuple[BaseModel, list[str], list[str]]: + ) -> tuple[type[BaseModel], list[str], list[str]]: model_fields = self.document_model.model_fields set_fields = [k for k in doc if k in model_fields] unset_fields = [field for field in model_fields if field not in set_fields] @@ -1059,13 +1069,13 @@ def _generate_returned_model( ): vars(import_module(self.document_model.__module__)) - include_fields: dict[str, tuple[type, FieldInfo]] = {} + include_fields: dict[str, tuple[Any, FieldInfo]] = {} for name in set_fields: field_copy = model_fields[name]._copy() if not field_copy.default_factory: # Fields with a default_factory cannot also have a default in pydantic>=2.12.3 field_copy.default = None - include_fields[name] = ( + include_fields[name] = ( # type: ignore[assignment] Optional[model_fields[name].annotation], field_copy, ) @@ -1202,7 +1212,7 @@ def get_data_by_id( self, document_id: str, fields: list[str] | None = None, - ) -> BaseModel | dict: + ) -> BaseModel | dict[str, Any] | None: warnings.warn( "get_data_by_id is deprecated and will be removed soon. Please use the search method instead.", DeprecationWarning, @@ -1221,7 +1231,7 @@ def get_data_by_id( if isinstance(fields, str): # pragma: no cover fields = (fields,) # type: ignore - docs = self._search( # type: ignorech( # type: ignorech( # type: ignore + docs = self._search( **{self.primary_key + "s": document_id}, num_chunks=1, chunk_size=1, diff --git a/mp_api/client/core/exceptions.py b/mp_api/client/core/exceptions.py index fa9f87937..4f2d8d5ca 100644 --- a/mp_api/client/core/exceptions.py +++ b/mp_api/client/core/exceptions.py @@ -1,6 +1,8 @@ """Define custom exceptions and warnings for the client.""" from __future__ import annotations +import warnings + class MPRestError(Exception): """Raised when the query has problems, e.g., bad query format.""" @@ -8,3 +10,13 @@ class MPRestError(Exception): class MPRestWarning(Warning): """Raised when a query is malformed but interpretable.""" + + +def _emit_status_warning() -> None: + """Emit a warning if client can't hear a heartbeat.""" + warnings.warn( + "Cannot listen to heartbeat, check Materials Project " + "status page: https://status.materialsproject.org/", + category=MPRestWarning, + stacklevel=2, + ) diff --git a/mp_api/client/core/settings.py b/mp_api/client/core/settings.py index f5818d57a..75bee4526 100644 --- a/mp_api/client/core/settings.py +++ b/mp_api/client/core/settings.py @@ -14,7 +14,7 @@ _MAX_HTTP_URL_LENGTH = PMG_SETTINGS.get("MPRESTER_MAX_HTTP_URL_LENGTH", 2000) _MAX_LIST_LENGTH = min(PMG_SETTINGS.get("MPRESTER_MAX_LIST_LENGTH", 10000), 10000) -_EMMET_SETTINGS = EmmetSettings() +_EMMET_SETTINGS = EmmetSettings() # type: ignore[call-arg] _DEFAULT_ENDPOINT = "https://api.materialsproject.org/" @@ -109,4 +109,4 @@ def _get_endpoint_from_env(cls, v: str | None) -> str: return v or os.environ.get("MP_API_ENDPOINT") or _DEFAULT_ENDPOINT -MAPI_CLIENT_SETTINGS = MAPIClientSettings() +MAPI_CLIENT_SETTINGS: MAPIClientSettings = MAPIClientSettings() # type: ignore[call-arg] diff --git a/mp_api/client/core/utils.py b/mp_api/client/core/utils.py index 935e1453a..57510bac7 100644 --- a/mp_api/client/core/utils.py +++ b/mp_api/client/core/utils.py @@ -163,11 +163,11 @@ def __init__( import_str : str A dot-separated, import-like string. """ - if len(split_import_str := import_str.rsplit(".", 1)) > 1: - self._module_name, self._class_name = split_import_str + if len(split_import_str := import_str.rsplit(".", 1)) == 1: + self._module_name: str = split_import_str[0] + self._class_name: str | None = None else: - self._module_name = split_import_str[0] - self._class_name = None + self._module_name, self._class_name = split_import_str self._imported: Any | None = None self._obj: Any | None = None @@ -216,9 +216,9 @@ def __call__(self, *args, **kwargs) -> Any: if isinstance(self._imported, type): self._obj = self._imported(*args, **kwargs) return self._obj - else: + elif callable(self._imported): self._obj = self._imported - return self._obj(*args, **kwargs) + return self._obj(*args, **kwargs) # type: ignore[misc] def __getattr__(self, v: str) -> Any: """Get an attribute on a super lazy object.""" diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index 1d9afc5ca..f5d873272 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -20,8 +20,13 @@ from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from requests import Session, get -from mp_api.client.core import BaseRester, MPRestError, MPRestWarning +from mp_api.client.core import BaseRester from mp_api.client.core._oxygen_evolution import OxygenEvolution +from mp_api.client.core.exceptions import ( + MPRestError, + MPRestWarning, + _emit_status_warning, +) from mp_api.client.core.settings import MAPI_CLIENT_SETTINGS from mp_api.client.core.utils import ( LazyImport, @@ -35,12 +40,22 @@ from mp_api.client.routes.molecules import MOLECULES_RESTERS if TYPE_CHECKING: + from collections.abc import Sequence from typing import Any, Literal + import numpy as np from emmet.core.tasks import CoreTaskDoc + from packaging.version import Version from pymatgen.analysis.phase_diagram import PDEntry - from pymatgen.entries.computed_entries import ComputedEntry + from pymatgen.analysis.pourbaix_diagram import PourbaixEntry + from pymatgen.entries.compatibility import Compatibility + from pymatgen.entries.computed_entries import ( + ComputedEntry, + GibbsComputedStructureEntry, + ) + from pymatgen.util.typing import SpeciesLike + from mp_api.client.core.client import _DictLikeAccess DEFAULT_THERMOTYPE_CRITERIA = {"thermo_types": ["GGA_GGA+U"]} @@ -172,14 +187,15 @@ def __init__( ) # Check if emmet version of server is compatible - emmet_version = MPRester.get_emmet_version(self.endpoint) - - if version.parse(emmet_version.base_version) < version.parse( - MAPI_CLIENT_SETTINGS.MIN_EMMET_VERSION + if (emmet_version := MPRester.get_emmet_version(self.endpoint)) and ( + version.parse(emmet_version.base_version) + < version.parse(MAPI_CLIENT_SETTINGS.MIN_EMMET_VERSION) ): warnings.warn( "The installed version of the mp-api client may not be compatible with the API server. " - "Please install a previous version if any problems occur." + "Please install a previous version if any problems occur.", + category=MPRestWarning, + stacklevel=2, ) if notify_db_version: @@ -322,7 +338,7 @@ def get_structure_by_material_id( return structure_data - def get_database_version(self): + def get_database_version(self) -> str | None: """The Materials Project database is periodically updated and has a database version associated with it. When the database is updated, consolidated data (information about "a material") may and does @@ -335,20 +351,27 @@ def get_database_version(self): Returns: database version as a string """ - return get(url=self.endpoint + "heartbeat").json()["db_version"] + if (get_resp := get(url=self.endpoint + "heartbeat")).status_code == 403: + _emit_status_warning() + return None + return get_resp.json()["db_version"] @staticmethod @cache - def get_emmet_version(endpoint): + def get_emmet_version(endpoint) -> Version | None: """Get the latest version emmet-core and emmet-api used in the current API service. Returns: version as a string """ - response = get(url=endpoint + "heartbeat").json() + get_resp = get(url=endpoint + "heartbeat") + + if get_resp.status_code == 403: + _emit_status_warning() + return None - error = response.get("error", None) - if error: + response = get_resp.json() + if error := response.get("error", None): raise MPRestError(error) return version.parse(response["version"]) @@ -507,7 +530,7 @@ def get_entries( ) -> list[ComputedStructureEntry]: """Get a list of ComputedStructureEntry from a chemical system, or formula, or MPID. - This returns ComputedStructureEntries with final structures for all thermo types + This returns a list of ComputedStructureEntry with final structures for all thermo types represented in the database. Each type corresponds to a different mixing scheme (i.e. GGA/GGA+U, GGA/GGA+U/R2SCAN, R2SCAN). By default the thermo_type of the entry is also returned. @@ -613,7 +636,7 @@ def get_pourbaix_entries( chemsys: str | list[str] | list[ComputedEntry | ComputedStructureEntry], solid_compat="MaterialsProject2020Compatibility", use_gibbs: Literal[300] | None = None, - ): + ) -> list[PourbaixEntry]: """A helper function to get all entries necessary to generate a Pourbaix diagram from the rest interface. @@ -637,6 +660,9 @@ def get_pourbaix_entries( cases. Default: None. Note that temperatures other than 300K are not permitted here, because MaterialsProjectAqueousCompatibility corrections, used in Pourbaix diagram construction, are calculated based on 300 K data. + + Returns: + list of PourbaixEntry """ # imports are not top-level due to expense from pymatgen.analysis.pourbaix_diagram import PourbaixEntry @@ -653,19 +679,20 @@ def get_pourbaix_entries( if isinstance(chemsys, list) and all( isinstance(v, ComputedEntry | ComputedStructureEntry) for v in chemsys ): - user_entries = [ce.copy() for ce in chemsys] + user_entries = [ce.copy() for ce in chemsys] # type: ignore[union-attr] - elements = set() - for entry in user_entries: - elements.update(entry.elements) - chemsys = [ele.name for ele in elements] - - user_run_types = set( - [ - entry.parameters.get("run_type", "unknown").lower() + chemsys = sorted( + { + ele.name # type: ignore[misc] for entry in user_entries - ] + for ele in entry.elements + } ) + + user_run_types = { + entry.parameters.get("run_type", "unknown").lower() + for entry in user_entries + } if any("r2scan" in rt for rt in user_run_types): thermo_types = ["GGA_GGA+U_R2SCAN"] @@ -673,9 +700,7 @@ def get_pourbaix_entries( solid_compat = MaterialsProjectCompatibility() elif solid_compat == "MaterialsProject2020Compatibility": solid_compat = MaterialsProject2020Compatibility() - elif isinstance(solid_compat, Compatibility): - pass - else: + elif not isinstance(solid_compat, Compatibility): raise ValueError( "Solid compatibility can only be 'MaterialsProjectCompatibility', " "'MaterialsProject2020Compatibility', or an instance of a Compatibility class" @@ -686,13 +711,13 @@ def get_pourbaix_entries( if isinstance(chemsys, str): chemsys = chemsys.split("-") # capitalize and sort the elements - chemsys = sorted(e.capitalize() for e in chemsys) + sorted_chemsys: list[str] = sorted(e.capitalize() for e in chemsys) # type: ignore[union-attr] # Get ion entries first, because certain ions have reference # solids that aren't necessarily in the chemsys (Na2SO4) # download the ion reference data from MPContribs - ion_data = self.get_ion_reference_data_for_chemsys(chemsys) + ion_data = self.get_ion_reference_data_for_chemsys(sorted_chemsys) # build the PhaseDiagram for get_ion_entries ion_ref_comps = [ @@ -704,7 +729,9 @@ def get_pourbaix_entries( # TODO - would be great if the commented line below would work # However for some reason you cannot process GibbsComputedStructureEntry with # MaterialsProjectAqueousCompatibility - ion_ref_entries = ( + ion_ref_entries: Sequence[ + ComputedEntry | ComputedStructureEntry | GibbsComputedStructureEntry + ] = ( self.get_entries_in_chemsys( list([str(e) for e in ion_ref_elts] + ["O", "H"]), additional_criteria={"thermo_types": thermo_types}, @@ -739,12 +766,15 @@ def get_pourbaix_entries( ion_ref_pd = PhaseDiagram(ion_ref_entries) # type: ignore ion_entries = self.get_ion_entries(ion_ref_pd, ion_ref_data=ion_data) - pbx_entries = [PourbaixEntry(e, f"ion-{n}") for n, e in enumerate(ion_entries)] + pbx_entries = [ + PourbaixEntry(e, f"ion-{n}") # type: ignore[arg-type] + for n, e in enumerate(ion_entries) + ] # Construct the solid pourbaix entries from filtered ion_ref entries extra_elts = ( set(ion_ref_elts) - - {Element(s) for s in chemsys} + - {Element(s) for s in sorted_chemsys} - {Element("H"), Element("O")} ) for entry in ion_ref_entries: @@ -875,10 +905,7 @@ def get_ion_entries( f" diagram chemical system is {chemsys}." ) - if not ion_ref_data: - ion_data = self.get_ion_reference_data_for_chemsys(chemsys) - else: - ion_data = ion_ref_data + ion_data = ion_ref_data or self.get_ion_reference_data_for_chemsys(chemsys) # position the ion energies relative to most stable reference state ion_entries = [] @@ -971,7 +998,7 @@ def get_entries_in_chemsys( conventional_unit_cell: bool = False, additional_criteria: dict = DEFAULT_THERMOTYPE_CRITERIA, **kwargs, - ): + ) -> list[ComputedStructureEntry] | list[GibbsComputedStructureEntry]: """Helper method to get a list of ComputedEntries in a chemical system. For example, elements = ["Li", "Fe", "O"] will return a list of all entries in the parent Li-Fe-O chemical system, as well as all subsystems @@ -1008,7 +1035,7 @@ def get_entries_in_chemsys( in entry data kwargs : Other kwargs to pass to `get_entries` Returns: - List of ComputedStructureEntries. + List of ComputedStructureEntry. """ if isinstance(elements, str): elements = elements.split("-") @@ -1038,7 +1065,7 @@ def get_entries_in_chemsys( # replace the entries with GibbsComputedStructureEntry from pymatgen.entries.computed_entries import GibbsComputedStructureEntry - entries = GibbsComputedStructureEntry.from_entries(entries, temp=use_gibbs) + return GibbsComputedStructureEntry.from_entries(entries, temp=use_gibbs) return entries @@ -1115,7 +1142,14 @@ def get_wulff_shape(self, material_id: str): from pymatgen.analysis.wulff import WulffShape from pymatgen.symmetry.analyzer import SpacegroupAnalyzer - structure = self.get_structure_by_material_id(material_id) + if isinstance( + _structure := self.get_structure_by_material_id(material_id, final=True), + list, + ): + structure: Structure = _structure[0] + else: + structure = _structure + doc = self.materials.surface_properties.search(material_ids=material_id) if not doc: @@ -1180,7 +1214,7 @@ def get_charge_density_from_material_id( if not task_ids: return None - results: list[CoreTaskDoc] = self.materials.tasks.search( + results: list[_DictLikeAccess] = self.materials.tasks.search( task_ids=task_ids, fields=["last_updated", "task_id"] ) # type: ignore @@ -1299,7 +1333,7 @@ def get_cohesive_energy( self, material_ids: list[MPID | str], normalization: Literal["atom", "formula_unit"] = "atom", - ) -> float | dict[str, float]: + ) -> dict[str, float | None]: """Obtain the cohesive energy of the structure(s) corresponding to multiple MPIDs. Args: @@ -1318,18 +1352,20 @@ def get_cohesive_energy( } run_type_to_dfa = {"GGA": "PBE", "GGA_U": "PBE", "R2SCAN": "r2SCAN"} - energies = {mp_id: {} for mp_id in material_ids} + energies: dict[MPID | str, dict[str, dict[str, Any]]] = { + mp_id: {} for mp_id in material_ids + } entries = self.get_entries( material_ids, compatible_only=False, property_data=None, conventional_unit_cell=False, ) - for entry in entries: - entry = { - "data": entry.data, - "uncorrected_energy_per_atom": entry.uncorrected_energy_per_atom, - "composition": entry.composition, + for cse in entries: + entry: dict[str, Any] = { + "data": cse.data, + "uncorrected_energy_per_atom": cse.uncorrected_energy_per_atom, + "composition": cse.composition, } mp_id = entry["data"]["material_id"] @@ -1351,16 +1387,18 @@ def get_cohesive_energy( atomic_energies = self.get_atom_reference_data() - e_coh_per_atom = {} - for mp_id, entries in energies.items(): - if not entries: + e_coh_per_atom: dict[str, float | None] = {} + for mp_id, energy_entries in energies.items(): + if not energy_entries: e_coh_per_atom[str(mp_id)] = None continue # take entry from most reliable and available functional - prefered_func = sorted(list(entries), key=lambda k: entry_preference[k])[-1] + prefered_func = sorted( + list(energy_entries), key=lambda k: entry_preference[k] + )[-1] e_coh_per_atom[str(mp_id)] = self._get_cohesive_energy( - entries[prefered_func]["composition"], - entries[prefered_func]["total_energy_per_atom"], + energy_entries[prefered_func]["composition"], + energy_entries[prefered_func]["total_energy_per_atom"], atomic_energies[run_type_to_dfa.get(prefered_func, prefered_func)], normalization=normalization, ) @@ -1369,7 +1407,7 @@ def get_cohesive_energy( @lru_cache def get_atom_reference_data( self, - funcs: tuple[str] = ( + funcs: tuple[str, ...] = ( "PBE", "SCAN", "r2SCAN", @@ -1434,19 +1472,19 @@ def _get_cohesive_energy( def get_stability( self, - entries: ComputedEntry | ComputedStructureEntry | PDEntry, + entries: list[ComputedEntry | ComputedStructureEntry | PDEntry], thermo_type: str | ThermoType = ThermoType.GGA_GGA_U, ) -> list[dict[str, Any]] | None: - chemsys = set() - for entry in entries: - chemsys.update(entry.composition.elements) + chemsys: set[SpeciesLike] = { + ele for entry in entries for ele in entry.composition.elements + } chemsys_str = "-".join(sorted(str(ele) for ele in chemsys)) thermo_type = ( ThermoType(thermo_type) if isinstance(thermo_type, str) else thermo_type ) - corrector = None + corrector: Compatibility | None = None if thermo_type == ThermoType.GGA_GGA_U: from pymatgen.entries.compatibility import MaterialsProject2020Compatibility @@ -1469,18 +1507,22 @@ def get_stability( f"No phase diagram data available for chemical system {chemsys_str} " f"and thermo type {thermo_type}." ) - return + return None - if corrector: - corrected_entries = corrector.process_entries(entries + pd.all_entries) - else: - corrected_entries = [*entries, *pd.all_entries] + joint_entries: Sequence[ComputedEntry | ComputedStructureEntry | PDEntry] = [ + *entries, + *pd.all_entries, + ] - new_pd = PhaseDiagram(corrected_entries) + new_pd = PhaseDiagram( + corrector.process_entries(joint_entries) # type: ignore[arg-type] + if corrector + else joint_entries # type: ignore[list-item] + ) return [ { - "e_above_hull": new_pd.get_e_above_hull(entry), + "e_above_hull": new_pd.get_e_above_hull(entry), # type: ignore[arg-type] "composition": entry.composition.as_dict(), "energy": entry.energy, "entry_id": getattr(entry, "entry_id", f"user-entry-{idx}"), @@ -1493,8 +1535,10 @@ def get_oxygen_evolution( material_id: str | MPID | AlphaID, working_ion: str | Element, thermo_type: str | ThermoType = ThermoType.GGA_GGA_U, - ): - working_ion = Element(working_ion) + ) -> dict[str, np.ndarray]: + working_ion = ( + Element[working_ion] if isinstance(working_ion, str) else working_ion + ) formatted_mpid = AlphaID(material_id).string electrode_docs = self.materials.insertion_electrodes.search( battery_ids=[f"{formatted_mpid}_{working_ion.value}"], diff --git a/mp_api/client/routes/__init__.py b/mp_api/client/routes/__init__.py index a025534dd..c3a40ce5d 100644 --- a/mp_api/client/routes/__init__.py +++ b/mp_api/client/routes/__init__.py @@ -2,8 +2,8 @@ from mp_api.client.core.utils import LazyImport -GENERIC_RESTERS = { - k: LazyImport(f"mp_api.client.routes.{k}.{v}") +GENERIC_RESTERS: dict[str, LazyImport] = { + k: LazyImport(f"mp_api.client.routes._server.{v}") for k, v in { "_general_store": "GeneralStoreRester", "_messages": "MessagesRester", diff --git a/mp_api/client/routes/_general_store.py b/mp_api/client/routes/_general_store.py deleted file mode 100644 index 2ed73097d..000000000 --- a/mp_api/client/routes/_general_store.py +++ /dev/null @@ -1,44 +0,0 @@ -from __future__ import annotations - -from emmet.core._general_store import GeneralStoreDoc - -from mp_api.client.core import BaseRester - - -class GeneralStoreRester(BaseRester): # pragma: no cover - suffix = "_general_store" - document_model = GeneralStoreDoc # type: ignore - primary_key = "submission_id" - use_document_model = False - - def add_item(self, kind: str, markdown: str, meta: dict): # pragma: no cover - """Set general store data. - - Args: - kind: Data type description - markdown: Markdown data - meta: Metadata - Returns: - Dictionary with written data and submission id. - - - Raises: - MPRestError. - """ - return self._post_resource( - body=meta, params={"kind": kind, "markdown": markdown} - ).get("data") - - def get_items(self, kind): # pragma: no cover - """Get general store data. - - Args: - kind: Data type description - Returns: - List of dictionaries with kind, markdown, metadata, and submission_id. - - - Raises: - MPRestError. - """ - return self.search(kind=kind) diff --git a/mp_api/client/routes/_messages.py b/mp_api/client/routes/_messages.py deleted file mode 100644 index a1e85c85c..000000000 --- a/mp_api/client/routes/_messages.py +++ /dev/null @@ -1,81 +0,0 @@ -from __future__ import annotations - -from datetime import datetime - -from emmet.core._messages import MessagesDoc, MessageType - -from mp_api.client.core import BaseRester - - -class MessagesRester(BaseRester): # pragma: no cover - suffix = "_messages" - document_model = MessagesDoc # type: ignore - primary_key = "title" - use_document_model = False - - def set_message( - self, - title: str, - body: str, - type: MessageType = MessageType.generic, - authors: list[str] = None, - ): # pragma: no cover - """Set user settings. - - Args: - title: Message title - body: Message text body - type: Message type - authors: Message authors - Returns: - Dictionary with updated message data - - - Raises: - MPRestError. - """ - d = {"title": title, "body": body, "type": type.value, "authors": authors or []} - - return self._post_resource(body=d).get("data") - - def get_messages( - self, - last_updated: datetime, - sort_fields: list[str] | None = None, - num_chunks: int | None = None, - chunk_size: int = 1000, - all_fields: bool = True, - fields: list[str] | None = None, - ): # pragma: no cover - """Get user settings. - - Args: - last_updated (datetime): Datetime to use to query for newer messages - sort_fields (List[str]): Fields used to sort results. Prefix with '-' to sort in descending order. - num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible. - chunk_size (int): Number of data entries per chunk. - all_fields (bool): Whether to return all fields in the document. Defaults to True. - fields (List[str]): List of fields to project. - - Returns: - Dictionary with messages data - - - Raises: - MPRestError. - """ - query_params = {} - - if sort_fields: - query_params.update( - {"_sort_fields": ",".join([s.strip() for s in sort_fields])} - ) - - return self._search( - last_updated=last_updated, - num_chunks=num_chunks, - chunk_size=chunk_size, - all_fields=all_fields, - fields=fields, - **query_params, - ) diff --git a/mp_api/client/routes/_server.py b/mp_api/client/routes/_server.py new file mode 100644 index 000000000..7974e393c --- /dev/null +++ b/mp_api/client/routes/_server.py @@ -0,0 +1,214 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from emmet.core._general_store import GeneralStoreDoc +from emmet.core._messages import MessagesDoc, MessageType +from emmet.core._user_settings import UserSettingsDoc + +from mp_api.client.core import BaseRester + +if TYPE_CHECKING: + from datetime import datetime + + +class GeneralStoreRester(BaseRester): # pragma: no cover + suffix = "_general_store" + document_model = GeneralStoreDoc # type: ignore + primary_key = "submission_id" + use_document_model = False + + def add_item(self, kind: str, markdown: str, meta: dict): # pragma: no cover + """Set general store data. + + Args: + kind: Data type description + markdown: Markdown data + meta: Metadata + Returns: + Dictionary with written data and submission id. + + + Raises: + MPRestError. + """ + return self._post_resource( + body=meta, params={"kind": kind, "markdown": markdown} + ).get("data") + + def get_items(self, kind): # pragma: no cover + """Get general store data. + + Args: + kind: Data type description + Returns: + List of dictionaries with kind, markdown, metadata, and submission_id. + + + Raises: + MPRestError. + """ + return self.search(kind=kind) + + +class MessagesRester(BaseRester): # pragma: no cover + suffix = "_messages" + document_model = MessagesDoc # type: ignore + primary_key = "title" + use_document_model = False + + def set_message( + self, + title: str, + body: str, + type: MessageType = MessageType.generic, + authors: list[str] | None = None, + ): # pragma: no cover + """Set user settings. + + Args: + title: Message title + body: Message text body + type: Message type + authors: Message authors + Returns: + Dictionary with updated message data + + + Raises: + MPRestError. + """ + d = {"title": title, "body": body, "type": type.value, "authors": authors or []} + + return self._post_resource(body=d).get("data") + + def get_messages( + self, + last_updated: datetime, + sort_fields: list[str] | None = None, + num_chunks: int | None = None, + chunk_size: int = 1000, + all_fields: bool = True, + fields: list[str] | None = None, + ): # pragma: no cover + """Get user settings. + + Args: + last_updated (datetime): Datetime to use to query for newer messages + sort_fields (List[str]): Fields used to sort results. Prefix with '-' to sort in descending order. + num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible. + chunk_size (int): Number of data entries per chunk. + all_fields (bool): Whether to return all fields in the document. Defaults to True. + fields (List[str]): List of fields to project. + + Returns: + Dictionary with messages data + + + Raises: + MPRestError. + """ + query_params = {} + + if sort_fields: + query_params.update( + {"_sort_fields": ",".join([s.strip() for s in sort_fields])} + ) + + return self._search( + last_updated=last_updated, + num_chunks=num_chunks, + chunk_size=chunk_size, + all_fields=all_fields, + fields=fields, + **query_params, + ) + + +class UserSettingsRester(BaseRester): # pragma: no cover + suffix = "_user_settings" + document_model = UserSettingsDoc # type: ignore + primary_key = "consumer_id" + use_document_model = False + + def create_user_settings(self, consumer_id, settings): + """Create user settings. + + Args: + consumer_id: Consumer ID for the user + settings: Dictionary with user settings that + use UserSettingsDoc schema + Returns: + Dictionary with consumer_id and write status. + """ + return self._post_resource( + body=settings, params={"consumer_id": consumer_id} + ).get("data") + + def patch_user_settings(self, consumer_id, settings): # pragma: no cover + """Patch user settings. + + Args: + consumer_id: Consumer ID for the user + settings: Dictionary with user settings + Returns: + Dictionary with consumer_id and write status. + + + Raises: + MPRestError. + """ + body = dict() + valid_fields = [ + "institution", + "sector", + "job_role", + "is_email_subscribed", + "agreed_terms", + "message_last_read", + ] + for key in settings: + if key not in valid_fields: + raise ValueError( + f"Invalid setting key {key}. Must be one of {valid_fields}" + ) + body[f"settings.{key}"] = settings[key] + + return self._patch_resource(body=body, params={"consumer_id": consumer_id}).get( + "data" + ) + + def patch_user_time_settings(self, consumer_id, time): # pragma: no cover + """Set user settings last_read_message field. + + Args: + consumer_id: Consumer ID for the user + time: utc datetime object for when the user last see messages + Returns: + Dictionary with consumer_id and write status. + + + Raises: + MPRestError. + """ + return self._patch_resource( + body={"settings.message_last_read": time.isoformat()}, + params={"consumer_id": consumer_id}, + ).get("data") + + def get_user_settings(self, consumer_id, fields): # pragma: no cover + """Get user settings. + + Args: + consumer_id: Consumer ID for the user + fields: List of fields to project + Returns: + Dictionary with consumer_id and settings. + + + Raises: + MPRestError. + """ + return self._query_resource( + suburl=f"{consumer_id}", fields=fields, num_chunks=1, chunk_size=1 + ).get("data") diff --git a/mp_api/client/routes/_user_settings.py b/mp_api/client/routes/_user_settings.py deleted file mode 100644 index a1eea3041..000000000 --- a/mp_api/client/routes/_user_settings.py +++ /dev/null @@ -1,94 +0,0 @@ -from __future__ import annotations - -from emmet.core._user_settings import UserSettingsDoc - -from mp_api.client.core import BaseRester - - -class UserSettingsRester(BaseRester): # pragma: no cover - suffix = "_user_settings" - document_model = UserSettingsDoc # type: ignore - primary_key = "consumer_id" - use_document_model = False - - def create_user_settings(self, consumer_id, settings): - """Create user settings. - - Args: - consumer_id: Consumer ID for the user - settings: Dictionary with user settings that - use UserSettingsDoc schema - Returns: - Dictionary with consumer_id and write status. - """ - return self._post_resource( - body=settings, params={"consumer_id": consumer_id} - ).get("data") - - def patch_user_settings(self, consumer_id, settings): # pragma: no cover - """Patch user settings. - - Args: - consumer_id: Consumer ID for the user - settings: Dictionary with user settings - Returns: - Dictionary with consumer_id and write status. - - - Raises: - MPRestError. - """ - body = dict() - valid_fields = [ - "institution", - "sector", - "job_role", - "is_email_subscribed", - "agreed_terms", - "message_last_read", - ] - for key in settings: - if key not in valid_fields: - raise ValueError( - f"Invalid setting key {key}. Must be one of {valid_fields}" - ) - body[f"settings.{key}"] = settings[key] - - return self._patch_resource(body=body, params={"consumer_id": consumer_id}).get( - "data" - ) - - def patch_user_time_settings(self, consumer_id, time): # pragma: no cover - """Set user settings last_read_message field. - - Args: - consumer_id: Consumer ID for the user - time: utc datetime object for when the user last see messages - Returns: - Dictionary with consumer_id and write status. - - - Raises: - MPRestError. - """ - return self._patch_resource( - body={"settings.message_last_read": time.isoformat()}, - params={"consumer_id": consumer_id}, - ).get("data") - - def get_user_settings(self, consumer_id, fields): # pragma: no cover - """Get user settings. - - Args: - consumer_id: Consumer ID for the user - fields: List of fields to project - Returns: - Dictionary with consumer_id and settings. - - - Raises: - MPRestError. - """ - return self._query_resource( - suburl=f"{consumer_id}", fields=fields, num_chunks=1, chunk_size=1 - ).get("data") diff --git a/mp_api/client/routes/materials/absorption.py b/mp_api/client/routes/materials/absorption.py index 68eb3fbe9..cc02a28f4 100644 --- a/mp_api/client/routes/materials/absorption.py +++ b/mp_api/client/routes/materials/absorption.py @@ -93,7 +93,7 @@ def search( if query_params[entry] is not None } - return super()._search( + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, diff --git a/mp_api/client/routes/materials/alloys.py b/mp_api/client/routes/materials/alloys.py index eb6a2234f..9e5996eaa 100644 --- a/mp_api/client/routes/materials/alloys.py +++ b/mp_api/client/routes/materials/alloys.py @@ -54,7 +54,7 @@ def search( query_params.update({"material_ids": ",".join(validate_ids(material_ids))}) - return super()._search( + return super()._search( # type: ignore[return-value] formulae=formulae, num_chunks=num_chunks, chunk_size=chunk_size, diff --git a/mp_api/client/routes/materials/bonds.py b/mp_api/client/routes/materials/bonds.py index 82436d2bb..eaf826a56 100644 --- a/mp_api/client/routes/materials/bonds.py +++ b/mp_api/client/routes/materials/bonds.py @@ -94,7 +94,7 @@ def search( if query_params[entry] is not None } - return super()._search( + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, diff --git a/mp_api/client/routes/materials/chemenv.py b/mp_api/client/routes/materials/chemenv.py index 2fd10c36c..ffec91b65 100644 --- a/mp_api/client/routes/materials/chemenv.py +++ b/mp_api/client/routes/materials/chemenv.py @@ -120,7 +120,7 @@ def search( for chemenv_var_name, (chemenv_var, literals) in chemenv_literals.items(): if chemenv_var: t_types = {t if isinstance(t, str) else t.value for t in chemenv_var} - valid_types = {*map(str, literals.__args__)} + valid_types = {*map(str, literals.__args__)} # type: ignore[attr-defined] if invalid_types := t_types - valid_types: raise ValueError( f"Invalid type(s) passed for {chemenv_var_name}: {invalid_types}, valid types are: {valid_types}" @@ -140,7 +140,7 @@ def search( if query_params[entry] is not None } - return super()._search( + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, diff --git a/mp_api/client/routes/materials/dielectric.py b/mp_api/client/routes/materials/dielectric.py index 5d122a43d..ccd03bdfc 100644 --- a/mp_api/client/routes/materials/dielectric.py +++ b/mp_api/client/routes/materials/dielectric.py @@ -74,7 +74,7 @@ def search( if query_params[entry] is not None } - return super()._search( + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, diff --git a/mp_api/client/routes/materials/doi.py b/mp_api/client/routes/materials/doi.py index 25fde8e54..c55e37582 100644 --- a/mp_api/client/routes/materials/doi.py +++ b/mp_api/client/routes/materials/doi.py @@ -47,7 +47,7 @@ def search( if query_params[entry] is not None } - return super()._search( + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, diff --git a/mp_api/client/routes/materials/elasticity.py b/mp_api/client/routes/materials/elasticity.py index 74672cd29..59f622c8c 100644 --- a/mp_api/client/routes/materials/elasticity.py +++ b/mp_api/client/routes/materials/elasticity.py @@ -104,7 +104,7 @@ def search( if query_params[entry] is not None } - return super()._search( + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, diff --git a/mp_api/client/routes/materials/electrodes.py b/mp_api/client/routes/materials/electrodes.py index 9724d7698..65d4065e7 100644 --- a/mp_api/client/routes/materials/electrodes.py +++ b/mp_api/client/routes/materials/electrodes.py @@ -151,7 +151,7 @@ def search( # pragma: ignore if query_params[entry] is not None } - return super()._search(**query_params) + return super()._search(**query_params) # type: ignore[return-value] class ElectrodeRester(BaseElectrodeRester): diff --git a/mp_api/client/routes/materials/electronic_structure.py b/mp_api/client/routes/materials/electronic_structure.py index d52622418..28b02ed20 100644 --- a/mp_api/client/routes/materials/electronic_structure.py +++ b/mp_api/client/routes/materials/electronic_structure.py @@ -264,7 +264,7 @@ def get_bandstructure_from_task_id(self, task_id: str): Returns: bandstructure (BandStructure): BandStructure or BandStructureSymmLine object """ - return self._query_open_data( + return self._query_open_data( # type: ignore[call-overload] bucket="materialsproject-parsed", key=f"bandstructures/{validate_ids([task_id])[0]}.json.gz", decoder=lambda x: load_json(x, deser=True), @@ -293,13 +293,13 @@ def get_bandstructure_from_material_id( if not bs_doc: raise MPRestError("No electronic structure data found.") - if (bs_data := bs_doc[0]["bandstructure"]) is None: + if (_bs_data := bs_doc[0]["bandstructure"]) is None: raise MPRestError( f"No {path_type.value} band structure data found for {material_id}" ) - bs_data: dict = ( - bs_data.model_dump() if self.use_document_model else bs_data # type: ignore + bs_data = ( + _bs_data.model_dump() if self.use_document_model else _bs_data # type: ignore ) if bs_data.get(path_type.value, None) is None: @@ -316,14 +316,12 @@ def get_bandstructure_from_material_id( ): raise MPRestError("No electronic structure data found.") - if (bs_data := bs_doc[0]["dos"]) is None: + if (_bs_data := bs_doc[0]["dos"]) is None: raise MPRestError( f"No uniform band structure data found for {material_id}" ) - bs_data: dict = ( - bs_data.model_dump() if self.use_document_model else bs_data # type: ignore - ) + bs_data = _bs_data.model_dump() if self.use_document_model else _bs_data if bs_data.get("total", None) is None: raise MPRestError( @@ -462,7 +460,7 @@ def get_dos_from_task_id(self, task_id: str) -> CompleteDos: Returns: bandstructure (CompleteDos): CompleteDos object """ - return self._query_open_data( + return self._query_open_data( # type: ignore[call-overload] bucket="materialsproject-parsed", key=f"dos/{validate_ids([task_id])[0]}.json.gz", decoder=lambda x: load_json(x, deser=True), diff --git a/mp_api/client/routes/materials/eos.py b/mp_api/client/routes/materials/eos.py index 604459b74..0182eb6fc 100644 --- a/mp_api/client/routes/materials/eos.py +++ b/mp_api/client/routes/materials/eos.py @@ -60,7 +60,7 @@ def search( if query_params[entry] is not None } - return super()._search( + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, diff --git a/mp_api/client/routes/materials/grain_boundaries.py b/mp_api/client/routes/materials/grain_boundaries.py index b5ccb7d42..6949b9deb 100644 --- a/mp_api/client/routes/materials/grain_boundaries.py +++ b/mp_api/client/routes/materials/grain_boundaries.py @@ -113,7 +113,7 @@ def search( if query_params[entry] is not None } - return super()._search( + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, diff --git a/mp_api/client/routes/materials/magnetism.py b/mp_api/client/routes/materials/magnetism.py index ae093be6c..8321f1e41 100644 --- a/mp_api/client/routes/materials/magnetism.py +++ b/mp_api/client/routes/materials/magnetism.py @@ -115,7 +115,7 @@ def search( if query_params[entry] is not None } - return super()._search( + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, diff --git a/mp_api/client/routes/materials/materials.py b/mp_api/client/routes/materials/materials.py index 38f83b4ca..16d42f65d 100644 --- a/mp_api/client/routes/materials/materials.py +++ b/mp_api/client/routes/materials/materials.py @@ -15,8 +15,6 @@ if TYPE_CHECKING: from typing import Any - from pymatgen.entries.computed_entries import ComputedStructureEntry - class MaterialsRester(CoreRester): suffix = "materials/core" @@ -26,7 +24,7 @@ class MaterialsRester(CoreRester): def get_structure_by_material_id( self, material_id: str, final: bool = True - ) -> Structure | list[Structure]: + ) -> Structure | list[Structure] | None: """Get a structure for a given Materials Project ID. Arguments: @@ -42,19 +40,15 @@ def get_structure_by_material_id( response = self.search(material_ids=material_id, fields=[field]) - if response and response[0]: - response = response[0] + if response and (r := response[0][field]): # type: ignore[index] # Ensure that return type is a Structure regardless of `model_dump` - if isinstance(response[field], dict): - response[field] = Structure.from_dict(response[field]) - elif isinstance(response[field], list) and any( - isinstance(struct, dict) for struct in response[field] - ): - response[field] = [ - Structure.from_dict(struct) for struct in response[field] - ] + if isinstance(r, dict): + return Structure.from_dict(r) + elif isinstance(r, list) and any(isinstance(struct, dict) for struct in r): + return [Structure.from_dict(struct) for struct in r] + return r - return response[field] if response else response # type: ignore + return None def search( self, @@ -168,7 +162,7 @@ def search( if query_params[entry] is not None } - return super()._search( + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, @@ -242,7 +236,7 @@ def get_blessed_entries( uncorrected_energy: tuple[float | None, float | None] | float | None = None, num_chunks: int | None = None, chunk_size: int = 1000, - ) -> list[dict[str, str | dict | ComputedStructureEntry]]: + ) -> list[dict[str, Any]]: """Get blessed calculation entries for a given material and run type. Args: diff --git a/mp_api/client/routes/materials/oxidation_states.py b/mp_api/client/routes/materials/oxidation_states.py index a31e2a020..7ee455683 100644 --- a/mp_api/client/routes/materials/oxidation_states.py +++ b/mp_api/client/routes/materials/oxidation_states.py @@ -73,7 +73,7 @@ def search( if query_params[entry] is not None } - return super()._search( + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, diff --git a/mp_api/client/routes/materials/phonon.py b/mp_api/client/routes/materials/phonon.py index b47317a48..0373cd0dd 100644 --- a/mp_api/client/routes/materials/phonon.py +++ b/mp_api/client/routes/materials/phonon.py @@ -3,7 +3,6 @@ from collections import defaultdict from typing import TYPE_CHECKING -import numpy as np from emmet.core.phonon import PhononBS, PhononBSDOSDoc, PhononDOS from mp_api.client.core import BaseRester, MPRestError @@ -61,7 +60,7 @@ def search( if query_params[entry] is not None } - return super()._search( + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, @@ -86,10 +85,11 @@ def get_bandstructure_from_material_id( key=f"ph-bandstructures/{phonon_method}/{material_id}.json.gz", )[0][0] - if self.use_document_model: - return PhononBS(**result) - - return result + return ( + PhononBS(**result) # type: ignore[arg-type] + if self.use_document_model + else result # type: ignore[return-value] + ) def get_dos_from_material_id( self, material_id: str, phonon_method: str @@ -108,10 +108,11 @@ def get_dos_from_material_id( key=f"ph-dos/{phonon_method}/{material_id}.json.gz", )[0][0] - if self.use_document_model: - return PhononDOS(**result) - - return result + return ( + PhononDOS(**result) # type: ignore[arg-type] + if self.use_document_model + else result # type: ignore[return-value] + ) def get_forceconstants_from_material_id( self, material_id: str @@ -124,7 +125,7 @@ def get_forceconstants_from_material_id( Returns: force constants (list[list[Matrix3D]]): PhononDOS object """ - return self._query_open_data( + return self._query_open_data( # type: ignore[return-value] bucket="materialsproject-parsed", key=f"ph-force-constants/{material_id}.json.gz", )[0][0] @@ -146,9 +147,10 @@ def compute_thermo_quantities(self, material_id: str, phonon_method: str): raise MPRestError("No phonon document found") self.use_document_model = True - docs[0]["phonon_dos"] = self.get_dos_from_material_id( + docs[0]["phonon_dos"] = self.get_dos_from_material_id( # type: ignore[index] material_id, phonon_method ) - doc = PhononBSDOSDoc(**docs[0]) + doc = PhononBSDOSDoc(**docs[0]) # type: ignore[arg-type] self.use_document_model = use_document_model - return doc.compute_thermo_quantities(np.linspace(0, 800, 100)) + # below: same as numpy.linspace(0,800,100) but written out for mypy + return doc.compute_thermo_quantities([i * 800 / 99 for i in range(100)]) diff --git a/mp_api/client/routes/materials/piezo.py b/mp_api/client/routes/materials/piezo.py index ed5199e81..c2f8380a8 100644 --- a/mp_api/client/routes/materials/piezo.py +++ b/mp_api/client/routes/materials/piezo.py @@ -60,7 +60,7 @@ def search( if query_params[entry] is not None } - return super()._search( + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, diff --git a/mp_api/client/routes/materials/provenance.py b/mp_api/client/routes/materials/provenance.py index 1d3894ff4..fe80ea28a 100644 --- a/mp_api/client/routes/materials/provenance.py +++ b/mp_api/client/routes/materials/provenance.py @@ -48,7 +48,7 @@ def search( for entry in query_params if query_params[entry] is not None } - return super()._search( + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, diff --git a/mp_api/client/routes/materials/robocrys.py b/mp_api/client/routes/materials/robocrys.py index 41ef0029e..ece5b665f 100644 --- a/mp_api/client/routes/materials/robocrys.py +++ b/mp_api/client/routes/materials/robocrys.py @@ -77,7 +77,7 @@ def search_docs( for entry in query_params if query_params[entry] is not None } - return super()._search( + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, diff --git a/mp_api/client/routes/materials/similarity.py b/mp_api/client/routes/materials/similarity.py index ac4a59b12..aa6cab71f 100644 --- a/mp_api/client/routes/materials/similarity.py +++ b/mp_api/client/routes/materials/similarity.py @@ -70,7 +70,7 @@ def search( for entry in query_params if query_params[entry] is not None } - return super()._search( + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, @@ -111,7 +111,7 @@ def find_similar( docs = self.search(material_ids=[fmt_idx], fields=["feature_vector"]) if not docs: raise MPRestError(f"No similarity data available for {fmt_idx}") - feature_vector = docs[0]["feature_vector"] + feature_vector = docs[0]["feature_vector"] # type: ignore[index] elif isinstance(structure_or_mpid, Structure): feature_vector = self.fingerprint_structure(structure_or_mpid) diff --git a/mp_api/client/routes/materials/substrates.py b/mp_api/client/routes/materials/substrates.py index 68ca7854f..62eaa6762 100644 --- a/mp_api/client/routes/materials/substrates.py +++ b/mp_api/client/routes/materials/substrates.py @@ -83,7 +83,7 @@ def search( if query_params[entry] is not None } - return super()._search( + return super()._search( # type: ignore[return-value] **query_params, num_chunks=num_chunks, chunk_size=chunk_size, diff --git a/mp_api/client/routes/materials/summary.py b/mp_api/client/routes/materials/summary.py index f6aee2ecb..8874b3727 100644 --- a/mp_api/client/routes/materials/summary.py +++ b/mp_api/client/routes/materials/summary.py @@ -206,7 +206,9 @@ def search( # noqa: D417 # Check to see if user specified _search fields using **kwargs, # or if any of the **kwargs are unparsable - db_keys = {k: [] for k in ("duplicate", "warn", "unknown")} + db_keys: dict[str, list[str]] = { + k: [] for k in ("duplicate", "warn", "unknown") + } for k, v in kwargs.items(): category = "unknown" if non_db_k := mmnd_inv.get(k): @@ -325,13 +327,11 @@ def _csrc(x): "spacegroup_symbol": 230, } for k, cardinality in symm_cardinality.items(): - if hasattr(symm_vals := locals().get(k), "__len__") and not isinstance( - symm_vals, str - ): + if isinstance(symm_vals := locals().get(k), list | tuple | set): if len(symm_vals) < cardinality // 2: query_params.update({k: ",".join(str(v) for v in symm_vals)}) else: - raise ValueError( + raise MPRestError( f"Querying `{k}` by a list of values is only " f"supported for up to {cardinality//2 - 1} values. " f"For your query, retrieve all data first and then filter on `{k}`." @@ -376,7 +376,7 @@ def _csrc(x): if query_params[entry] is not None } - return super()._search( + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, diff --git a/mp_api/client/routes/materials/surface_properties.py b/mp_api/client/routes/materials/surface_properties.py index 2205ef364..76d9e60ce 100644 --- a/mp_api/client/routes/materials/surface_properties.py +++ b/mp_api/client/routes/materials/surface_properties.py @@ -96,7 +96,7 @@ def search( if query_params[entry] is not None } - return super()._search( + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, diff --git a/mp_api/client/routes/materials/tasks.py b/mp_api/client/routes/materials/tasks.py index 4e8498c93..bdef8b0d3 100644 --- a/mp_api/client/routes/materials/tasks.py +++ b/mp_api/client/routes/materials/tasks.py @@ -31,7 +31,7 @@ def get_trajectory(self, task_id: MPID | AlphaID | str) -> list[dict[str, Any]]: Returns: list of dict representing emmet.core.trajectory.Trajectory """ - traj_data = self._query_resource_data( + traj_data = self._query_resource_data( # type: ignore[union-attr] {"task_ids": [AlphaID(task_id).string]}, suburl="trajectory/", use_document_model=False, @@ -100,7 +100,7 @@ def search( } ) - return super()._search( + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, diff --git a/mp_api/client/routes/materials/thermo.py b/mp_api/client/routes/materials/thermo.py index 7e125a888..29c95e013 100644 --- a/mp_api/client/routes/materials/thermo.py +++ b/mp_api/client/routes/materials/thermo.py @@ -133,7 +133,7 @@ def search( if query_params[entry] is not None } - return super()._search( + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, @@ -166,7 +166,7 @@ def get_phase_diagram_from_chemsys( phdiag_id = f"thermo_type={t_type}/chemsys={sorted_chemsys}" version = self.db_version.replace(".", "-") obj_key = f"objects/{version}/phase-diagrams/{phdiag_id}.jsonl.gz" - pd = self._query_open_data( + pd = self._query_open_data( # type: ignore[union-attr] bucket="materialsproject-build", key=obj_key, decoder=lambda x: load_json(x, deser=True), diff --git a/mp_api/client/routes/materials/xas.py b/mp_api/client/routes/materials/xas.py index 88cc2bd4e..0a8efa36f 100644 --- a/mp_api/client/routes/materials/xas.py +++ b/mp_api/client/routes/materials/xas.py @@ -9,6 +9,8 @@ from mp_api.client.core.utils import validate_ids if TYPE_CHECKING: + from typing import Any + from emmet.core.types.enums import XasEdge, XasType @@ -56,7 +58,7 @@ def search( Returns: ([MaterialsDoc]) List of material documents """ - query_params = {} + query_params: dict[str, Any] = {} if edge: query_params.update({"edge": edge}) diff --git a/mp_api/mcp/server.py b/mp_api/mcp/server.py index 735dd0f93..2e78e2396 100644 --- a/mp_api/mcp/server.py +++ b/mp_api/mcp/server.py @@ -71,5 +71,10 @@ def parse_server_args(args: Sequence[str] | None = None) -> dict[str, Any]: mcp = get_core_mcp() -if __name__ == "__main__": + +def _run_mp_mcp_server() -> None: mcp.run(**parse_server_args()) + + +if __name__ == "__main__": + _run_mp_mcp_server() diff --git a/mp_api/mcp/tools.py b/mp_api/mcp/tools.py index 451ecd966..167c2d174 100644 --- a/mp_api/mcp/tools.py +++ b/mp_api/mcp/tools.py @@ -101,10 +101,7 @@ def search(self, query: str) -> SearchOutput: return SearchOutput( results=[ - FetchResult( - id=doc["material_id"], - text=doc["description"], - ) + FetchResult(id=doc["material_id"], text=doc["description"]) # type: ignore[call-arg] for doc in robo_docs ] ) @@ -146,14 +143,16 @@ def fetch(self, idx: str) -> FetchResult: # Assume this is a chemical formula or chemical system if "mp-" not in idx: - summ_kwargs = {"fields": ["energy_above_hull", "material_id"]} + summ_kwargs: dict[str, list[str] | str] = { + "fields": ["energy_above_hull", "material_id"] + } if "-" in idx: summ_kwargs["chemsys"] = "-".join(sorted(idx.split("-"))) else: summ_kwargs["formula"] = idx if not (summ_docs := self.client.materials.summary.search(**summ_kwargs)): - return FetchResult(id=idx) + return FetchResult(id=idx) # type: ignore[call-arg] idx = min(summ_docs, key=lambda doc: doc["energy_above_hull"])[ "material_id" @@ -171,7 +170,7 @@ def fetch(self, idx: str) -> FetchResult: robo_desc = robo_docs[0]["description"] if not robo_desc: - return FetchResult(id=idx) + return FetchResult(id=idx) # type: ignore[call-arg] metadata: dict[str, str] = {} @@ -201,7 +200,7 @@ def fetch(self, idx: str) -> FetchResult: # simple str or numeric type summary_doc = summary_docs[0] - return FetchResult( + return FetchResult( # type: ignore[call-arg] id=idx, text=robo_desc, metadata=MaterialMetadata.from_summary_data(summary_doc, **metadata), @@ -210,7 +209,7 @@ def fetch(self, idx: str) -> FetchResult: def get_phase_diagram_from_elements( self, elements: list[str], - thermo_type: Literal[ + thermo_type: Literal[ # type: ignore[valid-type] *[x.value for x in ThermoType.__members__.values() if x.value != "UNKNOWN"] ] | str = "GGA_GGA+U_R2SCAN", diff --git a/mp_api/mcp/utils.py b/mp_api/mcp/utils.py index 16bc60911..f3601e58f 100644 --- a/mp_api/mcp/utils.py +++ b/mp_api/mcp/utils.py @@ -52,7 +52,9 @@ def reset_client(self) -> None: ) self.client.session.headers["user-agent"] = self.client.session.headers[ "user-agent" - ].replace("mp-api", "mp-mcp") + ].replace( + "mp-api", "mp-mcp" # type: ignore[arg-type] + ) def update_user_api_key(self, api_key: str) -> None: """Change the API key used in the client. diff --git a/mp_api/py.typed b/mp_api/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/pyproject.toml b/pyproject.toml index d91f40298..2fc2c272b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ authors = [ description = "API Client for the Materials Project" readme = "README.md" requires-python = ">=3.11" -license = { text = "modified BSD" } +license = "BSD-3-Clause-LBNL" classifiers = [ "Programming Language :: Python :: 3", "Development Status :: 4 - Beta", @@ -56,6 +56,9 @@ test = [ ] docs = ["sphinx"] +[project.scripts] +mpmcp = "mp_api.mcp.server:_run_mp_mcp_server" + [tool.setuptools.packages.find] include = ["mp_api*"] namespaces = true @@ -111,3 +114,7 @@ isort.required-imports = ["from __future__ import annotations"] [tool.ruff.per-file-ignores] "*/__init__.py" = ["F401"] # F401: imported but unused + +[tool.mypy] +namespace_packages = true +ignore_missing_imports = true diff --git a/tests/client/materials/test_phonon.py b/tests/client/materials/test_phonon.py index 8805176c2..0b5aae754 100644 --- a/tests/client/materials/test_phonon.py +++ b/tests/client/materials/test_phonon.py @@ -76,6 +76,5 @@ def test_phonon_thermo(use_document_model): num_vals = 100 assert all( - isinstance(v, np.ndarray if k == "temperature" else list) and len(v) == num_vals - for k, v in thermo_props.items() + isinstance(v, list) and len(v) == num_vals for k, v in thermo_props.items() ) diff --git a/tests/client/materials/test_summary.py b/tests/client/materials/test_summary.py index 12613e19b..9d5b63989 100644 --- a/tests/client/materials/test_summary.py +++ b/tests/client/materials/test_summary.py @@ -106,13 +106,13 @@ def test_list_like_input(): } == set(crys_sys) # should fail - we don't support querying by so many list values - with pytest.raises(ValueError, match="retrieve all data first and then filter"): + with pytest.raises(MPRestError, match="retrieve all data first and then filter"): _ = search_method(spacegroup_number=list(range(1, 231))) - with pytest.raises(ValueError, match="retrieve all data first and then filter"): + with pytest.raises(MPRestError, match="retrieve all data first and then filter"): _ = search_method(spacegroup_number=["null" for _ in range(230)]) - with pytest.raises(ValueError, match="retrieve all data first and then filter"): + with pytest.raises(MPRestError, match="retrieve all data first and then filter"): _ = search_method(crystal_system=list(CrystalSystem)) diff --git a/tests/client/materials/test_xas.py b/tests/client/materials/test_xas.py index f9c13bd39..c31a5d9f0 100644 --- a/tests/client/materials/test_xas.py +++ b/tests/client/materials/test_xas.py @@ -47,6 +47,10 @@ def rester(): @requires_api_key +@pytest.mark.xfail( + reason="XAS endpoint often too slow to respond.", + strict=False, +) def test_client(rester): client_search_testing( search_method=rester.search, diff --git a/tests/client/test_client.py b/tests/client/test_client.py index dd94a910f..fda8323e2 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -3,11 +3,14 @@ import pytest from mp_api.client import MPRester -from mp_api.client.routes.materials.tasks import TaskRester -from mp_api.client.routes.materials.provenance import ProvenanceRester from .conftest import requires_api_key +try: + import pymatgen.analysis.alloys as pmg_alloys +except ImportError: + pmg_alloys = None + # -- Rester name data for generic tests key_only_resters = { @@ -45,14 +48,16 @@ # "summary", ] # temp - mpr = MPRester() # Temporarily ignore molecules resters while molecules query operators are changed resters_to_test = [ rester for rester in mpr._all_resters - if "molecule" not in rester._class_name.lower() + if ( + "molecule" not in rester._class_name.lower() + and not (pmg_alloys is None and "alloys" in str(rester).lower()) + ) ] diff --git a/tests/client/test_core_client.py b/tests/client/test_core_client.py index 83a4fafdd..2b449bd56 100644 --- a/tests/client/test_core_client.py +++ b/tests/client/test_core_client.py @@ -44,4 +44,4 @@ def test_count(mpr): def test_available_fields(rester, mpr): assert len(mpr.materials.available_fields) > 0 - assert rester.available_fields == ["Unknown fields."] + assert rester.available_fields == [] diff --git a/tests/client/test_heartbeat.py b/tests/client/test_heartbeat.py new file mode 100644 index 000000000..3b17eabed --- /dev/null +++ b/tests/client/test_heartbeat.py @@ -0,0 +1,31 @@ +import requests +import pytest +from unittest.mock import patch, Mock + +import mp_api.client.mprester + +from .conftest import requires_api_key + + +@pytest.fixture +def mock_403(): + with patch("mp_api.client.mprester.get") as mock_get: + mock_response = Mock() + mock_response.status_code = 403 + mock_get.return_value = mock_response + yield mock_get + + +@requires_api_key +@pytest.mark.xfail( + reason="Works in isolation, appear to be contamination from other test imports.", + strict=False, +) +def test_heartbeat_403(mock_403): + from mp_api.client.mprester import MPRester + from mp_api.client.core import MPRestWarning + + with pytest.warns(MPRestWarning, match="heartbeat, check Materials Project status"): + with MPRester() as mpr: + # Ensure that client can still work if heartbeat is unreachable + assert mpr.get_structure_by_material_id("mp-149") is not None diff --git a/tests/client/test_mprester.py b/tests/client/test_mprester.py index 63774a49f..94d8eedf7 100644 --- a/tests/client/test_mprester.py +++ b/tests/client/test_mprester.py @@ -2,6 +2,7 @@ import os import random import importlib +import requests from tempfile import NamedTemporaryFile import numpy as np diff --git a/tests/mcp/test_server.py b/tests/mcp/test_server.py index 972d17320..f7603af69 100644 --- a/tests/mcp/test_server.py +++ b/tests/mcp/test_server.py @@ -1,6 +1,14 @@ import asyncio import pytest +try: + import fastmcp +except ImportError: + pytest.skip( + "Please `pip install fastmcp` to test the MCP server directly.", + allow_module_level=True, + ) + from mp_api.client.core.exceptions import MPRestError from mp_api.mcp.server import get_core_mcp, parse_server_args