Source code for xflow.trainers.trainer

"""trainer with unified callbacks and delegated model I/O across frameworks."""

from __future__ import annotations

import collections
import json
import os
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass, field
from typing import Any, Callable, Dict, Iterable, List, Optional, TYPE_CHECKING

from ..utils.io import create_directory
from ..utils.typing import ModelType, PathLikeStr

if TYPE_CHECKING:
    import torch
# ============================== Callback core ==============================


@dataclass
class CallbackContext:
    trainer: "BaseTrainer"
    model: Any
    optimizer: Any = None
    scheduler: Any = None
    device: Any = None
    # progress
    epochs: int = 0  # NEW: total epochs
    epoch: int = 0
    batch_idx: int = -1
    batch: int = 0  # NEW: alias for PyTorch-style callbacks
    global_step: int = 0
    total_batches: int = 0
    phase: str = "train"
    logs: Dict[str, float] = field(default_factory=dict)
    request_stop: bool = False


class Callback:
    # Subclass and override the hooks you need. Every method receives only `ctx`.
    def on_train_begin(self, ctx: CallbackContext): ...
    def on_train_end(self, ctx: CallbackContext): ...
    def on_epoch_begin(self, ctx: CallbackContext): ...
    def on_epoch_end(self, ctx: CallbackContext): ...
    def on_batch_begin(self, ctx: CallbackContext): ...
    def on_batch_end(self, ctx: CallbackContext): ...

    # Optional validation-specific hooks:
    def on_val_epoch_begin(self, ctx: CallbackContext): ...
    def on_val_epoch_end(self, ctx: CallbackContext): ...
    def on_val_batch_begin(self, ctx: CallbackContext): ...
    def on_val_batch_end(self, ctx: CallbackContext): ...


class CallbackDispatcher:
    def __init__(self, callbacks):
        self.cbs = callbacks or []

    def call(self, name, ctx):
        import inspect

        kw = asdict(ctx)
        for cb in self.cbs:
            fn = getattr(cb, name, None)
            if not callable(fn):
                continue
            try:
                params = inspect.signature(fn).parameters
                usable = {k: v for k, v in kw.items() if k in params}
                if usable:
                    fn(**usable)  # works with your current PyTorch callbacks
                else:
                    fn(ctx)  # also supports modern def on_*(self, ctx)
            except TypeError:
                fn()  # last fallback

    @property
    def should_stop(self) -> bool:
        return any(getattr(cb, "should_stop", False) for cb in self.cbs)


# ============================== Model I/O (optional) ==============================


class ModelIO:
    """Tiny adapter so the trainer can save models when the model has no save_model()."""

    def save(self, model: Any, path: str, extra: Optional[Dict[str, Any]] = None):
        # 1) If model exposes save_model, use it (preferred).
        if hasattr(model, "save_model") and callable(getattr(model, "save_model")):
            model.save_model(path)
            return
        # 2) Try tf.keras
        try:
            import tensorflow as tf  # type: ignore

            if isinstance(model, tf.keras.Model):
                os.makedirs(os.path.dirname(path), exist_ok=True)
                model.save(path)
                return
        except Exception:
            pass
        # 3) Try PyTorch state_dict
        try:
            import torch  # type: ignore

            if hasattr(model, "state_dict"):
                os.makedirs(os.path.dirname(path), exist_ok=True)
                ckpt = {"model_state": model.state_dict()}
                if extra:
                    ckpt.update(extra)
                if not (path.endswith(".pt") or path.endswith(".pth")):
                    path = path + ".pt"
                torch.save(ckpt, path)
                return
        except Exception:
            pass
        raise NotImplementedError(
            "No known way to save this model. Implement model.save_model(path) or provide a custom ModelIO."
        )


# ============================== Base trainer ==============================


[docs] class BaseTrainer(ABC): """ Thin orchestrator: runs loops, dispatches callbacks, collects history. Creation/compilation of model/optimizer is outside; inject everything in. """
[docs] def __init__( self, model: Any, data_pipeline: Any, output_dir: str, *, callbacks: Optional[List[Callback]] = None, model_io: Optional[ModelIO] = None, config: Optional[Dict[str, Any]] = None, ): if model is None: raise ValueError("model cannot be None") if data_pipeline is None: raise ValueError("data_pipeline cannot be None") if not output_dir: raise ValueError("output_dir is required") os.makedirs(output_dir, exist_ok=True) self.model = model self.data = data_pipeline self.output_dir = output_dir self.cb = CallbackDispatcher(callbacks) self.model_io = model_io or ModelIO() self.config = dict(config or {}) self.history: Dict[str, List[Any]] = collections.defaultdict(list)
[docs] def save_history(self, path: Optional[str] = None): path = path or os.path.join(self.output_dir, "history.json") with open(path, "w", encoding="utf-8") as f: json.dump(self.history, f)
[docs] def save_model(self, path: Optional[str] = None, **extra): path = path or os.path.join(self.output_dir, "model.pt") self.model_io.save(self.model, path, extra=extra or {})
# Loader resolution helper (optional; keeps user code short) def _resolve_loaders(self, train_loader, val_loader): if train_loader is None: for name in ("train_loader", "train", "get_train_loader"): cand = getattr(self.data, name, None) train_loader = cand() if callable(cand) else cand if train_loader is not None: break if val_loader is None: for name in ("val_loader", "val", "get_val_loader"): cand = getattr(self.data, name, None) val_loader = cand() if callable(cand) else cand if val_loader is not None: break if train_loader is None: raise ValueError("No train_loader provided and data_pipeline has none.") return train_loader, val_loader
[docs] @abstractmethod def fit(
self, *, epochs: int, train_loader: Optional[Iterable] = None, val_loader: Optional[Iterable] = None, ) -> Dict[str, List[Any]]: ...
[docs] @abstractmethod def predict(self, loader: Iterable, **kwargs) -> Any: ...
# ============================== PyTorch trainer ============================== class TorchTrainer(BaseTrainer): def __init__( self, *, model: Any, data_pipeline: Any, output_dir: str, optimizer: Any, criterion: Any, device: Any = None, callbacks: Optional[List[Callback]] = None, model_io: Optional[ModelIO] = None, config: Optional[Dict[str, Any]] = None, val_metrics: Optional[List[Callable[[Any, Any], Dict[str, float]]]] = None, scheduler: Any = None, scheduler_step_per_batch: bool = False, ): super().__init__( model, data_pipeline, output_dir, callbacks=callbacks, model_io=model_io, config=config, ) self.optimizer = optimizer self.criterion = criterion self.device = device self.val_metrics = val_metrics or [] self.scheduler = scheduler self.scheduler_step_per_batch = scheduler_step_per_batch # ---- micro-steps (override if needed) ---- def _to_device(self, batch): import torch x, y = batch[:2] return x.to(self.device), y.to(self.device) def train_step(self, batch) -> Dict[str, float]: x, y = self._to_device(batch) self.optimizer.zero_grad(set_to_none=True) out = self.model(x) loss = self.criterion(out, y) loss.backward() self.optimizer.step() return {"loss": float(loss.item())} def val_step(self, batch) -> Dict[str, float]: import torch with torch.no_grad(): x, y = self._to_device(batch) out = self.model(x) logs = {"val_loss": float(self.criterion(out, y).item())} for fn in self.val_metrics: extra = fn(out, y) # must return dict[str, float] if extra: logs.update(extra) return logs # ---- main loop ---- def fit( self, *, epochs: int, train_loader: Optional[Iterable] = None, val_loader: Optional[Iterable] = None, ) -> Dict[str, List[Any]]: import torch self.device = self.device or torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) self.model.to(self.device) if hasattr(self, "discriminator") and self.discriminator is not None: self.discriminator.to(self.device) train_loader, val_loader = self._resolve_loaders(train_loader, val_loader) ctx = CallbackContext( trainer=self, model=self.model, optimizer=self.optimizer, scheduler=self.scheduler, device=self.device, total_batches=len(train_loader), logs={}, ) ctx.epochs = epochs # NEW: make epochs visible to callbacks self.cb.call("on_train_begin", ctx) global_step = 0 for epoch in range(epochs): ctx.epoch, ctx.phase, ctx.logs = epoch, "train", {} self.cb.call("on_epoch_begin", ctx) # -------- train -------- self.model.train() if hasattr(self, "discriminator") and self.discriminator is not None: self.discriminator.train() sum_loss = 0.0 for i, batch in enumerate(train_loader): ctx.batch_idx = i ctx.batch = i # NEW: provide PyTorch-style 'batch' self.cb.call("on_batch_begin", ctx) logs = self.train_step(batch) logs["train_loss"] = float(logs.get("loss", 0.0)) # pass train_loss sum_loss += logs.get("loss", 0.0) global_step += 1 ctx.logs = logs ctx.global_step = global_step self.cb.call("on_batch_end", ctx) # ---- scheduler per-batch (except Plateau) ---- if self.scheduler and self.scheduler_step_per_batch: if not isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): self.scheduler.step() if ctx.request_stop: break avg_train = sum_loss / max(1, len(train_loader)) # -------- validate -------- val_logs_epoch = {} if val_loader is not None and not ctx.request_stop: ctx.phase, ctx.logs = "val", {} self.cb.call("on_val_epoch_begin", ctx) self.model.eval() if hasattr(self, "discriminator") and self.discriminator is not None: self.discriminator.eval() acc = collections.defaultdict(float) for j, batch in enumerate(val_loader): ctx.batch_idx = j ctx.batch = j self.cb.call("on_val_batch_begin", ctx) logs = self.val_step(batch) for k, v in logs.items(): acc[k] += float(v) ctx.logs = logs self.cb.call("on_val_batch_end", ctx) if ctx.request_stop: break val_logs_epoch = {k: acc[k] / max(1, len(val_loader)) for k in acc} ctx.logs = val_logs_epoch self.cb.call("on_val_epoch_end", ctx) # -------- end epoch -------- epoch_logs = {"train_loss": avg_train, **val_logs_epoch} for k, v in epoch_logs.items(): self.history[k].append(v) ctx.phase, ctx.logs = "train", epoch_logs self.cb.call("on_epoch_end", ctx) # ---- scheduler per-epoch (handles Plateau) ---- if self.scheduler and not self.scheduler_step_per_batch: import torch if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): metric = val_logs_epoch.get("val_loss", avg_train) self.scheduler.step(metric) else: self.scheduler.step() if ctx.request_stop or self.cb.should_stop: break self.cb.call("on_train_end", ctx) return dict(self.history) def predict(self, loader: Iterable, **_) -> List[Any]: self.model.eval() preds = [] for batch in loader: x, _ = self._to_device(batch) preds.append(self.model(x)) return preds class TorchGANTrainer(TorchTrainer): def __init__( self, *, generator: torch.nn.Module, discriminator: torch.nn.Module, optimizer_g: torch.optim.Optimizer, optimizer_d: torch.optim.Optimizer, losses, # Pix2PixLosses(lambda_l1=100) data_pipeline: Any, output_dir: str, device: Any = None, callbacks: Optional[List[Any]] = None, model_io: Optional[Any] = None, config: Optional[Dict[str, Any]] = None, val_metrics: Optional[List[Any]] = None, ): # keep base fields happy: treat "model" as the generator super().__init__( model=generator, data_pipeline=data_pipeline, output_dir=output_dir, optimizer=optimizer_g, criterion=losses, # stored; we won’t call base train_step device=device, callbacks=callbacks, model_io=model_io, config=config, val_metrics=val_metrics or [], ) self.discriminator = discriminator self.optimizer_d = optimizer_d self.losses = losses # Pix2PixLosses def train_step(self, batch) -> Dict[str, float]: import torch x, y = self._to_device(batch) # ---- D step ---- self.optimizer_d.zero_grad(set_to_none=True) with torch.no_grad(): fake = self.model(x) d_real = self.discriminator(x, y) d_fake = self.discriminator(x, fake) d_loss = self.losses.discriminator_loss(d_real, d_fake) d_loss.backward() self.optimizer_d.step() # ---- G step ---- self.optimizer.zero_grad(set_to_none=True) # optimizer for G fake = self.model(x) d_fake_for_g = self.discriminator(x, fake) g_total, g_gan, g_l1 = self.losses.generator_loss(d_fake_for_g, fake, y) g_total.backward() self.optimizer.step() return { "loss": float(g_total.item()), "g_gan": float(g_gan.item()), "g_l1": float(g_l1.item()), "d_loss": float(d_loss.item()), } def val_step(self, batch) -> Dict[str, float]: import torch with torch.no_grad(): x, y = self._to_device(batch) fake = self.model(x) d_fake = self.discriminator(x, fake) d_real = self.discriminator(x, y) g_total, g_gan, g_l1 = self.losses.generator_loss(d_fake, fake, y) d_loss = self.losses.discriminator_loss(d_real, d_fake) logs = { "val_loss": float(g_total.item()), "val_g_gan": float(g_gan.item()), "val_g_l1": float(g_l1.item()), "val_d_loss": float(d_loss.item()), } for fn in self.val_metrics: extra = fn(fake, y) # same signature as your current metrics if extra: logs.update(extra) return logs