from collections.abc import Callable, AsyncGenerator from typing import Union, Any from functools import wraps from logging import Logger import inspect from ..types import ( Message, Context, Command, CommandType, Response, ) from ..injection import ( Injector, inject_all, MessageInjector, ContextInjector, CommandInjector, LoggerInjector, ) from .error_handling import with_failure_handling decorated_commands: dict[str, CommandType] = {} def _lift_command_fn(fn: Callable[..., Any]) -> Callable[..., AsyncGenerator[Any, None]]: if inspect.isasyncgenfunction(fn): lifted = fn elif inspect.iscoroutinefunction(fn): @wraps(fn) async def lifted(*args): yield await fn(*args) elif inspect.isgeneratorfunction(fn): @wraps(fn) async def lifted(*args): for res in fn(*args): yield res elif inspect.isfunction(fn): @wraps(fn) async def lifted(*args): yield fn(*args) else: raise ValueError # TODO details return lifted def _get_injectors(fn: Callable[..., Any]) -> list[Injector]: injectors = [] for param in inspect.signature(fn).parameters: annot = fn.__annotations__[param] if annot == Message: injectors.append(MessageInjector) elif annot == Context: injectors.append(ContextInjector) elif annot == Command: injectors.append(CommandInjector) elif annot == Logger: injectors.append(LoggerInjector) elif isinstance(annot, Injector): injectors.append(annot) else: raise ValueError # TODO details return injectors def _make_response(message: Message, result: Any) -> Response: if result is None or isinstance(result, Response): return result if isinstance(result, str): return Response.from_message(message, text=result) # TODO handle attachments, other special returns return Response.from_message(message, text=str(result)) def _on_command_impl(name: str, fn: Callable[..., Any]) -> Callable[..., Any]: lifted = _lift_command_fn(fn) injectors = _get_injectors(fn) @with_failure_handling @wraps(lifted) async def injected_command(message: Message, context: Context): async with inject_all(injectors, message, context) as args: async for result in lifted(*args): if (response := _make_response(message, result)) is not None: await context.respond(response) decorated_commands[name] = injected_command return fn def as_command(arg: Union[str, Callable[[Message, Context], Any]]): if isinstance(arg, str): return lambda fn: _on_command_impl(arg, fn) else: return _on_command_impl(arg.__name__, arg)