Source code for mythos.simulators.jax_md.utils

"""Utilities for JAX-MD samplers."""

from collections.abc import Callable
from typing import Protocol

import chex
import jax
import jax.numpy as jnp
import jax_md
import numpy as np
from typing_extensions import override

import mythos.input.topology as jdna_top
import mythos.utils.types as typ

ERR_CHKPNT_SCN = "`checkpoint_every` must evenly divide the length of `xs`. Got {} and {}."


[docs] class SimulationState(Protocol): """This is a protocol to help with typing. Every state implements at least position and mass. More info about the specific states can be found here: https://github.com/jax-md/jax-md/blob/main/jax_md/simulate.py """ position: jax_md.rigid_body.RigidBody mass: jax_md.rigid_body.RigidBody
[docs] class NeighborHelper(Protocol): """Helper class for managing neighbor lists.""" @property def idx(self) -> jnp.ndarray: """Return the indices of the unbonded neighbors.""" ...
[docs] def allocate(self, locs: jax_md.rigid_body.RigidBody) -> "NeighborHelper": """Allocate memory for the neighbor list.""" ...
[docs] def update(self, locs: jax_md.rigid_body.RigidBody) -> "NeighborHelper": """Update the neighbor list.""" ...
[docs] @chex.dataclass class NoNeighborList(NeighborHelper): """A dummy neighbor list that does nothing.""" unbonded_nbrs: jnp.ndarray @property def idx(self) -> jnp.ndarray: """Return the indices of the unbonded neighbors.""" return self.unbonded_nbrs
[docs] @override def allocate(self, locs: jax_md.rigid_body.RigidBody) -> "NoNeighborList": """Allocate memory for the neighbor list.""" return self
[docs] @override def update(self, locs: jax_md.rigid_body.RigidBody) -> "NoNeighborList": """Update the neighbor list.""" return self
[docs] @chex.dataclass class NeighborList(NeighborHelper): """Neighbor list for managing unbonded neighbors.""" displacement_fn: Callable topology: jdna_top.Topology r_cutoff: float dr_threshold: float box_size: jnp.ndarray init_positions: jax_md.rigid_body.RigidBody
[docs] def __post_init__(self) -> None: """Initialize the neighbor list.""" dense_mask = np.full((self.topology.n_nucleotides, 2), self.topology.n_nucleotides, dtype=np.int32) counter = np.zeros(self.topology.n_nucleotides, dtype=np.int32) for i, j in self.topology.bonded_nbrs: dense_mask[i, counter[i]] = j counter[i] += 1 dense_mask[j, counter[j]] = i counter[j] += 1 self.dense_mask = jnp.array(dense_mask, dtype=jnp.int32) def bonded_nbrs_mask_fn(dense_idx: int) -> jnp.ndarray: nbr_mask1 = dense_idx == dense_mask[:, 0].reshape(self.topology.n_nucleotides, 1) dense_idx = jnp.where(nbr_mask1, self.topology.n_nucleotides, dense_idx) nbr_mask2 = dense_idx == dense_mask[:, 1].reshape(self.topology.n_nucleotides, 1) return jnp.where(nbr_mask2, self.topology.n_nucleotides, dense_idx) self._neighborlist_fn = jax_md.partition.neighbor_list( self.displacement_fn, box_size=self.box_size, r_cutoff=self.r_cutoff, dr_threshold=self.dr_threshold, custom_mask_function=bonded_nbrs_mask_fn, disable_cell_list=True, format=jax_md.partition.NeighborListFormat.OrderedSparse, ) self._neighbors = self._neighborlist_fn.allocate(self.init_positions.center)
@property def idx(self) -> jnp.ndarray: """Return the indices of the unbonded neighbors.""" return self._neighbors.idx
[docs] @override def allocate(self, locs: jax_md.rigid_body.RigidBody) -> "NeighborList": """Allocate memory for the neighbor list.""" self._neighbors = self._neighborlist_fn.allocate(locs) return self
[docs] @override def update(self, locs: jax_md.rigid_body.RigidBody) -> "NeighborList": """Update the neighbor list.""" self._neighbors = self._neighbors.update(locs.position.center) return self._neighbors
[docs] @chex.dataclass class StaticSimulatorParams: """Static parameters for the simulator.""" # this is commonly referred to as `init_body` in the code, but the argument is named `R` in jax_md seq: typ.Arr_Nucleotide mass: jax_md.rigid_body.RigidBody gamma: jax_md.rigid_body.RigidBody bonded_neighbors: jnp.ndarray checkpoint_every: int dt: float kT: float # noqa: N815, the variable is commonly referred to using this casing @property def sim_init_fn(self) -> Callable: """Return the simulator init function.""" return { "dt": self.dt, "kT": self.kT, "gamma": self.gamma, } @property def init_fn(self) -> dict[str, jax_md.rigid_body.RigidBody | jnp.ndarray]: """Return the kwargs for initial state of the simulator.""" return {"mass": self.mass} @property def step_fn(self) -> dict[str, jax_md.rigid_body.RigidBody | jnp.ndarray]: """Return the kwargs for the step_fn of the simulator.""" return {}
[docs] def split_and_stack(x: jnp.ndarray, n: int) -> jnp.ndarray: """Split `xs` into `n` pieces and stack them.""" return jax.tree.map(lambda y: jnp.stack(jnp.split(y, n)), x)
[docs] def flatten_n(x: jnp.ndarray, n: int) -> jnp.ndarray: """Flatten `x` by `n` levels.""" # setting n <= 1 does not achieve the desired effect chex.assert_scalar_positive(n - 1) return jax.tree.map(lambda y: jnp.reshape(y, (-1,) + y.shape[n:]), x)
[docs] def checkpoint_scan( f: Callable, init: jax_md.rigid_body.RigidBody, xs: jnp.ndarray, checkpoint_every: int ) -> tuple[jax_md.rigid_body.RigidBody, jnp.ndarray]: """Replicates the behavior of `jax.lax.scan` but checkpoints gradients every `checkpoint_every` steps.""" flat_xs, _ = jax.tree_util.tree_flatten(xs) length = flat_xs[0].shape[0] outer_iterations, residual = divmod(length, checkpoint_every) if residual: raise ValueError(ERR_CHKPNT_SCN.format(checkpoint_every, length)) reshaped_xs = split_and_stack(xs, outer_iterations) @jax.checkpoint def inner_loop( _init: jax_md.rigid_body.RigidBody, _xs: jnp.ndarray ) -> tuple[jax_md.rigid_body.RigidBody, jnp.ndarray]: return jax.lax.scan(f, _init, _xs) final, result = jax.lax.scan(inner_loop, init, reshaped_xs) return final, flatten_n(result, 2)