from __future__ import annotations
import logging
from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sized
from typing import TYPE_CHECKING, Annotated, Any, Literal
import ase
import nshconfig as C
import numpy as np
from lightning.pytorch import LightningDataModule
from torch.utils.data import Dataset
from typing_extensions import TypeAliasType, TypedDict, override
from ..registry import data_registry
from .base import DatasetConfig
from .util.split_dataset import SplitDataset
if TYPE_CHECKING:
from ..finetune.loader import DataLoaderKwargs
log = logging.getLogger(__name__)
[docs]
class DatasetMapping(TypedDict, total=False):
train: Dataset[ase.Atoms]
validation: Dataset[ase.Atoms]
[docs]
class DataModuleBaseConfig(C.Config, ABC):
batch_size: int
"""The batch size for the dataloaders."""
num_workers: int | Literal["auto"] = "auto"
"""The number of workers for the dataloaders.
This is the number of processes that generate batches in parallel.
If set to "auto", the number of workers will be automatically
set based on the number of available CPUs.
Set to 0 to disable parallelism.
"""
pin_memory: bool = True
"""Whether to pin memory in the dataloaders.
This is useful for speeding up GPU data transfer.
"""
def _num_workers_or_auto(self):
if self.num_workers == "auto":
import os
if (cpu_count := os.cpu_count()) is not None:
return cpu_count - 1
else:
return 1
return self.num_workers
[docs]
def dataloader_kwargs(self) -> DataLoaderKwargs:
return {
"batch_size": self.batch_size,
"num_workers": self._num_workers_or_auto(),
"pin_memory": self.pin_memory,
}
[docs]
@abstractmethod
def dataset_configs(self) -> Iterable[DatasetConfig]: ...
[docs]
@abstractmethod
def create_datasets(self) -> DatasetMapping: ...
[docs]
@data_registry.rebuild_on_registers
class ManualSplitDataModuleConfig(DataModuleBaseConfig):
train: DatasetConfig
"""The configuration for the training data."""
validation: DatasetConfig | None = None
"""The configuration for the validation data."""
[docs]
@override
def dataset_configs(self):
yield self.train
if self.validation is not None:
yield self.validation
[docs]
@override
def create_datasets(self):
datasets: DatasetMapping = {}
datasets["train"] = self.train.create_dataset()
if (val := self.validation) is not None:
datasets["validation"] = val.create_dataset()
return datasets
[docs]
@data_registry.rebuild_on_registers
class AutoSplitDataModuleConfig(DataModuleBaseConfig):
dataset: DatasetConfig
"""The configuration for the dataset."""
train_split: float
"""The proportion of the dataset to include in the training split."""
validation_split: float | Literal["auto", "disable"] = "auto"
"""The proportion of the dataset to include in the validation split.
If set to "auto", the validation split will be automatically determined as
the complement of the training split, i.e. `validation_split = 1 - train_split`.
If set to "disable", the validation split will be disabled.
"""
shuffle: bool = True
"""Whether to shuffle the dataset before splitting."""
shuffle_seed: int = 42
"""The seed to use for shuffling the dataset."""
def _resolve_train_val_split(self):
train_split = self.train_split
match self.validation_split:
case "auto":
validation_split = 1.0 - train_split
case "disable":
validation_split = 0.0
case _:
validation_split = self.validation_split
return train_split, validation_split
[docs]
@override
def dataset_configs(self):
yield self.dataset
[docs]
@override
def create_datasets(self):
# Create the full dataset.
dataset = self.dataset.create_dataset()
# If the validation split is disabled, return the full dataset.
if self.validation_split == "disable":
return DatasetMapping(train=dataset)
if not isinstance(dataset, Sized):
raise TypeError(
f"The underlying dataset must be sized, but got {dataset!r}."
)
# Compute the indices for the training and validation splits.
dataset_len = len(dataset)
indices = np.arange(dataset_len)
if self.shuffle:
rng = np.random.default_rng(self.shuffle_seed)
rng.shuffle(indices)
train_split, validation_split = self._resolve_train_val_split()
train_len = int(train_split * dataset_len)
validation_len = int(validation_split * dataset_len)
# Get indices for each split
train_indices = indices[:train_len]
validation_indices = indices[train_len : train_len + validation_len]
# Create the training and validation datasets.
train_dataset = SplitDataset(dataset, train_indices)
validation_dataset = SplitDataset(dataset, validation_indices)
return DatasetMapping(train=train_dataset, validation=validation_dataset)
DataModuleConfig = TypeAliasType(
"DataModuleConfig",
Annotated[
ManualSplitDataModuleConfig | AutoSplitDataModuleConfig,
C.Field(description="The configuration for the data module."),
],
)
[docs]
class MatterTuneDataModule(LightningDataModule):
hparams: DataModuleConfig # pyright: ignore[reportIncompatibleMethodOverride]
hparams_initial: DataModuleConfig # pyright: ignore[reportIncompatibleMethodOverride]
[docs]
@override
def __init__(self, hparams: DataModuleConfig | Mapping[str, Any]):
# Validate & resolve the configuration.
if not isinstance(hparams, C.Config):
hparams = C.TypeAdapter(DataModuleConfig).validate_python(hparams)
super().__init__()
# Save the configuration for Lightning.
self.save_hyperparameters(hparams)
[docs]
@override
def prepare_data(self) -> None:
for config in self.hparams.dataset_configs():
config.prepare_data()
[docs]
@override
def setup(self, stage: str):
super().setup(stage)
self.datasets = self.hparams.create_datasets()
# PyTorch Lightning checks for the *existence* of the
# `train_dataloader`, `val_dataloader`, `test_dataloader`,
# and `predict_dataloader` methods to determine which dataloaders
# to create. That means that we cannot just return `None` from
# these methods if the dataset is not available. We also cannot
# raise an exception, because this will just crash the training
# loop.
# Instead, we will check, here, what datasets are available, and
# remove the corresponding methods if the dataset is not available.
METHOD_NAME_MAPPING = {
"train": "train_dataloader",
"validation": "val_dataloader",
}
for dataset_name, method_name in METHOD_NAME_MAPPING.items():
if dataset_name not in self.datasets:
setattr(self, method_name, None)
@property
def lightning_module(self):
if (trainer := self.trainer) is None:
raise ValueError("No trainer found.")
if (lightning_module := trainer.lightning_module) is None:
raise ValueError("No LightningModule found.")
from ..finetune.base import FinetuneModuleBase
if not isinstance(lightning_module, FinetuneModuleBase):
raise ValueError("The LightningModule is not a FinetuneModuleBase.")
return lightning_module
[docs]
@override
def train_dataloader(self):
if (dataset := self.datasets.get("train")) is None:
raise ValueError("No training dataset found.")
return self.lightning_module.create_dataloader(
dataset,
has_labels=True,
**self.hparams.dataloader_kwargs(),
)
[docs]
@override
def val_dataloader(self):
if (dataset := self.datasets.get("validation")) is None:
raise ValueError(
"No validation dataset found, but `val_dataloader` was called. "
"This should not happen. Report this as a bug."
)
return self.lightning_module.create_dataloader(
dataset,
has_labels=True,
**self.hparams.dataloader_kwargs(),
)