from __future__ import annotations
import logging
from pathlib import Path
from typing import Literal
import ase
import numpy as np
from ase.calculators.calculator import all_properties
from ase.calculators.singlepoint import SinglePointCalculator
from ase.db import connect
from ase.db.core import Database
from ase.stress import full_3x3_to_voigt_6_stress
from torch.utils.data import Dataset
from typing_extensions import override
from ..registry import data_registry
from .base import DatasetConfigBase
log = logging.getLogger(__name__)
[docs]
@data_registry.register
class DBDatasetConfig(DatasetConfigBase):
    """Configuration for a dataset stored in an ASE database."""
    type: Literal["db"] = "db"
    """Discriminator for the DB dataset."""
    src: Database | str | Path
    """Path to the ASE database file or a database object."""
    energy_key: str | None = None
    """Key for the energy label in the database."""
    forces_key: str | None = None
    """Key for the force label in the database."""
    stress_key: str | None = None
    """Key for the stress label in the database."""
    preload: bool = True
    """Whether to load all the data at once or not."""
[docs]
    @override
    def create_dataset(self):
        return DBDataset(self) 
 
[docs]
class DBDataset(Dataset[ase.Atoms]):
[docs]
    def __init__(self, config: DBDatasetConfig):
        super().__init__()
        self.config = config
        if isinstance(config.src, Database):
            self.db = config.src
        else:
            self.db = connect(config.src)
        if self.config.preload:
            self.atoms_list = []
            for row in self.db.select():
                atoms = self._load_atoms_from_row(row)
                self.atoms_list.append(atoms) 
    def _load_atoms_from_row(self, row):
        atoms = row.toatoms()
        labels = dict(row.data)
        unrecognized_labels = {}
        if self.config.energy_key:
            labels["energy"] = labels.pop(self.config.energy_key)
        if self.config.forces_key:
            labels["forces"] = np.array(labels.pop(self.config.forces_key))
        if self.config.stress_key:
            labels["stress"] = np.array(labels.pop(self.config.stress_key))
            if labels["stress"].shape == (3, 3):
                labels["stress"] = full_3x3_to_voigt_6_stress(labels["stress"])
            elif labels["stress"].shape != (6,):
                raise ValueError(
                    f"Stress has unexpected shape: {labels['stress'].shape}, expected (3, 3) or (6,)"
                )
        for key in list(labels.keys()):
            if key not in all_properties:
                unrecognized_labels[key] = labels.pop(key)
        calc = SinglePointCalculator(atoms, **labels)
        atoms.calc = calc
        atoms.info = unrecognized_labels
        return atoms
    @override
    def __getitem__(self, idx):
        if self.config.preload:
            return self.atoms_list[idx]
        else:
            row = self.db.get(idx=idx)
            return self._load_atoms_from_row(row)
    def __len__(self):
        return len(self.db)