mythos.simulators.jax_md

jax_md sampler implementation for mythos.

Submodules

Classes

JaxMDSimulator

A sampler based on running a jax_md simulation routine.

NeighborList

Neighbor list for managing unbonded neighbors.

NoNeighborList

A dummy neighbor list that does nothing.

SimulationState

This is a protocol to help with typing.

StaticSimulatorParams

Static parameters for the simulator.

Package Contents

class mythos.simulators.jax_md.JaxMDSimulator[source]

Bases: mythos.simulators.base.Simulator

A sampler based on running a jax_md simulation routine.

energy_fn: mythos.energy.base.EnergyFunction
simulator_params: mythos.simulators.jax_md.utils.StaticSimulatorParams
space: jax_md.space.Space
simulator_init: collections.abc.Callable[[collections.abc.Callable, collections.abc.Callable], jax_md.simulate.Simulator]
neighbors: mythos.simulators.jax_md.utils.NeighborHelper
__post_init__() None[source]

Builds the run function using the provided parameters.

class mythos.simulators.jax_md.NeighborList[source]

Bases: NeighborHelper

Neighbor list for managing unbonded neighbors.

displacement_fn: collections.abc.Callable
topology: mythos.input.topology.Topology
r_cutoff: float
dr_threshold: float
box_size: jax.numpy.ndarray
init_positions: jax_md.rigid_body.RigidBody
__post_init__() None[source]

Initialize the neighbor list.

property idx: jax.numpy.ndarray

Return the indices of the unbonded neighbors.

allocate(locs: jax_md.rigid_body.RigidBody) NeighborList[source]

Allocate memory for the neighbor list.

update(locs: jax_md.rigid_body.RigidBody) NeighborList[source]

Update the neighbor list.

class mythos.simulators.jax_md.NoNeighborList[source]

Bases: NeighborHelper

A dummy neighbor list that does nothing.

unbonded_nbrs: jax.numpy.ndarray
property idx: jax.numpy.ndarray

Return the indices of the unbonded neighbors.

allocate(locs: jax_md.rigid_body.RigidBody) NoNeighborList[source]

Allocate memory for the neighbor list.

update(locs: jax_md.rigid_body.RigidBody) NoNeighborList[source]

Update the neighbor list.

class mythos.simulators.jax_md.SimulationState[source]

Bases: Protocol

This is a protocol to help with typing.

Every state implements at least position and mass. More info about the specific states can be found here:

https://github.com/jax-md/jax-md/blob/main/jax_md/simulate.py

position: jax_md.rigid_body.RigidBody
mass: jax_md.rigid_body.RigidBody
class mythos.simulators.jax_md.StaticSimulatorParams[source]

Static parameters for the simulator.

seq: mythos.utils.types.Arr_Nucleotide
mass: jax_md.rigid_body.RigidBody
gamma: jax_md.rigid_body.RigidBody
bonded_neighbors: jax.numpy.ndarray
checkpoint_every: int
dt: float
kT: float
property sim_init_fn: collections.abc.Callable

Return the simulator init function.

property init_fn: dict[str, jax_md.rigid_body.RigidBody | jax.numpy.ndarray]

Return the kwargs for initial state of the simulator.

property step_fn: dict[str, jax_md.rigid_body.RigidBody | jax.numpy.ndarray]

Return the kwargs for the step_fn of the simulator.