Source code for pyRadPlan.cst._cst

"""Structure Set Implementation."""

from typing import Any, Union, Optional
from typing_extensions import Self
from pydantic import (
    Field,
    model_validator,
    ValidationInfo,
)

import numpy as np
from scipy import ndimage
import SimpleITK as sitk

from ..core import PyRadPlanBaseModel, Grid
from ..core.resample import resample_numpy_array
from ..ct import CT, validate_ct
from . import VOI, ExternalVOI, validate_voi, DEFAULT_VOI_COLORS


[docs] class StructureSet(PyRadPlanBaseModel): """Represents a Structure Set for a Patient.""" vois: list[VOI] = Field(init=False, description="List of VOIs in the Structure Set") ct_image: CT = Field(init=False, description="Reference to the CT Image") @classmethod def _process_matrad_data(cls, data: list, info: ValidationInfo) -> dict: """Handle data coming from matRad.""" # If the data is from matRad, we need to handle the beam quantities differently. # The keys are usually named with a "_beam" suffix. voi_list = [] cst_data = data["vois"] ct = data["ct_image"] def get_idx_list(vdata_item): # Wrap a single array into a list if needed arr = vdata_item[3] # return only one scenario (3D) else: Multi-Scenario (4D) return ( [np.asarray(arr, dtype=int).ravel()] if not isinstance(arr, list) else [np.asarray(a, dtype=int).ravel() for a in arr] ) def create_mask(idx): # Create the ITK mask for a given index list. tmp_mask = np.zeros((ct.size[2], ct.size[0], ct.size[1]), dtype=np.uint8) tmp_mask.flat[idx - 1] = 1 tmp_mask = np.swapaxes(tmp_mask, 1, 2) mask_image = sitk.GetImageFromArray(tmp_mask) mask_image.CopyInformation(ct.cube_hu) return mask_image for vdata in cst_data: idx_list = get_idx_list(vdata) masks = [create_mask(idx) for idx in idx_list] # For 4D, we need to join the masks. We also check here if the number of masks we could # extract matches the number of dimensions in the CT image dim = ct.cube_hu.GetDimension() if dim == 4: # First check if the mask is the same for all 4D scenarios if len(masks) == 1: masks = [masks[0]] * dim # Now do a sanity check that we don't have an incompatible number of masks if len(masks) != ct.cube_hu.GetSize()[3]: raise ValueError("Incompatible number of masks for 4D CT") masks = sitk.JoinSeries(*masks) # If it is a 3D CT, we just drop the list elif dim == 3: masks = masks[0] else: raise ValueError("Sanity Check failed -- unsupported CT dimensionality") # Check Objectives objectives = vdata[5] if len(vdata) > 5 else [] if not isinstance(objectives, list): objectives = [objectives] props = vdata[4] if len(vdata) > 4 and isinstance(vdata[4], dict) else {} voi = validate_voi( name=str(vdata[1]), voi_type=str(vdata[2]), mask=masks, ct_image=ct, objectives=objectives, **props, ) voi_list.append(voi) return {"vois": voi_list, "ct_image": ct}
[docs] @model_validator(mode="before") @classmethod def aggregate_dynamic_quantities(cls, data: Any, info: ValidationInfo) -> Any: # Validate required keys. # Not needed since pydantic validation will take care of this. # But error prompt is in more detail. if data.get("vois") is None: raise ValueError("No cst provided. Please provide a cst.") if data.get("ct_image") is None: raise ValueError("No reference CT provided. Please provide a CT.") data["ct_image"] = validate_ct(data["ct_image"]) # Handle the case where vois is supplied as a dict. if isinstance(data["vois"], dict): ct_from_vois = data["vois"].pop("ct_image", data["vois"].pop("ctImage", None)) if ct_from_vois is not None and ct_from_vois != data["ct_image"]: raise ValueError("CT image mismatch between StructureSet and provided CT") return data # Convert ndarray to list if needed. # needed for matRad cell array if isinstance(data["vois"], np.ndarray): data["vois"] = data["vois"].tolist() # If vois is a list and not already a list of VOI dicts, process it # -> assume matrad data if ( isinstance(data["vois"], list) and data["vois"] and not isinstance(data["vois"][0], dict) and not all(isinstance(i, VOI) for i in data["vois"]) ): data = cls._process_matrad_data(data, info) return data
[docs] @model_validator(mode="after") def check_cst(self) -> Self: """Check if the VOIs are valid and reference the same CT.""" if isinstance(self.vois, list): for voi in self.vois: if voi.ct_image != self.ct_image: raise ValueError("All VOIs must reference the same CT image.") self.set_colors() return self
[docs] def to_matrad(self, context: str = "mat-file") -> Any: """Convert the StructureSet to a matRad writeable format.""" if context != "mat-file": raise ValueError(f"Context {context} not supported") export_cell_list = [] for i, voi in enumerate(self.vois): voi_list = voi.to_matrad(context=context) voi_list[0] = i # TODO: set objectives here voi_list[5] = {} export_cell_list.append(voi_list) return export_cell_list
# Additional Properties @property def voi_types(self) -> list: """Return the unique VOI types in the Structure Set.""" return list({voi.voi_type for voi in self.vois})
[docs] def target_union_voxels(self, order="sitk") -> np.ndarray: """Return the union of all target indices.""" target_indices = [] for voi in self.vois: if voi.voi_type == "TARGET": target_indices.append(voi.get_indices(order=order)) return np.unique(np.concatenate(target_indices))
[docs] def target_union_mask(self) -> sitk.Image: """Return the union mask of all target indices.""" target_indices = self.target_union_voxels(order="numpy") # Creates a copy of the CT image with all zeros if self.ct_image.cube_hu.GetDimension() == 4: sz = np.array(self.ct_image.cube_hu.GetSize()) sz[3] = 0 tmpmask3d = sitk.Extract( self.ct_image.cube_hu, size=sz.tolist(), index=[0, 0, 0, 0], directionCollapseToStrategy=sitk.ExtractImageFilter.DIRECTIONCOLLAPSETOSUBMATRIX, ) else: tmpmask3d = self.ct_image.cube_hu mask = sitk.GetArrayViewFromImage(tmpmask3d).astype(np.uint8) mask.fill(0) mask.ravel()[target_indices] = 1 mask_image = sitk.GetImageFromArray(mask) mask_image.CopyInformation(tmpmask3d) return mask_image
[docs] def patient_voxels(self, order="sitk") -> np.ndarray: """Return the union of all patient indices.""" # First check if we have an "EXTERNAL" VOI designating the outer contour for voi in self.vois: if isinstance(voi, ExternalVOI): return voi.get_indices(order=order) patient_indices = [] for voi in self.vois: patient_indices.append(voi.get_indices(order=order)) return np.unique(np.concatenate(patient_indices))
[docs] def patient_mask(self) -> sitk.Image: """Return the union mask of all patient contours (or the EXTERNAL contour if provided).""" patient_indices = self.patient_voxels(order="numpy") # Creates a copy of the CT image with all zeros mask = sitk.GetArrayFromImage(self.ct_image.cube_hu).astype(np.uint8) mask.fill(0) mask.ravel()[patient_indices] = 1 mask_image = sitk.GetImageFromArray(mask) mask_image.CopyInformation(self.ct_image.cube_hu) return mask_image
[docs] def target_center_of_mass(self) -> np.ndarray: """Return the center of mass of the target.""" mask_image = self.target_union_mask() mask = sitk.GetArrayViewFromImage( mask_image ).T # Transpose allows use to use sitk indexing if mask.ndim == 4: mask = mask[:, :, :, 0] cm_index = ndimage.center_of_mass(mask) cm = mask_image.TransformContinuousIndexToPhysicalPoint(cm_index) return np.array(cm)
[docs] def resample_on_new_ct(self, new_ct: CT) -> Self: """ Resample the StructureSet on a new CT. Parameters ---------- new_ct : CT The new CT to resample the StructureSet on. Returns ------- StructureSet The resampled StructureSet. """ new_model = self.model_dump() if new_model["ct_image"] != new_ct: new_model["ct_image"] = new_ct new_model["vois"] = [voi.resample_on_new_ct(new_ct) for voi in self.vois] return self.model_validate(new_model)
[docs] def apply_overlap_priorities(self) -> Self: """ Apply overlaps to the StructureSet. Returns ------- StructureSet The StructureSet with overlaps applied. """ # gather overlaps overlaps = [v.overlap_priority for v in self.vois] # sort by overlap priority ix_sorted = np.argsort(overlaps) overlap_mask = self.vois[ix_sorted[0]].mask # Will aggregate the overlap masks using OR last_priority_mask = sitk.Image(overlap_mask.GetSize(), overlap_mask.GetPixelID()) last_priority_mask.CopyInformation(overlap_mask) new_vois = [None] * len(self.vois) new_vois[ix_sorted[0]] = self.vois[ix_sorted[0]].model_copy() for i, ix_voi in enumerate(ix_sorted[1:], 1): curr_voi = self.vois[ix_voi].model_copy() curr_mask = curr_voi.mask # if the overlap priority is higher than we need to apply overlap if curr_voi.overlap_priority >= new_vois[ix_sorted[i - 1]].overlap_priority: if curr_voi.overlap_priority > new_vois[ix_sorted[i - 1]].overlap_priority: last_priority_mask = overlap_mask curr_mask = sitk.MaskNegated(curr_mask, last_priority_mask) curr_voi.mask = curr_mask overlap_mask = sitk.Or(overlap_mask, curr_mask) # sitk.Show(overlap_mask, debugOn = True) new_vois[ix_voi] = curr_voi return self.model_copy(deep=True, update={"vois": new_vois})
[docs] def create_body_seg( self, threshold: float = -200.0, name: str = "BODY", voi_type="OAR" ) -> Self: """ Create a body segmentation from CT data based on a HU threshold. This method generates a body contour by thresholding the CT image, identifying the largest connected component (main body), and filling internal air cavities (like lungs) slice by slice to handle cases where the scan is cut off. Parameters ---------- threshold : float, optional HU threshold value for body segmentation. Voxels with HU values above this threshold are considered part of the body. Default is -200.0 HU (approximately air/tissue boundary). name : str, optional Name for the body VOI. Default is "BODY". voi_type : str, optional Type of the VOI to create. Default is "OAR". Returns ------- Self Updated StructureSet with the body segmentation added. """ # Segment the Body Contour using the provided threshold max_value = sitk.GetArrayViewFromImage(self.ct_image.cube_hu).max() binary_segmentation = sitk.BinaryThreshold( self.ct_image.cube_hu, threshold, float(max_value * 1.1) ) # Step 1: Label connected components labeled = sitk.ConnectedComponent(binary_segmentation) # Step 2: Compute statistics stats = sitk.LabelShapeStatisticsImageFilter() stats.Execute(labeled) # Step 3: Identify label with largest number of pixels largest_label = max(stats.GetLabels(), key=stats.GetNumberOfPixels) # Step 4: Create binary mask of largest component body_segmentation = sitk.Equal(labeled, largest_label) # Step 5: Fill holes slice by slice to handle cases where scan is cut off # (e.g. cutting of lungs results in not enclosing the entire lung volume) size = body_segmentation.GetSize() filled_slices = [] for slice_idx in range(size[2]): # Iterate through z-direction # Extract 2D slice slice_2d = sitk.Extract( body_segmentation, size=[size[0], size[1], 0], # 2D slice index=[0, 0, slice_idx], ) # Fill holes in this 2D slice filled_slice = sitk.BinaryFillhole(slice_2d) filled_slices.append(filled_slice) # Join all slices back into 3D volume body_segmentation = sitk.JoinSeries(filled_slices) body_voi_data = { "name": name, "ct_image": self.ct_image, "mask": body_segmentation, "voi_type": voi_type, "alpha_x": 0.1, "beta_x": 0.05, } body_voi = validate_voi(body_voi_data) # Add to existing VOIs self.vois = self.vois + [body_voi]
[docs] def get_reference_lq_params( self, overlap_is_applied: bool = False, resample_grid: Optional[Grid] = None ) -> tuple[np.ndarray, np.ndarray]: """Get the reference LQ parameters (alpha_x and beta_x) for the given CT.""" if not overlap_is_applied: cst = self.apply_overlap_priorities() num_voxels = np.prod(cst.vois[0].ct_image.size) alpha = np.zeros(num_voxels) beta = np.zeros(num_voxels) for voi in cst.vois: alpha[voi.indices_numpy] = voi.alpha_x beta[voi.indices_numpy] = voi.beta_x if resample_grid is not None: original_grid = cst.ct_image.grid alpha = resample_numpy_array( alpha.reshape(original_grid.dimensions[::-1]), reference_grid=original_grid, interpolator=sitk.sitkNearestNeighbor, target_grid=resample_grid, ).ravel() beta = resample_numpy_array( beta.reshape(original_grid.dimensions[::-1]), reference_grid=original_grid, interpolator=sitk.sitkNearestNeighbor, target_grid=resample_grid, ).ravel() return alpha, beta
[docs] def set_colors(self) -> Self: """Assign colors from ``DEFAULT_VOI_COLORS`` to VOIs lacking ``visible_color``. Colors are popped in order from the predefined list for each VOI type, skipping any color already taken by another VOI. If the list is exhausted the last color in the list is reused. Existing ``visible_color`` values are preserved. """ # Per-type pool of unused predefined colors. pool: dict[str, list[tuple[int, int, int]]] = { t: list(palette) for t, palette in DEFAULT_VOI_COLORS.items() } for voi in self.vois: if voi.visible_color is not None: color = tuple(voi.visible_color) if color in pool.get(voi.voi_type, []): pool[voi.voi_type].remove(color) for voi in self.vois: if voi.visible_color is None: palette = DEFAULT_VOI_COLORS.get(voi.voi_type, [(128, 128, 128)]) remaining = pool.setdefault(voi.voi_type, list(palette)) voi.visible_color = remaining.pop(0) if remaining else palette[-1] return self
[docs] def create_cst( cst_data: Union[dict[str, Any], StructureSet, None] = None, ct: Union[CT, dict, None] = None, **kwargs, ) -> StructureSet: """ Create a StructureSet from various input types. Parameters ---------- cst_data : Union[dict[str, Any], StructureSet, None] , optional The input data to create the CT object from. Can be a dictionary, existing CT object, file path, or None. ct : Union[CT,dict,None], CT, None], optional The input data to create the CT object from. Can be a dictionary, existing CT object, or None. **kwargs Additional keyword arguments to create the CT object. Returns ------- cst A StructureSet object created from the input data or keyword arguments. """ if isinstance(cst_data, StructureSet): if ct and cst_data.ct_image != validate_ct(ct): raise ValueError("CT image mismatch between StructureSet and provided CT") return cst_data elif isinstance(cst_data, dict): cst_data.update(kwargs) cst_data["ct_image"] = ct else: cst_data = {"vois": cst_data, "ct_image": ct, **kwargs} return StructureSet(**cst_data)
[docs] def validate_cst( cst_data: Union[dict[str, Any], StructureSet, None] = None, ct: Union[CT, dict, None] = None, **kwargs, ) -> StructureSet: """ Validate StructureSet. Parameters ---------- cst_data : Union[dict[str, Any], StructureSet, None] , optional The input data to create the CT object from. Can be a dictionary, existing CT object, file path, or None. ct : Union[CT,dict,None], CT, None], optional The input data to create the CT object from. Can be a dictionary, existing CT object, or None. **kwargs Additional keyword arguments to create the CT object. Returns ------- cst A StructureSet object created from the input data or keyword arguments. """ return create_cst(cst_data, ct, **kwargs)