Source code for mattertune.finetune.loss

from __future__ import annotations

from typing import Annotated, Literal

import nshconfig as C
import torch
import torch.nn.functional as F
from typing_extensions import TypeAliasType, assert_never


[docs] class MAELossConfig(C.Config): name: Literal["mae"] = "mae" reduction: Literal["mean", "sum"] = "mean" """How to reduce the loss values across the batch. - ``"mean"``: The mean of the loss values. - ``"sum"``: The sum of the loss values. """
[docs] class MSELossConfig(C.Config): name: Literal["mse"] = "mse" reduction: Literal["mean", "sum"] = "mean" """How to reduce the loss values across the batch. - ``"mean"``: The mean of the loss values. - ``"sum"``: The sum of the loss values. """
[docs] class HuberLossConfig(C.Config): name: Literal["huber"] = "huber" delta: float = 1.0 """The threshold value for the Huber loss function.""" reduction: Literal["mean", "sum"] = "mean" """How to reduce the loss values across the batch. - ``"mean"``: The mean of the loss values. - ``"sum"``: The sum of the loss values. """
[docs] class L2MAELossConfig(C.Config): name: Literal["l2_mae"] = "l2_mae" reduction: Literal["mean", "sum"] = "mean" """How to reduce the loss values across the batch. - ``"mean"``: The mean of the loss values. - ``"sum"``: The sum of the loss values. """
[docs] def l2_mae_loss( output: torch.Tensor, target: torch.Tensor, reduction: Literal["mean", "sum", "none"] = "mean", ) -> torch.Tensor: distances = F.pairwise_distance(output, target, p=2) match reduction: case "mean": return distances.mean() case "sum": return distances.sum() case "none": return distances case _: assert_never(reduction)
LossConfig = TypeAliasType( "LossConfig", Annotated[ MAELossConfig | MSELossConfig | HuberLossConfig | L2MAELossConfig, C.Field(discriminator="name"), ], )
[docs] def compute_loss( config: LossConfig, prediction: torch.Tensor, label: torch.Tensor, ) -> torch.Tensor: """ Compute the loss value given the model output, ``prediction``, and the target label, ``label``. The loss value should be a scalar tensor. Args: config: The loss configuration. prediction: The model output. label: The target label. Returns: The computed loss value. """ try: prediction = prediction.reshape(label.shape) except RuntimeError: raise ValueError( f"Prediction shape {prediction.shape} does not match ground truth shape {label.shape}" ) match config: case MAELossConfig(): return F.l1_loss(prediction, label, reduction=config.reduction) case MSELossConfig(): return F.mse_loss(prediction, label, reduction=config.reduction) case HuberLossConfig(): return F.huber_loss( prediction, label, delta=config.delta, reduction=config.reduction ) case L2MAELossConfig(): return l2_mae_loss(prediction, label, reduction=config.reduction) case _: assert_never(config)