mattertune.callbacks.multi_gpu_writer

Classes

CustomWriter([write_interval])

Pytorch Lightning Callback, for saving predictions from multiple GPUs during prediction.

class mattertune.callbacks.multi_gpu_writer.CustomWriter(write_interval='epoch')[source]

Pytorch Lightning Callback, for saving predictions from multiple GPUs during prediction. Requirements: 1. collect all predictions from different GPUs 2. follow the input order of the dataloader

__init__(write_interval='epoch')[source]
Return type:

None

write_on_epoch_end(trainer, pl_module, predictions, batch_indices)[source]

Called at the end of each epoch, saves the predictions and corresponding batch indices

Parameters:

predictions (list[dict[str, Tensor]])

gather_all_predictions()[source]

Load all saved predictions and corresponding indices from the temporary folder, sort them according to the batch indices, and aggregate the predictions into a single dictionary.

cleanup()[source]

Cleanup the temporary folder used for storing predictions.