decorators.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. from collections.abc import Callable
  2. from typing import Union
  3. from functools import wraps
  4. import inspect
  5. import asyncio
  6. from ..types import (
  7. Message,
  8. Context,
  9. CommandType,
  10. Response,
  11. StartupShutdownType,
  12. CommandConfiguration,
  13. )
  14. from .failure import RollbotFailureException
  15. from .injection import Injector, inject_message, inject_context
  16. decorated_startup: list[StartupShutdownType] = []
  17. decorated_shutdown: list[StartupShutdownType] = []
  18. decorated_commands: dict[str, CommandType] = {}
  19. def on_startup(fn):
  20. decorated_startup.append(fn)
  21. return fn
  22. def on_shutdown(fn):
  23. decorated_shutdown.append(fn)
  24. return fn
  25. def as_command(arg: Union[str, Callable]):
  26. def impl(name, fn):
  27. if inspect.isasyncgenfunction(fn):
  28. lifted = fn
  29. elif inspect.iscoroutinefunction(fn):
  30. @wraps(fn)
  31. async def lifted(*args):
  32. yield await fn(*args)
  33. elif inspect.isgeneratorfunction(fn):
  34. @wraps(fn)
  35. async def lifted(*args):
  36. for res in fn(*args):
  37. yield res
  38. elif inspect.isfunction(fn):
  39. @wraps(fn)
  40. async def lifted(*args):
  41. yield fn(*args)
  42. else:
  43. raise ValueError # TODO details
  44. injectors = []
  45. for param in inspect.signature(fn).parameters:
  46. annot = fn.__annotations__[param]
  47. if annot == Message:
  48. injectors.append(inject_message)
  49. elif annot == Context:
  50. injectors.append(inject_context)
  51. elif isinstance(annot, Injector):
  52. injectors.append(annot.inject)
  53. else:
  54. raise ValueError # TODO details
  55. async def command_impl(message: Message, context: Context):
  56. args = await asyncio.gather(*[inj(message, context) for inj in injectors])
  57. try:
  58. async for result in lifted(*args):
  59. if isinstance(result, Response):
  60. response = result
  61. elif isinstance(result, str):
  62. response = Response.from_message(message, text=result)
  63. # TODO handle attachments, other special returns
  64. else:
  65. response = Response.from_message(message, str(result))
  66. await context.respond(response)
  67. except RollbotFailureException as exc:
  68. # TODO handle errors more specifically
  69. await context.respond(Response.from_message(message, str(exc.failure)))
  70. decorated_commands[name] = command_impl
  71. return fn
  72. if isinstance(arg, str):
  73. return lambda fn: impl(arg, fn)
  74. else:
  75. return impl(arg.__name__, arg)
  76. def get_command_config() -> CommandConfiguration:
  77. return CommandConfiguration(
  78. commands=decorated_commands,
  79. call_and_response={},
  80. aliases={},
  81. bangs=(),
  82. startup=decorated_startup,
  83. shutdown=decorated_shutdown,
  84. )