Source code for xflow.models.base
"""Lightweight model coordination layer for plugin to training frameworks."""
from abc import ABC, abstractmethod
from typing import Any
from ..utils.typing import PathLikeStr, Batch, LossOrMetrics
class InferenceModel(ABC):
@abstractmethod
def predict(self, inputs: Any, **kwargs) -> Any:
"""Run a forward/inference pass."""
@abstractmethod
def save(self, path: PathLikeStr) -> None:
"""Persist weights (and any config) to disk."""
@classmethod
@abstractmethod
def load(cls, path: PathLikeStr, **kwargs) -> "InferenceModel":
"""Load model and config from disk."""
class Trainable(ABC):
@abstractmethod
def training_step(self, batch: Batch) -> LossOrMetrics:
"""
Consume one batch (inputs, targets), perform an update,
and return a loss or a metrics dict (if dict, must contain 'loss').
"""
@abstractmethod
def validation_step(self, batch: Batch) -> LossOrMetrics:
"""Evaluate one batch in eval mode; return loss or metrics dict."""
@abstractmethod
def configure_optimizers(self) -> Any:
"""Return optimizer(s)/schedulers required by the training framework."""
def set_train_mode(self, training: bool = True) -> None:
"""Set model to training or evaluation mode. Override if needed."""
pass
[docs]
class BaseModel(InferenceModel, Trainable, ABC):
"""Combined abstract interface; implement a single subclass."""