"""Base classes for energy functions."""
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import InitVar
from typing import Any, Union
import chex
import jax
import jax.numpy as jnp
import jax_md
from typing_extensions import override
import mythos.utils.types as typ
from mythos.energy.configuration import BaseConfiguration
from mythos.input.topology import Topology
ERR_PARAM_NOT_FOUND = "Parameter '{key}' not found in {class_name}"
ERR_CALL_NOT_IMPLEMENTED = "Subclasses must implement this method"
ERR_COMPOSED_ENERGY_FN_LEN_MISMATCH = "Weights must have the same length as energy functions"
ERR_COMPOSED_ENERGY_FN_TYPE_ENERGY_FNS = "energy_fns must be a list of energy functions"
[docs]
class EnergyFunction(ABC):
"""Abstract base class for energy functions.
These are a class of callable-classes that take in a RigidBody and return
the energy of the system as a scalar float.
"""
[docs]
@abstractmethod
def __call__(self, body: jax_md.rigid_body.RigidBody) -> float:
"""Calculate the energy of the system."""
[docs]
@abstractmethod
def with_params(self, *repl_dicts: dict, **repl_kwargs: Any) -> "EnergyFunction":
"""Return a new energy function with updated parameters.
Args:
*repl_dicts (dict): dictionaries of parameters to update. These
must come first in the argument list and will be applied in
order.
**repl_kwargs: keyword arguments of parameters to update. These are
applied after any parameter dictionaries supplied as positional
arguments.
"""
[docs]
@abstractmethod
def with_props(self, **kwargs) -> "EnergyFunction":
"""Create a new energy function from this with updated properties.
Properties are those that are defined at the energy function class level
and not the parameters that are defined therein. For example, the
`displacement_fn` can be modified using this method.
"""
[docs]
@abstractmethod
def with_noopt(self, *params: str) -> "EnergyFunction":
"""Create a new energy function from this with specified parameters non-optimizable."""
[docs]
@abstractmethod
def params_dict(self, *, include_dependent: bool = True, exclude_non_optimizable: bool = False) -> dict:
"""Get the parameters as a dictionary.
Args:
include_dependent (bool): whether to include dependent parameters
exclude_non_optimizable (bool): whether to exclude non-optimizable parameters
"""
[docs]
@abstractmethod
def opt_params(self) -> dict[str, typ.Scalar]:
"""Get the configured optimizable parameters."""
[docs]
def map(self, body_sequence: jnp.ndarray) -> jnp.ndarray:
"""Map the energy function over a sequence of rigid bodies."""
return jax.vmap(self.__call__)(body_sequence)
[docs]
@chex.dataclass(frozen=True)
class BaseNucleotide(jax_md.rigid_body.RigidBody, ABC):
"""Base nucleotide class."""
center: typ.Arr_Nucleotide_3
orientation: typ.Arr_Nucleotide_3 | jax_md.rigid_body.Quaternion
stack_sites: typ.Arr_Nucleotide_3
back_sites: typ.Arr_Nucleotide_3
base_sites: typ.Arr_Nucleotide_3
back_base_vectors: typ.Arr_Nucleotide_3
base_normals: typ.Arr_Nucleotide_3
cross_prods: typ.Arr_Nucleotide_3
[docs]
@staticmethod
@abstractmethod
def from_rigid_body(rigid_body: jax_md.rigid_body.RigidBody, **kwargs) -> "BaseNucleotide":
"""Create an instance of the subclass from a RigidBody.."""
[docs]
@chex.dataclass(frozen=True, kw_only=True)
class BaseEnergyFunction(EnergyFunction):
"""Base class for energy functions.
This class should not be used directly. Subclasses should implement the __call__ method.
Parameters:
displacement_fn (Callable): an instance of a displacement function from jax_md.space
"""
params: BaseConfiguration
displacement_fn: Callable
seq: typ.Sequence | None = None
bonded_neighbors: typ.Arr_Bonded_Neighbors_2 | None = None
unbonded_neighbors: typ.Arr_Unbonded_Neighbors_2 | None = None
topology: InitVar[Topology | None] = None
transform_fn: Callable | None = None
[docs]
@override
def __post_init__(self, topology: Topology | None) -> None:
if topology:
object.__setattr__(self, "seq", topology.seq)
object.__setattr__(self, "bonded_neighbors", topology.bonded_neighbors)
object.__setattr__(self, "unbonded_neighbors", topology.unbonded_neighbors.T)
elif any([self.seq is None, self.bonded_neighbors is None, self.unbonded_neighbors is None]):
raise ValueError("Missing topology information")
[docs]
@classmethod
def create_from(cls, other: "EnergyFunction", **kwargs) -> "EnergyFunction":
"""Create a new energy function from another with updated properties.
Args:
other: the energy function to copy properties from
**kwargs: properties to update, overriding those from other
"""
props = dict(other) | kwargs
return cls(**props)
@property
def displacement_mapped(self) -> Callable:
"""Returns the displacement function mapped to the space."""
return jax_md.space.map_bond(self.displacement_fn)
[docs]
def __add__(self, other: "BaseEnergyFunction") -> "ComposedEnergyFunction":
"""Add two energy functions together to create a ComposedEnergyFunction."""
if not isinstance(other, BaseEnergyFunction):
return NotImplemented
return ComposedEnergyFunction(energy_fns=[self, other])
[docs]
def __mul__(self, other: float) -> "ComposedEnergyFunction":
"""Multiply an energy function by a scalar to create a ComposedEnergyFunction."""
if not isinstance(other, float | int):
return NotImplemented
return ComposedEnergyFunction(
energy_fns=[self],
weights=jnp.array([other], dtype=float),
)
[docs]
@override
def with_props(self, **kwargs: Any) -> EnergyFunction:
return self.replace(**kwargs)
[docs]
@override
def with_noopt(self, *params: str) -> EnergyFunction:
updated = set(self.params.non_optimizable_required_params) | set(params)
new_params = self.params.replace(non_optimizable_required_params=list(updated))
return self.replace(params=new_params)
[docs]
@override
def opt_params(self) -> dict[str, typ.Scalar]:
return self.params.opt_params
[docs]
@override
def with_params(self, *repl_dicts: dict, **repl_kwargs: Any) -> EnergyFunction:
new_params = self.params
for replacements in repl_dicts:
new_params = new_params | replacements
new_params = new_params | repl_kwargs
return self.replace(params=new_params.init_params())
[docs]
@override
def params_dict(self, include_dependent: bool = True, exclude_non_optimizable: bool = False) -> dict:
return self.params.to_dictionary(
include_dependent=include_dependent,
exclude_non_optimizable=exclude_non_optimizable,
)
[docs]
@override
def __call__(self, body: jax_md.rigid_body.RigidBody) -> float:
if self.transform_fn:
body = self.transform_fn(body)
return self.compute_energy(body)
[docs]
@abstractmethod
def compute_energy(self, nucleotide: BaseNucleotide) -> float:
"""Compute the energy of the system given the nucleotide."""
[docs]
@chex.dataclass(frozen=True)
class ComposedEnergyFunction(EnergyFunction):
"""Represents a linear combination of energy functions.
The parameters of all composite energy functions are treated as sharing a
global namespace in all setting and retrieval methods. For example, calling
`with_params(kt=0.1)` will set the parameter `kt` in all those energy
functions that contain a parameter name `kt`.
Parameters:
energy_fns (list[BaseEnergyFunction]): a list of energy functions
weights (jnp.ndarray): optional, the weights of the energy functions
"""
energy_fns: list[BaseEnergyFunction]
weights: jnp.ndarray | None = None
[docs]
def __post_init__(self) -> None:
"""Check that the input is valid."""
if not isinstance(self.energy_fns, list) or not all(
isinstance(fn, BaseEnergyFunction) for fn in self.energy_fns
):
raise TypeError(ERR_COMPOSED_ENERGY_FN_TYPE_ENERGY_FNS)
if self.weights is not None and len(self.weights) != len(self.energy_fns):
raise ValueError(ERR_COMPOSED_ENERGY_FN_LEN_MISMATCH)
[docs]
@override
def with_props(self, **kwargs: Any) -> "ComposedEnergyFunction":
energy_fns = [fn.with_props(**kwargs) for fn in self.energy_fns]
return self.replace(energy_fns=energy_fns)
[docs]
def _param_in_fn(self, param: str, fn: BaseEnergyFunction) -> bool:
"""Helper for with_params to check if a param is in a given energy function."""
return param in fn.params
[docs]
def _rename_param_for_fn(self, param: str, _fn: BaseEnergyFunction) -> str:
"""Helper to rename a param for input to a given energy function."""
return param
[docs]
def _rename_param_from_fn(self, param: str, _fn: BaseEnergyFunction) -> str:
"""Helper to rename a param for output from a given energy function."""
return param
[docs]
@override
def with_noopt(self, *params: str) -> "ComposedEnergyFunction":
energy_fns = []
for fn in self.energy_fns:
fn_params = [self._rename_param_for_fn(p, fn) for p in params if self._param_in_fn(p, fn)]
energy_fns.append(fn.with_noopt(*fn_params))
return self.replace(energy_fns=energy_fns)
[docs]
@override
def opt_params(self, from_fns: list[type] | None = None) -> dict[str, typ.Scalar]:
energy_fns = self.energy_fns if from_fns is None else [fn for fn in self.energy_fns if type(fn) in from_fns]
return {self._rename_param_from_fn(k, fn): v for fn in energy_fns for k, v in fn.opt_params().items()}
[docs]
@override
def with_params(self, *repl_dicts: dict, **repl_kwargs: Any) -> "ComposedEnergyFunction":
# track replacements which are actually applied to functions in order to
# error on unused replacements (assume this is unintended)
all_replacements = set(repl_kwargs) | {k for arg in repl_dicts for k in arg}
used_replacements = set()
energy_fns = []
for fn in self.energy_fns:
# Flatten all the dict-type arguments. prefer the keyword arguments
# over the dicts for replacements (they appear last in order).
new_params = {k: v for arg in repl_dicts for k, v in arg.items() if self._param_in_fn(k, fn)}
new_params.update({k: v for k, v in repl_kwargs.items() if self._param_in_fn(k, fn)})
used_replacements.update(new_params.keys())
# Rename replacement keys if necessary (e.g. for qualified overload)
new_params = {self._rename_param_for_fn(k, fn): v for k, v in new_params.items()}
energy_fns.append(fn.with_params(**new_params))
if unused := all_replacements - used_replacements:
raise ValueError(f"Some parameters were not used in any energy function: {unused}.")
return self.replace(energy_fns=energy_fns)
[docs]
@override
def params_dict(self, *, include_dependent: bool = True, exclude_non_optimizable: bool = False) -> dict:
params = {}
for fn in self.energy_fns:
fn_params = fn.params_dict(
include_dependent=include_dependent, exclude_non_optimizable=exclude_non_optimizable,
)
params.update({self._rename_param_from_fn(k, fn): v for k, v in fn_params.items()})
return params
[docs]
def compute_terms(self, body: jax_md.rigid_body.RigidBody) -> jnp.ndarray:
"""Compute each of the energy terms in the energy function."""
return jnp.array([fn(body) for fn in self.energy_fns])
[docs]
@override
def __call__(self, body: jax_md.rigid_body.RigidBody) -> float:
energy_vals = self.compute_terms(body)
return jnp.sum(energy_vals) if self.weights is None else jnp.dot(self.weights, energy_vals)
[docs]
def without_terms(self, *terms: list[str|type]) -> "ComposedEnergyFunction":
"""Create a new ComposedEnergyFunction without the specified terms.
Args:
*terms: all positional arguments should be either a type or a string
which is the name of the type to exclude.
Returns:
ComposedEnergyFunction: a new ComposedEnergyFunction without the
specified terms
"""
new_energy_fns = []
new_weights = []
for i, fn in enumerate(self.energy_fns):
if type(fn) in terms or fn.__class__.__name__ in terms:
continue
new_energy_fns.append(fn)
if self.weights is not None:
new_weights.append(self.weights[i])
new_weights = None if self.weights is None else jnp.array(new_weights)
return self.replace(energy_fns=new_energy_fns, weights=new_weights)
[docs]
def add_energy_fn(self, energy_fn: BaseEnergyFunction, weight: float = 1.0) -> "ComposedEnergyFunction":
"""Add an energy function to the list of energy functions.
Args:
energy_fn (BaseEnergyFunction): the energy function to add
weight (float): the weight of the energy function
Returns:
ComposedEnergyFunction: a new ComposedEnergyFunction with the added energy function
"""
if self.weights is None:
weights = None if weight == 1.0 else jnp.array([1.0] * len(self.energy_fns) + [weight])
else:
weights = jnp.concatenate([self.weights, jnp.array([weight])])
return ComposedEnergyFunction(
energy_fns=[*self.energy_fns, energy_fn],
weights=weights,
)
[docs]
def add_composable_energy_fn(self, energy_fn: "ComposedEnergyFunction") -> "ComposedEnergyFunction":
"""Add a ComposedEnergyFunction to the list of energy functions.
Args:
energy_fn (ComposedEnergyFunction): the ComposedEnergyFunction to add
Returns:
ComposedEnergyFunction: a new ComposedEnergyFunction with the added energy function
"""
other_weights = energy_fn.weights
w_none = self.weights is None
ow_none = other_weights is None
if w_none and ow_none:
weights = None
elif not w_none and not ow_none:
weights = jnp.concatenate([self.weights, other_weights])
else:
this_weights = self.weights if not w_none else jnp.ones(len(energy_fn.energy_fns))
other_weights = other_weights if not ow_none else jnp.ones(len(self.energy_fns))
weights = jnp.concatenate([this_weights, other_weights])
return ComposedEnergyFunction(
energy_fns=self.energy_fns + energy_fn.energy_fns,
weights=weights,
)
[docs]
def __add__(self, other: Union[BaseEnergyFunction, "ComposedEnergyFunction"]) -> "ComposedEnergyFunction":
"""Create a new ComposedEnergyFunction by adding another energy function.
This is a convenience method for the add_energy_fn and add_composable_energy_fn methods.
"""
if isinstance(other, BaseEnergyFunction):
energy_fn = self.add_energy_fn
elif isinstance(other, ComposedEnergyFunction):
energy_fn = self.add_composable_energy_fn
else:
return NotImplemented
return energy_fn(other)
[docs]
def __radd__(self, other: Union[BaseEnergyFunction, "ComposedEnergyFunction"]) -> "ComposedEnergyFunction":
"""Create a new ComposedEnergyFunction by adding another energy function.
This is a convenience method for the add_energy_fn and add_composable_energy_fn methods.
"""
return self.__add__(other)
[docs]
@classmethod
def from_lists(
cls,
energy_fns: list[BaseEnergyFunction],
energy_configs: list[BaseConfiguration],
weights: list[float] | None = None,
**kwargs,
) -> "ComposedEnergyFunction":
"""Create a ComposedEnergyFunction from lists of energy functions and weights.
Args:
energy_fns (list[BaseEnergyFunction]): a list of energy functions
energy_configs (list[BaseConfiguration]): a list of energy configurations
weights (list[float] | None): optional, a list of weights for the
energy functions
**kwargs: keyword arguments to pass to each energy function
Returns:
ComposedEnergyFunction: a new ComposedEnergyFunction
"""
weights = weights if weights is not None else jnp.ones(len(energy_fns))
functions_configs = zip(energy_fns, energy_configs, strict=True)
energy_fns = [ef(**kwargs, params=ec.init_params()) for ef, ec in functions_configs]
return cls(energy_fns=energy_fns, weights=weights)
[docs]
class QualifiedComposedEnergyFunction(ComposedEnergyFunction):
"""A ComposedEnergyFunction that qualifies parameters by their function.
Parameters for composite functions do not share a global namespace, but
instead are qualified by the function they belong to in all setting and
retrieval methods. For example, parameter `eps_backbone` in Fene energy
function would be referred to as `Fene.eps_backbone` in the this energy
function. This is useful for isolating parameters from a specific energy
function for optimization, however note that not all simulations will
support this functionality - for example oxDNA simulations write only one
value per parameter.
"""
[docs]
@override
def _param_in_fn(self, param: str, fn: BaseEnergyFunction) -> bool:
"""Helper for with_params to check if a param is in a given energy function."""
cls, param = param.split(".", 1)
return param in fn.params and fn.__class__.__qualname__ == cls
[docs]
@override
def _rename_param_for_fn(self, param: str, fn: BaseEnergyFunction) -> str:
return param.split(".", 1)[1]
[docs]
@override
def _rename_param_from_fn(self, param: str, fn: BaseEnergyFunction) -> str:
return f"{fn.__class__.__qualname__}.{param}"