Fine-Tuning a Pre-trained Model
This guide will walk you through fine-tuning a pre-trained model for predicting molecular properties. We’ll use a complete example with detailed explanations.
import mattertune as mt
from pathlib import Path
# Step 1: Define the configuration for our fine-tuning process
config = mt.configs.MatterTunerConfig(
# Configure the model and its training parameters
model=mt.configs.JMPBackboneConfig(
# Path to the pre-trained model checkpoint you want to fine-tune
ckpt_path=Path("YOUR_CHECKPOINT_PATH"),
# Configure how atomic structures are processed
graph_computer=mt.configs.JMPGraphComputerConfig(
# Set pbc=True for periodic systems (crystals), False for molecules
pbc=True
),
# Define which properties to predict and how to train them
properties=[
mt.configs.EnergyPropertyConfig(
# Use Mean Absolute Error (MAE) loss for energy prediction
loss=mt.configs.MAELossConfig(),
# Weight of this loss in the total loss function
loss_coefficient=1.0
)
],
# Configure the optimizer for training
optimizer=mt.configs.AdamWConfig(
# Learning rate - adjust based on your dataset size and needs
lr=1e-4
),
),
# Configure how data is loaded and processed
data=mt.configs.AutoSplitDataModuleConfig(
# Specify the source dataset (XYZ file format in this case)
dataset=mt.configs.XYZDatasetConfig(
src=Path("YOUR_XYZFILE_PATH")
),
# Fraction of data to use for training (0.8 = 80% training, 20% validation)
train_split=0.8,
# Number of structures to process at once
batch_size=4,
),
# Configure the training process
trainer=mt.configs.TrainerConfig(
# Maximum number of training epochs
max_epochs=10,
# Use GPU for training
accelerator="gpu",
# Specify which GPU(s) to use (0 = first GPU)
devices=[0],
),
)
# Step 2: Initialize the MatterTuner with our configuration
tuner = mt.MatterTuner(config)
# Step 3: Start the fine-tuning process
# This returns both the trained model and the trainer object
model, trainer = tuner.tune()
# Step 4: Save the fine-tuned model for later use
trainer.save_checkpoint("finetuned_model.ckpt")
# Step 5: Make predictions with the fine-tuned model
# Create a property predictor interface
property_predictor = model.property_predictor()
# Example: predict energy for a structure
from ase import Atoms
# Create a water molecule as an example
water = Atoms('H2O',
positions=[[0, 0, 0], # O atom
[0, 0, 0.96], # H atom
[0.93, 0, 0]], # H atom
cell=[10, 10, 10],
pbc=True)
# Make predictions
predictions = property_predictor.predict([water], ["energy"])
print(f"Predicted energy: {predictions[0]['energy']} eV")
Key Components Explained
Configuration Structure:
MatterTunerConfig
: The main configuration containerJMPBackboneConfig
: Specifies the model architecture and training parametersAutoSplitDataModuleConfig
: Handles data loading and splittingTrainerConfig
: Controls the training process
Property Prediction:
Define what properties to predict using PropertyConfig objects
Each property can have its own loss function and weight
Common properties: energy, forces, stress tensors
Data Handling:
Supports various input formats (XYZ, ASE databases, etc.)
Automatic train/validation splitting
Configurable batch size for memory management
Common Customizations
# Add force prediction
properties=[
mt.configs.EnergyPropertyConfig(
loss=mt.configs.MAELossConfig(),
loss_coefficient=1.0
),
mt.configs.ForcesPropertyConfig(
loss=mt.configs.MAELossConfig(),
loss_coefficient=0.1, # Usually smaller than energy coefficient
conservative=True # Ensures forces are energy-conserving
)
]
# Use multiple GPUs
trainer=mt.configs.TrainerConfig(
max_epochs=10,
accelerator="gpu",
devices=[0, 1], # Use GPUs 0 and 1
strategy="ddp" # Distributed data parallel training
)
# Add logging with Weights & Biases
trainer=mt.configs.TrainerConfig(
# ... other settings ...
loggers=[
mt.configs.WandbLoggerConfig(
project="my-project",
name="experiment-1"
)
]
)
Tips for Successful Fine-Tuning
Data Quality:
Ensure your training data is clean and properly formatted
Use a reasonable train/validation split (80/20 is common)
Consider normalizing your target properties
Training Parameters:
Start with a small learning rate (1e-4 to 1e-5)
Monitor validation loss for signs of overfitting
Use early stopping to prevent overfitting
Adjust batch size based on your GPU memory
Model Selection:
Choose a pre-trained model suitable for your task
Consider the model’s original training data and your use case
Test different backbones if possible
Monitoring Training:
Use logging to track training progress
Monitor both training and validation metrics
Save checkpoints regularly