"""Utility functions for energy calculations."""
import functools
import importlib
import jax
import jax.numpy as jnp
import jax_md
import numpy as np
from jax import vmap
from jaxtyping import PyTree
import mythos.utils.constants as jd_const
import mythos.utils.types as typ
from mythos.input import toml
[docs]
@vmap
def q_to_back_base(q: jax_md.rigid_body.Quaternion) -> jnp.ndarray:
"""Get the vector from the center to the base of the nucleotide."""
q0, q1, q2, q3 = q.vec
return jnp.array([q0**2 + q1**2 - q2**2 - q3**2, 2 * (q1 * q2 + q0 * q3), 2 * (q1 * q3 - q0 * q2)])
[docs]
@vmap
def q_to_base_normal(q: jax_md.rigid_body.Quaternion) -> jnp.ndarray:
"""Get the normal vector to the base of the nucleotide."""
q0, q1, q2, q3 = q.vec
return jnp.array([2 * (q1 * q3 + q0 * q2), 2 * (q2 * q3 - q0 * q1), q0**2 - q1**2 - q2**2 + q3**2])
[docs]
@vmap
def q_to_cross_prod(q: jax_md.rigid_body.Quaternion) -> jnp.ndarray:
"""Get the cross product vector of the nucleotide."""
q0, q1, q2, q3 = q.vec
return jnp.array([2 * (q1 * q2 - q0 * q3), q0**2 - q1**2 + q2**2 - q3**2, 2 * (q2 * q3 + q0 * q1)])
[docs]
@functools.partial(vmap, in_axes=(None, 0, 0), out_axes=0)
def get_pair_probs(seq: typ.Arr_Nucleotide_4, i: int, j: int) -> jnp.ndarray:
"""Get the pair probabilities for a sequence."""
return jnp.kron(seq[i], seq[j])
[docs]
def compute_seq_dep_weight(
pseq: typ.Probabilistic_Sequence,
nt1: int,
nt2: int,
weights_table: np.ndarray,
is_unpaired: typ.Arr_Nucleotide_Int,
idx_to_unpaired_idx: typ.Arr_Nucleotide_Int,
idx_to_bp_idx: typ.Arr_Nucleotide_2_Int,
) -> float:
"""Computes the sequence-dependent weight for an interaction given a probabilistic sequence."""
unpaired_pseq, bp_pseq = pseq
flattened_weights_table = weights_table.flatten()
nt1_unpaired = is_unpaired[nt1]
nt2_unpaired = is_unpaired[nt2]
# Case 1: Both unpaired
pair_probs = jnp.kron(unpaired_pseq[idx_to_unpaired_idx[nt1]], unpaired_pseq[idx_to_unpaired_idx[nt2]])
pair_weight_unpaired = jnp.dot(pair_probs, flattened_weights_table)
# Case 2: nt1 unpaired, nt2 base paired
nt1_nt_probs = unpaired_pseq[idx_to_unpaired_idx[nt1]]
nt2_bp_idx, within_nt2_bp_idx = idx_to_bp_idx[nt2]
bp_probs = bp_pseq[nt2_bp_idx]
def nt1_up_fn(nt1_nt: int, nt2_bp_type_idx: int) -> float:
nt2_nt = jd_const.BP_IDXS[nt2_bp_type_idx][within_nt2_bp_idx]
return nt1_nt_probs[nt1_nt] * bp_probs[nt2_bp_type_idx] * weights_table[nt1_nt, nt2_nt]
pair_weight_nt1_up = vmap(vmap(nt1_up_fn, (None, 0)), (0, None))(
jnp.arange(jd_const.N_NT), jnp.arange(jd_const.N_BP_TYPES)
).sum()
# Case 3: nt2 unpaired, nt1 base paired
nt2_nt_probs = unpaired_pseq[idx_to_unpaired_idx[nt2]]
nt1_bp_idx, within_nt1_bp_idx = idx_to_bp_idx[nt1]
bp_probs = bp_pseq[nt1_bp_idx]
def nt2_up_fn(nt2_nt: int, nt1_bp_type_idx: int) -> float:
nt1_nt = jd_const.BP_IDXS[nt1_bp_type_idx][within_nt1_bp_idx]
return nt2_nt_probs[nt2_nt] * bp_probs[nt1_bp_type_idx] * weights_table[nt1_nt, nt2_nt]
pair_weight_nt2_up = vmap(vmap(nt2_up_fn, (None, 0)), (0, None))(
jnp.arange(jd_const.N_NT), jnp.arange(jd_const.N_BP_TYPES)
).sum()
# Case 4: both nt1 and nt2 are base paired
nt1_bp_idx, within_nt1_bp_idx = idx_to_bp_idx[nt1]
nt2_bp_idx, within_nt2_bp_idx = idx_to_bp_idx[nt2]
## Case 4.I: nt1 and nt2 are in the same base pair
bp_probs = bp_pseq[nt1_bp_idx]
def same_bp_fn(bp_idx: int) -> float:
bp_prob = bp_probs[bp_idx]
bp_nt1, bp_nt2 = jd_const.BP_IDXS[bp_idx][jnp.array([within_nt1_bp_idx, within_nt2_bp_idx])]
return bp_prob * weights_table[bp_nt1, bp_nt2]
pair_weight_same_bp = vmap(same_bp_fn)(jnp.arange(jd_const.N_BP_TYPES)).sum()
## Case 4.II: nt1 and nt2 are in different base pairs
bp1_probs = bp_pseq[nt1_bp_idx]
bp2_probs = bp_pseq[nt2_bp_idx]
def diff_bps_fn(bp1_idx: int, bp2_idx: int) -> float:
bp1_prob = bp1_probs[bp1_idx]
nt1_nt = jd_const.BP_IDXS[bp1_idx][within_nt1_bp_idx]
bp2_prob = bp2_probs[bp2_idx]
nt2_nt = jd_const.BP_IDXS[bp2_idx][within_nt2_bp_idx]
return bp1_prob * bp2_prob * weights_table[nt1_nt, nt2_nt]
pair_weight_diff_bps = vmap(vmap(diff_bps_fn, (None, 0)), (0, None))(
jnp.arange(jd_const.N_BP_TYPES), jnp.arange(jd_const.N_BP_TYPES)
).sum()
pair_weight_both_paired = jnp.where(nt1_bp_idx == nt2_bp_idx, pair_weight_same_bp, pair_weight_diff_bps)
return jnp.where(
jnp.logical_and(nt1_unpaired, nt2_unpaired),
pair_weight_unpaired,
jnp.where(
nt1_unpaired, pair_weight_nt1_up, jnp.where(nt2_unpaired, pair_weight_nt2_up, pair_weight_both_paired)
),
)
[docs]
def default_configs_for(base: str) -> tuple[PyTree, PyTree]:
"""Return the default simulation and energy configuration files for a given type."""
config_dir = importlib.resources.files("mythos").joinpath("input").joinpath(base)
sim_config_path = config_dir.joinpath("default_simulation.toml")
energy_config_path = config_dir.joinpath("default_energy.toml")
def cast_f(x: float | list[float]) -> jnp.ndarray:
return jnp.array(x, dtype=jnp.float64)
return (
jax.tree.map(cast_f, toml.parse_toml(sim_config_path)),
jax.tree.map(cast_f, toml.parse_toml(energy_config_path)),
)