Source code for mattertune.backbones.mace_foundation.model

from __future__ import annotations

import contextlib
import importlib.util
import logging
from typing import TYPE_CHECKING, Any, Literal, cast

import torch
import torch.nn.functional as F
from typing_extensions import final, override

from ...finetune import properties as props
from ...finetune.base import FinetuneModuleBase, FinetuneModuleBaseConfig, ModelOutput
from ...normalization import NormalizationContext
from ...registry import backbone_registry
from ...util import optional_import_error_message

if TYPE_CHECKING:
    from mace.tools.torch_geometric import Data, Batch

log = logging.getLogger(__name__)


[docs] @backbone_registry.register class MACEBackboneConfig(FinetuneModuleBaseConfig): name: Literal["mace"] = "mace" """The type of the backbone.""" pretrained_model: str """ The name of the pretrained model to load, please pass the name of the model in the following format: mace-<model_name>. supported <model_name> are: [ "small", "medium", "large", "small-0b", "medium-0b", "large-0b", "small-0b2", "medium-0b2", "medium-0b3", "large-0b2", "medium-omat-0", "small_off", "medium_off", "large_off", ] """
[docs] @override def create_model(self): return MACEBackboneModule(self)
[docs] @override @classmethod def ensure_dependencies(cls): # Make sure the mace package is available if importlib.util.find_spec("mace") is None: raise ImportError( "The mace is not installed. Please install it by following our installation guide." )
[docs] @final class MACEBackboneModule( FinetuneModuleBase["Data", "Batch", MACEBackboneConfig] ):
[docs] @override @classmethod def hparams_cls(cls): return MACEBackboneConfig
def _should_enable_grad(self): # MACE requires gradients to be enabled for force and stress calculations return self.calc_forces or self.calc_stress
[docs] @override def requires_disabled_inference_mode(self): return self._should_enable_grad()
[docs] @override def setup(self, stage: str): super().setup(stage) if self._should_enable_grad(): for loop in ( self.trainer.validate_loop, self.trainer.test_loop, self.trainer.predict_loop, ): if loop.inference_mode: raise ValueError( "MACE computes forces and stresses, which requires gradients to be enabled. " "Please set `inference_mode` to False in the trainer configuration." )
[docs] @override def create_model(self): with optional_import_error_message("mace"): from mace.modules.models import ScaleShiftMACE from mace.calculators.foundations_models import mace_mp, mace_off from mace.tools import utils as mace_utils model_name = self.hparams.pretrained_model.replace("mace-", "") if model_name in ("small", "medium", "large", "small-0b", "medium-0b", "large-0b", "small-0b2", "medium-0b2", "medium-0b3", "large-0b2", "medium-omat-0"): calc = mace_mp(model=model_name) model_foundation = calc.models[0] elif model_name in ["small_off", "medium_off", "large_off"]: calc = mace_off(model=model_name.split("_")[0]) model_foundation = calc.models[0] else: ## Load from a local file model_foundation = torch.load(model_name, map_location="cpu") ## TODO: Up to May 19th, 2025, all these pretrained MACE models are ScaleShiftMACE models. assert isinstance(model_foundation, ScaleShiftMACE), f"Model {model_name} is not a ScaleShiftMACE model" self.backbone = model_foundation.float().train() for p in self.backbone.parameters(): p.requires_grad_(True) self.z_table = mace_utils.AtomicNumberTable([int(z) for z in self.backbone.atomic_numbers]) # type: ignore self.cutoff = self.backbone.r_max.cpu().item() # type: ignore self.energy_prop_name = "energy" self.forces_prop_name = "forces" self.stress_prop_name = "stresses" self.calc_forces = False self.calc_stress = False for prop in self.hparams.properties: match prop: case props.EnergyPropertyConfig(): self.energy_prop_name = prop.name case props.ForcesPropertyConfig(): assert prop.conservative, ( "Only conservative forces are supported for mace" ) self.forces_prop_name = prop.name self.calc_forces = True case props.StressesPropertyConfig(): assert prop.conservative, ( "Only conservative stress are supported for mace" ) self.stress_prop_name = prop.name self.calc_stress = True case _: raise ValueError( f"Unsupported property config: {prop} for mace" "Please ask the maintainers of MatterTune or MatterSim for support" ) if not self.calc_forces and self.calc_stress: raise ValueError( "Stress calculation requires force calculation, cannot calculate stress without force" )
[docs] @override def trainable_parameters(self): for name, param in self.backbone.named_parameters(): yield name, param
[docs] @override @contextlib.contextmanager def model_forward_context(self, data, mode: str): with contextlib.ExitStack() as stack: if self.calc_forces or self.calc_stress: stack.enter_context(torch.enable_grad()) yield
[docs] @override def model_forward( self, batch: Batch, mode: str ): output = self.backbone( batch.to_dict(), compute_force=True, compute_stress=True, training=mode == "train", ) output_pred = {} output_pred[self.energy_prop_name] = output.get("energy", torch.zeros(1)) if self.calc_forces: output_pred[self.forces_prop_name] = output.get("forces") if self.calc_stress: output_pred[self.stress_prop_name] = output.get("stresses") pred: ModelOutput = {"predicted_properties": output_pred} return pred
[docs] @override def pretrained_backbone_parameters(self): return self.backbone.parameters()
[docs] @override def output_head_parameters(self): return []
[docs] @override def cpu_data_transform(self, data): return data
[docs] @override def collate_fn(self, data_list): with optional_import_error_message("mace"): from mace.tools.torch_geometric import Batch return Batch.from_data_list(data_list)
[docs] @override def gpu_batch_transform(self, batch): return batch
[docs] @override def batch_to_labels(self, batch): labels: dict[str, torch.Tensor] = {} for prop in self.hparams.properties: labels[prop.name] = getattr(batch, prop.name) return labels
[docs] @override def atoms_to_data(self, atoms, has_labels): with optional_import_error_message("mace"): from mace import data as mace_data data_config = mace_data.config_from_atoms(atoms) data = mace_data.AtomicData.from_config( data_config, z_table=self.z_table, cutoff=self.cutoff, ) setattr(data, "atomic_numbers", torch.tensor(atoms.get_atomic_numbers())) if has_labels: for prop in self.hparams.properties: value = prop._from_ase_atoms_to_torch(atoms) # For stress, we should make sure it is (3, 3), not the flattened (6,) # that ASE returns. if isinstance(prop, props.StressesPropertyConfig): from ase.constraints import voigt_6_to_full_3x3_stress value = voigt_6_to_full_3x3_stress(value.float().numpy()) value = torch.from_numpy(value).float().reshape(1, 3, 3) setattr(data, prop.name, value) return data
[docs] @override def create_normalization_context_from_batch(self, batch): with optional_import_error_message("torch_scatter"): from mace.tools.scatter import scatter_sum atomic_numbers: torch.Tensor = batch["atomic_numbers"].long() # type: ignore (n_atoms,) batch_idx: torch.Tensor = batch["batch"] # type: ignore (n_atoms,) ## get num_atoms per sample all_ones = torch.ones_like(atomic_numbers) num_atoms = scatter_sum( all_ones, batch_idx, dim=0, dim_size=batch.num_graphs, reduce="sum", ) # Convert atomic numbers to one-hot encoding atom_types_onehot = F.one_hot(atomic_numbers, num_classes=120) compositions = scatter_sum( atom_types_onehot, batch_idx, dim=0, dim_size=batch.num_graphs, reduce="sum", ) compositions = compositions[:, 1:] # Remove the zeroth element return NormalizationContext(num_atoms=num_atoms, compositions=compositions)
[docs] @override def apply_callable_to_backbone(self, fn): return fn(self.backbone)