"""Trajectory information for RNA/DNA strands."""
import concurrent.futures as cf
import functools
import itertools
import multiprocessing as mp
from pathlib import Path
from typing import TextIO, TypeAlias
import chex
import jax.numpy as jnp
import jax_md
import numpy as np
import mythos.utils.math as jdm
import mythos.utils.types as typ
TRAJECTORY_TIMES_DIMS = 1
TRAJECTORY_ENERGIES_SHAPE = (None, 3)
NUCLEOTIDE_STATE_SHAPE = (None, 15)
ERR_TRAJECTORY_FILE_NOT_FOUND = "Trajectory file not found: {}"
ERR_TRAJECTORY_N_NUCLEOTIDE_STRAND_LEGNTHS = "n_nucleotides and sum(strand_lengths) do not match"
ERR_TRAJECTORY_TIMES_TYPE = "times must be a numpy array"
ERR_TRAJECTORY_ENERGIES_TYPE = "energies must be a numpy array"
ERR_TRAJECTORY_T_E_S_LENGTHS = "times, energies, and states do not have the same length"
ERR_TRAJECTORY_TIMES_DIMS = "times must be a 1D array"
ERR_TRAJECTORY_ENERGIES_SHAPE = "energies must be a 2D array with shape (n_states, 3)"
ERR_NUCLEOTIDE_STATE_TYPE = "Invalid type for nucleotide states:"
ERR_NUCLEOTIDE_STATE_SHAPE = "Invalid shape for nucleotide states:"
ERR_FIXED_BOX_SIZE = "Only trajecories in a fixed box size are supported"
RawTrajectory: TypeAlias = tuple[list[typ.Scalar], list[typ.Vector3D], list[typ.Vector3D], list[typ.Arr_Nucleotide_15]]
[docs]
@chex.dataclass(frozen=True)
class Trajectory:
"""Trajectory information for a RNA/DNA strand."""
n_nucleotides: int
strand_lengths: list[int]
times: typ.Arr_States
energies: typ.Arr_States_3
states: list["NucleotideState"]
box_size: typ.Vector3D | None = None
[docs]
def __post_init__(self) -> None:
"""Validate the input."""
if self.n_nucleotides != sum(self.strand_lengths):
raise ValueError(ERR_TRAJECTORY_N_NUCLEOTIDE_STRAND_LEGNTHS)
if not isinstance(self.times, np.ndarray):
raise TypeError(ERR_TRAJECTORY_TIMES_TYPE)
if not isinstance(self.energies, np.ndarray):
raise TypeError(ERR_TRAJECTORY_ENERGIES_TYPE)
if len(self.times) != len(self.energies) or len(self.times) != len(self.states):
raise ValueError(ERR_TRAJECTORY_T_E_S_LENGTHS)
if len(self.times.shape) != TRAJECTORY_TIMES_DIMS:
raise ValueError(ERR_TRAJECTORY_TIMES_DIMS)
if (
len(self.energies.shape) != len(TRAJECTORY_ENERGIES_SHAPE)
or self.energies.shape[1] != TRAJECTORY_ENERGIES_SHAPE[1]
):
raise ValueError(ERR_TRAJECTORY_ENERGIES_SHAPE)
@property
def state_rigid_bodies(self) -> list[jax_md.rigid_body.RigidBody]:
"""Convert the states to a list of rigid bodies."""
return [state.to_rigid_body() for state in self.states]
@property
def state_rigid_body(self) -> jax_md.rigid_body.RigidBody:
"""Convert the states to a single rigid body."""
return jax_md.rigid_body.RigidBody(
center=jnp.stack([state.com for state in self.states]),
orientation=jax_md.rigid_body.Quaternion(jnp.stack([state.quaternions for state in self.states])),
)
[docs]
def slice(self, key: int | slice) -> "Trajectory":
"""Get a subset of the trajectory."""
return Trajectory(
n_nucleotides=self.n_nucleotides,
strand_lengths=self.strand_lengths,
times=self.times[key],
energies=self.energies[key],
states=self.states[key],
)
[docs]
def __repr__(self) -> str:
"""Return a string representation of the trajectory."""
return "\n".join(
[
"Trajectory:",
f"n_nucleotides: {self.n_nucleotides}",
f"strand_lengths: {self.strand_lengths}",
f"# times: {len(self.times)}",
f"# energies: {len(self.energies)}",
f"# states: {len(self.states)}",
]
)
[docs]
def to_file(self, filepath: Path) -> None:
"""Write a jaxDNA simulation trajectory to oxDNA file format.
In cases where the box_size is not specified, it will be written as "0 0 0".
"""
with Path(filepath).open("w") as f:
for state_i in range(len(self.times)):
_write_state(
file=f,
time=self.times[state_i],
energies=self.energies[state_i],
state=self.states[state_i].array,
box_size=self.box_size,
)
[docs]
@chex.dataclass(frozen=True)
class NucleotideState:
"""State information for the nucleotides in a single state."""
array: typ.Arr_Nucleotide_15
[docs]
def __post_init__(self) -> None:
"""Validate the input array."""
if not isinstance(self.array, np.ndarray):
raise TypeError(ERR_NUCLEOTIDE_STATE_TYPE + str(type(self.array)))
if len(self.array.shape) != len(NUCLEOTIDE_STATE_SHAPE) or self.array.shape[1] != NUCLEOTIDE_STATE_SHAPE[1]:
raise ValueError(ERR_NUCLEOTIDE_STATE_SHAPE + str(self.array.shape))
@property
def com(self) -> typ.Arr_Nucleotide_3:
"""Center of mass of the nucleotides."""
return self.array[:, :3]
@property
def back_base_vector(self) -> typ.Arr_Nucleotide_3:
"""Backbone base vector."""
return self.array[:, 3:6]
@property
def base_normal(self) -> typ.Arr_Nucleotide_3:
"""Base normal to the base plane."""
return self.array[:, 6:9]
@property
def velocity(self) -> typ.Arr_Nucleotide_3:
"""Velocity of the nucleotides."""
return self.array[:, 9:12]
@property
def angular_velocity(self) -> typ.Arr_Nucleotide_3:
"""Angular velocity of the nucleotides."""
return self.array[:, 12:15]
@property
def euler_angles(self) -> tuple[typ.Arr_Nucleotide, typ.Arr_Nucleotide, typ.Arr_Nucleotide]:
"""Convert principal axes to Tait-Bryan Euler angles."""
return jdm.principal_axes_to_euler_angles(
self.back_base_vector,
np.cross(self.base_normal, self.back_base_vector),
self.base_normal,
)
@property
def quaternions(self) -> typ.Arr_Nucleotide_4:
"""Convert Euler angles to quaternions."""
return jdm.euler_angles_to_quaternion(*self.euler_angles)
[docs]
def to_rigid_body(self) -> jax_md.rigid_body.RigidBody:
"""Convert the nucleotide state to jax-md rigid bodies."""
return jax_md.rigid_body.RigidBody(
self.com,
jax_md.rigid_body.Quaternion(vec=self.quaternions),
)
[docs]
def validate_box_size(state_box_sizes: list[typ.Vector3D]) -> None:
"""Validate the volume for a simulation is fixed."""
state_box_sizes = np.array(state_box_sizes)
if not np.all(state_box_sizes == state_box_sizes[0]):
raise ValueError(ERR_FIXED_BOX_SIZE)
[docs]
def from_file(
path: typ.PathOrStr,
strand_lengths: list[int],
*,
is_5p_3p: bool = True,
n_processes: int = 1,
) -> Trajectory:
"""Parse a trajectory file.
Trajectory files are in the following format:
t = number
b = number number number
E = number number number
com_x com_y com_z a1_x a1_y a1_z a3_x a3_y a3_z v_x v_y v_z L_x L_y L_z
...repeated n_nucleotides times in total
com_x com_y com_z a1_x a1_y a1_z a3_x a3_y a3_z v_x v_y v_z L_x L_y L_z
where the com_x, ..., L_z are all floating point numbers.
This can be repeated a total of "timestep" number of times.
In oxDNA the states are stored in 3'->5' in the original format. We use that
format for the internal memory layout. When the new oxdna topology format is
used, it will write in 5'->3' order, and thus we must reverse the order per
strand.
Args:
path (PathOrStr): path to the trajectory file
strand_lengths (list[int]): if this is an oxDNA trajectory,
the lengths of each strand, so that they can be flipped to 5'->3'
order
is_5p_3p (bool): whether the trajectory is in 5'->3' format (for example
if the topology file used in oxdna standalone is in the new oxdna format)
n_processes (int): number of processors to use for reading the file
Returns:
Trajectory: trajectory information
"""
path = Path(path)
if not path.exists():
raise FileNotFoundError(ERR_TRAJECTORY_FILE_NOT_FOUND.format(path))
if n_processes == 1:
ts, bs, es, states = _read_file(path, 0, path.stat().st_size, strand_lengths, is_5p_3p=is_5p_3p)
else:
ts, bs, es, states = _read_parallel(path, strand_lengths, is_5p_3p=is_5p_3p, n_processes=n_processes)
validate_box_size(bs)
return Trajectory(
box_size=bs[0], # from above, we know all box sizes are the same
n_nucleotides=sum(strand_lengths),
strand_lengths=strand_lengths,
times=np.array(ts, dtype=np.float64),
energies=np.array(es, dtype=np.float64),
states=[NucleotideState(array=s) for s in states],
)
[docs]
def _read_parallel(path: Path, strand_lengths: list[int], *, is_5p_3p: bool, n_processes: int) -> RawTrajectory:
boundaries = np.linspace(0, path.stat().st_size, n_processes + 1, dtype=np.int64)
n_runs = len(boundaries) - 1
with cf.ProcessPoolExecutor(n_processes, mp_context=mp.get_context("spawn")) as pool:
vals = pool.map(
_read_file_process_wrapper,
[(path, boundaries[i], boundaries[i + 1], strand_lengths, is_5p_3p) for i in range(n_runs)],
)
# this is now an list of iterables where each iterable is a concatenated
# list of the output of _read_file for each process
return (list(itertools.chain.from_iterable(v)) for v in zip(*vals, strict=True))
[docs]
def _read_file_process_wrapper(args: tuple[Path, int, int, list[int], bool]) -> RawTrajectory:
"""Wrapper for reading a trajectory file."""
file_path, start, end, strand_lengths, is_5p_3p = args
return _read_file(file_path, start, end, strand_lengths, is_5p_3p=is_5p_3p)
[docs]
def _read_file(file_path: Path, start: int, end: int, strand_lengths: list[int], *, is_5p_3p: bool) -> RawTrajectory:
"""Read a trajectory file object."""
# we don't know where we are in the file, but we can be only in one of two
# situations: We are at the start of the state or we are in the midle of a
# state. If we are in the middle of a state, we need to read until the next
# state starts and then parse the states from there. Importantly, we need
# to pass our 'end' if the end is in the middle of a state, because the
# worker ahead of in the file will not read it.
parse_str = functools.partial(np.fromstring, sep=" ", dtype=np.float64)
state_length = sum(strand_lengths)
strand_bounds = list(itertools.pairwise([0, *itertools.accumulate(strand_lengths)]))
file_obj = file_path.open()
file_obj.seek(start)
line = file_obj.readline()
while not line.startswith("t"):
line = file_obj.readline()
ts, bs, es, states = [], [], [], []
state = []
current = file_obj.tell()
while current < end:
if line[0] == "t":
t = float(line.strip().split("=")[1])
ts.append(t)
elif line[0] == "b":
b = parse_str(line.strip().split("=")[1])
bs.append(b)
elif line[0] == "E":
e = parse_str(line.strip().split("=")[1])
es.append(e)
else:
state.append(parse_str(line.strip()))
if len(state) == state_length:
# if the trajectory is stored in 3'->5' order, we need to flip
# the order of the nucleotides in each strand
if is_5p_3p:
state = list(itertools.chain.from_iterable([state[s:e][::-1] for s, e in strand_bounds]))
state = np.array(state, dtype=np.float64)
states.append(np.array(state, dtype=np.float64))
state = []
current = file_obj.tell()
line = file_obj.readline()
return ts, bs, es, states
[docs]
def _write_state(
file: TextIO,
time: float,
energies: typ.Vector3D,
state: typ.Arr_Nucleotide_15,
box_size: typ.Vector3D = (0, 0, 0)
) -> None:
file.write(f"t = {time}\n")
file.write(f"b = {box_size[0]} {box_size[1]} {box_size[2]}\n")
file.write(f"E = {energies[0]} {energies[1]} {energies[2]}\n")
for nucleotide in state:
row = " ".join(map(str, nucleotide))
file.write(f"{row}\n")