mythos.optimization.objective
Objectives implemented as frozen chex dataclasses.
Objectives are immutable dataclasses that compute gradients from observables. State is passed through the compute method and returned in the ObjectiveOutput.
Attributes
Classes
Output of an objective calculation. |
|
Frozen dataclass for objectives that calculate gradients. |
|
Frozen dataclass for DiffTRe-based gradient computation. |
Functions
|
Compute the weights and normalized effective sample size of a trajectory. |
|
Compute the grads, loss, and auxiliary values. |
Module Contents
- mythos.optimization.objective.ERR_DIFFTRE_MISSING_KWARGS = 'Missing required kwargs: {missing_kwargs}.'
- mythos.optimization.objective.ERR_MISSING_ARG = 'Missing required argument: {missing_arg}.'
- mythos.optimization.objective.ERR_OBJECTIVE_NOT_READY = 'Not all required observables have been obtained.'
- mythos.optimization.objective.EnergyFn
- mythos.optimization.objective.empty_dict
- class mythos.optimization.objective.ObjectiveOutput[source]
Output of an objective calculation.
- is_ready
Whether the objective has computed gradients.
- grads
The computed gradients, if ready.
- observables
Observable values to preserve between calls.
- state
State information to pass back to the next compute call. For DiffTRe, this includes reference_states, reference_energies, opt_steps.
- needs_update
List of observable names that need new values.
- class mythos.optimization.objective.Objective[source]
Frozen dataclass for objectives that calculate gradients.
Objectives are immutable - all state is passed in and out through the calculate method. The ObjectiveOutput.state field carries state that needs to persist between calculate calls (e.g., reference states for DiffTRe).
- name
The name of the objective.
- required_observables
Observable names required to compute gradients.
- logging_observables
Observable names used for logging.
- grad_or_loss_fn
Function that computes gradients from observables.
- grad_or_loss_fn: Callable[[tuple[Any, Ellipsis]], tuple[mythos.utils.types.Grads, list[tuple[str, Any]]]]
- calculate(observables: dict[str, Any], opt_params: mythos.utils.types.Params | None = None, **_kwargs) ObjectiveOutput[source]
Compute gradients from observables.
- Parameters:
observables – Dictionary mapping observable names to their values.
opt_params – Current optimization parameters (unused in base class).
- Returns:
ObjectiveOutput containing gradients and updated state.
- mythos.optimization.objective.compute_weights_and_neff(beta: float, new_energies: mythos.utils.types.Arr_N, ref_energies: mythos.utils.types.Arr_N) tuple[jax.numpy.ndarray, float][source]
Compute the weights and normalized effective sample size of a trajectory.
Calculation derived from the DiffTRe algorithm.
https://www.nature.com/articles/s41467-021-27241-4 See equations 4 and 5.
- Parameters:
beta – The inverse temperature.
new_energies – The new energies of the trajectory.
ref_energies – The reference energies of the trajectory.
- Returns:
The weights and the normalized effective sample size
- mythos.optimization.objective.compute_loss(opt_params: mythos.utils.types.Params, energy_fn: mythos.energy.base.EnergyFunction, beta: float, loss_fn: collections.abc.Callable[[jax_md.rigid_body.RigidBody, mythos.utils.types.Arr_N, EnergyFn], tuple[jax.numpy.ndarray, tuple[str, Any]]], ref_states: jax_md.rigid_body.RigidBody, ref_energies: mythos.utils.types.Arr_N, observables: list[Any]) tuple[float, tuple[float, jax.numpy.ndarray]][source]
Compute the grads, loss, and auxiliary values.
- Parameters:
opt_params – The optimization parameters.
energy_fn – The energy function.
beta – The inverse temperature.
loss_fn – The loss function.
ref_states – The reference states of the trajectory.
ref_energies – The reference energies of the trajectory.
observables – The observables passed to the loss function.
- Returns:
The grads, the loss, a tuple containing the normalized effective sample size and the measured value of the trajectory, and the new energies.
- mythos.optimization.objective.compute_loss_and_grad
- class mythos.optimization.objective.DiffTReObjective[source]
Bases:
ObjectiveFrozen dataclass for DiffTRe-based gradient computation.
DiffTRe (Differentiable Trajectory Reweighting) allows computing gradients by reweighting trajectories rather than running new simulations. State such as reference_states, reference_energies, and opt_steps is passed through the metadata field of ObjectiveOutput.
- energy_fn
The energy function used to compute energies.
- beta
The inverse temperature.
- n_equilibration_steps
Number of equilibration steps to skip.
- min_n_eff_factor
Minimum normalized effective sample size threshold.
- max_valid_opt_steps
Maximum optimization steps before requiring new trajectory.
- energy_fn: mythos.energy.base.EnergyFunction
- calculate(observables: dict[str, Any], opt_params: mythos.utils.types.Params, opt_steps: int = 0, reference_opt_params: mythos.utils.types.Params | None = None) ObjectiveOutput[source]
Compute gradients using DiffTRe reweighting.
- Parameters:
observables – Dictionary mapping observable names to their values.
metadata –
State from previous calculate call containing: - reference_opt_params: Optimization parameters used to compute
reference energies.
opt_steps: Current optimization step count
opt_params – Current optimization parameters for energy computation.
opt_steps – Current optimization step count.
reference_opt_params – Optimization parameters used to compute reference energies.
- Returns:
ObjectiveOutput with gradients and updated metadata.