Source code for mythos.losses.observable_wrappers

"""Loss functions for observables."""

from typing import Any

import chex
import jax.numpy as jnp
from typing_extensions import override

import mythos.observables.base as jd_obs_base
import mythos.simulators.io as jd_sio

loss_input = jnp.ndarray | tuple[jnp.ndarray, dict[str, Any]]


[docs] @chex.dataclass class LossFn: """Base class for loss functions."""
[docs] def __call__(self, actual: loss_input, target: loss_input, weights: jnp.ndarray) -> float: """Calculate the loss.""" raise NotImplementedError("Subclasses must implement this method.")
[docs] @chex.dataclass class SquaredError(LossFn): """Calculate the squared error between the actual and target values."""
[docs] @override def __call__(self, actual: jnp.ndarray, target: jnp.ndarray) -> float: return (target - actual) ** 2
[docs] @chex.dataclass class RootMeanSquaredError(LossFn): """Calculate the root mean squared error between the actual and target values."""
[docs] @override def __call__(self, actual: jnp.ndarray, target: jnp.ndarray) -> float: return jnp.sqrt(jnp.mean((target - actual) ** 2))
[docs] @chex.dataclass class ObservableLossFn: """A simple loss function wrapper for an observable.""" observable: jd_obs_base.BaseObservable loss_fn: LossFn return_observable: bool = False
[docs] def __call__(self, trajectory: jd_sio.SimulatorTrajectory, target: jnp.ndarray, weights: jnp.ndarray) -> float: """Calculate the loss for the observable over the trajectory.""" observable = jnp.sum(self.observable(trajectory) * weights) vals = [self.loss_fn(observable, target)] if self.return_observable: vals.append(observable) return tuple(vals)
[docs] def l2_loss(actual: jnp.ndarray, target: jnp.ndarray) -> float: """Calculate the L2 loss between the actual and target values.""" return jnp.sum((actual - target) ** 2)