Source code for mattertune.wrappers.ase_calculator

from __future__ import annotations

import copy
from typing import TYPE_CHECKING

import numpy as np
import torch
from ase import Atoms
from ase.calculators.calculator import Calculator
from typing_extensions import override

if TYPE_CHECKING:
    from ..finetune.properties import PropertyConfig
    from .property_predictor import MatterTunePropertyPredictor
    from ..finetune.base import FinetuneModuleBase


[docs] class MatterTuneCalculator(Calculator): """ A fast version of the MatterTuneCalculator that uses the `predict_step` method directly without creating a trainer. """
[docs] @override def __init__(self, model: FinetuneModuleBase, device: torch.device): super().__init__() self.model = model.to(device) self.implemented_properties: list[str] = [] self._ase_prop_to_config: dict[str, PropertyConfig] = {} for prop in self.model.hparams.properties: # Ignore properties not marked as ASE calculator properties. if (ase_prop_name := prop.ase_calculator_property_name()) is None: continue self.implemented_properties.append(ase_prop_name) self._ase_prop_to_config[ase_prop_name] = prop
[docs] @override def calculate( self, atoms: Atoms | None = None, properties: list[str] | None = None, system_changes: list[str] | None = None, ): if properties is None: properties = copy.deepcopy(self.implemented_properties) # Call the parent class to set `self.atoms`. Calculator.calculate(self, atoms) # Make sure `self.atoms` is set. assert self.atoms is not None, ( "`MatterTuneCalculator.atoms` is not set. " "This should have been set by the parent class. " "Please report this as a bug." ) assert isinstance(self.atoms, Atoms), ( "`MatterTuneCalculator.atoms` is not an `ase.Atoms` object. " "This should have been set by the parent class. " "Please report this as a bug." ) prop_configs = [self._ase_prop_to_config[prop] for prop in properties] normalized_atoms = copy.deepcopy(self.atoms) scaled_pos = normalized_atoms.get_scaled_positions() scaled_pos = np.mod(scaled_pos, 1.0) normalized_atoms.set_scaled_positions(scaled_pos) data = self.model.atoms_to_data(normalized_atoms, has_labels=False) batch = self.model.collate_fn([data]) batch = batch.to(self.model.device) pred = self.model.predict_step( batch = batch, batch_idx = 0, ) pred = pred[0] # type: ignore for prop in prop_configs: ase_prop_name = prop.ase_calculator_property_name() assert ase_prop_name is not None, ( f"Property '{prop.name}' does not have an ASE calculator property name. " "This should have been checked when creating the MatterTuneCalculator. " "Please report this as a bug." ) value = pred[prop.name].detach().to(torch.float32).cpu().numpy() # type: ignore value = value.astype(prop._numpy_dtype()) value = prop.prepare_value_for_ase_calculator(value) self.results[ase_prop_name] = value