Source code for genro_toolbox.dict_utils

"""
Dictionary utilities for Genro-Toolbox.

Provides utilities for dict manipulation used across the library.
"""

import configparser
import inspect
import json
import os
from collections.abc import Callable, Mapping
from pathlib import Path
from typing import Any, get_type_hints

from .treedict import TreeDict

try:
    import tomllib
except ImportError:
    tomllib = None  # type: ignore[assignment]

try:
    import yaml
except ImportError:
    yaml = None  # type: ignore[assignment]

_ENV_PREFIX = "ENV:"
_RESERVED_ATTR_NAMES = ["class"]


def filtered_dict(
    data: Mapping[str, Any] | None,
    filter_fn: Callable[[str, Any], bool] | None = None,
) -> dict[str, Any]:
    """
    Return a dict filtered through ``filter_fn``.

    Args:
        data: Mapping with the original values (can be None).
        filter_fn: Optional callable receiving ``(key, value)`` and returning
            True if the pair should be kept. When None, the mapping is copied.
    """
    if not data:
        return {}
    if filter_fn is None:
        return dict(data)
    return {k: v for k, v in data.items() if filter_fn(k, v)}


def _merge_kwargs(
    incoming: Mapping[str, Any] | None,
    defaults: Mapping[str, Any] | None,
    *,
    filter_fn: Callable[[str, Any], bool] | None = None,
    ignore_none: bool = False,
    ignore_empty: bool = False,
) -> dict[str, Any]:
    combined_filter = _compose_filter(filter_fn, ignore_none, ignore_empty)
    merged_defaults = dict(defaults or {})
    filtered_incoming = filtered_dict(incoming, combined_filter)
    return merged_defaults | filtered_incoming


def _compose_filter(
    filter_fn: Callable[[str, Any], bool] | None,
    ignore_none: bool,
    ignore_empty: bool,
) -> Callable[[str, Any], bool] | None:
    if not (filter_fn or ignore_none or ignore_empty):
        return None

    def predicate(key: str, value: Any) -> bool:
        if ignore_none and value is None:
            return False
        if ignore_empty and _is_empty_value(value):
            return False
        if filter_fn:
            return filter_fn(key, value)
        return True

    return predicate


def _is_empty_value(value: Any) -> bool:
    """Return True for values considered 'empty'."""
    empty_sequences = (str, bytes, list, tuple, dict, set, frozenset)
    if isinstance(value, empty_sequences):
        return len(value) == 0
    return False


def _extract_signature_info(
    func: Callable[..., Any],
) -> tuple[dict[str, Any], dict[str, type], list[str]]:
    """Extract defaults, types, and positional params from callable signature."""
    sig = inspect.signature(func)
    defaults = {}
    types = {}
    positional_params = []

    # Use get_type_hints to resolve stringified annotations (PEP 563)
    try:
        hints = get_type_hints(func)
    except (NameError, AttributeError, TypeError):
        hints = {}

    for name, param in sig.parameters.items():
        if param.default is not inspect.Parameter.empty:
            defaults[name] = param.default
        else:
            positional_params.append(name)

        # Extract type from resolved hints or fallback to annotation
        ann = hints.get(name, param.annotation)
        if ann is inspect.Parameter.empty:
            continue

        # Handle Annotated types
        if hasattr(ann, "__origin__") and ann.__origin__ is type(None):
            continue
        if hasattr(ann, "__metadata__"):  # Annotated type
            ann = ann.__args__[0]
        if ann in (str, int, float, bool):
            types[name] = ann

    return defaults, types, positional_params


def _load_from_callable(func: Callable[..., Any], argv: list[str] | None = None) -> dict[str, Any]:
    """Extract defaults from callable signature and parse argv."""
    defaults, types, positional_params = _extract_signature_info(func)

    # Parse argv if provided
    if argv is not None:
        return _parse_argv(argv, defaults, types, positional_params)

    return defaults


def _parse_argv(
    argv: list[str],
    defaults: dict[str, Any],
    types: dict[str, type],
    positional_params: list[str],
) -> dict[str, Any]:
    """Parse argv into a dict using defaults and types."""
    result = dict(defaults)
    positional_index = 0
    i = 0

    while i < len(argv):
        arg = argv[i]
        if arg.startswith("--"):
            key = arg[2:].replace("-", "_")
            # Check if it's a boolean flag
            if key in types and types[key] is bool:
                result[key] = True
            elif i + 1 < len(argv) and not argv[i + 1].startswith("--"):
                value = argv[i + 1]
                if key in types:
                    value = types[key](value)
                result[key] = value
                i += 1
        else:
            # Positional argument
            if positional_index < len(positional_params):
                key = positional_params[positional_index]
                value = arg
                if key in types:
                    value = types[key](value)
                result[key] = value
                positional_index += 1
        i += 1

    return result


def _load_env(prefix: str, types: dict[str, type] | None = None) -> dict[str, Any]:
    """Load config from environment variables with given prefix.

    Args:
        prefix: Environment variable prefix (e.g., "MYAPP" for MYAPP_HOST, MYAPP_PORT)
        types: Optional dict mapping keys to types for conversion
    """
    prefix_upper = prefix.upper() + "_"
    result: dict[str, Any] = {}
    for key, raw_value in os.environ.items():
        if key.startswith(prefix_upper):
            # Remove prefix and convert to lowercase
            clean_key = key[len(prefix_upper) :].lower()
            # Convert type if specified
            converted: Any = raw_value
            if types and clean_key in types:
                target_type = types[clean_key]
                if target_type is bool:
                    converted = raw_value.lower() in ("true", "1", "yes", "on")
                else:
                    converted = target_type(raw_value)
            result[clean_key] = converted
    return result


def _load_config_file(path: str | Path) -> dict[str, Any]:
    """Load config from file based on extension. Returns {} if file doesn't exist."""
    path = Path(path)
    if not path.exists():
        return {}
    suffix = path.suffix.lower()

    if suffix in (".yaml", ".yml"):
        if yaml is None:
            raise ImportError("PyYAML is required to load YAML files: pip install pyyaml")
        with open(path) as f:
            return yaml.safe_load(f) or {}
    elif suffix == ".json":
        with open(path) as f:
            return json.load(f)
    elif suffix == ".toml":
        if tomllib is None:
            raise ImportError(
                "tomli is required to load TOML files on Python < 3.11: pip install tomli"
            )
        with open(path, "rb") as f:
            return tomllib.load(f)
    elif suffix == ".ini":
        parser = configparser.ConfigParser()
        parser.read(path)
        return {
            f"{section}_{key}": value
            for section in parser.sections()
            for key, value in parser.items(section)
        }
    else:
        raise ValueError(f"Unsupported config file format: {suffix}")


def _wrap_nested_dicts(data: dict[str, Any]) -> dict[str, Any]:
    """Wrap nested dicts and string lists in SmartOptions recursively.

    - Nested dicts become SmartOptions
    - String lists become SmartOptions with boolean values (feature flags)
    - Lists of dicts are indexed by first key of first element
    """
    result: dict[str, Any] = {}
    for key, value in data.items():
        if isinstance(value, dict):
            result[key] = SmartOptions(value)
        elif isinstance(value, list) and value:
            if all(isinstance(x, str) for x in value):
                # String list -> feature flags
                result[key] = SmartOptions(dict.fromkeys(value, True))
            elif all(isinstance(x, dict) for x in value):
                # List of dicts -> index by first key of first element
                result[key] = _index_list_of_dicts(value)
            else:
                result[key] = value
        else:
            result[key] = value
    return result


def _index_list_of_dicts(items: list[dict[str, Any]]) -> "SmartOptions":
    """Convert list of dicts to SmartOptions indexed by first key value."""
    if not items or not items[0]:
        return SmartOptions({})

    # Get the first key from first element as index key
    index_key = next(iter(items[0].keys()))

    indexed = {}
    for item in items:
        if index_key in item:
            key_value = item[index_key]
            indexed[key_value] = SmartOptions(item)

    return SmartOptions(indexed)


[docs] class SmartOptions(TreeDict): """ Convenience namespace for option management, built on TreeDict. Args: incoming: Mapping with runtime kwargs, or a file path (str/Path) to load, or a callable to extract defaults and types from its signature. defaults: Mapping with baseline options, or argv list when incoming is callable. env: Environment variable prefix (e.g., "MYAPP" for MYAPP_HOST). Only used when incoming is a callable (types from signature are used). argv: Command line arguments list. Only used when incoming is a callable. ignore_none: Skip incoming entries where the value is ``None``. ignore_empty: Skip empty strings/collections from incoming entries. filter_fn: Optional callable receiving ``(key, value)`` and returning True if the pair should be kept. If incoming is a string or Path and defaults is None, loads config from file. Nested dicts are recursively wrapped in SmartOptions. When incoming is a callable with env/argv: - Defaults come from function signature - env values override defaults (with type conversion) - argv values override env (with type conversion) Priority: defaults < env < argv Inherits from TreeDict: - Path notation access: opts["a.b.c"] - Context managers for thread/async safety """
[docs] def __init__( self, incoming: Mapping[str, Any] | str | Path | Callable[..., Any] | None = None, defaults: Mapping[str, Any] | list[str] | None = None, *, env: str | None = None, argv: list[str] | None = None, ignore_none: bool = False, ignore_empty: bool = False, filter_fn: Callable[[str, Any], bool] | None = None, ): # If incoming is callable, use new env/argv parameters or legacy defaults if callable(incoming) and not isinstance(incoming, type): sig_defaults, types, positional_params = _extract_signature_info(incoming) # Start with signature defaults result = dict(sig_defaults) # If env or argv keyword args are provided, use new API if env is not None or argv is not None: # Layer env values (with type conversion) if env is not None: prefix = env[len(_ENV_PREFIX) :] if env.startswith(_ENV_PREFIX) else env env_values = _load_env(prefix, types) result.update(env_values) # Layer argv values (with type conversion) if argv is not None: argv_values = _parse_argv(argv, {}, types, positional_params) result.update(argv_values) incoming = result else: # Legacy API: defaults is argv list legacy_argv = defaults if isinstance(defaults, list) else None incoming = _load_from_callable(incoming, legacy_argv) defaults = None # If incoming is a string, detect source type elif isinstance(incoming, str) and defaults is None: if incoming.startswith(_ENV_PREFIX): incoming = _load_env(incoming[len(_ENV_PREFIX) :]) else: incoming = _load_config_file(incoming) elif isinstance(incoming, Path) and defaults is None: incoming = _load_config_file(incoming) # At this point incoming is Mapping or None, defaults is Mapping or None merged = _merge_kwargs( incoming, # type: ignore[arg-type] defaults, # type: ignore[arg-type] filter_fn=filter_fn, ignore_none=ignore_none, ignore_empty=ignore_empty, ) # Wrap nested dicts recursively (SmartOptions-specific wrapping) merged = _wrap_nested_dicts(merged) # Initialize TreeDict with the merged data super().__init__(merged)
def _wrap(self, value: Any) -> Any: """Override to wrap nested dicts as SmartOptions instead of TreeDict. SmartOptions values are already wrapped by _wrap_nested_dicts, so we just return them as-is. """ if isinstance(value, SmartOptions): # Already a SmartOptions, return as-is return value if isinstance(value, TreeDict): # Share the same _data reference, but as SmartOptions new_opts = SmartOptions.__new__(SmartOptions) object.__setattr__(new_opts, "_data", value._data) object.__setattr__(new_opts, "_lock", self._lock) object.__setattr__(new_opts, "_async_lock", None) return new_opts if isinstance(value, dict): return SmartOptions(value) return value
[docs] def __add__(self, other: "SmartOptions | Mapping[str, Any]") -> "SmartOptions": """Merge two SmartOptions. Right side overrides left side.""" other_data = other._data if isinstance(other, (SmartOptions, TreeDict)) else dict(other) merged = self.as_dict() | other_data return SmartOptions(merged)
[docs] def __repr__(self) -> str: """Return string representation.""" return f"SmartOptions({self.as_dict()})"
def dictExtract(source_dict, prefix, pop=False, slice_prefix=True): """Return a dict of the items with keys starting with prefix. :param source_dict: source dictionary :param prefix: the prefix of the items you need to extract :param pop: removes the items from the source dictionary :param slice_prefix: shortens the keys of the output dict removing the prefix :returns: a dict of the items with keys starting with prefix""" lprefix = len(prefix) if slice_prefix else 0 extract_fn = source_dict.pop if pop else source_dict.get return { k[lprefix:] if k[lprefix:] not in _RESERVED_ATTR_NAMES else f"_{k[lprefix:]}": extract_fn(k) for k in list(source_dict.keys()) if k.startswith(prefix) }
[docs] class DictObj(dict): """Dict with dot-access for attribute-style read/write. Example:: ctx = DictObj() ctx.db = connection ctx.session = session_obj ctx.db.execute(...) # dot-access "db" in ctx # dict-access """ def __getattr__(self, name: str) -> Any: try: return self[name] except KeyError: raise AttributeError(name) from None def __setattr__(self, name: str, value: Any) -> None: self[name] = value def __delattr__(self, name: str) -> None: try: del self[name] except KeyError: raise AttributeError(name) from None