"""Structure Set Implementation."""
from typing import Any, Union, Optional
from typing_extensions import Self
from pydantic import (
Field,
model_validator,
ValidationInfo,
)
import numpy as np
from scipy import ndimage
import SimpleITK as sitk
from ..core import PyRadPlanBaseModel, Grid
from ..core.resample import resample_numpy_array
from ..ct import CT, validate_ct
from . import VOI, ExternalVOI, validate_voi, DEFAULT_VOI_COLORS
[docs]
class StructureSet(PyRadPlanBaseModel):
"""Represents a Structure Set for a Patient."""
vois: list[VOI] = Field(init=False, description="List of VOIs in the Structure Set")
ct_image: CT = Field(init=False, description="Reference to the CT Image")
@classmethod
def _process_matrad_data(cls, data: list, info: ValidationInfo) -> dict:
"""Handle data coming from matRad."""
# If the data is from matRad, we need to handle the beam quantities differently.
# The keys are usually named with a "_beam" suffix.
voi_list = []
cst_data = data["vois"]
ct = data["ct_image"]
def get_idx_list(vdata_item):
# Wrap a single array into a list if needed
arr = vdata_item[3]
# return only one scenario (3D) else: Multi-Scenario (4D)
return (
[np.asarray(arr, dtype=int).ravel()]
if not isinstance(arr, list)
else [np.asarray(a, dtype=int).ravel() for a in arr]
)
def create_mask(idx):
# Create the ITK mask for a given index list.
tmp_mask = np.zeros((ct.size[2], ct.size[0], ct.size[1]), dtype=np.uint8)
tmp_mask.flat[idx - 1] = 1
tmp_mask = np.swapaxes(tmp_mask, 1, 2)
mask_image = sitk.GetImageFromArray(tmp_mask)
mask_image.CopyInformation(ct.cube_hu)
return mask_image
for vdata in cst_data:
idx_list = get_idx_list(vdata)
masks = [create_mask(idx) for idx in idx_list]
# For 4D, we need to join the masks. We also check here if the number of masks we could
# extract matches the number of dimensions in the CT image
dim = ct.cube_hu.GetDimension()
if dim == 4:
# First check if the mask is the same for all 4D scenarios
if len(masks) == 1:
masks = [masks[0]] * dim
# Now do a sanity check that we don't have an incompatible number of masks
if len(masks) != ct.cube_hu.GetSize()[3]:
raise ValueError("Incompatible number of masks for 4D CT")
masks = sitk.JoinSeries(*masks)
# If it is a 3D CT, we just drop the list
elif dim == 3:
masks = masks[0]
else:
raise ValueError("Sanity Check failed -- unsupported CT dimensionality")
# Check Objectives
objectives = vdata[5] if len(vdata) > 5 else []
if not isinstance(objectives, list):
objectives = [objectives]
props = vdata[4] if len(vdata) > 4 and isinstance(vdata[4], dict) else {}
voi = validate_voi(
name=str(vdata[1]),
voi_type=str(vdata[2]),
mask=masks,
ct_image=ct,
objectives=objectives,
**props,
)
voi_list.append(voi)
return {"vois": voi_list, "ct_image": ct}
[docs]
@model_validator(mode="before")
@classmethod
def aggregate_dynamic_quantities(cls, data: Any, info: ValidationInfo) -> Any:
# Validate required keys.
# Not needed since pydantic validation will take care of this.
# But error prompt is in more detail.
if data.get("vois") is None:
raise ValueError("No cst provided. Please provide a cst.")
if data.get("ct_image") is None:
raise ValueError("No reference CT provided. Please provide a CT.")
data["ct_image"] = validate_ct(data["ct_image"])
# Handle the case where vois is supplied as a dict.
if isinstance(data["vois"], dict):
ct_from_vois = data["vois"].pop("ct_image", data["vois"].pop("ctImage", None))
if ct_from_vois is not None and ct_from_vois != data["ct_image"]:
raise ValueError("CT image mismatch between StructureSet and provided CT")
return data
# Convert ndarray to list if needed.
# needed for matRad cell array
if isinstance(data["vois"], np.ndarray):
data["vois"] = data["vois"].tolist()
# If vois is a list and not already a list of VOI dicts, process it
# -> assume matrad data
if (
isinstance(data["vois"], list)
and data["vois"]
and not isinstance(data["vois"][0], dict)
and not all(isinstance(i, VOI) for i in data["vois"])
):
data = cls._process_matrad_data(data, info)
return data
[docs]
@model_validator(mode="after")
def check_cst(self) -> Self:
"""Check if the VOIs are valid and reference the same CT."""
if isinstance(self.vois, list):
for voi in self.vois:
if voi.ct_image != self.ct_image:
raise ValueError("All VOIs must reference the same CT image.")
self.set_colors()
return self
[docs]
def to_matrad(self, context: str = "mat-file") -> Any:
"""Convert the StructureSet to a matRad writeable format."""
if context != "mat-file":
raise ValueError(f"Context {context} not supported")
export_cell_list = []
for i, voi in enumerate(self.vois):
voi_list = voi.to_matrad(context=context)
voi_list[0] = i
# TODO: set objectives here
voi_list[5] = {}
export_cell_list.append(voi_list)
return export_cell_list
# Additional Properties
@property
def voi_types(self) -> list:
"""Return the unique VOI types in the Structure Set."""
return list({voi.voi_type for voi in self.vois})
[docs]
def target_union_voxels(self, order="sitk") -> np.ndarray:
"""Return the union of all target indices."""
target_indices = []
for voi in self.vois:
if voi.voi_type == "TARGET":
target_indices.append(voi.get_indices(order=order))
return np.unique(np.concatenate(target_indices))
[docs]
def target_union_mask(self) -> sitk.Image:
"""Return the union mask of all target indices."""
target_indices = self.target_union_voxels(order="numpy")
# Creates a copy of the CT image with all zeros
if self.ct_image.cube_hu.GetDimension() == 4:
sz = np.array(self.ct_image.cube_hu.GetSize())
sz[3] = 0
tmpmask3d = sitk.Extract(
self.ct_image.cube_hu,
size=sz.tolist(),
index=[0, 0, 0, 0],
directionCollapseToStrategy=sitk.ExtractImageFilter.DIRECTIONCOLLAPSETOSUBMATRIX,
)
else:
tmpmask3d = self.ct_image.cube_hu
mask = sitk.GetArrayViewFromImage(tmpmask3d).astype(np.uint8)
mask.fill(0)
mask.ravel()[target_indices] = 1
mask_image = sitk.GetImageFromArray(mask)
mask_image.CopyInformation(tmpmask3d)
return mask_image
[docs]
def patient_voxels(self, order="sitk") -> np.ndarray:
"""Return the union of all patient indices."""
# First check if we have an "EXTERNAL" VOI designating the outer contour
for voi in self.vois:
if isinstance(voi, ExternalVOI):
return voi.get_indices(order=order)
patient_indices = []
for voi in self.vois:
patient_indices.append(voi.get_indices(order=order))
return np.unique(np.concatenate(patient_indices))
[docs]
def patient_mask(self) -> sitk.Image:
"""Return the union mask of all patient contours (or the EXTERNAL contour if provided)."""
patient_indices = self.patient_voxels(order="numpy")
# Creates a copy of the CT image with all zeros
mask = sitk.GetArrayFromImage(self.ct_image.cube_hu).astype(np.uint8)
mask.fill(0)
mask.ravel()[patient_indices] = 1
mask_image = sitk.GetImageFromArray(mask)
mask_image.CopyInformation(self.ct_image.cube_hu)
return mask_image
[docs]
def target_center_of_mass(self) -> np.ndarray:
"""Return the center of mass of the target."""
mask_image = self.target_union_mask()
mask = sitk.GetArrayViewFromImage(
mask_image
).T # Transpose allows use to use sitk indexing
if mask.ndim == 4:
mask = mask[:, :, :, 0]
cm_index = ndimage.center_of_mass(mask)
cm = mask_image.TransformContinuousIndexToPhysicalPoint(cm_index)
return np.array(cm)
[docs]
def resample_on_new_ct(self, new_ct: CT) -> Self:
"""
Resample the StructureSet on a new CT.
Parameters
----------
new_ct : CT
The new CT to resample the StructureSet on.
Returns
-------
StructureSet
The resampled StructureSet.
"""
new_model = self.model_dump()
if new_model["ct_image"] != new_ct:
new_model["ct_image"] = new_ct
new_model["vois"] = [voi.resample_on_new_ct(new_ct) for voi in self.vois]
return self.model_validate(new_model)
[docs]
def apply_overlap_priorities(self) -> Self:
"""
Apply overlaps to the StructureSet.
Returns
-------
StructureSet
The StructureSet with overlaps applied.
"""
# gather overlaps
overlaps = [v.overlap_priority for v in self.vois]
# sort by overlap priority
ix_sorted = np.argsort(overlaps)
overlap_mask = self.vois[ix_sorted[0]].mask # Will aggregate the overlap masks using OR
last_priority_mask = sitk.Image(overlap_mask.GetSize(), overlap_mask.GetPixelID())
last_priority_mask.CopyInformation(overlap_mask)
new_vois = [None] * len(self.vois)
new_vois[ix_sorted[0]] = self.vois[ix_sorted[0]].model_copy()
for i, ix_voi in enumerate(ix_sorted[1:], 1):
curr_voi = self.vois[ix_voi].model_copy()
curr_mask = curr_voi.mask
# if the overlap priority is higher than we need to apply overlap
if curr_voi.overlap_priority >= new_vois[ix_sorted[i - 1]].overlap_priority:
if curr_voi.overlap_priority > new_vois[ix_sorted[i - 1]].overlap_priority:
last_priority_mask = overlap_mask
curr_mask = sitk.MaskNegated(curr_mask, last_priority_mask)
curr_voi.mask = curr_mask
overlap_mask = sitk.Or(overlap_mask, curr_mask)
# sitk.Show(overlap_mask, debugOn = True)
new_vois[ix_voi] = curr_voi
return self.model_copy(deep=True, update={"vois": new_vois})
[docs]
def create_body_seg(
self, threshold: float = -200.0, name: str = "BODY", voi_type="OAR"
) -> Self:
"""
Create a body segmentation from CT data based on a HU threshold.
This method generates a body contour by thresholding the CT image,
identifying the largest connected component (main body), and filling
internal air cavities (like lungs) slice by slice to handle cases
where the scan is cut off.
Parameters
----------
threshold : float, optional
HU threshold value for body segmentation. Voxels with HU values above
this threshold are considered part of the body. Default is -200.0 HU
(approximately air/tissue boundary).
name : str, optional
Name for the body VOI. Default is "BODY".
voi_type : str, optional
Type of the VOI to create. Default is "OAR".
Returns
-------
Self
Updated StructureSet with the body segmentation added.
"""
# Segment the Body Contour using the provided threshold
max_value = sitk.GetArrayViewFromImage(self.ct_image.cube_hu).max()
binary_segmentation = sitk.BinaryThreshold(
self.ct_image.cube_hu, threshold, float(max_value * 1.1)
)
# Step 1: Label connected components
labeled = sitk.ConnectedComponent(binary_segmentation)
# Step 2: Compute statistics
stats = sitk.LabelShapeStatisticsImageFilter()
stats.Execute(labeled)
# Step 3: Identify label with largest number of pixels
largest_label = max(stats.GetLabels(), key=stats.GetNumberOfPixels)
# Step 4: Create binary mask of largest component
body_segmentation = sitk.Equal(labeled, largest_label)
# Step 5: Fill holes slice by slice to handle cases where scan is cut off
# (e.g. cutting of lungs results in not enclosing the entire lung volume)
size = body_segmentation.GetSize()
filled_slices = []
for slice_idx in range(size[2]): # Iterate through z-direction
# Extract 2D slice
slice_2d = sitk.Extract(
body_segmentation,
size=[size[0], size[1], 0], # 2D slice
index=[0, 0, slice_idx],
)
# Fill holes in this 2D slice
filled_slice = sitk.BinaryFillhole(slice_2d)
filled_slices.append(filled_slice)
# Join all slices back into 3D volume
body_segmentation = sitk.JoinSeries(filled_slices)
body_voi_data = {
"name": name,
"ct_image": self.ct_image,
"mask": body_segmentation,
"voi_type": voi_type,
"alpha_x": 0.1,
"beta_x": 0.05,
}
body_voi = validate_voi(body_voi_data)
# Add to existing VOIs
self.vois = self.vois + [body_voi]
[docs]
def get_reference_lq_params(
self, overlap_is_applied: bool = False, resample_grid: Optional[Grid] = None
) -> tuple[np.ndarray, np.ndarray]:
"""Get the reference LQ parameters (alpha_x and beta_x) for the given CT."""
if not overlap_is_applied:
cst = self.apply_overlap_priorities()
num_voxels = np.prod(cst.vois[0].ct_image.size)
alpha = np.zeros(num_voxels)
beta = np.zeros(num_voxels)
for voi in cst.vois:
alpha[voi.indices_numpy] = voi.alpha_x
beta[voi.indices_numpy] = voi.beta_x
if resample_grid is not None:
original_grid = cst.ct_image.grid
alpha = resample_numpy_array(
alpha.reshape(original_grid.dimensions[::-1]),
reference_grid=original_grid,
interpolator=sitk.sitkNearestNeighbor,
target_grid=resample_grid,
).ravel()
beta = resample_numpy_array(
beta.reshape(original_grid.dimensions[::-1]),
reference_grid=original_grid,
interpolator=sitk.sitkNearestNeighbor,
target_grid=resample_grid,
).ravel()
return alpha, beta
[docs]
def set_colors(self) -> Self:
"""Assign colors from ``DEFAULT_VOI_COLORS`` to VOIs lacking ``visible_color``.
Colors are popped in order from the predefined list for each VOI type, skipping
any color already taken by another VOI. If the list is exhausted the last color
in the list is reused. Existing ``visible_color`` values are preserved.
"""
# Per-type pool of unused predefined colors.
pool: dict[str, list[tuple[int, int, int]]] = {
t: list(palette) for t, palette in DEFAULT_VOI_COLORS.items()
}
for voi in self.vois:
if voi.visible_color is not None:
color = tuple(voi.visible_color)
if color in pool.get(voi.voi_type, []):
pool[voi.voi_type].remove(color)
for voi in self.vois:
if voi.visible_color is None:
palette = DEFAULT_VOI_COLORS.get(voi.voi_type, [(128, 128, 128)])
remaining = pool.setdefault(voi.voi_type, list(palette))
voi.visible_color = remaining.pop(0) if remaining else palette[-1]
return self
[docs]
def create_cst(
cst_data: Union[dict[str, Any], StructureSet, None] = None,
ct: Union[CT, dict, None] = None,
**kwargs,
) -> StructureSet:
"""
Create a StructureSet from various input types.
Parameters
----------
cst_data : Union[dict[str, Any], StructureSet, None] , optional
The input data to create the CT object from. Can be a dictionary,
existing CT object, file path, or None.
ct : Union[CT,dict,None], CT, None], optional
The input data to create the CT object from. Can be a dictionary,
existing CT object, or None.
**kwargs
Additional keyword arguments to create the CT object.
Returns
-------
cst
A StructureSet object created from the input data or keyword arguments.
"""
if isinstance(cst_data, StructureSet):
if ct and cst_data.ct_image != validate_ct(ct):
raise ValueError("CT image mismatch between StructureSet and provided CT")
return cst_data
elif isinstance(cst_data, dict):
cst_data.update(kwargs)
cst_data["ct_image"] = ct
else:
cst_data = {"vois": cst_data, "ct_image": ct, **kwargs}
return StructureSet(**cst_data)
[docs]
def validate_cst(
cst_data: Union[dict[str, Any], StructureSet, None] = None,
ct: Union[CT, dict, None] = None,
**kwargs,
) -> StructureSet:
"""
Validate StructureSet.
Parameters
----------
cst_data : Union[dict[str, Any], StructureSet, None] , optional
The input data to create the CT object from. Can be a dictionary,
existing CT object, file path, or None.
ct : Union[CT,dict,None], CT, None], optional
The input data to create the CT object from. Can be a dictionary,
existing CT object, or None.
**kwargs
Additional keyword arguments to create the CT object.
Returns
-------
cst
A StructureSet object created from the input data or keyword arguments.
"""
return create_cst(cst_data, ct, **kwargs)