as_command.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. from collections.abc import Callable, AsyncGenerator
  2. from typing import Union, Any
  3. from functools import wraps
  4. import inspect
  5. from ..types import (
  6. Message,
  7. Context,
  8. Command,
  9. CommandType,
  10. Response,
  11. )
  12. from ..injection import Injector, inject_all, MessageInjector, ContextInjector, CommandInjector
  13. from .error_handling import with_failure_handling
  14. decorated_commands: dict[str, CommandType] = {}
  15. def _lift_command_fn(fn: Callable[..., Any]) -> Callable[..., AsyncGenerator[Any, None]]:
  16. if inspect.isasyncgenfunction(fn):
  17. lifted = fn
  18. elif inspect.iscoroutinefunction(fn):
  19. @wraps(fn)
  20. async def lifted(*args):
  21. yield await fn(*args)
  22. elif inspect.isgeneratorfunction(fn):
  23. @wraps(fn)
  24. async def lifted(*args):
  25. for res in fn(*args):
  26. yield res
  27. elif inspect.isfunction(fn):
  28. @wraps(fn)
  29. async def lifted(*args):
  30. yield fn(*args)
  31. else:
  32. raise ValueError # TODO details
  33. return lifted
  34. def _get_injectors(fn: Callable[..., Any]) -> list[Injector]:
  35. injectors = []
  36. for param in inspect.signature(fn).parameters:
  37. annot = fn.__annotations__[param]
  38. if annot == Message:
  39. injectors.append(MessageInjector)
  40. elif annot == Context:
  41. injectors.append(ContextInjector)
  42. elif annot == Command:
  43. injectors.append(CommandInjector)
  44. elif isinstance(annot, Injector):
  45. injectors.append(annot)
  46. else:
  47. raise ValueError # TODO details
  48. return injectors
  49. def _make_response(message: Message, result: Any) -> Response:
  50. if result is None or isinstance(result, Response):
  51. return result
  52. if isinstance(result, str):
  53. return Response.from_message(message, text=result)
  54. # TODO handle attachments, other special returns
  55. return Response.from_message(message, text=str(result))
  56. def _on_command_impl(name: str, fn: Callable[..., Any]) -> Callable[..., Any]:
  57. lifted = _lift_command_fn(fn)
  58. injectors = _get_injectors(fn)
  59. @with_failure_handling
  60. @wraps(lifted)
  61. async def injected_command(message: Message, context: Context):
  62. async with inject_all(injectors, message, context) as args:
  63. async for result in lifted(*args):
  64. if (response := _make_response(message, result)) is not None:
  65. await context.respond(response)
  66. decorated_commands[name] = injected_command
  67. return fn
  68. def as_command(arg: Union[str, Callable[[Message, Context], Any]]):
  69. if isinstance(arg, str):
  70. return lambda fn: _on_command_impl(arg, fn)
  71. else:
  72. return _on_command_impl(arg.__name__, arg)