decorators.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. from collections.abc import Callable
  2. from typing import Union
  3. from functools import wraps
  4. import inspect
  5. import asyncio
  6. import dataclasses
  7. import json
  8. from ..types import (
  9. Message,
  10. Context,
  11. CommandType,
  12. Response,
  13. StartupShutdownType,
  14. CommandConfiguration,
  15. )
  16. from .failure import RollbotFailureException
  17. from .injection import Injector, inject_message, inject_context
  18. decorated_startup: list[StartupShutdownType] = []
  19. decorated_shutdown: list[StartupShutdownType] = []
  20. decorated_commands: dict[str, CommandType] = {}
  21. def on_startup(fn):
  22. decorated_startup.append(fn)
  23. return fn
  24. def on_shutdown(fn):
  25. decorated_shutdown.append(fn)
  26. return fn
  27. def as_command(arg: Union[str, Callable]):
  28. def impl(name, fn):
  29. if inspect.isasyncgenfunction(fn):
  30. lifted = fn
  31. elif inspect.iscoroutinefunction(fn):
  32. @wraps(fn)
  33. async def lifted(*args):
  34. yield await fn(*args)
  35. elif inspect.isgeneratorfunction(fn):
  36. @wraps(fn)
  37. async def lifted(*args):
  38. for res in fn(*args):
  39. yield res
  40. elif inspect.isfunction(fn):
  41. @wraps(fn)
  42. async def lifted(*args):
  43. yield fn(*args)
  44. else:
  45. raise ValueError # TODO details
  46. injectors = []
  47. for param in inspect.signature(fn).parameters:
  48. annot = fn.__annotations__[param]
  49. if annot == Message:
  50. injectors.append(inject_message)
  51. elif annot == Context:
  52. injectors.append(inject_context)
  53. elif isinstance(annot, Injector):
  54. injectors.append(annot.inject)
  55. else:
  56. raise ValueError # TODO details
  57. async def command_impl(message: Message, context: Context):
  58. args = await asyncio.gather(*[inj(message, context) for inj in injectors])
  59. try:
  60. async for result in lifted(*args):
  61. if isinstance(result, Response):
  62. response = result
  63. elif isinstance(result, str):
  64. response = Response.from_message(message, text=result)
  65. # TODO handle attachments, other special returns
  66. else:
  67. response = Response.from_message(message, str(result))
  68. await context.respond(response)
  69. except RollbotFailureException as exc:
  70. # TODO handle errors more specifically
  71. await context.respond(Response.from_message(message, str(exc.failure)))
  72. decorated_commands[name] = command_impl
  73. return fn
  74. if isinstance(arg, str):
  75. return lambda fn: impl(arg, fn)
  76. else:
  77. return impl(arg.__name__, arg)
  78. def get_command_config() -> CommandConfiguration:
  79. return CommandConfiguration(
  80. commands=decorated_commands,
  81. call_and_response={},
  82. aliases={},
  83. bangs=(),
  84. startup=decorated_startup,
  85. shutdown=decorated_shutdown,
  86. )
  87. def as_data(cls):
  88. table_name = "".join(("_" + c.lower()) if "A" <= c <= "Z" else c for c in cls.__name__).strip("_")
  89. columns = [
  90. "key TEXT NOT NULL PRIMARY KEY",
  91. 'body TEXT DEFAULT ""',
  92. ]
  93. queries = {}
  94. for k, v in cls.__annotations__.items():
  95. if v == int:
  96. t = "INT"
  97. elif v == float:
  98. t = "REAL"
  99. elif v == str:
  100. t = "TEXT"
  101. else:
  102. continue
  103. columns.append(f"__{k} {t} GENERATED ALWAYS AS (json_extract(body, '$.{k}')) VIRTUAL")
  104. queries[k] = f"SELECT body FROM {table_name} WHERE __{k} = ?"
  105. create_query = f"CREATE TABLE IF NOT EXISTS {table_name} ({', '.join(columns)})"
  106. @on_startup
  107. async def create_table(context: Context):
  108. async with context.database() as db:
  109. await db.execute(create_query)
  110. await db.commit()
  111. new_class = type(
  112. cls.__name__,
  113. (dataclasses.dataclass(cls),),
  114. dict(
  115. __field_queries=queries,
  116. __create_query=create_query,
  117. __select_query=f"SELECT body FROM {table_name} WHERE key=:key",
  118. __save_query=f"INSERT INTO {table_name} VALUES (:key, :body) ON CONFLICT(key) DO UPDATE SET body=:body",
  119. __to_blob=lambda self: json.dumps(dataclasses.asdict(self)),
  120. __from_blob=staticmethod(lambda blob: new_class(**json.loads(blob)))
  121. ),
  122. )
  123. return new_class