injection.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. from typing import Generic, TypeVar, Optional
  2. from argparse import ArgumentParser, Namespace
  3. import shlex
  4. from aiosqlite.core import Connection
  5. from ..types import Message, Context
  6. async def inject_message(message: Message, context: Context) -> Message:
  7. return message
  8. async def inject_context(message: Message, context: Context) -> Context:
  9. return context
  10. Dep = TypeVar("DepType")
  11. class Injector(Generic[Dep]):
  12. async def inject(self, message: Message, context: Context) -> Dep:
  13. raise NotImplemented
  14. class ArgsInjector(Injector[str]):
  15. async def inject(self, message: Message, context: Context) -> str:
  16. return message.command.args
  17. class ArgListSplitOn(Injector[list[str]]):
  18. def __init__(self, split: Optional[str] = None):
  19. self.split = split
  20. async def inject(self, message: Message, context: Context) -> str:
  21. if self.split is not None:
  22. return message.command.args.split(self.split)
  23. else:
  24. return message.command.args.split()
  25. class ArgParse(Injector[Namespace]):
  26. def __init__(self, parser: ArgumentParser):
  27. self.parser = parser
  28. async def inject(self, message: Message, context: Context) -> Namespace:
  29. return self.parser.parse_args(shlex.split(message.text))
  30. class DatabaseInjector(Injector[Connection]):
  31. async def inject(self, message: Message, context: Context) -> Connection:
  32. return context.database()
  33. Args = ArgsInjector()
  34. ArgList = ArgListSplitOn()
  35. Database = DatabaseInjector()