Advanced Usage
Custom Energy Functions
mythos supports custom energy functions for jax_md simulations. Energy
functions are generally comprised of two components: an EnergyFunction and
optionally an EnergyConfiguration.
Note
All implemented EnergyFunctions and EnergyConfigurations should be annotated
with @chex.dataclass, from chex.
This is a decorator makes the class compatible with
jax
Custom energy functions should be implemented as subclasses of the BaseEnergy
class, see mythos.energy.base. Further, any custom
energy function should also implement the __call__ function with the
following signature:
def __call__(
self,
body: jax_md.rigid_body.RigidBody,
seq: typ.Sequence,
bonded_neighbors: typ.Arr_Bonded_Neighbors_2,
unbonded_neighbors: typ.Arr_Unbonded_Neighbors_2,
) -> float:
# return a single float that describes the energy for the current state
Here body is a jax_md.rigid_body.RigidBody object that contains the
current state of the system, seq is a sequence of nucleotides, and
bonded_neighbors and unbonded_neighbors are arrays of bonded and
unbonded neighbors respectively.
By deriving from the BaseEnergy class, the function can be included with other implemented functions, sharing a common interface. Further the BaseEnergy function implements helpers like the + operator, which can be used to combine energy functions.
EnergyFunctions are paired with EnergyConfigurations, which are used to store the parameters of the energy function. More information is available in the BaseEnergyConfiguration class, see mythos.energy.base.
Warning
Any parameters defined in a EnergyConfiguration should be annotated with
@chex.dataclass. And all parameters should be optional is the derived
classes due to the way the base configuration is implemented.
An example of a trivial energy function is show below:
from typing_extensions import override
import chex
import jax.numpy as jnp
import mythos.energy.base as jdna_energy
import mythos.utils.types as typ
@chex.dataclass
class TrivialEnergyConfiguration(jdna_energy.BaseEnergyConfiguration):
some_opt_parameter: float | None = None
some_dep_parameter: float | None = None
required_params = ["some_opt_parameter"]
@override
def init_params(self) -> "TrivialEnergyConfiguration":
self.some_dep_parameter = 2 * self.some_opt_parameter
return self
@chex.dataclass
class TrivialEnergy(jd.BaseEnergy):
@overrride
def __call__(
self,
body: je_base.BaseNucleotide,
seq: typ.Sequence,
bonded_neighbors: typ.Arr_Bonded_Neighbors_2,
unbonded_neighbors: typ.Arr_Unbonded_Neighbors_2,
) -> float:
bonded_i = body[bonded_neighbors[0,:]].center
bonded_j = body[bonded_neighbors[1,:]].center
return jnp.sum(jnp.linalg.norm(bonded_i - bonded_j)) + self.config.some_dep_parameter
More examples can be found by looking at the implemented energies in mythos.energy.base
Advanced Optimizations
Beyond the simple optimization covered in Basic Usage more sophisticated
optimizations require multiple heterogenous simulations and with multiple kinds
of loss functions. To accommodate this, mythos sets up optimizations using
the following abstractions:
Simulator: ASimulatoris actor that that exposes one or moreObservables.Observable: AnObservableis something produced by aSimulator. It can be a trajectory, scalar, vector, or a tensor. Or really anything that anObjectiveneeds to compute its the loss/gradients.Objective: AnObjectiveis an actor that takes in one or moreObservablesand returns the gradients of theObjectivewith respect to the parameters we want to optimize.Optimizer: AnOptimizercoordinates running theSimulatorsand to produce theObservablesthat are needed by theObjectivesto optimize the parameters we are interested in.
Using these abstractions mythos leverages the ray
library to run Simulators and Objectives in parallel across multiple
heterogenous devices. This allows for mythos to schedule Simulators and
calculate gradients using Objectives in parallel. This is particularly useful
when the Simulators are slow to run and the Objectives are expensive to
compute.
See advanced_optimizations for more details and examples.