"""Topological information for DNA/RNA."""
import dataclasses as dc
import itertools
import warnings
from collections.abc import Callable
from enum import IntEnum
from pathlib import Path
import jax.numpy as jnp
import numpy as np
import mythos.utils.constants as jd_const
import mythos.utils.types as typ
N_1ST_LINE_OXDNA_CLASSIC = 2
N_1ST_LINE_OXDNA_NEW = 3
VALID_NEIGHBOR_SECOND_DIM = 2
ERR_TOPOLOGY_INVALID_NUMBER_NUCLEOTIDES = "Invalid number of nucleotides"
ERR_TOPOLOGY_INVALID_STRAND_COUNTS = "Invalid strand counts"
ERR_TOPOLOGY_SEQ_NOT_MATCH_NUCLEOTIDES = "Sequence does not match number of nucleotides"
ERR_TOPOLOGY_STRAND_COUNTS_NOT_MATCH = "Strand counts do not match number of nucleotides"
ERR_TOPOLOGY_BONDED_NEIGHBORS_INVALID_SHAPE = "Invalid bonded neighbors shape"
ERR_TOPOLOGY_UNBONDED_NEIGHBORS_INVALID_SHAPE = "Invalid unbonded neighbors shape"
ERR_TOPOLOGY_INVALID_SEQUENCE_LENGTH = "Invalid sequence length"
ERR_TOPOLOGY_INVALID_SEQUENCE_NUCLEOTIDES = "Invalid sequence nucleotides"
ERR_TOPOLOGY_INVALID_DISCRETE_SEQUENCE_SHAPE = "Invalid discrete sequence shape"
ERR_TOPOLOGY_INVALID_UNPAIRED_PSEQ_SHAPE = "Invalid unpaired probabilistic sequence shape"
ERR_TOPOLOGY_MISMATCH_PSEQ_SHAPE_NUM_NUCLEOTIDES = "Pseq shape does not match number of nucleotides"
ERR_TOPOLOGY_INVALID_BP_PSEQ_SHAPE = "Invalid base-paired probabilistic sequence shape"
ERR_TOPOLOGY_INVALID_PROBABILITIES = "Probabilities must be > 0"
ERR_TOPOLOGY_PSEQ_NOT_NORMALIZED = "Probabilities must be normalized"
ERR_INVALID_OXDNA_FORMAT = (
"Invalid oxDNA topology format. See "
"https://lorenzo-rovigatti.github.io/oxDNA/configurations.html#topology-file for more information."
)
ERR_STRAND_COUNTS_CIRCULAR_MISMATCH = "Strand counts and cicularity do not match"
ERR_FILE_NOT_FOUND = "Topology file not found"
ERR_TOPOLOGY_INVALID_SEQUENCE_TYPE = "Invalid sequence type. Must be discrete or probabilistic"
WARN_CLASSIC_UNSPECIFIED_NT_TYPE = "Type of strand {strand_idx} not specified, and did not find T/U for autodetect"
WARN_NEW_UNSPECIFIED_NT_TYPE = "Type of strand {strand_idx} not specified"
[docs]
class NucleotideType(IntEnum):
"""Defines nucleotide types as integer enumerations."""
UNSPECIFIED = 0
DNA = 1
RNA = 2
[docs]
def check_valid_seq(seq: typ.Sequence, n_nucleotides: int) -> None:
"""Checks if a sequence is well-formed."""
if isinstance(seq, typ.Discrete_Sequence):
if len(set(np.array(seq)) - {0, 1, 2, 3}) > 0:
raise ValueError(ERR_TOPOLOGY_INVALID_SEQUENCE_NUCLEOTIDES)
if seq.shape != (n_nucleotides,):
raise ValueError(ERR_TOPOLOGY_INVALID_DISCRETE_SEQUENCE_SHAPE)
elif isinstance(seq, tuple) and len(seq) == jd_const.TWO_DIMENSIONS: # typ.Probabilistic_Sequence
up_pseq, bp_pseq = seq
if len(up_pseq.shape) != jd_const.TWO_DIMENSIONS or up_pseq.shape[1] != jd_const.N_NT:
raise ValueError(ERR_TOPOLOGY_INVALID_UNPAIRED_PSEQ_SHAPE)
if len(bp_pseq.shape) != jd_const.TWO_DIMENSIONS or bp_pseq.shape[1] != jd_const.N_BP_TYPES:
raise ValueError(ERR_TOPOLOGY_INVALID_BP_PSEQ_SHAPE)
n_unpaired = up_pseq.shape[0]
n_bp = bp_pseq.shape[0]
if n_unpaired + jd_const.N_NT_PER_BP * n_bp != n_nucleotides:
raise ValueError(ERR_TOPOLOGY_MISMATCH_PSEQ_SHAPE_NUM_NUCLEOTIDES)
if (up_pseq < 0).any() or (bp_pseq < 0).any():
raise ValueError(ERR_TOPOLOGY_INVALID_PROBABILITIES)
if (not np.allclose(np.sum(up_pseq, axis=1), 1)) or (not np.allclose(np.sum(bp_pseq, axis=1), 1)):
raise ValueError(ERR_TOPOLOGY_PSEQ_NOT_NORMALIZED)
else:
raise ValueError(ERR_TOPOLOGY_INVALID_SEQUENCE_TYPE)
[docs]
@dc.dataclass(frozen=True)
class Topology:
"""Topology information for a RNA/DNA strand."""
n_nucleotides: int
strand_counts: np.ndarray
bonded_neighbors: np.ndarray
unbonded_neighbors: np.ndarray
seq: typ.Sequence
is_end: np.ndarray
nt_type: np.ndarray
[docs]
def __post_init__(self) -> None:
"""Check that the topology is valid."""
if self.n_nucleotides < 1:
raise ValueError(ERR_TOPOLOGY_INVALID_NUMBER_NUCLEOTIDES)
if len(self.strand_counts) == 0 or sum(self.strand_counts) == 0:
raise ValueError(ERR_TOPOLOGY_INVALID_STRAND_COUNTS)
if self.n_nucleotides != sum(self.strand_counts):
raise ValueError(ERR_TOPOLOGY_STRAND_COUNTS_NOT_MATCH)
if (
len(self.bonded_neighbors.shape) != VALID_NEIGHBOR_SECOND_DIM
or self.bonded_neighbors.shape[1] != VALID_NEIGHBOR_SECOND_DIM
):
raise ValueError(ERR_TOPOLOGY_BONDED_NEIGHBORS_INVALID_SHAPE)
if (
len(self.unbonded_neighbors.shape) != VALID_NEIGHBOR_SECOND_DIM
or self.unbonded_neighbors.shape[1] != VALID_NEIGHBOR_SECOND_DIM
):
raise ValueError(ERR_TOPOLOGY_UNBONDED_NEIGHBORS_INVALID_SHAPE)
check_valid_seq(self.seq, self.n_nucleotides)
[docs]
def from_oxdna_file(path: typ.PathOrStr, *, return_format: bool = False) -> Topology | tuple[Topology, typ.oxDNAFormat]:
"""Read topology information from an oxDNA file.
Args:
path (typ.PathOrStr): Path to the oxDNA topology file.
return_format (bool, optional): Whether to return the oxDNA format along
with the topology as a tuple (top, fmt). Defaults to False.
Returns:
Topology or tuple[Topology, typ.oxDNAFormat]: The parsed topology, and
optionally the format.
"""
path = Path(path)
if not path.exists():
raise FileNotFoundError(ERR_FILE_NOT_FOUND)
with path.open() as f:
lines = f.readlines()
fmt, parse_f = _determine_oxdna_format(lines[0])
if return_format:
return parse_f(lines), fmt
return parse_f(lines)
[docs]
def _get_bonded_neighbors(
strand_lengths: list[int],
is_circular: list[bool],
) -> list[tuple[int, int]]:
"""Convert 5' neighbors to bonded neighbors by index."""
if len(strand_lengths) != len(is_circular):
raise ValueError(ERR_STRAND_COUNTS_CIRCULAR_MISMATCH)
bonded_neighbors = []
init_idx = 0
for i, length in enumerate(strand_lengths):
pairs = list(itertools.pairwise(range(init_idx, init_idx + length)))
if is_circular[i]:
# the ordering here in intentional
pairs.append((init_idx, init_idx + length - 1))
bonded_neighbors.extend(pairs)
init_idx += length
return bonded_neighbors
[docs]
def _get_unbonded_neighbors(n_nucleotides: int, bonded_neighbors: list[tuple[int, int]]) -> list[tuple[int, int]]:
"""Get unbonded neighbors."""
all_possible_pairs = set(itertools.combinations(range(n_nucleotides), 2))
self_bonds = {(i, i) for i in range(n_nucleotides)}
return list(all_possible_pairs - set(bonded_neighbors) - self_bonds)
[docs]
def _from_file_oxdna_classic(lines: list[str]) -> Topology:
"""Read topology information from a file in the classix oxDNA format.
See https://lorenzo-rovigatti.github.io/oxDNA/configurations.html#topology-file
for more information.
Args:
lines (List[str]): lines from topology file
Returns:
Topology: Topology object
"""
n_nucleotides, n_strands = list(map(int, lines[0].strip().split()))
# after the first line the topology files are space delimited with the
# following columns:
# - strand id (1 indexed)
# - nucleotide base (A=0, C=1, G=2, T=3, U=3), use char for now
# - 3' neighbor (0-indexed), -1 if none, -1 indicates the stand isn't circular
# - 5' neighbor (0-indexed), -1 if none
#
strand_ids, bases, _, neighbor_5p = list(zip(*[line.strip().split() for line in lines[1:]], strict=True))
strand_ids = list(map(int, strand_ids))
_, strand_counts = np.unique(strand_ids, return_counts=True)
neighbor_5p = list(map(int, neighbor_5p))
reversed_bases = []
is_circular = []
is_end = []
nt_type = []
for i in range(1, n_strands + 1):
strand_bases, strand_5p = zip(
*[
id_nucleotide[1:]
for id_nucleotide in zip(strand_ids, bases, neighbor_5p, strict=True)
if id_nucleotide[0] == i
],
strict=True,
)
strand_is_circular = strand_5p[-1] != -1
is_circular.append(strand_is_circular)
reversed_bases.extend(strand_bases)
strand_length = len(strand_bases)
strand_is_end = [0 for _ in range(strand_length)]
if not strand_is_circular:
strand_is_end[0] = 1
strand_is_end[-1] = 1
is_end.extend(strand_is_end)
# Classic format doesn't permit type specification
if "T" in strand_bases:
nt_type.extend([NucleotideType.DNA] * strand_length)
elif "U" in strand_bases:
nt_type.extend([NucleotideType.RNA] * strand_length)
else:
warnings.warn(WARN_CLASSIC_UNSPECIFIED_NT_TYPE.format(strand_idx=i), stacklevel=1)
nt_type.extend([NucleotideType.UNSPECIFIED] * strand_length)
sequence = "".join(reversed_bases)
bonded_neighbors = _get_bonded_neighbors(strand_counts, is_circular)
unbonded_neighbors = _get_unbonded_neighbors(n_nucleotides, bonded_neighbors)
return Topology(
n_nucleotides=n_nucleotides,
strand_counts=strand_counts,
bonded_neighbors=np.array(list(bonded_neighbors)),
unbonded_neighbors=np.array(list(unbonded_neighbors)),
seq=jnp.array([jd_const.NUCLEOTIDES_IDX[s] for s in sequence], dtype=jnp.int32),
is_end=jnp.array(is_end).astype(jnp.int32),
nt_type=jnp.array(nt_type).astype(jnp.int32),
)
[docs]
def _from_file_oxdna_new(lines: list[str]) -> Topology:
# the first line of the new oxDNA format is:
# n_nucleotides n_strands 5->3
# we don't need the 5->3, so we'll just ignore it
n_nucleotides, n_strands = list(map(int, lines[0].strip().split()[:-1]))
# the rest of the new oxDNA file format is laid out as follows:
# nucleotides k=v
# ...
# nucleotides k=v
# Where `nuclotides` is a string of ACTG and `k=v` is a set of key value pairs
# the lines are repeated n_stand times
sequence = []
strand_counts = []
is_circular = []
is_end = []
nt_type = []
for line in lines[1:]:
nucleotides = line.strip().split()[0]
sequence.append(nucleotides[::-1]) # reverse the sequence to go 5'->3' to 3'->5'
strand_counts.append(len(nucleotides))
strand_is_circular = "circular=true" in line
is_circular.append(strand_is_circular)
strand_length = len(nucleotides)
strand_is_end = [0 for _ in range(strand_length)]
if not strand_is_circular:
strand_is_end[0] = 1
strand_is_end[-1] = 1
is_end.extend(strand_is_end)
strand_is_dna = "type=DNA" in line
strand_is_rna = "type=RNA" in line
if strand_is_dna:
strand_nt_type = [NucleotideType.DNA] * strand_length
elif strand_is_rna:
strand_nt_type = [NucleotideType.RNA] * strand_length
else:
warnings.warn(WARN_NEW_UNSPECIFIED_NT_TYPE.format(strand_idx=line), stacklevel=1)
strand_nt_type = [NucleotideType.UNSPECIFIED] * strand_length
nt_type.extend(strand_nt_type)
sequence = "".join(sequence)
bonded_neighbors = _get_bonded_neighbors(strand_counts, is_circular)
unbonded_neighbors = _get_unbonded_neighbors(n_nucleotides, bonded_neighbors)
return Topology(
n_nucleotides=n_nucleotides,
strand_counts=np.array(strand_counts),
bonded_neighbors=np.array(bonded_neighbors),
unbonded_neighbors=np.array(unbonded_neighbors),
seq=jnp.array([jd_const.NUCLEOTIDES_IDX[s] for s in sequence], dtype=jnp.int32),
is_end=jnp.array(is_end).astype(jnp.int32),
nt_type=jnp.array(nt_type).astype(jnp.int32),
)