# Copyright 2025 Softwell S.r.l. - Genropy Team
# SPDX-License-Identifier: Apache-2.0
"""SmartAsync - Unified sync/async API decorator.
Automatic context detection for methods that work in both sync and async contexts.
This module is also available as a standalone package: pip install smartasync
Design Context:
This module is optimized for environments with pre-assigned thread workers
(e.g., Gunicorn with sync workers, Django, Flask). In these contexts, threads
are long-lived and reused across requests, so the per-thread loop pool provides
efficient event loop reuse without creation overhead.
Caveats - Nested Mixed Calls:
When async code offloads sync code to a thread (via to_thread), and that sync
code calls async functions, those async functions MUST be decorated with
@smartasync to work correctly. Without the decorator, the sync code receives
a raw coroutine object it cannot use.
Problematic chain:
async A() -> sync B() [in thread] -> async C() [NO decorator] = BROKEN
Safe chain:
async A() -> sync B() [in thread] -> async C() [@smartasync] = WORKS
Best practice: Apply @smartasync only at the "leaf" level - the outermost
boundary where sync code calls async code. Avoid deep nesting of mixed calls.
Inline Usage:
smartasync can be used inline without the decorator syntax, useful for
wrapping third-party async functions or one-off calls:
# Wrap and call in one line
result = smartasync(some_async_func)(arg1, arg2)
# Or wrap once, call multiple times
wrapped = smartasync(third_party_async_func)
result1 = wrapped(args1)
result2 = wrapped(args2)
"""
import asyncio
import contextvars
import functools
import threading
from .typeutils import is_awaitable
_async_mode: contextvars.ContextVar[bool | None] = contextvars.ContextVar(
"genro_async_mode", default=None
)
def set_sync(active: bool = True) -> None:
"""Force sync mode. Overrides auto-detection of async context.
Call set_sync(False) to cancel override and return to auto-detect.
"""
_async_mode.set(False if active else None)
def set_async(active: bool = True) -> None:
"""Force async mode. Overrides auto-detection of async context.
Call set_async(False) to cancel override and return to auto-detect.
"""
_async_mode.set(True if active else None)
class AsyncHandler:
"""Manages per-thread event loops for sync context execution.
Provides a single point of access to determine async/sync context
and manage event loops for each thread.
The current_thread_loop property returns:
- None: if running in async context (external loop exists)
- EventLoop: if running in sync context (creates/reuses per-thread loop)
"""
def __init__(self):
self._thread_loops: dict[int, asyncio.AbstractEventLoop] = {}
self._reset_lock = threading.Lock()
@property
def current_thread_loop(self) -> asyncio.AbstractEventLoop | None:
"""Get event loop for current thread, or None if in async context."""
if is_async_context():
return None
tid = threading.get_ident()
loop = self._thread_loops.get(tid)
if loop is None or loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
self._thread_loops[tid] = loop
return loop
@current_thread_loop.setter
def current_thread_loop(self, value):
"""Set or remove the event loop for current thread (None to remove)."""
tid = threading.get_ident()
if value is None:
self._thread_loops.pop(tid, None)
else:
self._thread_loops[tid] = value
def reset(self):
"""Clear all cached event loops. Thread-safe.
Closes all loops before clearing. Use only in tests when no
other threads are actively using smartasync.
"""
with self._reset_lock:
for loop in self._thread_loops.values():
if not loop.is_closed():
loop.close()
self._thread_loops.clear()
# Module-level singleton
_async_handler = AsyncHandler()
def is_async_context() -> bool:
"""Return True if running in async context. Respects set_sync/set_async override."""
mode = _async_mode.get()
if mode is not None:
return mode
try:
asyncio.get_running_loop()
return True
except RuntimeError:
return False
def reset_smartasync_cache():
"""Clear all cached event loops. Call in tests to ensure clean state."""
_async_handler.reset()
[docs]
def smartasync(method):
"""Decorator that adapts sync/async functions to work in both contexts.
Dispatches based on (async_context, is_coroutine):
sync+async→run_until_complete, sync+sync→passthrough,
async+async→return coroutine, async+sync→to_thread.
"""
# Import time: Detect if method is async
is_coro = asyncio.iscoroutinefunction(method)
@functools.wraps(method)
def wrapper(*args, **kwargs):
# Get loop for current thread (None if async context)
loop = _async_handler.current_thread_loop
async_context = loop is None
# Dispatch based on (async_context, is_coro) using pattern matching
match (async_context, is_coro):
case (False, True):
# Sync context + Async method -> Run with per-thread loop
coro = method(*args, **kwargs)
return loop.run_until_complete(coro)
case (False, False):
# Sync context + Sync method -> Direct call (pass-through)
return method(*args, **kwargs)
case (True, True):
# Async context + Async method -> Return coroutine to be awaited
return method(*args, **kwargs)
case (True, False):
# Async context + Sync method -> Offload to thread (don't block event loop)
return asyncio.to_thread(method, *args, **kwargs)
return wrapper
async def smartawait(value):
"""Await a value recursively until it is no longer awaitable."""
while is_awaitable(value):
value = await value
return value
def smartcontinuation(value, on_resolved, *args, **kwargs):
"""Apply on_resolved to value, wrapping in a continuation if value is awaitable."""
if is_awaitable(value):
async def cont():
resolved = await value
return on_resolved(resolved, *args, **kwargs)
return cont()
return on_resolved(value, *args, **kwargs)
class SmartLock:
"""Async lock with lazy creation and Future sharing for concurrent callers."""
__slots__ = ("_lock", "_future")
def __init__(self):
"""Initialize with no lock or future (created on-demand)."""
self._lock = None
self._future = None
async def run_once(self, coro_func, *args, **kwargs):
"""Execute coro_func once, sharing the result with concurrent callers via a Future."""
# Fast path: if Future exists, another call is in progress
if self._future is not None:
return await self._future
# Create lock on first use
if self._lock is None:
self._lock = asyncio.Lock()
async with self._lock:
# Double-check after acquiring lock
if self._future is not None:
return await self._future
# Create Future for other callers to await
loop = asyncio.get_running_loop()
self._future = loop.create_future()
try:
result = await coro_func(*args, **kwargs)
self._future.set_result(result)
return result
except Exception as e:
self._future.set_exception(e)
raise
finally:
self._future = None
def reset(self):
"""Reset the lock state.
Cancels any pending future, causing waiters to receive CancelledError.
"""
if self._future is not None and not self._future.done():
self._future.cancel()
self._future = None