mythos.simulators.jax_md.utils
Utilities for JAX-MD samplers.
Attributes
Classes
This is a protocol to help with typing. |
|
Helper class for managing neighbor lists. |
|
A dummy neighbor list that does nothing. |
|
Neighbor list for managing unbonded neighbors. |
|
Static parameters for the simulator. |
Functions
|
Split xs into n pieces and stack them. |
|
Flatten x by n levels. |
|
Replicates the behavior of jax.lax.scan but checkpoints gradients every checkpoint_every steps. |
Module Contents
- mythos.simulators.jax_md.utils.ERR_CHKPNT_SCN = '`checkpoint_every` must evenly divide the length of `xs`. Got {} and {}.'
- class mythos.simulators.jax_md.utils.SimulationState[source]
Bases:
ProtocolThis is a protocol to help with typing.
Every state implements at least position and mass. More info about the specific states can be found here:
https://github.com/jax-md/jax-md/blob/main/jax_md/simulate.py
- position: jax_md.rigid_body.RigidBody
- mass: jax_md.rigid_body.RigidBody
- class mythos.simulators.jax_md.utils.NeighborHelper[source]
Bases:
ProtocolHelper class for managing neighbor lists.
- property idx: jax.numpy.ndarray
Return the indices of the unbonded neighbors.
- allocate(locs: jax_md.rigid_body.RigidBody) NeighborHelper[source]
Allocate memory for the neighbor list.
- update(locs: jax_md.rigid_body.RigidBody) NeighborHelper[source]
Update the neighbor list.
- class mythos.simulators.jax_md.utils.NoNeighborList[source]
Bases:
NeighborHelperA dummy neighbor list that does nothing.
- unbonded_nbrs: jax.numpy.ndarray
- property idx: jax.numpy.ndarray
Return the indices of the unbonded neighbors.
- allocate(locs: jax_md.rigid_body.RigidBody) NoNeighborList[source]
Allocate memory for the neighbor list.
- update(locs: jax_md.rigid_body.RigidBody) NoNeighborList[source]
Update the neighbor list.
- class mythos.simulators.jax_md.utils.NeighborList[source]
Bases:
NeighborHelperNeighbor list for managing unbonded neighbors.
- displacement_fn: collections.abc.Callable
- topology: mythos.input.topology.Topology
- box_size: jax.numpy.ndarray
- init_positions: jax_md.rigid_body.RigidBody
- property idx: jax.numpy.ndarray
Return the indices of the unbonded neighbors.
- allocate(locs: jax_md.rigid_body.RigidBody) NeighborList[source]
Allocate memory for the neighbor list.
- update(locs: jax_md.rigid_body.RigidBody) NeighborList[source]
Update the neighbor list.
- class mythos.simulators.jax_md.utils.StaticSimulatorParams[source]
Static parameters for the simulator.
- seq: mythos.utils.types.Arr_Nucleotide
- mass: jax_md.rigid_body.RigidBody
- gamma: jax_md.rigid_body.RigidBody
- bonded_neighbors: jax.numpy.ndarray
- property sim_init_fn: collections.abc.Callable
Return the simulator init function.
- mythos.simulators.jax_md.utils.split_and_stack(x: jax.numpy.ndarray, n: int) jax.numpy.ndarray[source]
Split xs into n pieces and stack them.
- mythos.simulators.jax_md.utils.flatten_n(x: jax.numpy.ndarray, n: int) jax.numpy.ndarray[source]
Flatten x by n levels.
- mythos.simulators.jax_md.utils.checkpoint_scan(f: collections.abc.Callable, init: jax_md.rigid_body.RigidBody, xs: jax.numpy.ndarray, checkpoint_every: int) tuple[jax_md.rigid_body.RigidBody, jax.numpy.ndarray][source]
Replicates the behavior of jax.lax.scan but checkpoints gradients every checkpoint_every steps.