"""Interface for voxel geometry ray tracers."""
from abc import ABC, abstractmethod
from typing import Union, Any, Optional
import logging
import time
import numpy as np
import SimpleITK as sitk
import array_api_compat
from ..core.xp_utils.typing import Array
from ..core.np2sitk import linear_indices_to_image_coordinates
from ..geometry import lps
from ..stf._beam import Beam
try:
from ._numba_perf import fast_spatial_circle_lookup
except ImportError:
fast_spatial_circle_lookup = None
logger = logging.getLogger(__name__)
[docs]
class RayTracerBase(ABC):
"""Base class for all ray tracers."""
lateral_cut_off: float
precision: np.dtype
fixed_ray_spacing_range: Optional[float]
@property
def cubes(self):
"""CT or other arbitrary cubes of similar resolution to be traced."""
return self._cubes
@cubes.setter
def cubes(self, cubes: Union[sitk.Image, list[sitk.Image]]):
if not isinstance(cubes, list):
cubes = [cubes]
self._cubes = cubes
self._initialize_geometry()
self._coords = None
def __init__(self, cubes: Union[sitk.Image, list[sitk.Image]]):
self.lateral_cut_off = 50.0
self.precision = np.float32
self.fixed_ray_spacing_length = None
self.cubes = cubes
self._coords = None
[docs]
def trace_rays(
self,
isocenter: Union[list, np.ndarray],
source_points: Union[list, np.ndarray],
target_points: Union[list, np.ndarray],
) -> tuple[np.ndarray, np.ndarray, list[np.ndarray], np.ndarray, np.ndarray]:
"""
Trace multiple rays through a cube.
Parameters
----------
isocenter : Union[list, np.ndarray]
Isocenter coordinates (1x3) array or list
source_points : Union[list, np.ndarray]
Source points coordinates. (nx3) array or list
target_points : Union[list, np.ndarray]
Target points coordinates. (nx3) array or list
Returns
-------
alphas : ndarray
Array of alpha values for each ray
lengths : ndarray
Array of lengths for each ray
rho : list[ndarray]
Array of rho values for each ray and each cube
d12 : ndarray
Array of full length of each ray
ix : ndarray
Linear indices (in numpy ordering) of the voxels intersected by each ray
Notes
-----
The default implementation loops over the trace_ray function. The separate implementation is
here to enable more performant implementations for specific ray tracers, e.g. through
vectorization.
"""
# Assuming size function equivalent is numpy's shape attribute.
num_rays = target_points.shape[0]
num_sources = source_points.shape[0]
if num_sources not in (num_rays, 1):
# MatRad_Config.instance() and dispError equivalent in Python needs handling.
raise (
f"Number of source points ({num_sources}) needs to be one "
f"or equal to number of target points ({num_rays})!"
)
if num_sources == 1:
source_points = np.tile(source_points, (num_rays, 1))
num_sources = num_rays
alphas, lengths, rho, d12, ix = [], [], [], [], []
for r in range(num_rays):
alpha, l_val, rho_val, d12_val, ix_val = self.trace_ray(
isocenter, source_points[r, :], target_points[r, :]
)
alphas.append(alpha)
lengths.append(l_val)
rho.append(rho_val)
d12.append(d12_val)
ix.append(ix_val)
# Padding with NaN values
maxnumval = max(len(x) for x in ix)
def nanpad(x):
return np.pad(x, (0, maxnumval - len(x)), constant_values=np.nan)
alphas = [nanpad(alpha) for alpha in alphas]
lengths = [nanpad(l_val) for l_val in lengths]
ix = [nanpad(ix_val) for ix_val in ix]
for c in range(len(self.cubes)):
rho[c] = [nanpad(rho_val) for rho_val in rho[c]]
return np.array(alphas), np.array(lengths), rho, np.array(d12), np.array(ix)
[docs]
@abstractmethod
def trace_ray(
self,
isocenter: Union[list, np.ndarray],
source_points: Union[list, np.ndarray],
target_points: Union[list, np.ndarray],
) -> tuple[np.ndarray, np.ndarray, list[np.ndarray], np.ndarray, np.ndarray]:
"""
Trace a single ray through cubes.
Abstract Method to be implemented in subclasses.
"""
[docs]
def trace_cubes(self, beam: Union[dict[str, Any], Beam]) -> list[sitk.Image]:
"""
Automatically calculate depth by tracing rays through cubes.
Set up ray matrix with appropriate spacing to trace through
all cubes, resulting in a cumulative sum of values in every voxel
relative to the source. Will calculate cumulative sum on all of
the supplied images.
"""
if not isinstance(beam, Beam):
beam = Beam.model_validate(beam)
t_trace_start = time.perf_counter()
logger.debug("Computing coordinates...")
if self._coords is None:
# Obtain coordinates
cube_ix = np.arange(self.cubes[0].GetNumberOfPixels(), dtype=np.int64)
self._coords = linear_indices_to_image_coordinates(
cube_ix, self.cubes[0], index_type="sitk", dtype=self.precision
)
# obtain rotation matrix
rot_mat = lps.get_beam_rotation_matrix(beam.gantry_angle, beam.couch_angle)
# rotate coordinates
coords = (self._coords - beam.iso_center) @ rot_mat - beam.source_point_bev
t_trace_end = time.perf_counter()
logger.debug("took %s seconds!", t_trace_end - t_trace_start)
# central_ray_vector = np.array(iso_center) - np.array(source_point).reshape
logger.debug("Setting up Ray matrix...")
t_trace_start = time.perf_counter()
ray_spacing = np.min(self._resolution) / np.sqrt(2.0, dtype=self.precision)
ray_matrix_bev_y = (
np.max(coords[:, 1]) + np.max(self._resolution) + beam.source_point_bev[1]
)
ray_matrix_scale = 1 + ray_matrix_bev_y / beam.sad
# If we have reference positions, we use them to restrict the raytracing region
reference_positions_bev = ray_matrix_scale * np.array(
[ray.ray_pos_bev for ray in beam.rays]
)
if self.fixed_ray_spacing_length is not None:
ray_extent = self.fixed_ray_spacing_length
else:
# look at max ray_positions in bev and add lateral cutoff
ray_extent = 2.0 * (
np.max(np.abs(reference_positions_bev[:, [0, 2]])) + self.lateral_cut_off
)
spacing_range = ray_spacing * np.arange(
np.floor(-ray_extent / ray_spacing),
np.ceil(ray_extent / ray_spacing) + 1,
dtype=self.precision,
)
candidate_ray_mx = self._get_candidate_ray_matrix(spacing_range, reference_positions_bev)
ray_idx_z, ray_idx_x = np.nonzero(candidate_ray_mx)
ray_matrix_bev = np.column_stack(
(
spacing_range[ray_idx_x],
np.full(ray_idx_x.shape[0], ray_matrix_bev_y, dtype=self.precision),
spacing_range[ray_idx_z],
)
)
ray_matrix_lps = ray_matrix_bev @ rot_mat.T
t_trace_end = time.perf_counter()
logger.debug("took %s seconds!", t_trace_end - t_trace_start)
logger.debug("Tracing %d rays through the cubes", np.count_nonzero(candidate_ray_mx))
t_trace_start = time.perf_counter()
_, lengths, rho, d12, ix = self.trace_rays(
beam.iso_center.reshape(1, 3), beam.source_point.reshape(1, 3), ray_matrix_lps
)
t_trace_end = time.perf_counter()
logger.debug("Cube ray tracing took %s seconds...", t_trace_end - t_trace_start)
# Now we compute which rays will respectively give the voxel value for radiological depth
# We don't want -1 to be counted as "valid"
# or else the coords[ix[valid_ix], 1] silently reads the last elemtn
valid_ix = ix >= 0 # & np.isfinite(ix)
scale_factor = np.zeros_like(ix, dtype=self.precision)
scale_factor[valid_ix] = (ray_matrix_bev_y + beam.sad) / coords[ix[valid_ix], 1]
x_dist = np.full_like(ix, np.nan, dtype=self.precision)
z_dist = np.full_like(ix, np.nan, dtype=self.precision)
x_dist[valid_ix] = coords[ix[valid_ix], 0] * scale_factor[valid_ix]
x_dist = x_dist - ray_matrix_bev[:, 0, np.newaxis]
z_dist[valid_ix] = coords[ix[valid_ix], 2] * scale_factor[valid_ix]
z_dist = z_dist - ray_matrix_bev[:, 2, np.newaxis]
ray_selection = ray_spacing / 2.0
ix_remember_from_tracing = (
(x_dist > -ray_selection)
& (x_dist <= ray_selection)
& (z_dist > -ray_selection)
& (z_dist <= ray_selection)
)
t_remember_end = time.perf_counter()
logger.debug(
"Found %d ray indices for radiological depth calculation (took %s seconds)",
np.count_nonzero(ix_remember_from_tracing),
t_remember_end - t_trace_end,
)
rad_depth_cubes = [
np.nan * np.ones_like(sitk.GetArrayViewFromImage(cube), dtype=self.precision)
for cube in self.cubes
]
for i, cube in enumerate(rad_depth_cubes):
segment_depths = lengths * rho[i]
# Replace NaN with 0 before cumsum to prevent a single invalid voxel
# np.cumsum([[0.5, 0.33, NaN, 0.18, 0.22]]) -> [0.5, 0.83, NaN, NaN, NaN], which is bad
segment_depths = np.where(np.isfinite(segment_depths), segment_depths, 0.0)
rel_depths = np.cumsum(segment_depths, axis=1) - segment_depths / 2.0
try:
ix_assign = np.unravel_index(ix[ix_remember_from_tracing], cube.shape, order="F")
except (ValueError, IndexError):
logger.error(
"Error in unraveling indices from raytracing. Trying to recover...",
exc_info=True,
)
tmp_ix = ix[ix_remember_from_tracing]
rel_depths = rel_depths[ix_remember_from_tracing]
wrong_values = np.logical_or(tmp_ix < 0, tmp_ix >= cube.size)
tmp_ix = tmp_ix[~wrong_values]
rel_depths = rel_depths[~wrong_values]
# Remove the wrong values
ix_assign = np.unravel_index(tmp_ix, cube.shape, order="F")
logger.info(
"Recovered %d indices for radiological depth cube",
np.count_nonzero(wrong_values),
)
cube[ix_assign] = rel_depths
else:
cube[ix_assign] = rel_depths[ix_remember_from_tracing]
rad_depth_cubes[i] = sitk.GetImageFromArray(cube)
rad_depth_cubes[i].CopyInformation(self.cubes[i])
t_createcubes_end = time.perf_counter()
logger.debug(
"Radiological depth cube filling took %s seconds",
t_createcubes_end - t_remember_end,
)
return rad_depth_cubes
# scale_factor[valid_ix] = lengths[valid_ix] / d12[valid_ix]
def _get_candidate_ray_matrix(self, spacing_range: Array, ref_pos_bev: Array) -> Array:
"""Get candidate ray matrix for given ray spacing and reference positions."""
xp = array_api_compat.array_namespace(spacing_range, ref_pos_bev)
# Use numba accelerated code if possible
if array_api_compat.is_numpy_namespace(xp) and fast_spatial_circle_lookup is not None:
candidate_ray_coords_x, candidate_ray_coords_z = np.meshgrid(
spacing_range, spacing_range
)
return fast_spatial_circle_lookup(
candidate_ray_coords_x, candidate_ray_coords_z, ref_pos_bev, self.lateral_cut_off
)
# Array API compliant code
n_candidates = array_api_compat.size(spacing_range)
candidate_ray_mx = xp.zeros((n_candidates, n_candidates), dtype=xp.bool)
r2 = self.lateral_cut_off**2
# use buffers to avoid repeated allocations in the loop
buffer_x = xp.empty_like(spacing_range)
buffer_z = xp.empty_like(spacing_range)
for i in range(ref_pos_bev.shape[0]):
# Use in-place operations as much as possible
buffer_x[:] = spacing_range
buffer_z[:] = spacing_range
buffer_x -= ref_pos_bev[i, 0]
buffer_z -= ref_pos_bev[i, 2]
buffer_x *= buffer_x
buffer_z *= buffer_z
# simple full boolean or
# candidate_ray_mx |= (buffer_x[:, None] + buffer_z[None, :) <= r2
# reduce update region by finding z ranges that are valid
z_ok = xp.astype(buffer_z <= r2, xp.uint8) # argmax later only works on numeric types
if not xp.any(z_ok):
continue
z0 = xp.argmax(z_ok)
z1 = array_api_compat.size(z_ok) - xp.argmax(z_ok[::-1])
candidate_ray_mx[z0:z1, :] |= (buffer_z[z0:z1, None] + buffer_x[None, :]) <= r2
return candidate_ray_mx
@abstractmethod
def _initialize_geometry(self):
"""
Initialize geometry of the ray tracer.
Will be automatically called when the cubes are set.
"""