Source code for mythos.energy.dna1.base_smoothing_functions

"""Smoothing functions for the base functions in DNA1 model.

functional forms from oxDNA paper
https://ora.ox.ac.uk/objects/uuid:b2415bb2-7975-4f59-b5e2-8c022b4a3719/files/mdcac62bc9133143fc05070ed20048c50
# Section 2.4.1
"""

import jax.numpy as jnp

import mythos.utils.types as typ


[docs] def _solve_f1_b(x: typ.Scalar, a: typ.Scalar, x0: typ.Scalar, xc: typ.Scalar) -> typ.Scalar: """Solve for the smoothing parameter b in the f1 smoothing function.""" return ( a**2 * ( -jnp.exp(a * (3 * x0 + 2 * xc)) + 2 * jnp.exp(a * (x + 2 * x0 + 2 * xc)) - jnp.exp(a * (2 * x + x0 + 2 * xc)) ) * jnp.exp(-2 * a * x) / ( 2 * jnp.exp(a * (x + 2 * xc)) + jnp.exp(a * (2 * x + x0)) - 2 * jnp.exp(a * (2 * x + xc)) - jnp.exp(a * (x0 + 2 * xc)) ) )
[docs] def _solve_f1_xc_star(x: typ.Scalar, a: typ.Scalar, x0: typ.Scalar, xc: typ.Scalar) -> typ.Scalar: """Solve for the smoothing parameter xc_star in the f1 smoothing function.""" return ( ( a * x * jnp.exp(a * (x + 2 * xc)) - a * x * jnp.exp(a * (x0 + 2 * xc)) + 2 * jnp.exp(a * (x + 2 * xc)) + jnp.exp(a * (2 * x + x0)) - 2 * jnp.exp(a * (2 * x + xc)) - jnp.exp(a * (x0 + 2 * xc)) ) * jnp.exp(-2 * a * xc) / (a * (jnp.exp(a * x) - jnp.exp(a * x0))) )
[docs] def get_f1_smoothing_params( x0: typ.Scalar, a: typ.Scalar, xc: typ.Scalar, x_low: typ.Scalar, x_high: typ.Scalar ) -> tuple[typ.Scalar, typ.Scalar, typ.Scalar, typ.Scalar]: """Get the smoothing parameters for the f1 smoothing function.""" solved_b_low = _solve_f1_b(x_low, a, x0, xc) solved_b_high = _solve_f1_b(x_high, a, x0, xc) solved_xc_low = _solve_f1_xc_star(x_low, a, x0, xc) solved_xc_high = _solve_f1_xc_star(x_high, a, x0, xc) return solved_b_low, solved_xc_low, solved_b_high, solved_xc_high
[docs] def _solve_f2_b(x: typ.Scalar, x0: typ.Scalar, xc: typ.Scalar) -> typ.Scalar: """Solve for the smoothing parameter b in the f2 smoothing function.""" return (x - x0) ** 2 / (2 * (x - xc) * (x - 2 * x0 + xc))
[docs] def _solve_f2_xc_star(x: typ.Scalar, x0: typ.Scalar, xc: typ.Scalar) -> typ.Scalar: """Solve for the smoothing parameter xc_star in the f2 smoothing function.""" return (x * x0 - 2 * x0 * xc + xc**2) / (x - x0)
[docs] def get_f2_smoothing_params( x0: typ.Scalar, xc: typ.Scalar, x_low: typ.Scalar, x_high: typ.Scalar ) -> tuple[typ.Scalar, typ.Scalar, typ.Scalar, typ.Scalar]: """Get the smoothing parameters for the f2 smoothing function.""" solved_b_low = _solve_f2_b(x_low, x0, xc) solved_b_high = _solve_f2_b(x_high, x0, xc) solved_xc_low = _solve_f2_xc_star(x_low, x0, xc) solved_xc_high = _solve_f2_xc_star(x_high, x0, xc) return solved_b_low, solved_xc_low, solved_b_high, solved_xc_high
[docs] def _solve_f3_b(x: typ.Scalar, sigma: typ.Scalar) -> typ.Scalar: """Solve for the smoothing parameter b in the f3 smoothing function.""" return ( -36 * sigma**6 * (-2 * sigma**6 + x**6) ** 2 / (x**14 * (-sigma + x) * (sigma + x) * (sigma**2 - sigma * x + x**2) * (sigma**2 + sigma * x + x**2)) )
[docs] def _solve_f3_xc(x: typ.Scalar, sigma: typ.Scalar) -> typ.Scalar: """Solve for the smoothing parameter xc in the f3 smoothing function.""" return x * (-7 * sigma**6 + 4 * x**6) / (3 * (-2 * sigma**6 + x**6))
[docs] def get_f3_smoothing_params(r_star: typ.Scalar, sigma: typ.Scalar) -> tuple[typ.Scalar, typ.Scalar]: """Get the smoothing parameters for the f3 smoothing function.""" solved_b = _solve_f3_b(r_star, sigma) solved_xc = _solve_f3_xc(r_star, sigma) return solved_b, solved_xc
[docs] def _solve_f4_b(x: typ.Scalar, x0: typ.Scalar, a: typ.Scalar) -> typ.Scalar: """Solve for the smoothing parameter b in the f4 smoothing function.""" return -(a**2) * (x - x0) ** 2 / (a * x**2 - 2 * a * x * x0 + a * x0**2 - 1)
[docs] def _solve_f4_xc(x: typ.Scalar, x0: typ.Scalar, a: typ.Scalar) -> typ.Scalar: """Solve for the smoothing parameter xc in the f4 smoothing function.""" return (-a * x * x0 + a * x0**2 - 1) / (a * (-x + x0))
[docs] def get_f4_smoothing_params(a: typ.Scalar, x0: typ.Scalar, delta_x_star: typ.Scalar) -> tuple[typ.Scalar, typ.Scalar]: """Get the smoothing parameters for the f4 smoothing function.""" solved_b_plus = _solve_f4_b(x0 + delta_x_star, x0, a) solved_xc_plus = _solve_f4_xc(x0 + delta_x_star, x0, a) solved_delta_xc_plus = solved_xc_plus - x0 return solved_b_plus, solved_delta_xc_plus
[docs] def _solve_f5_b(x: typ.Scalar, x0: typ.Scalar, a: typ.Scalar) -> typ.Scalar: """Solve for the smoothing parameter b in the f5 smoothing function.""" return -(a**2) * (x - x0) ** 2 / (a * x**2 - 2 * a * x * x0 + a * x0**2 - 1)
[docs] def _solve_f5_xc(x: typ.Scalar, x0: typ.Scalar, a: typ.Scalar) -> typ.Scalar: """Solve for the smoothing parameter xc in the f5 smoothing function.""" return (a * x * x0 - a * x0**2 + 1) / (a * (x - x0))
[docs] def get_f5_smoothing_params(a: typ.Scalar, x_star: typ.Scalar) -> tuple[typ.Scalar, typ.Scalar]: """Get the smoothing parameters for the f5 smoothing function.""" solved_b = _solve_f5_b(x_star, 0.0, a) solved_xc = _solve_f5_xc(x_star, 0.0, a) return solved_b, solved_xc