|
@@ -1,104 +0,0 @@
|
|
|
-from typing import Generic, TypeVar, Optional, Type
|
|
|
-from argparse import ArgumentParser, Namespace
|
|
|
-from collections.abc import Callable, Coroutine
|
|
|
-from contextlib import asynccontextmanager
|
|
|
-import shlex
|
|
|
-
|
|
|
-from aiosqlite.core import Connection
|
|
|
-
|
|
|
-from ..types import Message, Context
|
|
|
-
|
|
|
-
|
|
|
-async def inject_message(message: Message, context: Context) -> Message:
|
|
|
- return message
|
|
|
-
|
|
|
-
|
|
|
-async def inject_context(message: Message, context: Context) -> Context:
|
|
|
- return context
|
|
|
-
|
|
|
-
|
|
|
-Dep = TypeVar("DepType")
|
|
|
-
|
|
|
-
|
|
|
-class Injector(Generic[Dep]):
|
|
|
- async def inject(self, message: Message, context: Context) -> Dep:
|
|
|
- raise NotImplemented
|
|
|
-
|
|
|
-
|
|
|
-class InjectorWithCleanup(Injector[Dep]):
|
|
|
- async def cleanup(self, dep: Dep):
|
|
|
- raise NotImplemented
|
|
|
-
|
|
|
-
|
|
|
-@asynccontextmanager
|
|
|
-async def inject_all(injectors: list[Injector[Any]], message: Message, context: Context):
|
|
|
- deps = await asyncio.gather(*[inj(message, context) for inj in injectors])
|
|
|
- try:
|
|
|
- yield deps
|
|
|
- finally:
|
|
|
- for dep, inj in zip(deps, injectors):
|
|
|
- if isinstance(inj, InjectorWithCleanup):
|
|
|
- await inj.cleanup(dep)
|
|
|
-
|
|
|
-
|
|
|
-class ArgsInjector(Injector[str]):
|
|
|
- async def inject(self, message: Message, context: Context) -> str:
|
|
|
- return message.command.args
|
|
|
-
|
|
|
-
|
|
|
-class ArgListSplitOn(Injector[list[str]]):
|
|
|
- def __init__(self, split: Optional[str] = None):
|
|
|
- self.split = split
|
|
|
-
|
|
|
- async def inject(self, message: Message, context: Context) -> str:
|
|
|
- if self.split is not None:
|
|
|
- return message.command.args.split(self.split)
|
|
|
- else:
|
|
|
- return message.command.args.split()
|
|
|
-
|
|
|
-
|
|
|
-class ArgParse(Injector[Namespace]):
|
|
|
- def __init__(self, parser: ArgumentParser):
|
|
|
- self.parser = parser
|
|
|
-
|
|
|
- async def inject(self, message: Message, context: Context) -> Namespace:
|
|
|
- return self.parser.parse_args(shlex.split(message.text))
|
|
|
-
|
|
|
-
|
|
|
-class DatabaseInjector(InjectorWithCleanup[Connection]):
|
|
|
- async def inject(self, message: Message, context: Context) -> Connection:
|
|
|
- return context.database()
|
|
|
-
|
|
|
- async def cleanup(self, conn: Connection):
|
|
|
- await conn.close()
|
|
|
-
|
|
|
-
|
|
|
-class Lazy(InjectorWithCleanup[Callable[[], Coroutine[None, None, Dep]]]):
|
|
|
- def __init__(self, deferred: Injector[Dep]):
|
|
|
- self.deferred = deferred
|
|
|
-
|
|
|
- async def inject(
|
|
|
- self, message: Message, context: Context
|
|
|
- ) -> Callable[[], Coroutine[None, None, Dep]]:
|
|
|
- class _Wrapper:
|
|
|
- def __init__(self, deferred):
|
|
|
- self._calculated = None
|
|
|
- async def call():
|
|
|
- if self._calculated is None:
|
|
|
- self._calculated = await deferred.inject(message, context)
|
|
|
- return self._calculated
|
|
|
- self._call = call
|
|
|
-
|
|
|
- def __call__(self):
|
|
|
- return self._call()
|
|
|
-
|
|
|
- return _Wrapper(self.deferred)
|
|
|
-
|
|
|
- async def cleanup(self, dep: Callable[[], Coroutine[None, None, Dep]]):
|
|
|
- if isinstance(self.deferred, InjectorWithCleanup) and dep._calculated is not None:
|
|
|
- await self.deferred.cleanup(dep._calculated)
|
|
|
-
|
|
|
-
|
|
|
-Args = ArgsInjector()
|
|
|
-ArgList = ArgListSplitOn()
|
|
|
-Database = DatabaseInjector()
|