Training Configuration Guide
MatterTune uses a comprehensive configuration system to control all aspects of training. This guide covers the key components and how to use them effectively.
Model Configuration
Control the model architecture and training parameters:
model = mt.configs.JMPBackboneConfig(
# Specify pre-trained model checkpoint
ckpt_path="path/to/pretrained/model.pt",
# Define properties to predict
properties=[
mt.configs.EnergyPropertyConfig(
loss=mt.configs.MAELossConfig(),
loss_coefficient=1.0
),
mt.configs.ForcesPropertyConfig(
loss=mt.configs.MAELossConfig(),
loss_coefficient=10.0,
conservative=True # Use energy-conserving force prediction
)
],
# Configure optimizer
optimizer=mt.configs.AdamWConfig(lr=1e-4),
# Optional: Configure learning rate scheduler
lr_scheduler=mt.configs.CosineAnnealingLRConfig(
T_max=100, # Number of epochs
eta_min=1e-6 # Minimum learning rate
)
)
Data Configuration
Configure data loading and processing:
data = mt.configs.AutoSplitDataModuleConfig(
# Specify dataset source
dataset=mt.configs.XYZDatasetConfig(
src="path/to/your/data.xyz"
),
# Control data splitting
train_split=0.8, # 80% for training
# Configure batch size and loading
batch_size=32,
num_workers=4, # Number of data loading workers
pin_memory=True # Optimize GPU transfer
)
Training Process Configuration
Control the training loop behavior:
trainer = mt.configs.TrainerConfig(
# Hardware configuration
accelerator="gpu",
devices=[0, 1], # Use GPUs 0 and 1
# Training stopping criteria
max_epochs=100,
# OR: max_steps=1000, # Stop after 1000 steps
# OR: max_time=datetime.timedelta(hours=1), # Stop after 1 hour
# Validation frequency
check_val_every_n_epoch=1,
# Gradient clipping: Prevent exploding gradients
gradient_clip_val=1.0,
# Early stopping configuration
early_stopping=mt.configs.EarlyStoppingConfig(
monitor="val/energy_mae",
patience=20,
mode="min"
),
# Model checkpointing
checkpoint=mt.configs.ModelCheckpointConfig(
monitor="val/energy_mae",
save_top_k=1,
mode="min"
),
# Configure logging
loggers=[
mt.configs.WandbLoggerConfig(
project="my-project",
name="experiment-1"
)
]
)
# Combine all configurations
config = mt.configs.MatterTunerConfig(
model=model,
data=data,
trainer=trainer
)
Configuration Management
MatterTune uses nshconfig
for configuration management, providing several ways to create and load configurations:
1. Direct Construction
config = mt.configs.MatterTunerConfig(
model=mt.configs.JMPBackboneConfig(...),
data=mt.configs.AutoSplitDataModuleConfig(...),
trainer=mt.configs.TrainerConfig(...)
)
2. Loading from Files/Dictionaries
# Load from YAML
config = mt.configs.MatterTunerConfig.from_yaml('/path/to/config.yaml')
# Load from JSON
config = mt.configs.MatterTunerConfig.from_json('/path/to/config.json')
# Load from dictionary
config = mt.configs.MatterTunerConfig.from_dict({
'model': {...},
'data': {...},
'trainer': {...}
})
3. Using Draft Configs
# Create a draft config
config = mt.configs.MatterTunerConfig.draft()
# Set values progressively
config.model = mt.configs.JMPBackboneConfig.draft()
config.model.ckpt_path = "path/to/model.pt"
# ... set other values ...
# Finalize the config
final_config = config.finalize()
For more advanced configuration management features, see the nshconfig documentation.