Source code for xflow.utils.helper

import inspect
import itertools
import os
import random
import sys
from pathlib import Path
from typing import Any, Dict, List, MutableMapping, Optional, Sequence, Tuple

import __main__

from .typing import T

# =============================================================================
# Path helpers
# =============================================================================


def print_caller_directory():
    """
    Prints the directory path of the script that called this function.

    Useful for debugging or logging script origin in multi-file projects.
    """
    caller_frame = inspect.stack()[1]
    caller_file = os.path.abspath(caller_frame.filename)
    print("Caller script directory:", os.path.dirname(caller_file))


[docs] def get_base_dir() -> Path: """ Returns the directory path of the calling context with cross-environment compatibility. """ # 1. Check if running in Jupyter notebook try: # Check for IPython/Jupyter environment if "ipykernel" in sys.modules or "IPython" in sys.modules: # Try to get notebook directory from IPython try: from IPython import get_ipython ipython = get_ipython() if ipython is not None: # Get the current working directory in Jupyter return Path(os.getcwd()).resolve() except ImportError: pass except Exception: pass # 2. Direct script execution: __main__.__file__ exists try: main_file = getattr(__main__, "__file__", None) if main_file and os.path.exists(main_file): return Path(main_file).parent.resolve() except Exception: pass # 3. Check if running as frozen executable try: if getattr(sys, "frozen", False): return Path(sys.executable).parent.resolve() except Exception: pass # 4. Fallback: inspect stack for first external caller (skip IPython frames) try: current_file = Path(__file__).resolve() for frame_info in inspect.stack()[1:]: # skip current frame filename = frame_info.filename # Skip interactive frames, this module, IPython/Jupyter internals, and built-ins if ( filename.startswith("<") or filename.startswith("[") or "IPython" in filename # some REPLs use brackets or "ipykernel" in filename or "jupyter" in filename or Path(filename).resolve() == current_file ): continue file_path = Path(filename) if file_path.exists(): return file_path.parent.resolve() except Exception: pass # 5. Try sys.argv[0] if available try: if sys.argv and sys.argv[0]: script_path = Path(sys.argv[0]) if script_path.exists() and script_path.is_file(): return script_path.parent.resolve() except Exception: pass # 6. Ultimate fallback: current working directory return Path(os.getcwd()).resolve()
# ============================================================================= # Iterable/Sequence helpers # =============================================================================
[docs] def subsample_sequence( items: Sequence[T], n_samples: Optional[int] = None, fraction: Optional[float] = None, strategy: str = "random", seed: Optional[int] = 42, ) -> List[T]: """ Subsampling function for any Sequence. Args: items: Any sequence (list, tuple, etc.) of type T. n_samples: Exact number to sample. fraction: Fraction to sample (0.0 to 1.0). strategy: "random", "first", "last", "stride", or "reservoir". seed: Random seed for reproducibility. Returns: List of sampled items of type T. """ # Validate parameters if n_samples is not None and fraction is not None: raise ValueError("Specify exactly one of n_samples or fraction, not both.") length = len(items) if n_samples is None and fraction is None: return list(items) if fraction is not None: if not 0.0 <= fraction <= 1.0: raise ValueError("fraction must be between 0.0 and 1.0") n_samples = int(length * fraction) # Clamp n_samples to [0, length] n_samples = max(0, min(n_samples, length)) # Random sampling if strategy == "random": rng = random.Random(seed) return rng.sample(list(items), n_samples) # First n samples elif strategy == "first": return list(items[:n_samples]) # Last n samples elif strategy == "last": return list(items[-n_samples:]) # Stride sampling (lazy with islice) elif strategy == "stride": if n_samples == 0: return [] step = max(1, length // n_samples) return list(itertools.islice(items, 0, None, step))[:n_samples] # Reservoir sampling for true iterators elif strategy == "reservoir": rng = random.Random(seed) reservoir: List[T] = [] for i, elem in enumerate(items): if i < n_samples: reservoir.append(elem) else: j = rng.randint(0, i) if j < n_samples: reservoir[j] = elem return reservoir else: raise ValueError(f"Unknown strategy: {strategy}")
[docs] def split_sequence( items: Sequence[T], split_ratio: float = 0.8, seed: int = 42, shuffle: bool = True ) -> Tuple[List[T], List[T]]: """ Split a sequence into two parts. Args: items: Any sequence (list, tuple, etc.) of type T. split_ratio: Ratio for first part (0.0 to 1.0). seed: Random seed for reproducibility. shuffle: Whether to shuffle before splitting. Returns: Tuple of (first_part, second_part) as lists of type T. """ if not 0.0 <= split_ratio <= 1.0: raise ValueError(f"split_ratio must be between 0.0 and 1.0, got {split_ratio}") items_list = list(items) if shuffle: rng = random.Random(seed) rng.shuffle(items_list) split_idx = int(len(items_list) * split_ratio) first_part = items_list[:split_idx] second_part = items_list[split_idx:] return first_part, second_part
# ============================================================================= # Dictionary helpers # =============================================================================
[docs] def deep_update(base: MutableMapping[str, Any], updates: Dict[str, Any]) -> None: """Recursively update a dictionary with another dictionary. Nested dictionaries are merged, other values are replaced. Modifies base dictionary in-place. Args: base: Dictionary to update (modified in-place) updates: Dictionary with updates to apply Example: >>> base = {"a": {"x": 1, "y": 2}, "b": 3} >>> updates = {"a": {"x": 10, "z": 3}, "c": 4} >>> deep_update(base, updates) >>> base {"a": {"x": 10, "y": 2, "z": 3}, "b": 3, "c": 4} """ for key, value in updates.items(): if isinstance(value, dict) and isinstance(base.get(key), dict): deep_update(base[key], value) else: base[key] = value
def deep_merge(*dicts: Dict[str, Any]) -> Dict[str, Any]: """Merge multiple dictionaries recursively, returning a new dictionary. Args: *dicts: Dictionaries to merge (left-to-right precedence) Returns: New merged dictionary """ if not dicts: return {} result = {} for d in dicts: deep_update(result, d) return result # ============================================================================= # Environment/OS helpers # ============================================================================= _CGROUP_HINTS = ("docker", "containerd", "kubepods", "libpod", "buildkit", "podman") def is_container() -> bool: # 0) explicit override wins flag = os.getenv("IN_CONTAINER") if flag and flag.lower() not in ("0", "false", "no", "off"): return True # 1) systemd's official hint try: if Path("/run/systemd/container").exists(): # readable by all return True # systemd sets this inside containers if os.geteuid() == 0: env1 = Path("/proc/1/environ").read_bytes() if b"container=" in env1: # e.g., container=podman/docker/lxc return True except Exception: pass # 2) Podman marker if Path("/run/.containerenv").exists(): # created by Podman/CRI-O return True # 3) Docker marker (may be absent under buildx/buildkit) if Path("/.dockerenv").exists(): return True # 4) Apptainer/Singularity (common on HPC) if any(k in os.environ for k in ("APPTAINER_NAME", "SINGULARITY_NAME")): return True # 5) Kubernetes (env injected by kubelet) if "KUBERNETES_SERVICE_HOST" in os.environ: return True # 6) cgroups heuristic (last resort; not 100% on cgroup v2) try: txt = Path("/proc/1/cgroup").read_text() if any(s in txt for s in _CGROUP_HINTS): return True except Exception: pass return False