mythos.losses.observable_wrappers

Loss functions for observables.

Attributes

loss_input

Classes

LossFn

Base class for loss functions.

SquaredError

Calculate the squared error between the actual and target values.

RootMeanSquaredError

Calculate the root mean squared error between the actual and target values.

ObservableLossFn

A simple loss function wrapper for an observable.

Functions

l2_loss(→ float)

Calculate the L2 loss between the actual and target values.

Module Contents

mythos.losses.observable_wrappers.loss_input
class mythos.losses.observable_wrappers.LossFn[source]

Base class for loss functions.

abstractmethod __call__(actual: loss_input, target: loss_input, weights: jax.numpy.ndarray) float[source]

Calculate the loss.

class mythos.losses.observable_wrappers.SquaredError[source]

Bases: LossFn

Calculate the squared error between the actual and target values.

__call__(actual: jax.numpy.ndarray, target: jax.numpy.ndarray) float[source]

Calculate the loss.

class mythos.losses.observable_wrappers.RootMeanSquaredError[source]

Bases: LossFn

Calculate the root mean squared error between the actual and target values.

__call__(actual: jax.numpy.ndarray, target: jax.numpy.ndarray) float[source]

Calculate the loss.

class mythos.losses.observable_wrappers.ObservableLossFn[source]

A simple loss function wrapper for an observable.

observable: mythos.observables.base.BaseObservable
loss_fn: LossFn
return_observable: bool = False
__call__(trajectory: mythos.simulators.io.SimulatorTrajectory, target: jax.numpy.ndarray, weights: jax.numpy.ndarray) float[source]

Calculate the loss for the observable over the trajectory.

mythos.losses.observable_wrappers.l2_loss(actual: jax.numpy.ndarray, target: jax.numpy.ndarray) float[source]

Calculate the L2 loss between the actual and target values.