# Advanced: Lightning Integration MatterTune uses PyTorch Lightning as its core training framework. This document outlines how Lightning is integrated and what functionality it provides. ## Core Components ### LightningModule Integration The base model class `FinetuneModuleBase` inherits from `LightningModule` and provides: - Automatic device management (GPU/CPU handling) - Distributed training support - Built-in training/validation/test loops - Logging and metrics tracking - Checkpoint management ```python class FinetuneModuleBase(LightningModule): def training_step(self, batch, batch_idx): output = self(batch) loss = self._compute_loss(output["predicted_properties"], self.batch_to_labels(batch)) return loss def validation_step(self, batch, batch_idx): output = self(batch) self._compute_loss(output["predicted_properties"], self.batch_to_labels(batch)) def test_step(self, batch, batch_idx): output = self(batch) self._compute_loss(output["predicted_properties"], self.batch_to_labels(batch)) def configure_optimizers(self): return create_optimizer(self.hparams.optimizer, self.parameters()) ``` ### Data Handling MatterTune uses Lightning's DataModule system for standardized data loading: ```python class MatterTuneDataModule(LightningDataModule): def prepare_data(self): # Download data if needed pass def setup(self, stage): # Create train/val/test splits self.datasets = self.hparams.create_datasets() def train_dataloader(self): return self.lightning_module.create_dataloader( self.datasets["train"], has_labels=True ) def val_dataloader(self): return self.lightning_module.create_dataloader( self.datasets["validation"], has_labels=True ) ``` ## Key Features ### 1. Checkpoint Management Lightning automatically handles model checkpointing: ```python checkpoint_callback = ModelCheckpointConfig( monitor="val/forces_mae", dirpath="./checkpoints", filename="best-model", save_top_k=1, mode="min" ).create_callback() trainer = Trainer(callbacks=[checkpoint_callback]) ``` ### 2. Early Stopping Built-in early stopping support: ```python early_stopping = EarlyStoppingConfig( monitor="val/forces_mae", patience=20, mode="min" ).create_callback() trainer = Trainer(callbacks=[early_stopping]) ``` ### 3. Multi-GPU Training Lightning handles distributed training with minimal code changes: ```python # Single GPU trainer = Trainer(accelerator="gpu", devices=[0]) # Multiple GPUs with DDP trainer = Trainer(accelerator="gpu", devices=[0,1], strategy="ddp") ``` ### 4. Logging Lightning provides unified logging interfaces: ```python def training_step(self, batch, batch_idx): loss = ... self.log("train_loss", loss) self.log_dict({ "energy_mae": energy_mae, "forces_mae": forces_mae }) ``` ### 5. Precision Settings Easy configuration of precision: ```python # 32-bit training trainer = Trainer(precision="32-true") # Mixed precision training trainer = Trainer(precision="16-mixed") ``` ## Available Trainer Configurations The `TrainerConfig` class exposes common Lightning Trainer settings: ```python trainer_config = TrainerConfig( # Hardware accelerator="gpu", devices=[0,1], precision="16-mixed", # Training max_epochs=100, gradient_clip_val=1.0, # Validation val_check_interval=1.0, check_val_every_n_epoch=1, # Callbacks early_stopping=EarlyStoppingConfig(...), checkpoint=ModelCheckpointConfig(...), # Logging loggers=["tensorboard", "wandb"] ) ``` ## Best Practices 1. Use `self.log()` for tracking metrics during training 2. Enable checkpointing to save model states 3. Set appropriate early stopping criteria 4. Use appropriate precision settings for your hardware 5. Configure multi-GPU training based on available resources ## Advanced Usage For advanced use cases: ```python # Custom training loop @override def training_step(self, batch, batch_idx): if self.trainer.global_rank == 0: # Do something only on main process pass # Access trainer properties if self.trainer.is_last_batch: # Special handling for last batch pass # Custom validation @override def validation_epoch_end(self, outputs): # Compute epoch-level metrics pass ``` This integration provides a robust foundation for training atomistic models while handling common ML engineering concerns automatically.