"""DNA1 interactions.
These functions are based on the oxDNA1 model paper found here:
https://ora.ox.ac.uk/objects/uuid:b2415bb2-7975-4f59-b5e2-8c022b4a3719/files/mdcac62bc9133143fc05070ed20048c50
"""
import jax.numpy as jnp
import jax.tree_util as tu
import mythos.energy.dna1.base_functions as jd_base_functions
import mythos.energy.potentials as jd_potentials
import mythos.utils.math as jd_math
import mythos.utils.types as typ
[docs]
def v_fene_smooth(
r: typ.ARR_OR_SCALAR,
eps_backbone: typ.Scalar,
r0_backbone: typ.Scalar,
delta_backbone: typ.Scalar,
fmax: typ.Scalar = 500,
finf: typ.Scalar = 4.0,
) -> typ.ARR_OR_SCALAR:
"""Smoothed version of the FENE potential."""
eps = eps_backbone
r0 = r0_backbone
delt = delta_backbone
diff = jd_math.smooth_abs(r - r0)
delt2 = delt**2
eps2 = eps**2
fmax2 = fmax**2
xmax = (-eps + jnp.sqrt(eps2 + 4 * fmax2 * delt2)) / (2 * fmax)
# precompute terms for smoothed case
fene_xmax = -(eps / 2.0) * jnp.log(1.0 - xmax**2 / delt2)
long_xmax = (fmax - finf) * xmax * jnp.log(xmax) + finf * xmax
smoothed_energy = (fmax - finf) * xmax * jnp.log(diff) + finf * diff - long_xmax + fene_xmax
return jnp.where(diff > xmax, smoothed_energy, jd_potentials.v_fene(r, eps, r0, delt))
[docs]
def exc_vol_bonded(
dr_base: typ.ARR_OR_SCALAR,
dr_back_base: typ.ARR_OR_SCALAR,
dr_base_back: typ.ARR_OR_SCALAR,
eps_exc: typ.Scalar,
# reference to f3(dr_base)
dr_star_base: typ.Scalar,
sigma_base: typ.Scalar,
b_base: typ.Scalar,
dr_c_base: typ.Scalar,
# reference to f3(dr_back_base)
dr_star_back_base: typ.Scalar,
sigma_back_base: typ.Scalar,
b_back_base: typ.Scalar,
dr_c_back_base: typ.Scalar,
# reference to f3(dr_base_back)
dr_star_base_back: typ.Scalar,
sigma_base_back: typ.Scalar,
b_base_back: typ.Scalar,
dr_c_base_back: typ.Scalar,
) -> typ.Scalar:
"""Excluded volume energy for bonded interactions."""
# Note: r_c must be greater than r*
r_base = jnp.linalg.norm(dr_base, axis=1)
r_back_base = jnp.linalg.norm(dr_back_base, axis=1)
r_base_back = jnp.linalg.norm(dr_base_back, axis=1)
f3_base_exc_vol = jd_base_functions.f3(
r_base, r_star=dr_star_base, r_c=dr_c_base, eps=eps_exc, sigma=sigma_base, b=b_base
)
f3_back_base_exc_vol = jd_base_functions.f3(
r_back_base, r_star=dr_star_back_base, r_c=dr_c_back_base, eps=eps_exc, sigma=sigma_back_base, b=b_back_base
)
f3_base_back_exc_vol = jd_base_functions.f3(
r_base_back, r_star=dr_star_base_back, r_c=dr_c_base_back, eps=eps_exc, sigma=sigma_base_back, b=b_base_back
)
return f3_base_exc_vol + f3_back_base_exc_vol + f3_base_back_exc_vol
[docs]
def exc_vol_unbonded(
dr_base: typ.ARR_OR_SCALAR,
dr_backbone: typ.ARR_OR_SCALAR,
dr_back_base: typ.ARR_OR_SCALAR,
dr_base_back: typ.ARR_OR_SCALAR,
eps_exc: typ.Scalar,
# reference to f3(dr_base)
dr_star_base: typ.Scalar,
sigma_base: typ.Scalar,
b_base: typ.Scalar,
dr_c_base: typ.Scalar,
# reference to f3(dr_back_base)
dr_star_back_base: typ.Scalar,
sigma_back_base: typ.Scalar,
b_back_base: typ.Scalar,
dr_c_back_base: typ.Scalar,
# reference to f3(dr_base_back)
dr_star_base_back: typ.Scalar,
sigma_base_back: typ.Scalar,
b_base_back: typ.Scalar,
dr_c_base_back: typ.Scalar,
# reference to f3(backbone)
dr_star_backbone: typ.Scalar,
sigma_backbone: typ.Scalar,
b_backbone: typ.Scalar,
dr_c_backbone: typ.Scalar,
) -> typ.Scalar:
"""Excluded volume energy for unbonded interactions."""
r_back = jnp.linalg.norm(dr_backbone, axis=1)
f3_back_exc_vol = jd_base_functions.f3(
r_back, r_star=dr_star_backbone, r_c=dr_c_backbone, eps=eps_exc, sigma=sigma_backbone, b=b_backbone
)
return f3_back_exc_vol + exc_vol_bonded(
dr_base,
dr_back_base,
dr_base_back,
eps_exc,
dr_star_base,
sigma_base,
b_base,
dr_c_base,
dr_star_back_base,
sigma_back_base,
b_back_base,
dr_c_back_base,
dr_star_base_back,
sigma_base_back,
b_base_back,
dr_c_base_back,
)
# comments from original code
# Note that we use r_stack instead of dr_stack
# TODO(rkruegs123): fix this one with bs and rcs
# https://github.com/ssec-jhu/mythos/issues/7
[docs]
def stacking(
# obervables
r_stack: typ.ARR_OR_SCALAR,
theta4: typ.ARR_OR_SCALAR,
theta5: typ.ARR_OR_SCALAR,
theta6: typ.ARR_OR_SCALAR,
cosphi1: typ.ARR_OR_SCALAR,
cosphi2: typ.ARR_OR_SCALAR,
# params
dr_low_stack: typ.Scalar,
dr_high_stack: typ.Scalar,
eps_stack: typ.Scalar,
a_stack: typ.Scalar,
dr0_stack: typ.Scalar,
dr_c_stack: typ.Scalar,
dr_c_low_stack: typ.Scalar,
dr_c_high_stack: typ.Scalar,
b_low_stack: typ.Scalar,
b_high_stack: typ.Scalar,
theta0_stack_4: typ.Scalar,
delta_theta_star_stack_4: typ.Scalar,
a_stack_4: typ.Scalar,
delta_theta_stack_4_c: typ.Scalar,
b_stack_4: typ.Scalar,
theta0_stack_5: typ.Scalar,
delta_theta_star_stack_5: typ.Scalar,
a_stack_5: typ.Scalar,
delta_theta_stack_5_c: typ.Scalar,
b_stack_5: typ.Scalar,
theta0_stack_6: typ.Scalar,
delta_theta_star_stack_6: typ.Scalar,
a_stack_6: typ.Scalar,
delta_theta_stack_6_c: typ.Scalar,
b_stack_6: typ.Scalar,
neg_cos_phi1_star_stack: typ.Scalar,
a_stack_1: typ.Scalar,
neg_cos_phi1_c_stack: typ.Scalar,
b_neg_cos_phi1_stack: typ.Scalar,
neg_cos_phi2_star_stack: typ.Scalar,
a_stack_2: typ.Scalar,
neg_cos_phi2_c_stack: typ.Scalar,
b_neg_cos_phi2_stack: typ.Scalar,
) -> typ.Scalar:
"""Stacking energy."""
f1_dr_stack = jd_base_functions.f1(
r_stack,
r_low=dr_low_stack,
r_high=dr_high_stack,
r_c_low=dr_c_low_stack,
r_c_high=dr_c_high_stack,
eps=eps_stack,
a=a_stack,
r0=dr0_stack,
r_c=dr_c_stack,
b_low=b_low_stack,
b_high=b_high_stack,
)
f4_theta_4_stack = jd_base_functions.f4(
theta4,
theta0=theta0_stack_4,
delta_theta_star=delta_theta_star_stack_4,
delta_theta_c=delta_theta_stack_4_c,
a=a_stack_4,
b=b_stack_4,
)
f4_theta_5p_stack = jd_base_functions.f4(
theta5,
theta0=theta0_stack_5,
delta_theta_star=delta_theta_star_stack_5,
delta_theta_c=delta_theta_stack_5_c,
a=a_stack_5,
b=b_stack_5,
)
f4_theta_6p_stack = jd_base_functions.f4(
theta6,
theta0=theta0_stack_6,
delta_theta_star=delta_theta_star_stack_6,
delta_theta_c=delta_theta_stack_6_c,
a=a_stack_6,
b=b_stack_6,
)
f5_neg_cosphi1_stack = jd_base_functions.f5(
-cosphi1,
x_star=neg_cos_phi1_star_stack,
x_c=neg_cos_phi1_c_stack,
a=a_stack_1,
b=b_neg_cos_phi1_stack,
)
f5_neg_cosphi2_stack = jd_base_functions.f5(
-cosphi2,
x_star=neg_cos_phi2_star_stack,
x_c=neg_cos_phi2_c_stack,
a=a_stack_2,
b=b_neg_cos_phi2_stack,
)
return (
f1_dr_stack
* f4_theta_4_stack
* f4_theta_5p_stack
* f4_theta_6p_stack
* f5_neg_cosphi1_stack
* f5_neg_cosphi2_stack
)
[docs]
def cross_stacking(
# observables
r_hb: typ.ARR_OR_SCALAR,
theta1: typ.ARR_OR_SCALAR,
theta2: typ.ARR_OR_SCALAR,
theta3: typ.ARR_OR_SCALAR,
theta4: typ.ARR_OR_SCALAR,
theta7: typ.ARR_OR_SCALAR,
theta8: typ.ARR_OR_SCALAR,
# reference to f2_dr_cross
dr_low_cross: typ.Scalar,
dr_high_cross: typ.Scalar,
dr_c_low_cross: typ.Scalar,
dr_c_high_cross: typ.Scalar,
k_cross: typ.Scalar,
r0_cross: typ.Scalar,
dr_c_cross: typ.Scalar,
b_low_cross: typ.Scalar,
b_high_cross: typ.Scalar,
# reference to f4(theta1)
theta0_cross_1: typ.Scalar,
delta_theta_star_cross_1: typ.Scalar,
delta_theta_cross_1_c: typ.Scalar,
a_cross_1: typ.Scalar,
b_cross_1: typ.Scalar,
# reference to f4(theta2)
theta0_cross_2: typ.Scalar,
delta_theta_star_cross_2: typ.Scalar,
delta_theta_cross_2_c: typ.Scalar,
a_cross_2: typ.Scalar,
b_cross_2: typ.Scalar,
# reference to f4(theta3)
theta0_cross_3: typ.Scalar,
delta_theta_star_cross_3: typ.Scalar,
delta_theta_cross_3_c: typ.Scalar,
a_cross_3: typ.Scalar,
b_cross_3: typ.Scalar,
# reference to f4(theta4)
theta0_cross_4: typ.Scalar,
delta_theta_star_cross_4: typ.Scalar,
delta_theta_cross_4_c: typ.Scalar,
a_cross_4: typ.Scalar,
b_cross_4: typ.Scalar,
# reference to f7(theta7)
theta0_cross_7: typ.Scalar,
delta_theta_star_cross_7: typ.Scalar,
delta_theta_cross_7_c: typ.Scalar,
a_cross_7: typ.Scalar,
b_cross_7: typ.Scalar,
# reference to f8(theta8)
theta0_cross_8: typ.Scalar,
delta_theta_star_cross_8: typ.Scalar,
delta_theta_cross_8_c: typ.Scalar,
a_cross_8: typ.Scalar,
b_cross_8: typ.Scalar,
) -> typ.Scalar:
"""Cross-stacking energy."""
f2_dr_cross = jd_base_functions.f2(
r_hb,
r_low=dr_low_cross,
r_high=dr_high_cross,
r_c_low=dr_c_low_cross,
r_c_high=dr_c_high_cross,
k=k_cross,
r0=r0_cross,
r_c=dr_c_cross,
b_low=b_low_cross,
b_high=b_high_cross,
)
f4_theta_1_cross = jd_base_functions.f4(
theta1,
theta0=theta0_cross_1,
delta_theta_star=delta_theta_star_cross_1,
delta_theta_c=delta_theta_cross_1_c,
a=a_cross_1,
b=b_cross_1,
)
f4_theta_2_cross = jd_base_functions.f4(
theta2,
theta0=theta0_cross_2,
delta_theta_star=delta_theta_star_cross_2,
delta_theta_c=delta_theta_cross_2_c,
a=a_cross_2,
b=b_cross_2,
)
f4_theta_3_cross = jd_base_functions.f4(
theta3,
theta0=theta0_cross_3,
delta_theta_star=delta_theta_star_cross_3,
delta_theta_c=delta_theta_cross_3_c,
a=a_cross_3,
b=b_cross_3,
)
f4_theta_4_cross_fn = tu.Partial(
jd_base_functions.f4,
theta0=theta0_cross_4,
delta_theta_star=delta_theta_star_cross_4,
delta_theta_c=delta_theta_cross_4_c,
a=a_cross_4,
b=b_cross_4,
)
f4_theta_7_cross_fn = tu.Partial(
jd_base_functions.f4,
theta0=theta0_cross_7,
delta_theta_star=delta_theta_star_cross_7,
delta_theta_c=delta_theta_cross_7_c,
a=a_cross_7,
b=b_cross_7,
)
f4_theta_8_cross_fn = tu.Partial(
jd_base_functions.f4,
theta0=theta0_cross_8,
delta_theta_star=delta_theta_star_cross_8,
delta_theta_c=delta_theta_cross_8_c,
a=a_cross_8,
b=b_cross_8,
)
return (
f2_dr_cross
* f4_theta_1_cross
* f4_theta_2_cross
* f4_theta_3_cross
* (f4_theta_4_cross_fn(theta4) + f4_theta_4_cross_fn(jnp.pi - theta4))
* (f4_theta_7_cross_fn(theta7) + f4_theta_7_cross_fn(jnp.pi - theta7))
* (f4_theta_8_cross_fn(theta8) + f4_theta_8_cross_fn(jnp.pi - theta8))
)
[docs]
def coaxial_stacking(
# obersvables
dr_stack: typ.ARR_OR_SCALAR,
theta4: typ.ARR_OR_SCALAR,
theta1: typ.ARR_OR_SCALAR,
theta5: typ.ARR_OR_SCALAR,
theta6: typ.ARR_OR_SCALAR,
cosphi3: typ.ARR_OR_SCALAR,
cosphi4: typ.ARR_OR_SCALAR,
# reference to f2(dr_stack)
dr_low_coax: typ.Scalar,
dr_high_coax: typ.Scalar,
dr_c_low_coax: typ.Scalar,
dr_c_high_coax: typ.Scalar,
k_coax: typ.Scalar,
dr0_coax: typ.Scalar,
dr_c_coax: typ.Scalar,
b_low_coax: typ.Scalar,
b_high_coax: typ.Scalar,
# reference to f4(theta4)
theta0_coax_4: typ.Scalar,
delta_theta_star_coax_4: typ.Scalar,
delta_theta_coax_4_c: typ.Scalar,
a_coax_4: typ.Scalar,
b_coax_4: typ.Scalar,
# reference to f4(theta1)
theta0_coax_1: typ.Scalar,
delta_theta_star_coax_1: typ.Scalar,
delta_theta_coax_1_c: typ.Scalar,
a_coax_1: typ.Scalar,
b_coax_1: typ.Scalar,
# reference to f4(theta5)
theta0_coax_5: typ.Scalar,
delta_theta_star_coax_5: typ.Scalar,
delta_theta_coax_5_c: typ.Scalar,
a_coax_5: typ.Scalar,
b_coax_5: typ.Scalar,
# reference to f4(theta6)
theta0_coax_6: typ.Scalar,
delta_theta_star_coax_6: typ.Scalar,
delta_theta_coax_6_c: typ.Scalar,
a_coax_6: typ.Scalar,
b_coax_6: typ.Scalar,
# reference to f5(cosphi3)
cos_phi3_star_coax: typ.Scalar,
cos_phi3_c_coax: typ.Scalar,
a_coax_3p: typ.Scalar,
b_cos_phi3_coax: typ.Scalar,
# reference to f5(cosphi4)
cos_phi4_star_coax: typ.Scalar,
cos_phi4_c_coax: typ.Scalar,
a_coax_4p: typ.Scalar,
b_cos_phi4_coax: typ.Scalar,
) -> typ.Scalar:
"""Coaxial stacking energy."""
r_stack = jnp.linalg.norm(dr_stack, axis=1)
f2_dr_coax = jd_base_functions.f2(
r_stack,
r_low=dr_low_coax,
r_high=dr_high_coax,
r_c_low=dr_c_low_coax,
r_c_high=dr_c_high_coax,
k=k_coax,
r0=dr0_coax,
r_c=dr_c_coax,
b_low=b_low_coax,
b_high=b_high_coax,
)
f4_theta_4_coax = jd_base_functions.f4(
theta4,
theta0=theta0_coax_4,
delta_theta_star=delta_theta_star_coax_4,
delta_theta_c=delta_theta_coax_4_c,
a=a_coax_4,
b=b_coax_4,
)
f4_theta_1_coax_fn = tu.Partial(
jd_base_functions.f4,
theta0=theta0_coax_1,
delta_theta_star=delta_theta_star_coax_1,
delta_theta_c=delta_theta_coax_1_c,
a=a_coax_1,
b=b_coax_1,
)
f4_theta_5_coax_fn = tu.Partial(
jd_base_functions.f4,
theta0=theta0_coax_5,
delta_theta_star=delta_theta_star_coax_5,
delta_theta_c=delta_theta_coax_5_c,
a=a_coax_5,
b=b_coax_5,
)
f4_theta_6_coax_fn = tu.Partial(
jd_base_functions.f4,
theta0=theta0_coax_6,
delta_theta_star=delta_theta_star_coax_6,
delta_theta_c=delta_theta_coax_6_c,
a=a_coax_6,
b=b_coax_6,
)
f5_cosphi3_coax = jd_base_functions.f5(
cosphi3, x_star=cos_phi3_star_coax, x_c=cos_phi3_c_coax, a=a_coax_3p, b=b_cos_phi3_coax
)
f5_cosphi4_coax = jd_base_functions.f5(
cosphi4, x_star=cos_phi4_star_coax, x_c=cos_phi4_c_coax, a=a_coax_4p, b=b_cos_phi4_coax
)
return (
f2_dr_coax
* f4_theta_4_coax
* (f4_theta_1_coax_fn(theta1) + f4_theta_1_coax_fn(2 * jnp.pi - theta1))
* (f4_theta_5_coax_fn(theta5) + f4_theta_5_coax_fn(jnp.pi - theta5))
* (f4_theta_6_coax_fn(theta6) + f4_theta_6_coax_fn(jnp.pi - theta6))
* f5_cosphi3_coax
* f5_cosphi4_coax
)
[docs]
def hydrogen_bonding(
# observables
dr_hb: typ.ARR_OR_SCALAR,
theta1: typ.ARR_OR_SCALAR,
theta2: typ.ARR_OR_SCALAR,
theta3: typ.ARR_OR_SCALAR,
theta4: typ.ARR_OR_SCALAR,
theta7: typ.ARR_OR_SCALAR,
theta8: typ.ARR_OR_SCALAR,
# reference to f1_dr_hb
dr_low_hb: typ.Scalar,
dr_high_hb: typ.Scalar,
dr_c_low_hb: typ.Scalar,
dr_c_high_hb: typ.Scalar,
eps_hb: typ.Scalar,
a_hb: typ.Scalar,
dr0_hb: typ.Scalar,
dr_c_hb: typ.Scalar,
b_low_hb: typ.Scalar,
b_high_hb: typ.Scalar,
# reference to f4_theta_1_hb
theta0_hb_1: typ.Scalar,
delta_theta_star_hb_1: typ.Scalar,
a_hb_1: typ.Scalar,
delta_theta_hb_1_c: typ.Scalar,
b_hb_1: typ.Scalar,
# reference to f4_theta_2_hb
theta0_hb_2: typ.Scalar,
delta_theta_star_hb_2: typ.Scalar,
a_hb_2: typ.Scalar,
delta_theta_hb_2_c: typ.Scalar,
b_hb_2: typ.Scalar,
# reference to f4_theta_3_hb
theta0_hb_3: typ.Scalar,
delta_theta_star_hb_3: typ.Scalar,
a_hb_3: typ.Scalar,
delta_theta_hb_3_c: typ.Scalar,
b_hb_3: typ.Scalar,
# reference to f4_theta_4_hb
theta0_hb_4: typ.Scalar,
delta_theta_star_hb_4: typ.Scalar,
a_hb_4: typ.Scalar,
delta_theta_hb_4_c: typ.Scalar,
b_hb_4: typ.Scalar,
# reference to f4_theta_7_hb
theta0_hb_7: typ.Scalar,
delta_theta_star_hb_7: typ.Scalar,
a_hb_7: typ.Scalar,
delta_theta_hb_7_c: typ.Scalar,
b_hb_7: typ.Scalar,
# reference to f4_theta_8_hb
theta0_hb_8: typ.Scalar,
delta_theta_star_hb_8: typ.Scalar,
a_hb_8: typ.Scalar,
delta_theta_hb_8_c: typ.Scalar,
b_hb_8: typ.Scalar,
) -> typ.Scalar:
"""Hydrogen bonding energy."""
r_hb = jnp.linalg.norm(dr_hb, axis=1)
f1_dr_hb = jd_base_functions.f1(
r_hb,
r_low=dr_low_hb,
r_high=dr_high_hb,
r_c_low=dr_c_low_hb,
r_c_high=dr_c_high_hb,
eps=eps_hb,
a=a_hb,
r0=dr0_hb,
r_c=dr_c_hb,
b_low=b_low_hb,
b_high=b_high_hb,
)
f4_theta_1_hb = jd_base_functions.f4(
theta1,
theta0=theta0_hb_1,
delta_theta_star=delta_theta_star_hb_1,
delta_theta_c=delta_theta_hb_1_c,
a=a_hb_1,
b=b_hb_1,
)
f4_theta_2_hb = jd_base_functions.f4(
theta2,
theta0=theta0_hb_2,
delta_theta_star=delta_theta_star_hb_2,
delta_theta_c=delta_theta_hb_2_c,
a=a_hb_2,
b=b_hb_2,
)
f4_theta_3_hb = jd_base_functions.f4(
theta3,
theta0=theta0_hb_3,
delta_theta_star=delta_theta_star_hb_3,
delta_theta_c=delta_theta_hb_3_c,
a=a_hb_3,
b=b_hb_3,
)
f4_theta_4_hb = jd_base_functions.f4(
theta4,
theta0=theta0_hb_4,
delta_theta_star=delta_theta_star_hb_4,
delta_theta_c=delta_theta_hb_4_c,
a=a_hb_4,
b=b_hb_4,
)
f4_theta_7_hb = jd_base_functions.f4(
theta7,
theta0=theta0_hb_7,
delta_theta_star=delta_theta_star_hb_7,
delta_theta_c=delta_theta_hb_7_c,
a=a_hb_7,
b=b_hb_7,
)
f4_theta_8_hb = jd_base_functions.f4(
theta8,
theta0=theta0_hb_8,
delta_theta_star=delta_theta_star_hb_8,
delta_theta_c=delta_theta_hb_8_c,
a=a_hb_8,
b=b_hb_8,
)
return f1_dr_hb * f4_theta_1_hb * f4_theta_2_hb * f4_theta_3_hb * f4_theta_4_hb * f4_theta_7_hb * f4_theta_8_hb