mythos.simulators.jax_md.utils

Utilities for JAX-MD samplers.

Attributes

ERR_CHKPNT_SCN

Classes

SimulationState

This is a protocol to help with typing.

NeighborHelper

Helper class for managing neighbor lists.

NoNeighborList

A dummy neighbor list that does nothing.

NeighborList

Neighbor list for managing unbonded neighbors.

StaticSimulatorParams

Static parameters for the simulator.

Functions

split_and_stack(→ jax.numpy.ndarray)

Split xs into n pieces and stack them.

flatten_n(→ jax.numpy.ndarray)

Flatten x by n levels.

checkpoint_scan(→ tuple[jax_md.rigid_body.RigidBody, ...)

Replicates the behavior of jax.lax.scan but checkpoints gradients every checkpoint_every steps.

Module Contents

mythos.simulators.jax_md.utils.ERR_CHKPNT_SCN = '`checkpoint_every` must evenly divide the length of `xs`. Got {} and {}.'
class mythos.simulators.jax_md.utils.SimulationState[source]

Bases: Protocol

This is a protocol to help with typing.

Every state implements at least position and mass. More info about the specific states can be found here:

https://github.com/jax-md/jax-md/blob/main/jax_md/simulate.py

position: jax_md.rigid_body.RigidBody
mass: jax_md.rigid_body.RigidBody
class mythos.simulators.jax_md.utils.NeighborHelper[source]

Bases: Protocol

Helper class for managing neighbor lists.

property idx: jax.numpy.ndarray

Return the indices of the unbonded neighbors.

allocate(locs: jax_md.rigid_body.RigidBody) NeighborHelper[source]

Allocate memory for the neighbor list.

update(locs: jax_md.rigid_body.RigidBody) NeighborHelper[source]

Update the neighbor list.

class mythos.simulators.jax_md.utils.NoNeighborList[source]

Bases: NeighborHelper

A dummy neighbor list that does nothing.

unbonded_nbrs: jax.numpy.ndarray
property idx: jax.numpy.ndarray

Return the indices of the unbonded neighbors.

allocate(locs: jax_md.rigid_body.RigidBody) NoNeighborList[source]

Allocate memory for the neighbor list.

update(locs: jax_md.rigid_body.RigidBody) NoNeighborList[source]

Update the neighbor list.

class mythos.simulators.jax_md.utils.NeighborList[source]

Bases: NeighborHelper

Neighbor list for managing unbonded neighbors.

displacement_fn: collections.abc.Callable
topology: mythos.input.topology.Topology
r_cutoff: float
dr_threshold: float
box_size: jax.numpy.ndarray
init_positions: jax_md.rigid_body.RigidBody
__post_init__() None[source]

Initialize the neighbor list.

property idx: jax.numpy.ndarray

Return the indices of the unbonded neighbors.

allocate(locs: jax_md.rigid_body.RigidBody) NeighborList[source]

Allocate memory for the neighbor list.

update(locs: jax_md.rigid_body.RigidBody) NeighborList[source]

Update the neighbor list.

class mythos.simulators.jax_md.utils.StaticSimulatorParams[source]

Static parameters for the simulator.

seq: mythos.utils.types.Arr_Nucleotide
mass: jax_md.rigid_body.RigidBody
gamma: jax_md.rigid_body.RigidBody
bonded_neighbors: jax.numpy.ndarray
checkpoint_every: int
dt: float
kT: float
property sim_init_fn: collections.abc.Callable

Return the simulator init function.

property init_fn: dict[str, jax_md.rigid_body.RigidBody | jax.numpy.ndarray]

Return the kwargs for initial state of the simulator.

property step_fn: dict[str, jax_md.rigid_body.RigidBody | jax.numpy.ndarray]

Return the kwargs for the step_fn of the simulator.

mythos.simulators.jax_md.utils.split_and_stack(x: jax.numpy.ndarray, n: int) jax.numpy.ndarray[source]

Split xs into n pieces and stack them.

mythos.simulators.jax_md.utils.flatten_n(x: jax.numpy.ndarray, n: int) jax.numpy.ndarray[source]

Flatten x by n levels.

mythos.simulators.jax_md.utils.checkpoint_scan(f: collections.abc.Callable, init: jax_md.rigid_body.RigidBody, xs: jax.numpy.ndarray, checkpoint_every: int) tuple[jax_md.rigid_body.RigidBody, jax.numpy.ndarray][source]

Replicates the behavior of jax.lax.scan but checkpoints gradients every checkpoint_every steps.