Source code for mythos.optimization.optimization

"""Runs an optimization loop using Ray actors for objectives and simulators."""

from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import field
from typing import Any

import chex
import optax
import ray
from ray import ObjectRef as RayRef
from typing_extensions import override

from mythos.optimization.objective import Objective, ObjectiveOutput
from mythos.simulators.base import Simulator
from mythos.ui.loggers import logger as jdna_logger
from mythos.utils.types import Grads, Params

ERR_MISSING_OBJECTIVES = "At least one objective is required."
ERR_MISSING_SIMULATORS = "At least one simulator is required."
ERR_MISSING_AGG_GRAD_FN = "An aggregate gradient function is required."
ERR_MISSING_OPTIMIZER = "An optimizer is required."
# to prevent infinite unresolvable loops in step. The first call may use cached
# observables, so may required rerun of sims. After this, we don't expect any
# new information, so not-ready state after this is an error.
OBJECTIVE_PER_STEP_CALL_LIMIT = 2


[docs] @chex.dataclass(frozen=True, kw_only=True) class OptimizerState: """State container for optimization loops. This dataclass stores all mutable state needed during optimization, allowing optimizers to work with frozen objective dataclasses. Attributes: observables: Current observable values from simulators. state: Per-objective/simulator state (keyed by name). Both objective and simulator state share this namespace, so names should be unique across objectives and simulators. optimizer_state: Current optax optimizer state. """ observables: dict[str, Any] = field(default_factory=dict) component_state: dict[str, dict[str, Any]] = field(default_factory=dict) optimizer_state: Any | None = None # optax.OptState
[docs] @chex.dataclass(frozen=True, kw_only=True) class OptimizerOutput: """Output container for optimization steps. Attributes: grads: The computed (aggregate) gradients from the optimization step. opt_params: The updated parameters after the optimization step. state: The updated optimizer state after the optimization step. This data structure should be passed back into the next call to step. observables: The logged observables from the optimization step. These are keyed by component name (e.g. objective) and each value should itself be a dict of observable name to value. """ grads: Grads opt_params: Params state: OptimizerState observables: dict[str, dict[str, Any]] = field(default_factory=dict)
[docs] @chex.dataclass(frozen=True, kw_only=True) class Optimizer(ABC): """Abstract base class for optimizers."""
[docs] @abstractmethod def step(self, params: Params, state: OptimizerState | None = None) -> OptimizerOutput: """Perform a single optimization step. Args: params: The current parameters. state: The current optimization state. If None, an empty state is initialized. Returns: An optimizer output including params, new state, grads, and observables. """
[docs] @chex.dataclass(frozen=True, kw_only=True) class RayOptimizer(Optimizer): """Optimization of a list of objectives using a list of simulators. Parameters: objectives: A list of objectives to optimize. simulators: A list of simulators to use for the optimization. aggregate_grad_fn: A function that aggregates the gradients from the objectives. optimizer: An optax optimizer. optimizer_state: The state of the optimizer. logger: A logger to use for the optimization. """ objectives: list[Objective] simulators: list[Simulator] aggregate_grad_fn: Callable[[list[Grads]], Grads] optimizer: optax.GradientTransformation logger: jdna_logger.Logger = field(default_factory=jdna_logger.NullLogger)
[docs] def __post_init__(self) -> None: """Validate the initialization of the Optimization.""" if not self.objectives: raise ValueError(ERR_MISSING_OBJECTIVES) if not self.simulators: raise ValueError(ERR_MISSING_SIMULATORS) if self.aggregate_grad_fn is None: raise ValueError(ERR_MISSING_AGG_GRAD_FN) if self.optimizer is None: raise ValueError(ERR_MISSING_OPTIMIZER) # Check for conflicts in global namespaces that we use for coordination all_names = [obj.name for obj in self.objectives] \ + [sim.name for sim in self.simulators] \ + [exp for sim in self.simulators for exp in sim.exposes()] if len(all_names) != len(set(all_names)): raise ValueError("All objective, simulator, and exposes names must be unique")
[docs] def _create_and_run_remote(self, fun: callable, ray_options: dict, *args) -> RayRef|list[RayRef]: remote_fun = ray.remote(fun).options(**ray_options) return remote_fun.remote(*args)
[docs] def _run_simulator( self, simulator: Simulator, params: Params, **state ) -> tuple[list[RayRef], RayRef]: def simulator_run_fn(params: Params, state: dict[str, Any]) -> tuple[list[RayRef], RayRef]: output = simulator.run(opt_params=params, **state) return *output.observables, output.state ray_opts = {"name": "simulator_run:" + simulator.name, "num_returns": 1 + len(simulator.exposes())} refs = self._create_and_run_remote(simulator_run_fn, ray_opts, params, state) return refs[:-1], refs[-1] # observables as a list, state
[docs] def _run_objective( self, objective: Objective, observables: dict[str, RayRef], params: Params, **state ) -> RayRef: def objective_compute_fn(obs: dict[str, RayRef], params: Params, state: dict[str, Any]) -> ObjectiveOutput: obs = {k: ray.get(v) for k, v in obs.items()} return objective.calculate(observables=obs, opt_params=params, **state) ray_opts = {"name": "objective_compute:" + objective.name} return self._create_and_run_remote(objective_compute_fn, ray_opts, observables, params, state)
[docs] def _wait_remotes(self, refs: list[RayRef]) -> list[RayRef]: ref_list = list(refs) ray.wait(ref_list, fetch_local=False, num_returns=1) # The below is to maximize our chance of getting multiple at once # (for example multiple observables and state from a simulator) ready, _ = ray.wait(ref_list, fetch_local=False, timeout=0.1) return ready
[docs] @override def step(self, params: Params, state: OptimizerState|None = None) -> OptimizerOutput: # noqa: C901, PLR0912 state = state or OptimizerState() state_observables, component_state = state.observables.copy(), state.component_state.copy() obj_lookup = {obj.name: obj for obj in self.objectives} call_count = {obj.name: 0 for obj in self.objectives} sim_lookup = {sim.name: sim for sim in self.simulators} expose_lookup = {exp: sim for sim in self.simulators for exp in sim.exposes()} ref_map, grads_completed, output_observables = {}, {}, {} # schedule all objectives that already have their observables in state while needed_objectives := set(obj_lookup) - set(grads_completed): for obj_name in needed_objectives: objective = obj_lookup[obj_name] # skip if we are currently running it if objective.name in ref_map.values(): continue # It is an unresolvable state if we have called the objective # more than twice if call_count[objective.name] > OBJECTIVE_PER_STEP_CALL_LIMIT: raise RuntimeError(f"Objective {objective.name} could not be resolved after multiple attempts.") # If we have all the observables in state, we make an attempt at # the objective. This may return a not ready signal, in which # case observables will be cleared to trigger this logic again. if set(objective.required_observables).issubset(state_observables): obj_observables = {k: state_observables[k] for k in objective.required_observables} obj_state = component_state.get(objective.name, {}) ref = self._run_objective(objective, obj_observables, params, **obj_state) ref_map[ref] = objective.name call_count[objective.name] += 1 # there are simulators running that provide some of what we # need, so we have gone through the scheduling step elif set(objective.required_observables).intersection(ref_map.values()): continue else: needed_sims = {expose_lookup[exp].name for exp in objective.required_observables} # filter out sims we know are running based on state ref for sim_name in needed_sims - set(ref_map.values()): sim = sim_lookup[sim_name] # make sure we aren't waiting on any of the observables # this provides. It may be possible that the ref of # state or some observables become available separately if set(sim.exposes()).intersection(ref_map.values()): continue sim_state = component_state.get(sim.name, {}) refs, md_ref = self._run_simulator(sim, params, **sim_state) for r, exp in zip(refs, sim.exposes(), strict=True): ref_map[r] = exp ref_map[md_ref] = sim.name # wait for anything to finish. We do a second wait without num # returns but non-blocking to gather as many as we can at once ready = self._wait_remotes(ref_map.keys()) for ref in ready: producer = ref_map.pop(ref) if producer in obj_lookup: output = ray.get(ref) component_state[producer] = output.state if output.is_ready: grads_completed[producer] = output.grads output_observables[producer] = output.observables else: # remove the needs from the state observables so the # above loop check will schedule the providing simulator state_observables = {k: v for k, v in state.observables.items() if k not in output.needs_update} elif producer in expose_lookup: state_observables[producer] = ref else: # finally it must be simulator state component_state[producer] = ray.get(ref) grads = self.aggregate_grad_fn(list(grads_completed.values())) opt_state = state.optimizer_state or self.optimizer.init(params) updates, opt_state = self.optimizer.update(grads, opt_state, params) new_params = optax.apply_updates(params, updates) return OptimizerOutput( opt_params=new_params, state=state.replace( optimizer_state=opt_state, component_state=component_state, observables=state_observables, ), grads=grads, observables=output_observables, )
[docs] @chex.dataclass(frozen=True) class SimpleOptimizer(Optimizer): """A simple optimizer that uses a single objective and simulator. This optimizer manages the state for a frozen Objective dataclass, passing observables and state through the compute method. State is managed via OptimizationState which is passed in and out. """ objective: Objective simulator: Simulator optimizer: optax.GradientTransformation logger: jdna_logger.Logger = field(default_factory=jdna_logger.NullLogger)
[docs] @override def step(self, params: Params, state: OptimizerState | None = None) -> OptimizerOutput: state = state or OptimizerState() obj_state = state.component_state.get(self.objective.name, {}) sim_state = state.component_state.get(self.simulator.name, {}) obj_output = None if state.observables: obj_output = self.objective.calculate(state.observables, opt_params=params, **obj_state) obj_state = obj_output.state if obj_output is None or not obj_output.is_ready: sim_output = self.simulator.run(params, **sim_state) sim_state = sim_output.state exposes = self.simulator.exposes() state = state.replace(observables=dict(zip(exposes, sim_output.observables, strict=True))) # Try again with updated observables obj_output = self.objective.calculate(state.observables, opt_params=params, **obj_state) obj_state = obj_output.state if not obj_output.is_ready: # this should be an impossible state, could end up in infinite loop raise ValueError("Objective readiness check failed after simulation run.") grads = obj_output.grads opt_state = state.optimizer_state or self.optimizer.init(params) updates, opt_state = self.optimizer.update(grads, opt_state, params) new_params = optax.apply_updates(params, updates) # should this be filtered? also be allowed to return filtered from # simulators? output_observables = {self.objective.name: obj_output.observables} return OptimizerOutput( opt_params=new_params, state=state.replace( optimizer_state=opt_state, component_state={ **state.component_state, self.objective.name: obj_state, self.simulator.name: sim_state, }, ), grads=grads, observables=output_observables )