mattertune.configs.data
- class mattertune.configs.data.AutoSplitDataModuleConfig(*, batch_size, num_workers='auto', pin_memory=True, dataset, train_split, validation_split='auto', shuffle=True, shuffle_seed=42)[source]
- Parameters:
batch_size (int)
num_workers (int | Literal['auto'])
pin_memory (bool)
dataset (DatasetConfig)
train_split (float)
validation_split (float | Literal['auto', 'disable'])
shuffle (bool)
shuffle_seed (int)
- dataset: DatasetConfig
The configuration for the dataset.
- train_split: float
The proportion of the dataset to include in the training split.
- validation_split: float | Literal['auto', 'disable']
The proportion of the dataset to include in the validation split.
If set to “auto”, the validation split will be automatically determined as the complement of the training split, i.e. validation_split = 1 - train_split.
If set to “disable”, the validation split will be disabled.
- shuffle: bool
Whether to shuffle the dataset before splitting.
- shuffle_seed: int
The seed to use for shuffling the dataset.
- batch_size: int
The batch size for the dataloaders.
- num_workers: int | Literal['auto']
The number of workers for the dataloaders.
This is the number of processes that generate batches in parallel.
If set to “auto”, the number of workers will be automatically set based on the number of available CPUs.
Set to 0 to disable parallelism.
- pin_memory: bool
Whether to pin memory in the dataloaders.
This is useful for speeding up GPU data transfer.
- class mattertune.configs.data.DBDatasetConfig(*, type='db', src, energy_key=None, forces_key=None, stress_key=None, preload=True)[source]
Configuration for a dataset stored in an ASE database.
- Parameters:
type (Literal['db'])
src (Database | str | Path)
energy_key (str | None)
forces_key (str | None)
stress_key (str | None)
preload (bool)
- type: Literal['db']
Discriminator for the DB dataset.
- src: Database | str | Path
Path to the ASE database file or a database object.
- energy_key: str | None
Key for the energy label in the database.
- forces_key: str | None
Key for the force label in the database.
- stress_key: str | None
Key for the stress label in the database.
- preload: bool
Whether to load all the data at once or not.
- class mattertune.configs.data.DataModuleBaseConfig(*, batch_size, num_workers='auto', pin_memory=True)[source]
- Parameters:
batch_size (int)
num_workers (int | Literal['auto'])
pin_memory (bool)
- batch_size: int
The batch size for the dataloaders.
- num_workers: int | Literal['auto']
The number of workers for the dataloaders.
This is the number of processes that generate batches in parallel.
If set to “auto”, the number of workers will be automatically set based on the number of available CPUs.
Set to 0 to disable parallelism.
- pin_memory: bool
Whether to pin memory in the dataloaders.
This is useful for speeding up GPU data transfer.
- class mattertune.configs.data.DatasetConfigBase[source]
- prepare_data()[source]
Prepare the dataset for training.
Use this to download and prepare data. Downloading and saving data with multiple processes (distributed settings) will result in corrupted data. Lightning ensures this method is called only within a single process, so you can safely add your downloading logic within this method.
- class mattertune.configs.data.JSONDatasetConfig(*, type='json', src, tasks)[source]
- Parameters:
type (Literal['json'])
src (str | Path)
tasks (dict[str, str])
- type: Literal['json']
Discriminator for the JSON dataset.
- src: str | Path
The path to the JSON dataset.
- tasks: dict[str, str]
Attributes in the JSON file that correspond to the tasks to be predicted.
- class mattertune.configs.data.MPDatasetConfig(*, type='mp', api, fields, query)[source]
Configuration for a dataset stored in the Materials Project database.
- Parameters:
type (Literal['mp'])
api (str)
fields (list[str])
query (dict)
- type: Literal['mp']
Discriminator for the MP dataset.
- api: str
Input API key for the Materials Project database.
- fields: list[str]
Fields to retrieve from the Materials Project database.
- query: dict
Query to filter the data from the Materials Project database.
- class mattertune.configs.data.MPTrajDatasetConfig(*, type='mptraj', split='train', min_num_atoms=5, max_num_atoms=None, elements=None)[source]
Configuration for a dataset stored in the Materials Project database.
- Parameters:
type (Literal['mptraj'])
split (Literal['train', 'val', 'test'])
min_num_atoms (int | None)
max_num_atoms (int | None)
elements (list[str] | None)
- type: Literal['mptraj']
Discriminator for the MPTraj dataset.
- split: Literal['train', 'val', 'test']
Split of the dataset to use.
- min_num_atoms: int | None
Minimum number of atoms to be considered. Drops structures with fewer atoms.
- max_num_atoms: int | None
Maximum number of atoms to be considered. Drops structures with more atoms.
- elements: list[str] | None
List of elements to be considered. Drops structures with elements not in the list. Subsets are also allowed. For example, [“Li”, “Na”] will keep structures with either Li or Na.
- class mattertune.configs.data.ManualSplitDataModuleConfig(*, batch_size, num_workers='auto', pin_memory=True, train, validation=None)[source]
- Parameters:
batch_size (int)
num_workers (int | Literal['auto'])
pin_memory (bool)
train (DatasetConfig)
validation (DatasetConfig | None)
- train: DatasetConfig
The configuration for the training data.
- validation: DatasetConfig | None
The configuration for the validation data.
- batch_size: int
The batch size for the dataloaders.
- num_workers: int | Literal['auto']
The number of workers for the dataloaders.
This is the number of processes that generate batches in parallel.
If set to “auto”, the number of workers will be automatically set based on the number of available CPUs.
Set to 0 to disable parallelism.
- pin_memory: bool
Whether to pin memory in the dataloaders.
This is useful for speeding up GPU data transfer.
- class mattertune.configs.data.MatbenchDatasetConfig(*, type='matbench', task=None, property_name=None, fold_idx=0)[source]
Configuration for the Matbench dataset.
- Parameters:
type (Literal['matbench'])
task (str | None)
property_name (str | None)
fold_idx (Literal[0, 1, 2, 3, 4])
- type: Literal['matbench']
Discriminator for the Matbench dataset.
- task: str | None
The name of the self.tasks to include in the dataset.
- property_name: str | None
Assign a property name for the self.task. Must match the property head in the model.
- fold_idx: Literal[0, 1, 2, 3, 4]
The index of the fold to be used in the dataset.
- class mattertune.configs.data.OMAT24DatasetConfig(*, type='omat24', src)[source]
- Parameters:
type (Literal['omat24'])
src (Path)
- type: Literal['omat24']
Discriminator for the OMAT24 dataset.
- src: Path
The path to the OMAT24 dataset.
- class mattertune.configs.data.XYZDatasetConfig(*, type='xyz', src)[source]
- Parameters:
type (Literal['xyz'])
src (str | Path)
- type: Literal['xyz']
Discriminator for the XYZ dataset.
- src: str | Path
The path to the XYZ dataset.
Modules