Source code for mythos.energy

"""The energy model and function for mythos."""

from collections.abc import Callable

import jax_md

import mythos.utils.types as jdna_types
from mythos.energy import base, configuration

DEFAULT_DISPLACEMENT = jax_md.space.free()[0]


[docs] def energy_fn_builder( energy_fns: list[base.BaseEnergyFunction], energy_configs: list[configuration.BaseConfiguration], transform_fn: Callable[[jdna_types.PyTree], jdna_types.PyTree], displacement_fn: Callable[[jdna_types.PyTree], jdna_types.PyTree] = DEFAULT_DISPLACEMENT, ) -> Callable[[jdna_types.PyTree], float]: """Build an energy function from a list of energy functions and configurations.""" def energy_fn( opt_params: jdna_types.PyTree, ) -> base.ComposedEnergyFunction: """Energy function generated using mythos.energy.energy_fn_builder. Input: opt_params (jdna_types.PyTree): the parameters of the energy function Returns: float: the energy of the system """ transformed_fns = [ e_fn( displacement_fn=displacement_fn, params=(e_c | param).init_params(), ) for param, e_c, e_fn in zip(opt_params, energy_configs, energy_fns, strict=True) ] return base.ComposedEnergyFunction( energy_fns=transformed_fns, rigid_body_transform_fn=transform_fn, ) return energy_fn