Source code for mythos.simulators.base
"""Base class for a simulation."""
import shutil
import uuid
from abc import ABC, abstractmethod
from dataclasses import field
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, ClassVar
import chex
from typing_extensions import override
[docs]
@chex.dataclass(frozen=True)
class SimulatorOutput:
"""Output container for simlators."""
observables: list[Any]
state: dict[str, Any] = field(default_factory=dict)
[docs]
@chex.dataclass(frozen=True, kw_only=True)
class Simulator:
"""Base class for a simulation."""
name: str = field(default_factory=lambda: str(uuid.uuid4()))
exposed_observables: ClassVar[list[str]] = ["trajectory"]
[docs]
def run(self, *_args, opt_params: dict[str, Any], **_kwargs) -> SimulatorOutput:
"""Run the simulation."""
[docs]
def exposes(self) -> list[str]:
"""Get the list of exposed observables."""
return [f"{obs}.{self.__class__.__name__}.{self.name}" for obs in self.exposed_observables]
[docs]
@classmethod
def create_n(cls, n: int, name: str|None = None, **kwargs) -> list["Simulator"]:
"""Create N simulators with unique names."""
name = name or str(uuid.uuid4())
return [cls(name=f"{name}.{i}", **kwargs) for i in range(n)]
[docs]
@chex.dataclass(frozen=True, kw_only=True)
class InputDirSimulator(Simulator, ABC):
"""A base class for simulators that run based on an input directory.
This class handles copying the input directory to a temporary location
unless overwrite_input is set to True.
Subclasses must implement the run_simulation method, which runs the
simulation logic given the provided input directory.
Arguments:
input_dir: Path to the input directory.
overwrite_input: Whether to overwrite the input directory or copy it. If
this is False (default), the contents of the input_dir will be
copied to a temporary directory for running the simulation to avoid
overwriting input.
"""
input_dir: str
overwrite_input: bool = False
[docs]
@override
def run(self, *args, **kwargs) -> SimulatorOutput:
if self.overwrite_input:
return self.run_simulation(Path(self.input_dir), *args, **kwargs)
with TemporaryDirectory(prefix=f"mythos-sim-{self.name}") as temp_dir:
self.copy_inputs(temp_dir)
return self.run_simulation(Path(temp_dir), *args, **kwargs)
[docs]
def copy_inputs(self, temp_dir: str) -> None:
"""Copy input files to temporary directory."""
shutil.copytree(self.input_dir, temp_dir, dirs_exist_ok=True)
[docs]
@abstractmethod
def run_simulation(self, input_dir: Path, *args, **kwargs) -> SimulatorOutput:
"""Run the simulation in the given input directory."""