123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103 |
- from collections.abc import Callable, AsyncGenerator
- from typing import Union, Any
- from functools import wraps
- from logging import Logger
- import inspect
- from ..types import (
- Message,
- Context,
- Command,
- CommandType,
- Response,
- )
- from ..injection import (
- Injector,
- inject_all,
- MessageInjector,
- ContextInjector,
- CommandInjector,
- LoggerInjector,
- )
- from .error_handling import with_failure_handling
- decorated_commands: dict[str, CommandType] = {}
- def _lift_command_fn(fn: Callable[..., Any]) -> Callable[..., AsyncGenerator[Any, None]]:
- if inspect.isasyncgenfunction(fn):
- lifted = fn
- elif inspect.iscoroutinefunction(fn):
- @wraps(fn)
- async def lifted(*args):
- yield await fn(*args)
- elif inspect.isgeneratorfunction(fn):
- @wraps(fn)
- async def lifted(*args):
- for res in fn(*args):
- yield res
- elif inspect.isfunction(fn):
- @wraps(fn)
- async def lifted(*args):
- yield fn(*args)
- else:
- raise ValueError # TODO details
- return lifted
- def _get_injectors(fn: Callable[..., Any]) -> list[Injector]:
- injectors = []
- for param in inspect.signature(fn).parameters:
- annot = fn.__annotations__[param]
- if annot == Message:
- injectors.append(MessageInjector)
- elif annot == Context:
- injectors.append(ContextInjector)
- elif annot == Command:
- injectors.append(CommandInjector)
- elif annot == Logger:
- injectors.append(LoggerInjector)
- elif isinstance(annot, Injector):
- injectors.append(annot)
- else:
- raise ValueError # TODO details
- return injectors
- def _make_response(message: Message, result: Any) -> Response:
- if result is None or isinstance(result, Response):
- return result
- if isinstance(result, str):
- return Response.from_message(message, text=result)
- # TODO handle attachments, other special returns
- return Response.from_message(message, text=str(result))
- def _on_command_impl(name: str, fn: Callable[..., Any]) -> Callable[..., Any]:
- lifted = _lift_command_fn(fn)
- injectors = _get_injectors(fn)
- @with_failure_handling
- @wraps(lifted)
- async def injected_command(message: Message, context: Context):
- async with inject_all(injectors, message, context) as args:
- async for result in lifted(*args):
- if (response := _make_response(message, result)) is not None:
- await context.respond(response)
- decorated_commands[name] = injected_command
- return fn
- def as_command(arg: Union[str, Callable[[Message, Context], Any]]):
- if isinstance(arg, str):
- return lambda fn: _on_command_impl(arg, fn)
- else:
- return _on_command_impl(arg.__name__, arg)
|