mattertune.callbacks.multi_gpu_writer
Classes
|
Pytorch Lightning Callback, for saving predictions from multiple GPUs during prediction. |
- class mattertune.callbacks.multi_gpu_writer.CustomWriter(write_interval='epoch')[source]
Pytorch Lightning Callback, for saving predictions from multiple GPUs during prediction. Requirements: 1. collect all predictions from different GPUs 2. follow the input order of the dataloader
- write_on_epoch_end(trainer, pl_module, predictions, batch_indices)[source]
Called at the end of each epoch, saves the predictions and corresponding batch indices
- Parameters:
predictions (list[dict[str, Tensor]])