from typing import Generic, TypeVar, Optional, Type from argparse import ArgumentParser, Namespace from collections.abc import Callable, Coroutine from contextlib import asynccontextmanager import shlex from aiosqlite.core import Connection from ..types import Message, Context async def inject_message(message: Message, context: Context) -> Message: return message async def inject_context(message: Message, context: Context) -> Context: return context Dep = TypeVar("DepType") class Injector(Generic[Dep]): async def inject(self, message: Message, context: Context) -> Dep: raise NotImplemented class InjectorWithCleanup(Injector[Dep]): async def cleanup(self, dep: Dep): raise NotImplemented @asynccontextmanager async def inject_all(injectors: list[Injector[Any]], message: Message, context: Context): deps = await asyncio.gather(*[inj(message, context) for inj in injectors]) try: yield deps finally: for dep, inj in zip(deps, injectors): if isinstance(inj, InjectorWithCleanup): await inj.cleanup(dep) class ArgsInjector(Injector[str]): async def inject(self, message: Message, context: Context) -> str: return message.command.args class ArgListSplitOn(Injector[list[str]]): def __init__(self, split: Optional[str] = None): self.split = split async def inject(self, message: Message, context: Context) -> str: if self.split is not None: return message.command.args.split(self.split) else: return message.command.args.split() class ArgParse(Injector[Namespace]): def __init__(self, parser: ArgumentParser): self.parser = parser async def inject(self, message: Message, context: Context) -> Namespace: return self.parser.parse_args(shlex.split(message.text)) class DatabaseInjector(InjectorWithCleanup[Connection]): async def inject(self, message: Message, context: Context) -> Connection: return context.database() async def cleanup(self, conn: Connection): await conn.close() class Lazy(InjectorWithCleanup[Callable[[], Coroutine[None, None, Dep]]]): def __init__(self, deferred: Injector[Dep]): self.deferred = deferred async def inject( self, message: Message, context: Context ) -> Callable[[], Coroutine[None, None, Dep]]: class _Wrapper: def __init__(self, deferred): self._calculated = None async def call(): if self._calculated is None: self._calculated = await deferred.inject(message, context) return self._calculated self._call = call def __call__(self): return self._call() return _Wrapper(self.deferred) async def cleanup(self, dep: Callable[[], Coroutine[None, None, Dep]]): if isinstance(self.deferred, InjectorWithCleanup) and dep._calculated is not None: await self.deferred.cleanup(dep._calculated) Args = ArgsInjector() ArgList = ArgListSplitOn() Database = DatabaseInjector()