Source code for mythos.input.sequence_constraints

"""Sequence constraint information for DNA/RNA."""

import chex
import jax.numpy as jnp
import numpy as np

import mythos.utils.constants as jd_const
import mythos.utils.types as typ

ERR_SEQ_CONSTRAINTS_INVALID_NUMBER_NUCLEOTIDES = "Invalid number of nucleotides"
ERR_SEQ_CONSTRAINTS_INVALID_UNPAIRED_SHAPE = "Invalid shape for unpaired nucleotides"
ERR_INVALID_BP_SHAPE = "Invalid shape for base pairs"
ERR_SEQ_CONSTRAINTS_INVALID_IS_UNPAIRED_SHAPE = "Invalid shape for array specifying if unpaired"
ERR_SEQ_CONSTRAINTS_INVALID_UNPAIRED_MAPPER_SHAPE = "Invalid shape for unpaired nucleotide index mapper"
ERR_SEQ_CONSTRAINTS_INVALID_BP_MAPPER_SHAPE = "Invalid shape for base pair index mapper"
ERR_SEQ_CONSTRAINTS_MISMATCH_NUM_TYPES = (
    "Number of nucleotides should equal the number of unpaired base pairs plus the number of coupled base pairs"
)
ERR_SEQ_CONSTRAINTS_INVALID_COVER = "Unpaired and coupled nucleotides do not cover all nucleotides"
ERR_SEQ_CONSTRAINTS_IS_UNPAIRED_INVALID_VALUES = (
    "Array specifying if unpaired contains invalid values, can only be one-hot"
)
ERR_SEQ_CONSTRAINTS_INVALID_IS_UNPAIRED = "Array specifying if is_unpaired disagrees with list of unpaired nucleotides"
ERR_SEQ_CONSTRAINTS_PAIRED_NT_MAPPED_TO_UNPAIRED = "Base paired nucleotides cannot be mapped to an unpaired nucleotide"
ERR_SEQ_CONSTRAINTS_INCOMPLETE_UNPAIRED_MAPPED_IDXS = (
    "Map of position indices to indices of unpaired nucleotides does not cover number of unpaired nucleotides"
)
ERR_SEQ_CONSTRAINTS_UNPAIRED_NT_MAPPED_TO_PAIRED = "Unpaired nucleotides cannot be mapped to a base paired nucleotide"
ERR_SEQ_CONSTRAINTS_INCOMPLETE_BP_MAPPED_IDXS = (
    "Map of position indices to indices of base paired nucleotides does not cover number of base paired nucleotides"
)
ERR_BP_ARR_CONTAINS_DUPLICATES = "Array specifying base paired indices cannot contain duplicates"
ERR_INVALID_BP_INDICES = "Base paired indices must be between 0 and n_nucleotides-1"
ERR_DSEQ_TO_PSEQ_INVALID_BP = (
    "Invalid base pair encountered when converting discrete sequence to probabilistic sequence"
)


[docs] def check_consistent_constraints( n_unpaired: int, n_bp: int, unpaired: typ.Arr_Unpaired, idx_to_unpaired_idx: typ.Arr_Nucleotide_Int, idx_to_bp_idx: typ.Arr_Nucleotide_2_Int, ) -> None: """Checks for consistency between specified nucleotide constraints and index mappers.""" unpaired_mapped_idxs = [] for idx, mapped_idx in enumerate(np.array(idx_to_unpaired_idx)): if idx in set(np.array(unpaired)): unpaired_mapped_idxs.append(mapped_idx) elif mapped_idx != -1: raise ValueError(ERR_SEQ_CONSTRAINTS_PAIRED_NT_MAPPED_TO_UNPAIRED) if set(unpaired_mapped_idxs) != set(np.arange(n_unpaired)): raise ValueError(ERR_SEQ_CONSTRAINTS_INCOMPLETE_UNPAIRED_MAPPED_IDXS) bp_mapped_idxs = [] for idx, (mapped_idx1, mapped_idx2) in enumerate(np.array(idx_to_bp_idx)): if idx not in set(np.array(unpaired)): bp_mapped_idxs.append((mapped_idx1, mapped_idx2)) elif mapped_idx1 != -1 or mapped_idx2 != -1: raise ValueError(ERR_SEQ_CONSTRAINTS_UNPAIRED_NT_MAPPED_TO_PAIRED) expected_bp_idxs = [(bp_idx, 0) for bp_idx in range(n_bp)] expected_bp_idxs += [(bp_idx, 1) for bp_idx in range(n_bp)] if set(bp_mapped_idxs) != set(expected_bp_idxs): raise ValueError(ERR_SEQ_CONSTRAINTS_INCOMPLETE_BP_MAPPED_IDXS)
[docs] def check_cover(n_nucleotides: int, n_unpaired: int, n_bp: int, unpaired: typ.Arr_Unpaired, bps: typ.Arr_Bp) -> None: """Checks if unpaired and paired nucleotides cover the entire set of nucleotides.""" if n_unpaired + 2 * n_bp != n_nucleotides: raise ValueError(ERR_SEQ_CONSTRAINTS_MISMATCH_NUM_TYPES) if set(np.concatenate([unpaired, bps.flatten()])) != set(np.arange(n_nucleotides)): raise ValueError(ERR_SEQ_CONSTRAINTS_INVALID_COVER)
[docs] @chex.dataclass(frozen=True) class SequenceConstraints: """Constraint information for a RNA/DNA strand.""" n_nucleotides: int n_unpaired: int n_bp: int is_unpaired: typ.Arr_Nucleotide_Int unpaired: typ.Arr_Unpaired bps: typ.Arr_Bp idx_to_unpaired_idx: typ.Arr_Nucleotide_Int idx_to_bp_idx: typ.Arr_Nucleotide_2_Int
[docs] def __post_init__(self) -> None: """Check that the sequence constraints are valid.""" # Check valid numbers if self.n_nucleotides < 1: raise ValueError(ERR_SEQ_CONSTRAINTS_INVALID_NUMBER_NUCLEOTIDES) # Check valid shapes if self.unpaired.shape != (self.n_unpaired,): raise ValueError(ERR_SEQ_CONSTRAINTS_INVALID_UNPAIRED_SHAPE) if self.bps.shape != (self.n_bp, 2): raise ValueError(ERR_INVALID_BP_SHAPE) if self.is_unpaired.shape != (self.n_nucleotides,): raise ValueError(ERR_SEQ_CONSTRAINTS_INVALID_IS_UNPAIRED_SHAPE) if self.idx_to_unpaired_idx.shape != (self.n_nucleotides,): raise ValueError(ERR_SEQ_CONSTRAINTS_INVALID_UNPAIRED_MAPPER_SHAPE) if self.idx_to_bp_idx.shape != (self.n_nucleotides, 2): raise ValueError(ERR_SEQ_CONSTRAINTS_INVALID_BP_MAPPER_SHAPE) # Check cover check_cover(self.n_nucleotides, self.n_unpaired, self.n_bp, self.unpaired, self.bps) # Check values if not set(np.array(self.is_unpaired)).issubset({0, 1}): raise ValueError(ERR_SEQ_CONSTRAINTS_IS_UNPAIRED_INVALID_VALUES) for idx, idx_unpaired in enumerate(self.is_unpaired): valid = idx_unpaired == 1 if idx in set(np.array(self.unpaired)) else idx_unpaired == 0 if not valid: raise ValueError(ERR_SEQ_CONSTRAINTS_INVALID_IS_UNPAIRED) check_consistent_constraints( self.n_unpaired, self.n_bp, self.unpaired, self.idx_to_unpaired_idx, self.idx_to_bp_idx, )
[docs] def from_bps(n_nucleotides: int, bps: typ.Arr_Bp) -> SequenceConstraints: """Construct a SequenceConstraints object from a set of base pairs.""" # Check format of base pairs if ( len(bps.shape) != jd_const.TWO_DIMENSIONS or bps.shape[1] != jd_const.N_NT_PER_BP or jd_const.N_NT_PER_BP * bps.shape[0] > n_nucleotides ): raise ValueError(ERR_INVALID_BP_SHAPE) paired_nucleotides = bps.flatten() has_duplicates = len(np.unique(paired_nucleotides)) < len(paired_nucleotides) if has_duplicates: raise ValueError(ERR_BP_ARR_CONTAINS_DUPLICATES) in_range = np.all((paired_nucleotides >= 0) & (paired_nucleotides < n_nucleotides)) if not in_range: raise ValueError(ERR_INVALID_BP_INDICES) # Infer the unpaired nucleotides unpaired = np.setdiff1d(np.arange(n_nucleotides), paired_nucleotides) n_unpaired = unpaired.shape[0] # Construct the index mapper for unpaired nucleotides idx_to_unpaired_idx = np.full((n_nucleotides,), -1, dtype=np.int32) for up_idx, idx in enumerate(unpaired): idx_to_unpaired_idx[idx] = up_idx idx_to_unpaired_idx = np.array(idx_to_unpaired_idx) # Construct the index mapper for base paired nucleotides idx_to_bp_idx = np.full((n_nucleotides, 2), -1, dtype=np.int32) for bp_idx, (nt1, nt2) in enumerate(bps): idx_to_bp_idx[nt1] = [bp_idx, 0] idx_to_bp_idx[nt2] = [bp_idx, 1] idx_to_bp_idx = np.array(idx_to_bp_idx) # Construct additional metadata is_unpaired = np.array([(i in set(np.array(unpaired))) for i in range(n_nucleotides)]) n_bp = bps.shape[0] # Construct a SequenceConstraints object return SequenceConstraints( n_nucleotides=n_nucleotides, n_unpaired=n_unpaired, n_bp=n_bp, is_unpaired=jnp.array(is_unpaired), unpaired=jnp.array(unpaired), bps=jnp.array(bps), idx_to_unpaired_idx=jnp.array(idx_to_unpaired_idx), idx_to_bp_idx=jnp.array(idx_to_bp_idx), )
[docs] def dseq_to_pseq(dseq: typ.Discrete_Sequence, sc: SequenceConstraints) -> typ.Probabilistic_Sequence: """Converts a discrete sequence to a probabilistic sequence.""" # First, generate unpaired pseq unpaired = sc.unpaired n_unpaired = sc.n_unpaired up_pseq = np.zeros((n_unpaired, jd_const.N_NT), dtype=np.float64) for up_idx, idx in enumerate(unpaired): nt = dseq[idx] up_pseq[up_idx][nt] = 1.0 # Second, generate base paired pseq bps = sc.bps n_bp = sc.n_bp bp_pseq = np.zeros((n_bp, 4), dtype=np.float64) for bp_idx, (idx1, idx2) in enumerate(bps): nt1, nt2 = dseq[idx1], dseq[idx2] bp_tuple = (int(nt1), int(nt2)) if bp_tuple not in jd_const.BP_IDX_MAP: raise ValueError(ERR_DSEQ_TO_PSEQ_INVALID_BP) bp_type_idx = jd_const.BP_IDX_MAP[bp_tuple] bp_pseq[bp_idx][bp_type_idx] = 1.0 # With zero base pairs, we need a dummy array to handle the -1 indexing # which is used in computing pair weights where individual nucleotides are # not in a base pair and thus ignored in computation but needed in access. if n_bp == 0: bp_pseq = np.zeros((1, 4), dtype=np.float64) return (jnp.array(up_pseq), jnp.array(bp_pseq))