"""Configuration class for energy models."""
import warnings
from typing import Any, Union
import chex
import mythos.utils.types as jdt
ERR_MISSING_REQUIRED_PARAMS = "Required properties {props} are not initialized."
ERR_OPT_DEPENDENT_PARAMS = "Only {req_params} permitted for optimization, but found {given_params}"
WARN_INIT_PARAMS_NOT_IMPLEMENTED = "init_params not implemented"
WARN_DEPENDENT_PARAMS_NOT_INITIALIZED = "Dependent parameters not initialized"
[docs]
@chex.dataclass(frozen=True)
class BaseConfiguration:
"""Base class for configuration classes.
This class should not be used directly.
Parameters:
params_to_optimize (tuple[str]): parameters to optimize
required_params (tuple[str]): required parameters
non_optimizable_required_params (tuple[str]): required parameters that are not optimizable
dependent_params (tuple[str]): dependent parameters, these are calculated from the independent parameters
OPT_ALL (tuple[str]): CONSTANT, is a wild card for all parameters
"""
params_to_optimize: tuple[str] = ()
required_params: tuple[str] = ()
non_optimizable_required_params: tuple[str] = ()
dependent_params: tuple[str] = ()
OPT_ALL: tuple[str] = ("*",)
@property
def opt_params(self) -> dict[str, jdt.Scalar]:
"""Returns the parameters to optimize."""
if self.params_to_optimize == self.OPT_ALL:
params = {
k: v
for k, v in self.items()
if (k in self.required_params) and (k not in self.non_optimizable_required_params)
}
else:
params = {k: v for k, v in self.items() if k in self.params_to_optimize}
return params
[docs]
def __post_init__(self) -> None:
"""Checks validity of the configuration."""
non_initialized_props = [param for param in self.required_params if getattr(self, param) is None]
if non_initialized_props:
raise ValueError(ERR_MISSING_REQUIRED_PARAMS.format(props=",".join(non_initialized_props)))
optimizable_params = set(self.required_params) - set(self.non_optimizable_required_params)
unoptimizable_params = set(self.params_to_optimize) - optimizable_params
if unoptimizable_params and unoptimizable_params != set(self.OPT_ALL):
raise ValueError(
ERR_OPT_DEPENDENT_PARAMS.format(
req_params=",".join(sorted(optimizable_params)),
given_params=",".join(sorted(unoptimizable_params)),
)
)
[docs]
def init_params(self) -> "BaseConfiguration":
"""Initializes the dependent parameters in configuration.
Should be implemented in the subclass if dependent parameters are present.
"""
warnings.warn(WARN_INIT_PARAMS_NOT_IMPLEMENTED, stacklevel=1)
return self
[docs]
@classmethod
def from_dict(cls, params: dict[str, float], params_to_optimize: tuple[str] = ()) -> "BaseConfiguration":
"""Creates a configuration from a dictionary."""
return cls(**(params | {"params_to_optimize": params_to_optimize}))
[docs]
def to_dictionary(
self,
*,
include_dependent: bool,
exclude_non_optimizable: bool,
) -> dict[str, jdt.ARR_OR_SCALAR]:
"""Converts the configuration to a dictionary."""
params = {k: getattr(self, k) for k in self.required_params}
if include_dependent:
for k in self.dependent_params:
if (val := getattr(self, k)) is not None:
params[k] = val
else:
warnings.warn(WARN_DEPENDENT_PARAMS_NOT_INITIALIZED, stacklevel=1)
if exclude_non_optimizable:
for k in self.non_optimizable_required_params:
params.pop(k, None)
return params
def __merge__baseconfig(self, other: "BaseConfiguration") -> "BaseConfiguration":
"""Merges two BaseConfiguration objects."""
filtered = {k: v for k, v in other.items() if v is not None}
return self.__merge__dict(filtered)
def __merge__dict(self, other: dict[str, Any]) -> "BaseConfiguration":
"""Merges a dictionary with the configuration."""
return self.replace(**other)
# python doesn't like using the bar for type hints when inside the class, use Union for now
[docs]
def __or__(self, other: Union["BaseConfiguration", dict[str, jdt.ARR_OR_SCALAR]]) -> "BaseConfiguration":
"""Convenience method to merge a configuration or a dictionary with the current configuration.
Returns a new configuration object.
"""
if isinstance(other, BaseConfiguration):
merge_fn = self.__merge__baseconfig
elif isinstance(other, dict):
merge_fn = self.__merge__dict
else:
return NotImplemented
return merge_fn(other)