"""FENE energy function for NA1 model."""
import dataclasses as dc
import chex
import jax
import jax.numpy as jnp
from typing_extensions import override
import mythos.energy.base as je_base
import mythos.energy.configuration as config
import mythos.energy.dna1 as dna1_energy
import mythos.energy.na1.nucleotide as na1_nucleotide
import mythos.energy.na1.utils as je_utils
import mythos.utils.types as typ
[docs]
@chex.dataclass(frozen=True)
class FeneConfiguration(config.BaseConfiguration):
"""Configuration for the FENE energy function."""
# independent parameters
nt_type: typ.Arr_Nucleotide | None = None
## DNA2-specific
dna_eps_backbone: float | None = None
dna_r0_backbone: float | None = None
dna_delta_backbone: float | None = None
dna_fmax: float | None = None
dna_finf: float | None = None
## RNA2-specific
rna_eps_backbone: float | None = None
rna_r0_backbone: float | None = None
rna_delta_backbone: float | None = None
rna_fmax: float | None = None
rna_finf: float | None = None
# dependent parameters
dna_config: dna1_energy.FeneConfiguration | None = None
rna_config: dna1_energy.FeneConfiguration | None = None
# override
required_params: tuple[str] = (
"nt_type",
# DNA2-specific
"dna_eps_backbone",
"dna_r0_backbone",
"dna_delta_backbone",
"dna_fmax",
"dna_finf",
# RNA2-specific
"rna_eps_backbone",
"rna_r0_backbone",
"rna_delta_backbone",
"rna_fmax",
"rna_finf",
)
[docs]
@override
def init_params(self) -> "FeneConfiguration":
dna_config = dna1_energy.FeneConfiguration(
eps_backbone=self.dna_eps_backbone,
r0_backbone=self.dna_r0_backbone,
delta_backbone=self.dna_delta_backbone,
fmax=self.dna_fmax,
finf=self.dna_finf,
).init_params()
rna_config = dna1_energy.FeneConfiguration(
eps_backbone=self.rna_eps_backbone,
r0_backbone=self.rna_r0_backbone,
delta_backbone=self.rna_delta_backbone,
fmax=self.rna_fmax,
finf=self.rna_finf,
).init_params()
return dc.replace(
self,
dna_config=dna_config,
rna_config=rna_config,
)
[docs]
@chex.dataclass(frozen=True)
class Fene(je_base.BaseEnergyFunction):
"""FENE energy function for NA1 model."""
params: FeneConfiguration
[docs]
@override
def compute_energy(self, nucleotide: na1_nucleotide.HybridNucleotide) -> typ.Scalar:
nn_i = self.bonded_neighbors[:, 0]
nn_j = self.bonded_neighbors[:, 1]
is_rna_bond = jax.vmap(je_utils.is_rna_pair, (0, 0, None))(nn_i, nn_j, self.params.nt_type)
dna_dgs = dna1_energy.Fene.create_from(self, params=self.params.dna_config).pairwise_energies(
nucleotide.dna,
self.bonded_neighbors,
)
rna_dgs = dna1_energy.Fene.create_from(self, params=self.params.rna_config).pairwise_energies(
nucleotide.rna,
self.bonded_neighbors,
)
dgs = jnp.where(is_rna_bond, rna_dgs, dna_dgs)
return dgs.sum()