Source code for mythos.observables.stretch_torsion

"""Utility functions for computing stretch-torsion moduli."""

import dataclasses as dc
from collections.abc import Callable

import chex
import jax
import jax.numpy as jnp

import mythos.observables.base as jd_obs
import mythos.simulators.io as jd_sio
import mythos.utils.math as jd_math
import mythos.utils.types as jd_types


[docs] def single_angle_xy(quartet: jnp.ndarray, base_sites: jnp.ndarray, displacement_fn: Callable) -> jd_types.ARR_OR_SCALAR: """Computes the angle in the X-Y plane between adjacent base pairs.""" # Extract the base pairs bp1, bp2 = quartet (a1, b1), (a2, b2) = bp1, bp2 # Compute the vector between base sites for each base pair bb1 = displacement_fn(base_sites[b1], base_sites[a1]) bb2 = displacement_fn(base_sites[b2], base_sites[a2]) # Omit the z-direction from normalization bb1 = bb1[:2] bb2 = bb2[:2] # Normalize bb1 = bb1 / jnp.linalg.norm(bb1) bb2 = bb2 / jnp.linalg.norm(bb2) # Compute return jnp.arccos(jd_math.clamp(jnp.dot(bb1, bb2)))
[docs] @chex.dataclass(frozen=True, kw_only=True) class TwistXY(jd_obs.BaseObservable): """Computes the total twist of a duplex in the X-Y plane in radians. The total twist of a duplex is defined as the sum of angles in the X-Y plane between adjacent base pairs. Args: - quartets: a (n_quartets, 2, 2) array containing the pairs of adjacent base pairs - displacement_fn: a function for computing displacements between two positions """ quartets: jnp.ndarray = dc.field(hash=False) displacement_fn: Callable
[docs] def __post_init__(self) -> None: """Validate the input.""" if self.rigid_body_transform_fn is None: raise ValueError(jd_obs.ERR_RIGID_BODY_TRANSFORM_FN_REQUIRED)
[docs] def __call__(self, trajectory: jd_sio.SimulatorTrajectory) -> jd_types.ARR_OR_SCALAR: """Calculate the total twist in the X-Y plane in radians. Args: trajectory (jd_traj.Trajectory): the trajectory Returns: jd_types.ARR_OR_SCALAR: the total twist in radians for each state, so expect a size of (n_states,) """ nucleotides = jax.vmap(self.rigid_body_transform_fn)(trajectory.rigid_body) base_sites = nucleotides.base_sites per_state_mapper = jax.vmap(single_angle_xy, (0, None, None)) angles = jax.vmap(per_state_mapper, (None, 0, None))(self.quartets, base_sites, self.displacement_fn) return jnp.sum(angles, axis=1)
[docs] def single_extension_z( center: jd_types.Arr_Nucleotide_3, bp1: jnp.ndarray, bp2: jnp.ndarray, displacement_fn: Callable, ) -> jd_types.ARR_OR_SCALAR: """Computes the distance between the midpoints of two base pairs.""" # Extract the base pair indices a1, b1 = bp1 a2, b2 = bp2 # Compute the midpoints of each base pair using displacement_fn for PBC bp1_midp = center[a1] + displacement_fn(center[b1], center[a1]) / 2 bp2_midp = center[a2] + displacement_fn(center[b2], center[a2]) / 2 # Compute the extension between the two base pairs in the Z-direction # displacement(bp2, bp1) = bp2 - bp1, giving vector from bp1 to bp2 extension = displacement_fn(bp2_midp, bp1_midp) return jnp.abs(extension[2])
[docs] @chex.dataclass(frozen=True, kw_only=True) class ExtensionZ(jd_obs.BaseObservable): """Computes the total extension of a duplex in the Z-direction in simulation units. The total extension of a duplex is defined as the distance between the midpoints of two pre-specified base pairs in the Z-direction. Args: - bp1: a (2,) array specifying the indices of the first base pair - bp2: a (2,) array specifying the indices of the second base pair - displacement_fn: a function for computing displacements between two positions """ bp1: jnp.ndarray = dc.field(hash=False) bp2: jnp.ndarray = dc.field(hash=False) displacement_fn: Callable
[docs] def __post_init__(self) -> None: """Validate the input.""" if self.rigid_body_transform_fn is None: raise ValueError(jd_obs.ERR_RIGID_BODY_TRANSFORM_FN_REQUIRED)
[docs] def __call__(self, trajectory: jd_sio.SimulatorTrajectory) -> jd_types.ARR_OR_SCALAR: """Calculate the total extension in simulation units. Args: trajectory (jd_traj.Trajectory): the trajectory Returns: jd_types.ARR_OR_SCALAR: the total extension for each state, so expect a size of (n_states,) """ nucleotides = jax.vmap(self.rigid_body_transform_fn)(trajectory.rigid_body) center = nucleotides.center # return the extensions return jax.vmap(single_extension_z, (0, None, None, None))(center, self.bp1, self.bp2, self.displacement_fn)
[docs] def stretch(forces: jnp.ndarray, extensions: jnp.ndarray) -> tuple[float, float, float]: r"""Computes the effective stretch modulus and relevant summary statistics from stretch experiments. Following Assenza and Perez (JCTC 2022), the effective stretch modulus can be computed as .. math:: \tilde{S} = \frac{L_0}{A_1} where `A_1` and `L_0` are the slope and offset, respectively, of a linear force-extension fit. Args: forces (jnp.ndarray): the forces applied to the polymer extensions (jnp.ndarray): the equilibrium extensions under the applied forces Returns: Tuple[float, float, float]: the slope and offset of the linear fit, and the effective stretch modulus """ # Format the forces for line-fitting forces_ = jnp.stack([jnp.ones_like(forces), forces], axis=1) # Fit a line # Note: we do not fix l0 to be the extension under 0 force. We fit it as a parameter. fit_ = jnp.linalg.lstsq(forces_, extensions) # Extract statistics a1 = fit_[0][1] l0 = fit_[0][0] # Note: this is the equilibrium extension at 0 force and torque, *not* the contour length # Compute effective stretch modulus s_eff = l0 / a1 return a1, l0, s_eff
[docs] def torsion(torques: jnp.ndarray, extensions: jnp.ndarray, twists: jnp.ndarray) -> tuple[float, float]: """Computes the relevant summary statistics from torsion experiments. Following Assenza and Perez (JCTC 2022), the torsional modulus and twist-stretch coupling can be computed via linear fits to the extension and twist of a duplex under torque (when combined with similar statistics from stretching experiments). This function computes the slopes of these linear fits Args: torques (jnp.ndarray): the torques applied to the polymer extensions (jnp.ndarray): the equilibrium extensions under the applied torques twists (jnp.ndarray): the equilibrium twists under the applied torques Returns: Tuple[float, float]: the slopes of the linear fits to the extensions and twists, respectively """ # Format the torques for line-fitting torques_ = jnp.stack([jnp.ones_like(torques), torques], axis=1) # Fit a line to the extensions fit_ = jnp.linalg.lstsq(torques_, extensions) a3 = fit_[0][1] # Fit a line to the twists fit_ = jnp.linalg.lstsq(torques_, twists) a4 = fit_[0][1] return a3, a4
[docs] def stretch_torsion( forces: jnp.ndarray, force_extensions: jnp.ndarray, torques: jnp.ndarray, torque_extensions: jnp.ndarray, torque_twists: jnp.ndarray, ) -> tuple[float, float, float]: """Computes the effective stretch and torsion moduli, and twist-stretch coupling from stretch-torsion experiments. Args: forces (jnp.ndarray): the forces applied to the polymer force_extensions (jnp.ndarray): the equilibrium extensions under the applied forces torques (jnp.ndarray): the torques applied to the polymer torque_extensions (jnp.ndarray): the equilibrium extensions under the applied torques torque_twists (jnp.ndarray): the equilibrium twists under the applied torques Returns: Tuple[float, float, float]: the effective stretch modulus, torsional modulus, and twist-stretch coupling """ # Compute the effective stretch modulus and relevant summary statistics from stretching experiments a1, l0, s_eff = stretch(forces, force_extensions) # Compute the relevant summary statistics from torsion experiments a3, a4 = torsion(torques, torque_extensions, torque_twists) # Compute the torsional modulus and twist-stretch coupling c = a1 * l0 / (a4 * a1 - a3**2) g = -(a3 * l0) / (a4 * a1 - a3**2) return s_eff, c, g