import inspect
import itertools
import os
import random
import sys
import warnings
from collections.abc import Mapping, Sequence
from pathlib import Path
from typing import Any, Dict, List, MutableMapping, Optional, Sequence, Tuple
import __main__
from .typing import T
# =============================================================================
# Path helpers
# =============================================================================
def print_caller_directory():
"""
Prints the directory path of the script that called this function.
Useful for debugging or logging script origin in multi-file projects.
"""
caller_frame = inspect.stack()[1]
caller_file = os.path.abspath(caller_frame.filename)
print("Caller script directory:", os.path.dirname(caller_file))
def get_base_dir() -> Path:
"""
Returns the directory path of the calling context with cross-environment compatibility.
"""
# 1. Check if running in Jupyter notebook
try:
# Check for IPython/Jupyter environment
if "ipykernel" in sys.modules or "IPython" in sys.modules:
# Try to get notebook directory from IPython
try:
from IPython import get_ipython
ipython = get_ipython()
if ipython is not None:
# Get the current working directory in Jupyter
return Path(os.getcwd()).resolve()
except ImportError:
pass
except Exception:
pass
# 2. Direct script execution: __main__.__file__ exists
try:
main_file = getattr(__main__, "__file__", None)
if main_file and os.path.exists(main_file):
return Path(main_file).parent.resolve()
except Exception:
pass
# 3. Check if running as frozen executable
try:
if getattr(sys, "frozen", False):
return Path(sys.executable).parent.resolve()
except Exception:
pass
# 4. Fallback: inspect stack for first external caller (skip IPython frames)
try:
current_file = Path(__file__).resolve()
for frame_info in inspect.stack()[1:]: # skip current frame
filename = frame_info.filename
# Skip interactive frames, this module, IPython/Jupyter internals, and built-ins
if (
filename.startswith("<")
or filename.startswith("[")
or "IPython" in filename # some REPLs use brackets
or "ipykernel" in filename
or "jupyter" in filename
or Path(filename).resolve() == current_file
):
continue
file_path = Path(filename)
if file_path.exists():
return file_path.parent.resolve()
except Exception:
pass
# 5. Try sys.argv[0] if available
try:
if sys.argv and sys.argv[0]:
script_path = Path(sys.argv[0])
if script_path.exists() and script_path.is_file():
return script_path.parent.resolve()
except Exception:
pass
# 6. Ultimate fallback: current working directory
return Path(os.getcwd()).resolve()
# =============================================================================
# Iterable/Sequence helpers
# =============================================================================
[docs]
def subsample_sequence(
items: Sequence[T],
n_samples: Optional[int] = None,
fraction: Optional[float] = None,
strategy: str = "random",
seed: Optional[int] = 42,
) -> List[T]:
"""
Subsampling function for any Sequence.
Args:
items: Any sequence (list, tuple, etc.) of type T.
n_samples: Exact number to sample.
fraction: Fraction to sample (0.0 to 1.0).
strategy: "random", "first", "last", "stride", or "reservoir".
seed: Random seed for reproducibility.
Returns:
List of sampled items of type T.
"""
# Validate parameters
if n_samples is not None and fraction is not None:
raise ValueError("Specify exactly one of n_samples or fraction, not both.")
length = len(items)
if n_samples is None and fraction is None:
return list(items)
if fraction is not None:
if not 0.0 <= fraction <= 1.0:
raise ValueError("fraction must be between 0.0 and 1.0")
n_samples = int(length * fraction)
# Clamp n_samples to [0, length]
n_samples = max(0, min(n_samples, length))
# Random sampling
if strategy == "random":
rng = random.Random(seed)
return rng.sample(list(items), n_samples)
# First n samples
elif strategy == "first":
return list(items[:n_samples])
# Last n samples
elif strategy == "last":
return list(items[-n_samples:])
# Stride sampling (lazy with islice)
elif strategy == "stride":
if n_samples == 0:
return []
step = max(1, length // n_samples)
return list(itertools.islice(items, 0, None, step))[:n_samples]
# Reservoir sampling for true iterators
elif strategy == "reservoir":
rng = random.Random(seed)
reservoir: List[T] = []
for i, elem in enumerate(items):
if i < n_samples:
reservoir.append(elem)
else:
j = rng.randint(0, i)
if j < n_samples:
reservoir[j] = elem
return reservoir
else:
raise ValueError(f"Unknown strategy: {strategy}")
[docs]
def split_sequence(
items: Sequence[T], split_ratio: float = 0.8, seed: int = 42, shuffle: bool = True
) -> Tuple[List[T], List[T]]:
"""
Split a sequence into two parts.
Args:
items: Any sequence (list, tuple, etc.) of type T.
split_ratio: Ratio for first part (0.0 to 1.0).
seed: Random seed for reproducibility.
shuffle: Whether to shuffle before splitting.
Returns:
Tuple of (first_part, second_part) as lists of type T.
"""
if not 0.0 <= split_ratio <= 1.0:
raise ValueError(f"split_ratio must be between 0.0 and 1.0, got {split_ratio}")
items_list = list(items)
if shuffle:
rng = random.Random(seed)
rng.shuffle(items_list)
split_idx = int(len(items_list) * split_ratio)
first_part = items_list[:split_idx]
second_part = items_list[split_idx:]
return first_part, second_part
# =============================================================================
# Dictionary helpers
# =============================================================================
[docs]
def deep_update(base: MutableMapping[str, Any], updates: Dict[str, Any]) -> None:
"""Recursively update a dictionary with another dictionary.
Nested dictionaries are merged, other values are replaced.
Modifies base dictionary in-place.
Args:
base: Dictionary to update (modified in-place)
updates: Dictionary with updates to apply
Example:
>>> base = {"a": {"x": 1, "y": 2}, "b": 3}
>>> updates = {"a": {"x": 10, "z": 3}, "c": 4}
>>> deep_update(base, updates)
>>> base
{"a": {"x": 10, "y": 2, "z": 3}, "b": 3, "c": 4}
"""
for key, value in updates.items():
if isinstance(value, dict) and isinstance(base.get(key), dict):
deep_update(base[key], value)
else:
base[key] = value
def deep_merge(*dicts: Dict[str, Any]) -> Dict[str, Any]:
"""Merge multiple dictionaries recursively, returning a new dictionary.
Args:
*dicts: Dictionaries to merge (left-to-right precedence)
Returns:
New merged dictionary
"""
if not dicts:
return {}
result = {}
for d in dicts:
deep_update(result, d)
return result
# =============================================================================
# Environment/OS helpers
# =============================================================================
_CGROUP_HINTS = ("docker", "containerd", "kubepods", "libpod", "buildkit", "podman")
def is_container() -> bool:
# 0) explicit override wins
flag = os.getenv("IN_CONTAINER")
if flag and flag.lower() not in ("0", "false", "no", "off"):
return True
# 1) systemd's official hint
try:
if Path("/run/systemd/container").exists(): # readable by all
return True # systemd sets this inside containers
if os.geteuid() == 0:
env1 = Path("/proc/1/environ").read_bytes()
if b"container=" in env1: # e.g., container=podman/docker/lxc
return True
except Exception:
pass
# 2) Podman marker
if Path("/run/.containerenv").exists(): # created by Podman/CRI-O
return True
# 3) Docker marker (may be absent under buildx/buildkit)
if Path("/.dockerenv").exists():
return True
# 4) Apptainer/Singularity (common on HPC)
if any(k in os.environ for k in ("APPTAINER_NAME", "SINGULARITY_NAME")):
return True
# 5) Kubernetes (env injected by kubelet)
if "KUBERNETES_SERVICE_HOST" in os.environ:
return True
# 6) cgroups heuristic (last resort; not 100% on cgroup v2)
try:
txt = Path("/proc/1/cgroup").read_text()
if any(s in txt for s in _CGROUP_HINTS):
return True
except Exception:
pass
return False
def instantiate(
target,
cfg=None,
*,
overrides=None,
allow_positional: bool = False,
allow_extra_kwargs: bool = False,
):
"""
Safe instantiation/call helper.
- If cfg is a Mapping: passes as **kwargs (strict by default).
- If cfg is a Sequence (not str/bytes): passes as *args (only if allow_positional=True).
- Raises on unknown keys and missing required args.
- If allow_extra_kwargs=True and target does not accept **kwargs,
extra keys are dropped before binding/calling.
"""
cfg = {} if cfg is None else cfg
overrides = {} if overrides is None else overrides
sig = inspect.signature(target)
# Detect if target accepts **kwargs (e.g. nn.Module subclasses forwarding to super().__init__)
has_var_keyword = any(
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
)
# Only named parameters (exclude *args/**kwargs placeholders)
named_params = {
p.name
for p in sig.parameters.values()
if p.kind
not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
}
def _sanitize_kwargs(kwargs: dict) -> dict:
unknown = set(kwargs) - named_params
if not unknown:
return kwargs
if has_var_keyword:
return kwargs
if allow_extra_kwargs:
filtered = {k: v for k, v in kwargs.items() if k in named_params}
return filtered
raise TypeError(
f"{getattr(target, '__name__', target)}: unexpected keys {sorted(unknown)}; "
f"allowed keys {sorted(named_params)}"
)
# Mapping -> kwargs mode
if isinstance(cfg, Mapping):
kwargs = dict(cfg)
kwargs.update(overrides)
kwargs = _sanitize_kwargs(kwargs)
# Bind to force "missing required", "positional-only passed as keyword", etc.
sig.bind(**kwargs)
return target(**kwargs)
# Sequence -> positional mode (explicitly opt-in)
if isinstance(cfg, Sequence) and not isinstance(cfg, (str, bytes, bytearray)):
if not allow_positional:
raise TypeError(
"Positional config (list/tuple) is disabled. "
"Pass a dict, or set allow_positional=True."
)
args = list(cfg)
kwargs = dict(overrides)
kwargs = _sanitize_kwargs(kwargs)
warnings.warn(
"Using positional args from a Sequence. This is more fragile than dict-based kwargs.",
RuntimeWarning,
stacklevel=2,
)
sig.bind(*args, **kwargs)
return target(*args, **kwargs)
raise TypeError(
f"cfg must be a Mapping (dict-like) or a Sequence (list/tuple), got {type(cfg).__name__}"
)