Source code for mattertune.data.xyz

from __future__ import annotations

import logging
from pathlib import Path
from typing import Literal

import ase
from ase import Atoms
from ase.io import read
from torch.utils.data import Dataset
from typing_extensions import override

from ..registry import data_registry
from .base import DatasetConfigBase

log = logging.getLogger(__name__)


[docs] @data_registry.register class XYZDatasetConfig(DatasetConfigBase): type: Literal["xyz"] = "xyz" """Discriminator for the XYZ dataset.""" src: str | Path """The path to the XYZ dataset."""
[docs] @override def create_dataset(self): return XYZDataset(self)
[docs] class XYZDataset(Dataset[ase.Atoms]):
[docs] def __init__(self, config: XYZDatasetConfig): super().__init__() self.config = config atoms_list = read(str(self.config.src), index=":") assert isinstance(atoms_list, list), "Expected a list of Atoms objects" self.atoms_list: list[Atoms] = atoms_list log.info(f"Loaded {len(self.atoms_list)} atoms from {self.config.src}")
@override def __getitem__(self, idx: int) -> ase.Atoms: return self.atoms_list[idx] def __len__(self) -> int: return len(self.atoms_list)