Source code for cellcoloc.analysis

"""
Core Cellpose colocalization analysis logic.

author: Fabrizio Musacchio
date: May/June 2026
"""
# %% IMPORTS
from __future__ import annotations

import numpy as np
import pandas as pd
from skimage.measure import regionprops, regionprops_table

from .config import CellposeModelConfig, ColocalizationConfig, RuntimeConfig
from .filtering import apply_postfilters, apply_prefilter
from .roi import get_bbox_2d
from .schemas import (
    CellposeChannelRefinementContext,
    CellposeRefinementRoiCache,
    ColocalizationRunResult,
    ColocalizationTables,
    LoadedImageChannels,
    OptionalRegionSegmentationResult,
)
from .segmentation import (
    create_cellpose_model,
    create_cellpose_models_for_channels,
    evaluate_segmentation_method,
    filter_labels_by_size,
    normalize_segmentation_method,
    relabel_with_offset,
)

# %% MAIN ANALYSIS LOGIC
def analyze_label_overlaps(
    cell_masks: np.ndarray,
    marker_masks: np.ndarray,
    roi_id: int,
) -> list[dict[str, int | float]]:
    """Compute per-cell overlap rows against marker labels within one ROI."""

    rows: list[dict[str, int | float]] = []
    cell_labels = np.unique(cell_masks)
    cell_labels = cell_labels[cell_labels != 0]

    for cell_label in cell_labels:
        cell_mask = cell_masks == cell_label
        cell_voxels = int(cell_mask.sum())

        overlapping_markers = marker_masks[cell_mask]
        overlapping_markers = overlapping_markers[overlapping_markers != 0]
        unique_markers, counts = np.unique(overlapping_markers, return_counts=True)

        if len(unique_markers) == 0:
            rows.append(
                {
                    "roi_id": roi_id,
                    "cell_label": int(cell_label),
                    "cell_voxels": cell_voxels,
                    "n_overlapping_markers": 0,
                    "marker_label": np.nan,
                    "overlap_voxels": 0,
                    "overlap_fraction_of_cell": 0.0,
                }
            )
            continue

        for marker_label, overlap_voxels in zip(unique_markers, counts):
            rows.append(
                {
                    "roi_id": roi_id,
                    "cell_label": int(cell_label),
                    "cell_voxels": cell_voxels,
                    "n_overlapping_markers": int(len(unique_markers)),
                    "marker_label": int(marker_label),
                    "overlap_voxels": int(overlap_voxels),
                    "overlap_fraction_of_cell": float(overlap_voxels / cell_voxels),
                }
            )

    return rows


[docs] def build_positive_cell_mask(cell_masks: np.ndarray, summary_table: pd.DataFrame) -> np.ndarray: """Create a label image containing only marker-positive cells. The returned label image preserves the original cell labels for all cells classified as marker-positive in ``summary_table`` and sets every other voxel to zero. """ if summary_table.empty: return np.zeros_like(cell_masks, dtype=np.uint32) positive_labels = summary_table.loc[summary_table["marker_positive"], "cell_label"].astype(np.uint32).to_numpy() max_label = int(cell_masks.max()) positive_labels = positive_labels[positive_labels <= max_label] lookup = np.zeros(max_label + 1, dtype=np.uint32) lookup[positive_labels] = positive_labels return lookup[cell_masks]
def _compute_3d_surface_area_um2( object_mask_zyx: np.ndarray, voxel_scale_zyx: tuple[float, float, float], ) -> float: """Compute voxel-surface area of one 3D object in squared micrometers.""" z_size_um, y_size_um, x_size_um = voxel_scale_zyx padded = np.pad(np.asarray(object_mask_zyx, dtype=bool), 1, mode="constant", constant_values=False) transitions_z = np.count_nonzero(padded[1:, :, :] != padded[:-1, :, :]) transitions_y = np.count_nonzero(padded[:, 1:, :] != padded[:, :-1, :]) transitions_x = np.count_nonzero(padded[:, :, 1:] != padded[:, :, :-1]) face_area_z = y_size_um * x_size_um face_area_y = z_size_um * x_size_um face_area_x = z_size_um * y_size_um return float( transitions_z * face_area_z + transitions_y * face_area_y + transitions_x * face_area_x ) def _compute_3d_ellipticity( object_mask_zyx: np.ndarray, voxel_scale_zyx: tuple[float, float, float], ) -> float: """Estimate a 3D ellipticity-like elongation score from voxel coordinates.""" coordinates = np.column_stack(np.where(object_mask_zyx)) if coordinates.shape[0] < 3: return np.nan spacing = np.asarray(voxel_scale_zyx, dtype=float) coordinates_um = coordinates * spacing centered = coordinates_um - coordinates_um.mean(axis=0, keepdims=True) covariance = np.cov(centered, rowvar=False) eigenvalues = np.sort(np.linalg.eigvalsh(covariance))[::-1] eigenvalues = np.clip(eigenvalues, 0.0, None) if eigenvalues[0] <= 0: return np.nan major_axis = np.sqrt(eigenvalues[0]) minor_axis = np.sqrt(eigenvalues[-1]) if major_axis <= 0: return np.nan return float(1.0 - (minor_axis / major_axis)) def _channel_metric_column_names(prefix: str) -> list[str]: """Return the standard morphology metric column names for one channel.""" return [ f"{prefix}_area_px_2d", f"{prefix}_area_um2_2d", f"{prefix}_perimeter_px_2d", f"{prefix}_perimeter_um_2d", f"{prefix}_roundness_2d", f"{prefix}_eccentricity_2d", f"{prefix}_volume_voxels_3d", f"{prefix}_volume_um3_3d", f"{prefix}_surface_area_um2_3d", f"{prefix}_sphericity_3d", f"{prefix}_ellipticity_3d", ] def _empty_channel_properties_table( label_column: str, metric_prefix: str, ) -> pd.DataFrame: """Create an empty per-object morphology table for one segmented channel.""" return pd.DataFrame( columns=[ "roi_id", label_column, "centroid_z", "centroid_y", "centroid_x", *_channel_metric_column_names(metric_prefix), ] ) def _build_channel_properties_table( label_image: np.ndarray, voxel_scale_zyx: tuple[float, float, float], roi_labels_2d: np.ndarray, label_column: str, metric_prefix: str, ) -> pd.DataFrame: """Create one morphology row per segmented label of one channel.""" if np.max(label_image) == 0: return _empty_channel_properties_table(label_column, metric_prefix) is_effective_2d = label_image.shape[0] == 1 z_size_um, y_size_um, x_size_um = voxel_scale_zyx pixel_area_um2 = y_size_um * x_size_um voxel_volume_um3 = z_size_um * y_size_um * x_size_um rows: list[dict[str, int | float]] = [] roi_labels_3d = np.broadcast_to(roi_labels_2d, label_image.shape) if is_effective_2d: for region in regionprops(label_image[0]): object_label = int(region.label) object_mask = label_image == object_label roi_values = roi_labels_3d[object_mask] roi_values = roi_values[roi_values != 0] roi_id = int(np.unique(roi_values)[0]) if roi_values.size > 0 else 0 object_area_px = float(region.area) object_area_um2 = float(object_area_px * pixel_area_um2) perimeter_px = float(getattr(region, "perimeter", np.nan)) perimeter_um = ( float(perimeter_px * ((y_size_um + x_size_um) / 2.0)) if np.isfinite(perimeter_px) else np.nan ) roundness = ( float((4.0 * np.pi * object_area_px) / (perimeter_px ** 2)) if np.isfinite(perimeter_px) and perimeter_px > 0 else np.nan ) rows.append( { "roi_id": roi_id, label_column: object_label, "centroid_z": 0.0, "centroid_y": float(region.centroid[0]), "centroid_x": float(region.centroid[1]), f"{metric_prefix}_area_px_2d": object_area_px, f"{metric_prefix}_area_um2_2d": object_area_um2, f"{metric_prefix}_perimeter_px_2d": perimeter_px, f"{metric_prefix}_perimeter_um_2d": perimeter_um, f"{metric_prefix}_roundness_2d": roundness, f"{metric_prefix}_eccentricity_2d": float(getattr(region, "eccentricity", np.nan)), f"{metric_prefix}_volume_voxels_3d": np.nan, f"{metric_prefix}_volume_um3_3d": np.nan, f"{metric_prefix}_surface_area_um2_3d": np.nan, f"{metric_prefix}_sphericity_3d": np.nan, f"{metric_prefix}_ellipticity_3d": np.nan, } ) else: for region in regionprops(label_image): object_label = int(region.label) object_mask = label_image == object_label roi_values = roi_labels_3d[object_mask] roi_values = roi_values[roi_values != 0] roi_id = int(np.unique(roi_values)[0]) if roi_values.size > 0 else 0 object_volume_voxels = float(region.area) object_volume_um3 = float(object_volume_voxels * voxel_volume_um3) object_surface_area_um2 = _compute_3d_surface_area_um2( object_mask, voxel_scale_zyx, ) object_sphericity = ( float( (np.pi ** (1.0 / 3.0)) * ((6.0 * object_volume_um3) ** (2.0 / 3.0)) / object_surface_area_um2 ) if object_surface_area_um2 > 0 and object_volume_um3 > 0 else np.nan ) rows.append( { "roi_id": roi_id, label_column: object_label, "centroid_z": float(region.centroid[0]), "centroid_y": float(region.centroid[1]), "centroid_x": float(region.centroid[2]), f"{metric_prefix}_area_px_2d": np.nan, f"{metric_prefix}_area_um2_2d": np.nan, f"{metric_prefix}_perimeter_px_2d": np.nan, f"{metric_prefix}_perimeter_um_2d": np.nan, f"{metric_prefix}_roundness_2d": np.nan, f"{metric_prefix}_eccentricity_2d": np.nan, f"{metric_prefix}_volume_voxels_3d": object_volume_voxels, f"{metric_prefix}_volume_um3_3d": object_volume_um3, f"{metric_prefix}_surface_area_um2_3d": object_surface_area_um2, f"{metric_prefix}_sphericity_3d": object_sphericity, f"{metric_prefix}_ellipticity_3d": _compute_3d_ellipticity( object_mask, voxel_scale_zyx, ), } ) return pd.DataFrame(rows).sort_values(by=["roi_id", label_column]).reset_index(drop=True) def _build_channel_roi_summary_table( object_table: pd.DataFrame | None, roi_labels_2d: np.ndarray, label_column: str, count_column: str, metric_prefix: str, ) -> pd.DataFrame: """Create one per-ROI morphology-mean row for one segmented channel.""" metric_columns = _channel_metric_column_names(metric_prefix) if object_table is None: object_table = _empty_channel_properties_table(label_column, metric_prefix) rows: list[dict[str, int | float]] = [] for roi_id in np.unique(roi_labels_2d): if roi_id == 0: continue object_rows = object_table[object_table["roi_id"] == roi_id] row: dict[str, int | float] = { "roi_id": int(roi_id), count_column: int(len(object_rows)), } for column_name in metric_columns: mean_value = object_rows[column_name].dropna().mean() if column_name in object_rows.columns else np.nan row[f"average_{column_name}"] = float(mean_value) if pd.notna(mean_value) else np.nan rows.append(row) return pd.DataFrame(rows) def _normalize_z_crop_bounds( z_crop: tuple[int | None, int | None], z_size: int, ) -> tuple[int, int]: """Validate and normalize one user-supplied z-crop against a stack size. Parameters ---------- z_crop: Tuple of ``(start, stop)`` indices. ``None`` endpoints mean "from the start" or "to the end" respectively. z_size: Full z depth of the currently loaded stack. Returns ------- tuple[int, int] Clipped and validated z bounds suitable for Python slicing. """ start_raw, stop_raw = z_crop start = 0 if start_raw is None else int(start_raw) stop = z_size if stop_raw is None else int(stop_raw) start = max(0, min(start, z_size)) stop = max(0, min(stop, z_size)) if start >= stop: raise ValueError( "Invalid z-crop bounds. Expected a tuple like ``(start, stop)`` " f"with start < stop after clipping to the stack size, got {z_crop!r} " f"for z size {z_size}." ) return start, stop def _resolve_analysis_z_bounds( z_size: int, *model_configs: CellposeModelConfig | None, fallback: tuple[int, int] | None = None, ) -> tuple[int, int] | None: """Resolve one global analysis z-crop from one or more channel configs. The pipeline treats z-cropping as a global analysis constraint. Individual channel configs may expose the same ``z_crop`` field for user convenience, but conflicting bounds across channels are rejected to keep all internal computations aligned. """ normalized_bounds: list[tuple[int, int]] = [] for model_config in model_configs: if model_config is None or model_config.z_crop is None: continue normalized_bounds.append(_normalize_z_crop_bounds(model_config.z_crop, z_size)) if not normalized_bounds: return fallback first_bounds = normalized_bounds[0] if any(bounds != first_bounds for bounds in normalized_bounds[1:]): raise ValueError( "Conflicting z-crop bounds were provided across channel configs. " "Please use the same z-crop for all participating channels." ) return first_bounds def _normalize_z_projection_method(z_projection: str | None) -> str | None: """Normalize and validate an optional global z-projection method. Supported methods are ``"max"``, ``"mean"``, ``"median"``, ``"std"``, and ``"var"``. ``None`` disables projection. """ if z_projection is None: return None normalized = str(z_projection).strip().lower() allowed = {"max", "mean", "median", "std", "var"} if normalized not in allowed: raise ValueError( "`z_projection` must be one of None, 'max', 'mean', 'median', " f"'std', or 'var', got {z_projection!r}." ) return normalized def _resolve_analysis_z_projection_method( *model_configs: CellposeModelConfig | None, ) -> str | None: """Resolve one global z-projection method from one or more channel configs. The pipeline treats z-projection as a global preprocessing choice. Channel configs may expose the same field for convenience, but conflicting methods across channels are rejected. """ normalized_methods: list[str] = [] for model_config in model_configs: if model_config is None: continue normalized_method = _normalize_z_projection_method(model_config.z_projection) if normalized_method is not None: normalized_methods.append(normalized_method) if not normalized_methods: return None first_method = normalized_methods[0] if any(method != first_method for method in normalized_methods[1:]): raise ValueError( "Conflicting z-projection methods were provided across channel " "configs. Please use the same z-projection for all participating " "channels." ) return first_method def _project_zyx_volume( image_zyx: np.ndarray, projection_method: str, ) -> np.ndarray: """Project one ``ZYX`` image volume along z and keep singleton-z shape. Returns a float ``(1, Y, X)`` array so downstream code can continue to use the same ``ZYX`` interface even after a nominally 2D projection step. """ image_float = np.asarray(image_zyx, dtype=np.float32) if projection_method == "max": projection_yx = np.max(image_float, axis=0) elif projection_method == "mean": projection_yx = np.mean(image_float, axis=0) elif projection_method == "median": projection_yx = np.median(image_float, axis=0) elif projection_method == "std": projection_yx = np.std(image_float, axis=0) elif projection_method == "var": projection_yx = np.var(image_float, axis=0) else: raise ValueError(f"Unsupported z-projection method: {projection_method!r}.") return np.asarray(projection_yx, dtype=np.float32)[np.newaxis, :, :]
[docs] def prepare_loaded_images_for_analysis( loaded_images: LoadedImageChannels, *model_configs: CellposeModelConfig | None, ) -> LoadedImageChannels: """Prepare a loaded dataset for downstream analysis according to configs. This helper currently resolves an optional global z-projection from the provided channel configs. When no projection method is configured, the original ``loaded_images`` object is returned unchanged. When a projection is requested, the helper optionally applies the globally configured z-crop first, projects every available channel along z, and returns a new ``LoadedImageChannels`` bundle that behaves like a 2D dataset with singleton-z image arrays. All later ROI drawing, segmentation, quantification, and visualization steps should use this prepared bundle. Parameters ---------- loaded_images: Previously loaded channel bundle from :func:`cellcoloc.io.load_analysis_images`. *model_configs: One or more participating channel configs. Any configured ``z_crop`` and ``z_projection`` values are resolved globally across them. Returns ------- LoadedImageChannels Either the original loaded image bundle or a projected analysis view. """ projection_method = _resolve_analysis_z_projection_method(*model_configs) if projection_method is None: return loaded_images analysis_z_bounds = _resolve_analysis_z_bounds( loaded_images.cell_image.shape[0], *model_configs, ) z_slice = slice(*analysis_z_bounds) if analysis_z_bounds is not None else slice(None) projected_cell_image = _project_zyx_volume(loaded_images.cell_image[z_slice], projection_method) projected_marker_image = _project_zyx_volume(loaded_images.marker_image[z_slice], projection_method) projected_optional_region_image = None if loaded_images.optional_region_image is not None: projected_optional_region_image = _project_zyx_volume( loaded_images.optional_region_image[z_slice], projection_method, ) return LoadedImageChannels( source_path=loaded_images.source_path, paths=loaded_images.paths, voxel_scale_zyx=(1.0, loaded_images.voxel_scale_zyx[1], loaded_images.voxel_scale_zyx[2]), cell_image=projected_cell_image, marker_image=projected_marker_image, optional_region_image=projected_optional_region_image, raw_shape_tzcyx=loaded_images.raw_shape_tzcyx, raw_z_size=loaded_images.raw_z_size, is_3d=False, metadata=loaded_images.metadata, analysis_z_bounds=analysis_z_bounds, z_projection_method=projection_method, )
def _apply_analysis_z_bounds( label_image: np.ndarray | None, analysis_z_bounds: tuple[int, int] | None, ) -> np.ndarray | None: """Zero label content outside the active analysis z range. This helper keeps all mask arrays in full-stack shape for visualization and export, while ensuring that quantification and later refinement only see labels inside the chosen analysis z interval. """ if label_image is None: return None if analysis_z_bounds is None: return np.asarray(label_image, dtype=np.uint32).copy() cropped = np.zeros_like(label_image, dtype=np.uint32) z_start, z_stop = analysis_z_bounds cropped[z_start:z_stop] = np.asarray(label_image[z_start:z_stop], dtype=np.uint32) return cropped
[docs] def analyze_existing_masks( loaded_images: LoadedImageChannels, roi_labels_2d: np.ndarray, cell_masks: np.ndarray, marker_masks: np.ndarray, colocalization_config: ColocalizationConfig, optional_region_result: OptionalRegionSegmentationResult | None = None, optional_region_masks: np.ndarray | None = None, analysis_z_bounds: tuple[int, int] | None = None, cell_refinement_context: CellposeChannelRefinementContext | None = None, marker_refinement_context: CellposeChannelRefinementContext | None = None, optional_region_refinement_context: CellposeChannelRefinementContext | None = None, cell_model_config: CellposeModelConfig | None = None, marker_model_config: CellposeModelConfig | None = None, optional_region_model_config: CellposeModelConfig | None = None, ) -> ColocalizationRunResult: """Recompute colocalization tables from existing label masks. This helper is used both after the initial Cellpose segmentation and after any later manual or threshold-based refinement of the label masks. Parameters ---------- loaded_images: Loaded raw analysis channels and dataset metadata. roi_labels_2d: Drawn or generated 2D ROI label mask. cell_masks, marker_masks: Full-stack label masks for the two primary analysis channels. They may originate from Cellpose, thresholding, or manual relabeling. colocalization_config: Thresholds controlling how per-cell overlaps are interpreted. optional_region_result, optional_region_masks: Optional third-channel segmentation supplied either as the legacy result wrapper or directly as a label image. When both are provided, ``optional_region_masks`` takes precedence. analysis_z_bounds: Optional global z interval used for the current analysis. Labels outside this interval are ignored internally but the stored masks keep full-stack shape. cell_refinement_context, marker_refinement_context: Optional cached Cellpose network outputs used for later threshold-only refinement. cell_model_config, marker_model_config, optional_region_model_config: Optional channel configs reused here mainly so postfilters can be applied consistently when masks are reanalyzed. Returns ------- ColocalizationRunResult Structured masks and tables reflecting the provided segmentation state. """ effective_analysis_z_bounds = ( None if loaded_images.z_projection_method is not None else analysis_z_bounds ) full_cell_masks = _apply_analysis_z_bounds(cell_masks, effective_analysis_z_bounds) full_marker_masks = _apply_analysis_z_bounds(marker_masks, effective_analysis_z_bounds) roi_ids = np.unique(roi_labels_2d) roi_ids = roi_ids[roi_ids != 0] print(f"\nFiltering cell labels smaller than {colocalization_config.min_cell_voxels} voxels...") full_cell_masks = filter_labels_by_size(full_cell_masks, colocalization_config.min_cell_voxels) if cell_model_config is not None and cell_model_config.postfilters is not None: print("Applying configured postfilters to cell masks...") full_cell_masks = apply_postfilters( full_cell_masks, loaded_images.cell_image, cell_model_config, ) if marker_model_config is not None and marker_model_config.postfilters is not None: print("Applying configured postfilters to marker masks...") full_marker_masks = apply_postfilters( full_marker_masks, loaded_images.marker_image, marker_model_config, ) detailed_rows: list[dict[str, int | float]] = [] for roi_id in roi_ids: roi_mask_2d = roi_labels_2d == roi_id bbox = get_bbox_2d(roi_mask_2d) if bbox is None: continue y_slice, x_slice = bbox cell_roi = full_cell_masks[:, y_slice, x_slice] marker_roi = full_marker_masks[:, y_slice, x_slice] rows = analyze_label_overlaps(cell_roi, marker_roi, roi_id=int(roi_id)) for row in rows: row["y_min"] = int(y_slice.start) row["y_max"] = int(y_slice.stop) row["x_min"] = int(x_slice.start) row["x_max"] = int(x_slice.stop) detailed_rows.extend(rows) detailed_table = pd.DataFrame(detailed_rows) if not detailed_table.empty: detailed_table = detailed_table.sort_values( by=["roi_id", "cell_label", "overlap_voxels"], ascending=[True, True, False], ) effective_optional_region_masks = ( np.asarray(optional_region_masks, dtype=np.uint32) if optional_region_masks is not None else ( np.asarray(optional_region_result.labels, dtype=np.uint32) if optional_region_result is not None else None ) ) effective_optional_region_masks = _apply_analysis_z_bounds( effective_optional_region_masks, effective_analysis_z_bounds, ) if ( optional_region_model_config is not None and optional_region_model_config.postfilters is not None and effective_optional_region_masks is not None ): print("Applying configured postfilters to optional third-channel masks...") effective_optional_region_masks = apply_postfilters( effective_optional_region_masks, loaded_images.optional_region_image, optional_region_model_config, ) summary_table = _build_summary_table( detailed_table, full_cell_masks, loaded_images.voxel_scale_zyx, colocalization_config, roi_labels_2d, effective_optional_region_masks, ) marker_properties = _build_channel_properties_table( label_image=full_marker_masks, voxel_scale_zyx=loaded_images.voxel_scale_zyx, roi_labels_2d=roi_labels_2d, label_column="marker_label", metric_prefix="marker", ) third_channel_properties = ( _build_channel_properties_table( label_image=effective_optional_region_masks, voxel_scale_zyx=loaded_images.voxel_scale_zyx, roi_labels_2d=roi_labels_2d, label_column="optional_region_label", metric_prefix="optional_region", ) if effective_optional_region_masks is not None else None ) roi_cell_summary = _build_channel_roi_summary_table( object_table=summary_table, roi_labels_2d=roi_labels_2d, label_column="cell_label", count_column="n_cells", metric_prefix="cell", ) roi_marker_summary = _build_channel_roi_summary_table( object_table=marker_properties, roi_labels_2d=roi_labels_2d, label_column="marker_label", count_column="n_marker_objects", metric_prefix="marker", ) roi_third_channel_summary = ( _build_channel_roi_summary_table( object_table=third_channel_properties, roi_labels_2d=roi_labels_2d, label_column="optional_region_label", count_column="n_3rd_channel_objects", metric_prefix="optional_region", ) if third_channel_properties is not None else None ) overview_table = _build_overview_table( roi_labels_2d=roi_labels_2d, loaded_images=loaded_images, cell_masks=full_cell_masks, marker_masks=full_marker_masks, summary_table=summary_table, optional_region_masks=effective_optional_region_masks, analysis_z_bounds=effective_analysis_z_bounds, ) positive_cell_masks = build_positive_cell_mask(full_cell_masks, summary_table) return ColocalizationRunResult( cell_masks=full_cell_masks, marker_masks=full_marker_masks, positive_cell_masks=positive_cell_masks, optional_region_masks=effective_optional_region_masks, analysis_z_bounds=effective_analysis_z_bounds, tables=ColocalizationTables( detailed=detailed_table, summary=summary_table, overview=overview_table, marker_properties=marker_properties, third_channel_properties=third_channel_properties, roi_cell_summary=roi_cell_summary, roi_marker_summary=roi_marker_summary, roi_third_channel_summary=roi_third_channel_summary, ), cell_refinement_context=cell_refinement_context, marker_refinement_context=marker_refinement_context, optional_region_refinement_context=optional_region_refinement_context, )
def _rebuild_masks_from_refinement_context( image_shape: tuple[int, int, int], refinement_context: CellposeChannelRefinementContext, flow_threshold: float | None = None, cellprob_threshold: float | None = None, ) -> np.ndarray: """Recompute full-size masks from cached Cellpose network outputs. The expensive neural-network forward pass is skipped here. Instead, this helper rebuilds masks only from stored Cellpose flow and cell-probability arrays for each ROI and stitches them back into one full-size label image. """ rebuilt_masks = np.zeros(image_shape, dtype=np.uint32) label_offset = 0 for roi_cache in refinement_context.roi_caches: current_flow_threshold = roi_cache.flow_threshold if flow_threshold is None else flow_threshold current_cellprob_threshold = ( roi_cache.cellprob_threshold if cellprob_threshold is None else cellprob_threshold ) masks_roi = refinement_context.model._compute_masks( roi_cache.shape_for_masks, roi_cache.dP, roi_cache.cellprob, flow_threshold=current_flow_threshold, cellprob_threshold=current_cellprob_threshold, min_size=roi_cache.min_size, max_size_fraction=roi_cache.max_size_fraction, niter=roi_cache.niter, do_3D=roi_cache.do_3d, stitch_threshold=0.0, ) masks_roi = np.asarray(masks_roi, dtype=np.uint32) if not roi_cache.do_3d: masks_roi = masks_roi[np.newaxis, :, :] masks_roi[:, ~roi_cache.roi_mask_crop_2d] = 0 masks_roi = relabel_with_offset(masks_roi, label_offset) if masks_roi.max() > 0: label_offset = int(masks_roi.max()) y_slice = slice(roi_cache.y_min, roi_cache.y_max) x_slice = slice(roi_cache.x_min, roi_cache.x_max) rebuilt_masks[:, y_slice, x_slice] = np.maximum( rebuilt_masks[:, y_slice, x_slice], masks_roi, ) return rebuilt_masks
[docs] def refine_run_result_from_cellpose_cache( loaded_images: LoadedImageChannels, roi_labels_2d: np.ndarray, run_result: ColocalizationRunResult, colocalization_config: ColocalizationConfig, cell_model_config: CellposeModelConfig | None = None, marker_model_config: CellposeModelConfig | None = None, optional_region_model_config: CellposeModelConfig | None = None, cell_cellprob_threshold: float | None = None, cell_flow_threshold: float | None = None, marker_cellprob_threshold: float | None = None, marker_flow_threshold: float | None = None, optional_region_cellprob_threshold: float | None = None, optional_region_flow_threshold: float | None = None, optional_region_result: OptionalRegionSegmentationResult | None = None, ) -> ColocalizationRunResult: """Recompute masks and tables from cached Cellpose outputs. This avoids rerunning the neural network forward pass and only recomputes the mask generation stage from cached ``dP`` and ``cellprob`` arrays. Passing ``cell_model_config=None``, ``marker_model_config=None``, and/or ``optional_region_model_config=None`` leaves the respective channel unchanged and reuses the masks already stored in ``run_result``. Any z-crop defined in the supplied refinement configs is interpreted as one global analysis z range and applied consistently across all channels. When no refinement config specifies a z-crop, the function preserves the z-bounds stored in ``run_result``. When the loaded images already represent a z-projection, additional z-cropping is ignored because the data have already been collapsed to a singleton-z analysis view. """ if loaded_images.z_projection_method is not None: analysis_z_bounds = None else: analysis_z_bounds = _resolve_analysis_z_bounds( loaded_images.cell_image.shape[0], cell_model_config, marker_model_config, optional_region_model_config, fallback=run_result.analysis_z_bounds, ) if cell_model_config is None: rebuilt_cell_masks = np.asarray(run_result.cell_masks, dtype=np.uint32).copy() else: if run_result.cell_refinement_context is None: raise ValueError( "Cell refinement was requested, but this run result does not " "contain Cellpose refinement caches for the cell channel. " "Threshold-only refinement is currently available only when " "the initial segmentation was produced with a supported " "Cellpose 4 run." ) rebuilt_cell_masks = _rebuild_masks_from_refinement_context( image_shape=loaded_images.cell_image.shape, refinement_context=run_result.cell_refinement_context, flow_threshold=cell_flow_threshold, cellprob_threshold=cell_cellprob_threshold, ) if marker_model_config is None: rebuilt_marker_masks = np.asarray(run_result.marker_masks, dtype=np.uint32).copy() else: if run_result.marker_refinement_context is None: raise ValueError( "Marker refinement was requested, but this run result does not " "contain Cellpose refinement caches for the marker channel. " "Threshold-only refinement is currently available only when " "the initial segmentation was produced with a supported " "Cellpose 4 run." ) rebuilt_marker_masks = _rebuild_masks_from_refinement_context( image_shape=loaded_images.marker_image.shape, refinement_context=run_result.marker_refinement_context, flow_threshold=marker_flow_threshold, cellprob_threshold=marker_cellprob_threshold, ) if optional_region_model_config is None: rebuilt_optional_region_masks = ( None if run_result.optional_region_masks is None else np.asarray(run_result.optional_region_masks, dtype=np.uint32).copy() ) else: if run_result.optional_region_refinement_context is None: raise ValueError( "Optional third-channel refinement was requested, but this run " "result does not contain Cellpose refinement caches for that " "channel. Threshold-only refinement is currently available " "only when the initial third-channel segmentation was produced " "with a supported Cellpose 4 run." ) rebuilt_optional_region_masks = _rebuild_masks_from_refinement_context( image_shape=loaded_images.optional_region_image.shape, refinement_context=run_result.optional_region_refinement_context, flow_threshold=optional_region_flow_threshold, cellprob_threshold=optional_region_cellprob_threshold, ) rebuilt_cell_masks = _apply_analysis_z_bounds(rebuilt_cell_masks, analysis_z_bounds) rebuilt_marker_masks = _apply_analysis_z_bounds(rebuilt_marker_masks, analysis_z_bounds) rebuilt_optional_region_masks = _apply_analysis_z_bounds(rebuilt_optional_region_masks, analysis_z_bounds) return analyze_existing_masks( loaded_images=loaded_images, roi_labels_2d=roi_labels_2d, cell_masks=rebuilt_cell_masks, marker_masks=rebuilt_marker_masks, colocalization_config=colocalization_config, optional_region_result=optional_region_result, optional_region_masks=rebuilt_optional_region_masks, analysis_z_bounds=analysis_z_bounds, cell_refinement_context=run_result.cell_refinement_context, marker_refinement_context=run_result.marker_refinement_context, optional_region_refinement_context=run_result.optional_region_refinement_context, cell_model_config=cell_model_config, marker_model_config=marker_model_config, optional_region_model_config=optional_region_model_config, )
def _build_summary_table( detailed_table: pd.DataFrame, cell_masks: np.ndarray, voxel_scale_zyx: tuple[float, float, float], config: ColocalizationConfig, roi_labels_2d: np.ndarray, optional_region_masks: np.ndarray | None = None, ) -> pd.DataFrame: """Aggregate detailed overlap rows into one summary row per cell. The summary retains the strongest marker overlap for each cell, classifies positivity according to ``ColocalizationConfig``, and optionally augments the result with third-channel positivity columns. """ if detailed_table.empty: return pd.DataFrame( columns=[ "roi_id", "cell_label", "cell_voxels", "marker_positive", "n_overlapping_markers", "best_marker_label", "best_overlap_voxels", "best_overlap_fraction", "cell_voxels_props", "centroid_z", "centroid_y", "centroid_x", "cell_voxels_delta", *_channel_metric_column_names("cell"), ] ) cell_props = regionprops_table(cell_masks, properties=("label", "area", "centroid")) props_table = pd.DataFrame(cell_props).rename( columns={ "label": "cell_label", "area": "cell_voxels_props", "centroid-0": "centroid_z", "centroid-1": "centroid_y", "centroid-2": "centroid_x", } ) props_table["cell_label"] = props_table["cell_label"].astype(int) summary_rows: list[dict[str, int | float | bool]] = [] for roi_id in np.unique(detailed_table["roi_id"]): detailed_roi = detailed_table[detailed_table["roi_id"] == roi_id] for cell_label in np.unique(detailed_roi["cell_label"]): detailed_cell = detailed_roi[detailed_roi["cell_label"] == cell_label] cell_voxels = int(detailed_cell["cell_voxels"].iloc[0]) n_overlapping_markers = int(detailed_cell["n_overlapping_markers"].max()) best_idx = detailed_cell["overlap_voxels"].idxmax() best_overlap_voxels = int(detailed_cell.loc[best_idx, "overlap_voxels"]) best_overlap_fraction = float(detailed_cell.loc[best_idx, "overlap_fraction_of_cell"]) best_marker_label = detailed_cell.loc[best_idx, "marker_label"] marker_positive = ( (n_overlapping_markers > 0) and (best_overlap_voxels >= config.min_overlap_voxels) and (best_overlap_fraction >= config.overlap_fraction_threshold) ) summary_rows.append( { "roi_id": int(roi_id), "cell_label": int(cell_label), "cell_voxels": cell_voxels, "marker_positive": bool(marker_positive), "n_overlapping_markers": n_overlapping_markers, "best_marker_label": int(best_marker_label) if not pd.isna(best_marker_label) else np.nan, "best_overlap_voxels": best_overlap_voxels, "best_overlap_fraction": best_overlap_fraction, } ) cell_properties = _build_channel_properties_table( label_image=cell_masks, voxel_scale_zyx=voxel_scale_zyx, roi_labels_2d=roi_labels_2d, label_column="cell_label", metric_prefix="cell", ) summary_table = pd.DataFrame(summary_rows).merge(props_table, on="cell_label", how="left") summary_table["cell_voxels_delta"] = summary_table["cell_voxels"] - summary_table["cell_voxels_props"] summary_table = summary_table.merge( cell_properties, on=["roi_id", "cell_label"], how="left", ) if config.evaluate_optional_region_cell_positivity: optional_region_summary = _build_optional_region_summary_table( roi_labels_2d=roi_labels_2d, cell_masks=cell_masks, optional_region_masks=optional_region_masks, config=config, ) summary_table = summary_table.merge( optional_region_summary, on=["roi_id", "cell_label"], how="left", ) summary_table["optional_region_positive"] = summary_table["optional_region_positive"].fillna(False).astype(bool) summary_table["n_overlapping_optional_region_objects"] = ( summary_table["n_overlapping_optional_region_objects"].fillna(0).astype(int) ) summary_table["best_optional_region_overlap_voxels"] = ( summary_table["best_optional_region_overlap_voxels"].fillna(0).astype(int) ) summary_table["best_optional_region_overlap_fraction"] = ( summary_table["best_optional_region_overlap_fraction"].fillna(0.0).astype(float) ) summary_table["marker_and_optional_region_positive"] = ( summary_table["marker_positive"] & summary_table["optional_region_positive"] ) return summary_table def _build_optional_region_summary_table( roi_labels_2d: np.ndarray, cell_masks: np.ndarray, optional_region_masks: np.ndarray | None, config: ColocalizationConfig, ) -> pd.DataFrame: """Summarize which cells overlap an optional third-channel segmentation. This produces one row per cell with overlap statistics against the segmented optional third channel so the main summary table can expose both separate third-channel positivity and marker-and-third-channel double-positivity. """ if optional_region_masks is None: return pd.DataFrame( columns=[ "roi_id", "cell_label", "optional_region_positive", "n_overlapping_optional_region_objects", "best_optional_region_label", "best_optional_region_overlap_voxels", "best_optional_region_overlap_fraction", ] ) rows: list[dict[str, int | float | bool]] = [] roi_ids = np.unique(roi_labels_2d) roi_ids = roi_ids[roi_ids != 0] for roi_id in roi_ids: roi_mask_2d = roi_labels_2d == roi_id bbox = get_bbox_2d(roi_mask_2d) if bbox is None: continue y_slice, x_slice = bbox cell_roi = cell_masks[:, y_slice, x_slice] optional_region_roi = optional_region_masks[:, y_slice, x_slice] detailed_rows = analyze_label_overlaps(cell_roi, optional_region_roi, roi_id=int(roi_id)) detailed_table = pd.DataFrame(detailed_rows) if detailed_table.empty: continue for cell_label in np.unique(detailed_table["cell_label"]): detailed_cell = detailed_table[detailed_table["cell_label"] == cell_label] n_overlapping_objects = int(detailed_cell["n_overlapping_markers"].max()) best_idx = detailed_cell["overlap_voxels"].idxmax() best_overlap_voxels = int(detailed_cell.loc[best_idx, "overlap_voxels"]) best_overlap_fraction = float(detailed_cell.loc[best_idx, "overlap_fraction_of_cell"]) best_optional_region_label = detailed_cell.loc[best_idx, "marker_label"] optional_region_positive = ( (n_overlapping_objects > 0) and (best_overlap_voxels >= config.min_overlap_voxels) and (best_overlap_fraction >= config.overlap_fraction_threshold) ) rows.append( { "roi_id": int(roi_id), "cell_label": int(cell_label), "optional_region_positive": bool(optional_region_positive), "n_overlapping_optional_region_objects": n_overlapping_objects, "best_optional_region_label": ( int(best_optional_region_label) if not pd.isna(best_optional_region_label) else np.nan ), "best_optional_region_overlap_voxels": best_overlap_voxels, "best_optional_region_overlap_fraction": best_overlap_fraction, } ) return pd.DataFrame(rows) def _build_overview_table( roi_labels_2d: np.ndarray, loaded_images: LoadedImageChannels, cell_masks: np.ndarray, marker_masks: np.ndarray, summary_table: pd.DataFrame, optional_region_masks: np.ndarray | None, analysis_z_bounds: tuple[int, int] | None, ) -> pd.DataFrame: """Create one ROI overview row per ROI. The overview combines ROI geometry, counts of segmented objects, counts of positive cells, and channel-wise occupancy metrics. When a global analysis z-crop is active, ROI volume and all 3D occupancy metrics are computed only inside that z interval. """ z_size_um, y_size_um, x_size_um = loaded_images.voxel_scale_zyx pixel_area_um2 = y_size_um * x_size_um voxel_volume_um3 = z_size_um * y_size_um * x_size_um n_z = loaded_images.cell_image.shape[0] z_start, z_stop = analysis_z_bounds if analysis_z_bounds is not None else (0, n_z) analysis_depth = z_stop - z_start rows: list[dict[str, int | float]] = [] for roi_id in np.unique(roi_labels_2d): if roi_id == 0: continue roi_mask_2d = roi_labels_2d == roi_id roi_area_px = int(roi_mask_2d.sum()) roi_area_um2 = float(roi_area_px * pixel_area_um2) roi_volume_voxels = int(roi_area_px * analysis_depth) roi_volume_um3 = float(roi_volume_voxels * voxel_volume_um3) roi_mask_3d = np.zeros((n_z, *roi_mask_2d.shape), dtype=bool) roi_mask_3d[z_start:z_stop] = np.repeat( roi_mask_2d[np.newaxis, :, :], analysis_depth, axis=0, ) cell_labels_roi = np.unique(cell_masks[roi_mask_3d]) cell_labels_roi = cell_labels_roi[cell_labels_roi != 0] marker_labels_roi = np.unique(marker_masks[roi_mask_3d]) marker_labels_roi = marker_labels_roi[marker_labels_roi != 0] summary_roi = summary_table[summary_table["roi_id"] == roi_id] row: dict[str, int | float] = { "roi_id": int(roi_id), "n_cells": int(len(cell_labels_roi)), "n_marker_positive_cells": int(summary_roi["marker_positive"].sum()) if not summary_roi.empty else 0, "n_marker_objects": int(len(marker_labels_roi)), "drawn_roi_area_px": roi_area_px, "drawn_roi_area_um2": roi_area_um2, "roi_volume_voxels": roi_volume_voxels, "roi_volume_um3": roi_volume_um3, } if "optional_region_positive" in summary_roi.columns: row["n_optional_region_positive_cells"] = ( int(summary_roi["optional_region_positive"].sum()) if not summary_roi.empty else 0 ) if "marker_and_optional_region_positive" in summary_roi.columns: row["n_marker_and_optional_region_positive_cells"] = ( int(summary_roi["marker_and_optional_region_positive"].sum()) if not summary_roi.empty else 0 ) row.update( _compute_mask_occupancy_metrics( "cell", cell_masks, roi_mask_2d, loaded_images.voxel_scale_zyx, analysis_z_bounds, ) ) row.update( _compute_mask_occupancy_metrics( "marker", marker_masks, roi_mask_2d, loaded_images.voxel_scale_zyx, analysis_z_bounds, ) ) if optional_region_masks is not None: row.update( _compute_mask_occupancy_metrics( "optional_region", optional_region_masks, roi_mask_2d, loaded_images.voxel_scale_zyx, analysis_z_bounds, ) ) rows.append(row) return pd.DataFrame(rows) def _compute_mask_occupancy_metrics( prefix: str, label_image: np.ndarray, roi_mask_2d: np.ndarray, voxel_scale_zyx: tuple[float, float, float], analysis_z_bounds: tuple[int, int] | None, ) -> dict[str, int | float]: """Compute generic ROI occupancy metrics for one segmented channel. Both 2D projection coverage and true 3D occupancy are reported. When an ``analysis_z_bounds`` interval is provided, only voxels inside that z range contribute to the 3D denominator and numerator. """ z_size_um, y_size_um, x_size_um = voxel_scale_zyx pixel_area_um2 = y_size_um * x_size_um voxel_volume_um3 = z_size_um * y_size_um * x_size_um n_z = label_image.shape[0] z_start, z_stop = analysis_z_bounds if analysis_z_bounds is not None else (0, n_z) analysis_depth = z_stop - z_start roi_area_px = int(roi_mask_2d.sum()) roi_volume_voxels = int(roi_area_px * analysis_depth) roi_mask_3d = np.zeros((n_z, *roi_mask_2d.shape), dtype=bool) roi_mask_3d[z_start:z_stop] = np.repeat(roi_mask_2d[np.newaxis, :, :], analysis_depth, axis=0) occupancy_mask = (label_image > 0) & roi_mask_3d occupied_volume_voxels = int(occupancy_mask.sum()) occupied_volume_um3 = float(occupied_volume_voxels * voxel_volume_um3) occupancy_3d_percent = float(100 * occupied_volume_voxels / roi_volume_voxels) if roi_volume_voxels > 0 else np.nan occupancy_projection_2d = occupancy_mask.any(axis=0) occupied_area_px = int((occupancy_projection_2d & roi_mask_2d).sum()) occupied_area_um2 = float(occupied_area_px * pixel_area_um2) occupancy_2d_percent = float(100 * occupied_area_px / roi_area_px) if roi_area_px > 0 else np.nan return { f"{prefix}_occupancy_area_px_2d_projection": occupied_area_px, f"{prefix}_occupancy_area_um2_2d_projection": occupied_area_um2, f"{prefix}_occupancy_coverage_2d_percent": occupancy_2d_percent, f"{prefix}_occupancy_volume_voxels_3d": occupied_volume_voxels, f"{prefix}_occupancy_volume_um3_3d": occupied_volume_um3, f"{prefix}_occupancy_coverage_3d_percent": occupancy_3d_percent, }
[docs] def run_roi_cellpose_colocalization( loaded_images: LoadedImageChannels, roi_labels_2d: np.ndarray, cell_model_config: CellposeModelConfig, marker_model_config: CellposeModelConfig, colocalization_config: ColocalizationConfig, runtime_config: RuntimeConfig, optional_region_model_config: CellposeModelConfig | None = None, optional_region_result: OptionalRegionSegmentationResult | None = None, ) -> ColocalizationRunResult: """Run the configured ROI-wise segmentation workflow and build result tables. The pipeline always segments ROI crops in ``XY`` and may additionally apply one global analysis z-crop resolved from the participating channel configs. That z-crop affects all channels, all ROIs, and all downstream quantification consistently, while the exported and visualized arrays keep full-stack shape. When the input ``loaded_images`` bundle already represents a prepared z-projection, segmentation and quantification operate on that projected 2D analysis view instead of the original full stack. The two primary analysis channels can each use either Cellpose or one of the supported threshold-based backends. An optional third channel can be segmented through the same mechanism and contributes occupancy metrics, and optionally per-cell positivity, to the result tables. """ if not runtime_config.process_rois: raise ValueError("ROI processing is disabled in RuntimeConfig.") roi_ids = np.unique(roi_labels_2d) roi_ids = roi_ids[roi_ids != 0] print(f"Found {len(roi_ids)} ROIs: {roi_ids}") if loaded_images.z_projection_method is not None: analysis_z_bounds = None else: analysis_z_bounds = _resolve_analysis_z_bounds( loaded_images.cell_image.shape[0], cell_model_config, marker_model_config, optional_region_model_config, ) z_slice = slice(*analysis_z_bounds) if analysis_z_bounds is not None else slice(None) cell_model, marker_model = create_cellpose_models_for_channels( cell_model_config=cell_model_config, marker_model_config=marker_model_config, use_gpu=runtime_config.use_gpu, ) optional_region_model = None if ( optional_region_model_config is not None and loaded_images.optional_region_image is not None and normalize_segmentation_method(optional_region_model_config.segmentation_method) == "cellpose" ): optional_region_model = create_cellpose_model( optional_region_model_config.model_name_or_path, runtime_config.use_gpu, ) full_cell_masks = np.zeros(loaded_images.cell_image.shape, dtype=np.uint32) full_marker_masks = np.zeros(loaded_images.marker_image.shape, dtype=np.uint32) full_optional_region_masks = None if optional_region_model_config is not None and loaded_images.optional_region_image is not None: full_optional_region_masks = np.zeros(loaded_images.optional_region_image.shape, dtype=np.uint32) cell_roi_caches: list[CellposeRefinementRoiCache] = [] marker_roi_caches: list[CellposeRefinementRoiCache] = [] optional_region_roi_caches: list[CellposeRefinementRoiCache] = [] cell_label_offset = 0 marker_label_offset = 0 for roi_id in roi_ids: print(f"\nProcessing ROI {int(roi_id)}...") roi_mask_2d = roi_labels_2d == roi_id bbox = get_bbox_2d(roi_mask_2d) if bbox is None: print(f"Skipping ROI {int(roi_id)}: empty ROI") continue y_slice, x_slice = bbox roi_mask_crop_2d = roi_mask_2d[y_slice, x_slice] cell_crop = loaded_images.cell_image[z_slice, y_slice, x_slice].copy() marker_crop = loaded_images.marker_image[z_slice, y_slice, x_slice].copy() cell_crop = apply_prefilter(cell_crop, cell_model_config) marker_crop = apply_prefilter(marker_crop, marker_model_config) cell_crop[:, ~roi_mask_crop_2d] = 0 marker_crop[:, ~roi_mask_crop_2d] = 0 optional_region_masks_roi = None optional_region_refinement_cache = None if optional_region_model_config is not None: if loaded_images.optional_region_image is None: raise ValueError( "An optional-region segmentation config was provided, but " "no optional region channel was loaded." ) optional_region_crop = loaded_images.optional_region_image[z_slice, y_slice, x_slice].copy() optional_region_crop = apply_prefilter(optional_region_crop, optional_region_model_config) optional_region_crop[:, ~roi_mask_crop_2d] = 0 cell_masks_roi, cell_refinement_cache = evaluate_segmentation_method( cell_model, cell_crop, cell_model_config, loaded_images.voxel_scale_zyx, ) marker_masks_roi, marker_refinement_cache = evaluate_segmentation_method( marker_model, marker_crop, marker_model_config, loaded_images.voxel_scale_zyx, ) if optional_region_model_config is not None: optional_region_masks_roi, optional_region_refinement_cache = evaluate_segmentation_method( optional_region_model, optional_region_crop, optional_region_model_config, loaded_images.voxel_scale_zyx, ) if cell_refinement_cache is not None: cell_refinement_cache.roi_id = int(roi_id) cell_refinement_cache.y_min = int(y_slice.start) cell_refinement_cache.y_max = int(y_slice.stop) cell_refinement_cache.x_min = int(x_slice.start) cell_refinement_cache.x_max = int(x_slice.stop) cell_refinement_cache.roi_mask_crop_2d = roi_mask_crop_2d.copy() cell_roi_caches.append(cell_refinement_cache) if marker_refinement_cache is not None: marker_refinement_cache.roi_id = int(roi_id) marker_refinement_cache.y_min = int(y_slice.start) marker_refinement_cache.y_max = int(y_slice.stop) marker_refinement_cache.x_min = int(x_slice.start) marker_refinement_cache.x_max = int(x_slice.stop) marker_refinement_cache.roi_mask_crop_2d = roi_mask_crop_2d.copy() marker_roi_caches.append(marker_refinement_cache) if optional_region_refinement_cache is not None: optional_region_refinement_cache.roi_id = int(roi_id) optional_region_refinement_cache.y_min = int(y_slice.start) optional_region_refinement_cache.y_max = int(y_slice.stop) optional_region_refinement_cache.x_min = int(x_slice.start) optional_region_refinement_cache.x_max = int(x_slice.stop) optional_region_refinement_cache.roi_mask_crop_2d = roi_mask_crop_2d.copy() optional_region_roi_caches.append(optional_region_refinement_cache) cell_masks_roi = relabel_with_offset(cell_masks_roi, cell_label_offset) marker_masks_roi = relabel_with_offset(marker_masks_roi, marker_label_offset) if cell_masks_roi.max() > 0: cell_label_offset = int(cell_masks_roi.max()) if marker_masks_roi.max() > 0: marker_label_offset = int(marker_masks_roi.max()) full_cell_masks[z_slice, y_slice, x_slice] = np.maximum( full_cell_masks[z_slice, y_slice, x_slice], cell_masks_roi, ) full_marker_masks[z_slice, y_slice, x_slice] = np.maximum( full_marker_masks[z_slice, y_slice, x_slice], marker_masks_roi, ) if full_optional_region_masks is not None and optional_region_masks_roi is not None: full_optional_region_masks[z_slice, y_slice, x_slice] = np.maximum( full_optional_region_masks[z_slice, y_slice, x_slice], optional_region_masks_roi, ) cell_refinement_context = None marker_refinement_context = None if cell_roi_caches: cell_refinement_context = CellposeChannelRefinementContext( model=cell_model, model_name_or_path=cell_model_config.model_name_or_path, roi_caches=cell_roi_caches, ) if marker_roi_caches: marker_refinement_context = CellposeChannelRefinementContext( model=marker_model, model_name_or_path=marker_model_config.model_name_or_path, roi_caches=marker_roi_caches, ) optional_region_refinement_context = None if optional_region_roi_caches and optional_region_model_config is not None: optional_region_refinement_context = CellposeChannelRefinementContext( model=optional_region_model, model_name_or_path=optional_region_model_config.model_name_or_path, roi_caches=optional_region_roi_caches, ) return analyze_existing_masks( loaded_images=loaded_images, roi_labels_2d=roi_labels_2d, cell_masks=full_cell_masks, marker_masks=full_marker_masks, colocalization_config=colocalization_config, optional_region_result=optional_region_result, optional_region_masks=full_optional_region_masks, analysis_z_bounds=analysis_z_bounds, cell_refinement_context=cell_refinement_context, marker_refinement_context=marker_refinement_context, optional_region_refinement_context=optional_region_refinement_context, cell_model_config=cell_model_config, marker_model_config=marker_model_config, optional_region_model_config=optional_region_model_config, )
# %% END