"""Stacking energy function for DNA1 model."""
import chex
import jax.numpy as jnp
import numpy as np
from jax import vmap
from typing_extensions import override
import mythos.energy.base as je_base
import mythos.energy.configuration as config
import mythos.energy.dna1.base_smoothing_functions as bsf
import mythos.energy.dna1.interactions as dna1_interactions
import mythos.utils.math as jd_math
import mythos.utils.types as typ
from mythos.energy.utils import compute_seq_dep_weight
from mythos.input.sequence_constraints import SequenceConstraints
STACK_WEIGHTS_SA = jnp.array(
[
[1.0, 1.0, 1.0, 1.0], # AX
[1.0, 1.0, 1.0, 1.0], # CX
[1.0, 1.0, 1.0, 1.0], # GX
[1.0, 1.0, 1.0, 1.0], # TX
]
)
[docs]
@chex.dataclass(frozen=True)
class StackingConfiguration(config.BaseConfiguration):
"""Configuration for the stacking energy function."""
# independent parameters
eps_stack_base: float | None = None
eps_stack_kt_coeff: float | None = None
dr_low_stack: float | None = None
dr_high_stack: float | None = None
a_stack: float | None = None
dr0_stack: float | None = None
dr_c_stack: float | None = None
theta0_stack_4: float | None = None
delta_theta_star_stack_4: float | None = None
a_stack_4: float | None = None
theta0_stack_5: float | None = None
delta_theta_star_stack_5: float | None = None
a_stack_5: float | None = None
theta0_stack_6: float | None = None
delta_theta_star_stack_6: float | None = None
a_stack_6: float | None = None
neg_cos_phi1_star_stack: float | None = None
a_stack_1: float | None = None
neg_cos_phi2_star_stack: float | None = None
a_stack_2: float | None = None
# Set these to enable probabilistic sequence
pseq: typ.Probabilistic_Sequence | None = None
pseq_constraints: SequenceConstraints | None = None
kt: float | None = None
ss_stack_weights: np.ndarray | None = None
# dependent parameters
b_low_stack: float | None = None
dr_c_low_stack: float | None = None
b_high_stack: float | None = None
dr_c_high_stack: float | None = None
b_stack_4: float | None = None
delta_theta_stack_4_c: float | None = None
b_stack_5: float | None = None
delta_theta_stack_5_c: float | None = None
b_stack_6: float | None = None
delta_theta_stack_6_c: float | None = None
b_neg_cos_phi1_stack: float | None = None
neg_cos_phi1_c_stack: float | None = None
b_neg_cos_phi2_stack: float | None = None
neg_cos_phi2_c_stack: float | None = None
eps_stack: float | None = None
required_params: tuple[str] = (
"eps_stack_base",
"eps_stack_kt_coeff",
"dr_low_stack",
"dr_high_stack",
"a_stack",
"dr0_stack",
"dr_c_stack",
"theta0_stack_4",
"delta_theta_star_stack_4",
"a_stack_4",
"theta0_stack_5",
"delta_theta_star_stack_5",
"a_stack_5",
"theta0_stack_6",
"delta_theta_star_stack_6",
"a_stack_6",
"neg_cos_phi1_star_stack",
"a_stack_1",
"neg_cos_phi2_star_stack",
"a_stack_2",
"kt",
)
dependent_params: tuple[str] = (
"b_low_stack",
"dr_c_low_stack",
"b_high_stack",
"dr_c_high_stack",
"b_stack_4",
"delta_theta_stack_4_c",
"b_stack_5",
"delta_theta_stack_5_c",
"b_stack_6",
"delta_theta_stack_6_c",
"b_neg_cos_phi1_stack",
"neg_cos_phi1_c_stack",
"b_neg_cos_phi2_stack",
"neg_cos_phi2_c_stack",
"eps_stack",
)
[docs]
@override
def init_params(self) -> "StackingConfiguration":
if self.pseq and self.pseq_constraints is None:
raise ValueError("pseq_constraints must be provided when pseq is provided.")
if self.ss_stack_weights is None:
eps_stack = (self.eps_stack_base + self.eps_stack_kt_coeff * self.kt) * STACK_WEIGHTS_SA
else:
eps_stack = self.ss_stack_weights * (
1.0 - self.eps_stack_kt_coeff + (self.kt * 9.0 * self.eps_stack_kt_coeff)
)
b_low_stack, dr_c_low_stack, b_high_stack, dr_c_high_stack = bsf.get_f1_smoothing_params(
self.dr0_stack,
self.a_stack,
self.dr_c_stack,
self.dr_low_stack,
self.dr_high_stack,
)
b_stack_4, delta_theta_stack_4_c = bsf.get_f4_smoothing_params(
self.a_stack_4,
self.theta0_stack_4,
self.delta_theta_star_stack_4,
)
b_stack_5, delta_theta_stack_5_c = bsf.get_f4_smoothing_params(
self.a_stack_5,
self.theta0_stack_5,
self.delta_theta_star_stack_5,
)
b_stack_6, delta_theta_stack_6_c = bsf.get_f4_smoothing_params(
self.a_stack_6,
self.theta0_stack_6,
self.delta_theta_star_stack_6,
)
b_neg_cos_phi1_stack, neg_cos_phi1_c_stack = bsf.get_f5_smoothing_params(
self.a_stack_1,
self.neg_cos_phi1_star_stack,
)
b_neg_cos_phi2_stack, neg_cos_phi2_c_stack = bsf.get_f5_smoothing_params(
self.a_stack_2,
self.neg_cos_phi2_star_stack,
)
return self.replace(
b_low_stack=b_low_stack,
dr_c_low_stack=dr_c_low_stack,
b_high_stack=b_high_stack,
dr_c_high_stack=dr_c_high_stack,
b_stack_4=b_stack_4,
delta_theta_stack_4_c=delta_theta_stack_4_c,
b_stack_5=b_stack_5,
delta_theta_stack_5_c=delta_theta_stack_5_c,
b_stack_6=b_stack_6,
delta_theta_stack_6_c=delta_theta_stack_6_c,
b_neg_cos_phi1_stack=b_neg_cos_phi1_stack,
neg_cos_phi1_c_stack=neg_cos_phi1_c_stack,
b_neg_cos_phi2_stack=b_neg_cos_phi2_stack,
neg_cos_phi2_c_stack=neg_cos_phi2_c_stack,
eps_stack=eps_stack,
)
[docs]
@chex.dataclass(frozen=True)
class Stacking(je_base.BaseEnergyFunction):
"""Stacking energy function for DNA1 model."""
params: StackingConfiguration
[docs]
def compute_v_stack(
self,
stack_sites: typ.Arr_Nucleotide_3,
back_sites: typ.Arr_Nucleotide_3,
base_normals: typ.Arr_Nucleotide_3,
cross_prods: typ.Arr_Nucleotide_3,
bonded_neighbors: typ.Arr_Bonded_Neighbors_2,
) -> typ.Arr_Bonded_Neighbors:
"""Computes the sequence-independent energy for each bonded pair."""
nn_i = bonded_neighbors[:, 0]
nn_j = bonded_neighbors[:, 1]
dr_back_nn = self.displacement_mapped(back_sites[nn_i], back_sites[nn_j]) # N x N x 3
r_back_nn = jnp.linalg.norm(dr_back_nn, axis=1)
dr_stack_nn = self.displacement_mapped(stack_sites[nn_i], stack_sites[nn_j])
r_stack_nn = jnp.linalg.norm(dr_stack_nn, axis=1)
theta4 = jnp.arccos(jd_math.clamp(jnp.einsum("ij, ij->i", base_normals[nn_i], base_normals[nn_j])))
theta5 = jnp.pi - jnp.arccos(
jd_math.clamp(jnp.einsum("ij, ij->i", dr_stack_nn, base_normals[nn_j]) / r_stack_nn)
)
theta6 = jnp.pi - jnp.arccos(
jd_math.clamp(jnp.einsum("ij, ij->i", base_normals[nn_i], dr_stack_nn) / r_stack_nn)
)
cosphi1 = -jnp.einsum("ij, ij->i", cross_prods[nn_i], dr_back_nn) / r_back_nn
cosphi2 = -jnp.einsum("ij, ij->i", cross_prods[nn_j], dr_back_nn) / r_back_nn
return dna1_interactions.stacking(
r_stack_nn,
theta4,
theta5,
theta6,
cosphi1,
cosphi2,
self.params.dr_low_stack,
self.params.dr_high_stack,
1, # eps is applied via weighting later, from eps_stack NT-NT matrix
self.params.a_stack,
self.params.dr0_stack,
self.params.dr_c_stack,
self.params.dr_c_low_stack,
self.params.dr_c_high_stack,
self.params.b_low_stack,
self.params.b_high_stack,
self.params.theta0_stack_4,
self.params.delta_theta_star_stack_4,
self.params.a_stack_4,
self.params.delta_theta_stack_4_c,
self.params.b_stack_4,
self.params.theta0_stack_5,
self.params.delta_theta_star_stack_5,
self.params.a_stack_5,
self.params.delta_theta_stack_5_c,
self.params.b_stack_5,
self.params.theta0_stack_6,
self.params.delta_theta_star_stack_6,
self.params.a_stack_6,
self.params.delta_theta_stack_6_c,
self.params.b_stack_6,
self.params.neg_cos_phi1_star_stack,
self.params.a_stack_1,
self.params.neg_cos_phi1_c_stack,
self.params.b_neg_cos_phi1_stack,
self.params.neg_cos_phi2_star_stack,
self.params.a_stack_2,
self.params.neg_cos_phi2_c_stack,
self.params.b_neg_cos_phi2_stack,
)
[docs]
def pseq_weights(self, i: int, j: int, seq: typ.Probabilistic_Sequence) -> float:
"""Computes the probabilistic sequence-dependent weight for a bonded pair."""
sc = self.params.pseq_constraints
return compute_seq_dep_weight(
seq, i, j, self.params.eps_stack, sc.is_unpaired, sc.idx_to_unpaired_idx, sc.idx_to_bp_idx
)
[docs]
def pairwise_energies(
self,
body: je_base.BaseNucleotide,
seq: typ.Discrete_Sequence,
bonded_neighbors: typ.Arr_Bonded_Neighbors_2,
) -> typ.Arr_Bonded_Neighbors:
"""Computes the stacking energy for each bonded pair."""
# Compute sequence-independent energy for each bonded pair
v_stack = self.compute_v_stack(
body.stack_sites, body.back_sites, body.base_normals, body.cross_prods, bonded_neighbors
)
# Compute sequence-dependent weight for each bonded pair
nn_i = bonded_neighbors[:, 0]
nn_j = bonded_neighbors[:, 1]
if self.params.pseq:
stack_weights = vmap(self.pseq_weights, (0, 0, None))(nn_i, nn_j, self.params.pseq)
else:
stack_weights = self.params.eps_stack[seq[nn_i], seq[nn_j]]
return jnp.multiply(stack_weights, v_stack)
[docs]
@override
def compute_energy(self, nucleotide: je_base.BaseNucleotide) -> typ.Scalar:
dgs = self.pairwise_energies(nucleotide, self.seq, self.bonded_neighbors)
return dgs.sum()