Training Configuration Guide

MatterTune uses a comprehensive configuration system to control all aspects of training. This guide covers the key components and how to use them effectively.

Model Configuration

Control the model architecture and training parameters:

model = mt.configs.JMPBackboneConfig(
    # Specify pre-trained model checkpoint
    ckpt_path="path/to/pretrained/model.pt",

    # Define properties to predict
    properties=[
        mt.configs.EnergyPropertyConfig(
            loss=mt.configs.MAELossConfig(),
            loss_coefficient=1.0
        ),
        mt.configs.ForcesPropertyConfig(
            loss=mt.configs.MAELossConfig(),
            loss_coefficient=10.0,
            conservative=True  # Use energy-conserving force prediction
        )
    ],

    # Configure optimizer
    optimizer=mt.configs.AdamWConfig(lr=1e-4),

    # Optional: Configure learning rate scheduler
    lr_scheduler=mt.configs.CosineAnnealingLRConfig(
        T_max=100,  # Number of epochs
        eta_min=1e-6  # Minimum learning rate
    )
)

Data Configuration

Configure data loading and processing:

data = mt.configs.AutoSplitDataModuleConfig(
    # Specify dataset source
    dataset=mt.configs.XYZDatasetConfig(
        src="path/to/your/data.xyz"
    ),

    # Control data splitting
    train_split=0.8,  # 80% for training

    # Configure batch size and loading
    batch_size=32,
    num_workers=4,  # Number of data loading workers
    pin_memory=True  # Optimize GPU transfer
)

Training Process Configuration

Control the training loop behavior:

trainer = mt.configs.TrainerConfig(
    # Hardware configuration
    accelerator="gpu",
    devices=[0, 1],  # Use GPUs 0 and 1

    # Training stopping criteria
    max_epochs=100,
    # OR: max_steps=1000,  # Stop after 1000 steps
    # OR: max_time=datetime.timedelta(hours=1),  # Stop after 1 hour

    # Validation frequency
    check_val_every_n_epoch=1,

    # Gradient clipping: Prevent exploding gradients
    gradient_clip_val=1.0,

    # Early stopping configuration
    early_stopping=mt.configs.EarlyStoppingConfig(
        monitor="val/energy_mae",
        patience=20,
        mode="min"
    ),

    # Model checkpointing
    checkpoint=mt.configs.ModelCheckpointConfig(
        monitor="val/energy_mae",
        save_top_k=1,
        mode="min"
    ),

    # Configure logging
    loggers=[
        mt.configs.WandbLoggerConfig(
            project="my-project",
            name="experiment-1"
        )
    ]
)

# Combine all configurations
config = mt.configs.MatterTunerConfig(
    model=model,
    data=data,
    trainer=trainer
)

Configuration Management

MatterTune uses nshconfig for configuration management, providing several ways to create and load configurations:

1. Direct Construction

config = mt.configs.MatterTunerConfig(
    model=mt.configs.JMPBackboneConfig(...),
    data=mt.configs.AutoSplitDataModuleConfig(...),
    trainer=mt.configs.TrainerConfig(...)
)

2. Loading from Files/Dictionaries

# Load from YAML
config = mt.configs.MatterTunerConfig.from_yaml('/path/to/config.yaml')

# Load from JSON
config = mt.configs.MatterTunerConfig.from_json('/path/to/config.json')

# Load from dictionary
config = mt.configs.MatterTunerConfig.from_dict({
    'model': {...},
    'data': {...},
    'trainer': {...}
})

3. Using Draft Configs

# Create a draft config
config = mt.configs.MatterTunerConfig.draft()

# Set values progressively
config.model = mt.configs.JMPBackboneConfig.draft()
config.model.ckpt_path = "path/to/model.pt"
# ... set other values ...

# Finalize the config
final_config = config.finalize()

For more advanced configuration management features, see the nshconfig documentation.