injection.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. from typing import Generic, TypeVar, Optional
  2. from argparse import ArgumentParser, Namespace
  3. from collections.abc import Callable, Coroutine
  4. import shlex
  5. from aiosqlite.core import Connection
  6. from ..types import Message, Context
  7. async def inject_message(message: Message, context: Context) -> Message:
  8. return message
  9. async def inject_context(message: Message, context: Context) -> Context:
  10. return context
  11. Dep = TypeVar("DepType")
  12. class Injector(Generic[Dep]):
  13. async def inject(self, message: Message, context: Context) -> Dep:
  14. raise NotImplemented
  15. class ArgsInjector(Injector[str]):
  16. async def inject(self, message: Message, context: Context) -> str:
  17. return message.command.args
  18. class ArgListSplitOn(Injector[list[str]]):
  19. def __init__(self, split: Optional[str] = None):
  20. self.split = split
  21. async def inject(self, message: Message, context: Context) -> str:
  22. if self.split is not None:
  23. return message.command.args.split(self.split)
  24. else:
  25. return message.command.args.split()
  26. class ArgParse(Injector[Namespace]):
  27. def __init__(self, parser: ArgumentParser):
  28. self.parser = parser
  29. async def inject(self, message: Message, context: Context) -> Namespace:
  30. return self.parser.parse_args(shlex.split(message.text))
  31. class DatabaseInjector(Injector[Connection]):
  32. async def inject(self, message: Message, context: Context) -> Connection:
  33. return context.database()
  34. class Lazy(Injector[Callable[[], Coroutine[None, None, Dep]]]):
  35. def __init__(self, deferred: Injector[Dep]):
  36. self.deferred = deferred
  37. self._calculated = None
  38. async def inject(self, message: Message, context: Context) -> Callable[[], Coroutine[None, None, Dep]]:
  39. async def call():
  40. if self._calculated is None:
  41. self._calculated = await self.deferred.inject(message, context)
  42. return self._calculated
  43. return call
  44. Args = ArgsInjector()
  45. ArgList = ArgListSplitOn()
  46. Database = DatabaseInjector()