mattertune.backbones.m3gnet.model

Classes

M3GNetBackboneConfig(*, properties, optimizer)

M3GNetBackboneModule(hparams)

Implementation of the M3GNet backbone that fits into the MatterTune framework.

M3GNetGraphComputerConfig(*[, ...])

Configuration for initialize a MatGL Atoms2Graph Convertor.

MatGLBatch(g, lg, state_attr, lattice, ...)

MatGLData(g, lg, state_attr, lattice, labels)

class mattertune.backbones.m3gnet.model.MatGLData(g: 'DGLGraph', lg: 'DGLGraph | None', state_attr: 'torch.Tensor', lattice: 'torch.Tensor', labels: 'dict[str, torch.Tensor]')[source]
Parameters:
  • g (DGLGraph)

  • lg (DGLGraph | None)

  • state_attr (torch.Tensor)

  • lattice (torch.Tensor)

  • labels (dict[str, torch.Tensor])

g: DGLGraph

The DGL node graph for pairwise interactions.

lg: DGLGraph | None

The DGL edge graph for three-body interactions.

state_attr: torch.Tensor

The global state attributes

lattice: torch.Tensor

The atomic lattice vectors

labels: dict[str, torch.Tensor]

The ground truth labels

__init__(g, lg, state_attr, lattice, labels)
Parameters:
  • g (DGLGraph)

  • lg (DGLGraph | None)

  • state_attr (torch.Tensor)

  • lattice (torch.Tensor)

  • labels (dict[str, torch.Tensor])

Return type:

None

class mattertune.backbones.m3gnet.model.MatGLBatch(g: 'DGLGraph', lg: 'DGLGraph | None', state_attr: 'torch.Tensor', lattice: 'torch.Tensor', strain: 'torch.Tensor', labels: 'dict[str, torch.Tensor]')[source]
Parameters:
  • g (DGLGraph)

  • lg (DGLGraph | None)

  • state_attr (torch.Tensor)

  • lattice (torch.Tensor)

  • strain (torch.Tensor)

  • labels (dict[str, torch.Tensor])

g: DGLGraph

The DGL node graph for pairwise interactions.

lg: DGLGraph | None

The DGL edge graph for three-body interactions.

state_attr: torch.Tensor

The global state attributes

lattice: torch.Tensor

The atomic lattice vectors

strain: torch.Tensor

The strain tensor

labels: dict[str, torch.Tensor]

The ground truth labels

__init__(g, lg, state_attr, lattice, strain, labels)
Parameters:
  • g (DGLGraph)

  • lg (DGLGraph | None)

  • state_attr (torch.Tensor)

  • lattice (torch.Tensor)

  • strain (torch.Tensor)

  • labels (dict[str, torch.Tensor])

Return type:

None

class mattertune.backbones.m3gnet.model.M3GNetGraphComputerConfig(*, element_types=<factory>, cutoff=None, threebody_cutoff=None, pre_compute_line_graph=False, graph_labels=None)[source]

Configuration for initialize a MatGL Atoms2Graph Convertor.

Parameters:
  • element_types (tuple[str, ...])

  • cutoff (float | None)

  • threebody_cutoff (float | None)

  • pre_compute_line_graph (bool)

  • graph_labels (list[int | float] | None)

element_types: tuple[str, ...]

The element types to consider, default is all elements.

cutoff: float | None

The cutoff distance for the neighbor list. If None, the cutoff is loaded from the checkpoint.

threebody_cutoff: float | None

The cutoff distance for the three-body interactions. If None, the cutoff is loaded from the checkpoint.

pre_compute_line_graph: bool

Whether to pre-compute the line graph for three-body interactions in data preparation.

graph_labels: list[int | float] | None

The graph labels to consider, default is None.

class mattertune.backbones.m3gnet.model.M3GNetBackboneConfig(*, properties, optimizer, lr_scheduler=None, ignore_gpu_batch_transform_error=True, normalizers={}, name='m3gnet', ckpt_path, graph_computer)[source]
Parameters:
  • properties (Sequence[PropertyConfig])

  • optimizer (OptimizerConfig)

  • lr_scheduler (LRSchedulerConfig | None)

  • ignore_gpu_batch_transform_error (bool)

  • normalizers (Mapping[str, Sequence[NormalizerConfig]])

  • name (Literal['m3gnet'])

  • ckpt_path (str | Path)

  • graph_computer (M3GNetGraphComputerConfig)

name: Literal['m3gnet']

The type of the backbone.

ckpt_path: str | Path

The path to the pre-trained model checkpoint.

graph_computer: M3GNetGraphComputerConfig

Configuration for the graph computer.

create_model()[source]

Creates an instance of the finetune module for this configuration.

classmethod ensure_dependencies()[source]

Ensure that all dependencies are installed.

This method should raise an exception if any dependencies are missing, with a message indicating which dependencies are missing and how to install them.

class mattertune.backbones.m3gnet.model.M3GNetBackboneModule(hparams)[source]

Implementation of the M3GNet backbone that fits into the MatterTune framework. Followed the Matgl version of M3GNet. Paper: https://www.nature.com/articles/s43588-022-00349-3 Matgl Repo: https://github.com/materialsvirtuallab/matgl

Parameters:

hparams (TFinetuneModuleConfig)

classmethod hparams_cls()[source]

Return the hyperparameters config class for this module.

requires_disabled_inference_mode()[source]

Whether the model requires inference mode to be disabled.

setup(stage)[source]

Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Parameters:

stage (str) – either 'fit', 'validate', 'test', or 'predict'

Example:

class LitModel(...):
    def __init__(self):
        self.l1 = None

    def prepare_data(self):
        download_data()
        tokenize()

        # don't do this
        self.something = else

    def setup(self, stage):
        data = load_data(...)
        self.l1 = nn.Linear(28, data.num_classes)
create_model()[source]

Initialize both the pre-trained backbone and the output heads for the properties to predict.

You should also construct any other nn.Module instances necessary for the forward pass here.

model_forward_context(data)[source]

Context manager for the model forward pass.

This is used for any setup that needs to be done before the forward pass, e.g., setting pos.requires_grad_() for gradient-based force prediction.

model_forward(batch, return_backbone_output=False)[source]

Forward pass of the model.

Parameters:
  • batch (MatGLBatch) – Input batch.

  • return_backbone_output (bool) – Whether to return the output of the backbone model.

Returns:

Prediction of the model.

pretrained_backbone_parameters()[source]

Return the parameters of the backbone model.

output_head_parameters()[source]

Return the parameters of the output heads.

cpu_data_transform(data)[source]

Transform data (on the CPU) before being batched and sent to the GPU.

collate_fn(data_list)[source]

Collate function for the DataLoader

gpu_batch_transform(batch)[source]

Transform batch (on the GPU) before being fed to the model.

This will mainly be used to compute the (radius or knn) graph from the atomic positions.

Parameters:

batch (MatGLBatch)

Return type:

MatGLBatch

batch_to_labels(batch)[source]

Extract ground truth values from a batch. The output of this function should be a dictionary with keys corresponding to the target names and values corresponding to the ground truth values. The values should be torch tensors that match, in shape, the output of the corresponding output head.

atoms_to_data(atoms, has_labels)[source]

Convert an ASE atoms object to a data object. This is used to convert the input data to the format expected by the model.

Parameters:
  • atoms (Atoms) – ASE atoms object.

  • has_labels (bool) – Whether the atoms object contains labels.

Return type:

MatGLData

create_normalization_context_from_batch(batch)[source]

Create a normalization context from a batch. This is used to normalize and denormalize the properties.

The normalization context contains all the information required to normalize and denormalize the properties. Currently, this only includes the compositions of the materials in the batch. The compositions should be provided as an integer tensor of shape (batch_size, num_elements), where each row (i.e., compositions[i]) corresponds to the composition vector of the i-th material in the batch.

The composition vector is a vector that maps each element to the number of atoms of that element in the material. For example, compositions[:, 1] corresponds to the number of Hydrogen atoms in each material in the batch, compositions[:, 2] corresponds to the number of Helium atoms, and so on.

Parameters:

batch – Input batch.

Returns:

Normalization context.