from __future__ import annotations
import contextlib
import importlib.util
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, cast
from ase import Atoms
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing_extensions import assert_never, final, override
from ...finetune import properties as props
from ...finetune.base import FinetuneModuleBase, FinetuneModuleBaseConfig, ModelOutput
from ...normalization import NormalizationContext
from ...registry import backbone_registry
from ...util import optional_import_error_message
from ..eqV2.model import FAIRChemAtomsToGraphSystemConfig
if TYPE_CHECKING:
from fairchem.core.datasets.atomic_data import AtomicData
log = logging.getLogger(__name__)
HARDCODED_NAMES: dict[type[props.PropertyConfigBase], str] = {
props.EnergyPropertyConfig: "energy",
props.ForcesPropertyConfig: "forces",
props.StressesPropertyConfig: "stress",
}
[docs]
@backbone_registry.register
class UMABackboneConfig(FinetuneModuleBaseConfig):
name: Literal["uma"] = "uma"
"""The name of the backbone model to use. Should be "uma"."""
model_name: str
"""
The specific UMA model variant to use.
Options include:
- "uma-s-1"
- "uma-s-1.1"
- "uma-m-1.1"
- "uma-l"
"""
atoms_to_graph: FAIRChemAtomsToGraphSystemConfig = FAIRChemAtomsToGraphSystemConfig(radius=6.0)
"""Configuration for converting atomic data to graph representations."""
task_name: str | None = None
"""The task name for the dataset, e.g., 'oc20', 'omol', 'omat', 'odac', 'omc'. If None, it will be inferred from the data."""
[docs]
@override
@classmethod
def ensure_dependencies(cls):
# Make sure the fairchem module is available
if importlib.util.find_spec("fairchem") is None:
raise ImportError(
"The fairchem module is not installed. Please install it by running"
" pip install fairchem-core."
)
[docs]
@override
def create_model(self):
return UMABackboneModule(self)
[docs]
@final
class UMABackboneModule(FinetuneModuleBase["AtomicData", "AtomicData", UMABackboneConfig]):
[docs]
@override
@classmethod
def hparams_cls(cls):
return UMABackboneConfig
[docs]
@override
def requires_disabled_inference_mode(self):
return False
[docs]
@override
def create_model(self):
with optional_import_error_message("fairchem-core"):
from fairchem.core.models.uma.escn_moe import eSCNMDMoeBackbone
from fairchem.core.models.uma.escn_md import (
eSCNMDBackbone, MLP_EFS_Head, MLP_Energy_Head, Linear_Force_Head, MLP_Stress_Head
)
from fairchem.core import pretrained_mlip
predictor = pretrained_mlip.get_predict_unit("uma-s-1")
backbone: eSCNMDMoeBackbone = predictor.model.module.backbone # type: ignore[reportGeneralTypeIssues]
self.backbone = backbone.float()
self.output_heads = nn.ModuleDict()
## if any conservative forces or stresses are requested, we need to use the MLP_EFS_Head
f_conservative = False
s_conservative = False
for prop in self.hparams.properties:
if isinstance(prop, props.ForcesPropertyConfig):
if prop.conservative:
f_conservative = True
elif isinstance(prop, props.StressesPropertyConfig):
if prop.conservative:
s_conservative = True
if f_conservative or s_conservative:
head = MLP_EFS_Head(
backbone = self.backbone,
wrap_property=False,
)
head.regress_forces = f_conservative
head.regress_stress = s_conservative
self.output_heads["efs"] = head
## for other properties, we can use the specific heads
for prop in self.hparams.properties:
if isinstance(prop, props.EnergyPropertyConfig):
if not f_conservative and not s_conservative:
self.output_heads[prop.name] = MLP_Energy_Head(
backbone = self.backbone,
reduce="sum",
)
elif isinstance(prop, props.ForcesPropertyConfig):
if not prop.conservative:
self.output_heads[prop.name] = Linear_Force_Head(
backbone = self.backbone,
)
elif isinstance(prop, props.StressesPropertyConfig):
if not prop.conservative:
self.output_heads[prop.name] = MLP_Stress_Head(
backbone = self.backbone,
reduce="mean",
)
else:
raise ValueError(f"Unsupported property type: {type(prop)}, UMA for now only supports energy, forces, and stresses.")
for key in self.output_heads.keys():
self.output_heads[key] = self.output_heads[key].float()
[docs]
@override
def trainable_parameters(self):
if not self.hparams.freeze_backbone:
yield from self.backbone.named_parameters()
for head in self.output_heads.values():
yield from head.named_parameters()
[docs]
@override
@contextlib.contextmanager
def model_forward_context(self, data, mode: str):
yield
[docs]
@override
def model_forward(self, batch, mode: str):
if mode == "predict":
self.eval()
emb: dict[str, torch.Tensor] = self.backbone(batch)
output_pred: dict[str, torch.Tensor] = {}
for name, head in self.output_heads.items():
out = head(batch, emb)
output_pred.update(out)
predicted_properties: dict[str, torch.Tensor] = {}
for prop in self.hparams.properties:
predicted_properties[prop.name] = output_pred[HARDCODED_NAMES[type(prop)]]
if mode == "predict":
self.train()
return ModelOutput(predicted_properties=predicted_properties)
[docs]
@override
def apply_callable_to_backbone(self, fn):
return fn(self.backbone)
[docs]
@override
def pretrained_backbone_parameters(self):
return self.backbone.parameters()
[docs]
@override
def output_head_parameters(self):
for head in self.output_heads.values():
yield from head.parameters()
[docs]
@override
def collate_fn(self, data_list):
with optional_import_error_message("fairchem"):
from fairchem.core.datasets.atomic_data import atomicdata_list_to_batch
return atomicdata_list_to_batch(data_list)
[docs]
@override
def batch_to_labels(self, batch):
labels: dict[str, torch.Tensor] = {}
for prop in self.hparams.properties:
batch_prop_name = HARDCODED_NAMES.get(type(prop), prop.name)
labels[prop.name] = batch[batch_prop_name] # type: ignore[index]
return labels
[docs]
@override
def atoms_to_data(self, atoms:Atoms, has_labels):
with optional_import_error_message("fairchem"):
from fairchem.core.datasets.atomic_data import AtomicData
energy = False
forces = False
stress = False
if has_labels:
energy = any(
isinstance(prop, props.EnergyPropertyConfig)
for prop in self.hparams.properties
)
forces = any(
isinstance(prop, props.ForcesPropertyConfig)
for prop in self.hparams.properties
)
stress = any(
isinstance(prop, props.StressesPropertyConfig)
for prop in self.hparams.properties
)
task_name = self.hparams.task_name if self.hparams.task_name is not None else atoms.info.get("task_name", None)
assert task_name is not None, "task_name must be provided for UMA models. Choices include ['oc20', 'omol', 'omat', 'odac', 'omc']"
info_keys = atoms.info.keys()
info_keys = [key.lower() for key in info_keys]
data_keys = []
if "charge" in info_keys:
data_keys.append("charge")
if "spin" in info_keys:
data_keys.append("spin")
data = AtomicData.from_ase(
input_atoms=atoms,
radius=self.hparams.atoms_to_graph.radius,
max_neigh=self.hparams.atoms_to_graph.max_num_neighbors,
r_energy=energy,
r_forces=forces,
r_stress=stress,
r_data_keys=data_keys,
task_name=task_name,
)
return data
[docs]
@override
def create_normalization_context_from_batch(self, batch):
# with optional_import_error_message("torch_scatter"):
# from torch_scatter import scatter # type: ignore[reportMissingImports] # noqa
atomic_numbers: torch.Tensor = batch["atomic_numbers"].long() # (n_atoms,) # type: ignore[index]
batch_idx: torch.Tensor = batch["batch"] # (n_atoms,) # type: ignore[index]
## get num_atoms per sample
all_ones = torch.ones_like(atomic_numbers)
num_atoms = torch.zeros(batch.num_graphs, device=atomic_numbers.device, dtype=torch.long)
num_atoms.index_add_(0, batch_idx, all_ones)
# num_atoms = scatter(
# all_ones,
# batch_idx,
# dim=0,
# dim_size=batch.num_graphs,
# reduce="sum",
# )
# Convert atomic numbers to one-hot encoding
atom_types_onehot = F.one_hot(atomic_numbers, num_classes=120)
compositions = torch.zeros((batch.num_graphs, 120), device=atomic_numbers.device, dtype=torch.long)
compositions.index_add_(0, batch_idx, atom_types_onehot)
# compositions = scatter(
# atom_types_onehot,
# batch_idx,
# dim=0,
# dim_size=batch.num_graphs,
# reduce="sum",
# )
compositions = compositions[:, 1:] # Remove the zeroth element
return NormalizationContext(num_atoms=num_atoms, compositions=compositions)
[docs]
def merge_MOLE_model(self, atoms: Atoms):
with optional_import_error_message("fairchem-core"):
from fairchem.core.models.uma.escn_moe import eSCNMDMoeBackbone
assert isinstance(self.backbone, eSCNMDMoeBackbone), "Merging MOLE models is only supported for eSCNMDMoeBackbone." # type: ignore[reportGeneralTypeIssues]
data = self.atoms_to_data(atoms, has_labels=False)
batch = self.collate_fn([data])
batch = batch.to(self.device) # type: ignore[reportGeneralTypeIssues]
new_backbone = self.backbone.merge_MOLE_model(batch) # type: ignore[reportGeneralTypeIssues]
self.backbone = new_backbone.float().to(self.device) # type: ignore[reportGeneralTypeIssues]