Source code for mattertune.finetune.loader

from __future__ import annotations

from collections.abc import Iterable
from typing import TYPE_CHECKING

import ase
from torch.utils.data import DataLoader, Dataset, IterableDataset, Sampler
from torch.utils.data.dataloader import _worker_init_fn_t
from typing_extensions import TypedDict, Unpack

from .data_util import IterableDatasetWrapper, MapDatasetWrapper

if TYPE_CHECKING:
    from .base import FinetuneModuleBase, TBatch, TData, TFinetuneModuleConfig


[docs] class DataLoaderKwargs(TypedDict, total=False): """Keyword arguments for creating a DataLoader. Args: batch_size: How many samples per batch to load (default: 1). shuffle: Set to True to have the data reshuffled at every epoch (default: False). sampler: Defines the strategy to draw samples from the dataset. Can be any Iterable with __len__ implemented. If specified, shuffle must not be specified. batch_sampler: Like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last. num_workers: How many subprocesses to use for data loading. 0 means that the data will be loaded in the main process (default: 0). pin_memory: If True, the data loader will copy Tensors into device/CUDA pinned memory before returning them. drop_last: Set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size (default: False). timeout: If positive, the timeout value for collecting a batch from workers. Should always be non-negative (default: 0). worker_init_fn: If not None, this will be called on each worker subprocess with the worker id as input, after seeding and before data loading. multiprocessing_context: If None, the default multiprocessing context of your operating system will be used. generator: If not None, this RNG will be used by RandomSampler to generate random indexes and multiprocessing to generate base_seed for workers. prefetch_factor: Number of batches loaded in advance by each worker. persistent_workers: If True, the data loader will not shut down the worker processes after a dataset has been consumed once. pin_memory_device: The device to pin_memory to if pin_memory is True. """ batch_size: int | None shuffle: bool | None sampler: Sampler | Iterable | None batch_sampler: Sampler[list[int]] | Iterable[list[int]] | None num_workers: int pin_memory: bool drop_last: bool timeout: float worker_init_fn: _worker_init_fn_t | None multiprocessing_context: Any # type: ignore generator: Any # type: ignore prefetch_factor: int | None persistent_workers: bool pin_memory_device: str
[docs] def create_dataloader( dataset: Dataset[ase.Atoms], has_labels: bool, *, lightning_module: FinetuneModuleBase[TData, TBatch, TFinetuneModuleConfig], **kwargs: Unpack[DataLoaderKwargs], ): def map_fn(ase_data: ase.Atoms): data = lightning_module.atoms_to_data(ase_data, has_labels) data = lightning_module.cpu_data_transform(data) return data # Wrap the dataset with the CPU data transform dataset_mapped = ( IterableDatasetWrapper(dataset, map_fn) if isinstance(dataset, IterableDataset) else MapDatasetWrapper(dataset, map_fn) ) # Create the data loader with the model's collate function dl = DataLoader(dataset_mapped, collate_fn=lightning_module.collate_fn, **kwargs) return dl