Source code for mattertune.data.omat24

from __future__ import annotations

from pathlib import Path
from typing import Literal

import ase
from torch.utils.data import Dataset
from typing_extensions import override

from ..registry import data_registry
from ..util import optional_import_error_message
from .base import DatasetConfigBase


[docs] @data_registry.register class OMAT24DatasetConfig(DatasetConfigBase): type: Literal["omat24"] = "omat24" """Discriminator for the OMAT24 dataset.""" src: Path """The path to the OMAT24 dataset."""
[docs] @override def create_dataset(self): return OMAT24Dataset(self)
[docs] class OMAT24Dataset(Dataset[ase.Atoms]):
[docs] def __init__(self, config: OMAT24DatasetConfig): super().__init__() self.config = config with optional_import_error_message("fairchem"): from fairchem.core.datasets import AseDBDataset # type: ignore[reportMissingImports] # noqa self.dataset = AseDBDataset(config={"src": str(self.config.src)})
@override def __getitem__(self, idx: int) -> ase.Atoms: atoms = self.dataset.get_atoms(idx) return atoms def __len__(self) -> int: return len(self.dataset)