injection.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. from typing import Generic, TypeVar, Optional, Type
  2. from argparse import ArgumentParser, Namespace
  3. from collections.abc import Callable, Coroutine
  4. from contextlib import asynccontextmanager
  5. import shlex
  6. from aiosqlite.core import Connection
  7. from ..types import Message, Context
  8. async def inject_message(message: Message, context: Context) -> Message:
  9. return message
  10. async def inject_context(message: Message, context: Context) -> Context:
  11. return context
  12. Dep = TypeVar("DepType")
  13. class Injector(Generic[Dep]):
  14. async def inject(self, message: Message, context: Context) -> Dep:
  15. raise NotImplemented
  16. class InjectorWithCleanup(Injector[Dep]):
  17. async def cleanup(self, dep: Dep):
  18. raise NotImplemented
  19. @asynccontextmanager
  20. async def inject_all(injectors: list[Injector[Any]], message: Message, context: Context):
  21. deps = await asyncio.gather(*[inj(message, context) for inj in injectors])
  22. try:
  23. yield deps
  24. finally:
  25. for dep, inj in zip(deps, injectors):
  26. if isinstance(inj, InjectorWithCleanup):
  27. await inj.cleanup(dep)
  28. class ArgsInjector(Injector[str]):
  29. async def inject(self, message: Message, context: Context) -> str:
  30. return message.command.args
  31. class ArgListSplitOn(Injector[list[str]]):
  32. def __init__(self, split: Optional[str] = None):
  33. self.split = split
  34. async def inject(self, message: Message, context: Context) -> str:
  35. if self.split is not None:
  36. return message.command.args.split(self.split)
  37. else:
  38. return message.command.args.split()
  39. class ArgParse(Injector[Namespace]):
  40. def __init__(self, parser: ArgumentParser):
  41. self.parser = parser
  42. async def inject(self, message: Message, context: Context) -> Namespace:
  43. return self.parser.parse_args(shlex.split(message.text))
  44. class DatabaseInjector(InjectorWithCleanup[Connection]):
  45. async def inject(self, message: Message, context: Context) -> Connection:
  46. return context.database()
  47. async def cleanup(self, conn: Connection):
  48. await conn.close()
  49. class Lazy(InjectorWithCleanup[Callable[[], Coroutine[None, None, Dep]]]):
  50. def __init__(self, deferred: Injector[Dep]):
  51. self.deferred = deferred
  52. async def inject(
  53. self, message: Message, context: Context
  54. ) -> Callable[[], Coroutine[None, None, Dep]]:
  55. class _Wrapper:
  56. def __init__(self, deferred):
  57. self._calculated = None
  58. async def call():
  59. if self._calculated is None:
  60. self._calculated = await deferred.inject(message, context)
  61. return self._calculated
  62. self._call = call
  63. def __call__(self):
  64. return self._call()
  65. return _Wrapper(self.deferred)
  66. async def cleanup(self, dep: Callable[[], Coroutine[None, None, Dep]]):
  67. if isinstance(self.deferred, InjectorWithCleanup) and dep._calculated is not None:
  68. await self.deferred.cleanup(dep._calculated)
  69. Args = ArgsInjector()
  70. ArgList = ArgListSplitOn()
  71. Database = DatabaseInjector()