Advanced Usage

Custom Energy Functions

mythos supports custom energy functions for jax_md simulations. Energy functions are generally comprised of two components: an EnergyFunction and optionally an EnergyConfiguration.

Note

All implemented EnergyFunctions and EnergyConfigurations should be annotated with @chex.dataclass, from chex. This is a decorator makes the class compatible with jax

Custom energy functions should be implemented as subclasses of the BaseEnergy class, see mythos.energy.base. Further, any custom energy function should also implement the __call__ function with the following signature:

def __call__(
    self,
    body: jax_md.rigid_body.RigidBody,
    seq: typ.Sequence,
    bonded_neighbors: typ.Arr_Bonded_Neighbors_2,
    unbonded_neighbors: typ.Arr_Unbonded_Neighbors_2,
) -> float:
  # return a single float that describes the energy for the current state

Here body is a jax_md.rigid_body.RigidBody object that contains the current state of the system, seq is a sequence of nucleotides, and bonded_neighbors and unbonded_neighbors are arrays of bonded and unbonded neighbors respectively.

By deriving from the BaseEnergy class, the function can be included with other implemented functions, sharing a common interface. Further the BaseEnergy function implements helpers like the + operator, which can be used to combine energy functions.

EnergyFunctions are paired with EnergyConfigurations, which are used to store the parameters of the energy function. More information is available in the BaseEnergyConfiguration class, see mythos.energy.base.

Warning

Any parameters defined in a EnergyConfiguration should be annotated with @chex.dataclass. And all parameters should be optional is the derived classes due to the way the base configuration is implemented.

An example of a trivial energy function is show below:

from typing_extensions import override

import chex

import jax.numpy as jnp
import mythos.energy.base as jdna_energy
import mythos.utils.types as typ

@chex.dataclass
class TrivialEnergyConfiguration(jdna_energy.BaseEnergyConfiguration):
    some_opt_parameter: float | None = None
    some_dep_parameter: float | None = None

    required_params = ["some_opt_parameter"]

    @override
    def init_params(self) -> "TrivialEnergyConfiguration":
        self.some_dep_parameter = 2 * self.some_opt_parameter
        return self


@chex.dataclass
class TrivialEnergy(jd.BaseEnergy):

    @overrride
    def __call__(
        self,
        body: je_base.BaseNucleotide,
        seq: typ.Sequence,
        bonded_neighbors: typ.Arr_Bonded_Neighbors_2,
        unbonded_neighbors: typ.Arr_Unbonded_Neighbors_2,
    ) -> float:

        bonded_i = body[bonded_neighbors[0,:]].center
        bonded_j = body[bonded_neighbors[1,:]].center

        return jnp.sum(jnp.linalg.norm(bonded_i - bonded_j)) + self.config.some_dep_parameter

More examples can be found by looking at the implemented energies in mythos.energy.base

Advanced Optimizations

Beyond the simple optimization covered in Basic Usage more sophisticated optimizations require multiple heterogenous simulations and with multiple kinds of loss functions. To accommodate this, mythos sets up optimizations using the following abstractions:

_images/mythos_opt_diagram.svg
  • Simulator: A Simulator is actor that that exposes one or more Observables.

  • Observable: An Observable is something produced by a Simulator. It can be a trajectory, scalar, vector, or a tensor. Or really anything that an Objective needs to compute its the loss/gradients.

  • Objective: An Objective is an actor that takes in one or more Observables and returns the gradients of the Objective with respect to the parameters we want to optimize.

  • Optimizer: An Optimizer coordinates running the Simulators and to produce the Observables that are needed by the Objectives to optimize the parameters we are interested in.

Using these abstractions mythos leverages the ray library to run Simulators and Objectives in parallel across multiple heterogenous devices. This allows for mythos to schedule Simulators and calculate gradients using Objectives in parallel. This is particularly useful when the Simulators are slow to run and the Objectives are expensive to compute.

See advanced_optimizations for more details and examples.