mythos.simulators.jax_md.utils ============================== .. py:module:: mythos.simulators.jax_md.utils .. autoapi-nested-parse:: Utilities for JAX-MD samplers. Attributes ---------- .. autoapisummary:: mythos.simulators.jax_md.utils.ERR_CHKPNT_SCN Classes ------- .. autoapisummary:: mythos.simulators.jax_md.utils.SimulationState mythos.simulators.jax_md.utils.NeighborHelper mythos.simulators.jax_md.utils.NoNeighborList mythos.simulators.jax_md.utils.NeighborList mythos.simulators.jax_md.utils.StaticSimulatorParams Functions --------- .. autoapisummary:: mythos.simulators.jax_md.utils.split_and_stack mythos.simulators.jax_md.utils.flatten_n mythos.simulators.jax_md.utils.checkpoint_scan Module Contents --------------- .. py:data:: ERR_CHKPNT_SCN :value: '`checkpoint_every` must evenly divide the length of `xs`. Got {} and {}.' .. py:class:: SimulationState Bases: :py:obj:`Protocol` This is a protocol to help with typing. Every state implements at least position and mass. More info about the specific states can be found here: https://github.com/jax-md/jax-md/blob/main/jax_md/simulate.py .. py:attribute:: position :type: jax_md.rigid_body.RigidBody .. py:attribute:: mass :type: jax_md.rigid_body.RigidBody .. py:class:: NeighborHelper Bases: :py:obj:`Protocol` Helper class for managing neighbor lists. .. py:property:: idx :type: jax.numpy.ndarray Return the indices of the unbonded neighbors. .. py:method:: allocate(locs: jax_md.rigid_body.RigidBody) -> NeighborHelper Allocate memory for the neighbor list. .. py:method:: update(locs: jax_md.rigid_body.RigidBody) -> NeighborHelper Update the neighbor list. .. py:class:: NoNeighborList Bases: :py:obj:`NeighborHelper` A dummy neighbor list that does nothing. .. py:attribute:: unbonded_nbrs :type: jax.numpy.ndarray .. py:property:: idx :type: jax.numpy.ndarray Return the indices of the unbonded neighbors. .. py:method:: allocate(locs: jax_md.rigid_body.RigidBody) -> NoNeighborList Allocate memory for the neighbor list. .. py:method:: update(locs: jax_md.rigid_body.RigidBody) -> NoNeighborList Update the neighbor list. .. py:class:: NeighborList Bases: :py:obj:`NeighborHelper` Neighbor list for managing unbonded neighbors. .. py:attribute:: displacement_fn :type: collections.abc.Callable .. py:attribute:: topology :type: mythos.input.topology.Topology .. py:attribute:: r_cutoff :type: float .. py:attribute:: dr_threshold :type: float .. py:attribute:: box_size :type: jax.numpy.ndarray .. py:attribute:: init_positions :type: jax_md.rigid_body.RigidBody .. py:method:: __post_init__() -> None Initialize the neighbor list. .. py:property:: idx :type: jax.numpy.ndarray Return the indices of the unbonded neighbors. .. py:method:: allocate(locs: jax_md.rigid_body.RigidBody) -> NeighborList Allocate memory for the neighbor list. .. py:method:: update(locs: jax_md.rigid_body.RigidBody) -> NeighborList Update the neighbor list. .. py:class:: StaticSimulatorParams Static parameters for the simulator. .. py:attribute:: seq :type: mythos.utils.types.Arr_Nucleotide .. py:attribute:: mass :type: jax_md.rigid_body.RigidBody .. py:attribute:: gamma :type: jax_md.rigid_body.RigidBody .. py:attribute:: bonded_neighbors :type: jax.numpy.ndarray .. py:attribute:: checkpoint_every :type: int .. py:attribute:: dt :type: float .. py:attribute:: kT :type: float .. py:property:: sim_init_fn :type: collections.abc.Callable Return the simulator init function. .. py:property:: init_fn :type: dict[str, jax_md.rigid_body.RigidBody | jax.numpy.ndarray] Return the kwargs for initial state of the simulator. .. py:property:: step_fn :type: dict[str, jax_md.rigid_body.RigidBody | jax.numpy.ndarray] Return the kwargs for the step_fn of the simulator. .. py:function:: split_and_stack(x: jax.numpy.ndarray, n: int) -> jax.numpy.ndarray Split `xs` into `n` pieces and stack them. .. py:function:: flatten_n(x: jax.numpy.ndarray, n: int) -> jax.numpy.ndarray Flatten `x` by `n` levels. .. py:function:: checkpoint_scan(f: collections.abc.Callable, init: jax_md.rigid_body.RigidBody, xs: jax.numpy.ndarray, checkpoint_every: int) -> tuple[jax_md.rigid_body.RigidBody, jax.numpy.ndarray] Replicates the behavior of `jax.lax.scan` but checkpoints gradients every `checkpoint_every` steps.