mattertune.configs.data

class mattertune.configs.data.AutoSplitDataModuleConfig(*, batch_size, num_workers='auto', pin_memory=True, dataset, train_split, validation_split='auto', shuffle=True, shuffle_seed=42)[source]
Parameters:
  • batch_size (int)

  • num_workers (int | Literal['auto'])

  • pin_memory (bool)

  • dataset (DatasetConfig)

  • train_split (float)

  • validation_split (float | Literal['auto', 'disable'])

  • shuffle (bool)

  • shuffle_seed (int)

dataset: DatasetConfig

The configuration for the dataset.

train_split: float

The proportion of the dataset to include in the training split.

validation_split: float | Literal['auto', 'disable']

The proportion of the dataset to include in the validation split.

If set to “auto”, the validation split will be automatically determined as the complement of the training split, i.e. validation_split = 1 - train_split.

If set to “disable”, the validation split will be disabled.

shuffle: bool

Whether to shuffle the dataset before splitting.

shuffle_seed: int

The seed to use for shuffling the dataset.

dataset_configs()[source]
create_datasets()[source]
batch_size: int

The batch size for the dataloaders.

num_workers: int | Literal['auto']

The number of workers for the dataloaders.

This is the number of processes that generate batches in parallel.

If set to “auto”, the number of workers will be automatically set based on the number of available CPUs.

Set to 0 to disable parallelism.

pin_memory: bool

Whether to pin memory in the dataloaders.

This is useful for speeding up GPU data transfer.

class mattertune.configs.data.DBDatasetConfig(*, type='db', src, energy_key=None, forces_key=None, stress_key=None, preload=True)[source]

Configuration for a dataset stored in an ASE database.

Parameters:
  • type (Literal['db'])

  • src (Database | str | Path)

  • energy_key (str | None)

  • forces_key (str | None)

  • stress_key (str | None)

  • preload (bool)

type: Literal['db']

Discriminator for the DB dataset.

src: Database | str | Path

Path to the ASE database file or a database object.

energy_key: str | None

Key for the energy label in the database.

forces_key: str | None

Key for the force label in the database.

stress_key: str | None

Key for the stress label in the database.

preload: bool

Whether to load all the data at once or not.

create_dataset()[source]
class mattertune.configs.data.DataModuleBaseConfig(*, batch_size, num_workers='auto', pin_memory=True)[source]
Parameters:
  • batch_size (int)

  • num_workers (int | Literal['auto'])

  • pin_memory (bool)

batch_size: int

The batch size for the dataloaders.

num_workers: int | Literal['auto']

The number of workers for the dataloaders.

This is the number of processes that generate batches in parallel.

If set to “auto”, the number of workers will be automatically set based on the number of available CPUs.

Set to 0 to disable parallelism.

pin_memory: bool

Whether to pin memory in the dataloaders.

This is useful for speeding up GPU data transfer.

dataloader_kwargs()[source]
Return type:

DataLoaderKwargs

abstract dataset_configs()[source]
Return type:

Iterable[DatasetConfig]

abstract create_datasets()[source]
Return type:

DatasetMapping

class mattertune.configs.data.DatasetConfigBase[source]
abstract create_dataset()[source]
Return type:

Dataset[Atoms]

prepare_data()[source]

Prepare the dataset for training.

Use this to download and prepare data. Downloading and saving data with multiple processes (distributed settings) will result in corrupted data. Lightning ensures this method is called only within a single process, so you can safely add your downloading logic within this method.

classmethod ensure_dependencies()[source]

Ensure that all dependencies are installed.

This method should raise an exception if any dependencies are missing, with a message indicating which dependencies are missing and how to install them.

class mattertune.configs.data.JSONDatasetConfig(*, type='json', src, tasks)[source]
Parameters:
  • type (Literal['json'])

  • src (str | Path)

  • tasks (dict[str, str])

type: Literal['json']

Discriminator for the JSON dataset.

src: str | Path

The path to the JSON dataset.

tasks: dict[str, str]

Attributes in the JSON file that correspond to the tasks to be predicted.

create_dataset()[source]
class mattertune.configs.data.MPDatasetConfig(*, type='mp', api, fields, query)[source]

Configuration for a dataset stored in the Materials Project database.

Parameters:
  • type (Literal['mp'])

  • api (str)

  • fields (list[str])

  • query (dict)

type: Literal['mp']

Discriminator for the MP dataset.

api: str

Input API key for the Materials Project database.

fields: list[str]

Fields to retrieve from the Materials Project database.

query: dict

Query to filter the data from the Materials Project database.

create_dataset()[source]
class mattertune.configs.data.MPTrajDatasetConfig(*, type='mptraj', split='train', min_num_atoms=5, max_num_atoms=None, elements=None)[source]

Configuration for a dataset stored in the Materials Project database.

Parameters:
  • type (Literal['mptraj'])

  • split (Literal['train', 'val', 'test'])

  • min_num_atoms (int | None)

  • max_num_atoms (int | None)

  • elements (list[str] | None)

type: Literal['mptraj']

Discriminator for the MPTraj dataset.

split: Literal['train', 'val', 'test']

Split of the dataset to use.

min_num_atoms: int | None

Minimum number of atoms to be considered. Drops structures with fewer atoms.

max_num_atoms: int | None

Maximum number of atoms to be considered. Drops structures with more atoms.

elements: list[str] | None

List of elements to be considered. Drops structures with elements not in the list. Subsets are also allowed. For example, [“Li”, “Na”] will keep structures with either Li or Na.

create_dataset()[source]
class mattertune.configs.data.ManualSplitDataModuleConfig(*, batch_size, num_workers='auto', pin_memory=True, train, validation=None)[source]
Parameters:
  • batch_size (int)

  • num_workers (int | Literal['auto'])

  • pin_memory (bool)

  • train (DatasetConfig)

  • validation (DatasetConfig | None)

train: DatasetConfig

The configuration for the training data.

validation: DatasetConfig | None

The configuration for the validation data.

dataset_configs()[source]
create_datasets()[source]
batch_size: int

The batch size for the dataloaders.

num_workers: int | Literal['auto']

The number of workers for the dataloaders.

This is the number of processes that generate batches in parallel.

If set to “auto”, the number of workers will be automatically set based on the number of available CPUs.

Set to 0 to disable parallelism.

pin_memory: bool

Whether to pin memory in the dataloaders.

This is useful for speeding up GPU data transfer.

class mattertune.configs.data.MatbenchDatasetConfig(*, type='matbench', task=None, property_name=None, fold_idx=0)[source]

Configuration for the Matbench dataset.

Parameters:
  • type (Literal['matbench'])

  • task (str | None)

  • property_name (str | None)

  • fold_idx (Literal[0, 1, 2, 3, 4])

type: Literal['matbench']

Discriminator for the Matbench dataset.

task: str | None

The name of the self.tasks to include in the dataset.

property_name: str | None

Assign a property name for the self.task. Must match the property head in the model.

fold_idx: Literal[0, 1, 2, 3, 4]

The index of the fold to be used in the dataset.

create_dataset()[source]
class mattertune.configs.data.OMAT24DatasetConfig(*, type='omat24', src)[source]
Parameters:
  • type (Literal['omat24'])

  • src (Path)

type: Literal['omat24']

Discriminator for the OMAT24 dataset.

src: Path

The path to the OMAT24 dataset.

create_dataset()[source]
class mattertune.configs.data.XYZDatasetConfig(*, type='xyz', src)[source]
Parameters:
  • type (Literal['xyz'])

  • src (str | Path)

type: Literal['xyz']

Discriminator for the XYZ dataset.

src: str | Path

The path to the XYZ dataset.

create_dataset()[source]

Modules

base

datamodule

db

json_data

matbench

mp

mptraj

omat24

xyz