from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
import torch.nn as nn
import torchmetrics
from typing_extensions import override
if TYPE_CHECKING:
from .properties import PropertyConfig
[docs]
class MetricBase(nn.Module, ABC):
[docs]
@override
def __init__(self, property_name: str):
super().__init__()
self.property_name = property_name
[docs]
@abstractmethod
@override
def forward(
self, prediction: dict[str, Any], ground_truth: dict[str, Any]
) -> Mapping[str, torchmetrics.Metric]: ...
[docs]
class PropertyMetrics(MetricBase):
[docs]
@override
def __init__(self, property_name: str):
super().__init__(property_name)
self.mae = torchmetrics.MeanAbsoluteError()
self.mse = torchmetrics.MeanSquaredError(squared=True)
self.rmse = torchmetrics.MeanSquaredError(squared=False)
# self.r2 = torchmetrics.R2Score()
[docs]
@override
def forward(
self,
prediction: dict[str, Any],
ground_truth: dict[str, Any],
):
y_hat, y = prediction[self.property_name], ground_truth[self.property_name]
try:
y_hat = y_hat.reshape(y.shape)
except RuntimeError:
raise ValueError(
f"Prediction shape {y_hat.shape} does not match ground truth shape {y.shape}"
)
self.mae(y_hat, y)
self.mse(y_hat, y)
self.rmse(y_hat, y)
# self.r2(y_hat, y)
return {
f"{self.property_name}_mae": self.mae,
f"{self.property_name}_mse": self.mse,
f"{self.property_name}_rmse": self.rmse,
# f"{self.property_name}_r2": self.r2,
}
[docs]
class FinetuneMetrics(nn.Module):
[docs]
def __init__(
self,
properties: Sequence[PropertyConfig],
metric_prefix: str = "",
):
super().__init__()
self.metric_modules = nn.ModuleList(
[prop.metric_cls()(prop.name) for prop in properties]
)
self.metric_prefix = metric_prefix
[docs]
@override
def forward(
self, predictions: dict[str, Any], labels: dict[str, Any]
) -> Mapping[str, torchmetrics.Metric]:
metrics = {}
for metric_module in self.metric_modules:
metrics.update(metric_module(predictions, labels))
return {
f"{self.metric_prefix}{metric_name}": metric
for metric_name, metric in metrics.items()
}