from typing import Any, Callable, Dict, List, Optional, Tuple, Any, Dict
import yaml
# Map unified event names to framework-specific hook method names
event_map = {
"train_start": {"tf": "on_train_begin", "pl": "on_train_start"},
"train_end": {"tf": "on_train_end", "pl": "on_train_end"},
"epoch_start": {"tf": "on_epoch_begin", "pl": "on_train_epoch_start"},
"epoch_end": {"tf": "on_epoch_end", "pl": "on_train_epoch_end"},
"batch_start": {"tf": "on_train_batch_begin", "pl": "on_train_batch_start"},
"batch_end": {"tf": "on_train_batch_end", "pl": "on_train_batch_end"},
}
[docs]
class CallbackRegistry:
"""Registry for callback handlers or factories."""
_handlers: Dict[str, Callable] = {}
[docs]
@classmethod
def register(cls, name: str):
def decorator(func: Callable):
cls._handlers[name] = func
return func
return decorator
[docs]
@classmethod
def get_handler(cls, name: str) -> Callable:
if name not in cls._handlers:
raise ValueError(f"Handler '{name}' not found")
return cls._handlers[name]
[docs]
@classmethod
def list_handlers(cls) -> List[str]:
return list(cls._handlers.keys())
def make_tf_callback(handlers: Dict[str, List[Callable]]):
from tensorflow.keras.callbacks import Callback
methods = {}
for event, fns in handlers.items():
if not isinstance(fns, (list, tuple)):
fns = [fns]
hook_name = event_map[event]["tf"]
def _make_hook(fns):
def _hook(self, *args, **kwargs):
for fn in fns:
fn(self, *args, **kwargs)
return _hook
methods[hook_name] = _make_hook(fns)
return type("UnifiedTFCallback", (Callback,), methods)()
def make_pl_callback(handlers: Dict[str, List[Callable]]):
import pytorch_lightning as pl
methods = {}
for event, fns in handlers.items():
if not isinstance(fns, (list, tuple)):
fns = [fns]
hook_name = event_map[event]["pl"]
def _make_hook(fns):
def _hook(self, trainer, pl_module, *args, **kwargs):
for fn in fns:
fn(self, trainer, pl_module, *args, **kwargs)
return _hook
methods[hook_name] = _make_hook(fns)
return type("UnifiedPLCallback", (pl.Callback,), methods)()
[docs]
def build_callbacks_from_config(
config: List[Dict[str, Any]],
framework: str,
name_key: str = "name",
params_key: str = "params",
) -> List[Any]:
"""
Build a list of callbacks (native or unified) from a config list.
Args:
config: List of callback config dicts. Each dict should have at least a 'name' key and optionally a 'params' dict.
framework: Which framework to use ('tf', 'pl', or 'torch').
name_key: Key in each config dict for the callback/factory name (default: 'name').
params_key: Key in each config dict for the callback/factory parameters (default: 'params').
Returns:
List of instantiated callback objects.
Each config entry may either:
1) Define only 'name' + 'params' → handler must return a Callback instance (native/factory style)
2) Define 'events' (list of {event, name, params}) → use unified wrapper for event-based callbacks
"""
callbacks = []
for cb in config:
if name_key not in cb:
raise ValueError(f"Callback config missing '{name_key}' key: {cb}")
name = cb[name_key]
params = cb.get(params_key, {}) or {}
handler = CallbackRegistry.get_handler(name)
# 1) Native callback factory: no events = direct instance
if not cb.get("events"):
instance = handler(**params)
callbacks.append(instance)
continue
# 2) Unified hook functions
handlers: Dict[str, List[Callable]] = {}
for evt in cb["events"]:
evt_name = evt["event"]
evt_handler_name = evt[name_key]
evt_handler = CallbackRegistry.get_handler(evt_handler_name)
evt_params = evt.get(params_key, {})
fn = evt_handler(**evt_params) if evt_params else evt_handler
handlers.setdefault(evt_name, []).append(fn)
if framework in ("tf", "tensorflow"):
callbacks.append(make_tf_callback(handlers))
elif framework in ("pl", "pytorch_lightning"):
callbacks.append(make_pl_callback(handlers))
elif framework in ("torch", "pytorch"):
callbacks.append(make_torch_callback(handlers))
else:
raise ValueError(f"Unsupported framework: {framework}")
return callbacks
# --- Handlers & Factories, Tensorflow native ---
@CallbackRegistry.register("tf_early_stopping")
def make_early_stopping(monitor: str = "val_loss", patience: int = 3, **kwargs):
from tensorflow.keras.callbacks import EarlyStopping
return EarlyStopping(monitor=monitor, patience=patience, **kwargs)
@CallbackRegistry.register("tf_model_checkpoint")
def make_model_checkpoint(
filepath: str, monitor: str = "val_loss", save_best_only: bool = True, **kwargs
):
from tensorflow.keras.callbacks import ModelCheckpoint
return ModelCheckpoint(
filepath=filepath, monitor=monitor, save_best_only=save_best_only, **kwargs
)
@CallbackRegistry.register("tf_model_checkpoint")
def make_tf_model_checkpoint(
filepath: str, monitor: str = "val_loss", save_best_only: bool = True, **kwargs
):
from tensorflow.keras.callbacks import ModelCheckpoint
return ModelCheckpoint(
filepath=filepath, monitor=monitor, save_best_only=save_best_only, **kwargs
)
@CallbackRegistry.register("tf_eta")
def make_eta_callback():
import time
import numpy as np
import tensorflow as tf
class ETACallback(tf.keras.callbacks.Callback):
def on_train_begin(self, logs=None):
self.times = []
def on_epoch_begin(self, epoch, logs=None):
self.start = time.time()
def on_epoch_end(self, epoch, logs=None):
elapsed = time.time() - self.start
self.times.append(elapsed)
avg = np.mean(self.times[-5:]) # smooth over last 5
remaining = (self.params["epochs"] - epoch - 1) * avg
if remaining > 3600:
hrs = remaining // 3600
mins = (remaining % 3600) // 60
print(f"Estimated time left: {hrs:.0f}h {mins:.0f}m")
else:
print(f"Estimated time left: {remaining:.1f}s")
return ETACallback()
# --- PyTorch (vanilla) Callback System ---
class PyTorchCallback:
"""Base class for PyTorch callbacks following common callback patterns."""
def on_train_begin(self, **kwargs):
"""Called at the beginning of training."""
pass
def on_train_end(self, **kwargs):
"""Called at the end of training."""
pass
def on_epoch_begin(self, epoch, **kwargs):
"""Called at the beginning of each epoch."""
pass
def on_epoch_end(self, epoch, logs=None, **kwargs):
"""Called at the end of each epoch."""
pass
def on_batch_begin(self, batch, **kwargs):
"""Called at the beginning of each batch."""
pass
def on_batch_end(self, batch, logs=None, **kwargs):
pass
def make_torch_callback(handlers: Dict[str, List[Callable]]):
"""Create a unified PyTorch callback from event handlers."""
class UnifiedTorchCallback(PyTorchCallback):
def __init__(self):
super().__init__()
self.handlers = handlers
def _call_handlers(self, event, *args, **kwargs):
"""Call all handlers for a given event."""
if event in self.handlers:
for handler in self.handlers[event]:
handler(self, *args, **kwargs)
def on_train_begin(self, **kwargs):
self._call_handlers("train_start", **kwargs)
def on_train_end(self, **kwargs):
self._call_handlers("train_end", **kwargs)
def on_epoch_begin(self, epoch, **kwargs):
self._call_handlers("epoch_start", epoch, **kwargs)
def on_epoch_end(self, epoch, logs=None, **kwargs):
self._call_handlers("epoch_end", epoch, logs=logs, **kwargs)
def on_batch_begin(self, batch, **kwargs):
self._call_handlers("batch_start", batch, **kwargs)
def on_batch_end(self, batch, logs=None, **kwargs):
self._call_handlers("batch_end", batch, logs=logs, **kwargs)
return UnifiedTorchCallback()
# --- PyTorch Native Callback Implementations ---
@CallbackRegistry.register("torch_eta")
def make_torch_eta_callback(total_epochs=None, smoothing=5, sink=None):
"""Create PyTorch ETA callback to estimate remaining training time."""
import time
class TorchETACallback(PyTorchCallback):
def __init__(self, total_epochs, smoothing, sink):
super().__init__()
self.total_epochs = total_epochs
self.smoothing = max(1, int(smoothing))
self.sink = sink if callable(sink) else print
self.times = []
self.start_time = None
def on_train_begin(self, epochs=None, **kwargs):
self.times.clear()
if epochs is not None:
self.total_epochs = epochs
def on_epoch_begin(self, epoch, **kwargs):
self.start_time = time.time()
def on_epoch_end(self, epoch, **kwargs):
if self.start_time is None:
return
elapsed = time.time() - self.start_time
self.times.append(elapsed)
window = self.times[-self.smoothing :]
avg_time = sum(window) / len(window)
if self.total_epochs is None:
self.sink(f"Avg epoch: {avg_time:.2f}s | done {epoch+1}")
return
remaining = (self.total_epochs - (epoch + 1)) * avg_time
h, rem = divmod(int(remaining + 0.5), 3600)
m, s = divmod(rem, 60)
if h:
self.sink(f"ETA: {h}h {m}m")
elif m:
self.sink(f"ETA: {m}m {s}s")
else:
self.sink(f"ETA: {s}s")
return TorchETACallback(total_epochs, smoothing, sink)
@CallbackRegistry.register("torch_early_stopping")
def make_torch_early_stopping(
monitor: str = "val_loss",
patience: int = 20,
min_delta: float = 0.0,
restore_best: bool = True,
mode: str = "min",
):
"""Create PyTorch EarlyStopping callback."""
import copy
class TorchEarlyStopping(PyTorchCallback):
def __init__(self):
super().__init__()
self.monitor = monitor
self.patience = patience
self.min_delta = min_delta
self.restore_best = restore_best
self.mode = mode
self.best = float("inf") if mode == "min" else float("-inf")
self.wait = 0
self.best_state = None
self.should_stop = False
def _is_better(self, current, best):
"""Check if current metric is better than best."""
if self.mode == "min":
return current < best - self.min_delta
else:
return current > best + self.min_delta
def on_epoch_end(self, epoch, logs=None, model=None, **kwargs):
"""Check if we should stop training."""
if logs is None or self.monitor not in logs:
return
current = logs[self.monitor]
if self._is_better(current, self.best):
self.best = current
self.wait = 0
if self.restore_best and model is not None:
try:
import torch
self.best_state = copy.deepcopy(model.state_dict())
except ImportError:
pass
else:
self.wait += 1
if self.wait >= self.patience:
self.should_stop = True
print(f"Early stopping triggered after {epoch + 1} epochs")
def on_train_end(self, model=None, **kwargs):
"""Restore best weights if requested."""
if self.restore_best and self.best_state is not None and model is not None:
try:
model.load_state_dict(self.best_state)
print(f"Restored best weights with {self.monitor}={self.best:.4f}")
except Exception as e:
print(f"Warning: Could not restore best weights: {e}")
return TorchEarlyStopping()
@CallbackRegistry.register("torch_model_checkpoint")
def make_torch_model_checkpoint(
filepath: str,
monitor: str = "val_loss",
save_best_only: bool = True,
mode: str = "min",
save_weights_only: bool = False,
):
"""Create PyTorch ModelCheckpoint callback."""
import os
class TorchModelCheckpoint(PyTorchCallback):
def __init__(self):
super().__init__()
self.filepath = filepath
self.monitor = monitor
self.save_best_only = save_best_only
self.mode = mode
self.save_weights_only = save_weights_only
self.best = float("inf") if mode == "min" else float("-inf")
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(filepath), exist_ok=True)
def _is_better(self, current, best):
"""Check if current metric is better than best."""
if self.mode == "min":
return current < best
else:
return current > best
def on_epoch_end(self, epoch, logs=None, model=None, **kwargs):
"""Save model checkpoint if conditions are met."""
if model is None:
return
should_save = True
if self.save_best_only and logs is not None and self.monitor in logs:
current = logs[self.monitor]
if self._is_better(current, self.best):
self.best = current
should_save = True
else:
should_save = False
if should_save:
try:
import torch
# Format filepath with epoch number
formatted_path = self.filepath.format(epoch=epoch)
if self.save_weights_only:
torch.save(model.state_dict(), formatted_path)
else:
torch.save(model, formatted_path)
print(f"Saved model checkpoint to {formatted_path}")
except ImportError:
print("Warning: PyTorch not available for saving checkpoint")
except Exception as e:
print(f"Warning: Could not save checkpoint: {e}")
return TorchModelCheckpoint()
@CallbackRegistry.register("torch_lr_scheduler")
def make_torch_lr_scheduler(scheduler_class: str = "StepLR", **scheduler_kwargs):
"""Create PyTorch learning rate scheduler callback."""
class TorchLRScheduler(PyTorchCallback):
def __init__(self):
super().__init__()
self.scheduler_class = scheduler_class
self.scheduler_kwargs = scheduler_kwargs
self.scheduler = None
def on_train_begin(self, optimizer=None, **kwargs):
"""Initialize scheduler with optimizer."""
if optimizer is None:
print("Warning: No optimizer provided to LR scheduler")
return
try:
import torch.optim.lr_scheduler as lr_scheduler
scheduler_cls = getattr(lr_scheduler, self.scheduler_class)
self.scheduler = scheduler_cls(optimizer, **self.scheduler_kwargs)
print(f"Initialized {self.scheduler_class} scheduler")
except ImportError:
print("Warning: PyTorch not available for LR scheduling")
except AttributeError:
print(f"Warning: Unknown scheduler class: {self.scheduler_class}")
except Exception as e:
print(f"Warning: Could not initialize scheduler: {e}")
def on_epoch_end(self, epoch, logs=None, **kwargs):
"""Step the learning rate scheduler."""
if self.scheduler is not None:
try:
# Some schedulers need validation loss
if hasattr(self.scheduler, "step") and logs is not None:
if "val_loss" in logs and hasattr(
self.scheduler, "_step_count"
):
# ReduceLROnPlateau needs metric
if "ReduceLR" in self.scheduler_class:
self.scheduler.step(logs["val_loss"])
else:
self.scheduler.step()
else:
self.scheduler.step()
# Log current learning rate
if hasattr(self.scheduler, "get_last_lr"):
current_lr = self.scheduler.get_last_lr()[0]
print(f"Learning rate: {current_lr:.6f}")
except Exception as e:
print(f"Warning: Error stepping scheduler: {e}")
return TorchLRScheduler()
@CallbackRegistry.register("torch_progress_bar")
def make_torch_progress_bar(desc: str = "Training"):
"""Create a simple progress bar callback for PyTorch."""
class TorchProgressBar(PyTorchCallback):
def __init__(self):
super().__init__()
self.desc = desc
self.total_epochs = None
def on_train_begin(self, epochs=None, **kwargs):
"""Initialize progress tracking."""
self.total_epochs = epochs
print(f"Starting {self.desc}")
def on_epoch_end(self, epoch, logs=None, **kwargs):
"""Update progress after each epoch."""
progress = f"Epoch {epoch + 1}"
if self.total_epochs:
progress += f"/{self.total_epochs}"
if logs:
metrics = " - ".join([f"{k}: {v:.4f}" for k, v in logs.items()])
progress += f" - {metrics}"
print(progress)
def on_train_end(self, **kwargs):
"""Finish progress tracking."""
print(f"Completed {self.desc}")
return TorchProgressBar()
@CallbackRegistry.register("torch_batch_progress_bar")
def make_torch_batch_progress_bar(
desc: str = "Training",
update_freq: int = 1,
show_metrics: bool = True,
bar_width: int = 30,
only_keys=None, # e.g. ["train_loss", "val_loss"]
hide_keys=None, # e.g. ["val_accuracy"]
):
"""Create a batch-level progress bar callback for PyTorch with detailed progress tracking.
Use:
make_torch_batch_progress_bar(only_keys=["train_loss", "val_loss"])
# or
make_torch_batch_progress_bar(hide_keys=["beam_param_metric"])
"""
import time
class TorchBatchProgressBar(PyTorchCallback):
def __init__(self):
super().__init__()
self.desc = desc
self.update_freq = max(1, update_freq)
self.show_metrics = show_metrics
self.bar_width = max(10, bar_width)
# filtering
self.only_keys = set(only_keys) if only_keys else None
self.hide_keys = set(hide_keys or [])
# Training state
self.total_epochs = None
self.current_epoch = 0
self.total_batches = None
self.current_batch = 0
self.epoch_start_time = None
self.batch_times = []
# ------------------------ helpers ------------------------
def _format_metrics(self, logs):
if not logs or not self.show_metrics:
return ""
items = list(logs.items())
if self.only_keys is not None:
items = [(k, v) for k, v in items if k in self.only_keys]
if self.hide_keys:
items = [(k, v) for k, v in items if k not in self.hide_keys]
if not items:
return ""
def _fmt_val(v):
try:
return f"{float(v):.4f}"
except Exception:
return str(v)
return " - " + " - ".join(f"{k}: {_fmt_val(v)}" for k, v in items)
# ------------------------ lifecycle ------------------------
def on_train_begin(self, epochs=None, **kwargs):
"""Initialize progress tracking."""
self.total_epochs = epochs
print(f"Starting {self.desc}")
if self.total_epochs:
print(f"Total epochs: {self.total_epochs}")
def on_epoch_begin(self, epoch, total_batches=None, **kwargs):
"""Start epoch progress tracking."""
self.current_epoch = epoch
self.total_batches = total_batches
self.current_batch = 0
self.epoch_start_time = time.time()
self.batch_times.clear()
epoch_info = f"Epoch {epoch + 1}"
if self.total_epochs:
epoch_info += f"/{self.total_epochs}"
if self.total_batches:
epoch_info += f" - {self.total_batches} batches"
print(f"\n{epoch_info}")
def on_batch_begin(self, batch=None, batch_idx=None, **kwargs):
"""Track batch start (supports either arg name)."""
b = batch if batch is not None else batch_idx
if b is not None:
self.current_batch = b
def on_batch_end(self, batch=None, batch_idx=None, logs=None, **kwargs):
"""Update progress bar after each batch."""
b = batch if batch is not None else batch_idx
if b is None:
b = self.current_batch
self.current_batch = b + 1
# Update every N batches or at the end
should_update = ((b + 1) % self.update_freq == 0) or (
self.total_batches and (b + 1) == self.total_batches
)
if should_update:
self._update_progress_bar(logs)
def on_epoch_end(self, epoch, logs=None, **kwargs):
"""Finalize epoch progress."""
# Ensure final update
self._update_progress_bar(logs, force_complete=True)
# Show epoch summary
epoch_time = (
time.time() - self.epoch_start_time if self.epoch_start_time else 0
)
summary = f"\nEpoch {epoch + 1} completed in {epoch_time:.2f}s"
summary += self._format_metrics(logs)
print(summary)
def on_train_end(self, **kwargs):
"""Finish progress tracking."""
print(f"\n{self.desc} completed!")
# ------------------------ rendering ------------------------
def _update_progress_bar(self, logs=None, force_complete=False):
"""Update the progress bar display."""
if self.total_batches is None:
# Simple counter if total unknown
progress = f"\rBatch {self.current_batch}"
progress += self._format_metrics(logs)
print(progress, end="", flush=True)
return
# Calculate progress
if force_complete:
progress_ratio = 1.0
current = self.total_batches
else:
progress_ratio = min(self.current_batch / self.total_batches, 1.0)
current = self.current_batch
# Create progress bar (TensorFlow style)
filled_width = int(self.bar_width * progress_ratio)
if progress_ratio < 1.0 and filled_width > 0:
bar = (
"=" * (filled_width - 1)
+ ">"
+ "." * (self.bar_width - filled_width)
)
elif progress_ratio >= 1.0:
bar = "=" * self.bar_width
else:
bar = "." * self.bar_width
# Percentage + ETA
percentage = progress_ratio * 100
if self.epoch_start_time and self.current_batch > 0:
elapsed = time.time() - self.epoch_start_time
if progress_ratio > 0:
total_estimated = elapsed / progress_ratio
remaining = max(0, total_estimated - elapsed)
eta = (
f"{remaining/60:.1f}m"
if remaining > 60
else f"{remaining:.0f}s"
)
else:
eta = "?"
else:
eta = "?"
# Build line
progress_str = f"\r[{bar}] {current}/{self.total_batches} ({percentage:5.1f}%) - ETA: {eta}"
progress_str += self._format_metrics(logs)
# Print with spacing to clear previous line
print(f"{progress_str:<120}", end="", flush=True)
return TorchBatchProgressBar()
@CallbackRegistry.register("torch_total_training_time")
def make_torch_training_time_only(save_dir: str):
import os, json, time
class TorchTrainingTimeOnly(PyTorchCallback):
def __init__(self):
super().__init__()
self.save_dir = save_dir
self._t0 = None
def on_train_begin(self, **kwargs):
self._t0 = time.time()
def on_train_end(self, **kwargs):
elapsed = 0.0 if self._t0 is None else (time.time() - self._t0)
os.makedirs(self.save_dir, exist_ok=True)
path = os.path.join(self.save_dir, "training_meta.json")
with open(path, "w", encoding="utf-8") as f:
json.dump({"framework": "torch", "total_wall_time_sec": round(elapsed, 4)}, f, indent=2)
print(f"[torch_training_time_only] Saved {path}")
return TorchTrainingTimeOnly()
@CallbackRegistry.register("torch_training_meta_info")
def make_torch_training_meta_info(
save_dir: str, # core arg required
example_input: Any = None, # e.g. a torch.Tensor or tuple of tensors
input_shape: Optional[Tuple[int, ...]] = None, # used to synthesize a dummy batch, core arg
method: str = "auto", # "auto" | "profiler" | "thop"
backward_factor: float = 2.0, # used when fallback to thop forward-only estimate
grad_accum_steps: int = 1,
model_name: Optional[str] = None,
):
import os, json, time, datetime, traceback
"""
PyTorch callback that records total wall-time and estimates training FLOPs,
then writes JSON to <save_dir>/training_meta.json.
Notes:
- If neither example_input nor input_shape is provided (and batch tensors
are not passed in kwargs), FLOPs may be null but timing/metadata still saved.
- FLOPs estimation is performed once (lazily) using a single train step,
then scaled by total steps.
"""
class TorchTrainingMetaInfo(PyTorchCallback):
def __init__(self):
super().__init__()
self.save_dir = save_dir
self.example_input = example_input
self.input_shape = input_shape
self.method = method
self.backward_factor = float(backward_factor)
self.grad_accum_steps = max(1, int(grad_accum_steps))
self.model_name = model_name
# time bookkeeping
self._train_start_ts = None
self._train_end_ts = None
self._wall_time_sec = 0.0
# progress bookkeeping (best-effort; will be inferred if not provided)
self.epochs_declared = None
self.epochs_seen = 0
self.total_batches_declared = None
self.batches_seen_this_epoch = 0
self.total_batches_seen = 0
self.batch_size_seen = None
# flops bookkeeping
self.per_step_flops = None
self.per_forward_flops = None
self.per_epoch_flops = None
self.total_training_flops = None
self.flops_method = None
self.flops_notes = None
# model info snapshot
self.param_count = None
self.device = None
self.world_size = None
self.local_rank = None
# lazily profiled?
self._did_profile_once = False
# ------------------- helpers -------------------
def _atomic_write_json(self, path: str, payload: Dict):
os.makedirs(os.path.dirname(path), exist_ok=True)
tmp_path = path + ".tmp"
with open(tmp_path, "w", encoding="utf-8") as f:
json.dump(payload, f, indent=2)
os.replace(tmp_path, path)
def _now_iso(self):
return datetime.datetime.now().isoformat(timespec="seconds")
def _count_params(self, model):
try:
import torch
return int(sum(p.numel() for p in model.parameters()))
except Exception:
return None
def _infer_device(self, model):
try:
import torch
p = next(model.parameters(), None)
if p is not None:
return str(p.device)
except Exception:
pass
return None
def _dist_info(self):
try:
import torch.distributed as dist
if dist.is_available() and dist.is_initialized():
return dist.get_world_size(), dist.get_rank()
except Exception:
pass
return None, None
def _make_dummy_from_shape(self, shape, device, dtype=None):
try:
import torch
if dtype is None:
dtype = torch.float32
return torch.randn(*shape, device=device, dtype=dtype)
except Exception:
return None
def _as_tuple(self, x):
return x if isinstance(x, (tuple, list)) else (x,)
def _extract_batch_from_kwargs(self, **kwargs):
# Best-effort: look for a tensor or tuple of tensors in common keys
candidates = []
for key in ["inputs", "batch", "data", "x", "batch_data"]:
if key in kwargs:
candidates.append(kwargs[key])
for c in candidates:
# accept tensor, or (tensor, y) pairs, or dict with 'inputs'
try:
import torch
if isinstance(c, torch.Tensor):
return (c,)
if isinstance(c, (tuple, list)) and len(c) > 0:
return self._as_tuple(c[0]) # assume first is input
if isinstance(c, dict) and "inputs" in c:
return self._as_tuple(c["inputs"])
except Exception:
pass
return None
def _ensure_example_input(self, model, **kwargs):
if self.example_input is not None:
return self._as_tuple(self.example_input)
# try from kwargs
kw_batch = self._extract_batch_from_kwargs(**kwargs)
if kw_batch is not None:
self._maybe_set_batch_size(kw_batch)
return self._as_tuple(kw_batch)
# try from input_shape
if self.input_shape is not None:
dev = self.device or self._infer_device(model) or "cpu"
dummy = self._make_dummy_from_shape(self.input_shape, dev)
return (dummy,) if dummy is not None else None
return None
def _maybe_set_batch_size(self, inputs_tuple):
try:
import torch
x0 = inputs_tuple[0]
if isinstance(x0, torch.Tensor) and x0.dim() >= 1:
b = int(x0.shape[0])
if b > 0:
self.batch_size_seen = b
except Exception:
pass
def _profile_one_train_step(self, model, optimizer=None, **kwargs):
"""
Try to measure FLOPs of a full train step (fwd+backward+step) once.
Sets per_step_flops and per_forward_flops on success.
"""
if self._did_profile_once:
return
inputs = self._ensure_example_input(model, **kwargs)
if inputs is None:
self.flops_notes = "No example input available; FLOPs not estimated."
return
# shallow copy model to avoid altering caller's state
m = model
per_step = None
fwd = None
notes = []
used_method = None
# Prefer profiler if available
if self.method in ("auto", "profiler"):
try:
import torch
import torch.profiler as prof
activities = [prof.ProfilerActivity.CPU]
if torch.cuda.is_available():
activities.append(prof.ProfilerActivity.CUDA)
m_train = m.train()
for p in m_train.parameters():
p.requires_grad_(True)
with torch.enable_grad():
with prof.profile(
activities=activities,
record_shapes=True,
with_flops=True,
profile_memory=False,
) as p:
out = m_train(*inputs)
# generic scalar loss; if not scalar, sum it
if isinstance(out, (tuple, list)):
out = out[0]
loss = out.sum() if hasattr(out, "sum") else (out if out.ndim == 0 else None)
if loss is None:
raise RuntimeError("Cannot reduce model output to scalar for backward()")
loss.backward()
if optimizer is not None:
optimizer.step()
optimizer.zero_grad(set_to_none=True)
ka = p.key_averages()
# sum flops across ops
total_flops = 0
for evt in ka:
# PyTorch profiler exposes flops via .flops (may be None)
val = getattr(evt, "flops", None)
if val is not None:
total_flops += int(val)
if total_flops > 0:
used_method = "torch.profiler"
per_step = int(total_flops)
# We also try a forward-only pass to estimate per-forward if desired
# Fallback: assume forward ≈ per_step / (1 + backward_factor)
fwd = int(round(per_step / max(1.0, 1.0 + self.backward_factor)))
else:
notes.append("Profiler returned zero flops; falling back.")
except Exception as e:
notes.append(f"Profiler failed: {e}")
# Fallback: THOP forward-only
if per_step is None and self.method in ("auto", "thop"):
try:
import torch
from thop import profile as thop_profile
m_eval = m.eval()
with torch.no_grad():
macs, _params = thop_profile(m_eval, inputs=inputs, verbose=False)
fwd = int(macs * 2) # MACs→FLOPs
used_method = "thop (forward-only)"
per_step = int(round(fwd * (1.0 + self.backward_factor)))
except Exception as e:
notes.append(f"THOP failed: {e}")
if per_step is None:
self.flops_notes = " | ".join(notes) if notes else "FLOPs estimation unavailable."
return
self.per_step_flops = per_step
self.per_forward_flops = fwd
self.flops_method = used_method
self.flops_notes = "Estimated from a single train step."
self._did_profile_once = True
def _finalize_flops_totals(self):
if self.per_step_flops is None:
return
steps = max(1, self.total_batches_seen // self.grad_accum_steps)
self.total_training_flops = int(self.per_step_flops * steps)
# Approximate per-epoch if we know per-epoch batches
if self.total_batches_declared:
epoch_steps = max(1, self.total_batches_declared // self.grad_accum_steps)
self.per_epoch_flops = int(self.per_step_flops * epoch_steps)
def _build_payload(self, torch_module=None):
versions = {}
try:
import torch
versions = {
"torch": getattr(torch, "__version__", None),
"cuda": getattr(torch.version, "cuda", None),
"cudnn": getattr(getattr(torch.backends, "cudnn", None), "version", lambda: None)(),
}
except Exception:
pass
payload = {
"framework": "torch",
"timestamp": self._now_iso(),
"model_name": self.model_name,
"device": self.device,
"ddp": {"world_size": self.world_size, "local_rank": self.local_rank},
"params": self.param_count,
"epochs": self.epochs_declared or self.epochs_seen or None,
"epochs_seen": self.epochs_seen,
"batches_per_epoch": self.total_batches_declared,
"total_batches_seen": self.total_batches_seen,
"batch_size": self.batch_size_seen,
"grad_accum_steps": self.grad_accum_steps,
"total_wall_time_sec": round(self._wall_time_sec, 4),
"flops": {
"method": self.flops_method,
"per_forward_pass": self.per_forward_flops,
"per_train_step": self.per_step_flops,
"per_epoch": self.per_epoch_flops,
"total_training": self.total_training_flops,
"notes": self.flops_notes,
},
"versions": versions,
}
return payload
# ------------------- lifecycle -------------------
def on_train_begin(self, model=None, epochs=None, total_batches=None, **kwargs):
self._train_start_ts = time.time()
if epochs is not None:
self.epochs_declared = int(epochs)
if total_batches is not None:
self.total_batches_declared = int(total_batches)
if model is not None:
self.param_count = self._count_params(model)
self.device = self._infer_device(model)
self.world_size, self.local_rank = self._dist_info()
def on_epoch_begin(self, epoch, total_batches=None, **kwargs):
if total_batches is not None:
self.total_batches_declared = int(total_batches)
self.batches_seen_this_epoch = 0
def on_batch_begin(self, batch, **kwargs):
# Capture batch size if tensors are provided
inputs = self._extract_batch_from_kwargs(**kwargs)
if inputs is not None:
self._maybe_set_batch_size(inputs)
def on_batch_end(self, batch, model=None, optimizer=None, **kwargs):
self.batches_seen_this_epoch += 1
self.total_batches_seen += 1
# lazily profile on the very first step if possible
if not self._did_profile_once and model is not None:
try:
self._profile_one_train_step(model, optimizer=optimizer, **kwargs)
except Exception:
# Keep going; timing will still be recorded
pass
def on_epoch_end(self, epoch, **kwargs):
self.epochs_seen = max(self.epochs_seen, epoch + 1)
def on_train_end(self, model=None, **kwargs):
self._train_end_ts = time.time()
self._wall_time_sec = float(self._train_end_ts - (self._train_start_ts or self._train_end_ts))
# finalize FLOPs totals
self._finalize_flops_totals()
# write JSON
target = os.path.join(self.save_dir, "training_meta.json")
try:
payload = self._build_payload()
self._atomic_write_json(target, payload)
print(f"[torch_training_meta_info] Saved {target}")
except Exception as e:
print(f"[torch_training_meta_info] Failed to write meta: {e}")
try:
# last resort: dump traceback to a sidecar
side = target + ".error.txt"
with open(side, "w", encoding="utf-8") as f:
f.write("Exception while saving training_meta.json:\n")
traceback.print_exc(file=f)
print(f"[torch_training_meta_info] Error details saved to {side}")
except Exception:
pass
return TorchTrainingMetaInfo()
@CallbackRegistry.register("torch_grad_monitor")
def make_torch_grad_monitor(
save_dir: str,
track_per_param: bool = False, # False = per-layer summary only (recommended)
log_every_n_batches: int = 1, # collect every n batches (increase to reduce overhead)
clip_large_values_at: float = 0, # 0 = no clipping; else clip abs grads for robust stats
):
"""
Gradient-flow monitor (safe & robust).
- Registers .register_hook() on each requires_grad parameter.
- For each batch, gathers:
* global_grad_norm (L2)
* any_nan / any_inf flags
* per-layer {count, l2, l1, max_abs, mean_abs, std, zero_frac, finite_frac}
(per-parameter detail can be enabled but increases JSON size)
- Saves one JSON per epoch: grads_epoch_{epoch+1}.json
"""
import torch
from collections import defaultdict
import os, json, math, traceback
class TorchGradMonitor(PyTorchCallback):
def __init__(self):
super().__init__()
self.save_dir = save_dir
self.track_per_param = bool(track_per_param)
self.n = max(1, int(log_every_n_batches))
self.clip = float(clip_large_values_at)
self._hooks = []
self._epoch = -1
self._batch = -1
# per-batch scratch (filled by hooks during backward)
self._batch_global_sumsq = 0.0
self._batch_any_nan = False
self._batch_any_inf = False
# per-batch per-layer accumulators
self._batch_layer = defaultdict(lambda: {
"count": 0, "l2": 0.0, "l1": 0.0, "max_abs": 0.0,
"mean_abs_sum": 0.0, "std_sum": 0.0,
"zeros": 0, "numel": 0, "finite": 0
})
# per-epoch aggregated stats
self._epoch_batches = 0
self._epoch_global_norms = [] # sample per logged batch
self._epoch_any_nan = 0
self._epoch_any_inf = 0
self._epoch_layer = defaultdict(lambda: {
"count": 0, "l2": 0.0, "l1": 0.0, "max_abs": 0.0,
"mean_abs_sum": 0.0, "std_sum": 0.0,
"zeros": 0, "numel": 0, "finite": 0
})
# ----------------- helpers -----------------
def _atomic_write_json(self, path: str, payload: Dict[str, Any]):
os.makedirs(os.path.dirname(path), exist_ok=True)
tmp = path + ".tmp"
with open(tmp, "w", encoding="utf-8") as f:
json.dump(payload, f, indent=2)
os.replace(tmp, path)
def _safe_stats(self, g: torch.Tensor):
# Work on a view; no grad change
t = g.detach()
if self.clip > 0:
t = t.clamp(min=-self.clip, max=self.clip)
isfinite = torch.isfinite(t)
finite = t[isfinite]
numel = t.numel()
zeros = (t == 0).sum().item()
if finite.numel() == 0:
return {
"l2": 0.0, "l1": 0.0, "max_abs": 0.0,
"mean_abs": 0.0, "std": 0.0,
"zeros": zeros, "numel": numel, "finite": 0,
"any_nan": True, "any_inf": True
}
absf = finite.abs()
l2 = float(torch.linalg.vector_norm(finite, ord=2).item())
l1 = float(absf.sum().item())
max_abs = float(absf.max().item())
mean_abs = float(absf.mean().item())
std = float(finite.std(unbiased=False).item()) if finite.numel() > 1 else 0.0
any_nan = bool(torch.isnan(t).any().item())
any_inf = bool(torch.isinf(t).any().item())
return {
"l2": l2, "l1": l1, "max_abs": max_abs,
"mean_abs": mean_abs, "std": std,
"zeros": int(zeros), "numel": int(numel),
"finite": int(finite.numel()),
"any_nan": any_nan, "any_inf": any_inf,
}
def _attach_hooks(self, model):
# Remove old hooks if any
for h in self._hooks:
try:
h.remove()
except Exception:
pass
self._hooks.clear()
# Name parameters for per-layer aggregation
for name, p in model.named_parameters():
if not p.requires_grad:
continue
def make_hook(pname):
def _hook(grad):
s = self._safe_stats(grad)
# global
self._batch_global_sumsq += (s["l2"] ** 2)
self._batch_any_nan |= s["any_nan"]
self._batch_any_inf |= s["any_inf"]
# per-layer (group by module/param prefix)
layer = pname.rsplit(".", 1)[0] if "." in pname else pname
L = self._batch_layer[layer]
L["count"] += 1
L["l2"] += s["l2"]
L["l1"] += s["l1"]
L["max_abs"] = max(L["max_abs"], s["max_abs"])
L["mean_abs_sum"] += s["mean_abs"]
L["std_sum"] += s["std"]
L["zeros"] += s["zeros"]
L["numel"] += s["numel"]
L["finite"] += s["finite"]
if self.track_per_param:
# store minimal per-param details lazily
PP = self._batch_layer.setdefault(f"{layer}::{pname}", {
"count": 0, "l2": 0.0, "l1": 0.0, "max_abs": 0.0,
"mean_abs_sum": 0.0, "std_sum": 0.0,
"zeros": 0, "numel": 0, "finite": 0
})
PP["count"] += 1
PP["l2"] += s["l2"]
PP["l1"] += s["l1"]
PP["max_abs"] = max(PP["max_abs"], s["max_abs"])
PP["mean_abs_sum"] += s["mean_abs"]
PP["std_sum"] += s["std"]
PP["zeros"] += s["zeros"]
PP["numel"] += s["numel"]
PP["finite"] += s["finite"]
return _hook
self._hooks.append(p.register_hook(make_hook(name)))
def _reset_batch_accum(self):
self._batch_global_sumsq = 0.0
self._batch_any_nan = False
self._batch_any_inf = False
self._batch_layer.clear()
def _merge_batch_into_epoch(self):
if self._batch_global_sumsq > 0.0:
gnorm = math.sqrt(self._batch_global_sumsq)
self._epoch_global_norms.append(gnorm)
if self._batch_any_nan:
self._epoch_any_nan += 1
if self._batch_any_inf:
self._epoch_any_inf += 1
for k, v in self._batch_layer.items():
E = self._epoch_layer[k]
E["count"] += v["count"]
E["l2"] += v["l2"]
E["l1"] += v["l1"]
E["max_abs"] = max(E["max_abs"], v["max_abs"])
E["mean_abs_sum"] += v["mean_abs_sum"]
E["std_sum"] += v["std_sum"]
E["zeros"] += v["zeros"]
E["numel"] += v["numel"]
E["finite"] += v["finite"]
def _epoch_payload(self, epoch_idx: int):
# Build per-layer means
layer_stats = {}
for k, v in self._epoch_layer.items():
c = max(1, v["count"])
layer_stats[k] = {
"count": v["count"],
"l2_sum": round(v["l2"], 6),
"l1_sum": round(v["l1"], 6),
"max_abs": round(v["max_abs"], 6),
"mean_abs_mean": round(v["mean_abs_sum"] / c, 6),
"std_mean": round(v["std_sum"] / c, 6),
"zero_frac": round((v["zeros"] / v["numel"]) if v["numel"] else 0.0, 6),
"finite_frac": round((v["finite"] / v["numel"]) if v["numel"] else 0.0, 6),
}
return {
"epoch": epoch_idx + 1,
"batches_logged": self._epoch_batches,
"global_grad_norm": {
"num_samples": len(self._epoch_global_norms),
"mean": round(float(sum(self._epoch_global_norms) / max(1, len(self._epoch_global_norms))), 6),
"max": round(float(max(self._epoch_global_norms)) if self._epoch_global_norms else 0.0, 6),
"min": round(float(min(self._epoch_global_norms)) if self._epoch_global_norms else 0.0, 6),
},
"any_nan_batches": self._epoch_any_nan,
"any_inf_batches": self._epoch_any_inf,
"layers": layer_stats,
}
# --------------- lifecycle ----------------
def on_train_begin(self, model=None, **kwargs):
if model is not None:
self._attach_hooks(model)
def on_epoch_begin(self, epoch, **kwargs):
self._epoch = epoch
self._batch = -1
self._epoch_batches = 0
self._epoch_any_nan = 0
self._epoch_any_inf = 0
self._epoch_global_norms.clear()
self._epoch_layer.clear()
def on_batch_begin(self, batch, **kwargs):
self._batch = batch
self._reset_batch_accum()
def on_batch_end(self, batch, **kwargs):
# Only log every n batches to reduce overhead/size
if ((batch + 1) % self.n) == 0:
self._merge_batch_into_epoch()
self._epoch_batches += 1
self._reset_batch_accum() # reset for next accumulation window
def on_epoch_end(self, epoch, **kwargs):
# Flush any remainder accumulation
self._merge_batch_into_epoch()
payload = self._epoch_payload(epoch_idx=epoch)
try:
os.makedirs(self.save_dir, exist_ok=True)
path = os.path.join(self.save_dir, f"grads_epoch_{epoch+1:04d}.json")
self._atomic_write_json(path, payload)
print(f"[torch_grad_monitor] Saved {path}")
except Exception as e:
print(f"[torch_grad_monitor] Failed to save epoch grads: {e}")
try:
with open(os.path.join(self.save_dir, f"grads_epoch_{epoch+1:04d}.error.txt"), "w") as f:
traceback.print_exc(file=f)
except Exception:
pass
def on_train_end(self, **kwargs):
# Remove hooks
for h in self._hooks:
try:
h.remove()
except Exception:
pass
self._hooks.clear()
return TorchGradMonitor()
# Update the event map to include PyTorch (vanilla) support
event_map.update(
{
"train_start": {**event_map["train_start"], "torch": "on_train_begin"},
"train_end": {**event_map["train_end"], "torch": "on_train_end"},
"epoch_start": {**event_map["epoch_start"], "torch": "on_epoch_begin"},
"epoch_end": {**event_map["epoch_end"], "torch": "on_epoch_end"},
"batch_start": {**event_map["batch_start"], "torch": "on_batch_begin"},
"batch_end": {**event_map["batch_end"], "torch": "on_batch_end"},
}
)