Source code for mythos.input.trajectory

"""Trajectory information for RNA/DNA strands."""

import concurrent.futures as cf
import functools
import itertools
import multiprocessing as mp
from pathlib import Path
from typing import TextIO, TypeAlias

import chex
import jax.numpy as jnp
import jax_md
import numpy as np

import mythos.utils.math as jdm
import mythos.utils.types as typ

TRAJECTORY_TIMES_DIMS = 1
TRAJECTORY_ENERGIES_SHAPE = (None, 3)
NUCLEOTIDE_STATE_SHAPE = (None, 15)

ERR_TRAJECTORY_FILE_NOT_FOUND = "Trajectory file not found: {}"
ERR_TRAJECTORY_N_NUCLEOTIDE_STRAND_LEGNTHS = "n_nucleotides and sum(strand_lengths) do not match"
ERR_TRAJECTORY_TIMES_TYPE = "times must be a numpy array"
ERR_TRAJECTORY_ENERGIES_TYPE = "energies must be a numpy array"
ERR_TRAJECTORY_T_E_S_LENGTHS = "times, energies, and states do not have the same length"
ERR_TRAJECTORY_TIMES_DIMS = "times must be a 1D array"
ERR_TRAJECTORY_ENERGIES_SHAPE = "energies must be a 2D array with shape (n_states, 3)"

ERR_NUCLEOTIDE_STATE_TYPE = "Invalid type for nucleotide states:"
ERR_NUCLEOTIDE_STATE_SHAPE = "Invalid shape for nucleotide states:"

ERR_FIXED_BOX_SIZE = "Only trajecories in a fixed box size are supported"


RawTrajectory: TypeAlias = tuple[list[typ.Scalar], list[typ.Vector3D], list[typ.Vector3D], list[typ.Arr_Nucleotide_15]]


[docs] @chex.dataclass(frozen=True) class Trajectory: """Trajectory information for a RNA/DNA strand.""" n_nucleotides: int strand_lengths: list[int] times: typ.Arr_States energies: typ.Arr_States_3 states: list["NucleotideState"] box_size: typ.Vector3D | None = None
[docs] def __post_init__(self) -> None: """Validate the input.""" if self.n_nucleotides != sum(self.strand_lengths): raise ValueError(ERR_TRAJECTORY_N_NUCLEOTIDE_STRAND_LEGNTHS) if not isinstance(self.times, np.ndarray): raise TypeError(ERR_TRAJECTORY_TIMES_TYPE) if not isinstance(self.energies, np.ndarray): raise TypeError(ERR_TRAJECTORY_ENERGIES_TYPE) if len(self.times) != len(self.energies) or len(self.times) != len(self.states): raise ValueError(ERR_TRAJECTORY_T_E_S_LENGTHS) if len(self.times.shape) != TRAJECTORY_TIMES_DIMS: raise ValueError(ERR_TRAJECTORY_TIMES_DIMS) if ( len(self.energies.shape) != len(TRAJECTORY_ENERGIES_SHAPE) or self.energies.shape[1] != TRAJECTORY_ENERGIES_SHAPE[1] ): raise ValueError(ERR_TRAJECTORY_ENERGIES_SHAPE)
@property def state_rigid_bodies(self) -> list[jax_md.rigid_body.RigidBody]: """Convert the states to a list of rigid bodies.""" return [state.to_rigid_body() for state in self.states] @property def state_rigid_body(self) -> jax_md.rigid_body.RigidBody: """Convert the states to a single rigid body.""" return jax_md.rigid_body.RigidBody( center=jnp.stack([state.com for state in self.states]), orientation=jax_md.rigid_body.Quaternion(jnp.stack([state.quaternions for state in self.states])), )
[docs] def slice(self, key: int | slice) -> "Trajectory": """Get a subset of the trajectory.""" return Trajectory( n_nucleotides=self.n_nucleotides, strand_lengths=self.strand_lengths, times=self.times[key], energies=self.energies[key], states=self.states[key], )
[docs] def __repr__(self) -> str: """Return a string representation of the trajectory.""" return "\n".join( [ "Trajectory:", f"n_nucleotides: {self.n_nucleotides}", f"strand_lengths: {self.strand_lengths}", f"# times: {len(self.times)}", f"# energies: {len(self.energies)}", f"# states: {len(self.states)}", ] )
[docs] def to_file(self, filepath: Path) -> None: """Write a jaxDNA simulation trajectory to oxDNA file format. In cases where the box_size is not specified, it will be written as "0 0 0". """ with Path(filepath).open("w") as f: for state_i in range(len(self.times)): _write_state( file=f, time=self.times[state_i], energies=self.energies[state_i], state=self.states[state_i].array, box_size=self.box_size, )
[docs] @chex.dataclass(frozen=True) class NucleotideState: """State information for the nucleotides in a single state.""" array: typ.Arr_Nucleotide_15
[docs] def __post_init__(self) -> None: """Validate the input array.""" if not isinstance(self.array, np.ndarray): raise TypeError(ERR_NUCLEOTIDE_STATE_TYPE + str(type(self.array))) if len(self.array.shape) != len(NUCLEOTIDE_STATE_SHAPE) or self.array.shape[1] != NUCLEOTIDE_STATE_SHAPE[1]: raise ValueError(ERR_NUCLEOTIDE_STATE_SHAPE + str(self.array.shape))
@property def com(self) -> typ.Arr_Nucleotide_3: """Center of mass of the nucleotides.""" return self.array[:, :3] @property def back_base_vector(self) -> typ.Arr_Nucleotide_3: """Backbone base vector.""" return self.array[:, 3:6] @property def base_normal(self) -> typ.Arr_Nucleotide_3: """Base normal to the base plane.""" return self.array[:, 6:9] @property def velocity(self) -> typ.Arr_Nucleotide_3: """Velocity of the nucleotides.""" return self.array[:, 9:12] @property def angular_velocity(self) -> typ.Arr_Nucleotide_3: """Angular velocity of the nucleotides.""" return self.array[:, 12:15] @property def euler_angles(self) -> tuple[typ.Arr_Nucleotide, typ.Arr_Nucleotide, typ.Arr_Nucleotide]: """Convert principal axes to Tait-Bryan Euler angles.""" return jdm.principal_axes_to_euler_angles( self.back_base_vector, np.cross(self.base_normal, self.back_base_vector), self.base_normal, ) @property def quaternions(self) -> typ.Arr_Nucleotide_4: """Convert Euler angles to quaternions.""" return jdm.euler_angles_to_quaternion(*self.euler_angles)
[docs] def to_rigid_body(self) -> jax_md.rigid_body.RigidBody: """Convert the nucleotide state to jax-md rigid bodies.""" return jax_md.rigid_body.RigidBody( self.com, jax_md.rigid_body.Quaternion(vec=self.quaternions), )
[docs] def validate_box_size(state_box_sizes: list[typ.Vector3D]) -> None: """Validate the volume for a simulation is fixed.""" state_box_sizes = np.array(state_box_sizes) if not np.all(state_box_sizes == state_box_sizes[0]): raise ValueError(ERR_FIXED_BOX_SIZE)
[docs] def from_file( path: typ.PathOrStr, strand_lengths: list[int], *, is_5p_3p: bool = True, n_processes: int = 1, ) -> Trajectory: """Parse a trajectory file. Trajectory files are in the following format: t = number b = number number number E = number number number com_x com_y com_z a1_x a1_y a1_z a3_x a3_y a3_z v_x v_y v_z L_x L_y L_z ...repeated n_nucleotides times in total com_x com_y com_z a1_x a1_y a1_z a3_x a3_y a3_z v_x v_y v_z L_x L_y L_z where the com_x, ..., L_z are all floating point numbers. This can be repeated a total of "timestep" number of times. In oxDNA the states are stored in 3'->5' in the original format. We use that format for the internal memory layout. When the new oxdna topology format is used, it will write in 5'->3' order, and thus we must reverse the order per strand. Args: path (PathOrStr): path to the trajectory file strand_lengths (list[int]): if this is an oxDNA trajectory, the lengths of each strand, so that they can be flipped to 5'->3' order is_5p_3p (bool): whether the trajectory is in 5'->3' format (for example if the topology file used in oxdna standalone is in the new oxdna format) n_processes (int): number of processors to use for reading the file Returns: Trajectory: trajectory information """ path = Path(path) if not path.exists(): raise FileNotFoundError(ERR_TRAJECTORY_FILE_NOT_FOUND.format(path)) if n_processes == 1: ts, bs, es, states = _read_file(path, 0, path.stat().st_size, strand_lengths, is_5p_3p=is_5p_3p) else: ts, bs, es, states = _read_parallel(path, strand_lengths, is_5p_3p=is_5p_3p, n_processes=n_processes) validate_box_size(bs) return Trajectory( box_size=bs[0], # from above, we know all box sizes are the same n_nucleotides=sum(strand_lengths), strand_lengths=strand_lengths, times=np.array(ts, dtype=np.float64), energies=np.array(es, dtype=np.float64), states=[NucleotideState(array=s) for s in states], )
[docs] def _read_parallel(path: Path, strand_lengths: list[int], *, is_5p_3p: bool, n_processes: int) -> RawTrajectory: boundaries = np.linspace(0, path.stat().st_size, n_processes + 1, dtype=np.int64) n_runs = len(boundaries) - 1 with cf.ProcessPoolExecutor(n_processes, mp_context=mp.get_context("spawn")) as pool: vals = pool.map( _read_file_process_wrapper, [(path, boundaries[i], boundaries[i + 1], strand_lengths, is_5p_3p) for i in range(n_runs)], ) # this is now an list of iterables where each iterable is a concatenated # list of the output of _read_file for each process return (list(itertools.chain.from_iterable(v)) for v in zip(*vals, strict=True))
[docs] def _read_file_process_wrapper(args: tuple[Path, int, int, list[int], bool]) -> RawTrajectory: """Wrapper for reading a trajectory file.""" file_path, start, end, strand_lengths, is_5p_3p = args return _read_file(file_path, start, end, strand_lengths, is_5p_3p=is_5p_3p)
[docs] def _read_file(file_path: Path, start: int, end: int, strand_lengths: list[int], *, is_5p_3p: bool) -> RawTrajectory: """Read a trajectory file object.""" # we don't know where we are in the file, but we can be only in one of two # situations: We are at the start of the state or we are in the midle of a # state. If we are in the middle of a state, we need to read until the next # state starts and then parse the states from there. Importantly, we need # to pass our 'end' if the end is in the middle of a state, because the # worker ahead of in the file will not read it. parse_str = functools.partial(np.fromstring, sep=" ", dtype=np.float64) state_length = sum(strand_lengths) strand_bounds = list(itertools.pairwise([0, *itertools.accumulate(strand_lengths)])) file_obj = file_path.open() file_obj.seek(start) line = file_obj.readline() while not line.startswith("t"): line = file_obj.readline() ts, bs, es, states = [], [], [], [] state = [] current = file_obj.tell() while current < end: if line[0] == "t": t = float(line.strip().split("=")[1]) ts.append(t) elif line[0] == "b": b = parse_str(line.strip().split("=")[1]) bs.append(b) elif line[0] == "E": e = parse_str(line.strip().split("=")[1]) es.append(e) else: state.append(parse_str(line.strip())) if len(state) == state_length: # if the trajectory is stored in 3'->5' order, we need to flip # the order of the nucleotides in each strand if is_5p_3p: state = list(itertools.chain.from_iterable([state[s:e][::-1] for s, e in strand_bounds])) state = np.array(state, dtype=np.float64) states.append(np.array(state, dtype=np.float64)) state = [] current = file_obj.tell() line = file_obj.readline() return ts, bs, es, states
[docs] def _write_state( file: TextIO, time: float, energies: typ.Vector3D, state: typ.Arr_Nucleotide_15, box_size: typ.Vector3D = (0, 0, 0) ) -> None: file.write(f"t = {time}\n") file.write(f"b = {box_size[0]} {box_size[1]} {box_size[2]}\n") file.write(f"E = {energies[0]} {energies[1]} {energies[2]}\n") for nucleotide in state: row = " ".join(map(str, nucleotide)) file.write(f"{row}\n")