from __future__ import annotations
import contextlib
import importlib.util
import logging
from contextlib import ExitStack
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, cast
import nshconfig as C
import nshconfig_extra as CE
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ase import Atoms
from typing_extensions import final, override
from ...finetune import properties as props
from ...finetune.base import (
FinetuneModuleBase,
FinetuneModuleBaseConfig,
ModelOutput,
_SkipBatchError,
)
from ...normalization import NormalizationContext
from ...registry import backbone_registry
from ...util import optional_import_error_message
from .util import get_activation_cls
if TYPE_CHECKING:
from torch_geometric.data import Batch, Data # type: ignore[reportMissingImports] # noqa
from torch_geometric.data.data import BaseData # type: ignore[reportMissingImports] # noqa
log = logging.getLogger(__name__)
[docs]
class CutoffsConfig(C.Config):
main: float
aeaint: float
qint: float
aint: float
[docs]
@classmethod
def from_constant(cls, value: float):
return cls(main=value, aeaint=value, qint=value, aint=value)
[docs]
class MaxNeighborsConfig(C.Config):
main: int
aeaint: int
qint: int
aint: int
[docs]
@classmethod
def from_goc_base_proportions(cls, max_neighbors: int):
"""
GOC base proportions:
max_neighbors: 30
max_neighbors_qint: 8
max_neighbors_aeaint: 20
max_neighbors_aint: 1000
"""
return cls(
main=max_neighbors,
aeaint=int(max_neighbors * 20 / 30),
qint=int(max_neighbors * 8 / 30),
aint=int(max_neighbors * 1000 / 30),
)
[docs]
class JMPGraphComputerConfig(C.Config):
pbc: bool
"""Whether to use periodic boundary conditions."""
cutoffs: CutoffsConfig = CutoffsConfig.from_constant(12.0)
"""The cutoff for the radius graph."""
max_neighbors: MaxNeighborsConfig = MaxNeighborsConfig.from_goc_base_proportions(30)
"""The maximum number of neighbors for the radius graph."""
per_graph_radius_graph: bool = False
"""Whether to compute the radius graph per graph."""
def _to_jmp_graph_computer_config(self):
with optional_import_error_message("jmp"):
from jmp.models.gemnet.graph import CutoffsConfig, GraphComputerConfig, MaxNeighborsConfig # type: ignore[reportMissingImports] # noqa # fmt: skip
return GraphComputerConfig(
pbc=self.pbc,
cutoffs=CutoffsConfig(
main=self.cutoffs.main,
aeaint=self.cutoffs.aeaint,
qint=self.cutoffs.qint,
aint=self.cutoffs.aint,
),
max_neighbors=MaxNeighborsConfig(
main=self.max_neighbors.main,
aeaint=self.max_neighbors.aeaint,
qint=self.max_neighbors.qint,
aint=self.max_neighbors.aint,
),
per_graph_radius_graph=self.per_graph_radius_graph,
)
[docs]
@backbone_registry.register
class JMPBackboneConfig(FinetuneModuleBaseConfig):
name: Literal["jmp"] = "jmp"
"""The type of the backbone."""
ckpt_path: Path | CE.CachedPath
"""The path to the pre-trained model checkpoint."""
graph_computer: JMPGraphComputerConfig
"""The configuration for the graph computer."""
[docs]
@override
def create_model(self):
return JMPBackboneModule(self)
[docs]
@override
@classmethod
def ensure_dependencies(cls):
# Make sure the jmp module is available
if importlib.util.find_spec("jmp") is None:
raise ImportError(
"The jmp module is not installed. Please install it by running"
" pip install jmp."
)
# Make sure torch-geometric is available
if importlib.util.find_spec("torch_geometric") is None:
raise ImportError(
"The torch-geometric module is not installed. Please install it by running"
" pip install torch-geometric."
)
[docs]
@final
class JMPBackboneModule(FinetuneModuleBase["Data", "Batch", JMPBackboneConfig]):
[docs]
@override
@classmethod
def hparams_cls(cls):
return JMPBackboneConfig
[docs]
@override
def requires_disabled_inference_mode(self):
return False
def _find_potential_energy_prop_name(self):
for prop in self.hparams.properties:
if isinstance(prop, props.EnergyPropertyConfig):
return prop.name
raise ValueError("No energy property found in the property list")
def _create_output_head(self, prop: props.PropertyConfig):
activation_cls = get_activation_cls(self.backbone.hparams.activation)
match prop:
case props.EnergyPropertyConfig():
with optional_import_error_message("jmp"):
from jmp.nn.energy_head import EnergyTargetConfig # type: ignore[reportMissingImports] # noqa
return EnergyTargetConfig(
max_atomic_number=self.backbone.hparams.num_elements
).create_model(
self.backbone.hparams.emb_size_atom,
self.backbone.hparams.emb_size_edge,
activation_cls,
)
case props.ForcesPropertyConfig():
if not prop.conservative:
with optional_import_error_message("jmp"):
from jmp.nn.force_head import ForceTargetConfig # type: ignore[reportMissingImports] # noqa
return ForceTargetConfig().create_model(
self.backbone.hparams.emb_size_edge, activation_cls
)
else:
with optional_import_error_message("jmp"):
from jmp.nn.force_head import ConservativeForceTargetConfig # type: ignore[reportMissingImports] # noqa
force_config = ConservativeForceTargetConfig(
energy_prop_name=self._find_potential_energy_prop_name()
)
return force_config.create_model()
case props.StressesPropertyConfig():
if not prop.conservative:
with optional_import_error_message("jmp"):
from jmp.nn.stress_head import StressTargetConfig # type: ignore[reportMissingImports] # noqa
return StressTargetConfig().create_model(
self.backbone.hparams.emb_size_edge, activation_cls
)
else:
with optional_import_error_message("jmp"):
from jmp.nn.stress_head import ConservativeStressTargetConfig # type: ignore[reportMissingImports] # noqa
stress_config = ConservativeStressTargetConfig(
energy_prop_name=self._find_potential_energy_prop_name()
)
return stress_config.create_model()
case props.GraphPropertyConfig():
with optional_import_error_message("jmp"):
from jmp.nn.graph_scaler import GraphScalarTargetConfig # type: ignore[reportMissingImports] # noqa
return GraphScalarTargetConfig(reduction=prop.reduction).create_model(
self.backbone.hparams.emb_size_atom,
activation_cls,
)
case _:
raise ValueError(
f"Unsupported property config: {prop} for JMP"
"Please ask the maintainers of JMP for support"
)
[docs]
@override
def create_model(self):
# Resolve the checkpoint path
if isinstance(ckpt_path := self.hparams.ckpt_path, CE.CachedPath):
ckpt_path = ckpt_path.resolve()
# Load the backbone from the checkpoint
with optional_import_error_message("jmp"):
from jmp.models.gemnet import GemNetOCBackbone # type: ignore[reportMissingImports] # noqa
from jmp.models.gemnet.graph import GraphComputer # type: ignore[reportMissingImports] # noqa
self.backbone = GemNetOCBackbone.from_pretrained_ckpt(ckpt_path)
log.info(
f"Loaded the model from the checkpoint at {ckpt_path}. The model "
f"has {sum(p.numel() for p in self.backbone.parameters()):,} parameters."
)
# Create the graph computer
self.graph_computer = GraphComputer(
self.hparams.graph_computer._to_jmp_graph_computer_config(),
self.backbone.hparams,
)
# Create the output heads
self.output_heads = nn.ModuleDict()
## Rearange the properties to move the energy property to the front and stress second
self.hparams.properties = sorted(
self.hparams.properties,
key=lambda prop: (
not isinstance(prop, props.EnergyPropertyConfig),
not isinstance(prop, props.StressesPropertyConfig),
),
)
for prop in self.hparams.properties:
self.output_heads[prop.name] = self._create_output_head(prop)
[docs]
@override
@contextlib.contextmanager
def model_forward_context(self, data):
with ExitStack() as stack:
for head in self.output_heads.values():
stack.enter_context(head.forward_context(data))
yield
[docs]
@override
def model_forward(self, batch, return_backbone_output=False):
# Run the backbone
backbone_output = self.backbone(batch)
# Feed the backbone output to the output heads
predicted_properties: dict[str, torch.Tensor] = {}
head_input: dict[str, Any] = {
"data": batch,
"backbone_output": backbone_output,
"predicted_props": predicted_properties,
}
for name, head in self.output_heads.items():
output = head(head_input)
if torch.isnan(output).any() or torch.isinf(output).any():
raise _SkipBatchError("NaN or inf detected in the output")
head_input["predicted_props"][name] = output
pred: ModelOutput = {"predicted_properties": predicted_properties}
if return_backbone_output:
pred["backbone_output"] = backbone_output
return pred
[docs]
@override
def pretrained_backbone_parameters(self):
return self.backbone.parameters()
[docs]
@override
def output_head_parameters(self):
for head in self.output_heads.values():
yield from head.parameters()
[docs]
@override
def collate_fn(self, data_list):
with optional_import_error_message("torch_geometric"):
from torch_geometric.data import Batch # type: ignore[reportMissingImports] # noqa
return Batch.from_data_list(cast("list[BaseData]", data_list))
[docs]
@override
def batch_to_labels(self, batch):
labels: dict[str, torch.Tensor] = {}
for prop in self.hparams.properties:
labels[prop.name] = getattr(batch, prop.name)
return labels
[docs]
@override
def atoms_to_data(self, atoms, has_labels):
with optional_import_error_message("torch_geometric"):
from torch_geometric.data import Data # type: ignore[reportMissingImports] # noqa
# For JMP, your PyG object should have the following attributes:
# - pos: Node positions (shape: (N, 3))
# - atomic_numbers: Atomic numbers (shape: (N,))
# - natoms: Number of atoms (shape: (), i.e. a scalar)
# - tags: Atom tags (shape: (N,)), this is used to distinguish between
# surface and adsorbate atoms in datasets like OC20.
# Set this to 2 if you don't have this information.
# - fixed: Boolean tensor indicating whether an atom is fixed
# in the relaxation (shape: (N,)), set this to False
# if you don't have this information.
# - cell: The cell vectors (shape: (1, 3, 3))
# - pbc: The periodic boundary conditions (shape: (1, 3))
data_dict: dict[str, torch.Tensor] = {
"pos": torch.tensor(atoms.positions, dtype=torch.float32),
"atomic_numbers": torch.tensor(atoms.numbers, dtype=torch.long),
"natoms": torch.tensor(len(atoms), dtype=torch.long),
"tags": torch.full((len(atoms),), 2, dtype=torch.long),
"fixed": torch.from_numpy(_get_fixed(atoms)).bool(),
"cell": torch.from_numpy(np.array(atoms.cell, dtype=np.float32))
.float()
.unsqueeze(0),
"pbc": torch.tensor(atoms.pbc, dtype=torch.bool).unsqueeze(0),
}
if has_labels:
# Also, pass along any other targets/properties. This includes:
# - energy: The total energy of the system
# - forces: The forces on each atom
# - stress: The stress tensor of the system
# - anything else you want to predict
for prop in self.hparams.properties:
value = prop._from_ase_atoms_to_torch(atoms)
# For stress, we should make sure it is (3, 3), not the flattened (6,)
# that ASE returns.
if isinstance(prop, props.StressesPropertyConfig):
from ase.constraints import voigt_6_to_full_3x3_stress
value = voigt_6_to_full_3x3_stress(value.float().numpy())
value = torch.from_numpy(value).float().reshape(1, 3, 3)
data_dict[prop.name] = value
return Data.from_dict(data_dict)
[docs]
@override
def create_normalization_context_from_batch(self, batch):
with optional_import_error_message("torch_scatter"):
from torch_scatter import scatter # type: ignore[reportMissingImports] # noqa
atomic_numbers: torch.Tensor = batch["atomic_numbers"].long() # (n_atoms,)
batch_idx: torch.Tensor = batch["batch"] # (n_atoms,)
# Convert atomic numbers to one-hot encoding
atom_types_onehot = F.one_hot(atomic_numbers, num_classes=120) # (n_atoms, 120)
compositions = scatter(
atom_types_onehot,
batch_idx,
dim=0,
dim_size=batch.num_graphs,
reduce="sum",
)
compositions = compositions[:, 1:] # Remove the zeroth element
return NormalizationContext(compositions=compositions)
def _get_fixed(atoms: Atoms):
"""Gets the fixed atom constraint mask from an Atoms object."""
fixed = np.zeros(len(atoms), dtype=np.bool_)
if (constraints := getattr(atoms, "constraints", None)) is None:
raise ValueError("Atoms object does not have a constraints attribute")
from ase.constraints import FixAtoms
for constraint in constraints:
if not isinstance(constraint, FixAtoms):
continue
fixed[constraint.index] = True
return fixed