Source code for pyRadPlan.stf._steeringinformation

from typing import Any, Union
from typing_extensions import Self
import numpy as np
from pydantic import (
    SerializationInfo,
    model_validator,
    field_serializer,
    ValidatorFunctionWrapHandler,
    ValidationInfo,
    ValidationError,
)
from numpydantic import NDArray, Shape
from pyRadPlan.stf._beam import Beam
from pyRadPlan.core import PyRadPlanBaseModel
from pyRadPlan.util.helpers import models2recarray


[docs] class SteeringInformation(PyRadPlanBaseModel): """ A class representing the Steering Information (stf). This class extends PyRadPlanBaseModel (based on pydantic) and provides functionality to handle single beams, including their properties. These are defined in the corresponding class (_Beam.py). Attributes ---------- beams : List[Beam] - list consisting of Beam objects (pydantic) beam class object containing the properties of the beam. Methods ------- validate_model_input(data: Any) -> Any Validates the input data before creating the model instance. to_matrad() -> dict Creates a dictionary ready to save the stf model to a mat-file that can be read. """ beams: list[Beam] # Validation
[docs] @model_validator(mode="wrap") @classmethod def validate_model_input( cls, data: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo ) -> Self: """ Validate the input data for creating the model instance. Will first try to run pydantics handler, and if it fails with a ValidationError, it will try to convert the data to the right format. """ try: return handler(data, info) except ValidationError: # Check if import is from matlab # If from matlab but stf already chosen, the code is executed as normal matrad_format = ["__header__", "__version__", "__globals__"] if all(key in data for key in matrad_format): # struct of matlab usual in 4. position. #TODO: are there exceptions? data = data[list(data.keys())[3]] if isinstance(data, dict) and "beams" not in data: # This code is needed to pass the right format of the stf from matRad try: length_lists = [len(v) for v in data.values()] except (TypeError, ValueError): return handler({"beams": [data]}, info) # TODO: only works if its a list?? if len(set(length_lists)) == 1 and len(length_lists) > 1: beam_list = [] length_lists = [] for v in data.values(): if isinstance(v, list): length_lists.append(len(v)) # this exception is necessary if imported stf from matlab is only one beam if isinstance(v, int or float): length_lists.append(1) for i in range(length_lists[0]): beam = [] for key in data.keys(): entry = [key, data[key][i]] beam.append(entry) beam_list.append(dict(beam)) return handler({"beams": beam_list}, info) if isinstance(data, list): return handler({"beams": data}, info) return handler(data, info) except Exception as exc: raise exc
[docs] @field_serializer("beams") def custom_beams_serializer(self, v: list[Beam], info: SerializationInfo) -> Any: """Serialize the beams fields in various contexts.""" context = info.context if context and context.get("matRad") == "mat-file": override_types = {"rays": np.recarray} beams_recarray = models2recarray( v, override_types=override_types, serialization_context=context ) return beams_recarray return [ beam.model_dump( by_alias=info.by_alias, ) for beam in v ]
[docs] def to_matrad(self, context: str = "mat-file") -> Any: export = super().to_matrad(context=context) return export["beams"]
@property def num_of_beams(self) -> int: return len(self.beams) @property def num_of_rays(self) -> int: return sum([beam.num_of_rays for beam in self.beams]) @property def total_number_of_bixels(self) -> int: return sum([beam.total_number_of_bixels for beam in self.beams]) @property def bixel_beam_index_map(self) -> NDArray[Shape["1-*"], np.int64]: """Mapping of bixels to their respective beam index.""" tmp_map = np.zeros(self.total_number_of_bixels, dtype=np.int64) start = 0 for b, beam in enumerate(self.beams): tmp_map[start : start + beam.total_number_of_bixels] = b start += beam.total_number_of_bixels return tmp_map @property def bixel_ray_index_per_beam_map(self) -> NDArray[Shape["1-*"], np.int64]: """Mapping of bixels to the ray index in the individual beams.""" tmp_map = np.zeros(self.total_number_of_bixels, dtype=np.int64) start = 0 for beam in self.beams: tmp_map[start : start + beam.total_number_of_bixels] = beam.bixel_ray_map start += beam.total_number_of_bixels return tmp_map @property def bixel_index_per_beam_map(self) -> NDArray[Shape["1-*"], np.int64]: """Mapping of bixels to their bixel index in the respective beam.""" tmp_map = np.zeros(self.total_number_of_bixels, dtype=np.int64) start = 0 for beam in self.beams: tmp_map[start : start + beam.total_number_of_bixels] = np.arange( beam.total_number_of_bixels ) start += beam.total_number_of_bixels return tmp_map
[docs] def create_stf( stf: Union[dict[str, Any], SteeringInformation, None] = None, **kwargs ) -> SteeringInformation: """ Create a Steering Information object. Parameters ---------- stf : Union[dict[str, Any], None] dictionary containing the data to create the stf object. **kwargs Arbitrary keyword arguments. Returns ------- SteeringInformation A SteeringInformation class object. Raises ------ ValueError If the radiation mode is unknown or empty. """ if stf: # If data is already a Stf object, return it directly if isinstance(stf, SteeringInformation): return stf return SteeringInformation.model_validate(stf) return SteeringInformation(**kwargs) # not tested
[docs] def validate_stf( stf: Union[dict[str, Any], SteeringInformation, None] = None, **kwargs ) -> SteeringInformation: """ Validate a Steering Information object. Synonym to create_stf but should be used in validation context. Parameters ---------- stf : Union[dict[str, Any], None] dictionary containing the data to create the stf object. Returns ------- SteeringInformation A validated SteeringInformation class object. """ return create_stf(stf, **kwargs)