Source code for genro_toolbox.treedict

# Copyright 2025 Softwell S.r.l. - Genropy Team
# SPDX-License-Identifier: Apache-2.0

"""TreeDict - A hierarchical dictionary with dot notation access."""

from __future__ import annotations

import asyncio
import configparser
import json
from collections.abc import Iterator
from pathlib import Path
from threading import RLock
from typing import Any

try:
    import tomllib
except ImportError:
    tomllib = None  # type: ignore[assignment]

try:
    import yaml
except ImportError:
    yaml = None  # type: ignore[assignment]


class _TraversalError(Exception):
    """Internal signal for failed path traversal (not part of public API)."""


[docs] class TreeDict: """Hierarchical dictionary with dot-path access, list indexing (#N), and thread/async safety.""" __slots__ = ("_data", "_lock", "_async_lock")
[docs] def __init__(self, data: dict[str, Any] | str | None = None) -> None: """Initialize from dict, JSON string, or None (empty).""" object.__setattr__(self, "_data", {}) object.__setattr__(self, "_lock", RLock()) object.__setattr__(self, "_async_lock", None) # Lazy init if data is None: return if isinstance(data, str): data = json.loads(data) if not isinstance(data, dict): raise TypeError(f"Expected dict or JSON string, got {type(data).__name__}") for key, value in data.items(): self._data[key] = self._wrap(value)
[docs] def __enter__(self) -> TreeDict: """Acquire lock for thread-safe access.""" self._lock.acquire() return self
[docs] def __exit__(self, *args: Any) -> None: """Release lock.""" self._lock.release()
[docs] async def __aenter__(self) -> TreeDict: """Acquire async lock for async-safe access.""" if self._async_lock is None: object.__setattr__(self, "_async_lock", asyncio.Lock()) await self._async_lock.acquire() return self
[docs] async def __aexit__(self, *args: Any) -> None: """Release async lock.""" self._async_lock.release()
def _wrap(self, value: Any) -> Any: """Wrap nested dicts as TreeDict, leave other values unchanged. If value is a TreeDict, creates a new TreeDict sharing the same _data. """ if isinstance(value, TreeDict): # Share the same _data reference new_td = TreeDict.__new__(TreeDict) object.__setattr__(new_td, "_data", value._data) object.__setattr__(new_td, "_lock", RLock()) object.__setattr__(new_td, "_async_lock", None) return new_td if isinstance(value, dict): return TreeDict(value) return value def _unwrap(self, value: Any) -> Any: """Unwrap TreeDict to plain dict recursively.""" if isinstance(value, TreeDict): return {k: self._unwrap(v) for k, v in value._data.items()} if isinstance(value, list): return [self._unwrap(item) for item in value] return value def _parse_key(self, key: str) -> tuple[bool, int | str]: """Parse a path segment, detecting list index (#N) syntax. Returns: Tuple of (is_list_index, index_or_key) """ if key.startswith("#") and key[1:].isdigit(): return True, int(key[1:]) return False, key def _traverse_to_parent(self, parts: list[str], *, create: bool = False) -> Any: """Walk all path segments except the last, returning the parent container. Args: parts: All path segments (the full split path). create: If True, create intermediate TreeDict/list nodes as needed. Raises: _TraversalError: When navigation fails (missing key, wrong type, None). """ current: Any = self for i, part in enumerate(parts[:-1]): if current is None: raise _TraversalError is_list_idx, key = self._parse_key(part) if is_list_idx: assert isinstance(key, int) if not isinstance(current, list): if create: raise TypeError(f"Cannot index non-list with {part}") raise _TraversalError if create: while len(current) <= key: current.append(None) if current[key] is None: next_is_list, _ = self._parse_key(parts[i + 1]) current[key] = [] if next_is_list else TreeDict() elif key < 0 or key >= len(current): raise _TraversalError current = current[key] if not create and isinstance(current, dict) and not isinstance(current, TreeDict): current = TreeDict(current) elif isinstance(current, TreeDict): if create: assert isinstance(key, str) if key not in current._data or current._data[key] is None: next_is_list, _ = self._parse_key(parts[i + 1]) current._data[key] = [] if next_is_list else TreeDict() current = current._data[key] else: current = current._data.get(key) elif not create and isinstance(current, dict): current = current.get(key) else: if create: raise TypeError(f"Cannot set attribute on {type(current)}") raise _TraversalError return current def _get_by_path(self, path: str) -> Any: """Get value by dot-separated path string.""" parts = path.split(".") try: current = self._traverse_to_parent(parts, create=False) except _TraversalError: return None if current is None: return None last_part = parts[-1] is_list_idx, key = self._parse_key(last_part) if is_list_idx: if not isinstance(current, list): return None assert isinstance(key, int) if key < 0 or key >= len(current): return None result = current[key] if isinstance(result, dict) and not isinstance(result, TreeDict): return TreeDict(result) return result if isinstance(current, TreeDict): return current._data.get(key) if isinstance(current, dict): return current.get(key) return None def _set_by_path(self, path: str, value: Any) -> None: """Set value by dot-separated path string, creating intermediate nodes.""" parts = path.split(".") current = self._traverse_to_parent(parts, create=True) last_part = parts[-1] is_list_idx, key = self._parse_key(last_part) if is_list_idx: assert isinstance(key, int) if not isinstance(current, list): raise TypeError(f"Cannot index non-list with {last_part}") while len(current) <= key: current.append(None) current[key] = self._wrap(value) elif isinstance(current, TreeDict): assert isinstance(key, str) current._data[key] = self._wrap(value) else: raise TypeError(f"Cannot set attribute on {type(current)}") def _del_by_path(self, path: str) -> None: """Delete value by dot-separated path string.""" parts = path.split(".") try: current = self._traverse_to_parent(parts, create=False) except _TraversalError: raise KeyError(path) from None if current is None: raise KeyError(path) last_part = parts[-1] is_list_idx, key = self._parse_key(last_part) if is_list_idx: if not isinstance(current, list): raise KeyError(path) assert isinstance(key, int) if key < 0 or key >= len(current): raise KeyError(path) del current[key] elif isinstance(current, TreeDict): assert isinstance(key, str) if key not in current._data: raise KeyError(path) del current._data[key] else: raise KeyError(path)
[docs] def __getitem__(self, path: str) -> Any: """Get value by dot-separated path string.""" return self._get_by_path(path)
[docs] def __setitem__(self, path: str, value: Any) -> None: """Set value by dot-separated path string.""" self._set_by_path(path, value)
[docs] def __delitem__(self, path: str) -> None: """Delete value by dot-separated path string.""" self._del_by_path(path)
[docs] def __contains__(self, key: str) -> bool: """Check if key exists (top-level only).""" return key in self._data
[docs] def __len__(self) -> int: """Return number of top-level keys.""" return len(self._data)
[docs] def __iter__(self) -> Iterator[str]: """Iterate over top-level keys.""" return iter(self._data)
[docs] def __repr__(self) -> str: """Return string representation.""" return f"TreeDict({self._unwrap(self)})"
[docs] def __eq__(self, other: object) -> bool: """Check equality with another TreeDict or dict.""" if isinstance(other, TreeDict): return self._data == other._data if isinstance(other, dict): return self._unwrap(self) == other return NotImplemented
[docs] def keys(self) -> Any: """Return top-level keys.""" return self._data.keys()
[docs] def values(self) -> Any: """Return top-level values.""" return self._data.values()
[docs] def items(self) -> Any: """Return top-level items.""" return self._data.items()
[docs] def get(self, key: str, default: Any = None) -> Any: """Get value by key or path, with default.""" if "." in key or key.startswith("#"): result = self._get_by_path(key) return default if result is None else result return self._data.get(key, default)
[docs] def as_dict(self) -> dict[str, Any]: """Return plain dict representation (recursive unwrap).""" return self._unwrap(self)
[docs] @classmethod def from_file(cls, path: str | Path) -> TreeDict: """Load TreeDict from a config file (JSON/YAML/TOML/INI, auto-detected by extension).""" path = Path(path) if not path.exists(): raise FileNotFoundError(f"Config file not found: {path}") suffix = path.suffix.lower() if suffix == ".json": with open(path) as f: data = json.load(f) elif suffix in (".yaml", ".yml"): if yaml is None: raise ImportError("PyYAML is required to load YAML files: pip install pyyaml") with open(path) as f: data = yaml.safe_load(f) or {} elif suffix == ".toml": if tomllib is None: raise ImportError( "tomli is required to load TOML files on Python < 3.11: pip install tomli" ) with open(path, "rb") as f: data = tomllib.load(f) elif suffix == ".ini": parser = configparser.ConfigParser() parser.read(path) data = {section: dict(parser.items(section)) for section in parser.sections()} else: raise ValueError(f"Unsupported config file format: {suffix}") return cls(data)
[docs] def walk(self, expand_lists: bool = False, _prefix: str = "") -> Iterator[tuple[str, Any]]: """Iterate over all paths and leaf values. Args: expand_lists: If True, traverse into lists using #N paths. If False, lists are treated as leaf values. Yields: Tuples of (path, value) for each leaf node. Example: >>> td = TreeDict({"a": 1, "b": {"c": 2}}) >>> list(td.walk()) [('a', 1), ('b.c', 2)] """ for key, value in self._data.items(): path = f"{_prefix}.{key}" if _prefix else key if isinstance(value, TreeDict): yield from value.walk(expand_lists=expand_lists, _prefix=path) elif expand_lists and isinstance(value, list): yield from self._walk_list(value, path, expand_lists) else: yield path, value
def _walk_list( self, lst: list[Any], prefix: str, expand_lists: bool ) -> Iterator[tuple[str, Any]]: """Walk through a list, yielding paths with #N notation.""" for i, item in enumerate(lst): path = f"{prefix}.#{i}" if isinstance(item, dict): wrapped = TreeDict(item) if not isinstance(item, TreeDict) else item yield from wrapped.walk(expand_lists=expand_lists, _prefix=path) elif expand_lists and isinstance(item, list): yield from self._walk_list(item, path, expand_lists) else: yield path, item