Source code for xflow.data.pipeline

"""Core abstractions for building reusable, named preprocessing pipelines:"""

import itertools
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional, Union

from ..utils.decorator import with_progress
from .provider import DataProvider

if TYPE_CHECKING:
    import tensorflow as tf


@dataclass
class Transform:
    """Wrapper for a preprocessing function with metadata. (like len)"""

    fn: Callable[[Any], Any]
    name: str

    def __call__(self, item: Any) -> Any:
        return self.fn(item)

    def __repr__(self) -> str:
        return self.name


[docs] class BasePipeline(ABC): """Base class for data pipelines in scientific machine learning. Provides a simple interface for data sources with preprocessing pipelines, yielding preprocessed items for ML training. Args: data_provider: DataProvider instance that yields raw data items. transforms: List of functions (Transform-wrapped or named) applied sequentially. logger: Optional logger for debugging and error tracking. skip_errors: Whether to skip items that fail preprocessing vs. raise errors. Example: >>> # Using Transform wrapper for clear metadata >>> transforms = [ ... Transform(lambda path: np.loadtxt(path, delimiter=","), "load_csv"), ... Transform(lambda data: (data[:-1], data[-1]), "split_features_target"), ... Transform(lambda x: (normalize(x[0]), x[1]), "normalize_features") ... ] >>> >>> files = ListProvider(["data1.csv", "data2.csv"]) >>> pipeline = MyPipeline(files, transforms) >>> >>> # Clear, meaningful metadata >>> print(pipeline.get_metadata()) >>> # {"pipeline_type": "MyPipeline", "dataset_size": 2, >>> # "preprocessing_functions": ["load_csv", "split_features_target", "normalize_features"]} """
[docs] def __init__( self, data_provider: DataProvider, transforms: Optional[List[Union[Callable[[Any], Any], Transform]]] = None, *, logger: Optional[logging.Logger] = None, skip_errors: bool = True, ) -> None: self.data_provider = data_provider self.transforms = [ ( fn if isinstance(fn, Transform) else Transform(fn, getattr(fn, "__name__", "unknown")) ) for fn in (transforms or []) ] self.logger = logger or logging.getLogger(__name__) self.skip_errors = skip_errors self.error_count = 0 self.in_memory_sample_count: Optional[int] = None
def __iter__(self) -> Iterator[Any]: """Iterate over preprocessed items.""" for raw_item in self.data_provider(): current_transform: Optional[Transform] = None try: item = raw_item for fn in self.transforms: current_transform = fn item = fn(item) if item is not None: yield item else: self.error_count += 1 self.logger.warning("Preprocessed item is None, skipping.") except Exception as e: self.error_count += 1 if current_transform is not None: transform_name = getattr( current_transform, "name", getattr(current_transform, "__name__", repr(current_transform)) ) self.logger.warning( "Failed to preprocess item in transform '%s': %s", transform_name, e, ) else: self.logger.warning(f"Failed to preprocess item: {e}") if not self.skip_errors: raise def __len__(self) -> int: """Return the total number of items in the dataset.""" return len(self.data_provider)
[docs] def sample(self, n: int = 5) -> List[Any]: """Return up to n preprocessed items for inspection.""" return list(itertools.islice(self.__iter__(), n))
[docs] def shuffle(self, buffer_size: int) -> "BasePipeline": """Return a new pipeline that shuffles items with a reservoir buffer.""" from .transform import ShufflePipeline return ShufflePipeline(self, buffer_size)
[docs] def batch(self, batch_size: int) -> "BasePipeline": """Return a new pipeline that batches items into lists.""" from .transform import BatchPipeline return BatchPipeline(self, batch_size)
[docs] def prefetch(self) -> "BasePipeline": """Return a new pipeline that prefetches items in background."""
# TODO: Implement prefetching logic
[docs] def reset_error_count(self) -> None: """Reset the error count to zero.""" self.error_count = 0
[docs] @abstractmethod def to_framework_dataset(self) -> Any: """Convert pipeline to framework-native dataset.""" ...
[docs] def to_numpy(self): """ Convert the pipeline to NumPy arrays. If each item is a tuple, returns a tuple of arrays (one per component). If each item is a single array, returns a single array. """ import numpy as np from IPython.display import clear_output from tqdm.auto import tqdm items = [] pbar = tqdm( self, desc="Converting to numpy", leave=False, miniters=1, position=0 ) for x in pbar: items.append(x) pbar.close() clear_output(wait=True) if not items: return None first = items[0] if isinstance(first, (tuple, list)): return tuple(np.stack(c) for c in zip(*items)) return np.stack(items)
[docs] class DataPipeline(BasePipeline): """Simple pipeline that processes data lazily without storing in memory."""
[docs] def to_framework_dataset(self) -> Any: """Not supported for lazy processing.""" raise NotImplementedError( "DataPipeline doesn't support framework conversion. " "Use InMemoryPipeline or TensorFlowPipeline instead." )
[docs] class InMemoryPipeline(BasePipeline): """In-memory pipeline that processes all data upfront."""
[docs] def __init__( self, data_provider: DataProvider, transforms: Optional[List[Union[Callable[[Any], Any], Transform]]] = None, *, logger: Optional[logging.Logger] = None, skip_errors: bool = True, ) -> None: super().__init__( data_provider, transforms, logger=logger, skip_errors=skip_errors ) from .transform import apply_transforms_to_dataset self.dataset, self.error_count = apply_transforms_to_dataset( self.data_provider(), self.transforms, logger=self.logger, skip_errors=self.skip_errors, ) self.in_memory_sample_count = len(self.dataset)
def __iter__(self) -> Iterator[Any]: return iter(self.dataset) def __len__(self) -> int: return len(self.dataset) def __getitem__(self, index: int) -> Any: return self.dataset[index]
[docs] def to_framework_dataset( self, framework: str = "tensorflow", dataset_ops: List[Dict] = None ) -> Any: """Convert to framework-native dataset using already processed data.""" if framework.lower() == "tensorflow": try: import tensorflow as tf dataset = tf.data.Dataset.from_tensor_slices(self.dataset) if dataset_ops: from .transform import apply_dataset_operations_from_config dataset = apply_dataset_operations_from_config(dataset, dataset_ops) return dataset except ImportError: raise RuntimeError("TensorFlow not available") elif framework.lower() in ("pytorch", "torch"): try: from .transform import ( TorchDataset, apply_dataset_operations_from_config, ) torch_dataset = TorchDataset(self) if dataset_ops: torch_dataset = apply_dataset_operations_from_config( torch_dataset, dataset_ops ) return torch_dataset except ImportError: raise RuntimeError("PyTorch not available") else: raise NotImplementedError(f"Framework {framework} not implemented")
[docs] class TensorFlowPipeline(BasePipeline): """Pipeline that uses TensorFlow-native transforms without preprocessing."""
[docs] def to_framework_dataset( self, framework: str = "tensorflow", dataset_ops: List[Dict] = None ): """Convert to TensorFlow dataset.""" if framework.lower() != "tensorflow": raise ValueError( f"TensorFlowPipeline only supports tensorflow, got {framework}" ) try: import tensorflow as tf file_paths = list(self.data_provider()) dataset = tf.data.Dataset.from_tensor_slices(file_paths) for transform in self.transforms: dataset = dataset.map(transform.fn, num_parallel_calls=tf.data.AUTOTUNE) if dataset_ops: from .transform import apply_dataset_operations_from_config dataset = apply_dataset_operations_from_config(dataset, dataset_ops) return dataset except ImportError: raise RuntimeError("TensorFlow not available")
class PyTorchPipeline(BasePipeline): """Pipeline that uses PyTorch-native transforms without preprocessing.""" def to_framework_dataset( self, framework: str = "pytorch", dataset_ops: List[Dict] = None ): """Convert to PyTorch dataset.""" if framework.lower() not in ("pytorch", "torch"): raise ValueError( f"PyTorchPipeline only supports pytorch/torch, got {framework}" ) try: from .transform import TorchDataset, apply_dataset_operations_from_config # Create a PyTorch-compatible dataset that applies transforms on-the-fly class PyTorchTransformDataset(TorchDataset): def __init__(self, data_provider, transforms): self.data_provider = data_provider self.transforms = transforms self._file_paths = list(data_provider()) def __len__(self): return len(self._file_paths) def __getitem__(self, idx): item = self._file_paths[idx] for transform in self.transforms: item = transform.fn(item) return item dataset = PyTorchTransformDataset(self.data_provider, self.transforms) if dataset_ops: dataset = apply_dataset_operations_from_config(dataset, dataset_ops) return dataset except ImportError: raise RuntimeError("PyTorch not available") def to_memory_dataset(self, dataset_ops: List[Dict] = None): """ Load and process ALL data samples into memory as PyTorch TensorDataset. Only use this for datasets that fit comfortably in your available RAM. This method: 1. Processes all data samples through the complete transform pipeline 2. Converts results to PyTorch tensors 3. Stores everything in memory for ultra-fast O(1) random access 4. Returns a native PyTorch TensorDataset Benefits: - Eliminates file I/O during training (much faster) - Enables efficient shuffling and random sampling - Optimized for GPU transfer Args: dataset_ops: Optional list of dataset operations to apply after loading Returns: torch.utils.data.TensorDataset: In-memory dataset with all pre-processed tensors Example: >>> pipeline = PyTorchPipeline(provider, transforms) >>> # Load entire dataset into memory (use carefully!) >>> memory_dataset = pipeline.load_all_into_memory() >>> dataloader = DataLoader(memory_dataset, batch_size=32, shuffle=True) """ try: import torch from IPython.display import clear_output from torch.utils.data import Dataset as TorchDataset, TensorDataset from tqdm.auto import tqdm from .transform import apply_dataset_operations_from_config class _MemoryListDataset(TorchDataset): """Fallback dataset that returns preprocessed samples without stacking.""" def __init__(self, data: List[Any]) -> None: self._data = data def __len__(self) -> int: return len(self._data) def __getitem__(self, idx: int) -> Any: return self._data[idx] def _to_tensor_or_keep(value: Any) -> Any: """Convert value to tensor when possible, otherwise leave unchanged.""" if isinstance(value, torch.Tensor): return value if isinstance(value, (str, bytes)): return value try: return torch.as_tensor(value) except (TypeError, ValueError): return value # Process all data through pipeline and collect results processed_data = [] pbar = tqdm( self, desc="Loading data into memory", leave=False, miniters=1, position=0, ) for item in pbar: # Convert to tensor if not already if not isinstance(item, torch.Tensor): if isinstance(item, (tuple, list)): # Handle multiple outputs (e.g., features, labels) item = tuple(_to_tensor_or_keep(x) for x in item) else: item = _to_tensor_or_keep(item) processed_data.append(item) pbar.close() clear_output(wait=True) if not processed_data: raise ValueError("No data was processed from the pipeline") self.in_memory_sample_count = len(processed_data) # Handle the case where each item is a tuple/list (multiple tensors) first_item = processed_data[0] dataset: TorchDataset try: if isinstance(first_item, (tuple, list)): tensors = [] for i in range(len(first_item)): component_values = [item[i] for item in processed_data] if not all(isinstance(x, torch.Tensor) for x in component_values): raise TypeError("Non-tensor component detected") tensors.append(torch.stack(component_values)) dataset = TensorDataset(*tensors) else: if not all(isinstance(x, torch.Tensor) for x in processed_data): raise TypeError("Non-tensor samples detected") stacked_tensor = torch.stack(processed_data) dataset = TensorDataset(stacked_tensor) except (TypeError, ValueError) as err: self.logger.warning( "Falling back to list dataset because tensors could not be stacked: %s", err, ) dataset = _MemoryListDataset(processed_data) # Apply any additional dataset operations if dataset_ops: dataset = apply_dataset_operations_from_config(dataset, dataset_ops) return dataset except ImportError: raise RuntimeError("PyTorch not available")