from __future__ import annotations
import contextlib
import logging
from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic
import ase
import nshconfig as C
import torch
import torch.nn as nn
from lightning.pytorch import LightningModule
from lightning.pytorch.utilities.types import OptimizerLRSchedulerConfig
from torch.utils.data import Dataset
from typing_extensions import NotRequired, TypedDict, TypeVar, Unpack, cast, override
from ..normalization import ComposeNormalizers, NormalizationContext, NormalizerConfig
from .loader import DataLoaderKwargs, create_dataloader
from .loss import compute_loss
from .lr_scheduler import LRSchedulerConfig, create_lr_scheduler
from .metrics import FinetuneMetrics
from .optimizer import OptimizerConfig, create_optimizer
from .properties import PropertyConfig
if TYPE_CHECKING:
from ase import Atoms
log = logging.getLogger(__name__)
[docs]
class FinetuneModuleBaseConfig(C.Config, ABC):
properties: Sequence[PropertyConfig]
"""Properties to predict."""
optimizer: OptimizerConfig
"""Optimizer."""
lr_scheduler: LRSchedulerConfig | None = None
"""Learning Rate Scheduler"""
ignore_gpu_batch_transform_error: bool = True
"""Whether to ignore data processing errors during training."""
normalizers: Mapping[str, Sequence[NormalizerConfig]] = {}
"""Normalizers for the properties.
Any property can be associated with multiple normalizers. This is useful
for cases where we want to normalize the same property in different ways.
For example, we may want to normalize the energy by subtracting
the atomic reference energies, as well as by mean and standard deviation
normalization.
The normalizers are applied in the order they are defined in the list.
"""
[docs]
@classmethod
@abstractmethod
def ensure_dependencies(cls):
"""
Ensure that all dependencies are installed.
This method should raise an exception if any dependencies are missing,
with a message indicating which dependencies are missing and
how to install them.
"""
...
[docs]
@abstractmethod
def create_model(self) -> FinetuneModuleBase:
"""
Creates an instance of the finetune module for this configuration.
"""
@override
def __post_init__(self):
# VALIDATION: Any key for `normalizers` or `referencers` should be a property name.
for key in self.normalizers.keys():
property_names = [prop.name for prop in self.properties]
if key not in property_names:
raise ValueError(
f"Key '{key}' in 'normalizers' is not a valid property name."
)
class _SkipBatchError(Exception):
"""
Exception to skip a batch in the forward pass. This is not a real error and
should not be logged.
Instead, this is basically a control flow mechanism to skip a batch
if an error occurs during the forward pass. This is useful for
handling edge cases where a batch may be invalid or cause an error
during the forward pass. In this case, we can throw this exception
anywhere in the forward pas and then catch it in the `_common_step`
method. If this exception is caught, we can just skip the batch
instead of logging an error.
This is primarily used to skip graph generation errors in messy data. E.g.,
if our dataset contains materials with so little atoms that we cannot
generate a graph, we can just skip these materials instead of
completely failing the training run.
"""
pass
[docs]
class ModelOutput(TypedDict):
predicted_properties: dict[str, torch.Tensor]
"""Predicted properties. This dictionary should be exactly
in the same shape/format as the output of `batch_to_labels`."""
backbone_output: NotRequired[Any]
"""Output of the backbone model. Only set if `return_backbone_output` is True."""
TData = TypeVar("TData")
TBatch = TypeVar("TBatch")
TFinetuneModuleConfig = TypeVar(
"TFinetuneModuleConfig",
bound=FinetuneModuleBaseConfig,
covariant=True,
)
[docs]
class FinetuneModuleBase(
LightningModule,
ABC,
Generic[TData, TBatch, TFinetuneModuleConfig],
):
"""
Finetune module base class. Inherits ``lightning.pytorch.LightningModule``.
"""
[docs]
@classmethod
@abstractmethod
def hparams_cls(cls) -> type[TFinetuneModuleConfig]:
"""Return the hyperparameters config class for this module."""
...
# region ABC methods for output heads and model forward pass
[docs]
@abstractmethod
def create_model(self):
"""
Initialize both the pre-trained backbone and the output heads for the properties to predict.
You should also construct any other ``nn.Module`` instances
necessary for the forward pass here.
"""
...
[docs]
@abstractmethod
def model_forward_context(self, data: TBatch) -> contextlib.AbstractContextManager:
"""
Context manager for the model forward pass.
This is used for any setup that needs to be done before the forward pass,
e.g., setting pos.requires_grad_() for gradient-based force prediction.
"""
...
[docs]
@abstractmethod
def requires_disabled_inference_mode(self) -> bool:
"""
Whether the model requires inference mode to be disabled.
"""
...
[docs]
@abstractmethod
def model_forward(
self,
batch: TBatch,
return_backbone_output: bool = False,
) -> ModelOutput:
"""
Forward pass of the model.
Args:
batch: Input batch.
return_backbone_output: Whether to return the output of the backbone model.
Returns:
Prediction of the model.
"""
...
[docs]
@abstractmethod
def pretrained_backbone_parameters(self) -> Iterable[nn.Parameter]:
"""
Return the parameters of the backbone model.
"""
...
[docs]
@abstractmethod
def output_head_parameters(self) -> Iterable[nn.Parameter]:
"""
Return the parameters of the output heads.
"""
...
# endregion
# region ABC methods for data processing
[docs]
@abstractmethod
def collate_fn(self, data_list: list[TData]) -> TBatch:
"""
Collate function for the DataLoader
"""
...
[docs]
@abstractmethod
def batch_to_labels(self, batch: TBatch) -> dict[str, torch.Tensor]:
"""
Extract ground truth values from a batch. The output of this function
should be a dictionary with keys corresponding to the target names
and values corresponding to the ground truth values. The values should
be torch tensors that match, in shape, the output of the corresponding
output head.
"""
...
[docs]
@abstractmethod
def atoms_to_data(self, atoms: Atoms, has_labels: bool) -> TData:
"""
Convert an ASE atoms object to a data object. This is used to convert
the input data to the format expected by the model.
Args:
atoms: ASE atoms object.
has_labels: Whether the atoms object contains labels.
"""
...
[docs]
@abstractmethod
def create_normalization_context_from_batch(
self, batch: TBatch
) -> NormalizationContext:
"""
Create a normalization context from a batch. This is used to normalize
and denormalize the properties.
The normalization context contains all the information required to
normalize and denormalize the properties. Currently, this only
includes the compositions of the materials in the batch.
The compositions should be provided as an integer tensor of shape
(batch_size, num_elements), where each row (i.e., `compositions[i]`)
corresponds to the composition vector of the `i`-th material in the batch.
The composition vector is a vector that maps each element to the number of
atoms of that element in the material. For example, `compositions[:, 1]`
corresponds to the number of Hydrogen atoms in each material in the batch,
`compositions[:, 2]` corresponds to the number of Helium atoms, and so on.
Args:
batch: Input batch.
Returns:
Normalization context.
"""
...
# endregion
hparams: TFinetuneModuleConfig # pyright: ignore[reportIncompatibleMethodOverride]
hparams_initial: TFinetuneModuleConfig # pyright: ignore[reportIncompatibleMethodOverride]
[docs]
def __init__(self, hparams: TFinetuneModuleConfig | Mapping[str, Any]):
hparams_cls = self.hparams_cls()
if not isinstance(hparams, hparams_cls):
hparams = hparams_cls.model_validate(hparams)
super().__init__()
# Save the hyperparameters
self.save_hyperparameters(hparams)
# Create the backbone model and output heads
self.create_model()
# Create metrics
self.create_metrics()
# Create normalization modules
self.create_normalizers()
# Ensure that some parameters require gradients
if not any(p.requires_grad for p in self.parameters()):
raise ValueError(
"No parameters require gradients. "
"Please ensure that some parts of the model are trainable."
)
[docs]
def create_metrics(self):
self.train_metrics = FinetuneMetrics(self.hparams.properties)
self.val_metrics = FinetuneMetrics(self.hparams.properties)
self.test_metrics = FinetuneMetrics(self.hparams.properties)
[docs]
def create_normalizers(self):
self.normalizers = nn.ModuleDict(
{
prop.name: ComposeNormalizers(
[
normalizer.create_normalizer_module()
for normalizer in normalizers
]
)
for prop in self.hparams.properties
if (normalizers := self.hparams.normalizers.get(prop.name))
}
)
[docs]
def normalize(
self,
properties: dict[str, torch.Tensor],
ctx: NormalizationContext,
):
"""
Normalizes the ``properties`` dictionary. ``properties`` can either be
the predicted properties or the ground truth labels.
Args:
properties: Dictionary of properties to normalize. The dictionary
should have the same format as the output of ``batch_to_labels``.
ctx: Normalization context. This should be created using
``create_normalization_context_from_batch``.
Returns:
Normalized properties.
"""
normalized_properties = {}
for prop_name, prop_value in properties.items():
if prop_name in self.normalizers:
normalizer = cast(ComposeNormalizers, self.normalizers[prop_name])
normalized_properties[prop_name] = normalizer.normalize(prop_value, ctx)
else:
normalized_properties[prop_name] = prop_value
return normalized_properties
[docs]
def denormalize(
self,
properties: dict[str, torch.Tensor],
ctx: NormalizationContext,
):
"""
Denormalizes the ``properties`` dictionary. ``properties`` can either be
the predicted properties or the ground truth labels.
Args:
properties: Dictionary of properties to denormalize. The dictionary
should have the same format as the output of ``batch_to_labels``.
ctx: Normalization context. This should be created using
``create_normalization_context_from_batch``.
Returns:
Denormalized properties.
"""
denormalized_properties = {}
for prop_name, prop_value in properties.items():
if prop_name in self.normalizers:
normalizer = cast(ComposeNormalizers, self.normalizers[prop_name])
denormalized_properties[prop_name] = normalizer.denormalize(
prop_value, ctx
)
else:
denormalized_properties[prop_name] = prop_value
return denormalized_properties
[docs]
@override
def forward(
self,
batch: TBatch,
return_backbone_output: bool = False,
ignore_gpu_batch_transform_error: bool | None = None,
) -> ModelOutput:
if ignore_gpu_batch_transform_error is None:
ignore_gpu_batch_transform_error = (
self.hparams.ignore_gpu_batch_transform_error
)
with self.model_forward_context(batch):
# Generate graph/etc
if ignore_gpu_batch_transform_error:
try:
batch = self.gpu_batch_transform(batch)
except Exception as e:
log.warning("Error in forward pass. Skipping batch.", exc_info=e)
raise _SkipBatchError() from e
else:
batch = self.gpu_batch_transform(batch)
# Run the model
model_output = self.model_forward(
batch, return_backbone_output=return_backbone_output
)
model_output["predicted_properties"] = {
prop_name: prop_value.contiguous()
for prop_name, prop_value in model_output[
"predicted_properties"
].items()
}
return model_output
def _compute_loss(
self,
predictions: dict[str, torch.Tensor],
labels: dict[str, torch.Tensor],
log: bool = True,
log_prefix: str = "",
):
losses: list[torch.Tensor] = []
for prop in self.hparams.properties:
# Get the target and prediction
prediction = predictions[prop.name]
label = labels[prop.name]
# Compute the loss
loss = compute_loss(prop.loss, prediction, label) * prop.loss_coefficient
# Log the loss
if log:
self.log(f"{log_prefix}{prop.name}_loss", loss)
losses.append(loss)
# Sum the losses
loss = cast(torch.Tensor, sum(losses))
# Log the total loss & return
if log:
self.log(f"{log_prefix}total_loss", loss)
return loss
def _common_step(
self,
batch: TBatch,
name: str,
metrics: FinetuneMetrics | None,
log: bool = True,
):
try:
output: ModelOutput = self(batch)
except _SkipBatchError:
def _zero_output():
return {
"predicted_properties": {},
"backbone_output": None,
}
def _zero_loss():
# Return a zero loss tensor that is still attached to all
# parameters so that the optimizer can still update them.
# This prevents DDP unused parameter errors.
return cast(torch.Tensor, sum(p.sum() * 0.0 for p in self.parameters()))
return _zero_output(), _zero_loss()
# Extract labels from the batch
labels = self.batch_to_labels(batch)
predictions = output["predicted_properties"]
# Create the normalization context required for normalization/referencing.
# We only need to create the context once per batch.
normalization_ctx = self.create_normalization_context_from_batch(batch)
# Normalize the properties.
# NOTE: We normalize the target properties before computing the loss.
# This ensures that the model is effectively learning the normalized
# properties.
# NOTE: Some other implementations may instead choose to denormalize
# the predictions before computing the loss. We don't do this because
# it can lead to numerical instability in the loss computation.
normalized_labels = self.normalize(labels, normalization_ctx)
for key, value in normalized_labels.items():
normalized_labels[key] = value.contiguous()
# Compute loss
loss = self._compute_loss(
predictions,
normalized_labels,
log=log,
log_prefix=f"{name}/",
)
# NOTE: After computing the loss, we denormalize the predictions.
# This is done so that the values that are output by the model
# (and those used for metric computation) are measured in the
# denormalized units, which are more interpretable.
# NOTE: Again, we only denormalize the predictions, not the labels.
# This is because the labels are already in the denormalized units.
predictions = self.denormalize(predictions, normalization_ctx)
# Log metrics
if log and metrics is not None:
self.log_dict(
{
f"{name}/{metric_name}": metric
for metric_name, metric in metrics(predictions, labels).items()
}
)
return output, loss
[docs]
@override
def training_step(self, batch: TBatch, batch_idx: int):
_, loss = self._common_step(
batch,
"train",
self.train_metrics,
)
return loss
[docs]
@override
def validation_step(self, batch: TBatch, batch_idx: int):
_ = self._common_step(batch, "val", self.val_metrics)
[docs]
@override
def test_step(self, batch: TBatch, batch_idx: int):
_ = self._common_step(batch, "test", self.test_metrics)
[docs]
@override
def predict_step(self, batch: TBatch, batch_idx: int):
output: ModelOutput = self(batch, ignore_gpu_batch_transform_error=False)
predictions = output["predicted_properties"]
normalization_ctx = self.create_normalization_context_from_batch(batch)
denormalized_predictions = self.denormalize(predictions, normalization_ctx)
return denormalized_predictions
[docs]
def create_dataloader(
self,
dataset: Dataset[ase.Atoms],
has_labels: bool,
**kwargs: Unpack[DataLoaderKwargs],
):
"""
Creates a wrapped DataLoader for the given dataset.
This will wrap the dataset with the CPU data transform and the model's
collate function.
NOTE about `has_labels`: This is used to determine whether our data
loading pipeline should expect labels in the dataset. This should
be `True` for train/val/test datasets (as we compute loss and metrics
on these datasets) and `False` for prediction datasets.
Args:
dataset: Dataset to wrap.
has_labels: Whether the dataset contains labels. This should be
`True` for train/val/test datasets and `False` for prediction datasets.
**kwargs: Additional keyword arguments to pass to the DataLoader.
"""
return create_dataloader(dataset, has_labels, lightning_module=self, **kwargs)
[docs]
def property_predictor(
self, lightning_trainer_kwargs: dict[str, Any] | None = None
):
"""Return a wrapper for easy prediction without explicitly setting up a lightning trainer.
This method provides a high-level interface for making predictions with the trained model.
This can be used for various prediction tasks including but not limited to:
- Interatomic potential energy and forces
- Material property prediction
- Structure-property relationships
Parameters
----------
lightning_trainer_kwargs : dict[str, Any] | None, optional
Additional keyword arguments to pass to the PyTorch Lightning Trainer.
If None, default trainer settings will be used.
Returns
-------
MatterTunePropertyPredictor
A wrapper class that provides simplified prediction functionality without requiring
direct interaction with the Lightning Trainer.
Examples
--------
>>> model = MyModel()
>>> property_predictor = model.property_predictor()
>>> atoms_1 = ase.Atoms("H2O", positions=[[0, 0, 0], [0, 0, 1], [0, 1, 0]], cell=[10, 10, 10], pbc=True)
>>> atoms_2 = ase.Atoms("H2O", positions=[[0, 0, 0], [0, 0, 1], [0, 1, 0]], cell=[10, 10, 10], pbc=True)
>>> atoms = [atoms_1, atoms_2]
>>> predictions = property_predictor.predict(atoms, ["energy", "forces"])
>>> print("Atoms 1 energy:", predictions[0]["energy"])
>>> print("Atoms 1 forces:", predictions[0]["forces"])
>>> print("Atoms 2 energy:", predictions[1]["energy"])
>>> print("Atoms 2 forces:", predictions[1]["forces"])
"""
from ..wrappers.property_predictor import MatterTunePropertyPredictor
return MatterTunePropertyPredictor(
self,
lightning_trainer_kwargs=lightning_trainer_kwargs,
)
[docs]
def ase_calculator(self, lightning_trainer_kwargs: dict[str, Any] | None = None):
"""Returns an ASE calculator wrapper for the interatomic potential.
This method creates an ASE (Atomic Simulation Environment) calculator that can be used
to compute energies and forces using the trained interatomic potential model.
The calculator integrates with ASE's standard interfaces for molecular dynamics
and structure optimization.
Parameters
----------
lightning_trainer_kwargs : dict[str, Any] | None, optional
Keyword arguments to pass to the PyTorch Lightning trainer used for inference.
If None, default trainer settings will be used.
Returns
-------
MatterTuneCalculator
An ASE calculator wrapper around the trained potential that can be used
for energy and force calculations via ASE's interfaces.
Examples
--------
>>> model = MyModel()
>>> calc = model.ase_calculator()
>>> atoms = ase.Atoms("H2O", positions=[[0, 0, 0], [0, 0, 1], [0, 1, 0]], cell=[10, 10, 10], pbc=True)
>>> atoms.calc = calc
>>> energy = atoms.get_potential_energy()
>>> forces = atoms.get_forces()
"""
from ..wrappers.ase_calculator import MatterTuneCalculator
property_predictor = self.property_predictor(
lightning_trainer_kwargs=lightning_trainer_kwargs
)
return MatterTuneCalculator(property_predictor)