Source code for mythos.observables.diameter

"""Diameter observable."""

import dataclasses as dc
from collections.abc import Callable

import chex
import jax
import jax.numpy as jnp
from jax_md import space

import mythos.observables.base as jd_obs
import mythos.simulators.io as jd_sio
import mythos.utils.types as jd_types
import mythos.utils.units as jd_units

TARGETS = {
    "oxDNA": 23.0,  # Angstroms. Experimental value for helical radius is 11.5-12 A
}

ERR_DISPLACEMENT_FN_REQUIRED = "A displacement function is required for computing the helical diameter."


[docs] def single_diameter( bp: jnp.ndarray, back_sites: jnp.ndarray, displacement_fn: Callable, sigma_backbone: float ) -> jd_types.ARR_OR_SCALAR: """Computes the helical diameter of a base pair. Args: bp (jnp.ndarray): a 2-dimensional array containing the indices of the h-bonded nucleotides back_sites (jnp.ndarray): a 2-dimensional array containing the positions of the backbone sites displacement_fn (Callable): a function for computing displacements between two positions sigma_backbone (float): the excluded volume distance between backbone sites Returns: jd_types.ARR_OR_SCALAR: the helical diameter in Angstroms for each base pair """ bp1, bp2 = bp # Compute the distance between the backbone sites dr = displacement_fn(back_sites[bp1], back_sites[bp2]) r = space.distance(dr) # Add the excluded volume distance r += sigma_backbone return r * jd_units.ANGSTROMS_PER_OXDNA_LENGTH
single_diameter_mapped = jax.vmap(single_diameter, (0, None, None, None))
[docs] @chex.dataclass(frozen=True, kw_only=True) class Diameter(jd_obs.BaseObservable): """Computes the helical diameter for each state. The helical diameter for a given base pair is defined as the furthest extent of the excluded volume. The diameter per state is the average over all (specified) base pairs. Args: - bp: a 2-dimensional array containing the indices of the h-bonded nucleotides - displacement_fn: a function for computing displacements between two positions """ h_bonded_base_pairs: jnp.ndarray = dc.field(hash=False) displacement_fn: Callable
[docs] def __post_init__(self) -> None: """Validate the input.""" if self.rigid_body_transform_fn is None: raise ValueError(jd_obs.ERR_RIGID_BODY_TRANSFORM_FN_REQUIRED) if self.displacement_fn is None: raise ValueError(ERR_DISPLACEMENT_FN_REQUIRED)
[docs] def __call__(self, trajectory: jd_sio.SimulatorTrajectory, sigma_backbone: float) -> jd_types.ARR_OR_SCALAR: """Calculate the average helical diameter in Angstroms. Args: trajectory (jd_traj.Trajectory): the trajectory to calculate the helical diameter for sigma_backbone (float): the excluded volume distance between backbone sites Returns: jd_types.ARR_OR_SCALAR: the average helical diameter in Angstroms for each state, so expect a size of (n_states,) """ nucleotides = jax.vmap(self.rigid_body_transform_fn)(trajectory.rigid_body) back_sites = nucleotides.back_sites diameters = jax.vmap(single_diameter_mapped, (None, 0, None, None))( self.h_bonded_base_pairs, back_sites, self.displacement_fn, sigma_backbone ) return jnp.mean(diameters, axis=1)