Source code for exo2micro.alignment

"""
alignment.py
============
Image registration for pre/post-stain microscopy image pairs.

The multiscale registration pipeline:
  1. Coarse boundary correlation (translation + rotation + isotropic scale)
  2. ICP affine refinement of boundary contour correspondences
  3. Optional fine homography ECC pass (disabled by default)

All functions assume images are 2D float32 arrays.
"""

import numpy as np
import cv2
from .utils import equalize_pair


# ==============================================================================
# INTERNAL HELPERS
# ==============================================================================

def _extract_boundary(im_eq, boundary_width, boundary_smooth):
    """
    Extract the outer tissue boundary ring from an equalised image.

    Pipeline:
        1. Density-adaptive Gaussian blur to suppress interior texture
        2. Binary threshold to get tissue footprint
        3. Morphological closing to fill gaps
        4. Keep largest connected component; fill interior holes
        5. Large convex-envelope closing (dense images only)
        6. Erode then subtract to isolate the outer boundary ring
        7. Gaussian-soften the ring for ECC gradient tracking

    Parameters
    ----------
    im_eq : ndarray
        Equalised float32 image normalised to [0, 1].
    boundary_width : int
        Boundary ring thickness in pixels (erosion radius).
    boundary_smooth : float
        Gaussian softening sigma applied to the ring.

    Returns
    -------
    boundary_ecc : ndarray
        Softened boundary ring (float32, [0, 1]).
    boundary_raw : ndarray
        Hard binary boundary ring (float32, 0/255).
    """
    from scipy.ndimage import binary_fill_holes as _bfh

    nonzero_frac = float(np.mean(im_eq > 0.05))

    # Density-adaptive blur
    if nonzero_frac > 0.1:
        blur_sigma = min(max(im_eq.shape[0] // 50, 3), 20)
    else:
        blur_sigma = min(max(im_eq.shape[0] // 100, 2), 10)
    print(f"    _extract_boundary: blur_sigma={blur_sigma:.1f}  "
          f"(nonzero_frac={nonzero_frac:.3f})")

    im_blurred = cv2.GaussianBlur(im_eq, (0, 0), float(blur_sigma))
    im_blurred = im_blurred / (im_blurred.max() + 1e-8)

    # Adaptive threshold
    blurred_nz = im_blurred[im_blurred > 0]
    if nonzero_frac < 0.1 and len(blurred_nz) > 0:
        threshold = max(float(np.percentile(blurred_nz, 30)), 0.02)
        print(f"    _extract_boundary: sparse — adaptive threshold={threshold:.3f}")
    else:
        threshold = 0.1
    tissue_mask = (im_blurred > threshold).astype(np.uint8)

    # Morphological closing
    close_r = boundary_width * 3 if nonzero_frac > 0.3 else boundary_width
    close_k = cv2.getStructuringElement(
        cv2.MORPH_ELLIPSE, (close_r * 2 + 1, close_r * 2 + 1))
    tissue_mask = cv2.morphologyEx(tissue_mask, cv2.MORPH_CLOSE, close_k,
                                   iterations=2)

    # Keep largest connected component
    n_labels, labels, stats, _ = cv2.connectedComponentsWithStats(
        tissue_mask, connectivity=8)
    if n_labels > 1:
        largest = 1 + np.argmax(stats[1:, cv2.CC_STAT_AREA])
        tissue_mask = (labels == largest).astype(np.uint8)
    tissue_mask = _bfh(tissue_mask.astype(bool)).astype(np.uint8)

    # Large convex-envelope closing (dense images only)
    if nonzero_frac > 0.1:
        hull_r = max(im_eq.shape[0] // 20, boundary_width * 2)
        hull_k = cv2.getStructuringElement(
            cv2.MORPH_ELLIPSE, (hull_r * 2 + 1, hull_r * 2 + 1))
        tissue_mask = cv2.morphologyEx(
            tissue_mask, cv2.MORPH_CLOSE, hull_k, iterations=1)
        tissue_mask = _bfh(tissue_mask.astype(bool)).astype(np.uint8)

    # Erode then subtract to get boundary ring
    erode_k = cv2.getStructuringElement(
        cv2.MORPH_ELLIPSE, (boundary_width * 2 + 1, boundary_width * 2 + 1))
    eroded = cv2.erode(tissue_mask, erode_k, iterations=1)
    boundary = (tissue_mask - eroded).astype(np.float32)

    boundary_raw = boundary * 255.0

    # Gaussian soften for ECC
    boundary_ecc = cv2.GaussianBlur(
        boundary, (0, 0), max(float(boundary_smooth), 1.0))
    boundary_ecc = boundary_ecc / (boundary_ecc.max() + 1e-8)

    return boundary_ecc, boundary_raw


def _prepare_pair_for_ecc(post_small, pre_small, gauss_sigma, usharp,
                           use_edges, boundary_width, boundary_smooth):
    """
    Jointly preprocess a pair of downsampled images for ECC registration.

    Parameters
    ----------
    post_small : ndarray
        Downsampled post-stain image (float32).
    pre_small : ndarray
        Downsampled pre-stain image (float32).
    gauss_sigma : float
        Gaussian pre-smoothing sigma (0 = disabled).
    usharp : float or False
        Unsharp mask sigma (False = disabled).
    use_edges : bool
        If True, extract boundary ring instead of raw image.
    boundary_width : int
        Boundary ring thickness in pixels.
    boundary_smooth : float
        Gaussian softening sigma on the boundary ring.

    Returns
    -------
    post_ecc, pre_ecc : ndarray
        Preprocessed images ready for ECC.
    post_edges, pre_edges : ndarray or None
        Raw boundary rings for display.
    """
    post_edges = None
    pre_edges = None

    post_eq, pre_eq = equalize_pair(post_small, pre_small)

    if gauss_sigma > 0:
        post_eq = cv2.GaussianBlur(post_eq, (0, 0), float(gauss_sigma))
        pre_eq = cv2.GaussianBlur(pre_eq, (0, 0), float(gauss_sigma))

    if use_edges:
        post_ecc, post_edges = _extract_boundary(post_eq, boundary_width,
                                                  boundary_smooth)
        pre_ecc, pre_edges = _extract_boundary(pre_eq, boundary_width,
                                                boundary_smooth)
    elif usharp:
        post_blur = cv2.GaussianBlur(post_eq, (0, 0), float(usharp))
        pre_blur = cv2.GaussianBlur(pre_eq, (0, 0), float(usharp))
        post_ecc = np.clip(post_eq - post_blur, 0, None)
        pre_ecc = np.clip(pre_eq - pre_blur, 0, None)
        post_ecc = post_ecc / (post_ecc.max() + 1e-8)
        pre_ecc = pre_ecc / (pre_ecc.max() + 1e-8)
    else:
        post_ecc = post_eq
        pre_ecc = pre_eq

    return post_ecc, pre_ecc, post_edges, pre_edges


# ==============================================================================
# COARSE ALIGNMENT — BOUNDARY CORRELATION
# ==============================================================================

[docs] def boundary_correlation_coarse(post_full, pre_full, coarse_scale, boundary_width, boundary_smooth, rotation_search=True, angle_range=20, angle_step=1, scale_search=True, scale_min=0.85, scale_max=1.15, scale_step=0.05): """ Find coarse rigid alignment by maximising boundary ring overlap. Uses phase correlation and brute-force rotation/scale search over extracted tissue boundary rings. Parameters ---------- post_full : ndarray Full-resolution post-stain image (float32). pre_full : ndarray Full-resolution pre-stain image (float32). coarse_scale : float Downsample factor for boundary extraction. boundary_width : int Boundary ring thickness in pixels at coarse resolution. boundary_smooth : float Gaussian softening sigma on the boundary ring. rotation_search : bool Search over rotations (default True). angle_range : float Rotation search range ±degrees (default 20). angle_step : float Rotation search step in degrees (default 1). scale_search : bool Search over isotropic scale factors (default True). scale_min, scale_max, scale_step : float Scale search range and step. Returns ------- warp_coarse_full : ndarray 3x3 similarity homography at full resolution. best_angle : float Best rotation angle found (degrees). best_dx, best_dy : float Best translation in coarse-scale pixels. post_boundary_raw, pre_boundary_raw : ndarray Boundary rings at coarse scale (for diagnostics). """ # Downsample post_small = cv2.resize(post_full, None, fx=coarse_scale, fy=coarse_scale, interpolation=cv2.INTER_AREA) pre_small = cv2.resize(pre_full, None, fx=coarse_scale, fy=coarse_scale, interpolation=cv2.INTER_AREA) # Joint equalisation and boundary extraction post_eq, pre_eq = equalize_pair(post_small, pre_small) post_boundary, post_boundary_raw = _extract_boundary( post_eq, boundary_width, boundary_smooth) pre_boundary, pre_boundary_raw = _extract_boundary( pre_eq, boundary_width, boundary_smooth) h, w = post_boundary.shape # Centroid for rotation pivot pre_moments = cv2.moments((pre_boundary_raw > 0).astype(np.uint8)) if pre_moments['m00'] > 0: cx = pre_moments['m10'] / pre_moments['m00'] cy = pre_moments['m01'] / pre_moments['m00'] else: cx, cy = w / 2.0, h / 2.0 best_response = -np.inf best_angle = 0.0 best_scale = 1.0 best_dx = 0.0 best_dy = 0.0 angles = ([0.0] if not rotation_search else np.arange(-angle_range, angle_range + angle_step, angle_step)) scales = (np.arange(scale_min, scale_max + scale_step * 0.5, scale_step) if scale_search else [1.0]) def _l2norm(im): n = np.linalg.norm(im) return im / n if n > 1e-8 else im post_boundary_norm = _l2norm(post_boundary.astype(np.float32)) scale_responses = {} for scale in scales: scale_best_response = -np.inf for angle in angles: if angle == 0.0 and scale == 1.0: pre_transformed = pre_boundary.astype(np.float32) else: M_rs = cv2.getRotationMatrix2D((cx, cy), angle, scale) pre_transformed = cv2.warpAffine( pre_boundary.astype(np.float32), M_rs, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=0) pre_norm = _l2norm(pre_transformed) shift, response = cv2.phaseCorrelate(post_boundary_norm, pre_norm) if response > best_response: best_response = response best_angle = float(angle) best_scale = float(scale) best_dx = float(shift[0]) best_dy = float(shift[1]) if response > scale_best_response: scale_best_response = response scale_responses[float(scale)] = scale_best_response # Print scale response surface print(" Scale response surface (best response at each scale):") for sc, resp in sorted(scale_responses.items()): marker = ' <-- BEST' if abs(sc - best_scale) < 1e-6 else '' print(f" scale={sc:.3f} response={resp:.4f}{marker}") print(f" Boundary correlation: dx={best_dx:.1f}, dy={best_dy:.1f} px, " f"angle={best_angle:.1f} deg scale={best_scale:.3f} " f"(response={best_response:.4f})") # Build full-resolution similarity transform cx_full = cx / coarse_scale cy_full = cy / coarse_scale dx_full = best_dx / coarse_scale dy_full = best_dy / coarse_scale theta = np.radians(best_angle) cos_t, sin_t = np.cos(theta), np.sin(theta) s = best_scale warp_coarse_full = np.array([ [s * cos_t, -s * sin_t, (1 - s * cos_t) * cx_full + s * sin_t * cy_full + dx_full], [s * sin_t, s * cos_t, (1 - s * cos_t) * cy_full - s * sin_t * cx_full + dy_full], [0, 0, 1], ], dtype=np.float32) return (warp_coarse_full, best_angle, best_dx, best_dy, post_boundary_raw, pre_boundary_raw)
# ============================================================================== # ICP REFINEMENT # ==============================================================================
[docs] def refine_icp(post_full, warp_coarse, post_bnd_pts=None, pre_bnd_pts=None, bnd_scale=None, max_translation=200, max_rotation=5.0, max_scale_delta=0.1, max_scale_diff=0.05, close_threshold_px=30, close_threshold_floor=5, max_icp_iter=20): """ Refine coarse alignment using ICP on tissue boundary points. Only well-matched boundary points (those already close after the coarse pass) are used, avoiding regions where tissue shapes genuinely differ. Parameters ---------- post_full : ndarray Full-resolution post-stain image (float32). warp_coarse : ndarray 3x3 coarse homography at full resolution. post_bnd_pts : ndarray or None Nx2 boundary points for post-stain in coarse-scale coords. pre_bnd_pts : ndarray or None Nx2 boundary points for pre-stain in post-coarse frame. bnd_scale : float or None Conversion factor from boundary-point coords to full-res pixels. max_translation : float Maximum allowed translation correction (default 200 px). max_rotation : float Maximum allowed rotation correction (default 5 degrees). max_scale_delta : float Maximum deviation of scale from 1.0 (default 0.1). max_scale_diff : float Maximum allowed ``|scale_x - scale_y|`` (default 0.05). close_threshold_px : float Initial max NN distance for well-matched pairs (default 30). close_threshold_floor : float Tightest threshold allowed (default 5). max_icp_iter : int Maximum ICP iterations (default 20). Returns ------- warp_refined : ndarray 3x3 refined homography. accepted : bool True if ICP correction was accepted. """ from scipy.spatial import cKDTree if post_bnd_pts is None or pre_bnd_pts is None: print(" ICP: no boundary points provided — skipping") return warp_coarse.copy(), False if bnd_scale is None: bnd_scale = 1.0 def _subsample_pts(pts, n_max=800): if pts is None or len(pts) == 0: return pts if len(pts) <= n_max: return pts idx = np.round(np.linspace(0, len(pts) - 1, n_max)).astype(int) return pts[idx] post_pts = _subsample_pts(post_bnd_pts).astype(np.float64) pre_pts_current = _subsample_pts(pre_bnd_pts).astype(np.float64) # Auto-set initial threshold from actual NN distances tree_init = cKDTree(pre_pts_current) dists_init, _ = tree_init.query(post_pts, k=1) median_dist = float(np.median(dists_init)) adaptive_start = float(np.clip(2.0 * median_dist, close_threshold_floor, close_threshold_px)) print(f" ICP: post={len(post_pts)} pts, pre={len(pre_pts_current)} pts, " f"median_dist={median_dist:.1f} px " f"threshold={adaptive_start:.1f}{close_threshold_floor} px") # Global affine pre-correction for large misalignment M_accum = np.eye(3, dtype=np.float64) M_pre_corr = None freeze_scale = False if median_dist > 15.0: print(f" ICP: median_dist={median_dist:.1f} px > 15 px — " f"running global affine pre-correction") tree_pre = cKDTree(post_pts) _, idx_pre = tree_pre.query(pre_pts_current, k=1) src_pre = pre_pts_current dst_pre = post_pts[idx_pre] src_mean_pre = src_pre.mean(axis=0) dst_mean_pre = dst_pre.mean(axis=0) src_c_pre = src_pre - src_mean_pre dst_c_pre = dst_pre - dst_mean_pre n_pre = len(src_c_pre) A_pre = np.zeros((2 * n_pre, 4), dtype=np.float64) A_pre[0::2, 0] = src_c_pre[:, 0] A_pre[0::2, 1] = src_c_pre[:, 1] A_pre[1::2, 2] = src_c_pre[:, 0] A_pre[1::2, 3] = src_c_pre[:, 1] b_pre = np.empty(2 * n_pre, dtype=np.float64) b_pre[0::2] = dst_c_pre[:, 0] b_pre[1::2] = dst_c_pre[:, 1] res_pre, _, _, _ = np.linalg.lstsq(A_pre, b_pre, rcond=None) a_p, b_p, c_p, d_p = res_pre tx_p = dst_mean_pre[0] - (a_p * src_mean_pre[0] + b_p * src_mean_pre[1]) ty_p = dst_mean_pre[1] - (c_p * src_mean_pre[0] + d_p * src_mean_pre[1]) sx_p = np.sqrt(a_p**2 + c_p**2) sy_p = np.sqrt(b_p**2 + d_p**2) ang_p = np.degrees(np.arctan2(c_p, a_p)) print(f" pre-correction: tx={tx_p:.1f} ty={ty_p:.1f} " f"sx={sx_p:.4f} sy={sy_p:.4f} angle={ang_p:.2f} deg") pre_ok = (abs(tx_p) < close_threshold_px * 5 and abs(ty_p) < close_threshold_px * 5 and abs(sx_p - 1.0) < 0.3 and abs(sy_p - 1.0) < 0.3 and abs(ang_p) < 10.0) if pre_ok: M_pre_corr = np.array([[a_p, b_p, tx_p], [c_p, d_p, ty_p], [0, 0, 1]], dtype=np.float64) ones = np.ones((len(pre_pts_current), 1)) pre_pts_current = (M_pre_corr[:2] @ np.hstack([pre_pts_current, ones]).T).T M_accum = np.eye(3, dtype=np.float64) tree_post = cKDTree(pre_pts_current) dists_post, _ = tree_post.query(post_pts, k=1) median_dist_post = float(np.median(dists_post)) print(f" pre-correction accepted — " f"median_dist: {median_dist:.1f}{median_dist_post:.1f} px") adaptive_start = float(np.clip(2.0 * median_dist_post, close_threshold_floor, close_threshold_px)) freeze_scale = True else: M_pre_corr = None print(" pre-correction REJECTED (implausible) — proceeding without it") # ICP main loop best_mean_dist = np.inf stagnation_count = 0 for iteration in range(max_icp_iter): frac = iteration / max(max_icp_iter - 1, 1) threshold_iter = (adaptive_start * (1 - frac) + close_threshold_floor * frac) tree = cKDTree(pre_pts_current) dists, idx = tree.query(post_pts, k=1) keep = dists < threshold_iter n_close = keep.sum() if n_close < 10: print(f" ICP iter {iteration}: only {n_close} close pairs " f"(threshold={threshold_iter:.1f} px) — stopping") break src = pre_pts_current[idx[keep]] dst = post_pts[keep] src_mean = src.mean(axis=0) dst_mean = dst.mean(axis=0) src_c = src - src_mean dst_c = dst - dst_mean n = len(src_c) A = np.zeros((2 * n, 4), dtype=np.float64) A[0::2, 0] = src_c[:, 0] A[0::2, 1] = src_c[:, 1] A[1::2, 2] = src_c[:, 0] A[1::2, 3] = src_c[:, 1] b_vec = np.empty(2 * n, dtype=np.float64) b_vec[0::2] = dst_c[:, 0] b_vec[1::2] = dst_c[:, 1] result, _, _, _ = np.linalg.lstsq(A, b_vec, rcond=None) a, b, c, d = result if freeze_scale: col0_norm = np.sqrt(a * a + c * c) col1_norm = np.sqrt(b * b + d * d) if col0_norm > 1e-8: a /= col0_norm c /= col0_norm if col1_norm > 1e-8: b /= col1_norm d /= col1_norm tx = dst_mean[0] - (a * src_mean[0] + b * src_mean[1]) ty = dst_mean[1] - (c * src_mean[0] + d * src_mean[1]) M_iter = np.array([[a, b, tx], [c, d, ty], [0, 0, 1]]) iter_sx = np.sqrt(a * a + c * c) iter_sy = np.sqrt(b * b + d * d) step_too_large = (abs(tx) > threshold_iter * 3 or abs(ty) > threshold_iter * 3 or abs(iter_sx - 1.0) > 0.15 or abs(iter_sy - 1.0) > 0.15) if step_too_large: print(f" ICP iter {iteration}: step rejected " f"(tx={tx:.1f} ty={ty:.1f} " f"sx={iter_sx:.3f} sy={iter_sy:.3f})") continue M_accum = M_iter @ M_accum ones = np.ones((len(pre_pts_current), 1)) pre_pts_current = (M_iter[:2] @ np.hstack([pre_pts_current, ones]).T).T mean_close = float(dists[keep].mean()) sx = np.sqrt(a * a + c * c) sy = np.sqrt(b * b + d * d) print(f" ICP iter {iteration}: n_close={n_close} " f"threshold={threshold_iter:.1f} px mean_dist={mean_close:.2f} px " f"tx={tx:.2f} ty={ty:.2f} scale_x={sx:.4f} scale_y={sy:.4f}") if mean_close < 0.5: print(f" converged at iteration {iteration}") break if mean_close < best_mean_dist * 0.98: best_mean_dist = mean_close stagnation_count = 0 else: stagnation_count += 1 if stagnation_count >= 5: print(f" ICP stagnated at iteration {iteration} " f"(mean_dist={mean_close:.2f} px) — stopping") break # Decompose and sanity-check a = M_accum[0, 0] b = M_accum[0, 1] c = M_accum[1, 0] d = M_accum[1, 1] tx_bnd = M_accum[0, 2] ty_bnd = M_accum[1, 2] angle_deg = np.degrees(np.arctan2(c, a)) scale_x = np.sqrt(a * a + c * c) scale_y = np.sqrt(b * b + d * d) tx_full = tx_bnd * bnd_scale ty_full = ty_bnd * bnd_scale print(f" ICP accumulated: tx={tx_full:.1f} px ty={ty_full:.1f} px " f"angle={angle_deg:.3f} deg scale_x={scale_x:.4f} " f"scale_y={scale_y:.4f}") reasons = [] if abs(tx_full) > max_translation: reasons.append(f"tx={tx_full:.1f} px > {max_translation}") if abs(ty_full) > max_translation: reasons.append(f"ty={ty_full:.1f} px > {max_translation}") if abs(angle_deg) > max_rotation: reasons.append(f"angle={angle_deg:.2f} deg > {max_rotation}") if abs(scale_x - 1.0) > max_scale_delta: reasons.append(f"scale_x={scale_x:.4f} outside [1±{max_scale_delta}]") if abs(scale_y - 1.0) > max_scale_delta: reasons.append(f"scale_y={scale_y:.4f} outside [1±{max_scale_delta}]") if abs(scale_x - scale_y) > max_scale_diff: reasons.append(f"|scale_x-scale_y|={abs(scale_x - scale_y):.4f} " f"> {max_scale_diff}") if reasons: print(f" ICP REJECTED ({'; '.join(reasons)}) — keeping coarse") return warp_coarse.copy(), False # Lift to full resolution and compose S_up = np.diag([bnd_scale, bnd_scale, 1.0]) S_down = np.diag([1.0 / bnd_scale, 1.0 / bnd_scale, 1.0]) M_icp_full = S_up @ M_accum @ S_down M_icp_full_inv = np.linalg.inv(M_icp_full) if M_pre_corr is not None: M_pre_full = S_up @ M_pre_corr @ S_down M_pre_full_inv = np.linalg.inv(M_pre_full) warp_refined = (M_icp_full_inv @ M_pre_full_inv @ warp_coarse.astype(np.float64)).astype(np.float32) else: warp_refined = (M_icp_full_inv @ warp_coarse.astype(np.float64)).astype(np.float32) print(" ICP accepted") return warp_refined, True
# ============================================================================== # PHASE CORRELATION PRE-ALIGNMENT # ==============================================================================
[docs] def prealign_phase_correlation(post_im, pre_im): """ Compute a coarse translational pre-alignment using phase correlation. Parameters ---------- post_im : ndarray Post-stain (fixed) image (2D). pre_im : ndarray Pre-stain image to be shifted (2D). Returns ------- post_full : ndarray (float32) pre_shift : ndarray (float32) Pre-stain image shifted to coarsely align. shift : tuple (dx, dy) shift applied. """ post = post_im.astype(np.float32) pre = pre_im.astype(np.float32) shift, response = cv2.phaseCorrelate(post, pre) dx, dy = shift print(f"prealign_phase_correlation: dx={dx:.1f}, dy={dy:.1f} px " f"(response={response:.4f})") h, w = post.shape M = np.float32([[1, 0, dx], [0, 1, dy]]) pre_shift = cv2.warpAffine(pre, M, (w, h), flags=cv2.INTER_LINEAR) return post, pre_shift, (dx, dy)
# ============================================================================== # INTERIOR FEATURE-MATCHING REFINEMENT # ==============================================================================
[docs] def refine_interior_sift(post_full, pre_full, warp_init, interior_blur_base=8.0, interior_max_correction=500, interior_min_inlier_ratio=0.4): """ Refine alignment using SIFT feature matching on interior image content. Detects SIFT features in both post-stain and (ICP-warped) pre-stain images, matches them, and computes a refined homography via RANSAC. This is robust to the intensity differences between pre and post because SIFT features are based on local gradient structure. Microbe-only features (present in post but not pre) are naturally rejected by the matching + RANSAC pipeline as outliers. This function was previously called ``refine_interior_ecc`` when the body used an ECC pyramid (pre-v2.1). The implementation now uses SIFT exclusively; the old name has been retained as an alias for one version for backward compatibility. Parameters ---------- post_full : ndarray Full-resolution post-stain image (float32). pre_full : ndarray Full-resolution pre-stain image (float32, un-warped). warp_init : ndarray 3×3 homography from upstream alignment (boundary corr + ICP). interior_blur_base : float Gaussian blur sigma applied before feature detection (default 8.0). Suppresses microbe-scale features to focus on mineral grain structure. interior_max_correction : float Maximum allowed total correction in full-resolution pixels (default 500). If exceeded, the refinement is rejected. interior_min_inlier_ratio : float Minimum RANSAC inlier ratio to accept the refinement (default 0.4). Below this, the match quality is too poor to trust. Returns ------- warp_refined : ndarray 3×3 refined homography at full resolution. result : dict Diagnostic information: - 'success' : bool - 'levels_completed' : int (1 if success, 0 if not) - 'total_levels' : int (always 1) - 'estimated_accuracy_px' : float - 'level_details' : list of dict - 'failure_reason' : str or None """ h, w = post_full.shape # Work at a moderate scale for feature detection — full resolution # is too slow and has too many microbe-scale features. # 0.5× gives good feature density while keeping computation reasonable. work_scale = 0.5 if max(h, w) * work_scale > 15000: work_scale = 10000.0 / max(h, w) if max(h, w) * work_scale < 2000: work_scale = min(1.0, 2000.0 / max(h, w)) print(f"\n Interior feature matching (work_scale={work_scale:.3f})...") # Downsample post_small = cv2.resize(post_full, None, fx=work_scale, fy=work_scale, interpolation=cv2.INTER_AREA) pre_small = cv2.resize(pre_full, None, fx=work_scale, fy=work_scale, interpolation=cv2.INTER_AREA) h_s, w_s = post_small.shape # Rescale warp to work scale S = np.diag([work_scale, work_scale, 1.0]).astype(np.float32) S_inv = np.diag([1.0 / work_scale, 1.0 / work_scale, 1.0]).astype(np.float32) warp_work = S @ warp_init.astype(np.float32) @ S_inv # Warp pre into post frame pre_warped = cv2.warpPerspective( pre_small, warp_work, (w_s, h_s), flags=cv2.INTER_LINEAR + cv2.WARP_INVERSE_MAP) # Light blur to suppress noise / microbe spots before feature detection blur_sigma = max(interior_blur_base * 0.5, 1.0) post_blur = cv2.GaussianBlur(post_small, (0, 0), float(blur_sigma)) pre_blur = cv2.GaussianBlur(pre_warped, (0, 0), float(blur_sigma)) # Convert to uint8 for SIFT (needs 8-bit input) def _to_uint8(im): px = im[im > 0] if len(px) == 0: return np.zeros_like(im, dtype=np.uint8) vmin = float(np.percentile(px, 1)) vmax = float(np.percentile(px, 99)) if vmax <= vmin: vmax = vmin + 1.0 scaled = np.clip((im - vmin) / (vmax - vmin) * 255, 0, 255) return scaled.astype(np.uint8) post_u8 = _to_uint8(post_blur) pre_u8 = _to_uint8(pre_blur) # Detect SIFT features sift = cv2.SIFT_create(nfeatures=5000) kp_post, desc_post = sift.detectAndCompute(post_u8, None) kp_pre, desc_pre = sift.detectAndCompute(pre_u8, None) print(f" Features detected: post={len(kp_post)}, pre={len(kp_pre)}") detail = { 'level': 1, 'scale': work_scale, 'n_features_post': len(kp_post), 'n_features_pre': len(kp_pre), } if len(kp_post) < 10 or len(kp_pre) < 10: detail['status'] = 'failed' detail['reason'] = 'Too few features detected' print(f" Too few features — keeping ICP result") print(f" → To assess ICP alignment, blink:") print(f" fits/01_padded_post.fits") print(f" fits/02_icp_aligned_pre.fits") return warp_init.copy(), { 'success': False, 'levels_completed': 0, 'total_levels': 1, 'estimated_accuracy_px': float('inf'), 'level_details': [detail], 'failure_reason': detail['reason'], } # Match features using FLANN index_params = dict(algorithm=1, trees=5) # FLANN_INDEX_KDTREE search_params = dict(checks=50) flann = cv2.FlannBasedMatcher(index_params, search_params) matches = flann.knnMatch(desc_post, desc_pre, k=2) # Lowe's ratio test good_matches = [] for m_pair in matches: if len(m_pair) == 2: m, n = m_pair if m.distance < 0.7 * n.distance: good_matches.append(m) print(f" Good matches (Lowe ratio<0.7): {len(good_matches)}") detail['n_matches'] = len(good_matches) if len(good_matches) < 10: detail['status'] = 'failed' detail['reason'] = f'Only {len(good_matches)} good matches (need ≥10)' print(f" Too few matches — keeping ICP result") print(f" → To assess ICP alignment, blink:") print(f" fits/01_padded_post.fits") print(f" fits/02_icp_aligned_pre.fits") return warp_init.copy(), { 'success': False, 'levels_completed': 0, 'total_levels': 1, 'estimated_accuracy_px': float('inf'), 'level_details': [detail], 'failure_reason': detail['reason'], } # Extract matched point coordinates pts_post = np.float32([kp_post[m.queryIdx].pt for m in good_matches]) pts_pre = np.float32([kp_pre[m.trainIdx].pt for m in good_matches]) # Find homography with RANSAC # This maps pre_warped coords → post coords H_delta, inlier_mask = cv2.findHomography( pts_pre, pts_post, cv2.RANSAC, ransacReprojThreshold=3.0, maxIters=5000, confidence=0.999) if H_delta is None: detail['status'] = 'failed' detail['reason'] = 'RANSAC failed to find homography' print(f" RANSAC failed — keeping ICP result") return warp_init.copy(), { 'success': False, 'levels_completed': 0, 'total_levels': 1, 'estimated_accuracy_px': float('inf'), 'level_details': [detail], 'failure_reason': detail['reason'], } n_inliers = int(inlier_mask.sum()) inlier_ratio = n_inliers / len(good_matches) print(f" RANSAC inliers: {n_inliers}/{len(good_matches)} " f"({inlier_ratio*100:.1f}%)") detail['n_inliers'] = n_inliers detail['inlier_ratio'] = inlier_ratio # Sanity check: inlier ratio if inlier_ratio < interior_min_inlier_ratio: detail['status'] = 'rejected' detail['reason'] = (f'inlier ratio {inlier_ratio:.1%} below minimum ' f'{interior_min_inlier_ratio:.0%}') print(f" Inlier ratio {inlier_ratio:.1%} below minimum " f"({interior_min_inlier_ratio:.0%}) — REJECTED") print(f" → Match quality too low to trust. This may indicate:") print(f" - Too few shared features between pre and post") print(f" - Try increasing interior_blur_base to suppress microbe detail") print(f" → To assess ICP alignment, blink:") print(f" fits/01_padded_post.fits") print(f" fits/02_icp_aligned_pre.fits") return warp_init.copy(), { 'success': False, 'levels_completed': 0, 'total_levels': 1, 'estimated_accuracy_px': float('inf'), 'level_details': [detail], 'failure_reason': detail['reason'], } # Measure correction magnitude (translation component in full-res pixels) dx_work = H_delta[0, 2] dy_work = H_delta[1, 2] dx_fullres = dx_work / work_scale dy_fullres = dy_work / work_scale correction_fullres = np.sqrt(dx_fullres**2 + dy_fullres**2) detail['dx_correction'] = float(dx_fullres) detail['dy_correction'] = float(dy_fullres) detail['correction_fullres_px'] = float(correction_fullres) # Sanity check if correction_fullres > interior_max_correction: detail['status'] = 'rejected' detail['reason'] = (f'correction {correction_fullres:.1f}px ' f'exceeds limit {interior_max_correction:.1f}px') print(f" Correction {correction_fullres:.1f}px exceeds limit " f"({interior_max_correction:.1f}px) — REJECTED") print(f" → To assess ICP alignment, blink:") print(f" fits/01_padded_post.fits") print(f" fits/02_icp_aligned_pre.fits") return warp_init.copy(), { 'success': False, 'levels_completed': 0, 'total_levels': 1, 'estimated_accuracy_px': float('inf'), 'level_details': [detail], 'failure_reason': detail['reason'], } # Compose: H_delta maps pre_warped→post in work coords. # With WARP_INVERSE_MAP, we need dst→src map. # Current warp maps post→pre (dst→src). # H_delta maps pre_warped→post, so inv(H_delta) maps post→pre_warped. # Combined: inv(H_delta) ∘ warp_work maps post→pre (through pre_warped). # But warp_work already maps post→pre, and H_delta is a correction # in the warped space. So: new_warp = warp_work @ inv(H_delta) H_delta_inv = np.linalg.inv(H_delta.astype(np.float64)) warp_work_refined = (warp_work.astype(np.float64) @ H_delta_inv) # Lift back to full resolution warp_refined = (S_inv @ warp_work_refined @ S).astype(np.float32) # Compute reprojection error on inliers for accuracy estimate inlier_pts_post = pts_post[inlier_mask.ravel() == 1] inlier_pts_pre = pts_pre[inlier_mask.ravel() == 1] # Transform pre points by H_delta and measure distance to post points pre_transformed = cv2.perspectiveTransform( inlier_pts_pre.reshape(-1, 1, 2), H_delta).reshape(-1, 2) reproj_errors = np.sqrt(np.sum((inlier_pts_post - pre_transformed)**2, axis=1)) median_reproj = float(np.median(reproj_errors)) estimated_accuracy = median_reproj / work_scale # in full-res pixels detail['status'] = 'accepted' detail['median_reproj_error'] = median_reproj detail['estimated_accuracy_px'] = estimated_accuracy print(f" dx={dx_fullres:.1f}px dy={dy_fullres:.1f}px " f"(total={correction_fullres:.1f}px)") print(f" Median reprojection error: {median_reproj:.2f}px " f"(at work scale) → ±{estimated_accuracy:.1f}px (full res)") print(f" Interior feature matching: accepted " f"(estimated accuracy: ±{estimated_accuracy:.1f}px)") print(f" → To verify, blink:") print(f" fits/01_padded_post.fits") print(f" fits/03_interior_aligned_pre.fits") return warp_refined, { 'success': True, 'levels_completed': 1, 'total_levels': 1, 'estimated_accuracy_px': estimated_accuracy, 'level_details': [detail], 'failure_reason': None, }
# ============================================================================== # MAIN REGISTRATION API # ==============================================================================
[docs] def register_highorder(post_im, pre_im, stopit=500, stopdelta=1e-6, down_scale=0.3, usharp=False, gauss_sigma=0, use_edges=True, boundary_width=15, boundary_smooth=10, coarse_stopit=1000, coarse_stopdelta=1e-4, rotation_search=True, angle_range=20, angle_step=1, scale_search=True, scale_min=0.85, scale_max=1.15, scale_step=0.05, multiscale=True, fine_ecc=False, max_translation=200, max_rotation=5.0, max_scale_delta=0.2, max_scale_diff=0.15, save_prefix=None): """ Register pre-stain to post-stain using a multiscale strategy. Pipeline: 1. Boundary correlation coarse pass (translation + rotation + scale) 2. ICP affine refinement of boundary contour correspondences 3. Optional fine homography ECC pass (fine_ecc=True) Parameters ---------- post_im, pre_im : ndarray Post-stain (fixed) and pre-stain (moving) images (2D). stopit : int Max ECC iterations for fine pass (default 500). stopdelta : float ECC convergence threshold (default 1e-6). down_scale : float Downsample factor for fine ECC pass (default 0.3). usharp : float or False Unsharp mask sigma (default False). gauss_sigma : float Gaussian pre-smoothing sigma (default 0). use_edges : bool Extract boundary rings for ECC (default True). boundary_width : int Boundary ring thickness (default 15). boundary_smooth : float Boundary softening sigma (default 10). rotation_search : bool Search over rotations (default True). angle_range : float Rotation search ±degrees (default 20). angle_step : float Rotation search step (default 1). scale_search : bool Search over scale factors (default True). scale_min, scale_max, scale_step : float Scale search parameters. multiscale : bool Run coarse+ICP pipeline (default True). fine_ecc : bool Run fine ECC after ICP (default False). max_translation, max_rotation : float ICP sanity limits. max_scale_delta, max_scale_diff : float ICP scale sanity limits. save_prefix : str or None If set, save pipeline check plots to disk. Returns ------- post_full : ndarray (float32) pre_aligned : ndarray (float32) Pre-stain warped by full pipeline. pre_coarse_aligned : ndarray or None Pre-stain warped by coarse pass only (None if multiscale=False). warp_matrix : ndarray Final 3x3 homography. debug_data : dict Intermediate data for pipeline check plots. """ warp_mode = cv2.MOTION_HOMOGRAPHY warp_matrix = np.eye(3, 3, dtype=np.float32) warp_coarse = None criteria_fine = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, stopit, stopdelta) post_full = post_im.astype(np.float32) pre_full = pre_im.astype(np.float32) debug_data = { 'stages': [], 'post_bnd_raw': None, 'pre_bnd_raw': None, 'pre_bnd_warped': None, 'post_bnd_display': None, } if multiscale: coarse_scale = max(down_scale * 0.25, 0.02) print(f"Multiscale: boundary correlation (coarse_scale={coarse_scale:.3f}) " f"→ ICP → fine ECC (down_scale={down_scale:.3f})") # Step 1: Coarse boundary correlation (warp_coarse, best_angle, best_dx, best_dy, post_bnd_raw, pre_bnd_raw) = boundary_correlation_coarse( post_full, pre_full, coarse_scale=coarse_scale, boundary_width=boundary_width, boundary_smooth=boundary_smooth, rotation_search=rotation_search, angle_range=angle_range, angle_step=angle_step, scale_search=scale_search, scale_min=scale_min, scale_max=scale_max, scale_step=scale_step, ) warp_matrix = warp_coarse.copy() # Warp pre-stain boundary into post-coarse frame post_coarse = cv2.resize(post_full, None, fx=coarse_scale, fy=coarse_scale, interpolation=cv2.INTER_AREA) pre_coarse = cv2.resize(pre_full, None, fx=coarse_scale, fy=coarse_scale, interpolation=cv2.INTER_AREA) h_c, w_c = post_coarse.shape S_c = np.diag([coarse_scale, coarse_scale, 1.0]).astype(np.float32) S_c_inv = np.diag([1.0 / coarse_scale, 1.0 / coarse_scale, 1.0]).astype(np.float32) H_c = S_c @ warp_coarse @ S_c_inv pre_bnd_warped = cv2.warpPerspective( pre_bnd_raw.astype(np.float32), H_c, (w_c, h_c), flags=cv2.INTER_NEAREST + cv2.WARP_INVERSE_MAP) post_bnd_display = cv2.resize(post_bnd_raw, (w_c, h_c), interpolation=cv2.INTER_NEAREST) # Store debug data pre_warped_coarse = cv2.warpPerspective( pre_coarse, H_c, (w_c, h_c), flags=cv2.INTER_LINEAR + cv2.WARP_INVERSE_MAP) pre_bnd_pre = cv2.resize(pre_bnd_raw, (w_c, h_c), interpolation=cv2.INTER_NEAREST) debug_data['stages'].append({ 'label': f'Boundary corr (scale={coarse_scale:.3f})', 'post_raw': post_coarse, 'pre_raw': pre_coarse, 'post_ecc': post_bnd_display, 'pre_ecc': pre_bnd_warped, 'pre_warped': pre_warped_coarse, 'post_edges': post_bnd_display, 'pre_edges': pre_bnd_warped, 'pre_edges_pre': pre_bnd_pre, }) debug_data['post_bnd_raw'] = post_bnd_raw debug_data['pre_bnd_raw'] = pre_bnd_raw debug_data['pre_bnd_warped'] = pre_bnd_warped debug_data['post_bnd_display'] = post_bnd_display # Step 2: ICP refinement if warp_coarse is not None: h_full, w_full = post_full.shape def _bnd_to_pts(bnd_raw): mask = (bnd_raw > 10).astype(np.uint8) contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) if not contours: return None c = max(contours, key=cv2.contourArea) return c.reshape(-1, 2).astype(np.float32) post_icp_pts = _bnd_to_pts(post_bnd_raw) pre_icp_pts = _bnd_to_pts(pre_bnd_warped) warp_matrix, accepted = refine_icp( post_full, warp_coarse, post_bnd_pts=post_icp_pts, pre_bnd_pts=pre_icp_pts, bnd_scale=1.0 / coarse_scale, max_translation=max_translation, max_rotation=max_rotation, max_scale_delta=max_scale_delta, max_scale_diff=max_scale_diff, ) # Step 3: Optional fine ECC if fine_ecc: post_small = cv2.resize(post_full, None, fx=down_scale, fy=down_scale, interpolation=cv2.INTER_AREA) pre_small = cv2.resize(pre_full, None, fx=down_scale, fy=down_scale, interpolation=cv2.INTER_AREA) post_ecc, pre_ecc, post_fine_edges, pre_fine_edges = \ _prepare_pair_for_ecc(post_small, pre_small, gauss_sigma, usharp, use_edges, boundary_width, boundary_smooth) try: cc, warp_matrix = cv2.findTransformECC( post_ecc, pre_ecc, warp_matrix, warp_mode, criteria_fine) print(f" Fine ECC converged (cc={cc:.4f})") except cv2.error: print(f" Fine ECC failed — keeping coarse/ICP result") warp_matrix = (warp_coarse.copy() if warp_coarse is not None else warp_matrix) # Rescale from down_scale coords to full resolution S = np.diag([down_scale, down_scale, 1.0]).astype(np.float32) S_inv = np.diag([1.0 / down_scale, 1.0 / down_scale, 1.0]).astype(np.float32) warp_matrix = S_inv @ warp_matrix @ S else: print(" Fine ECC skipped (fine_ecc=False) — using coarse/ICP result") # Apply final homography h, w = post_full.shape pre_aligned = cv2.warpPerspective( pre_full, warp_matrix, (w, h), flags=cv2.INTER_LINEAR + cv2.WARP_INVERSE_MAP) # Coarse-only warp for comparison if warp_coarse is not None: pre_coarse_aligned = cv2.warpPerspective( pre_full, warp_coarse, (w, h), flags=cv2.INTER_LINEAR + cv2.WARP_INVERSE_MAP) else: pre_coarse_aligned = None # Report final homography decomposition H = warp_matrix / warp_matrix[2, 2] a, b_val = H[0, 0], H[0, 1] c, d = H[1, 0], H[1, 1] tx, ty = H[0, 2], H[1, 2] p, q = H[2, 0], H[2, 1] angle_deg = np.degrees(np.arctan2(c, a)) scale_x = np.sqrt(a * a + c * c) scale_y = np.sqrt(b_val * b_val + d * d) shear = (a * b_val + c * d) / (scale_x * scale_y) perspective_mag = np.sqrt(p * p + q * q) print("\n=========== Homography Alignment Results ===========") print(f"Rotation angle : {angle_deg:10.4f} degrees") print(f"Translation X (dx) : {tx:10.2f} pixels") print(f"Translation Y (dy) : {ty:10.2f} pixels") print(f"Scale X (magnify) : {scale_x:10.6f}") print(f"Scale Y (magnify) : {scale_y:10.6f}") print(f"Shear factor : {shear:10.6f}") print(f"Perspective distortion: {perspective_mag:10.8f}") print("====================================================\n") return post_full, pre_aligned, pre_coarse_aligned, warp_matrix, debug_data
# ── Backward compatibility alias (removed in a future version) ──────── refine_interior_ecc = refine_interior_sift """Deprecated alias for :func:`refine_interior_sift`. The function was renamed because its body uses SIFT, not ECC, after the v2.1 rewrite. Prefer the new name for new code."""