"""Utility functions for computing worm-like chain (WLC) fit."""
import jax.numpy as jnp
from jaxopt import GaussNewton
import mythos.utils.types as jd_types
[docs]
def coth(x: jd_types.ARR_OR_SCALAR) -> jd_types.ARR_OR_SCALAR:
"""Hyperbolic cotangent function."""
return (jnp.exp(2 * x) + 1) / (jnp.exp(2 * x) - 1)
[docs]
def calculate_extension(
force: jd_types.ARR_OR_SCALAR,
l0: jd_types.ARR_OR_SCALAR,
lp: jd_types.ARR_OR_SCALAR,
k: jd_types.ARR_OR_SCALAR,
kT: float, # noqa: N803 -- kT is a special unit variable
) -> jd_types.ARR_OR_SCALAR:
r"""Computes the extension under a specified force under the wormlike chain (WLC) model.
Via the model of Odijk, the extension of an extensible wormlike chain (WLC) under a force
F can be computed as
.. math::
x = L_0 \left (1 + \frac{F}{K} - \frac{kT}{2F} [1 + y\coth y] \right)
where
.. math::
y = \left( \frac{FL_0^2}{L_p kT} \right)^{1/2}
where `L_0` is the contour length and `L_p` is the persistence length.
This function computes implements this model for computing the extension.
Args:
force (jd_types.ARR_OR_SCALAR): the force applied to the duplex
l0 (jd_types.ARR_OR_SCALAR): the contour length
lp (jd_types.ARR_OR_SCALAR): the persistence length
k (jd_types.ARR_OR_SCALAR): the extensional modulus
kT (float): the temperature
Returns:
jd_types.ARR_OR_SCALAR: the predicted extension
"""
y = ((force * l0**2) / (lp * kT)) ** (1 / 2)
return l0 * (1 + force / k - kT / (2 * force * l0) * (1 + y * coth(y)))
[docs]
def loss(
coeffs: jnp.ndarray,
extensions: jnp.ndarray,
forces: jnp.ndarray,
kT: float, # noqa: N803 -- kT is a special unit variable
) -> jnp.ndarray:
"""An objective function for the WLC model compatible with JAX solvers.
Args:
coeffs (jnp.ndarray): The parameters of the WLC model, ordered as [L_0, L_p, K]
extensions (jnp.ndarray): The measured extensions (via simulation) to which we are fitting the model
forces (jnp.ndarray): The forces under which the extensions were measured
kT (float): the temperature
Returns:
jnp.ndarray: the residual for each measured extension
"""
# Extract the coefficients
# Note: coefficients ordering: [L0, Lp, K]
l0 = coeffs[0]
lp = coeffs[1]
k = coeffs[2]
# Compute the extensions as predicted with the designated parameters
extensions_calc = calculate_extension(forces, l0, lp, k, kT)
# Compute the residuals with the measured extensions
return extensions - extensions_calc
[docs]
def fit_wlc(
extensions: jnp.ndarray,
forces: jnp.ndarray,
init_guess: jnp.ndarray,
kT: float, # noqa: N803 -- kT is a special unit variable
*,
implicit_diff: bool = True,
) -> jnp.ndarray:
"""Fit the WLC model via nonlinear least squares given a set of forces and measured extensions.
Args:
extensions (jnp.ndarray): The measured extensions (via simulation) to which we are fitting the model
forces (jnp.ndarray): The forces under which the extensions were measured
init_guess (jnp.ndarray): An initial guess for the parameters of the WLC model, ordered as [L_0, L_p, K]
kT (float): the temperature
implicit_diff (bool): Whether or not to use implicit differentiation for the numerical solver
Returns:
jnp.ndarray: the fit parameters of the WLC model, ordered as [L_0, L_p, K]
"""
gn = GaussNewton(residual_fun=loss, implicit_diff=implicit_diff)
res = gn.run(init_guess, extensions=extensions, forces=forces, kT=kT)
return res.params