mythos.simulators.jax_md ======================== .. py:module:: mythos.simulators.jax_md .. autoapi-nested-parse:: jax_md sampler implementation for mythos. Submodules ---------- .. toctree:: :maxdepth: 1 /autoapi/mythos/simulators/jax_md/jaxmd/index /autoapi/mythos/simulators/jax_md/utils/index Classes ------- .. autoapisummary:: mythos.simulators.jax_md.JaxMDSimulator mythos.simulators.jax_md.NeighborList mythos.simulators.jax_md.NoNeighborList mythos.simulators.jax_md.SimulationState mythos.simulators.jax_md.StaticSimulatorParams Package Contents ---------------- .. py:class:: JaxMDSimulator Bases: :py:obj:`mythos.simulators.base.Simulator` A sampler based on running a jax_md simulation routine. .. py:attribute:: energy_fn :type: mythos.energy.base.EnergyFunction .. py:attribute:: simulator_params :type: mythos.simulators.jax_md.utils.StaticSimulatorParams .. py:attribute:: space :type: jax_md.space.Space .. py:attribute:: simulator_init :type: collections.abc.Callable[[collections.abc.Callable, collections.abc.Callable], jax_md.simulate.Simulator] .. py:attribute:: neighbors :type: mythos.simulators.jax_md.utils.NeighborHelper .. py:method:: __post_init__() -> None Builds the run function using the provided parameters. .. py:class:: NeighborList Bases: :py:obj:`NeighborHelper` Neighbor list for managing unbonded neighbors. .. py:attribute:: displacement_fn :type: collections.abc.Callable .. py:attribute:: topology :type: mythos.input.topology.Topology .. py:attribute:: r_cutoff :type: float .. py:attribute:: dr_threshold :type: float .. py:attribute:: box_size :type: jax.numpy.ndarray .. py:attribute:: init_positions :type: jax_md.rigid_body.RigidBody .. py:method:: __post_init__() -> None Initialize the neighbor list. .. py:property:: idx :type: jax.numpy.ndarray Return the indices of the unbonded neighbors. .. py:method:: allocate(locs: jax_md.rigid_body.RigidBody) -> NeighborList Allocate memory for the neighbor list. .. py:method:: update(locs: jax_md.rigid_body.RigidBody) -> NeighborList Update the neighbor list. .. py:class:: NoNeighborList Bases: :py:obj:`NeighborHelper` A dummy neighbor list that does nothing. .. py:attribute:: unbonded_nbrs :type: jax.numpy.ndarray .. py:property:: idx :type: jax.numpy.ndarray Return the indices of the unbonded neighbors. .. py:method:: allocate(locs: jax_md.rigid_body.RigidBody) -> NoNeighborList Allocate memory for the neighbor list. .. py:method:: update(locs: jax_md.rigid_body.RigidBody) -> NoNeighborList Update the neighbor list. .. py:class:: SimulationState Bases: :py:obj:`Protocol` This is a protocol to help with typing. Every state implements at least position and mass. More info about the specific states can be found here: https://github.com/jax-md/jax-md/blob/main/jax_md/simulate.py .. py:attribute:: position :type: jax_md.rigid_body.RigidBody .. py:attribute:: mass :type: jax_md.rigid_body.RigidBody .. py:class:: StaticSimulatorParams Static parameters for the simulator. .. py:attribute:: seq :type: mythos.utils.types.Arr_Nucleotide .. py:attribute:: mass :type: jax_md.rigid_body.RigidBody .. py:attribute:: gamma :type: jax_md.rigid_body.RigidBody .. py:attribute:: bonded_neighbors :type: jax.numpy.ndarray .. py:attribute:: checkpoint_every :type: int .. py:attribute:: dt :type: float .. py:attribute:: kT :type: float .. py:property:: sim_init_fn :type: collections.abc.Callable Return the simulator init function. .. py:property:: init_fn :type: dict[str, jax_md.rigid_body.RigidBody | jax.numpy.ndarray] Return the kwargs for initial state of the simulator. .. py:property:: step_fn :type: dict[str, jax_md.rigid_body.RigidBody | jax.numpy.ndarray] Return the kwargs for the step_fn of the simulator.