from __future__ import annotations
import contextlib
import importlib.util
import logging
from typing import TYPE_CHECKING, Any, Literal, cast
from pathlib import Path
import nshconfig as C
import torch
import torch.nn.functional as F
from ase.units import GPa
from typing_extensions import final, override
from ...finetune import properties as props
from ...finetune.base import FinetuneModuleBase, FinetuneModuleBaseConfig, ModelOutput
from ...normalization import NormalizationContext
from ...registry import backbone_registry
from ...util import optional_import_error_message
if TYPE_CHECKING:
from nequip.data import AtomicDataDict
log = logging.getLogger(__name__)
MODEL_URLS = {
"NequIP-OAM-L-0.1": "https://zenodo.org/api/records/16980200/files/NequIP-OAM-L-0.1.nequip.zip/content",
"NequIP-MP-L-0.1": "https://zenodo.org/api/records/16980200/files/NequIP-MP-L-0.1.nequip.zip/content",
"Allegro-OAM-L-0.1": "https://zenodo.org/api/records/16980200/files/Allegro-OAM-L-0.1.nequip.zip/content",
"Allegro-MP-L-0.1": "https://zenodo.org/api/records/16980200/files/Allegro-MP-L-0.1.nequip.zip/content",
}
CACHE_DIR = Path(torch.hub.get_dir()) / "nequip_checkpoints"
CACHE_DIR.mkdir(parents=True, exist_ok=True)
PROPERTY_KEY_MAP = {
"energy": "total_energy",
"forces": "forces",
"stresses": "stress",
}
[docs]
@backbone_registry.register
class NequIPBackboneConfig(FinetuneModuleBaseConfig):
name: Literal["nequip"] = "nequip"
"""The type of the backbone."""
pretrained_model: str = "NequIP-OAM-L-0.1"
"""
The name of the pretrained model to load.
- NequIP-OAM-L-0.1: NequIP foundational potential model for materials, pretrained on OAM dataset.
- NequIP-MP-L-0.1: NequIP foundational potential model pretrained on MP dataset.
- Allegro-OAM-L-0.1: Allegro foundational potential model for materials, pretrained on OAM dataset.
- Allegro-MP-L-0.1: Allegro foundational potential model pretrained on MP dataset.
"""
[docs]
@override
def create_model(self):
assert self.freeze_backbone is False, "Freezing the NequIP backbone is not supported, since there is no output heads for NequIP."
return NequIPBackboneModule(self)
[docs]
@override
@classmethod
def ensure_dependencies(cls):
# Make sure the jmp module is available
if importlib.util.find_spec("nequip") is None:
raise ImportError(
"The nequip is not installed. Please install it by following our installation guide."
)
[docs]
@final
class NequIPBackboneModule(
FinetuneModuleBase["AtomicDataDict.Type", "AtomicDataDict.Type", NequIPBackboneConfig]
):
[docs]
@override
@classmethod
def hparams_cls(cls):
return NequIPBackboneConfig
def _should_enable_grad(self):
return True
[docs]
@override
def requires_disabled_inference_mode(self):
return self._should_enable_grad()
[docs]
@override
def setup(self, stage: str):
super().setup(stage)
if self._should_enable_grad():
for loop in (
self.trainer.validate_loop,
self.trainer.test_loop,
self.trainer.predict_loop,
):
if loop.inference_mode:
raise ValueError(
"Cannot run inference mode with forces or stress calculation. "
"Please set `inference_mode` to False in the trainer configuration."
)
[docs]
@override
def create_model(self):
with optional_import_error_message("nequip"):
from nequip.model.saved_models.package import ModelFromPackage
from nequip.nn.graph_model import GraphModel
from nequip.ase.nequip_calculator import _create_neighbor_transform
from nequip.data.transforms import (
ChemicalSpeciesToAtomTypeMapper,
NeighborListTransform,
)
from nequip.nn import graph_model
pretrained_model = self.hparams.pretrained_model
if pretrained_model in MODEL_URLS:
cached_ckpt_path = CACHE_DIR / f"{pretrained_model}.nequip.zip"
if not cached_ckpt_path.exists():
log.info(
f"Downloading the pretrained model from {MODEL_URLS[pretrained_model]}"
)
torch.hub.download_url_to_file(
MODEL_URLS[pretrained_model], str(cached_ckpt_path)
)
ckpt_path = cached_ckpt_path
else:
ckpt_path = None
raise ValueError(
f"Unknown pretrained model: {pretrained_model}, available models: {MODEL_URLS.keys()}"
)
model = ModelFromPackage(package_path=str(ckpt_path))
self.backbone: GraphModel = model["sole_model"]
self.metadata = self.backbone.metadata
self.r_max = float(self.metadata[graph_model.R_MAX_KEY])
self.type_names = self.metadata[graph_model.TYPE_NAMES_KEY].split(" ")
self.neighbor_transform: NeighborListTransform = _create_neighbor_transform(metadata=self.metadata, r_max=self.r_max, type_names=self.type_names)
chemical_species_to_atom_type_map = {sym: sym for sym in self.type_names}
self.atomtype_transform: ChemicalSpeciesToAtomTypeMapper = ChemicalSpeciesToAtomTypeMapper(
model_type_names=self.type_names,
chemical_species_to_atom_type_map=chemical_species_to_atom_type_map,
)
for prop in self.hparams.properties:
assert isinstance(prop, (props.EnergyPropertyConfig, props.ForcesPropertyConfig, props.StressesPropertyConfig)), \
f"Unsupported property {prop.name} for NequIP backbone. Supported properties are energy, forces, and stresses."
if isinstance(prop, (props.ForcesPropertyConfig, props.StressesPropertyConfig)):
assert prop.conservative is True, f"Non-conservative {prop.name} is not supported for NequIP backbone."
[docs]
@override
def trainable_parameters(self):
for name, param in self.backbone.named_parameters():
yield name, param
[docs]
@override
@contextlib.contextmanager
def model_forward_context(self, data, mode: str):
with contextlib.ExitStack() as stack:
stack.enter_context(torch.enable_grad())
yield
[docs]
@override
def model_forward(
self, batch: AtomicDataDict.Type, mode: str
):
output = self.backbone(batch)
predicted_properties: dict[str, torch.Tensor] = {}
for prop in self.hparams.properties:
key = PROPERTY_KEY_MAP.get(prop.name)
if key is not None and key in output:
predicted_properties[prop.name] = output[key].to(torch.float32)
else:
raise ValueError(f"Property {prop.name} not found in the model output.")
pred: ModelOutput = {"predicted_properties": predicted_properties}
return pred
[docs]
@override
def pretrained_backbone_parameters(self):
return self.backbone.parameters()
[docs]
@override
def output_head_parameters(self):
return []
[docs]
@override
def collate_fn(self, data_list):
with optional_import_error_message("nequip"):
from nequip.data import AtomicDataDict
return AtomicDataDict.batched_from_list(data_list)
[docs]
@override
def batch_to_labels(self, batch):
labels: dict[str, torch.Tensor] = {}
for prop in self.hparams.properties:
labels[prop.name] = batch[PROPERTY_KEY_MAP[prop.name]]
return labels
[docs]
@override
def atoms_to_data(self, atoms, has_labels: bool=True):
import copy
with optional_import_error_message("nequip"):
from nequip.data.ase import from_ase
data = from_ase(atoms)
# if has_labels:
# for prop in self.hparams.properties:
# value = prop._from_ase_atoms_to_torch(atoms).float()
# # For stress, we should make sure it is (3, 3), not the flattened (6,)
# # that ASE returns.
# if isinstance(prop, props.StressesPropertyConfig):
# from ase.constraints import voigt_6_to_full_3x3_stress
# value = voigt_6_to_full_3x3_stress(value.numpy())
# value = torch.from_numpy(value).reshape(1, 3, 3)
# if isinstance(prop, props.EnergyPropertyConfig):
# value = value.reshape(1, 1)
# data[prop.name + "_gt"] = value
return data
[docs]
@override
def create_normalization_context_from_batch(self, batch):
atomic_numbers: torch.Tensor = batch["atomic_numbers"].long() # (n_atoms,)
batch_idx: torch.Tensor = batch["batch"] # (n_atoms,)
num_graphs = int(batch_idx.max().item()) + 1
## get num_atoms per sample
all_ones = torch.ones_like(atomic_numbers)
num_atoms = torch.zeros(num_graphs, device=atomic_numbers.device, dtype=torch.long)
num_atoms.index_add_(0, batch_idx, all_ones)
# Convert atomic numbers to one-hot encoding
atom_types_onehot = F.one_hot(atomic_numbers, num_classes=120)
compositions = torch.zeros((num_graphs, 120), device=atomic_numbers.device, dtype=torch.long)
compositions.index_add_(0, batch_idx, atom_types_onehot)
compositions = compositions[:, 1:] # Remove the zeroth element
return NormalizationContext(num_atoms=num_atoms, compositions=compositions)
[docs]
@override
def apply_callable_to_backbone(self, fn):
return fn(self.backbone)
[docs]
@override
def batch_to_device(
self,
batch: AtomicDataDict.Type,
device: torch.device | str,
):
with optional_import_error_message("nequip"):
from nequip.data import AtomicDataDict
if type(device) is str:
device = torch.device(device)
return AtomicDataDict.to_(batch, device) # type: ignore
[docs]
@override
def apply_pruning_message_passing(self, message_passing_steps: int|None):
"""
Apply message passing for early stopping.
"""
raise NotImplementedError("For now, NequIP/Allegro models do not support pruning and partition acceleration")
[docs]
@override
def get_connectivity_from_atoms(self, atoms):
"""
Get the connectivity from the data. This is used to extract the connectivity
information from the data object. This is useful for message passing
and other graph-based operations.
Returns:
edge_index: Array of shape (2, num_edges) containing the src and dst indices of the edges.
"""
raise NotImplementedError("For now, NequIP/Allegro models do not support pruning and partition acceleration")
[docs]
@override
def get_connectivity_from_data(self, data) -> torch.Tensor:
"""
Get the connectivity from the data. This is used to extract the connectivity
information from the data object. This is useful for message passing
and other graph-based operations.
Returns:
edge_index: Tensor of shape (2, num_edges) containing the src and dst indices of the edges.
"""
raise NotImplementedError("For now, NequIP/Allegro models do not support pruning and partition acceleration")
[docs]
@override
def model_forward_partition(
self,
batch,
mode: str,
using_partition: bool = False,
) -> ModelOutput:
"""
Forward pass of the model under partitioning.
Args:
batch: Input batch.
Returns:
Prediction of the model.
"""
raise NotImplementedError("For now, NequIP/Allegro models do not support pruning and partition acceleration")