mythos.losses.observable_wrappers
Loss functions for observables.
Attributes
Classes
Base class for loss functions. |
|
Calculate the squared error between the actual and target values. |
|
Calculate the root mean squared error between the actual and target values. |
|
A simple loss function wrapper for an observable. |
Functions
|
Calculate the L2 loss between the actual and target values. |
Module Contents
- mythos.losses.observable_wrappers.loss_input
- class mythos.losses.observable_wrappers.SquaredError[source]
Bases:
LossFnCalculate the squared error between the actual and target values.
- class mythos.losses.observable_wrappers.RootMeanSquaredError[source]
Bases:
LossFnCalculate the root mean squared error between the actual and target values.
- class mythos.losses.observable_wrappers.ObservableLossFn[source]
A simple loss function wrapper for an observable.
- observable: mythos.observables.base.BaseObservable
- __call__(trajectory: mythos.simulators.io.SimulatorTrajectory, target: jax.numpy.ndarray, weights: jax.numpy.ndarray) float[source]
Calculate the loss for the observable over the trajectory.