Source code for mythos.utils.helpers

"""Helper functions for the mythos package."""

import itertools
import sys
from collections.abc import Iterable
from typing import Any

import jax
import jax.numpy as jnp
import jaxtyping as jaxtyp

ERR_BATCHED_N = "n must be at least one"


[docs] def batched(iterable: Iterable[Any], n: int) -> Iterable[Any]: """Batch an iterable into chunks of size n. Args: iterable (iter[Any]): iterable to batch n (int): batch size Returns: iter[Any]: batched iterable """ if sys.version_info >= (3, 12): batch_f = itertools.batched else: # taken from https://docs.python.org/3/library/itertools.html#itertools.batched def batch_f(iterable: Iterable[Any], n: int) -> Iterable[Any]: # batched('ABCDEFG', 3) → ABC DEF G if n < 1: raise ValueError(ERR_BATCHED_N) it = iter(iterable) while batch := tuple(itertools.islice(it, n)): yield batch return batch_f(iterable, n)
[docs] def tree_stack(trees: list[jaxtyp.PyTree]) -> jaxtyp.PyTree: """Stacks corresponding leaves of PyTrees into arrays along a new axis.""" return jax.tree.map(lambda *v: jnp.stack(v), *trees)
[docs] def tree_concatenate(trees: list[jaxtyp.PyTree]) -> jaxtyp.PyTree: """Concatenates corresponding leaves of PyTrees along the first axis.""" return jax.tree.map(lambda *v: jnp.concatenate(v), *trees)