Source code for mythos.energy.dna1.unbonded_excluded_volume

"""Unbonded excluded volume energy function for DNA1 model."""

import dataclasses as dc

import chex
import jax.numpy as jnp
from typing_extensions import override

import mythos.energy.base as je_base
import mythos.energy.configuration as config
import mythos.energy.dna1.base_smoothing_functions as bsf
import mythos.energy.dna1.interactions as dna1_interactions
import mythos.utils.types as typ


[docs] @chex.dataclass(frozen=True) class UnbondedExcludedVolumeConfiguration(config.BaseConfiguration): """Configuration for the unbonded excluded volume energy function.""" # independent parameters eps_exc: float | None = None dr_star_base: float | None = None sigma_base: float | None = None dr_star_back_base: float | None = None sigma_back_base: float | None = None dr_star_base_back: float | None = None sigma_base_back: float | None = None dr_star_backbone: float | None = None sigma_backbone: float | None = None # dependent parameters b_base: float | None = None dr_c_base: float | None = None b_back_base: float | None = None dr_c_back_base: float | None = None b_base_back: float | None = None dr_c_base_back: float | None = None b_backbone: float | None = None dr_c_backbone: float | None = None # override required_params: tuple[str] = ( "eps_exc", "dr_star_base", "sigma_base", "dr_star_back_base", "sigma_back_base", "dr_star_base_back", "sigma_base_back", "dr_star_backbone", "sigma_backbone", ) # override dependent_params: tuple[str] = ( "b_base", "dr_c_base", "b_back_base", "dr_c_back_base", "b_base_back", "dr_c_base_back", "b_backbone", "dr_c_backbone", )
[docs] @override def init_params(self) -> "UnbondedExcludedVolumeConfiguration": # reference to f3(dr_base) b_base, dr_c_base = bsf.get_f3_smoothing_params(self.dr_star_base, self.sigma_base) # reference to f3(dr_back_base) b_back_base, dr_c_back_base = bsf.get_f3_smoothing_params(self.dr_star_back_base, self.sigma_back_base) # reference to f3(dr_base_back) b_base_back, dr_c_base_back = bsf.get_f3_smoothing_params( self.dr_star_base_back, self.sigma_base_back, ) # reference to f3(dr_backbone) b_backbone, dr_c_backbone = bsf.get_f3_smoothing_params( self.dr_star_backbone, self.sigma_backbone, ) return dc.replace( self, b_base=b_base, dr_c_base=dr_c_base, b_back_base=b_back_base, dr_c_back_base=dr_c_back_base, b_base_back=b_base_back, dr_c_base_back=dr_c_base_back, b_backbone=b_backbone, dr_c_backbone=dr_c_backbone, )
[docs] @chex.dataclass(frozen=True) class UnbondedExcludedVolume(je_base.BaseEnergyFunction): """Unbonded excluded volume energy function for DNA1 model.""" params: UnbondedExcludedVolumeConfiguration
[docs] def pairwise_energies( self, body_i: je_base.BaseNucleotide, body_j: je_base.BaseNucleotide, unbonded_neighbors: typ.Arr_Unbonded_Neighbors_2, ) -> typ.Arr_Bonded_Neighbors: """Computes the excluded volume energy for each unbonded pair.""" op_i = unbonded_neighbors[0] op_j = unbonded_neighbors[1] mask = jnp.array(op_i < body_i.center.shape[0], dtype=jnp.float32) dr_base_op = self.displacement_mapped(body_j.base_sites[op_j], body_i.base_sites[op_i]) dr_backbone_op = self.displacement_mapped(body_j.back_sites[op_j], body_i.back_sites[op_i]) dr_back_base_op = self.displacement_mapped(body_i.back_sites[op_i], body_j.base_sites[op_j]) dr_base_back_op = self.displacement_mapped(body_i.base_sites[op_i], body_j.back_sites[op_j]) exc_vol_unbonded_dg = dna1_interactions.exc_vol_unbonded( dr_base_op, dr_backbone_op, dr_back_base_op, dr_base_back_op, self.params.eps_exc, self.params.dr_star_base, self.params.sigma_base, self.params.b_base, self.params.dr_c_base, self.params.dr_star_back_base, self.params.sigma_back_base, self.params.b_back_base, self.params.dr_c_back_base, self.params.dr_star_base_back, self.params.sigma_base_back, self.params.b_base_back, self.params.dr_c_base_back, self.params.dr_star_backbone, self.params.sigma_backbone, self.params.b_backbone, self.params.dr_c_backbone, ) return jnp.where(mask, exc_vol_unbonded_dg, 0.0)
[docs] @override def compute_energy(self, nucleotide: je_base.BaseNucleotide) -> typ.Scalar: dgs = self.pairwise_energies(nucleotide, nucleotide, self.unbonded_neighbors) return dgs.sum()