Source code for mythos.energy.dna1.hydrogen_bonding

"""Hydrogen bonding energy function for DNA1 model."""


import chex
import jax.numpy as jnp
import numpy as np
from jax import vmap
from typing_extensions import override

import mythos.energy.base as je_base
import mythos.energy.configuration as config
import mythos.energy.dna1.base_smoothing_functions as bsf
import mythos.energy.dna1.interactions as dna1_interactions
import mythos.utils.math as jd_math
import mythos.utils.types as typ
from mythos.energy.utils import compute_seq_dep_weight
from mythos.input.sequence_constraints import SequenceConstraints

HB_WEIGHTS_SA = jnp.array(
    [
        [0.0, 0.0, 0.0, 1.0],  # AX
        [0.0, 0.0, 1.0, 0.0],  # CX
        [0.0, 1.0, 0.0, 0.0],  # GX
        [1.0, 0.0, 0.0, 0.0],  # TX
    ]
)


[docs] @chex.dataclass(frozen=True) class HydrogenBondingConfiguration(config.BaseConfiguration): """Configuration for the hydrogen bonding energy function.""" # independent parameters =================================================== # reference to f1(dr_hb) eps_hb: float | None = None a_hb: float | None = None dr0_hb: float | None = None dr_c_hb: float | None = None dr_low_hb: float | None = None dr_high_hb: float | None = None # reference to f4(theta_1) a_hb_1: float | None = None theta0_hb_1: float | None = None delta_theta_star_hb_1: float | None = None # reference to f4(theta_2) a_hb_2: float | None = None theta0_hb_2: float | None = None delta_theta_star_hb_2: float | None = None # reference to f4(theta_3) a_hb_3: float | None = None theta0_hb_3: float | None = None delta_theta_star_hb_3: float | None = None # reference to f4(theta_4) a_hb_4: float | None = None theta0_hb_4: float | None = None delta_theta_star_hb_4: float | None = None # reference to f4(theta_7) a_hb_7: float | None = None theta0_hb_7: float | None = None delta_theta_star_hb_7: float | None = None # reference to f4(theta_8) a_hb_8: float | None = None theta0_hb_8: float | None = None delta_theta_star_hb_8: float | None = None # required but not optimizable ss_hb_weights: np.ndarray | None = None # dependent parameters ===================================================== b_low_hb: float | None = None dr_c_low_hb: float | None = None b_high_hb: float | None = None dr_c_high_hb: float | None = None b_hb_1: float | None = None delta_theta_hb_1_c: float | None = None b_hb_2: float | None = None delta_theta_hb_2_c: float | None = None b_hb_3: float | None = None delta_theta_hb_3_c: float | None = None b_hb_4: float | None = None delta_theta_hb_4_c: float | None = None b_hb_7: float | None = None delta_theta_hb_7_c: float | None = None b_hb_8: float | None = None delta_theta_hb_8_c: float | None = None # Sequence-dependence parameters eps_hb_weights: np.ndarray | None = None # probabilistic sequence and constraints pseq: typ.Probabilistic_Sequence | None = None pseq_constraints: SequenceConstraints | None = None # override required_params: tuple[str] = ( # Sequence-independence "eps_hb", "a_hb", "dr0_hb", "dr_c_hb", "dr_low_hb", "dr_high_hb", "a_hb_1", "theta0_hb_1", "delta_theta_star_hb_1", "a_hb_2", "theta0_hb_2", "delta_theta_star_hb_2", "a_hb_3", "theta0_hb_3", "delta_theta_star_hb_3", "a_hb_4", "theta0_hb_4", "delta_theta_star_hb_4", "a_hb_7", "theta0_hb_7", "delta_theta_star_hb_7", "a_hb_8", "theta0_hb_8", "delta_theta_star_hb_8", ) # override dependent_params: tuple[str] = ( "b_low_hb", "dr_c_low_hb", "b_high_hb", "dr_c_high_hb", "b_hb_1", "delta_theta_hb_1_c", "b_hb_2", "delta_theta_hb_2_c", "b_hb_3", "delta_theta_hb_3_c", "b_hb_4", "delta_theta_hb_4_c", "b_hb_7", "delta_theta_hb_7_c", "b_hb_8", "delta_theta_hb_8_c", "eps_hb_weights", )
[docs] @override def init_params(self) -> "HydrogenBondingConfiguration": if self.pseq is not None and self.pseq_constraints is None: raise ValueError("pseq_constraints must be provided when pseq is provided.") eps_hb_weights = HB_WEIGHTS_SA * self.eps_hb if self.ss_hb_weights is None else self.ss_hb_weights # reference to f1(dr_hb) b_low_hb, dr_c_low_hb, b_high_hb, dr_c_high_hb = bsf.get_f1_smoothing_params( self.dr0_hb, self.a_hb, self.dr_c_hb, self.dr_low_hb, self.dr_high_hb, ) # reference to f4(theta_1) b_hb_1, delta_theta_hb_1_c = bsf.get_f4_smoothing_params( self.a_hb_1, self.theta0_hb_1, self.delta_theta_star_hb_1, ) # reference to f4(theta_2) b_hb_2, delta_theta_hb_2_c = bsf.get_f4_smoothing_params( self.a_hb_2, self.theta0_hb_2, self.delta_theta_star_hb_2, ) # reference to f4(theta_3) b_hb_3, delta_theta_hb_3_c = bsf.get_f4_smoothing_params( self.a_hb_3, self.theta0_hb_3, self.delta_theta_star_hb_3, ) # reference to f4(theta_4) b_hb_4, delta_theta_hb_4_c = bsf.get_f4_smoothing_params( self.a_hb_4, self.theta0_hb_4, self.delta_theta_star_hb_4, ) # reference to f4(theta_7) b_hb_7, delta_theta_hb_7_c = bsf.get_f4_smoothing_params( self.a_hb_7, self.theta0_hb_7, self.delta_theta_star_hb_7, ) # reference to f4(theta_8) b_hb_8, delta_theta_hb_8_c = bsf.get_f4_smoothing_params( self.a_hb_8, self.theta0_hb_8, self.delta_theta_star_hb_8, ) return self.replace( b_low_hb=b_low_hb, dr_c_low_hb=dr_c_low_hb, b_high_hb=b_high_hb, dr_c_high_hb=dr_c_high_hb, b_hb_1=b_hb_1, delta_theta_hb_1_c=delta_theta_hb_1_c, b_hb_2=b_hb_2, delta_theta_hb_2_c=delta_theta_hb_2_c, b_hb_3=b_hb_3, delta_theta_hb_3_c=delta_theta_hb_3_c, b_hb_4=b_hb_4, delta_theta_hb_4_c=delta_theta_hb_4_c, b_hb_7=b_hb_7, delta_theta_hb_7_c=delta_theta_hb_7_c, b_hb_8=b_hb_8, delta_theta_hb_8_c=delta_theta_hb_8_c, eps_hb_weights=eps_hb_weights, )
[docs] @chex.dataclass(frozen=True) class HydrogenBonding(je_base.BaseEnergyFunction): """Hydrogen bonding energy function for DNA1 model.""" params: HydrogenBondingConfiguration
[docs] def compute_v_hb( self, body_i: je_base.BaseNucleotide, body_j: je_base.BaseNucleotide, unbonded_neighbors: typ.Arr_Unbonded_Neighbors, ) -> typ.Arr_Unbonded_Neighbors: """Computes the sequence-independent energy for each unbonded pair.""" op_i = unbonded_neighbors[0] op_j = unbonded_neighbors[1] mask = jnp.array(op_i < body_i.center.shape[0], dtype=jnp.float64) dr_base_op = self.displacement_mapped(body_j.base_sites[op_j], body_i.base_sites[op_i]) r_base_op = jnp.linalg.norm(dr_base_op, axis=1) theta1_op = jnp.arccos( jd_math.clamp(jd_math.mult(-body_i.back_base_vectors[op_i], body_j.back_base_vectors[op_j])) ) theta2_op = jnp.arccos(jd_math.clamp(jd_math.mult(-body_j.back_base_vectors[op_j], dr_base_op) / r_base_op)) theta3_op = jnp.arccos(jd_math.clamp(jd_math.mult(body_i.back_base_vectors[op_i], dr_base_op) / r_base_op)) theta4_op = jnp.arccos(jd_math.clamp(jd_math.mult(body_i.base_normals[op_i], body_j.base_normals[op_j]))) # note: are these swapped in Lorenzo's code? theta7_op = jnp.arccos(jd_math.clamp(jd_math.mult(-body_j.base_normals[op_j], dr_base_op) / r_base_op)) theta8_op = jnp.pi - jnp.arccos(jd_math.clamp(jd_math.mult(body_i.base_normals[op_i], dr_base_op) / r_base_op)) v_hb = dna1_interactions.hydrogen_bonding( dr_base_op, theta1_op, theta2_op, theta3_op, theta4_op, theta7_op, theta8_op, self.params.dr_low_hb, self.params.dr_high_hb, self.params.dr_c_low_hb, self.params.dr_c_high_hb, 1, # eps is handled via eps_hb_weights NT-NT matrix self.params.a_hb, self.params.dr0_hb, self.params.dr_c_hb, self.params.b_low_hb, self.params.b_high_hb, self.params.theta0_hb_1, self.params.delta_theta_star_hb_1, self.params.a_hb_1, self.params.delta_theta_hb_1_c, self.params.b_hb_1, self.params.theta0_hb_2, self.params.delta_theta_star_hb_2, self.params.a_hb_2, self.params.delta_theta_hb_2_c, self.params.b_hb_2, self.params.theta0_hb_3, self.params.delta_theta_star_hb_3, self.params.a_hb_3, self.params.delta_theta_hb_3_c, self.params.b_hb_3, self.params.theta0_hb_4, self.params.delta_theta_star_hb_4, self.params.a_hb_4, self.params.delta_theta_hb_4_c, self.params.b_hb_4, self.params.theta0_hb_7, self.params.delta_theta_star_hb_7, self.params.a_hb_7, self.params.delta_theta_hb_7_c, self.params.b_hb_7, self.params.theta0_hb_8, self.params.delta_theta_star_hb_8, self.params.a_hb_8, self.params.delta_theta_hb_8_c, self.params.b_hb_8, ) return jnp.where(mask, v_hb, 0.0) # Mask for neighbors
[docs] def weight(self, i: int, j: int, seq: typ.Probabilistic_Sequence) -> float: """Computes the sequence-dependent weight for an unbonded pair.""" sc = self.params.pseq_constraints return compute_seq_dep_weight( seq, i, j, self.params.eps_hb_weights, sc.is_unpaired, sc.idx_to_unpaired_idx, sc.idx_to_bp_idx )
[docs] def pairwise_energies( self, body_i: je_base.BaseNucleotide, body_j: je_base.BaseNucleotide, seq: typ.Discrete_Sequence, unbonded_neighbors: typ.Arr_Unbonded_Neighbors_2, ) -> typ.Arr_Unbonded_Neighbors: """Computes the hydrogen bonding energy for each unbonded pair.""" # Compute sequence-independent energy for each unbonded pair v_hb = self.compute_v_hb(body_i, body_j, unbonded_neighbors) # Compute sequence-dependent weight for each unbonded pair op_i = unbonded_neighbors[0] op_j = unbonded_neighbors[1] if self.params.pseq: hb_weights = vmap(self.weight, (0, 0, None))(op_i, op_j, self.params.pseq) else: hb_weights = self.params.eps_hb_weights[seq[op_i], seq[op_j]] return jnp.multiply(hb_weights, v_hb)
[docs] @override def compute_energy(self, nucleotide: je_base.BaseNucleotide) -> typ.Scalar: dgs = self.pairwise_energies(nucleotide, nucleotide, self.seq, self.unbonded_neighbors) return dgs.sum()