"""
plotting.py
===========
Active visualization functions for the exo2micro pipeline.
Includes:
- Registration pipeline check plots (boundary extraction, alignment)
- Pre/post diagnostic plots (heatmap, histograms, ratio histogram)
- Excess signal heatmap (diagonal reflection)
- Difference image visualization
- Simple image display
Legacy plotting functions (plot_im_sub, plot_diff_comparison,
plot_stretch_comparison, plot_zoom_region, plot_signal_scatter,
plot_ratio_histogram, plot_residual_histogram) have been moved
to exo2micro.legacy.
All plots include sample and dye in the title when provided.
"""
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.ndimage import gaussian_filter
# ==============================================================================
# COLORMAP UTILITIES
# ==============================================================================
def _make_diverging_cmap():
"""
Diverging colormap with black at the centre.
Designed for symmetric difference data (e.g. ``post − scale·pre``
or ``grid − grid.T``) where zero is the meaningful neutral and
the two signs need to be visually distinguishable. Goes from
saturated blue (negative extreme) through dark blue → black
(zero) → dark red → saturated red (positive extreme).
The black centre keeps near-zero values dim so genuine signal at
the extremes pops; this is opposite to most matplotlib diverging
colormaps (RdBu, coolwarm, etc.) which use a light/white centre.
Light-centre cmaps work well when "the data should look mostly
blank with bright spots of interest"; black-centre works well
when "the data should look mostly dark with bright spots of
interest at either sign". Difference images of stained tissue
are the latter case — most pixels are near zero (background or
well-cancelled stain) and we want the microbe signal to stand
out at the extremes.
NaN values render as white (set via ``set_bad``), which is how
background / masked pixels appear distinct from genuine
near-zero data.
"""
colors = [
'#3361ff', # saturated blue (negative extreme)
'#1a3380', # dark blue
'#000000', # black (zero)
'#801a33', # dark red
'#ff3361', # saturated red (positive extreme)
]
cmap = LinearSegmentedColormap.from_list('dark_div', colors, N=512)
cmap.set_bad('white')
return cmap
def _make_inferno_cmap():
"""Inferno colormap with NaN values set to white."""
cmap = plt.get_cmap('inferno').copy()
cmap.set_bad('white')
return cmap
def _title_prefix(sample, dye):
"""Build a 'Sample Dye — ' prefix string."""
_id = f'{sample} {dye}' if (sample or dye) else ''
return f'{_id} — ' if _id else ''
# ==============================================================================
# REGISTRATION PIPELINE CHECK PLOTS
# ==============================================================================
[docs]
def plot_registration(stages, title='Registration', save_path=None,
sample='', dye=''):
"""
Four-panel pipeline check figure for registration quality.
Panel 1a: Post-stain boundary extraction (cyan contour)
Panel 1b: Pre-stain boundary extraction (magenta contour)
Panel 2: Coarse alignment (both boundaries overlaid)
Panel 3: Final difference image (post - pre, unscaled)
Parameters
----------
stages : list of dict
Stage dicts from register_highorder debug_data['stages'].
title : str
Figure title (default 'Registration').
save_path : str or None
If set, save to this path.
sample, dye : str
For title prefix.
Returns
-------
fig : matplotlib.Figure or None
"""
if not stages:
return None
coarse = stages[0]
fine = stages[-1]
prefix = _title_prefix(sample, dye)
fig, axs = plt.subplots(1, 4, figsize=(28, 7))
fig.suptitle(f'{prefix}{title}', fontsize=13)
def _draw_boundary_on_image(ax, raw_im, edge_im, colour, panel_title):
offset = 1.0
log_im = np.log10(raw_im.astype(np.float32) + offset)
px = raw_im[raw_im > 0]
log_vmax = (np.log10(float(np.percentile(px, 99)) + offset)
if len(px) > 0 else 1.0)
log_vmin = (np.log10(max(float(np.percentile(px, 1)), 1.0) + offset)
if len(px) > 0 else 0.0)
ax.imshow(log_im, cmap='gray', vmin=log_vmin, vmax=log_vmax)
if edge_im is not None and edge_im.max() > 0:
h, w = raw_im.shape
ax.contour(np.arange(w), np.arange(h), edge_im, levels=[0.5],
colors=[colour], linewidths=[1.5], linestyles=['solid'])
ax.set_title(panel_title, fontsize=9)
ax.axis('off')
post_e = coarse.get('post_edges')
_pre_e_pre = coarse.get('pre_edges_pre')
pre_e_pre = _pre_e_pre if _pre_e_pre is not None else coarse.get('pre_edges')
post_raw = coarse['post_raw']
pre_raw = coarse['pre_raw']
# Panel 1a: post-stain boundary
_draw_boundary_on_image(axs[0], post_raw, post_e, 'cyan',
'1a. Post-Stain + boundary (cyan)')
# Panel 1b: pre-stain boundary (before alignment)
_draw_boundary_on_image(axs[1], pre_raw, pre_e_pre, 'magenta',
'1b. Pre-Stain + boundary (magenta, before alignment)')
# Panel 2: coarse alignment overlay
ax = axs[2]
ax.set_title('2. Coarse alignment '
'(cyan=post, magenta=pre after coarse)',
fontsize=9)
post_e2 = coarse.get('post_edges')
pre_e2 = coarse.get('pre_edges')
pre_pre = coarse.get('pre_edges_pre')
offset = 1.0
post_log = np.log10(post_raw.astype(np.float32) + offset)
px = post_raw[post_raw > 0]
log_vmax_dark = (np.log10(float(np.percentile(px, 70)) + offset)
if len(px) > 0 else 1.0)
log_vmin = (np.log10(max(float(np.percentile(px, 1)), 1.0) + offset)
if len(px) > 0 else 0.0)
ax.imshow(post_log, cmap='gray', vmin=log_vmin, vmax=log_vmax_dark, alpha=0.7)
h, w = post_raw.shape
ys, xs = np.arange(h), np.arange(w)
if post_e2 is not None and post_e2.max() > 0:
ax.contour(xs, ys, post_e2, levels=[0.5],
colors=['cyan'], linewidths=[1.5], linestyles=['solid'])
if pre_e2 is not None and pre_e2.max() > 0:
ax.contour(xs, ys, pre_e2, levels=[0.5],
colors=['magenta'], linewidths=[1.5], linestyles=['solid'])
if pre_pre is not None and pre_pre.max() > 0:
ax.contour(xs, ys, pre_pre, levels=[0.5],
colors=['magenta'], linewidths=[1.8], linestyles=['dashed'],
alpha=0.9)
legend_elements2 = [
Line2D([0], [0], color='cyan', linewidth=1.5,
label='Post-Stain boundary'),
Line2D([0], [0], color='magenta', linewidth=1.5, linestyle='solid',
label='Pre-Stain boundary (after alignment)'),
Line2D([0], [0], color='magenta', linewidth=1.2, linestyle='dashed',
label='Pre-Stain boundary (before alignment)'),
]
ax.legend(handles=legend_elements2, loc='lower left', fontsize=7,
framealpha=0.7, facecolor='black', labelcolor='white')
ax.axis('off')
# Panel 3: final difference
ax = axs[3]
ax.set_title('3. Post - Pre difference (post-warp, downsampled)',
fontsize=9)
pre_warped = fine['pre_warped']
post_raw_f = fine['post_raw']
diffim = post_raw_f.astype(np.float32) - pre_warped.astype(np.float32)
diff_px = diffim[np.abs(diffim) > 0]
dv = np.nanpercentile(np.abs(diff_px), 95) if len(diff_px) > 0 else 1.0
ax.imshow(diffim, cmap='bwr', vmin=-dv, vmax=dv)
legend_elements = [
Patch(facecolor='red', label='Post > Pre (excess post-stain signal)'),
Patch(facecolor='blue', label='Pre > Post (excess pre-stain signal)'),
Patch(facecolor='white', label='Balanced (good local alignment)'),
]
ax.legend(handles=legend_elements, loc='lower left', fontsize=7,
framealpha=0.7, facecolor='black', labelcolor='white')
ax.axis('off')
plt.tight_layout()
if save_path:
fig.savefig(save_path, dpi=150, bbox_inches='tight')
print(f" saved: {save_path}")
plt.close(fig)
else:
plt.show()
return fig
[docs]
def plot_fine_alignment(post_raw, pre_coarse_raw, pre_refined_raw,
post_bnd, pre_coarse_bnd, pre_refined_bnd,
title='Fine alignment', save_path=None,
sample='', dye=''):
"""
Two-panel comparison of coarse vs ICP-refined alignment.
Parameters
----------
post_raw : ndarray
Post-stain image (downsampled, float32).
pre_coarse_raw : ndarray
Pre-stain warped by coarse transform.
pre_refined_raw : ndarray
Pre-stain warped by ICP-refined transform.
post_bnd : ndarray
Post-stain boundary ring.
pre_coarse_bnd : ndarray
Pre-stain boundary after coarse alignment.
pre_refined_bnd : ndarray
Pre-stain boundary after ICP refinement.
title : str
Figure title.
save_path : str or None
If set, save to this path.
sample, dye : str
For title prefix.
Returns
-------
fig : matplotlib.Figure
"""
prefix = _title_prefix(sample, dye)
fig, axs = plt.subplots(1, 2, figsize=(14, 7))
fig.suptitle(f'{prefix}{title}', fontsize=13, fontweight='bold', y=1.01)
def _log_bg(ax, raw_im, alpha=0.8):
offset = 1.0
log_im = np.log10(raw_im.astype(np.float32) + offset)
px = raw_im[raw_im > 0]
vmax = (np.log10(float(np.percentile(px, 70)) + offset)
if len(px) > 0 else 1.0)
vmin = (np.log10(max(float(np.percentile(px, 1)), 1.0) + offset)
if len(px) > 0 else 0.0)
ax.imshow(log_im, cmap='gray', vmin=vmin, vmax=vmax, alpha=alpha)
h, w = post_raw.shape
xs, ys = np.arange(w), np.arange(h)
# Panel 1: after coarse alignment
_log_bg(axs[0], post_raw)
if post_bnd is not None and post_bnd.max() > 0:
axs[0].contour(xs, ys, post_bnd, levels=[0.5],
colors=['cyan'], linewidths=[1.5])
if pre_coarse_bnd is not None and pre_coarse_bnd.max() > 0:
axs[0].contour(xs, ys, pre_coarse_bnd, levels=[0.5],
colors=['magenta'], linewidths=[1.5])
axs[0].set_title('After coarse alignment\n'
'cyan = post | magenta = pre',
fontsize=10, pad=6)
axs[0].axis('off')
# Panel 2: after ICP refinement
_log_bg(axs[1], post_raw)
if post_bnd is not None and post_bnd.max() > 0:
axs[1].contour(xs, ys, post_bnd, levels=[0.5],
colors=['cyan'], linewidths=[1.5], alpha=0.8)
if pre_refined_bnd is not None and pre_refined_bnd.max() > 0:
axs[1].contour(xs, ys, pre_refined_bnd, levels=[0.5],
colors=['yellow'], linewidths=[1.8], linestyles=['solid'],
alpha=0.85)
if pre_coarse_bnd is not None and pre_coarse_bnd.max() > 0:
axs[1].contour(xs, ys, pre_coarse_bnd, levels=[0.5],
colors=['magenta'], linewidths=[1.0], linestyles=['solid'],
alpha=0.6)
axs[1].set_title('After ICP refinement\n'
'cyan = post | yellow = refined | magenta = coarse',
fontsize=10, pad=6)
axs[1].axis('off')
plt.tight_layout()
if save_path:
fig.savefig(save_path, dpi=150, bbox_inches='tight')
print(f" saved: {save_path}")
plt.close(fig)
else:
plt.show()
return fig
# ==============================================================================
# SCALE ESTIMATION DIAGNOSTICS
# ==============================================================================
[docs]
def plot_im(im, lims=None):
"""
Display a single image with auto-scaled or user-specified colorbar.
Parameters
----------
im : ndarray
lims : list or None
[vmin, vmax] display limits.
"""
plt.figure()
if not lims:
vmin = np.nanpercentile(im, 10)
vmax = np.nanpercentile(im, 90)
else:
vmin, vmax = lims[0], lims[1]
plt.imshow(im, vmin=vmin, vmax=vmax)
plt.colorbar()
# ==============================================================================
# SIMPLE PRE/POST DIAGNOSTICS (Phase 3)
# ==============================================================================
def _integer_bin_edges(data, percentile_clip=99.9):
"""
Build bin edges aligned to integer values (-0.5, 0.5, 1.5, …).
Works correctly whether the data is integer-valued (post-stain)
or continuous floats (warped pre-stain).
Returns
-------
edges : ndarray
Bin edges from -0.5 up to ceil(clip_value) + 0.5.
"""
hi = float(np.percentile(data, percentile_clip))
max_int = int(np.ceil(hi))
return np.arange(-0.5, max_int + 1.5, 1.0)
[docs]
def plot_pre_post_heatmap(post_im, pre_im,
sample='', dye='', save_path=None):
"""
2-D density heatmap of pre-stain vs post-stain pixel brightness.
Uses ALL pixels. The y-axis (post-stain) is binned by the actual
integer values present in the data, collapsing empty quantization
gaps without smoothing or interpolation. The x-axis (pre-stain)
uses standard integer bins.
Parameters
----------
post_im, pre_im : ndarray (2-D, float-like)
Post-stain and aligned pre-stain images.
sample, dye : str
For title.
save_path : str or None
If set, save figure to this path.
Returns
-------
fig : matplotlib.Figure
"""
post = np.rint(post_im.ravel()).astype(np.int32)
pre = np.rint(pre_im.ravel()).astype(np.int32)
n_pixels = len(post)
post = np.clip(post, 0, 255)
pre = np.clip(pre, 0, 255)
# Find the actual values present in post-stain (y-axis)
post_vals = np.sort(np.unique(post))
# Build bin edges at midpoints between consecutive real values
# so each bin is centred on a real value
y_edges = np.empty(len(post_vals) + 1)
y_edges[0] = post_vals[0] - 0.5
for i in range(1, len(post_vals)):
y_edges[i] = (post_vals[i - 1] + post_vals[i]) / 2.0
y_edges[-1] = post_vals[-1] + 0.5
# x-axis: standard integer bins (pre-stain is continuous from warp)
x_edges = np.arange(-0.5, 256.5, 1.0)
# 2-D histogram with asymmetric bins
h2d, x_out, y_out = np.histogram2d(pre, post, bins=[x_edges, y_edges])
# Log colour scale; mask zeros to white
with np.errstate(divide='ignore', invalid='ignore'):
h2d_log = np.log10(h2d.astype(np.float64))
h2d_log[~np.isfinite(h2d_log)] = np.nan
# vmax: exclude the (0,0) bin
valid = h2d_log.copy()
valid[0, 0] = np.nan # exclude (pre=0, post=0) corner
valid_vals = valid[np.isfinite(valid)]
vmin = float(np.nanmin(valid_vals)) if len(valid_vals) > 0 else 0
vmax = float(np.nanmax(valid_vals)) if len(valid_vals) > 0 else 1
prefix = _title_prefix(sample, dye)
fig, ax = plt.subplots(figsize=(9, 8))
cmap = plt.get_cmap('inferno').copy()
cmap.set_bad('white')
img = ax.pcolormesh(x_out, y_out, h2d_log.T,
cmap=cmap, shading='flat', rasterized=True,
vmin=vmin, vmax=vmax)
# Colorbar pinned to plot height via make_axes_locatable
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='4%', pad=0.08)
cb = fig.colorbar(img, cax=cax)
cb.set_label('log₁₀(pixel count)', fontsize=10)
# Identity line
ax.plot([0, 255], [0, 255], color='white', linewidth=1.2,
linestyle='--', alpha=0.8, label='scale = 1 (identity)')
ax.set_xlim(-0.5, 255.5)
ax.set_ylim(y_edges[0], y_edges[-1])
ax.set_aspect('equal')
ax.set_xlabel('pre-stain brightness', fontsize=11)
ax.set_ylabel('post-stain brightness', fontsize=11)
ax.set_title(f'{prefix}pre vs post pixel brightness '
f'(n={n_pixels:,} pixels)', fontsize=10)
ax.legend(fontsize=9, loc='upper left')
plt.tight_layout()
if save_path:
fig.savefig(save_path, dpi=150, bbox_inches='tight')
print(f" saved: {save_path}")
plt.close(fig)
else:
plt.show()
return fig
[docs]
def plot_excess_heatmap(post_im, pre_im, scale=None, scales=None,
sample='', dye='', save_path=None):
"""
Excess post-stain signal heatmap (upper triangle only).
Bins both axes by the actual integer values present in the
post-stain data, then displays the post-excess asymmetry
``grid − grid.T`` only above the diagonal (where post > pre).
The lower triangle is masked to NaN because, by construction,
``grid − grid.T`` is antisymmetric: every cell below the diagonal
is the negative of its reflection above. There is no additional
information in the lower half — displaying it would just be
showing the same numbers with flipped sign on the wrong side
of the line. Restricting the display to the upper triangle gives
a single, unambiguous answer to the question "where in the
brightness space is post brighter than pre, and by how much?".
Cells with positive excess (post-stain pixels outnumber the
reflected pre-stain pixels) are coloured by ``log₁₀(excess)``
using a sequential magma palette. Cells with zero or negative
excess in the upper triangle are also masked to NaN — these
correspond to brightness pairs where pre-stain pixels are
actually MORE common than post (a possible sign of bleaching,
quenching, or alignment artifact).
Optionally overplots one or more estimated scale lines.
Parameters
----------
post_im, pre_im : ndarray (2-D, float-like)
Post-stain and aligned pre-stain images.
scale : float or None
Single scale line to overplot (legacy convenience; equivalent to
passing ``scales=[('scale', scale, '#00cc88')]``).
scales : list of tuple or None
List of ``(label, value, color)`` tuples to overplot as scale lines.
Takes precedence over ``scale`` when both are given.
Typical usage::
scales=[
('Moffat fit', 1.123, '#00cc88'),
('ratio p99.1', 1.456, '#ff9933'),
('manual', 1.500, '#ff3366'),
]
sample, dye : str
For title.
save_path : str or None
If set, save figure to this path.
Returns
-------
fig : matplotlib.Figure
"""
post = np.rint(post_im.ravel()).astype(np.int32)
pre = np.rint(pre_im.ravel()).astype(np.int32)
n_pixels = len(post)
post = np.clip(post, 0, 255)
pre = np.clip(pre, 0, 255)
# Bin both axes by post-stain sensor values
vals = np.sort(np.unique(post))
n_vals = len(vals)
val_to_idx = np.full(256, -1, dtype=np.int32)
for i, v in enumerate(vals):
val_to_idx[v] = i
post_idx = val_to_idx[post]
# Round pre-stain to nearest post-stain value
idx_above = np.searchsorted(vals, pre, side='left').clip(0, n_vals - 1)
idx_below = (idx_above - 1).clip(0, n_vals - 1)
dist_above = np.abs(pre - vals[idx_above])
dist_below = np.abs(pre - vals[idx_below])
pre_idx = np.where(dist_below < dist_above, idx_below, idx_above).astype(np.int32)
# Build count grid
grid = np.zeros((n_vals, n_vals), dtype=np.float64)
np.add.at(grid, (pre_idx, post_idx), 1)
# Bin edges for pcolormesh
edges = np.empty(n_vals + 1)
edges[0] = vals[0] - 0.5
for i in range(1, n_vals):
edges[i] = (vals[i - 1] + vals[i]) / 2.0
edges[-1] = vals[-1] + 0.5
# Compute the excess and restrict to the upper triangle
# (post > pre). The lower triangle is masked to NaN because
# excess is antisymmetric — the lower half is just the negated,
# mirrored version of the upper half and contains no new
# information. Negative or zero values in the upper triangle
# (rare but possible — would mean pre-stain dominates at this
# brightness pair) are also masked since the magma colormap
# only conveys positive magnitudes.
excess = grid - grid.T
# In data-space: grid[i, j] = count(pre=vals[i], post=vals[j])
# So excess[i, j] is positive when there are more (pre=i, post=j)
# pixels than (pre=j, post=i). When j > i (post > pre, upper
# triangle in display), positive excess is the post-excess
# signal we want to show.
upper_tri_mask = np.triu(np.ones((n_vals, n_vals), dtype=bool), k=1)
# k=1 so the diagonal itself is NOT included (zero anyway).
excess_display = np.full_like(excess, np.nan)
pos_in_upper = upper_tri_mask & (excess > 0)
excess_display[pos_in_upper] = np.log10(excess[pos_in_upper])
# Color limits from finite data
finite_vals = excess_display[np.isfinite(excess_display)]
if len(finite_vals) > 0:
vmin = float(np.nanmin(finite_vals))
vmax = float(np.nanmax(finite_vals))
else:
vmin, vmax = 0.0, 1.0
# Plot — square data axes, sequential magma colormap.
prefix = _title_prefix(sample, dye)
fig, ax = plt.subplots(figsize=(9, 8))
cmap = plt.get_cmap('magma').copy()
cmap.set_bad('white')
img = ax.pcolormesh(edges, edges, excess_display.T,
cmap=cmap, shading='flat', rasterized=True,
vmin=vmin, vmax=vmax)
# Colorbar pinned to the same height as the plot via
# make_axes_locatable. Default fig.colorbar tries to steal
# axes space and ends up shorter than the parent on square
# axes; this approach guarantees a height match.
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='4%', pad=0.08)
cbar = fig.colorbar(img, cax=cax)
cbar.set_label('log₁₀(post excess count)', fontsize=10)
# 1:1 identity line — black dashed, visible against magma
ax.plot([vals[0], vals[-1]], [vals[0], vals[-1]],
color='black', linewidth=1.0,
linestyle='--', alpha=0.6, label='scale = 1')
# Scale lines
line_specs = []
if scales:
line_specs = list(scales)
elif scale is not None:
line_specs = [('scale', float(scale), '#00cc88')]
for label, value, color in line_specs:
if value is None or not np.isfinite(value) or value <= 0:
continue
x_end = min(float(vals[-1]), float(vals[-1]) / value)
ax.plot([0, x_end], [0, x_end * value],
color=color, linewidth=1.5,
linestyle='-', alpha=0.9,
label=f'{label} = {value:.3f}')
ax.set_xlim(edges[0], edges[-1])
ax.set_ylim(edges[0], edges[-1])
ax.set_aspect('equal')
ax.set_xlabel('pre-stain brightness', fontsize=11)
ax.set_ylabel('post-stain brightness', fontsize=11)
ax.set_title(f'{prefix}post-stain excess (upper triangle only) '
f'(n={n_pixels:,} pixels)', fontsize=10)
ax.legend(fontsize=8, loc='upper left')
plt.tight_layout()
if save_path:
fig.savefig(save_path, dpi=150, bbox_inches='tight')
print(f" saved: {save_path}")
plt.close(fig)
else:
plt.show()
return fig
[docs]
def plot_pre_post_histograms(post_im, pre_im,
raw_pre_im=None,
sample='', dye='', save_path=None):
"""
Overlapping histograms of pre-stain and post-stain pixel values.
Apples-to-apples comparison: when ``raw_pre_im`` is provided, the
foreground pre histogram uses the **raw padded pre-stain image**
(discrete 8-bit values, no warp interpolation), making it directly
comparable to the post histogram which is also discrete 8-bit.
The warped/interpolated pre is drawn underneath in faint grey
for reference so the effect of the alignment warp on the
distribution is still visible.
Padding-region zeros (where both pre and post are 0) are excluded
via a sample-region mask defined as ``(post > 0) | (pre > 0)``.
Interior dark pixels (zero in the sample region) are retained.
Bin edges are integer-aligned (-0.5, 0.5, …) and adapt to the
actual range and density of observed values to avoid empty bins
when the data are sparse.
The y-axis is linear by default; if the zero-value bin is more
than 5× taller than the next-tallest bin (a common situation
when the sample has lots of interior dark pixels), the y-axis
switches to log so the rest of the distribution remains visible.
Parameters
----------
post_im : ndarray (2-D, float-like)
Post-stain image (the reference frame, always
``01_padded_post``).
pre_im : ndarray (2-D, float-like)
Aligned (warped, interpolated) pre-stain image
(``03_interior_aligned_pre`` or ``02_icp_aligned_pre``).
Plotted in faint grey as a background reference.
raw_pre_im : ndarray (2-D, float-like) or None
Raw padded pre-stain image (``01_padded_pre``) — discrete
8-bit values, no warp interpolation. When provided this is
the apples-to-apples comparison to the post histogram and
is plotted in the foreground. When ``None`` (legacy
callers), the function falls back to the previous two-curve
behaviour: warped pre and post only.
sample, dye : str
For title.
save_path : str or None
If set, save figure to this path.
Returns
-------
fig : matplotlib.Figure
"""
# Sample-region mask: anywhere any of the three images has signal
# is "in the sample". Anywhere all are zero is padding. This
# excludes padding from all distributions without throwing away
# interior dark pixels. ORing in the raw pre matters because it
# lives in a slightly different (unwarped) frame than post and
# warped pre — for large warps its sample footprint can extend
# beyond the post/warped sample mask, and we want to keep those
# pixels.
if raw_pre_im is not None:
sample_mask = (post_im > 0) | (pre_im > 0) | (raw_pre_im > 0)
else:
sample_mask = (post_im > 0) | (pre_im > 0)
post = post_im[sample_mask].astype(np.float64).ravel()
pre_warped = pre_im[sample_mask].astype(np.float64).ravel()
if raw_pre_im is not None:
raw_pre = raw_pre_im[sample_mask].astype(np.float64).ravel()
else:
raw_pre = None
# ── Adaptive bin edges ────────────────────────────────────────
# Use integer-aligned bins (each integer at the centre of its
# bin). Range is determined by the foreground data (post + raw
# pre if available, else post + warped pre). 99.9th percentile
# caps the upper edge so a few bright outliers don't waste 90%
# of the plot width.
if raw_pre is not None and len(raw_pre):
hi_data = np.concatenate([post, raw_pre])
else:
hi_data = np.concatenate([post, pre_warped])
if len(hi_data) == 0:
# Degenerate case: nothing to plot. Bail out with an empty
# figure rather than crashing.
fig, ax = plt.subplots(figsize=(9, 4))
ax.set_title(f'{_title_prefix(sample, dye)}'
'pre vs post: no in-sample pixels')
if save_path:
fig.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close(fig)
return fig
lo = 0 # always start at 0 so interior dark pixels are visible
hi = max(1, int(np.ceil(np.percentile(hi_data, 99.9))))
# Sparsity check: count distinct integer values present in the
# foreground data over [lo, hi]. If they fill more than ~30% of
# the integer range, use one bin per integer. Otherwise use one
# bin per *observed* value — avoids long combs of empty bins
# when only a handful of brightnesses appear (e.g. heavily
# quantised raw data).
foreground_ints = np.unique(np.rint(hi_data[hi_data <= hi]).astype(np.int64))
foreground_ints = foreground_ints[foreground_ints >= lo]
span = hi - lo + 1
if len(foreground_ints) >= 0.30 * span:
# Dense — one bin per integer.
edges = np.arange(lo - 0.5, hi + 1.5, 1.0)
x_positions = np.arange(lo, hi + 1)
else:
# Sparse — one bin per observed value. Edges sit halfway
# between consecutive observed values so each bin centres
# on its value.
if len(foreground_ints) == 1:
v = foreground_ints[0]
edges = np.array([v - 0.5, v + 0.5])
else:
mids = (foreground_ints[:-1] + foreground_ints[1:]) / 2.0
edges = np.concatenate([
[foreground_ints[0] - 0.5],
mids,
[foreground_ints[-1] + 0.5],
])
x_positions = foreground_ints
prefix = _title_prefix(sample, dye)
fig, ax = plt.subplots(figsize=(9, 4))
# ── Background layer: warped (interpolated) pre, faint grey ──
# Drawn first so it sits beneath the foreground. We keep this
# so users can see how much the alignment warp smeared the
# distribution compared to the raw pre.
if raw_pre is not None:
ax.hist(pre_warped, bins=edges,
histtype='stepfilled', color='#888888', alpha=0.18,
edgecolor='#888888', linewidth=0.8,
label=f'pre-stain, warped (n={len(pre_warped):,})',
zorder=1)
# ── Foreground: raw pre (or warped pre as fallback) and post ──
fg_pre = raw_pre if raw_pre is not None else pre_warped
fg_pre_label = ('pre-stain, raw' if raw_pre is not None
else 'pre-stain')
pre_counts, _, _ = ax.hist(
fg_pre, bins=edges,
histtype='stepfilled', color='#2196a0', alpha=0.45,
edgecolor='#2196a0', linewidth=1.2,
label=f'{fg_pre_label} (n={len(fg_pre):,})',
zorder=2)
post_counts, _, _ = ax.hist(
post, bins=edges,
histtype='stepfilled', color='#e05c2a', alpha=0.45,
edgecolor='#e05c2a', linewidth=1.2,
label=f'post-stain (n={len(post):,})',
zorder=3)
# ── Decide y-scale ────────────────────────────────────────────
# If the zero bin is more than 10× taller than the second-tallest
# bin in either foreground distribution, switch to log y so the
# rest of the histogram doesn't get crushed to the baseline.
use_log = False
for counts in (pre_counts, post_counts):
if len(counts) < 2:
continue
# The zero bin is whichever bin contains value 0. With our
# integer-aligned edges that's the first bin if lo == 0.
zero_bin_idx = 0 if x_positions[0] == 0 else None
if zero_bin_idx is None:
continue
zero_count = counts[zero_bin_idx]
# Second-tallest = max over all non-zero bins.
others = np.concatenate([counts[:zero_bin_idx],
counts[zero_bin_idx + 1:]])
if len(others) == 0:
continue
second = others.max()
if second > 0 and zero_count > 5 * second:
use_log = True
break
if use_log:
ax.set_yscale('log')
y_label = 'pixel count (log)'
else:
y_label = 'pixel count'
ax.set_xlim(lo - 0.5, hi + 0.5)
ax.set_xlabel('pixel brightness', fontsize=11)
ax.set_ylabel(y_label, fontsize=10)
ax.set_title(f'{prefix}pre vs post pixel value distributions',
fontsize=10)
ax.legend(fontsize=9, loc='best')
plt.tight_layout()
if save_path:
fig.savefig(save_path, dpi=150, bbox_inches='tight')
print(f" saved: {save_path}")
plt.close(fig)
else:
plt.show()
return fig
[docs]
def plot_difference_histogram(post_im, pre_im,
sample='', dye='', save_path=None):
"""
Histogram of raw pixel-wise difference (post - pre).
Three overlaid distributions:
- All pixels (outline only)
- Pixels where post > 0 and pre == 0 (post-only signal)
- Pixels where both post > 0 and pre > 0 (shared signal)
Linear x-axis, log y-axis.
Parameters
----------
post_im, pre_im : ndarray (2-D, float-like)
Post-stain and aligned pre-stain images.
sample, dye : str
For title.
save_path : str or None
If set, save figure to this path.
Returns
-------
fig : matplotlib.Figure
"""
post = post_im.ravel().astype(np.float64)
pre = pre_im.ravel().astype(np.float64)
diff_all = np.rint(post - pre).astype(np.int32)
# Subsets
post_only = (post > 0) & (pre == 0)
both = (post > 0) & (pre > 0)
diff_post_only = diff_all[post_only]
diff_both = diff_all[both]
# Integer-aligned bin edges
lo = int(diff_all.min())
hi = int(diff_all.max())
edges = np.arange(lo - 0.5, hi + 1.5, 1.0)
prefix = _title_prefix(sample, dye)
fig, ax = plt.subplots(figsize=(12, 4))
# All pixels: outline only
ax.hist(diff_all, bins=edges,
histtype='step', color='#7b4fa6', linewidth=1.2,
alpha=0.7,
label=f'all pixels (n={len(diff_all):,})')
# Both > 0: filled
ax.hist(diff_both, bins=edges,
histtype='stepfilled', color='#2196a0', alpha=0.45,
edgecolor='#2196a0', linewidth=0.8,
label=f'both > 0 (n={len(diff_both):,})')
# Post > 0, pre == 0: filled
ax.hist(diff_post_only, bins=edges,
histtype='stepfilled', color='#e05c2a', alpha=0.45,
edgecolor='#e05c2a', linewidth=0.8,
label=f'post > 0, pre = 0 (n={len(diff_post_only):,})')
ax.axvline(0, color='black', linewidth=1.2, linestyle='--',
alpha=0.7, label='zero')
ax.set_yscale('log')
ax.set_xlabel('post − pre (pixel brightness difference)',
fontsize=11)
ax.set_ylabel('pixel count', fontsize=10)
ax.set_title(f'{prefix}post − pre difference distribution', fontsize=10)
ax.legend(fontsize=9)
plt.tight_layout()
if save_path:
fig.savefig(save_path, dpi=150, bbox_inches='tight')
print(f" saved: {save_path}")
plt.close(fig)
else:
plt.show()
return fig
[docs]
def plot_ratio_histogram_simple(post_im, pre_im, n_bins=200,
smooth_sigma=3,
sample='', dye='', save_path=None):
"""
Histogram of per-pixel post/pre ratio with scale estimation.
Plotted in log₁₀ space. Estimates the background scale factor
from the smoothed histogram peak, then mirrors the left wing
across the peak and fits a Moffat profile to model the noise.
Parameters
----------
post_im, pre_im : ndarray (2-D, float-like)
Post-stain and aligned pre-stain images.
n_bins : int
Number of histogram bins (default 200).
smooth_sigma : float
Gaussian smoothing sigma (in bins) for peak finding (default 3).
sample, dye : str
For title.
save_path : str or None
If set, save figure to this path.
Returns
-------
fig : matplotlib.Figure
scale_estimate : float
Estimated background scale factor (in linear units).
"""
from scipy.ndimage import uniform_filter1d
from scipy.optimize import curve_fit as _curve_fit
post = post_im.ravel().astype(np.float64)
pre = pre_im.ravel().astype(np.float64)
# Only pixels where both have signal
both = (post > 0) & (pre > 0)
post_b = post[both]
pre_b = pre[both]
ratio = post_b / pre_b
log_ratio = np.log10(ratio)
lo = float(log_ratio.min())
hi = float(log_ratio.max())
# Histogram
counts, edges_h = np.histogram(log_ratio, bins=n_bins, range=(lo, hi))
centres = (edges_h[:-1] + edges_h[1:]) / 2.0
counts_f = counts.astype(np.float64)
bin_width = float(edges_h[1] - edges_h[0])
# Smooth for initial peak finding
counts_smooth = uniform_filter1d(counts_f, size=max(int(smooth_sigma * 2 + 1), 3))
# Find peak of smoothed histogram, EXCLUDING bins near ratio=1
log_one = 0.0
exclude_radius = 3 * bin_width
peak_candidates = counts_smooth.copy()
near_one = np.abs(centres - log_one) <= exclude_radius
peak_candidates[near_one] = 0 # zero out ratio=1 neighborhood
peak_bin = int(np.argmax(peak_candidates))
log_scale_init = float(centres[peak_bin])
# --- Moffat noise fit from left wing mirrored across peak ---
# Exclude bins near ratio=1 from fitting
not_near_one = ~near_one
left_mask = (centres <= log_scale_init) & not_near_one
x_left = centres[left_mask]
y_left = counts_f[left_mask]
# Mirror left wing across initial peak
x_mirror = 2.0 * log_scale_init - x_left
y_mirror = y_left.copy()
# Exclude mirrored points near ratio=1
mirror_ok = np.abs(x_mirror - log_one) > exclude_radius
x_mirror = x_mirror[mirror_ok]
y_mirror = y_mirror[mirror_ok]
# Combine for fitting
x_fit = np.concatenate([x_left, x_mirror])
y_fit = np.concatenate([y_left, y_mirror])
# Moffat profile: amp * (1 + ((x - mu) / alpha)^2)^(-beta)
# beta controls peakedness: beta=1 is Lorentzian, large beta -> Gaussian
def _moffat(x, amp, mu, alpha, beta):
return amp * (1.0 + ((x - mu) / alpha) ** 2) ** (-beta)
fit_x = np.linspace(lo, hi, 500)
fit_y = None
log_scale = log_scale_init # will be refined by fit
if len(x_fit) > 5 and y_fit.max() > 0:
# Initial guesses
alpha0 = float(max(log_scale_init - lo, 0.1) / 3.0)
amp0 = float(y_fit.max())
p0 = [amp0, log_scale_init, alpha0, 2.5]
try:
popt, _ = _curve_fit(_moffat, x_fit, y_fit, p0=p0,
maxfev=10000,
bounds=([0, lo, 1e-6, 1.0],
[np.inf, hi, hi - lo, 20.0]))
amp_fit, mu_fit, alpha_fit, beta_fit = popt
fit_y = _moffat(fit_x, amp_fit, mu_fit, alpha_fit, beta_fit)
# Refine scale from the fit centre
log_scale = mu_fit
print(f" Moffat fit: mu={mu_fit:.4f} alpha={alpha_fit:.4f} "
f"beta={beta_fit:.2f}")
except Exception as e:
print(f" Moffat fit failed: {e} — using smoothed peak")
scale_estimate = float(10 ** log_scale)
# --- Plot ---
prefix = _title_prefix(sample, dye)
fig, ax = plt.subplots(figsize=(10, 4))
# Main histogram
ax.hist(log_ratio, bins=n_bins, range=(lo, hi),
histtype='stepfilled', color='#2196a0', alpha=0.5,
edgecolor='#2196a0', linewidth=0.8,
label=f'both > 0 (n={len(log_ratio):,})')
# Voigt noise estimate
if fit_y is not None:
ax.fill_between(fit_x, fit_y, alpha=0.35,
color='#888888', zorder=2,
label='noise estimate (Moffat)')
ax.plot(fit_x, fit_y, color='#888888', linewidth=1.5,
linestyle='-', zorder=3)
# Fitted points (left wing + mirror), excluding near ratio=1
if len(x_fit) > 0:
ax.scatter(x_left, y_left, color='#999999', s=8,
alpha=0.6, zorder=4, marker='o',
label='fit bins (left of peak)')
ax.scatter(x_mirror, y_mirror, color='#bbbbbb', s=8,
alpha=0.5, zorder=4, marker='^',
label='mirrored bins')
# Reference lines
ax.axvline(0, color='black', linewidth=1.0, linestyle='--',
alpha=0.5, label='ratio = 1')
ax.axvline(log_scale, color='#e05c2a', linewidth=2, linestyle='-',
alpha=0.9,
label=f'scale estimate = {scale_estimate:.3f}')
# Tick labels in ratio units
def _log_fmt(val, pos):
return f'{10**val:.2g}'
ax.xaxis.set_major_formatter(mpl.ticker.FuncFormatter(_log_fmt))
ax.set_xlabel('post / pre ratio (log scale)', fontsize=11)
ax.set_ylabel('pixel count', fontsize=10)
ax.set_title(f'{prefix}post/pre ratio distribution '
f'(n={len(log_ratio):,} pixels where both > 0)',
fontsize=10)
ax.legend(fontsize=8)
plt.tight_layout()
if save_path:
fig.savefig(save_path, dpi=150, bbox_inches='tight')
print(f" saved: {save_path}")
plt.close(fig)
else:
plt.show()
print(f" Scale estimate (ratio peak): {scale_estimate:.4f}")
return fig, scale_estimate
[docs]
def plot_difference_image(post_im, pre_im, scale,
sample='', dye='', save_path=None):
"""
Plot the scaled difference image: post − scale × pre.
Shows the image with an asinh stretch and a diverging colormap
so positive (microbe) signal is visually distinct from negative
(over-subtraction) regions.
Parameters
----------
post_im, pre_im : ndarray (2-D, float-like)
Post-stain and aligned pre-stain images.
scale : float
Scale factor to apply to pre-stain before subtraction.
sample, dye : str
For title.
save_path : str or None
If set, save figure to this path.
Returns
-------
fig : matplotlib.Figure
diff : ndarray
The difference image (float32).
"""
post = post_im.astype(np.float32)
pre = pre_im.astype(np.float32)
diff = post - scale * pre
# Asinh stretch for display
tissue = (post > 0) | (pre > 0)
tissue_px = diff[tissue]
if len(tissue_px) > 0:
knee = max(float(np.percentile(np.abs(tissue_px), 50)), 1.0)
else:
knee = 1.0
stretched = np.arcsinh(diff / knee)
# Symmetric color limits from tissue pixels
stretched_tissue = stretched[tissue]
if len(stretched_tissue) > 0:
sv = float(np.percentile(np.abs(stretched_tissue), 95))
else:
sv = 1.0
# Mask background
display = stretched.copy()
display[~tissue] = np.nan
prefix = _title_prefix(sample, dye)
fig, ax = plt.subplots(figsize=(12, 10))
cmap = _make_diverging_cmap()
img = ax.imshow(display, vmin=-sv, vmax=sv, cmap=cmap,
interpolation='nearest')
# Colorbar pinned to plot height via make_axes_locatable
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='4%', pad=0.08)
cb = fig.colorbar(img, cax=cax)
cb.set_label(f'asinh((post − {scale:.3f}×pre) / {knee:.0f})',
fontsize=9)
ax.set_title(f'{prefix}post − {scale:.3f} × pre', fontsize=11)
ax.axis('off')
plt.tight_layout()
if save_path:
fig.savefig(save_path, dpi=150, bbox_inches='tight')
print(f" saved: {save_path}")
plt.close(fig)
else:
plt.show()
return fig, diff
# ==============================================================================
# ZOOM / INSPECTION
# ==============================================================================
[docs]
def plot_zoom(image, row, col, size, sigma=0.0,
cmap='gray', stretch_percentile=99.0,
diverging=False, title='', save_path=None):
"""
Crop a square region from an image, optionally smooth it, and display.
Useful for inspecting fine structure in large microscopy images or
difference images. The crop is bounds-checked so near-edge coordinates
work without raising.
Parameters
----------
image : ndarray (2-D)
Source image to crop from. Can be any float or integer dtype.
row, col : int
Top-left corner of the crop region in pixel coordinates.
size : int
Side length of the square crop in pixels.
sigma : float
Gaussian blur sigma applied to the crop (default 0 = no blur).
cmap : str
Matplotlib colormap name (default 'gray'). Ignored if
``diverging=True``.
stretch_percentile : float
For non-diverging display, clip values above this percentile of
the crop for display (default 99.0). Lower values push faint
features harder.
diverging : bool
If True, use a symmetric diverging colormap centred at zero
(appropriate for difference images). Overrides ``cmap``.
title : str
Figure title.
save_path : str or None
If set, save figure to this path.
Returns
-------
fig : matplotlib.Figure
crop : ndarray (2-D)
The cropped (and smoothed, if sigma > 0) region.
"""
h, w = image.shape[:2]
# Clamp to image bounds
r0 = int(max(0, min(row, h - 1)))
c0 = int(max(0, min(col, w - 1)))
r1 = int(max(r0 + 1, min(r0 + size, h)))
c1 = int(max(c0 + 1, min(c0 + size, w)))
crop = image[r0:r1, c0:c1].astype(np.float32)
if sigma and sigma > 0:
crop = gaussian_filter(crop, sigma=float(sigma))
fig, ax = plt.subplots(figsize=(8, 8))
if diverging:
# Symmetric colour limits
finite = crop[np.isfinite(crop)]
if len(finite) > 0:
sv = float(np.percentile(np.abs(finite), stretch_percentile))
if sv <= 0:
sv = 1.0
else:
sv = 1.0
dcmap = _make_diverging_cmap()
img = ax.imshow(crop, vmin=-sv, vmax=sv, cmap=dcmap,
interpolation='nearest')
else:
finite = crop[np.isfinite(crop)]
if len(finite) > 0:
vmin = float(np.percentile(finite, 100 - stretch_percentile))
vmax = float(np.percentile(finite, stretch_percentile))
if vmax <= vmin:
vmax = vmin + 1.0
else:
vmin, vmax = 0.0, 1.0
img = ax.imshow(crop, vmin=vmin, vmax=vmax, cmap=cmap,
interpolation='nearest')
fig.colorbar(img, ax=ax, pad=0.02, shrink=0.8)
sigma_str = f' σ={sigma:g}' if sigma else ''
full_title = title if title else 'zoom'
ax.set_title(f'{full_title} [{r0}:{r1}, {c0}:{c1}]{sigma_str}',
fontsize=10)
ax.axis('off')
plt.tight_layout()
if save_path:
fig.savefig(save_path, dpi=150, bbox_inches='tight')
print(f" saved: {save_path}")
plt.close(fig)
else:
plt.show()
return fig, crop
[docs]
def plot_zoom_multi(images, labels, row, col, size, sigma=0.0,
cmaps=None, diverging_flags=None,
stretch_percentile=99.0,
sample='', dye='', save_path=None):
"""
Side-by-side zoomed crops from multiple images at the same coordinates.
Useful for comparing post-stain, aligned pre-stain, and difference
image at a common region of interest.
Parameters
----------
images : list of ndarray
2-D images to crop from. All must have the same shape.
labels : list of str
Panel titles, one per image.
row, col, size, sigma : int / int / int / float
Crop geometry and blur, as in :func:`plot_zoom`.
cmaps : list of str or None
Per-image colormap names. Default: 'gray' for each.
diverging_flags : list of bool or None
If set, per-image flag to use a diverging symmetric colormap
(typically True for the difference image).
stretch_percentile : float
Display stretch percentile.
sample, dye : str
For title prefix.
save_path : str or None
If set, save figure to this path.
Returns
-------
fig : matplotlib.Figure
crops : list of ndarray
"""
n = len(images)
if n == 0:
return None, []
if cmaps is None:
cmaps = ['gray'] * n
if diverging_flags is None:
diverging_flags = [False] * n
fig, axs = plt.subplots(1, n, figsize=(6 * n, 6))
if n == 1:
axs = [axs]
crops = []
for ax, im, lab, cmap, is_div in zip(axs, images, labels,
cmaps, diverging_flags):
h, w = im.shape[:2]
r0 = int(max(0, min(row, h - 1)))
c0 = int(max(0, min(col, w - 1)))
r1 = int(max(r0 + 1, min(r0 + size, h)))
c1 = int(max(c0 + 1, min(c0 + size, w)))
crop = im[r0:r1, c0:c1].astype(np.float32)
if sigma and sigma > 0:
crop = gaussian_filter(crop, sigma=float(sigma))
crops.append(crop)
if is_div:
finite = crop[np.isfinite(crop)]
sv = (float(np.percentile(np.abs(finite), stretch_percentile))
if len(finite) > 0 else 1.0)
if sv <= 0:
sv = 1.0
dcmap = _make_diverging_cmap()
img = ax.imshow(crop, vmin=-sv, vmax=sv, cmap=dcmap,
interpolation='nearest')
else:
finite = crop[np.isfinite(crop)]
if len(finite) > 0:
vmin = float(np.percentile(finite, 100 - stretch_percentile))
vmax = float(np.percentile(finite, stretch_percentile))
if vmax <= vmin:
vmax = vmin + 1.0
else:
vmin, vmax = 0.0, 1.0
img = ax.imshow(crop, vmin=vmin, vmax=vmax, cmap=cmap,
interpolation='nearest')
fig.colorbar(img, ax=ax, pad=0.02, shrink=0.8)
ax.set_title(lab, fontsize=10)
ax.axis('off')
sigma_str = f' σ={sigma:g}' if sigma else ''
prefix = _title_prefix(sample, dye)
fig.suptitle(f'{prefix}zoom [{row}:{row+size}, {col}:{col+size}]{sigma_str}',
fontsize=11, y=1.02)
plt.tight_layout()
if save_path:
fig.savefig(save_path, dpi=150, bbox_inches='tight')
print(f" saved: {save_path}")
plt.close(fig)
else:
plt.show()
return fig, crops