Source code for mattertune.data.db

from __future__ import annotations

import logging
from pathlib import Path
from typing import Literal

import ase
import numpy as np
from ase.calculators.calculator import all_properties
from ase.calculators.singlepoint import SinglePointCalculator
from ase.db import connect
from ase.db.core import Database
from ase.stress import full_3x3_to_voigt_6_stress
from torch.utils.data import Dataset
from typing_extensions import override

from ..registry import data_registry
from .base import DatasetConfigBase

log = logging.getLogger(__name__)


[docs] @data_registry.register class DBDatasetConfig(DatasetConfigBase): """Configuration for a dataset stored in an ASE database.""" type: Literal["db"] = "db" """Discriminator for the DB dataset.""" src: Database | str | Path """Path to the ASE database file or a database object.""" energy_key: str | None = None """Key for the energy label in the database.""" forces_key: str | None = None """Key for the force label in the database.""" stress_key: str | None = None """Key for the stress label in the database.""" preload: bool = True """Whether to load all the data at once or not."""
[docs] @override def create_dataset(self): return DBDataset(self)
[docs] class DBDataset(Dataset[ase.Atoms]):
[docs] def __init__(self, config: DBDatasetConfig): super().__init__() self.config = config if isinstance(config.src, Database): self.db = config.src else: self.db = connect(config.src) if self.config.preload: self.atoms_list = [] for row in self.db.select(): atoms = self._load_atoms_from_row(row) self.atoms_list.append(atoms)
def _load_atoms_from_row(self, row): atoms = row.toatoms() labels = dict(row.data) unrecognized_labels = {} if self.config.energy_key: labels["energy"] = labels.pop(self.config.energy_key) if self.config.forces_key: labels["forces"] = np.array(labels.pop(self.config.forces_key)) if self.config.stress_key: labels["stress"] = np.array(labels.pop(self.config.stress_key)) if labels["stress"].shape == (3, 3): labels["stress"] = full_3x3_to_voigt_6_stress(labels["stress"]) elif labels["stress"].shape != (6,): raise ValueError( f"Stress has unexpected shape: {labels['stress'].shape}, expected (3, 3) or (6,)" ) for key in list(labels.keys()): if key not in all_properties: unrecognized_labels[key] = labels.pop(key) calc = SinglePointCalculator(atoms, **labels) atoms.calc = calc atoms.info = unrecognized_labels return atoms @override def __getitem__(self, idx): if self.config.preload: return self.atoms_list[idx] else: row = self.db.get(idx=idx) return self._load_atoms_from_row(row) def __len__(self): return len(self.db)