from __future__ import annotations
from typing import Annotated, Literal
import nshconfig as C
import torch
import torch.nn.functional as F
from typing_extensions import TypeAliasType, assert_never
[docs]
class MAELossConfig(C.Config):
name: Literal["mae"] = "mae"
reduction: Literal["mean", "sum"] = "mean"
"""How to reduce the loss values across the batch.
- ``"mean"``: The mean of the loss values.
- ``"sum"``: The sum of the loss values.
"""
[docs]
class MSELossConfig(C.Config):
name: Literal["mse"] = "mse"
reduction: Literal["mean", "sum"] = "mean"
"""How to reduce the loss values across the batch.
- ``"mean"``: The mean of the loss values.
- ``"sum"``: The sum of the loss values.
"""
[docs]
class HuberLossConfig(C.Config):
name: Literal["huber"] = "huber"
delta: float = 1.0
"""The threshold value for the Huber loss function."""
reduction: Literal["mean", "sum"] = "mean"
"""How to reduce the loss values across the batch.
- ``"mean"``: The mean of the loss values.
- ``"sum"``: The sum of the loss values.
"""
[docs]
class L2MAELossConfig(C.Config):
name: Literal["l2_mae"] = "l2_mae"
reduction: Literal["mean", "sum"] = "mean"
"""How to reduce the loss values across the batch.
- ``"mean"``: The mean of the loss values.
- ``"sum"``: The sum of the loss values.
"""
[docs]
def l2_mae_loss(
output: torch.Tensor,
target: torch.Tensor,
reduction: Literal["mean", "sum", "none"] = "mean",
) -> torch.Tensor:
distances = F.pairwise_distance(output, target, p=2)
match reduction:
case "mean":
return distances.mean()
case "sum":
return distances.sum()
case "none":
return distances
case _:
assert_never(reduction)
LossConfig = TypeAliasType(
"LossConfig",
Annotated[
MAELossConfig | MSELossConfig | HuberLossConfig | L2MAELossConfig,
C.Field(discriminator="name"),
],
)
[docs]
def compute_loss(
config: LossConfig,
prediction: torch.Tensor,
label: torch.Tensor,
) -> torch.Tensor:
"""
Compute the loss value given the model output, ``prediction``,
and the target label, ``label``.
The loss value should be a scalar tensor.
Args:
config: The loss configuration.
prediction: The model output.
label: The target label.
Returns:
The computed loss value.
"""
try:
prediction = prediction.reshape(label.shape)
except RuntimeError:
raise ValueError(
f"Prediction shape {prediction.shape} does not match ground truth shape {label.shape}"
)
match config:
case MAELossConfig():
return F.l1_loss(prediction, label, reduction=config.reduction)
case MSELossConfig():
return F.mse_loss(prediction, label, reduction=config.reduction)
case HuberLossConfig():
return F.huber_loss(
prediction, label, delta=config.delta, reduction=config.reduction
)
case L2MAELossConfig():
return l2_mae_loss(prediction, label, reduction=config.reduction)
case _:
assert_never(config)