mattertune.wrappers.property_predictor

Classes

MatterTunePropertyPredictor(lightning_module)

A wrapper class for handling predictions using a fine-tuned MatterTune model.

class mattertune.wrappers.property_predictor.MatterTunePropertyPredictor(lightning_module, lightning_trainer_kwargs=None)[source]

A wrapper class for handling predictions using a fine-tuned MatterTune model.

This class provides an interface to make predictions using a trained neural network. It wraps a PyTorch Lightning module and handles the necessary setup for making predictions on atomic systems.

lightning_moduleFinetuneModuleBase

The trained PyTorch Lightning module that will be used for predictions.

lightning_trainer_kwargsdict[str, Any] | None, optional

Additional keyword arguments to pass to the PyTorch Lightning Trainer. Defaults to None.

Examples

>>> from mattertune.wrappers import MatterTunePropertyPredictor
>>> predictor = MatterTunePropertyPredictor(trained_model)  # OR `predictor = trained_model.property_predictor()`
>>> predictions = predictor.predict(atoms_list)

The class provides a simplified interface for making predictions with trained models, handling the necessary setup of trainers and dataloaders internally.

Parameters:
__init__(lightning_module, lightning_trainer_kwargs=None)[source]
Parameters:
predict(atoms_list, properties=None, *, batch_size=1)[source]

Predicts properties for a list of atomic systems using the trained model.

This method processes a list of atomic structures through the model and returns predicted properties for each system.

Parameters:
  • atoms_list (list[ase.Atoms]) – List of atomic systems to predict properties for.

  • properties (Sequence[str | PropertyConfig] | None, optional) – Properties to predict. Can be specified as strings or PropertyConfig objects. If None, predicts all properties supported by the model.

  • batch_size (int)

Returns:

List of dictionaries containing predicted properties for each system. Each dictionary maps property names to torch.Tensor values.

Return type:

list[dict[str, torch.Tensor]]

Notes

  • Creates a temporary trainer instance for prediction

  • Converts input atoms to a dataloader compatible with the model

  • Returns raw prediction outputs from the model