Source code for mythos.observables.rise

"""Rise observable."""

import dataclasses as dc
from collections.abc import Callable

import chex
import jax
import jax.numpy as jnp

import mythos.observables.base as jd_obs
import mythos.simulators.io as jd_sio
import mythos.utils.types as jd_types
import mythos.utils.units as jd_units

TARGETS = {
    "oxDNA": 3.4,  # Angstroms
}


[docs] def single_rise(quartet: jnp.ndarray, base_sites: jnp.ndarray, displacement_fn: Callable) -> jd_types.ARR_OR_SCALAR: """Computes the rise between adjacent base pairs.""" # Extract the base pairs bp1, bp2 = quartet (a1, b1), (a2, b2) = bp1, bp2 # Compute the local helical axis local_helix_dir = jd_obs.local_helical_axis(quartet, base_sites, displacement_fn) # Compute the midpoints of each base pair midp1 = (base_sites[a1] + base_sites[b1]) / 2.0 midp2 = (base_sites[a2] + base_sites[b2]) / 2.0 # Project the displacement between the midpoints onto the local helical axis dr = displacement_fn(midp2, midp1) rise = jnp.dot(dr, local_helix_dir) return rise * jd_units.ANGSTROMS_PER_OXDNA_LENGTH
single_rise_mapped = jax.vmap(single_rise, (0, None, None))
[docs] @chex.dataclass(frozen=True, kw_only=True) class Rise(jd_obs.BaseObservable): """Computes the rise for each state. The rise between two adjacent base pairs is defined as the distance of the displacement vector between their midpoints projected onto the local helical axis. The rise per state is the average rise over all (specified) pairs of adjacent base pairs (i.e. quartets) Args: quartets: a (n_bp, 2, 2) array containing the pairs of adjacent base pairs for which to compute the rise displacement_fn: a function for computing displacements between two positions """ quartets: jnp.ndarray = dc.field(hash=False) displacement_fn: Callable
[docs] def __post_init__(self) -> None: """Validate the input.""" if self.rigid_body_transform_fn is None: raise ValueError(jd_obs.ERR_RIGID_BODY_TRANSFORM_FN_REQUIRED)
[docs] def __call__(self, trajectory: jd_sio.SimulatorTrajectory) -> jd_types.ARR_OR_SCALAR: """Calculate the average rise in Angstroms. Args: trajectory (jd_sio.Trajectory): the trajectory to calculate the rise for Returns: jd_types.ARR_OR_SCALAR: the average rise in Angstroms for each state, so expect a size of (n_states,) """ nucleotides = jax.vmap(self.rigid_body_transform_fn)(trajectory.rigid_body) base_sites = nucleotides.base_sites rises = jax.vmap(single_rise_mapped, (None, 0, None))(self.quartets, base_sites, self.displacement_fn) return jnp.mean(rises, axis=1)