Source code for mythos.input.tree

"""Functions for saving and loading pytrees."""

import pickle
from pathlib import Path

import jax

import mythos.utils.types as jdna_types


[docs] def save_pytree(data: jdna_types.PyTree, filename: jdna_types.PathOrStr) -> None: """Save a pytree to a file.""" save_path = Path(filename) leaves, treedef = jax.tree_util.tree_flatten(data) with save_path.open("wb") as f: pickle.dump((leaves, treedef), f)
[docs] def load_pytree(filename: jdna_types.PathOrStr) -> jdna_types.PyTree: """Load a pytree to a file.""" save_path = Path(filename) with save_path.open("rb") as f: # Though this is labeled as a security issue by Bandit we only open # files that we write. So we can ignore this for now, but if there # another way we should consider switching to that. # TODO(ryanhausen): Investigate a more secure way to load the file. # https://github.com/mythos-bio/mythos/issues/7 leaves, treedef = pickle.load(f) # nosec B301 # noqa: S301 return jax.tree_util.tree_unflatten(treedef, leaves)