Source code for mattertune.normalization

from __future__ import annotations

import argparse
import json
import logging
from abc import ABC, abstractmethod
from collections import Counter
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from pathlib import Path
from typing import Annotated, Any, Literal, Protocol, cast, runtime_checkable

import ase
import nshconfig as C
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from typing_extensions import TypeAliasType, assert_never, override

from .finetune.properties import PropertyConfig

log = logging.getLogger(__name__)


[docs] @dataclass(frozen=True) class NormalizationContext: """ The normalization context contains all the information required to normalize and denormalize the properties. Currently, this only includes the compositions of the materials in the batch. This flexibility allows for the "Normalizer" interface to be used for other types of normalization, beyond just simple mean and standard deviation normalization. For example, subtracting linear references from total energies can be implemented using this interface. """ compositions: torch.Tensor # (num_samples, num_elements) """ The compositions should be provided as an integer tensor of shape `(batch_size, num_elements)`, where each row (i.e., `compositions[i]`) corresponds to the composition vector of the `i`-th material in the batch. The composition vector is a vector that maps each element to the number of atoms of that element in the material. For example, `compositions[:, 1]` corresponds to the number of Hydrogen atoms in each material in the batch, `compositions[:, 2]` corresponds to the number of Helium atoms, and so on. """
[docs] @runtime_checkable class NormalizerModule(Protocol):
[docs] def normalize( self, value: torch.Tensor, ctx: NormalizationContext, ) -> torch.Tensor: """Normalizes the input tensor using the normalizer's parameters and context. Args: value (torch.Tensor): The input tensor to be normalized ctx (NormalizationContext): Context containing compositions information Returns: torch.Tensor: The normalized tensor """ ...
[docs] def denormalize( self, value: torch.Tensor, ctx: NormalizationContext, ) -> torch.Tensor: """Denormalizes the input tensor using the normalizer's parameters and context. Args: value (torch.Tensor): The normalized tensor to be denormalized ctx (NormalizationContext): Context containing compositions information Returns: torch.Tensor: The denormalized tensor """ ...
[docs] class NormalizerConfigBase(C.Config, ABC):
[docs] @abstractmethod def create_normalizer_module(self) -> NormalizerModule: ...
[docs] class MeanStdNormalizerConfig(NormalizerConfigBase): mean: float """The mean of the property values.""" std: float """The standard deviation of the property values."""
[docs] class MeanStdNormalizerModule(nn.Module, NormalizerModule, ABC): mean: torch.Tensor std: torch.Tensor
[docs] @override def __init__(self, config: MeanStdNormalizerConfig): super().__init__() self.register_buffer("mean", torch.tensor(config.mean)) self.register_buffer("std", torch.tensor(config.std))
[docs] @override def normalize(self, value, ctx): return (value - self.mean) / self.std
[docs] @override def denormalize(self, value, ctx): return value * self.std + self.mean
[docs] class RMSNormalizerConfig(NormalizerConfigBase): rms: float """The root mean square of the property values."""
[docs] class RMSNormalizerModule(nn.Module, NormalizerModule, ABC): rms: torch.Tensor
[docs] @override def __init__(self, config: RMSNormalizerConfig): super().__init__() self.register_buffer("rms", torch.tensor(config.rms))
[docs] @override def normalize(self, value, ctx): return value / self.rms
[docs] @override def denormalize(self, value, ctx): return value * self.rms
[docs] class PerAtomReferencingNormalizerConfig(NormalizerConfigBase): per_atom_references: Mapping[int, float] | Sequence[float] | Path """The reference values for each element. - If a dictionary is provided, it maps atomic numbers to reference values - If a list is provided, it's a list of reference values indexed by atomic number - If a path is provided, it should point to a JSON file containing the references """ def _references_as_dict(self) -> dict[int, float]: if isinstance(self.per_atom_references, Mapping): return dict(self.per_atom_references) elif isinstance(self.per_atom_references, Sequence): return {z: ref for z, ref in enumerate(self.per_atom_references)} else: with open(self.per_atom_references, "r") as f: per_atom_references = json.load(f) per_atom_references = { int(k): v for k, v in per_atom_references.items() } return per_atom_references
[docs] @override def create_normalizer_module(self) -> NormalizerModule: return PerAtomReferencingNormalizerModule(self)
[docs] class PerAtomReferencingNormalizerModule(nn.Module, NormalizerModule): references: torch.Tensor
[docs] def __init__(self, config: PerAtomReferencingNormalizerConfig): super().__init__() references_dict = config._references_as_dict() max_atomic_number = max(references_dict.keys()) + 1 references = torch.zeros(max_atomic_number) for z, ref in references_dict.items(): references[z] = ref ## delete reference with key 0 references = references[1:] self.register_buffer("references", references)
[docs] @override def normalize(self, value: torch.Tensor, ctx: NormalizationContext) -> torch.Tensor: # Compute references for each composition in the batch references = self.references max_atomic_number = len(references) compositions = ctx.compositions[:, :max_atomic_number].to(references.dtype) references = torch.einsum("ij,j->i", compositions, references).reshape( value.shape ) # Subtract references from values return value - references
[docs] @override def denormalize( self, value: torch.Tensor, ctx: NormalizationContext ) -> torch.Tensor: # Add references back to get original valuesreferences = self.references references = self.references max_atomic_number = len(references) compositions = ctx.compositions[:, :max_atomic_number].to(references.dtype) references = torch.einsum("ij,j->i", compositions, references).reshape( value.shape ) return value + references
[docs] class ComposeNormalizers(nn.Module, NormalizerModule):
[docs] def __init__(self, normalizers: Sequence[NormalizerModule]): super().__init__() self.normalizers = nn.ModuleList(cast(list[nn.Module], normalizers))
[docs] @override def normalize(self, value: torch.Tensor, ctx: NormalizationContext) -> torch.Tensor: for normalizer in self.normalizers: value = normalizer.normalize(value, ctx) return value
[docs] @override def denormalize( self, value: torch.Tensor, ctx: NormalizationContext ) -> torch.Tensor: for normalizer in reversed(self.normalizers): value = normalizer.denormalize(value, ctx) return value
[docs] def compute_per_atom_references( dataset: Dataset[ase.Atoms], property: PropertyConfig, reference_model: Literal["linear", "ridge"], reference_model_kwargs: dict[str, Any] = {}, ): property_values: list[float] = [] compositions: list[Counter[int]] = [] # Iterate through the dataset to extract all labels. for atoms in dataset: # Extract the composition from the `ase.Atoms` object composition = Counter(atoms.get_atomic_numbers()) # Get the property value label = property._from_ase_atoms_to_torch(atoms) # Make sure label is a scalar and convert to float assert ( label.numel() == 1 ), f"Label for property {property.name} is not a scalar. Shape: {label.shape}" property_values.append(float(label.item())) compositions.append(composition) # Convert the compositions to a matrix num_samples = len(compositions) num_elements = max(max(c.keys()) for c in compositions) + 1 compositions_matrix = np.zeros((num_samples, num_elements)) for i, composition in enumerate(compositions): for z, count in composition.items(): compositions_matrix[i, z] = count # Fit the linear model match reference_model: case "linear": from sklearn.linear_model import LinearRegression model = LinearRegression(fit_intercept=False, **reference_model_kwargs) case "ridge": from sklearn.linear_model import Ridge model = Ridge(fit_intercept=False, **reference_model_kwargs) case _: assert_never(reference_model) references = model.fit(compositions_matrix, torch.tensor(property_values)).coef_ # references: (num_elements,) # Convert the reference to a dict[int, float] references_dict = {int(z): ref for z, ref in enumerate(references.tolist())} ## delete reference with key 0 del references_dict[0] return references_dict
[docs] def compute_per_atom_references_cli_main( args: argparse.Namespace, parser: argparse.ArgumentParser, ): # Extract the necessary arguments config_arg: Path = args.config property_name_arg: str = args.property dest_arg: Path = args.dest # Load the fine-tuning config from .main import MatterTunerConfig with open(config_arg, "r") as f: config = MatterTunerConfig.model_validate_json(f.read()) # Extract the property config from the model config if ( property := next( p for p in config.model.properties if p.name == property_name_arg ) ) is None: parser.error(f"Property {property_name_arg} not found in the model config.") # Load the dataset based on the config from .main import MatterTuneDataModule data_module = MatterTuneDataModule(config.data) data_module.prepare_data() data_module.setup("fit") # Get the train dataset or throw if (dataset := data_module.datasets.get("train")) is None: parser.error("The data module does not have a train dataset.") # Compute the reference values references_dict = compute_per_atom_references( dataset, property, args.reference_model, args.reference_model_kwargs, ) # Print the reference values log.info(f"Computed reference values:\n{references_dict}") # Save the reference values to a JSON file with open(dest_arg, "w") as f: json.dump(references_dict, f)
NormalizerConfig = TypeAliasType( "NormalizerConfig", Annotated[ MeanStdNormalizerConfig | RMSNormalizerConfig | PerAtomReferencingNormalizerConfig, C.Field( description="Configuration for normalizing and denormalizing a property." ), ], ) if __name__ == "__main__": parser = argparse.ArgumentParser( description="Compute per-atom references for a property using a linear model." ) parser.add_argument( "config", type=Path, help="The path to the MatterTune config JSON file.", ) parser.add_argument( "property", type=str, help="The name of the property for which to compute the per-atom references.", ) parser.add_argument( "dest", type=Path, help="The path to save the computed per-atom references JSON file.", ) parser.add_argument( "--reference-model", type=str, default="linear", choices=["linear", "ridge"], help="The type of reference model to use.", ) parser.add_argument( "--reference-model-kwargs", type=json.loads, default={}, help="The keyword arguments to pass to the reference model constructor.", ) args = parser.parse_args() compute_per_atom_references_cli_main(args, parser)