as_command.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. from collections.abc import Callable, AsyncGenerator
  2. from typing import Union, Any
  3. from functools import wraps
  4. from logging import Logger
  5. import inspect
  6. from ..types import (
  7. Message,
  8. Context,
  9. Command,
  10. CommandType,
  11. Response,
  12. )
  13. from ..injection import (
  14. Injector,
  15. inject_all,
  16. MessageInjector,
  17. ContextInjector,
  18. CommandInjector,
  19. LoggerInjector,
  20. )
  21. from .error_handling import with_failure_handling
  22. decorated_commands: dict[str, CommandType] = {}
  23. def _lift_command_fn(fn: Callable[..., Any]) -> Callable[..., AsyncGenerator[Any, None]]:
  24. if inspect.isasyncgenfunction(fn):
  25. lifted = fn
  26. elif inspect.iscoroutinefunction(fn):
  27. @wraps(fn)
  28. async def lifted(*args):
  29. yield await fn(*args)
  30. elif inspect.isgeneratorfunction(fn):
  31. @wraps(fn)
  32. async def lifted(*args):
  33. for res in fn(*args):
  34. yield res
  35. elif inspect.isfunction(fn):
  36. @wraps(fn)
  37. async def lifted(*args):
  38. yield fn(*args)
  39. else:
  40. raise ValueError # TODO details
  41. return lifted
  42. def _get_injectors(fn: Callable[..., Any]) -> list[Injector]:
  43. injectors = []
  44. for param in inspect.signature(fn).parameters:
  45. annot = fn.__annotations__[param]
  46. if annot == Message:
  47. injectors.append(MessageInjector)
  48. elif annot == Context:
  49. injectors.append(ContextInjector)
  50. elif annot == Command:
  51. injectors.append(CommandInjector)
  52. elif annot == Logger:
  53. injectors.append(LoggerInjector)
  54. elif isinstance(annot, Injector):
  55. injectors.append(annot)
  56. else:
  57. raise ValueError # TODO details
  58. return injectors
  59. def _make_response(message: Message, result: Any) -> Response:
  60. if result is None or isinstance(result, Response):
  61. return result
  62. if isinstance(result, str):
  63. return Response.from_message(message, text=result)
  64. # TODO handle attachments, other special returns
  65. return Response.from_message(message, text=str(result))
  66. def _on_command_impl(name: str, fn: Callable[..., Any]) -> Callable[..., Any]:
  67. lifted = _lift_command_fn(fn)
  68. injectors = _get_injectors(fn)
  69. @with_failure_handling
  70. @wraps(lifted)
  71. async def injected_command(message: Message, context: Context):
  72. async with inject_all(injectors, message, context) as args:
  73. async for result in lifted(*args):
  74. if (response := _make_response(message, result)) is not None:
  75. await context.respond(response)
  76. decorated_commands[name] = injected_command
  77. return fn
  78. def as_command(arg: Union[str, Callable[[Message, Context], Any]]):
  79. if isinstance(arg, str):
  80. return lambda fn: _on_command_impl(arg, fn)
  81. else:
  82. return _on_command_impl(arg.__name__, arg)