"""Melting temperature observable."""
import dataclasses as dc
from collections.abc import Callable
import chex
import jax
import jax.numpy as jnp
import mythos.observables.base as jd_obs
import mythos.simulators.io as jd_sio
import mythos.utils.types as jd_types
from mythos.utils.units import get_kt_from_c
TARGETS = {
"SL_avg_6bp": get_kt_from_c(31.2), # degrees
"SL_avg_8bp": get_kt_from_c(48.2), # degrees
"SL_avg_12bp": get_kt_from_c(64.7), # degrees
}
[docs]
def jax_interp1d(x: jnp.ndarray, y: jnp.ndarray, x_new: float) -> jnp.ndarray:
"""Simple linear interpolation function using JAX.
Args:
x: Array of x coordinates
y: Array of y coordinates
x_new: Point(s) at which to interpolate
Returns:
Interpolated y value(s)
"""
# Sort x and y if x is not already sorted
sorted_idx = jnp.argsort(x)
x_sorted = x[sorted_idx]
y_sorted = y[sorted_idx]
return jnp.interp(x_new, x_sorted, y_sorted)
[docs]
def compute_finf(ratio: jnp.ndarray) -> jnp.ndarray:
"""Finite size correction to bound:unbound ratio."""
return 1 + 1/(2*ratio) - jnp.sqrt((1 + 1/(2*ratio))**2 - 1)
[docs]
def find_melting_temp(temperatures: jnp.ndarray, ratios: jnp.ndarray) -> float:
"""Find the temperature at which the concentration of single strands = 0.5 * duplex concentration.
Args:
temperatures: Array of temperature values
ratios: Array of unbound:bound ratio values corresponding to each temperature
target_ratio: The ratio value to find (default: 0.5)
Returns:
The interpolated temperature where ratio = target_ratio
"""
return jax_interp1d(ratios, temperatures, 0.5)
[docs]
def compute_curve_width(temperatures: jnp.ndarray, ratios: jnp.ndarray) -> float:
"""Find the width of the melting curve.
defined as the temperature separation between unbound:bound ratio = 0.2 and
unbound:bound ratio = 0.8
Args:
temperatures: Array of temperature values
ratios: Array of unbound:bound ratio values corresponding to each temperature
Returns:
The width of the interpolated temperature curve between 0.2 and 0.8
"""
return jax_interp1d(ratios, temperatures, 0.8) - jax_interp1d(ratios, temperatures, 0.2)
[docs]
@chex.dataclass(frozen=True)
class MeltingTemp(jd_obs.BaseObservable):
"""Computes the melting temperature of a duplex using umbrella sampling.
The melting temperature is defined as the temperature at which the
concentration of DNA duplexes is double that of the concentration of single
strands.
Args:
sim_temperature: float. the temperature at which the SimulatorTrajectory
was collected, in sim. units.
temperature_range: a vector containing the temperatures to extrapolate
the SimulatorTrajectory data to (via histogram reweighting), in sim.
units.
energy_config: Energy configurations.
energy_fn_builder: Energy function builder.
"""
sim_temperature: float # Temperature at which the simulation was conducted in sim. units
temperature_range: jnp.ndarray = dc.field(hash=False)
energy_fn: Callable
[docs]
def __call__(
self,
trajectory: jd_sio.SimulatorTrajectory,
bind_states: jnp.ndarray,
umbrella_weights: jnp.ndarray,
opt_params: jd_types.PyTree,
) -> float:
"""Calculate the melting temperature.
Args:
trajectory (jd_traj.Trajectory): the trajectory to calculate the melting temperature for
bind_states (jnp.ndarray): an array of the sampled states of the "bond" order parameter
umbrella_weights (jnp.ndarray): an N-dimensional array containing umbrella sampling weights
opt_params: the parameters to optimize; use the current vals in building the energy functions
Returns:
float: the melting temperature in simulation units
"""
return self.get_melting_temperature(trajectory, bind_states, umbrella_weights, opt_params)
[docs]
def get_melting_temperature(
self,
trajectory: jd_sio.SimulatorTrajectory,
bind_states: jnp.ndarray,
umbrella_weights: jnp.ndarray,
opt_params: jd_types.PyTree,
) -> float:
"""Calculate the melting temperature."""
extrap_ratios = self.get_extrap_ratios(trajectory, bind_states, umbrella_weights, opt_params)
return find_melting_temp(self.temperature_range, extrap_ratios)
[docs]
def get_melting_curve(
self,
trajectory: jd_sio.SimulatorTrajectory,
bind_states: jnp.ndarray,
umbrella_weights: jnp.ndarray,
opt_params: jd_types.PyTree,
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Calculate the melting curve."""
extrap_ratios = self.get_extrap_ratios(trajectory, bind_states, umbrella_weights, opt_params)
return self.temperature_range, extrap_ratios
[docs]
def get_melting_curve_width(
self,
trajectory: jd_sio.SimulatorTrajectory,
bind_states: jnp.ndarray,
umbrella_weights: jnp.ndarray,
opt_params: jd_types.PyTree,
) -> float:
"""Calculate the melting curve width."""
extrap_ratios = self.get_extrap_ratios(trajectory, bind_states, umbrella_weights, opt_params)
return compute_curve_width(self.temperature_range, extrap_ratios)