mythos.losses.observable_wrappers ================================= .. py:module:: mythos.losses.observable_wrappers .. autoapi-nested-parse:: Loss functions for observables. Attributes ---------- .. autoapisummary:: mythos.losses.observable_wrappers.loss_input Classes ------- .. autoapisummary:: mythos.losses.observable_wrappers.LossFn mythos.losses.observable_wrappers.SquaredError mythos.losses.observable_wrappers.RootMeanSquaredError mythos.losses.observable_wrappers.ObservableLossFn Functions --------- .. autoapisummary:: mythos.losses.observable_wrappers.l2_loss Module Contents --------------- .. py:data:: loss_input .. py:class:: LossFn Base class for loss functions. .. py:method:: __call__(actual: loss_input, target: loss_input, weights: jax.numpy.ndarray) -> float :abstractmethod: Calculate the loss. .. py:class:: SquaredError Bases: :py:obj:`LossFn` Calculate the squared error between the actual and target values. .. py:method:: __call__(actual: jax.numpy.ndarray, target: jax.numpy.ndarray) -> float Calculate the loss. .. py:class:: RootMeanSquaredError Bases: :py:obj:`LossFn` Calculate the root mean squared error between the actual and target values. .. py:method:: __call__(actual: jax.numpy.ndarray, target: jax.numpy.ndarray) -> float Calculate the loss. .. py:class:: ObservableLossFn A simple loss function wrapper for an observable. .. py:attribute:: observable :type: mythos.observables.base.BaseObservable .. py:attribute:: loss_fn :type: LossFn .. py:attribute:: return_observable :type: bool :value: False .. py:method:: __call__(trajectory: mythos.simulators.io.SimulatorTrajectory, target: jax.numpy.ndarray, weights: jax.numpy.ndarray) -> float Calculate the loss for the observable over the trajectory. .. py:function:: l2_loss(actual: jax.numpy.ndarray, target: jax.numpy.ndarray) -> float Calculate the L2 loss between the actual and target values.