mythos.optimization.objective

Objectives implemented as frozen chex dataclasses.

Objectives are immutable dataclasses that compute gradients from observables. State is passed through the compute method and returned in the ObjectiveOutput.

Attributes

ERR_DIFFTRE_MISSING_KWARGS

ERR_MISSING_ARG

ERR_OBJECTIVE_NOT_READY

EnergyFn

empty_dict

compute_loss_and_grad

Classes

ObjectiveOutput

Output of an objective calculation.

Objective

Frozen dataclass for objectives that calculate gradients.

DiffTReObjective

Frozen dataclass for DiffTRe-based gradient computation.

Functions

compute_weights_and_neff(→ tuple[jax.numpy.ndarray, float])

Compute the weights and normalized effective sample size of a trajectory.

compute_loss(→ tuple[float, tuple[float, ...)

Compute the grads, loss, and auxiliary values.

Module Contents

mythos.optimization.objective.ERR_DIFFTRE_MISSING_KWARGS = 'Missing required kwargs: {missing_kwargs}.'
mythos.optimization.objective.ERR_MISSING_ARG = 'Missing required argument: {missing_arg}.'
mythos.optimization.objective.ERR_OBJECTIVE_NOT_READY = 'Not all required observables have been obtained.'
mythos.optimization.objective.EnergyFn
mythos.optimization.objective.empty_dict
class mythos.optimization.objective.ObjectiveOutput[source]

Output of an objective calculation.

is_ready

Whether the objective has computed gradients.

grads

The computed gradients, if ready.

observables

Observable values to preserve between calls.

state

State information to pass back to the next compute call. For DiffTRe, this includes reference_states, reference_energies, opt_steps.

needs_update

List of observable names that need new values.

is_ready: bool
grads: mythos.utils.types.Grads | None = None
observables: dict[str, Any]
state: dict[str, Any]
needs_update: tuple[str, Ellipsis]
class mythos.optimization.objective.Objective[source]

Frozen dataclass for objectives that calculate gradients.

Objectives are immutable - all state is passed in and out through the calculate method. The ObjectiveOutput.state field carries state that needs to persist between calculate calls (e.g., reference states for DiffTRe).

name

The name of the objective.

required_observables

Observable names required to compute gradients.

logging_observables

Observable names used for logging.

grad_or_loss_fn

Function that computes gradients from observables.

name: str
required_observables: tuple[str, Ellipsis]
logging_observables: tuple[str, Ellipsis]
grad_or_loss_fn: Callable[[tuple[Any, Ellipsis]], tuple[mythos.utils.types.Grads, list[tuple[str, Any]]]]
__post_init__() None[source]

Validate required fields.

calculate(observables: dict[str, Any], opt_params: mythos.utils.types.Params | None = None, **_kwargs) ObjectiveOutput[source]

Compute gradients from observables.

Parameters:
  • observables – Dictionary mapping observable names to their values.

  • opt_params – Current optimization parameters (unused in base class).

Returns:

ObjectiveOutput containing gradients and updated state.

get_logging_observables(observables: dict[str, Any]) list[tuple[str, Any]][source]

Return the observable values for logging.

Parameters:

observables – Dictionary mapping observable names to their values.

Returns:

List of (name, value) tuples for logging observables.

mythos.optimization.objective.compute_weights_and_neff(beta: float, new_energies: mythos.utils.types.Arr_N, ref_energies: mythos.utils.types.Arr_N) tuple[jax.numpy.ndarray, float][source]

Compute the weights and normalized effective sample size of a trajectory.

Calculation derived from the DiffTRe algorithm.

https://www.nature.com/articles/s41467-021-27241-4 See equations 4 and 5.

Parameters:
  • beta – The inverse temperature.

  • new_energies – The new energies of the trajectory.

  • ref_energies – The reference energies of the trajectory.

Returns:

The weights and the normalized effective sample size

mythos.optimization.objective.compute_loss(opt_params: mythos.utils.types.Params, energy_fn: mythos.energy.base.EnergyFunction, beta: float, loss_fn: collections.abc.Callable[[jax_md.rigid_body.RigidBody, mythos.utils.types.Arr_N, EnergyFn], tuple[jax.numpy.ndarray, tuple[str, Any]]], ref_states: jax_md.rigid_body.RigidBody, ref_energies: mythos.utils.types.Arr_N, observables: list[Any]) tuple[float, tuple[float, jax.numpy.ndarray]][source]

Compute the grads, loss, and auxiliary values.

Parameters:
  • opt_params – The optimization parameters.

  • energy_fn – The energy function.

  • beta – The inverse temperature.

  • loss_fn – The loss function.

  • ref_states – The reference states of the trajectory.

  • ref_energies – The reference energies of the trajectory.

  • observables – The observables passed to the loss function.

Returns:

The grads, the loss, a tuple containing the normalized effective sample size and the measured value of the trajectory, and the new energies.

mythos.optimization.objective.compute_loss_and_grad
class mythos.optimization.objective.DiffTReObjective[source]

Bases: Objective

Frozen dataclass for DiffTRe-based gradient computation.

DiffTRe (Differentiable Trajectory Reweighting) allows computing gradients by reweighting trajectories rather than running new simulations. State such as reference_states, reference_energies, and opt_steps is passed through the metadata field of ObjectiveOutput.

energy_fn

The energy function used to compute energies.

beta

The inverse temperature.

n_equilibration_steps

Number of equilibration steps to skip.

min_n_eff_factor

Minimum normalized effective sample size threshold.

max_valid_opt_steps

Maximum optimization steps before requiring new trajectory.

energy_fn: mythos.energy.base.EnergyFunction
beta: float
n_equilibration_steps: int = 0
min_n_eff_factor: float = 0.95
max_valid_opt_steps: float = inf
__post_init__() None[source]

Validate required fields.

calculate(observables: dict[str, Any], opt_params: mythos.utils.types.Params, opt_steps: int = 0, reference_opt_params: mythos.utils.types.Params | None = None) ObjectiveOutput[source]

Compute gradients using DiffTRe reweighting.

Parameters:
  • observables – Dictionary mapping observable names to their values.

  • metadata

    State from previous calculate call containing: - reference_opt_params: Optimization parameters used to compute

    reference energies.

    • opt_steps: Current optimization step count

  • opt_params – Current optimization parameters for energy computation.

  • opt_steps – Current optimization step count.

  • reference_opt_params – Optimization parameters used to compute reference energies.

Returns:

ObjectiveOutput with gradients and updated metadata.