"""Propeller twist observable."""
import dataclasses as dc
import chex
import jax
import jax.numpy as jnp
import mythos.observables.base as jd_obs
import mythos.simulators.io as jd_sio
import mythos.utils.math as jd_math
import mythos.utils.types as jd_types
TARGETS = {
"oxDNA": 21.7, # degrees
}
[docs]
def single_propeller_twist_rad(
bp: jnp.ndarray, # this is a array of shape (2,) containing the indices of the h-bonded nucleotides
base_normals: jnp.ndarray,
) -> jnp.ndarray:
"""Computes the propeller twist of a base pair."""
# get the normal vectors of the h-bonded bases
bp1, bp2 = bp
nv1 = base_normals[bp1]
nv2 = base_normals[bp2]
# compute angle between base normal vectors
return jnp.arccos(jd_math.clamp(jnp.dot(nv1, nv2)))
propeller_twist_rad = jax.vmap(single_propeller_twist_rad, in_axes=(0, None))
[docs]
@chex.dataclass(frozen=True)
class PropellerTwist(jd_obs.BaseObservable):
"""Computes the propeller twist of a base pair.
The propeller twist is defined as the angle between the normal vectors
of h-bonded bases
Args:
- bp: a 2-dimensional array containing the indices of the h-bonded nucleotides
- base_normals: the base normal vectors of the entire body
"""
h_bonded_base_pairs: jnp.ndarray = dc.field(
hash=False
) # a 2-dimensional array containing the indices of the h-bonded nucleotides
[docs]
def __post_init__(self) -> None:
"""Validate the input."""
if self.rigid_body_transform_fn is None:
raise ValueError(jd_obs.ERR_RIGID_BODY_TRANSFORM_FN_REQUIRED)
[docs]
def __call__(self, trajectory: jd_sio.SimulatorTrajectory) -> jd_types.ARR_OR_SCALAR:
"""Calculate the twist of the propeller in degrees.
Args:
trajectory (jd_traj.Trajectory): the trajectory to calculate the propeller twist for
Returns:
jd_types.ARR_OR_SCALAR: the propeller twist in degrees for each state , so expect a size
of (n_states,)
"""
nucleotides = jax.vmap(self.rigid_body_transform_fn)(trajectory.rigid_body)
base_normals = nucleotides.base_normals
ptwist = jax.vmap(lambda bn: 180.0 - (propeller_twist_rad(self.h_bonded_base_pairs, bn) * 180.0 / jnp.pi))
return jnp.mean(ptwist(base_normals), axis=1)