Source code for mattertune.backbones.nequip_foundation.util

from __future__ import annotations

import os
import yaml
from pathlib import Path
import rich

import torch
from torch.package import PackageExporter, PackageImporter
from ase import Atoms

from mattertune.util import optional_import_error_message
from mattertune.backbones import NequIPBackboneModule

[docs] def nequip_model_package( ckpt_path: str | Path, example_atoms: Atoms, output_path: str | Path, ): """ A suggested NequIP workflow is: 1. Train a NequIP model and save the checkpoint (.ckpt) file. 2. Test the trained model using the checkpoint file if needed. 3. Package the trained model into a NequIP package file (.nequip.zip) 4. Compile the NequIP package file into a compiled model file (.nequip.pth/pt2) This function packages a trained NequIP model from a checkpoint file into a NequIP package file. The implementation of this function is based on the nequip-package API in the NequIP repository, and the .nequip.zip packages produced by this function are fully compatible with subsequent nequip-compile api in nequip repo. Some references: 1. nequip workflow: https://nequip.readthedocs.io/en/latest/guide/getting-started/workflow.html 2. example usage: TO-BE-ADDED """ assert os.path.exists(ckpt_path), f"Checkpoint path {ckpt_path} does not exist." assert len(example_atoms) > 3, f"Example atoms must contain more than 3 atoms, found {len(example_atoms)} atoms." assert str(output_path).endswith(".nequip.zip"), f"Output path must end with .nequip.zip, found {output_path}" with optional_import_error_message("nequip"): from nequip.train.lightning import _SOLE_MODEL_KEY from nequip.data import AtomicDataDict from nequip.utils.global_dtype import _GLOBAL_DTYPE from nequip.utils.versions import get_current_code_versions from nequip.utils.versions.version_utils import get_version_safe from nequip.scripts.package import _CURRENT_NEQUIP_PACKAGE_VERSION from nequip.scripts._package_utils import ( _EXTERNAL_MODULES, _MOCK_MODULES, _INTERNAL_MODULES, ) from nequip.scripts._workflow_utils import set_workflow_state from nequip.model.saved_models.package import ( _get_shared_importer, _suppress_package_importer_exporter_warnings, _get_package_metadata, ) from nequip.model.utils import ( _COMPILE_MODE_OPTIONS, _EAGER_MODEL_KEY, ) set_workflow_state("package") mt_module = NequIPBackboneModule.load_from_checkpoint(ckpt_path).to(torch.device("cpu")) mt_backbone = mt_module.backbone eager_model = torch.nn.ModuleDict({_SOLE_MODEL_KEY: mt_backbone}) data = mt_module.atoms_to_data(example_atoms) data = mt_module.atomtype_transform(data) data = mt_module.neighbor_transform(data) if AtomicDataDict.CELL_KEY not in data: data[AtomicDataDict.CELL_KEY] = 1e5 * torch.eye( 3, dtype=_GLOBAL_DTYPE, device=data[AtomicDataDict.POSITIONS_KEY].device, ).unsqueeze(0) data[AtomicDataDict.EDGE_CELL_SHIFT_KEY] = torch.zeros( (data[AtomicDataDict.EDGE_INDEX_KEY].size(1), 3), dtype=_GLOBAL_DTYPE, device=data[AtomicDataDict.POSITIONS_KEY].device, ) data = AtomicDataDict.to_(data, device=torch.device("cpu")) code_versions = get_current_code_versions() models_to_package = {_EAGER_MODEL_KEY: eager_model} type_names = mt_module.type_names pkg_metadata = { "versions": code_versions, "external_modules": { k: get_version_safe(k) for k in _EXTERNAL_MODULES }, "package_version_id": _CURRENT_NEQUIP_PACKAGE_VERSION, "available_models": list(models_to_package.keys()), "atom_types": {idx: name for idx, name in enumerate(type_names)}, } ## TODO: In original NequIP code, they wrap the entire config.yaml for training into this dummy_config. ## However, it seems that the dummy config is not used in nequip-compile ## So for now, we just create a minimal dummy_config dummy_config = {"generated_by": "MatterTune-nequip-export", "version": "0.1"} orig_config_yaml = yaml.safe_dump(dummy_config, sort_keys=False) pkg_metadata_yaml = yaml.safe_dump(pkg_metadata, sort_keys=False) importers = (torch.package.importer.sys_importer,) imp = _get_shared_importer() ## return a global variable _PACKAGE_TIME_SHARED_IMPORTER. print(imp) assert imp is not None, f"Failed to get shared importer, it should not be None." if imp is not None: importers = (imp,) + importers output_path = Path(output_path) with _suppress_package_importer_exporter_warnings(): with PackageExporter(str(output_path), importer=importers, debug=True) as exp: exp.mock([f"{pkg}.**" for pkg in _MOCK_MODULES]) exp.extern([f"{pkg}.**" for pkg in _EXTERNAL_MODULES]) exp.intern([f"{pkg}.**" for pkg in _INTERNAL_MODULES]) exp.save_pickle( package="model", resource="example_data.pkl", obj=data, dependencies=True, ) exp.save_text( "model", "config.yaml", orig_config_yaml, ) exp.save_text( "model", "package_metadata.txt", pkg_metadata_yaml, ) for compile_mode, model in models_to_package.items(): model = model.to(torch.device("cpu")) exp.save_pickle( package="model", resource=f"{compile_mode}_model.pkl", obj=model, dependencies=True, ) rich.print("Saved package to", output_path)