mythos.optimization.objective ============================= .. py:module:: mythos.optimization.objective .. autoapi-nested-parse:: Objectives implemented as frozen chex dataclasses. Objectives are immutable dataclasses that compute gradients from observables. State is passed through the compute method and returned in the ObjectiveOutput. Attributes ---------- .. autoapisummary:: mythos.optimization.objective.ERR_DIFFTRE_MISSING_KWARGS mythos.optimization.objective.ERR_MISSING_ARG mythos.optimization.objective.ERR_OBJECTIVE_NOT_READY mythos.optimization.objective.EnergyFn mythos.optimization.objective.empty_dict mythos.optimization.objective.compute_loss_and_grad Classes ------- .. autoapisummary:: mythos.optimization.objective.ObjectiveOutput mythos.optimization.objective.Objective mythos.optimization.objective.DiffTReObjective Functions --------- .. autoapisummary:: mythos.optimization.objective.compute_weights_and_neff mythos.optimization.objective.compute_loss Module Contents --------------- .. py:data:: ERR_DIFFTRE_MISSING_KWARGS :value: 'Missing required kwargs: {missing_kwargs}.' .. py:data:: ERR_MISSING_ARG :value: 'Missing required argument: {missing_arg}.' .. py:data:: ERR_OBJECTIVE_NOT_READY :value: 'Not all required observables have been obtained.' .. py:data:: EnergyFn .. py:data:: empty_dict .. py:class:: ObjectiveOutput Output of an objective calculation. .. attribute:: is_ready Whether the objective has computed gradients. .. attribute:: grads The computed gradients, if ready. .. attribute:: observables Observable values to preserve between calls. .. attribute:: state State information to pass back to the next compute call. For DiffTRe, this includes reference_states, reference_energies, opt_steps. .. attribute:: needs_update List of observable names that need new values. .. py:attribute:: is_ready :type: bool .. py:attribute:: grads :type: mythos.utils.types.Grads | None :value: None .. py:attribute:: observables :type: dict[str, Any] .. py:attribute:: state :type: dict[str, Any] .. py:attribute:: needs_update :type: tuple[str, Ellipsis] .. py:class:: Objective Frozen dataclass for objectives that calculate gradients. Objectives are immutable - all state is passed in and out through the calculate method. The ObjectiveOutput.state field carries state that needs to persist between calculate calls (e.g., reference states for DiffTRe). .. attribute:: name The name of the objective. .. attribute:: required_observables Observable names required to compute gradients. .. attribute:: logging_observables Observable names used for logging. .. attribute:: grad_or_loss_fn Function that computes gradients from observables. .. py:attribute:: name :type: str .. py:attribute:: required_observables :type: tuple[str, Ellipsis] .. py:attribute:: logging_observables :type: tuple[str, Ellipsis] .. py:attribute:: grad_or_loss_fn :type: Callable[[tuple[Any, Ellipsis]], tuple[mythos.utils.types.Grads, list[tuple[str, Any]]]] .. py:method:: __post_init__() -> None Validate required fields. .. py:method:: calculate(observables: dict[str, Any], opt_params: mythos.utils.types.Params | None = None, **_kwargs) -> ObjectiveOutput Compute gradients from observables. :param observables: Dictionary mapping observable names to their values. :param opt_params: Current optimization parameters (unused in base class). :returns: ObjectiveOutput containing gradients and updated state. .. py:method:: get_logging_observables(observables: dict[str, Any]) -> list[tuple[str, Any]] Return the observable values for logging. :param observables: Dictionary mapping observable names to their values. :returns: List of (name, value) tuples for logging observables. .. py:function:: compute_weights_and_neff(beta: float, new_energies: mythos.utils.types.Arr_N, ref_energies: mythos.utils.types.Arr_N) -> tuple[jax.numpy.ndarray, float] Compute the weights and normalized effective sample size of a trajectory. Calculation derived from the DiffTRe algorithm. https://www.nature.com/articles/s41467-021-27241-4 See equations 4 and 5. :param beta: The inverse temperature. :param new_energies: The new energies of the trajectory. :param ref_energies: The reference energies of the trajectory. :returns: The weights and the normalized effective sample size .. py:function:: compute_loss(opt_params: mythos.utils.types.Params, energy_fn: mythos.energy.base.EnergyFunction, beta: float, loss_fn: collections.abc.Callable[[jax_md.rigid_body.RigidBody, mythos.utils.types.Arr_N, EnergyFn], tuple[jax.numpy.ndarray, tuple[str, Any]]], ref_states: jax_md.rigid_body.RigidBody, ref_energies: mythos.utils.types.Arr_N, observables: list[Any]) -> tuple[float, tuple[float, jax.numpy.ndarray]] Compute the grads, loss, and auxiliary values. :param opt_params: The optimization parameters. :param energy_fn: The energy function. :param beta: The inverse temperature. :param loss_fn: The loss function. :param ref_states: The reference states of the trajectory. :param ref_energies: The reference energies of the trajectory. :param observables: The observables passed to the loss function. :returns: The grads, the loss, a tuple containing the normalized effective sample size and the measured value of the trajectory, and the new energies. .. py:data:: compute_loss_and_grad .. py:class:: DiffTReObjective Bases: :py:obj:`Objective` Frozen dataclass for DiffTRe-based gradient computation. DiffTRe (Differentiable Trajectory Reweighting) allows computing gradients by reweighting trajectories rather than running new simulations. State such as reference_states, reference_energies, and opt_steps is passed through the metadata field of ObjectiveOutput. .. attribute:: energy_fn The energy function used to compute energies. .. attribute:: beta The inverse temperature. .. attribute:: n_equilibration_steps Number of equilibration steps to skip. .. attribute:: min_n_eff_factor Minimum normalized effective sample size threshold. .. attribute:: max_valid_opt_steps Maximum optimization steps before requiring new trajectory. .. py:attribute:: energy_fn :type: mythos.energy.base.EnergyFunction .. py:attribute:: beta :type: float .. py:attribute:: n_equilibration_steps :type: int :value: 0 .. py:attribute:: min_n_eff_factor :type: float :value: 0.95 .. py:attribute:: max_valid_opt_steps :type: float :value: inf .. py:method:: __post_init__() -> None Validate required fields. .. py:method:: calculate(observables: dict[str, Any], opt_params: mythos.utils.types.Params, opt_steps: int = 0, reference_opt_params: mythos.utils.types.Params | None = None) -> ObjectiveOutput Compute gradients using DiffTRe reweighting. :param observables: Dictionary mapping observable names to their values. :param metadata: State from previous calculate call containing: - reference_opt_params: Optimization parameters used to compute reference energies. - opt_steps: Current optimization step count :param opt_params: Current optimization parameters for energy computation. :param opt_steps: Current optimization step count. :param reference_opt_params: Optimization parameters used to compute reference energies. :returns: ObjectiveOutput with gradients and updated metadata.