mythos.utils.helpers

Helper functions for the mythos package.

Attributes

ERR_BATCHED_N

Functions

batched(→ collections.abc.Iterable[Any])

Batch an iterable into chunks of size n.

tree_stack(→ jaxtyping.PyTree)

Stacks corresponding leaves of PyTrees into arrays along a new axis.

tree_concatenate(→ jaxtyping.PyTree)

Concatenates corresponding leaves of PyTrees along the first axis.

Module Contents

mythos.utils.helpers.ERR_BATCHED_N = 'n must be at least one'
mythos.utils.helpers.batched(iterable: collections.abc.Iterable[Any], n: int) collections.abc.Iterable[Any][source]

Batch an iterable into chunks of size n.

Parameters:
  • iterable (iter[Any]) – iterable to batch

  • n (int) – batch size

Returns:

batched iterable

Return type:

iter[Any]

mythos.utils.helpers.tree_stack(trees: list[jaxtyping.PyTree]) jaxtyping.PyTree[source]

Stacks corresponding leaves of PyTrees into arrays along a new axis.

mythos.utils.helpers.tree_concatenate(trees: list[jaxtyping.PyTree]) jaxtyping.PyTree[source]

Concatenates corresponding leaves of PyTrees along the first axis.