mattertune.finetune.base

Classes

FinetuneModuleBase(hparams)

Finetune module base class.

FinetuneModuleBaseConfig(*, properties, ...)

ModelOutput

class mattertune.finetune.base.FinetuneModuleBaseConfig(*, properties, optimizer, lr_scheduler=None, ignore_gpu_batch_transform_error=True, normalizers={})[source]
Parameters:
  • properties (Sequence[PropertyConfig])

  • optimizer (OptimizerConfig)

  • lr_scheduler (LRSchedulerConfig | None)

  • ignore_gpu_batch_transform_error (bool)

  • normalizers (Mapping[str, Sequence[NormalizerConfig]])

properties: Sequence[PropertyConfig]

Properties to predict.

optimizer: OptimizerConfig

Optimizer.

lr_scheduler: LRSchedulerConfig | None

Learning Rate Scheduler

ignore_gpu_batch_transform_error: bool

Whether to ignore data processing errors during training.

normalizers: Mapping[str, Sequence[NormalizerConfig]]

Normalizers for the properties.

Any property can be associated with multiple normalizers. This is useful for cases where we want to normalize the same property in different ways. For example, we may want to normalize the energy by subtracting the atomic reference energies, as well as by mean and standard deviation normalization.

The normalizers are applied in the order they are defined in the list.

abstract classmethod ensure_dependencies()[source]

Ensure that all dependencies are installed.

This method should raise an exception if any dependencies are missing, with a message indicating which dependencies are missing and how to install them.

abstract create_model()[source]

Creates an instance of the finetune module for this configuration.

Return type:

FinetuneModuleBase

class mattertune.finetune.base.ModelOutput[source]
predicted_properties: dict[str, Tensor]

Predicted properties. This dictionary should be exactly in the same shape/format as the output of batch_to_labels.

backbone_output: NotRequired[Any]

Output of the backbone model. Only set if return_backbone_output is True.

class mattertune.finetune.base.FinetuneModuleBase(hparams)[source]

Finetune module base class. Inherits lightning.pytorch.LightningModule.

Parameters:

hparams (TFinetuneModuleConfig)

abstract classmethod hparams_cls()[source]

Return the hyperparameters config class for this module.

Return type:

type[TFinetuneModuleConfig]

abstract create_model()[source]

Initialize both the pre-trained backbone and the output heads for the properties to predict.

You should also construct any other nn.Module instances necessary for the forward pass here.

abstract model_forward_context(data)[source]

Context manager for the model forward pass.

This is used for any setup that needs to be done before the forward pass, e.g., setting pos.requires_grad_() for gradient-based force prediction.

Parameters:

data (TBatch)

Return type:

AbstractContextManager

abstract requires_disabled_inference_mode()[source]

Whether the model requires inference mode to be disabled.

Return type:

bool

abstract model_forward(batch, return_backbone_output=False)[source]

Forward pass of the model.

Parameters:
  • batch (TBatch) – Input batch.

  • return_backbone_output (bool) – Whether to return the output of the backbone model.

Returns:

Prediction of the model.

Return type:

ModelOutput

abstract pretrained_backbone_parameters()[source]

Return the parameters of the backbone model.

Return type:

Iterable[Parameter]

abstract output_head_parameters()[source]

Return the parameters of the output heads.

Return type:

Iterable[Parameter]

abstract cpu_data_transform(data)[source]

Transform data (on the CPU) before being batched and sent to the GPU.

Parameters:

data (TData)

Return type:

TData

abstract collate_fn(data_list)[source]

Collate function for the DataLoader

Parameters:

data_list (list[TData])

Return type:

TBatch

abstract gpu_batch_transform(batch)[source]

Transform batch (on the GPU) before being fed to the model.

This will mainly be used to compute the (radius or knn) graph from the atomic positions.

Parameters:

batch (TBatch)

Return type:

TBatch

abstract batch_to_labels(batch)[source]

Extract ground truth values from a batch. The output of this function should be a dictionary with keys corresponding to the target names and values corresponding to the ground truth values. The values should be torch tensors that match, in shape, the output of the corresponding output head.

Parameters:

batch (TBatch)

Return type:

dict[str, Tensor]

abstract atoms_to_data(atoms, has_labels)[source]

Convert an ASE atoms object to a data object. This is used to convert the input data to the format expected by the model.

Parameters:
  • atoms (Atoms) – ASE atoms object.

  • has_labels (bool) – Whether the atoms object contains labels.

Return type:

TData

abstract create_normalization_context_from_batch(batch)[source]

Create a normalization context from a batch. This is used to normalize and denormalize the properties.

The normalization context contains all the information required to normalize and denormalize the properties. Currently, this only includes the compositions of the materials in the batch. The compositions should be provided as an integer tensor of shape (batch_size, num_elements), where each row (i.e., compositions[i]) corresponds to the composition vector of the i-th material in the batch.

The composition vector is a vector that maps each element to the number of atoms of that element in the material. For example, compositions[:, 1] corresponds to the number of Hydrogen atoms in each material in the batch, compositions[:, 2] corresponds to the number of Helium atoms, and so on.

Parameters:

batch (TBatch) – Input batch.

Returns:

Normalization context.

Return type:

NormalizationContext

__init__(hparams)[source]
Parameters:

hparams (TFinetuneModuleConfig | Mapping[str, Any])

create_metrics()[source]
create_normalizers()[source]
normalize(properties, ctx)[source]

Normalizes the properties dictionary. properties can either be the predicted properties or the ground truth labels.

Parameters:
  • properties (dict[str, Tensor]) – Dictionary of properties to normalize. The dictionary should have the same format as the output of batch_to_labels.

  • ctx (NormalizationContext) – Normalization context. This should be created using create_normalization_context_from_batch.

Returns:

Normalized properties.

denormalize(properties, ctx)[source]

Denormalizes the properties dictionary. properties can either be the predicted properties or the ground truth labels.

Parameters:
  • properties (dict[str, Tensor]) – Dictionary of properties to denormalize. The dictionary should have the same format as the output of batch_to_labels.

  • ctx (NormalizationContext) – Normalization context. This should be created using create_normalization_context_from_batch.

Returns:

Denormalized properties.

forward(batch, return_backbone_output=False, ignore_gpu_batch_transform_error=None)[source]

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

  • batch (TBatch)

  • return_backbone_output (bool)

  • ignore_gpu_batch_transform_error (bool | None)

Returns:

Your model’s output

Return type:

ModelOutput

training_step(batch, batch_idx)[source]

Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.

Parameters:
  • batch (TBatch) – The output of your data iterable, normally a DataLoader.

  • batch_idx (int) – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)

Returns:

  • Tensor - The loss tensor

  • dict - A dictionary which can include any keys, but must include the key 'loss' in the case of automatic optimization.

  • None - In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.

In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.

Example:

def training_step(self, batch, batch_idx):
    x, y, z = batch
    out = self.encoder(x)
    loss = self.loss(out, x)
    return loss

To use multiple optimizers, you can switch to ‘manual optimization’ and control their stepping:

def __init__(self):
    super().__init__()
    self.automatic_optimization = False


# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx):
    opt1, opt2 = self.optimizers()

    # do training_step with encoder
    ...
    opt1.step()
    # do training_step with decoder
    ...
    opt2.step()

Note

When accumulate_grad_batches > 1, the loss returned here will be automatically normalized by accumulate_grad_batches internally.

validation_step(batch, batch_idx)[source]

Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.

Parameters:
  • batch (TBatch) – The output of your data iterable, normally a DataLoader.

  • batch_idx (int) – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)

Returns:

  • Tensor - The loss tensor

  • dict - A dictionary. Can include any keys, but must include the key 'loss'.

  • None - Skip to the next batch.

# if you have one val dataloader:
def validation_step(self, batch, batch_idx): ...


# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx=0): ...

Examples:

# CASE 1: A single validation dataset
def validation_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'val_loss': loss, 'val_acc': val_acc})

If you pass in multiple val dataloaders, validation_step() will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

# CASE 2: multiple validation dataloaders
def validation_step(self, batch, batch_idx, dataloader_idx=0):
    # dataloader_idx tells you which dataset this is.
    ...

Note

If you don’t need to validate you don’t need to implement this method.

Note

When the validation_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.

test_step(batch, batch_idx)[source]

Operates on a single batch of data from the test set. In this step you’d normally generate examples or calculate anything of interest such as accuracy.

Parameters:
  • batch (TBatch) – The output of your data iterable, normally a DataLoader.

  • batch_idx (int) – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)

Returns:

  • Tensor - The loss tensor

  • dict - A dictionary. Can include any keys, but must include the key 'loss'.

  • None - Skip to the next batch.

# if you have one test dataloader:
def test_step(self, batch, batch_idx): ...


# if you have multiple test dataloaders:
def test_step(self, batch, batch_idx, dataloader_idx=0): ...

Examples:

# CASE 1: A single test dataset
def test_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'test_loss': loss, 'test_acc': test_acc})

If you pass in multiple test dataloaders, test_step() will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

# CASE 2: multiple test dataloaders
def test_step(self, batch, batch_idx, dataloader_idx=0):
    # dataloader_idx tells you which dataset this is.
    ...

Note

If you don’t need to test you don’t need to implement this method.

Note

When the test_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of the test epoch, the model goes back to training mode and gradients are enabled.

predict_step(batch, batch_idx)[source]

Step function called during predict(). By default, it calls forward(). Override to add any processing logic.

The predict_step() is used to scale inference on multi-devices.

To prevent an OOM error, it is possible to use BasePredictionWriter callback to write the predictions to disk or database after each batch or on epoch end.

The BasePredictionWriter should be used while using a spawn based accelerator. This happens for Trainer(strategy="ddp_spawn") or training on 8 TPU cores with Trainer(accelerator="tpu", devices=8) as predictions won’t be returned.

Parameters:
  • batch (TBatch) – The output of your data iterable, normally a DataLoader.

  • batch_idx (int) – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)

Returns:

Predicted output (optional).

Example

class MyModel(LightningModule):

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        return self(batch)

dm = ...
model = MyModel()
trainer = Trainer(accelerator="gpu", devices=2)
predictions = trainer.predict(model, dm)
configure_optimizers()[source]

Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.

Returns:

Any of these 6 options.

  • Single optimizer.

  • List or Tuple of optimizers.

  • Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple lr_scheduler_config).

  • Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose value is a single LR scheduler or lr_scheduler_config.

  • None - Fit will run without any optimizer.

The lr_scheduler_config is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.

lr_scheduler_config = {
    # REQUIRED: The scheduler instance
    "scheduler": lr_scheduler,
    # The unit of the scheduler's step size, could also be 'step'.
    # 'epoch' updates the scheduler on epoch end whereas 'step'
    # updates it after a optimizer update.
    "interval": "epoch",
    # How many epochs/steps should pass between calls to
    # `scheduler.step()`. 1 corresponds to updating the learning
    # rate after every epoch/step.
    "frequency": 1,
    # Metric to to monitor for schedulers like `ReduceLROnPlateau`
    "monitor": "val_loss",
    # If set to `True`, will enforce that the value specified 'monitor'
    # is available when the scheduler is updated, thus stopping
    # training if not found. If set to `False`, it will only produce a warning
    "strict": True,
    # If using the `LearningRateMonitor` callback to monitor the
    # learning rate progress, this keyword can be used to specify
    # a custom logged name
    "name": None,
}

When there are schedulers in which the .step() method is conditioned on a value, such as the torch.optim.lr_scheduler.ReduceLROnPlateau scheduler, Lightning requires that the lr_scheduler_config contains the keyword "monitor" set to the metric name that the scheduler should be conditioned on.

Metrics can be made available to monitor by simply logging it using self.log('metric_to_track', metric_val) in your LightningModule.

Note

Some things to know:

  • Lightning calls .backward() and .step() automatically in case of automatic optimization.

  • If a learning rate scheduler is specified in configure_optimizers() with key "interval" (default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s .step() method automatically in case of automatic optimization.

  • If you use 16-bit precision (precision=16), Lightning will automatically handle the optimizer.

  • If you use torch.optim.LBFGS, Lightning handles the closure function automatically for you.

  • If you use multiple optimizers, you will have to switch to ‘manual optimization’ mode and step them yourself.

  • If you need to control how often the optimizer steps, override the optimizer_step() hook.

create_dataloader(dataset, has_labels, **kwargs)[source]

Creates a wrapped DataLoader for the given dataset.

This will wrap the dataset with the CPU data transform and the model’s collate function.

NOTE about has_labels: This is used to determine whether our data loading pipeline should expect labels in the dataset. This should be True for train/val/test datasets (as we compute loss and metrics on these datasets) and False for prediction datasets.

Parameters:
  • dataset (Dataset[Atoms]) – Dataset to wrap.

  • has_labels (bool) – Whether the dataset contains labels. This should be True for train/val/test datasets and False for prediction datasets.

  • **kwargs (Unpack[DataLoaderKwargs]) – Additional keyword arguments to pass to the DataLoader.

property_predictor(lightning_trainer_kwargs=None)[source]

Return a wrapper for easy prediction without explicitly setting up a lightning trainer.

This method provides a high-level interface for making predictions with the trained model.

This can be used for various prediction tasks including but not limited to: - Interatomic potential energy and forces - Material property prediction - Structure-property relationships

Parameters:

lightning_trainer_kwargs (dict[str, Any] | None, optional) – Additional keyword arguments to pass to the PyTorch Lightning Trainer. If None, default trainer settings will be used.

Returns:

A wrapper class that provides simplified prediction functionality without requiring direct interaction with the Lightning Trainer.

Return type:

MatterTunePropertyPredictor

Examples

>>> model = MyModel()
>>> property_predictor = model.property_predictor()
>>> atoms_1 = ase.Atoms("H2O", positions=[[0, 0, 0], [0, 0, 1], [0, 1, 0]], cell=[10, 10, 10], pbc=True)
>>> atoms_2 = ase.Atoms("H2O", positions=[[0, 0, 0], [0, 0, 1], [0, 1, 0]], cell=[10, 10, 10], pbc=True)
>>> atoms = [atoms_1, atoms_2]
>>> predictions = property_predictor.predict(atoms, ["energy", "forces"])
>>> print("Atoms 1 energy:", predictions[0]["energy"])
>>> print("Atoms 1 forces:", predictions[0]["forces"])
>>> print("Atoms 2 energy:", predictions[1]["energy"])
>>> print("Atoms 2 forces:", predictions[1]["forces"])
ase_calculator(lightning_trainer_kwargs=None)[source]

Returns an ASE calculator wrapper for the interatomic potential.

This method creates an ASE (Atomic Simulation Environment) calculator that can be used to compute energies and forces using the trained interatomic potential model.

The calculator integrates with ASE’s standard interfaces for molecular dynamics and structure optimization.

Parameters:

lightning_trainer_kwargs (dict[str, Any] | None, optional) – Keyword arguments to pass to the PyTorch Lightning trainer used for inference. If None, default trainer settings will be used.

Returns:

An ASE calculator wrapper around the trained potential that can be used for energy and force calculations via ASE’s interfaces.

Return type:

MatterTuneCalculator

Examples

>>> model = MyModel()
>>> calc = model.ase_calculator()
>>> atoms = ase.Atoms("H2O", positions=[[0, 0, 0], [0, 0, 1], [0, 1, 0]], cell=[10, 10, 10], pbc=True)
>>> atoms.calc = calc
>>> energy = atoms.get_potential_energy()
>>> forces = atoms.get_forces()