Source code for xflow.utils.visualization
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from .typing import ImageLike
def to_numpy_image(img: ImageLike) -> np.ndarray:
"""
Convert various image formats to a 2D/3D numpy array suitable for display.
Works for CPU/GPU PyTorch tensors, TF eager tensors, PIL, and numpy arrays.
"""
import torch
if isinstance(img, torch.Tensor): # PyTorch tensor
arr = img.detach().cpu().numpy()
elif hasattr(img, "numpy"): # TF tensor
arr = img.numpy()
elif isinstance(img, Image.Image): # PIL
arr = np.array(img)
elif isinstance(img, np.ndarray): # already numpy
arr = img
else: # fallback
arr = np.array(img)
# Normalize shape for display
if arr.ndim == 4: # (B, C, H, W) or (B, H, W, C) → take first
arr = arr[0]
if arr.ndim == 3 and arr.shape[0] in (1, 3): # channel-first → channel-last
arr = np.transpose(arr, (1, 2, 0))
if arr.ndim == 3 and arr.shape[-1] == 1: # single channel → squeeze
arr = arr[..., 0]
return arr
[docs]
def plot_image(
img: ImageLike,
cmap: str = None,
title: str = None,
figsize: tuple = None,
vmin: float = None,
vmax: float = None,
colorbar: bool = True,
) -> None:
"""
Plot an image using matplotlib.
Args:
img: Image in any supported format (will be converted automatically)
cmap: Colormap to use (auto-detected if None)
title: Plot title
figsize: Figure size tuple
vmin: Minimum pixel value for color scaling (auto if None)
vmax: Maximum pixel value for color scaling (auto if None)
colorbar: Whether to show the colorbar (default True)
"""
arr = to_numpy_image(img)
if cmap is None:
cmap = "gray" if arr.ndim == 2 else None
if figsize:
plt.figure(figsize=figsize)
plt.imshow(arr, cmap=cmap, vmin=vmin, vmax=vmax)
plt.xlabel("X (pixel index)")
plt.ylabel("Y (pixel index)")
if colorbar:
plt.colorbar(label="Pixel value")
if title:
plt.title(title)
plt.tight_layout()
plt.show()
def save_image(
img: ImageLike,
path: str,
cmap: str | None = None,
title: str | None = None,
figsize: tuple[float, float] = (6, 4),
dpi: int = 150,
) -> None:
"""
Same as plot_image, but saves to disk instead of showing.
Defaults ensure it runs without extra args.
"""
arr = to_numpy_image(img)
if cmap is None:
cmap = "gray" if arr.ndim == 2 else None
Path(path).parent.mkdir(parents=True, exist_ok=True)
fig, ax = plt.subplots(figsize=figsize)
im = ax.imshow(arr, cmap=cmap)
ax.set_xlabel("X (pixel index)")
ax.set_ylabel("Y (pixel index)")
# Colorbar can fail for RGB; keep behavior but make it safe
try:
fig.colorbar(im, ax=ax, label="Pixel value")
except Exception:
pass
if title:
ax.set_title(title)
fig.tight_layout()
fig.savefig(path, dpi=dpi)
plt.close(fig)