Advanced: Lightning Integration
MatterTune uses PyTorch Lightning as its core training framework. This document outlines how Lightning is integrated and what functionality it provides.
Core Components
LightningModule Integration
The base model class FinetuneModuleBase
inherits from LightningModule
and provides:
Automatic device management (GPU/CPU handling)
Distributed training support
Built-in training/validation/test loops
Logging and metrics tracking
Checkpoint management
class FinetuneModuleBase(LightningModule):
def training_step(self, batch, batch_idx):
output = self(batch)
loss = self._compute_loss(output["predicted_properties"], self.batch_to_labels(batch))
return loss
def validation_step(self, batch, batch_idx):
output = self(batch)
self._compute_loss(output["predicted_properties"], self.batch_to_labels(batch))
def test_step(self, batch, batch_idx):
output = self(batch)
self._compute_loss(output["predicted_properties"], self.batch_to_labels(batch))
def configure_optimizers(self):
return create_optimizer(self.hparams.optimizer, self.parameters())
Data Handling
MatterTune uses Lightning’s DataModule system for standardized data loading:
class MatterTuneDataModule(LightningDataModule):
def prepare_data(self):
# Download data if needed
pass
def setup(self, stage):
# Create train/val/test splits
self.datasets = self.hparams.create_datasets()
def train_dataloader(self):
return self.lightning_module.create_dataloader(
self.datasets["train"],
has_labels=True
)
def val_dataloader(self):
return self.lightning_module.create_dataloader(
self.datasets["validation"],
has_labels=True
)
Key Features
1. Checkpoint Management
Lightning automatically handles model checkpointing:
checkpoint_callback = ModelCheckpointConfig(
monitor="val/forces_mae",
dirpath="./checkpoints",
filename="best-model",
save_top_k=1,
mode="min"
).create_callback()
trainer = Trainer(callbacks=[checkpoint_callback])
2. Early Stopping
Built-in early stopping support:
early_stopping = EarlyStoppingConfig(
monitor="val/forces_mae",
patience=20,
mode="min"
).create_callback()
trainer = Trainer(callbacks=[early_stopping])
3. Multi-GPU Training
Lightning handles distributed training with minimal code changes:
# Single GPU
trainer = Trainer(accelerator="gpu", devices=[0])
# Multiple GPUs with DDP
trainer = Trainer(accelerator="gpu", devices=[0,1], strategy="ddp")
4. Logging
Lightning provides unified logging interfaces:
def training_step(self, batch, batch_idx):
loss = ...
self.log("train_loss", loss)
self.log_dict({
"energy_mae": energy_mae,
"forces_mae": forces_mae
})
5. Precision Settings
Easy configuration of precision:
# 32-bit training
trainer = Trainer(precision="32-true")
# Mixed precision training
trainer = Trainer(precision="16-mixed")
Available Trainer Configurations
The TrainerConfig
class exposes common Lightning Trainer settings:
trainer_config = TrainerConfig(
# Hardware
accelerator="gpu",
devices=[0,1],
precision="16-mixed",
# Training
max_epochs=100,
gradient_clip_val=1.0,
# Validation
val_check_interval=1.0,
check_val_every_n_epoch=1,
# Callbacks
early_stopping=EarlyStoppingConfig(...),
checkpoint=ModelCheckpointConfig(...),
# Logging
loggers=["tensorboard", "wandb"]
)
Best Practices
Use
self.log()
for tracking metrics during trainingEnable checkpointing to save model states
Set appropriate early stopping criteria
Use appropriate precision settings for your hardware
Configure multi-GPU training based on available resources
Advanced Usage
For advanced use cases:
# Custom training loop
@override
def training_step(self, batch, batch_idx):
if self.trainer.global_rank == 0:
# Do something only on main process
pass
# Access trainer properties
if self.trainer.is_last_batch:
# Special handling for last batch
pass
# Custom validation
@override
def validation_epoch_end(self, outputs):
# Compute epoch-level metrics
pass
This integration provides a robust foundation for training atomistic models while handling common ML engineering concerns automatically.