Source code for mattertune.data.mp

from __future__ import annotations

import logging
from typing import Literal

import ase
from ase import Atoms
from torch.utils.data import Dataset
from typing_extensions import override

from ..registry import data_registry
from ..util import optional_import_error_message
from .base import DatasetConfigBase

log = logging.getLogger(__name__)


[docs] @data_registry.register class MPDatasetConfig(DatasetConfigBase): """Configuration for a dataset stored in the Materials Project database.""" type: Literal["mp"] = "mp" """Discriminator for the MP dataset.""" api: str """Input API key for the Materials Project database.""" fields: list[str] """Fields to retrieve from the Materials Project database.""" query: dict """Query to filter the data from the Materials Project database."""
[docs] @override def create_dataset(self): return MPDataset(self)
[docs] class MPDataset(Dataset[ase.Atoms]):
[docs] def __init__(self, config: MPDatasetConfig): super().__init__() self.config = config with optional_import_error_message("mp_api"): from mp_api.client import MPRester # type: ignore[reportMissingImports] self.mpr = MPRester(config.api) if "material_id" not in config.fields: config.fields.append("material_id") self.docs = self.mpr.summary.search(fields=config.fields, **config.query)
@override def __getitem__(self, idx: int) -> Atoms: from pymatgen.io.ase import AseAtomsAdaptor doc = self.docs[idx] mid = doc.material_id structure = self.mpr.get_structure_by_material_id(mid) adaptor = AseAtomsAdaptor() atoms = adaptor.get_atoms(structure) assert isinstance(atoms, Atoms), "Expected an Atoms object" doc_labels = dict(doc) atoms.info = { key: doc_labels[key] for key in self.config.fields if key in doc_labels and key != "material_id" } return atoms def __len__(self) -> int: return len(self.docs)