as_command.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. from collections.abc import Callable, AsyncGenerator
  2. from typing import Union, Any
  3. from functools import wraps
  4. import inspect
  5. from ..types import (
  6. Message,
  7. Context,
  8. CommandType,
  9. Response,
  10. )
  11. from ..injection import Injector, inject_all, MessageInjector, ContextInjector
  12. from .error_handling import with_failure_handling
  13. decorated_commands: dict[str, CommandType] = {}
  14. def _lift_command_fn(fn: Callable[..., Any]) -> Callable[..., AsyncGenerator[Any, None]]:
  15. if inspect.isasyncgenfunction(fn):
  16. lifted = fn
  17. elif inspect.iscoroutinefunction(fn):
  18. @wraps(fn)
  19. async def lifted(*args):
  20. yield await fn(*args)
  21. elif inspect.isgeneratorfunction(fn):
  22. @wraps(fn)
  23. async def lifted(*args):
  24. for res in fn(*args):
  25. yield res
  26. elif inspect.isfunction(fn):
  27. @wraps(fn)
  28. async def lifted(*args):
  29. yield fn(*args)
  30. else:
  31. raise ValueError # TODO details
  32. return lifted
  33. def _get_injectors(fn: Callable[..., Any]) -> list[Injector]:
  34. injectors = []
  35. for param in inspect.signature(fn).parameters:
  36. annot = fn.__annotations__[param]
  37. if annot == Message:
  38. injectors.append(MessageInjector)
  39. elif annot == Context:
  40. injectors.append(ContextInjector)
  41. elif isinstance(annot, Injector):
  42. injectors.append(annot)
  43. else:
  44. raise ValueError # TODO details
  45. return injectors
  46. def _make_response(message: Message, result: Any) -> Response:
  47. if result is None or isinstance(result, Response):
  48. return result
  49. elif isinstance(result, str):
  50. return Response.from_message(message, text=result)
  51. # TODO handle attachments, other special returns
  52. else:
  53. return Response.from_message(message, text=str(result))
  54. def _on_command_impl(name: str, fn: Callable[..., Any]) -> Callable[..., Any]:
  55. lifted = _lift_command_fn(fn)
  56. injectors = _get_injectors(fn)
  57. @with_failure_handling
  58. @wraps(lifted)
  59. async def injected_command(message: Message, context: Context):
  60. async with inject_all(injectors, message, context) as args:
  61. async for result in lifted(*args):
  62. if (response := _make_response(message, result)) is not None:
  63. await context.respond(response)
  64. decorated_commands[name] = injected_command
  65. return fn
  66. def as_command(arg: Union[str, Callable[[Message, Context], Any]]):
  67. if isinstance(arg, str):
  68. return lambda fn: _on_command_impl(arg, fn)
  69. else:
  70. return _on_command_impl(arg.__name__, arg)