decorators.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. from collections.abc import Callable, AsyncGenerator
  2. from typing import Union
  3. from functools import wraps
  4. import inspect
  5. import asyncio
  6. import dataclasses
  7. import json
  8. from ..types import (
  9. Message,
  10. Context,
  11. CommandType,
  12. Response,
  13. StartupShutdownType,
  14. CommandConfiguration,
  15. )
  16. from .failure import with_failure_handling
  17. from .injection import Injector, inject_message, inject_context, inject_all
  18. decorated_startup: list[StartupShutdownType] = []
  19. decorated_shutdown: list[StartupShutdownType] = []
  20. decorated_commands: dict[str, CommandType] = {}
  21. def on_startup(fn):
  22. decorated_startup.append(fn)
  23. return fn
  24. def on_shutdown(fn):
  25. decorated_shutdown.append(fn)
  26. return fn
  27. def _lift_command_fn(fn: Callable[..., Any]) -> Callable[..., AsyncGenerator[Any, None]]:
  28. if inspect.isasyncgenfunction(fn):
  29. lifted = fn
  30. elif inspect.iscoroutinefunction(fn):
  31. @wraps(fn)
  32. async def lifted(*args):
  33. yield await fn(*args)
  34. elif inspect.isgeneratorfunction(fn):
  35. @wraps(fn)
  36. async def lifted(*args):
  37. for res in fn(*args):
  38. yield res
  39. elif inspect.isfunction(fn):
  40. @wraps(fn)
  41. async def lifted(*args):
  42. yield fn(*args)
  43. else:
  44. raise ValueError # TODO details
  45. return lifted
  46. def _get_injectors(fn: Callable[..., Any]) -> list[Injector]:
  47. injectors = []
  48. for param in inspect.signature(fn).parameters:
  49. annot = fn.__annotations__[param]
  50. if annot == Message:
  51. injectors.append(inject_message)
  52. elif annot == Context:
  53. injectors.append(inject_context)
  54. elif isinstance(annot, Injector):
  55. injectors.append(annot.inject)
  56. else:
  57. raise ValueError # TODO details
  58. return injectors
  59. def _make_response(result: Any) -> Response:
  60. if result is None or isinstance(result, Response):
  61. return result
  62. elif isinstance(result, str):
  63. return Response.from_message(message, text=result)
  64. # TODO handle attachments, other special returns
  65. else:
  66. return Response.from_message(message, text=str(result))
  67. def _on_command_impl(name: str, fn: Callable[..., Any]) -> Callable[..., Any]:
  68. lifted = _lift_command_fn(fn)
  69. injectors = _get_injectors(fn)
  70. @with_failure_handling
  71. @wraps(lifted)
  72. async def injected_command(message: Message, context: Context):
  73. async with inject_all(injectors, message, context) as args:
  74. async for result in lifted(*args):
  75. if (response := _make_response(result)) is not None:
  76. await context.respond(response)
  77. decorated_commands[name] = injected_command
  78. return fn
  79. def as_command(arg: Union[str, Callable[[Message, Context], Any]]):
  80. if isinstance(arg, str):
  81. return lambda fn: _on_command_impl(arg, fn)
  82. else:
  83. return _on_command_impl(arg.__name__, arg)
  84. def get_command_config() -> CommandConfiguration:
  85. return CommandConfiguration(
  86. commands=decorated_commands,
  87. call_and_response={},
  88. aliases={},
  89. bangs=(),
  90. startup=decorated_startup,
  91. shutdown=decorated_shutdown,
  92. )