Source code for mythos.observables.base

"""Base class for observables."""

import itertools
from collections.abc import Callable

import chex
import jax.numpy as jnp

import mythos.simulators.io as jd_sio

ERR_RIGID_BODY_TRANSFORM_FN_REQUIRED = "rigid_body_transform_fn must be provided"


[docs] @chex.dataclass(frozen=True) class BaseObservable: """Base class for observables.""" rigid_body_transform_fn: Callable
[docs] def __call__(self, trajectory: jd_sio.SimulatorTrajectory) -> jnp.ndarray: """Calculate the observable."""
[docs] def local_helical_axis_with_norm( quartet: jnp.ndarray, base_sites: jnp.ndarray, displacement_fn: Callable ) -> jnp.ndarray: """Computes the norm and normalized local helical axis defined by two base pairs.""" # Extract the two base pairs. a1 is h-bonded to b1, a2 is h-bonded to b2 bp1, bp2 = quartet (a1, b1), (a2, b2) = bp1, bp2 # Compute the midpoints of each base pair midp_a1b1 = (base_sites[a1] + base_sites[b1]) / 2.0 midp_a2b2 = (base_sites[a2] + base_sites[b2]) / 2.0 # Compute the normalized direction between the midpoints dr = displacement_fn(midp_a2b2, midp_a1b1) norm = jnp.linalg.norm(dr) return dr / norm, norm
[docs] def local_helical_axis(quartet: jnp.ndarray, base_sites: jnp.ndarray, displacement_fn: Callable) -> jnp.ndarray: """Computes the normalized local helical axis defined by two base pairs.""" dr, _ = local_helical_axis_with_norm(quartet, base_sites, displacement_fn) return dr
[docs] def get_duplex_quartets(n_nucs_per_strand: int) -> jnp.ndarray: """Computes all quartets (i.e. pairs of adjacent base pairs) for a duplex of a given size. Args: n_nucs_per_strand (int): number of nucleotides on each strand Returns: jnp.ndarray: array of all quartets """ # Construct the indices of the nucleotides on each strand s1_nucs = list(range(n_nucs_per_strand)) s2_nucs = list(range(n_nucs_per_strand, n_nucs_per_strand * 2)) s2_nucs.reverse() # Record all pairs of adjacent base pairs bps = list(zip(s1_nucs, s2_nucs, strict=True)) all_quartets = list(map(list, itertools.pairwise(bps))) return jnp.array(all_quartets, dtype=jnp.int32)