Source code for mythos.energy.dna1.fene

"""FENE energy function for DNA1 model."""

import chex
import jax.numpy as jnp
from typing_extensions import override

import mythos.energy.base as je_base
import mythos.energy.configuration as config
import mythos.energy.dna1.interactions as dna1_interactions
import mythos.utils.types as typ


[docs] @chex.dataclass(frozen=True) class FeneConfiguration(config.BaseConfiguration): """Configuration for the FENE energy function.""" eps_backbone: float | None = None r0_backbone: float | None = None delta_backbone: float | None = None fmax: float | None = None finf: float | None = None # override required_params: tuple[str] = ("eps_backbone", "r0_backbone", "delta_backbone", "fmax", "finf")
[docs] @override def init_params(self) -> "FeneConfiguration": return self
[docs] @chex.dataclass(frozen=True) class Fene(je_base.BaseEnergyFunction): """FENE energy function for DNA1 model.""" params: FeneConfiguration
[docs] def pairwise_energies( self, body: je_base.BaseNucleotide, bonded_neighbors: typ.Arr_Bonded_Neighbors_2, ) -> typ.Arr_Bonded_Neighbors: """Computes the FENE energy for each bonded pair.""" nn_i = bonded_neighbors[:, 0] nn_j = bonded_neighbors[:, 1] dr_back_nn = self.displacement_mapped(body.back_sites[nn_i], body.back_sites[nn_j]) r_back_nn = jnp.linalg.norm(dr_back_nn, axis=1) return dna1_interactions.v_fene_smooth( r_back_nn, self.params.eps_backbone, self.params.r0_backbone, self.params.delta_backbone, self.params.fmax, self.params.finf, )
[docs] @override def compute_energy(self, nucleotide: je_base.BaseNucleotide) -> typ.Scalar: dgs = self.pairwise_energies(nucleotide, self.bonded_neighbors) return dgs.sum()