"""Segmentation utilities used by the reusable colocalization pipeline.
author: Fabrizio Musacchio
date: May/June 2026
"""
# %% IMPORTS
from __future__ import annotations
from importlib.metadata import PackageNotFoundError, version
from pathlib import Path
import numpy as np
from cellpose import models
from skimage.filters import gaussian, threshold_li, threshold_otsu
from skimage.measure import label
from skimage.morphology import ball, closing, disk, remove_small_holes, remove_small_objects
from .config import CellposeModelConfig, OptionalRegionSegmentationConfig
from .schemas import CellposeRefinementRoiCache, OptionalRegionSegmentationResult
# %% SEGMENTATION HELPERS
def normalize_segmentation_method(method: str) -> str:
"""Normalize and validate a configured segmentation method name."""
normalized = method.strip().lower()
allowed = {"cellpose", "otsu", "li", "percentile"}
if normalized not in allowed:
raise ValueError(
"`CellposeModelConfig.segmentation_method` must be one of "
f"{sorted(allowed)}, got {method!r}."
)
return normalized
def resolve_cellpose_anisotropy(
model_config: CellposeModelConfig,
voxel_scale_zyx: tuple[float, float, float],
do_3d: bool,
) -> float | None:
"""Resolve the Cellpose anisotropy factor for one segmentation run.
The function only applies anisotropy handling to genuine 3D runs. Users
can disable the feature with ``False``, enable automatic derivation from
voxel spacing with ``True``, or provide a numeric factor explicitly.
The automatic rule follows a practical microscopy heuristic: if the z-step
is appreciably larger than the in-plane sampling, Cellpose benefits from an
anisotropy factor of ``z_spacing / mean(xy_spacing)``. If z-spacing is not
larger than the in-plane spacing, no anisotropy value is forwarded.
"""
if not do_3d:
return None
anisotropy_setting = model_config.anisotropy
if anisotropy_setting is False:
return None
if isinstance(anisotropy_setting, (int, float)) and not isinstance(
anisotropy_setting,
bool,
):
anisotropy_value = float(anisotropy_setting)
if anisotropy_value <= 0:
raise ValueError(
"A manually configured Cellpose anisotropy value must be "
f"greater than 0, got {anisotropy_value}."
)
print(f"Using manually configured Cellpose anisotropy: {anisotropy_value:.4f}")
return anisotropy_value
if anisotropy_setting is not True:
raise ValueError(
"`CellposeModelConfig.anisotropy` must be set to False, True, or "
f"a positive numeric value, got {anisotropy_setting!r}."
)
z_spacing, y_spacing, x_spacing = voxel_scale_zyx
if z_spacing <= 0 or y_spacing <= 0 or x_spacing <= 0:
raise ValueError(
"Voxel spacing values must be strictly positive to derive "
f"anisotropy automatically, got {voxel_scale_zyx}."
)
xy_spacing = (y_spacing + x_spacing) / 2.0
anisotropy_activation_ratio = 1.25
if z_spacing <= xy_spacing * anisotropy_activation_ratio:
print(
"Skipping Cellpose anisotropy auto-correction because z-spacing "
f"({z_spacing:.4f}) is not sufficiently larger than mean "
f"xy-spacing ({xy_spacing:.4f})."
)
return None
anisotropy_value = z_spacing / xy_spacing
print(
"Using automatically derived Cellpose anisotropy: "
f"{anisotropy_value:.4f} (z={z_spacing:.4f}, mean_xy={xy_spacing:.4f})"
)
return anisotropy_value
[docs]
def get_cellpose_major_version() -> int | None:
"""Return the installed Cellpose major version when it can be determined."""
try:
raw_version = version("cellpose")
except PackageNotFoundError:
return None
major_token = raw_version.split(".", maxsplit=1)[0]
try:
return int(major_token)
except ValueError:
return None
[docs]
def get_available_cellpose_model_names() -> list[str]:
"""Return locally available built-in and user-registered Cellpose models."""
available_models = set(getattr(models, "MODEL_NAMES", []))
get_user_models = getattr(models, "get_user_models", None)
if callable(get_user_models):
try:
available_models.update(get_user_models())
except Exception:
pass
return sorted(model_name for model_name in available_models if model_name)
[docs]
def create_cellpose_model(model_name_or_path: str, use_gpu: bool) -> models.CellposeModel:
"""Create a Cellpose model from either a built-in name or a custom path.
The helper is intentionally strict: if a requested built-in model name is
not locally available in the installed Cellpose version, it raises a clear
error instead of silently falling back to a different model. It also adapts
model construction to the installed Cellpose major version so that newer
Cellpose 4 setups do not receive the deprecated ``model_type`` argument.
"""
candidate_path = Path(model_name_or_path).expanduser()
if candidate_path.exists():
print(f"Loading Cellpose custom model from:\n{candidate_path.resolve()}")
return models.CellposeModel(
gpu=use_gpu,
pretrained_model=str(candidate_path.resolve()),
)
available_models = get_available_cellpose_model_names()
if model_name_or_path not in available_models:
raise ValueError(
"The requested Cellpose model is not available in the current "
f"environment: {model_name_or_path!r}. Available model names: "
f"{available_models or ['<none>']}. Please provide a local custom "
"model path or use a Cellpose installation that exposes the "
"required built-in model."
)
cellpose_major = get_cellpose_major_version()
print(f"Loading built-in Cellpose model: {model_name_or_path}")
if cellpose_major is not None and cellpose_major >= 4:
if model_name_or_path == "cpsam":
return models.CellposeModel(gpu=use_gpu)
return models.CellposeModel(
gpu=use_gpu,
pretrained_model=model_name_or_path,
)
return models.CellposeModel(
gpu=use_gpu,
pretrained_model=model_name_or_path,
model_type=model_name_or_path,
)
[docs]
def create_cellpose_models_for_channels(
cell_model_config: CellposeModelConfig,
marker_model_config: CellposeModelConfig,
use_gpu: bool,
) -> tuple[models.CellposeModel | None, models.CellposeModel | None]:
"""Create Cellpose model instances for the cell and marker channels.
Models are only created for channels configured with
``segmentation_method="cellpose"``. For Cellpose 4 and newer, the same
model instance is reused when both such channels request the same built-in
model or custom model path. For older Cellpose versions, the previous
behavior is preserved and separate model instances are created for each
Cellpose channel configuration.
"""
cell_uses_cellpose = normalize_segmentation_method(cell_model_config.segmentation_method) == "cellpose"
marker_uses_cellpose = normalize_segmentation_method(marker_model_config.segmentation_method) == "cellpose"
if not cell_uses_cellpose and not marker_uses_cellpose:
return None, None
if cell_uses_cellpose and not marker_uses_cellpose:
return create_cellpose_model(cell_model_config.model_name_or_path, use_gpu), None
if marker_uses_cellpose and not cell_uses_cellpose:
return None, create_cellpose_model(marker_model_config.model_name_or_path, use_gpu)
cellpose_major = get_cellpose_major_version()
if cellpose_major is not None and cellpose_major >= 4:
cell_model_name = cell_model_config.model_name_or_path
marker_model_name = marker_model_config.model_name_or_path
if cell_model_name == marker_model_name:
shared_model = create_cellpose_model(cell_model_name, use_gpu)
print("Reusing one shared Cellpose model instance for both channels.")
return shared_model, shared_model
return (
create_cellpose_model(cell_model_config.model_name_or_path, use_gpu),
create_cellpose_model(marker_model_config.model_name_or_path, use_gpu),
)
def segment_threshold_channel(
image_zyx: np.ndarray,
model_config: CellposeModelConfig,
) -> np.ndarray:
"""Segment one analysis channel via thresholding and connected components.
True 2D data may enter this function either as a raw ``(Y, X)`` image or
as the pipeline's normalized singleton-z representation ``(1, Y, X)``. In
both cases, thresholding and morphology are executed in 2D and the final
label image is returned in ``(1, Y, X)`` form.
"""
method = normalize_segmentation_method(model_config.segmentation_method)
if method == "cellpose":
raise ValueError("Threshold segmentation was requested with the 'cellpose' method.")
image_float = np.asarray(image_zyx, dtype=np.float32, copy=False)
is_singleton_z_2d = image_float.ndim == 3 and image_float.shape[0] == 1
if image_float.ndim == 2:
image_working = image_float
is_3d = False
elif is_singleton_z_2d:
image_working = image_float[0]
is_3d = False
else:
image_working = image_float
is_3d = image_working.shape[0] > 1
if model_config.threshold_background_sigma is not None and model_config.threshold_background_sigma > 0:
print(
"Threshold segmentation: background subtraction with sigma="
f"{model_config.threshold_background_sigma}..."
)
background = gaussian(
image_working,
sigma=model_config.threshold_background_sigma,
preserve_range=True,
)
image_work = image_working - background
image_work[image_work < 0] = 0
else:
image_work = image_working
values = image_work[np.isfinite(image_work)]
values = values[values > 0]
if values.size == 0:
zero_labels = np.zeros_like(image_work, dtype=np.uint32)
if zero_labels.ndim == 2:
zero_labels = zero_labels[np.newaxis, :, :]
return zero_labels
print(f"Threshold segmentation: computing threshold with method='{method}'...")
if method == "otsu":
threshold = float(threshold_otsu(values))
elif method == "li":
threshold = float(threshold_li(values))
elif method == "percentile":
threshold = float(np.percentile(values, model_config.threshold_percentile))
else:
raise ValueError(f"Unsupported threshold segmentation method: {method}")
binary_mask = image_work > threshold
if model_config.threshold_apply_closing:
binary_mask = closing(binary_mask, footprint=ball(1) if is_3d else disk(1))
if model_config.threshold_min_object_voxels > 0:
binary_mask = remove_small_objects(
binary_mask,
max_size=_legacy_threshold_to_max_size(model_config.threshold_min_object_voxels),
)
if model_config.threshold_min_hole_voxels > 0:
binary_mask = remove_small_holes(
binary_mask,
max_size=_legacy_threshold_to_max_size(model_config.threshold_min_hole_voxels),
)
labels = label(binary_mask)
labels = np.asarray(labels, dtype=np.uint32)
if labels.ndim == 2:
labels = labels[np.newaxis, :, :]
return labels
def evaluate_segmentation_method(
model: models.CellposeModel | None,
image_zyx: np.ndarray,
model_config: CellposeModelConfig,
voxel_scale_zyx: tuple[float, float, float],
) -> tuple[np.ndarray, CellposeRefinementRoiCache | None]:
"""Segment one image channel using the configured backend."""
method = normalize_segmentation_method(model_config.segmentation_method)
if method == "cellpose":
if model is None:
raise ValueError(
"A Cellpose segmentation was requested, but no Cellpose model "
"instance was created for this channel."
)
return evaluate_cellpose_model(model, image_zyx, model_config, voxel_scale_zyx)
return segment_threshold_channel(image_zyx, model_config), None
[docs]
def evaluate_cellpose_model(
model: models.CellposeModel,
image_zyx: np.ndarray,
model_config: CellposeModelConfig,
voxel_scale_zyx: tuple[float, float, float],
) -> tuple[np.ndarray, CellposeRefinementRoiCache | None]:
"""Run Cellpose and return the resulting label image as ``uint32``.
The function accepts both 3D ``ZYX`` arrays and 2D images represented as a
singleton-z ``(1, Y, X)`` volume. When ``model_config.do_3d`` is ``None``,
the function auto-detects the correct Cellpose mode from the z-size.
"""
do_3d = model_config.do_3d
if do_3d is None:
do_3d = image_zyx.shape[0] > 1
if model_config.flow3d_smooth < 0 or model_config.flow3d_smooth > 10:
raise ValueError(
"`CellposeModelConfig.flow3d_smooth` must be between 0 and 10, "
f"got {model_config.flow3d_smooth}."
)
anisotropy = resolve_cellpose_anisotropy(
model_config=model_config,
voxel_scale_zyx=voxel_scale_zyx,
do_3d=do_3d,
)
cellpose_major = get_cellpose_major_version()
if not do_3d and image_zyx.shape[0] != 1:
raise ValueError(
"A 2D Cellpose run was requested for an image with more than one z "
"slice. Please keep automatic 3D detection enabled or project the "
"image to 2D before segmentation."
)
cellpose_input = image_zyx if do_3d else image_zyx[0]
shape_for_masks = image_zyx.shape if do_3d else (1, image_zyx.shape[1], image_zyx.shape[2])
eval_kwargs = {
"do_3D": do_3d,
"z_axis": model_config.z_axis if do_3d else None,
"channel_axis": model_config.channel_axis,
}
if anisotropy is not None:
eval_kwargs["anisotropy"] = anisotropy
if do_3d:
eval_kwargs["flow3D_smooth"] = model_config.flow3d_smooth
if cellpose_major is not None and cellpose_major >= 4:
eval_kwargs["cellprob_threshold"] = model_config.cellprob_threshold
eval_kwargs["flow_threshold"] = model_config.flow_threshold
eval_kwargs["compute_masks"] = False
if model_config.diameter is not None:
eval_kwargs["diameter"] = model_config.diameter
else:
if model_config.diameter is None:
raise ValueError(
"For Cellpose versions below 4, an explicit diameter is "
"currently required by this pipeline. Please set "
"`CellposeModelConfig.diameter` in the user script."
)
eval_kwargs["diameter"] = model_config.diameter
masks, flows, _ = model.eval(cellpose_input, **eval_kwargs)
if cellpose_major is not None and cellpose_major >= 4:
dP = np.asarray(flows[1])
cellprob = np.asarray(flows[2])
if do_3d:
if dP.ndim != 4:
raise ValueError(
"Cellpose returned unexpected 3D flow dimensions for "
f"`dP`: {dP.shape}."
)
if cellprob.ndim != 3:
raise ValueError(
"Cellpose returned unexpected 3D cellprob dimensions: "
f"{cellprob.shape}."
)
else:
if dP.ndim == 3:
dP = dP[:, np.newaxis, :, :]
elif dP.ndim != 4:
raise ValueError(
"Cellpose returned unexpected 2D flow dimensions for "
f"`dP`: {dP.shape}."
)
if cellprob.ndim == 2:
cellprob = cellprob[np.newaxis, :, :]
elif cellprob.ndim != 3:
raise ValueError(
"Cellpose returned unexpected 2D cellprob dimensions: "
f"{cellprob.shape}."
)
image_scaling = 30.0 / model_config.diameter if model_config.diameter is not None and model_config.diameter > 0 else 1.0
niter = int(200 / image_scaling)
masks = model._compute_masks(
shape_for_masks,
dP,
cellprob,
flow_threshold=model_config.flow_threshold,
cellprob_threshold=model_config.cellprob_threshold,
min_size=15,
max_size_fraction=0.4,
niter=niter,
do_3D=do_3d,
stitch_threshold=0.0,
)
refinement_cache = CellposeRefinementRoiCache(
roi_id=-1,
y_min=-1,
y_max=-1,
x_min=-1,
x_max=-1,
roi_mask_crop_2d=np.zeros((1, 1), dtype=bool),
shape_for_masks=tuple(int(v) for v in shape_for_masks),
dP=dP,
cellprob=cellprob,
do_3d=do_3d,
niter=niter,
min_size=15,
max_size_fraction=0.4,
flow_threshold=model_config.flow_threshold,
cellprob_threshold=model_config.cellprob_threshold,
)
else:
refinement_cache = None
masks_array = np.asarray(masks, dtype=np.uint32)
if do_3d:
return masks_array, refinement_cache
if masks_array.ndim != 2:
raise ValueError(
"Cellpose returned an unexpected 2D mask shape: "
f"{masks_array.shape}."
)
return masks_array[np.newaxis, :, :], refinement_cache
[docs]
def relabel_with_offset(mask: np.ndarray, offset: int) -> np.ndarray:
"""Shift all non-zero labels by a fixed offset."""
out = mask.copy()
valid = out > 0
out[valid] += offset
return out
[docs]
def filter_labels_by_size(label_image: np.ndarray, min_size: int) -> np.ndarray:
"""Remove labels smaller than the configured voxel threshold."""
labels, counts = np.unique(label_image, return_counts=True)
keep_labels = labels[(labels != 0) & (counts >= min_size)]
lookup = np.zeros(int(label_image.max()) + 1, dtype=label_image.dtype)
lookup[keep_labels] = keep_labels
return lookup[label_image]
def _legacy_threshold_to_max_size(value: int) -> int:
"""Convert the old strict threshold semantics to the new scikit-image API."""
return max(int(value) - 1, 0)
[docs]
def segment_optional_region(
image_zyx: np.ndarray,
roi_labels_2d: np.ndarray | None,
config: OptionalRegionSegmentationConfig,
) -> OptionalRegionSegmentationResult:
"""Threshold an optional third channel and compute a cleaned 3D mask.
The function keeps the old prototype behavior while avoiding the deprecated
scikit-image arguments by mapping strict minimum thresholds to the new
inclusive ``max_size`` semantics.
"""
print("\nSegmenting optional region channel...")
image_float = image_zyx.astype(np.float32, copy=False)
roi_mask_3d = None
if roi_labels_2d is not None:
roi_mask_3d = np.repeat((roi_labels_2d > 0)[np.newaxis, :, :], image_float.shape[0], axis=0)
image_work = image_float.copy()
if roi_mask_3d is not None:
image_work[~roi_mask_3d] = 0
if config.gaussian_sigma is not None and config.gaussian_sigma > 0:
print(f" Optional region: Gaussian smoothing with sigma={config.gaussian_sigma}...")
image_smooth = gaussian(image_work, sigma=config.gaussian_sigma, preserve_range=True)
if roi_mask_3d is not None:
image_smooth[~roi_mask_3d] = 0
else:
image_smooth = image_work
if config.background_sigma is not None and config.background_sigma > 0:
print(f" Optional region: background subtraction with sigma={config.background_sigma}...")
background = gaussian(image_smooth, sigma=config.background_sigma, preserve_range=True)
image_corrected = image_smooth - background
image_corrected[image_corrected < 0] = 0
if roi_mask_3d is not None:
image_corrected[~roi_mask_3d] = 0
else:
image_corrected = image_smooth
if roi_mask_3d is not None:
values = image_corrected[roi_mask_3d]
else:
values = image_corrected.ravel()
values = values[np.isfinite(values)]
values = values[values > 0]
if len(values) == 0:
raise ValueError("Cannot compute an optional-region threshold because no positive values were found.")
method = config.method.lower()
print(f" Optional region: computing threshold with method='{method}'...")
if method == "otsu":
threshold = float(threshold_otsu(values))
elif method == "li":
threshold = float(threshold_li(values))
elif method == "percentile":
threshold = float(np.percentile(values, config.percentile))
else:
raise ValueError(f"Unknown optional region segmentation method: {config.method}")
if roi_mask_3d is not None:
region_mask = np.zeros_like(image_corrected, dtype=bool)
region_mask[roi_mask_3d] = image_corrected[roi_mask_3d] > threshold
else:
region_mask = image_corrected > threshold
if config.apply_closing:
print(" Optional region: applying morphological closing...")
region_mask = closing(region_mask, footprint=ball(1))
if roi_mask_3d is not None:
region_mask[~roi_mask_3d] = False
if config.min_object_voxels > 0:
print(f" Optional region: removing objects smaller than {config.min_object_voxels} voxels...")
region_mask = remove_small_objects(
region_mask,
max_size=_legacy_threshold_to_max_size(config.min_object_voxels),
)
if config.min_hole_voxels > 0:
print(f" Optional region: removing holes smaller than {config.min_hole_voxels} voxels...")
region_mask = remove_small_holes(
region_mask,
max_size=_legacy_threshold_to_max_size(config.min_hole_voxels),
)
if roi_mask_3d is not None:
region_mask[~roi_mask_3d] = False
region_labels = label(region_mask).astype(np.uint32)
print(f" Optional region: threshold={threshold}")
print(f" Optional region: n_objects={int(region_labels.max())}")
return OptionalRegionSegmentationResult(
mask=region_mask,
labels=region_labels,
threshold=threshold,
corrected_image=np.asarray(image_corrected),
)
# %% END