"""Base Implementation for objective functions."""
from abc import abstractmethod
from typing import ClassVar, Any, Literal, Union, Optional
import logging
from pydantic import computed_field, Field, field_validator, model_validator, PrivateAttr
import SimpleITK as sitk
import array_api_compat
from ...core.xp_utils.typing import Array
from ...core.xp_utils import to_numpy, from_numpy
from ...core.datamodel import PyRadPlanBaseModel
from ...quantities import get_available_quantities
from ...core import Grid
from ...core.resample import resample_numpy_array
ParameterType = Union[
Literal["reference", "numeric", "relative_volume", "image_reference"], list[str]
]
logger = logging.getLogger(__name__)
class ParameterMetadata:
"""
Parameter Metadata to attach to objective function parameters to designate configurability and type.
Parameters
----------
configurable : bool, optional, default=True
Whether the parameter is intended to be configurable (e.g. a reference dose value for
optimization)
kind : str, optional, default='numeric'
Type/Meaning of the parameter. Options are 'reference' (relates to the input vector, e.g.
the dose when you optimize on dose. Used for normalization), 'numeric', 'relative_volume'.
The kind is used in case the parameter needs to pre transformed in context.
Attributes
----------
configurable : bool, optional, default=True
Whether the parameter is intended to be configurable (e.g. a reference dose value for
optimization)
kind : str, optional, default='numeric'
Type/Meaning of the parameter. Options are 'reference' (relates to the input vector, e.g.
the dose when you optimize on dose. Used for normalization), 'numeric', 'relative_volume'.
The kind is used in case the parameter needs to pre transformed in context.
Note
----
Intended for use with 'Annotated' when defining an objective attribute.
"""
configurable: bool
kind: Optional[ParameterType]
"""Configurable Parameter."""
def __init__(self, configurable: bool = True, kind: Optional[ParameterType] = "numeric"):
self.configurable = configurable
self.kind = kind
def __repr__(self):
return f"{self.__class__}({self.__dict__})"
[docs]
class Objective(PyRadPlanBaseModel):
"""
Base class for objective functions in the optimization problem.
Attributes
----------
name : ClassVar[str]
Name of the objective function.
has_hessian : ClassVar[bool]
Whether the objective function has a Hessian implementation.
priority : float
Weight/Priority assigned to the objective function.
quantity : str
The quantity this objective is connected to (e.g. 'physical_dose', 'RBExDose').
"""
name: ClassVar[str]
has_hessian: ClassVar[bool] = False
priority: float = Field(default=1.0, ge=0.0, alias="penalty")
quantity: str = Field(default="physical_dose")
_resampled_image_reference_cache: dict[str, Array] = PrivateAttr(default_factory=dict)
[docs]
def preprocess_image_reference_parameters(
self, target_grid: Grid, index_list: Optional[Array] = None
):
"""
Preprocess image reference parameters if existing in the objective definition.
Preprocessing of reference image parameters is necessary to align with the corresponding
target dose/optimization grid. The function will resample the reference image to the
target grid and cache it.
Parameters
----------
target_grid : Grid
The target grid that the parameter should match.
index_list : Optional[Array], optional
Array containing indices for which the objective needs to be cached
"""
for param_name, param_type in zip(self.parameter_names, self.parameter_types):
if param_type == "image_reference":
# Check cache first
cache_key = param_name
param_value = getattr(self, param_name)
# Get Array and Grid from parameter value
if isinstance(param_value, sitk.Image):
ref_grid = Grid.from_sitk_image(param_value)
param_value = sitk.GetArrayViewFromImage(param_value)
elif isinstance(param_value, tuple):
param_value, ref_grid = param_value
# Get array namespace from parameter value
xp = array_api_compat.array_namespace(param_value, index_list)
# Resample the parameter to target grid
resampled_array = resample_numpy_array(
input_array=to_numpy(param_value),
reference_grid=ref_grid,
target_grid=target_grid,
)
resampled_array = xp.reshape(
from_numpy(xp, resampled_array),
(array_api_compat.size(resampled_array),),
copy=False,
)
if index_list is not None:
resampled_array = resampled_array[index_list]
# Cache the resampled value
self._resampled_image_reference_cache[cache_key] = xp.asarray(resampled_array)
[docs]
@abstractmethod
def compute_objective(self, values):
"""Compute the objective function."""
[docs]
@abstractmethod
def compute_gradient(self, values):
"""Compute the objective gradient."""
[docs]
def compute_hessian(self, values):
"""Compute the objective Hessian."""
return
@computed_field
@property
def parameter_names(self) -> list[str]:
"""List[str]: Parameter names."""
return self._parameter_names()
@classmethod
def _parameter_names(cls) -> list[str]:
"""List[str]: Parameter names as classmethod."""
return [
name
for name, info in cls.model_fields.items()
if any(isinstance(meta, ParameterMetadata) for meta in info.metadata)
]
@computed_field
@property
def parameter_types(self) -> list[ParameterType]:
"""List[str]: Parameter types."""
return self._parameter_types()
@classmethod
def _parameter_types(cls) -> list[ParameterType]:
"""List[str]: Parameter types as classmethod."""
return [
meta.kind
for name in cls._parameter_names()
for meta in cls.model_fields[name].metadata
if isinstance(meta, ParameterMetadata)
]
@computed_field
@property
def parameters(self) -> list[Any]:
"""List[str]: Parameter values."""
return [getattr(self, name) for name in self.parameter_names]
@field_validator("quantity")
@classmethod
def _validate_quantity(cls, v):
"""Validate the quantity attribute."""
if v not in get_available_quantities():
raise ValueError(
f"Quantity {v} not available. Choose from {get_available_quantities()}"
)
return v
@model_validator(mode="before")
@classmethod
def _validate_model(cls, data: Any) -> Any:
"""Pre-validate the input and perform conversions if necessary."""
# Check if this is a matRad-like objective
if isinstance(data, dict) and "className" in data:
data = data.copy()
# Should we confirm once more we have the correct objective?
data.pop("className")
params = data.get("parameters", [])
# If there are not more than one parameter,
# it will usually not be in a list so we put it into one
if not isinstance(params, list):
params = [params]
# obtain the parameter names
param_names = cls._parameter_names()
if len(params) != len(param_names):
logger.warning(
"Objective '%s' expects %d parameters, but %d were provided.",
cls.name,
len(param_names),
len(params),
)
for param in param_names:
data[param] = params.pop(0)
data.pop("parameters", None)
return data