mattertune.wrappers.property_predictor
Classes
|
A wrapper class for handling predictions using a fine-tuned MatterTune model. |
- class mattertune.wrappers.property_predictor.MatterTunePropertyPredictor(lightning_module, lightning_trainer_kwargs=None)[source]
A wrapper class for handling predictions using a fine-tuned MatterTune model.
This class provides an interface to make predictions using a trained neural network. It wraps a PyTorch Lightning module and handles the necessary setup for making predictions on atomic systems.
- lightning_moduleFinetuneModuleBase
The trained PyTorch Lightning module that will be used for predictions.
- lightning_trainer_kwargsdict[str, Any] | None, optional
Additional keyword arguments to pass to the PyTorch Lightning Trainer. Defaults to None.
Examples
>>> from mattertune.wrappers import MatterTunePropertyPredictor >>> predictor = MatterTunePropertyPredictor(trained_model) # OR `predictor = trained_model.property_predictor()` >>> predictions = predictor.predict(atoms_list)
The class provides a simplified interface for making predictions with trained models, handling the necessary setup of trainers and dataloaders internally.
- Parameters:
lightning_module (FinetuneModuleBase[Any, Any, FinetuneModuleBaseConfig])
lightning_trainer_kwargs (dict[str, Any] | None)
- __init__(lightning_module, lightning_trainer_kwargs=None)[source]
- Parameters:
lightning_module (FinetuneModuleBase[Any, Any, FinetuneModuleBaseConfig])
lightning_trainer_kwargs (dict[str, Any] | None)
- predict(atoms_list, properties=None, *, batch_size=1)[source]
Predicts properties for a list of atomic systems using the trained model.
This method processes a list of atomic structures through the model and returns predicted properties for each system.
- Parameters:
atoms_list (list[ase.Atoms]) – List of atomic systems to predict properties for.
properties (Sequence[str | PropertyConfig] | None, optional) – Properties to predict. Can be specified as strings or PropertyConfig objects. If None, predicts all properties supported by the model.
batch_size (int)
- Returns:
List of dictionaries containing predicted properties for each system. Each dictionary maps property names to torch.Tensor values.
- Return type:
list[dict[str, torch.Tensor]]
Notes
Creates a temporary trainer instance for prediction
Converts input atoms to a dataloader compatible with the model
Returns raw prediction outputs from the model