as_command.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. from collections.abc import Callable, AsyncGenerator
  2. from typing import Union, Any
  3. from functools import wraps
  4. from logging import Logger
  5. import inspect
  6. from ..types import (
  7. Message,
  8. Context,
  9. Command,
  10. CommandType,
  11. Response,
  12. Attachment,
  13. )
  14. from ..injection import (
  15. Injector,
  16. inject_all,
  17. MessageInjector,
  18. ContextInjector,
  19. CommandInjector,
  20. LoggerInjector,
  21. )
  22. from .error_handling import with_failure_handling
  23. decorated_commands: dict[str, CommandType] = {}
  24. def _lift_command_fn(fn: Callable[..., Any]) -> Callable[..., AsyncGenerator[Any, None]]:
  25. if inspect.isasyncgenfunction(fn):
  26. lifted = fn
  27. elif inspect.iscoroutinefunction(fn):
  28. @wraps(fn)
  29. async def lifted(*args):
  30. yield await fn(*args)
  31. elif inspect.isgeneratorfunction(fn):
  32. @wraps(fn)
  33. async def lifted(*args):
  34. for res in fn(*args):
  35. yield res
  36. elif inspect.isfunction(fn):
  37. @wraps(fn)
  38. async def lifted(*args):
  39. yield fn(*args)
  40. else:
  41. raise ValueError(
  42. f"Commands should be functions, generators, async functions, or async generators"
  43. )
  44. return lifted
  45. def _get_injectors(fn: Callable[..., Any]) -> list[Injector]:
  46. injectors = []
  47. for param in inspect.signature(fn).parameters:
  48. annot = fn.__annotations__[param]
  49. if annot == Message:
  50. injectors.append(MessageInjector)
  51. elif annot == Context:
  52. injectors.append(ContextInjector)
  53. elif annot == Command:
  54. injectors.append(CommandInjector)
  55. elif annot == Logger:
  56. injectors.append(LoggerInjector)
  57. elif isinstance(annot, Injector):
  58. injectors.append(annot)
  59. else:
  60. raise ValueError(
  61. f"Annotations should be Injectors, {param} is {annot} in {fn.__name__}"
  62. )
  63. return injectors
  64. def _make_response(message: Message, result: Any) -> Response:
  65. if result is None or isinstance(result, Response):
  66. return result
  67. if isinstance(result, str):
  68. return Response.from_message(message, text=result)
  69. if isinstance(result, Attachment):
  70. return Response.from_message(message, attachments=[Attachment])
  71. return Response.from_message(message, text=str(result))
  72. def _on_command_impl(name: str, fn: Callable[..., Any]) -> Callable[..., Any]:
  73. lifted = _lift_command_fn(fn)
  74. injectors = _get_injectors(fn)
  75. @with_failure_handling
  76. @wraps(lifted)
  77. async def injected_command(message: Message, context: Context):
  78. async with inject_all(injectors, message, context) as args:
  79. async for result in lifted(*args):
  80. if (response := _make_response(message, result)) is not None:
  81. await context.respond(response)
  82. decorated_commands[name] = injected_command
  83. return fn
  84. def as_command(arg: Union[str, Callable[...]]) -> Callable[...]:
  85. if isinstance(arg, str):
  86. return lambda fn: _on_command_impl(arg, fn)
  87. else:
  88. return _on_command_impl(arg.__name__, arg)