Source code for mattertune.backbones.nequip_foundation.nequip_model

from __future__ import annotations

import contextlib
import importlib.util
import logging
from typing import TYPE_CHECKING, Any, Literal, cast
from pathlib import Path

import nshconfig as C
import torch
import torch.nn.functional as F
from ase.units import GPa
from typing_extensions import final, override

from ...finetune import properties as props
from ...finetune.base import FinetuneModuleBase, FinetuneModuleBaseConfig, ModelOutput
from ...normalization import NormalizationContext
from ...registry import backbone_registry
from ...util import optional_import_error_message

if TYPE_CHECKING:
    from nequip.data import AtomicDataDict
    

log = logging.getLogger(__name__)

MODEL_URLS = {
    "NequIP-OAM-L-0.1": "https://zenodo.org/api/records/16980200/files/NequIP-OAM-L-0.1.nequip.zip/content",
    "NequIP-MP-L-0.1": "https://zenodo.org/api/records/16980200/files/NequIP-MP-L-0.1.nequip.zip/content",
    "Allegro-OAM-L-0.1": "https://zenodo.org/api/records/16980200/files/Allegro-OAM-L-0.1.nequip.zip/content",
    "Allegro-MP-L-0.1": "https://zenodo.org/api/records/16980200/files/Allegro-MP-L-0.1.nequip.zip/content",
}
CACHE_DIR = Path(torch.hub.get_dir()) / "nequip_checkpoints"
CACHE_DIR.mkdir(parents=True, exist_ok=True)

PROPERTY_KEY_MAP = {
    "energy": "total_energy",
    "forces": "forces",
    "stresses": "stress",
}


[docs] @backbone_registry.register class NequIPBackboneConfig(FinetuneModuleBaseConfig): name: Literal["nequip"] = "nequip" """The type of the backbone.""" pretrained_model: str = "NequIP-OAM-L-0.1" """ The name of the pretrained model to load. - NequIP-OAM-L-0.1: NequIP foundational potential model for materials, pretrained on OAM dataset. - NequIP-MP-L-0.1: NequIP foundational potential model pretrained on MP dataset. - Allegro-OAM-L-0.1: Allegro foundational potential model for materials, pretrained on OAM dataset. - Allegro-MP-L-0.1: Allegro foundational potential model pretrained on MP dataset. """
[docs] @override def create_model(self): assert self.freeze_backbone is False, "Freezing the NequIP backbone is not supported, since there is no output heads for NequIP." return NequIPBackboneModule(self)
[docs] @override @classmethod def ensure_dependencies(cls): # Make sure the jmp module is available if importlib.util.find_spec("nequip") is None: raise ImportError( "The nequip is not installed. Please install it by following our installation guide." )
[docs] @final class NequIPBackboneModule( FinetuneModuleBase["AtomicDataDict.Type", "AtomicDataDict.Type", NequIPBackboneConfig] ):
[docs] @override @classmethod def hparams_cls(cls): return NequIPBackboneConfig
def _should_enable_grad(self): return True
[docs] @override def requires_disabled_inference_mode(self): return self._should_enable_grad()
[docs] @override def setup(self, stage: str): super().setup(stage) if self._should_enable_grad(): for loop in ( self.trainer.validate_loop, self.trainer.test_loop, self.trainer.predict_loop, ): if loop.inference_mode: raise ValueError( "Cannot run inference mode with forces or stress calculation. " "Please set `inference_mode` to False in the trainer configuration." )
[docs] @override def create_model(self): with optional_import_error_message("nequip"): from nequip.model.saved_models.package import ModelFromPackage from nequip.nn.graph_model import GraphModel from nequip.ase.nequip_calculator import _create_neighbor_transform from nequip.data.transforms import ( ChemicalSpeciesToAtomTypeMapper, NeighborListTransform, ) from nequip.nn import graph_model pretrained_model = self.hparams.pretrained_model if pretrained_model in MODEL_URLS: cached_ckpt_path = CACHE_DIR / f"{pretrained_model}.nequip.zip" if not cached_ckpt_path.exists(): log.info( f"Downloading the pretrained model from {MODEL_URLS[pretrained_model]}" ) torch.hub.download_url_to_file( MODEL_URLS[pretrained_model], str(cached_ckpt_path) ) ckpt_path = cached_ckpt_path else: ckpt_path = None raise ValueError( f"Unknown pretrained model: {pretrained_model}, available models: {MODEL_URLS.keys()}" ) model = ModelFromPackage(package_path=str(ckpt_path)) self.backbone: GraphModel = model["sole_model"] self.metadata = self.backbone.metadata self.r_max = float(self.metadata[graph_model.R_MAX_KEY]) self.type_names = self.metadata[graph_model.TYPE_NAMES_KEY].split(" ") self.neighbor_transform: NeighborListTransform = _create_neighbor_transform(metadata=self.metadata, r_max=self.r_max, type_names=self.type_names) chemical_species_to_atom_type_map = {sym: sym for sym in self.type_names} self.atomtype_transform: ChemicalSpeciesToAtomTypeMapper = ChemicalSpeciesToAtomTypeMapper( model_type_names=self.type_names, chemical_species_to_atom_type_map=chemical_species_to_atom_type_map, ) for prop in self.hparams.properties: assert isinstance(prop, (props.EnergyPropertyConfig, props.ForcesPropertyConfig, props.StressesPropertyConfig)), \ f"Unsupported property {prop.name} for NequIP backbone. Supported properties are energy, forces, and stresses." if isinstance(prop, (props.ForcesPropertyConfig, props.StressesPropertyConfig)): assert prop.conservative is True, f"Non-conservative {prop.name} is not supported for NequIP backbone."
[docs] @override def trainable_parameters(self): for name, param in self.backbone.named_parameters(): yield name, param
[docs] @override @contextlib.contextmanager def model_forward_context(self, data, mode: str): with contextlib.ExitStack() as stack: stack.enter_context(torch.enable_grad()) yield
[docs] @override def model_forward( self, batch: AtomicDataDict.Type, mode: str ): output = self.backbone(batch) predicted_properties: dict[str, torch.Tensor] = {} for prop in self.hparams.properties: key = PROPERTY_KEY_MAP.get(prop.name) if key is not None and key in output: predicted_properties[prop.name] = output[key].to(torch.float32) else: raise ValueError(f"Property {prop.name} not found in the model output.") pred: ModelOutput = {"predicted_properties": predicted_properties} return pred
[docs] @override def pretrained_backbone_parameters(self): return self.backbone.parameters()
[docs] @override def output_head_parameters(self): return []
[docs] @override def cpu_data_transform(self, data): return data
[docs] @override def collate_fn(self, data_list): with optional_import_error_message("nequip"): from nequip.data import AtomicDataDict return AtomicDataDict.batched_from_list(data_list)
[docs] @override def gpu_batch_transform(self, batch): batch = self.atomtype_transform(batch) batch = self.neighbor_transform(batch) return batch
[docs] @override def batch_to_labels(self, batch): labels: dict[str, torch.Tensor] = {} for prop in self.hparams.properties: labels[prop.name] = batch[PROPERTY_KEY_MAP[prop.name]] return labels
[docs] @override def atoms_to_data(self, atoms, has_labels: bool=True): import copy with optional_import_error_message("nequip"): from nequip.data.ase import from_ase data = from_ase(atoms) # if has_labels: # for prop in self.hparams.properties: # value = prop._from_ase_atoms_to_torch(atoms).float() # # For stress, we should make sure it is (3, 3), not the flattened (6,) # # that ASE returns. # if isinstance(prop, props.StressesPropertyConfig): # from ase.constraints import voigt_6_to_full_3x3_stress # value = voigt_6_to_full_3x3_stress(value.numpy()) # value = torch.from_numpy(value).reshape(1, 3, 3) # if isinstance(prop, props.EnergyPropertyConfig): # value = value.reshape(1, 1) # data[prop.name + "_gt"] = value return data
[docs] @override def create_normalization_context_from_batch(self, batch): atomic_numbers: torch.Tensor = batch["atomic_numbers"].long() # (n_atoms,) batch_idx: torch.Tensor = batch["batch"] # (n_atoms,) num_graphs = int(batch_idx.max().item()) + 1 ## get num_atoms per sample all_ones = torch.ones_like(atomic_numbers) num_atoms = torch.zeros(num_graphs, device=atomic_numbers.device, dtype=torch.long) num_atoms.index_add_(0, batch_idx, all_ones) # Convert atomic numbers to one-hot encoding atom_types_onehot = F.one_hot(atomic_numbers, num_classes=120) compositions = torch.zeros((num_graphs, 120), device=atomic_numbers.device, dtype=torch.long) compositions.index_add_(0, batch_idx, atom_types_onehot) compositions = compositions[:, 1:] # Remove the zeroth element return NormalizationContext(num_atoms=num_atoms, compositions=compositions)
[docs] @override def apply_callable_to_backbone(self, fn): return fn(self.backbone)
[docs] @override def batch_to_device( self, batch: AtomicDataDict.Type, device: torch.device | str, ): with optional_import_error_message("nequip"): from nequip.data import AtomicDataDict if type(device) is str: device = torch.device(device) return AtomicDataDict.to_(batch, device) # type: ignore
[docs] @override def apply_pruning_message_passing(self, message_passing_steps: int|None): """ Apply message passing for early stopping. """ raise NotImplementedError("For now, NequIP/Allegro models do not support pruning and partition acceleration")
[docs] @override def get_connectivity_from_atoms(self, atoms): """ Get the connectivity from the data. This is used to extract the connectivity information from the data object. This is useful for message passing and other graph-based operations. Returns: edge_index: Array of shape (2, num_edges) containing the src and dst indices of the edges. """ raise NotImplementedError("For now, NequIP/Allegro models do not support pruning and partition acceleration")
[docs] @override def get_connectivity_from_data(self, data) -> torch.Tensor: """ Get the connectivity from the data. This is used to extract the connectivity information from the data object. This is useful for message passing and other graph-based operations. Returns: edge_index: Tensor of shape (2, num_edges) containing the src and dst indices of the edges. """ raise NotImplementedError("For now, NequIP/Allegro models do not support pruning and partition acceleration")
[docs] @override def model_forward_partition( self, batch, mode: str, using_partition: bool = False, ) -> ModelOutput: """ Forward pass of the model under partitioning. Args: batch: Input batch. Returns: Prediction of the model. """ raise NotImplementedError("For now, NequIP/Allegro models do not support pruning and partition acceleration")