123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104 |
- from typing import Generic, TypeVar, Optional, Type
- from argparse import ArgumentParser, Namespace
- from collections.abc import Callable, Coroutine
- from contextlib import asynccontextmanager
- import shlex
- from aiosqlite.core import Connection
- from ..types import Message, Context
- async def inject_message(message: Message, context: Context) -> Message:
- return message
- async def inject_context(message: Message, context: Context) -> Context:
- return context
- Dep = TypeVar("DepType")
- class Injector(Generic[Dep]):
- async def inject(self, message: Message, context: Context) -> Dep:
- raise NotImplemented
- class InjectorWithCleanup(Injector[Dep]):
- async def cleanup(self, dep: Dep):
- raise NotImplemented
- @asynccontextmanager
- async def inject_all(injectors: list[Injector[Any]], message: Message, context: Context):
- deps = await asyncio.gather(*[inj(message, context) for inj in injectors])
- try:
- yield deps
- finally:
- for dep, inj in zip(deps, injectors):
- if isinstance(inj, InjectorWithCleanup):
- await inj.cleanup(dep)
- class ArgsInjector(Injector[str]):
- async def inject(self, message: Message, context: Context) -> str:
- return message.command.args
- class ArgListSplitOn(Injector[list[str]]):
- def __init__(self, split: Optional[str] = None):
- self.split = split
- async def inject(self, message: Message, context: Context) -> str:
- if self.split is not None:
- return message.command.args.split(self.split)
- else:
- return message.command.args.split()
- class ArgParse(Injector[Namespace]):
- def __init__(self, parser: ArgumentParser):
- self.parser = parser
- async def inject(self, message: Message, context: Context) -> Namespace:
- return self.parser.parse_args(shlex.split(message.text))
- class DatabaseInjector(InjectorWithCleanup[Connection]):
- async def inject(self, message: Message, context: Context) -> Connection:
- return context.database()
- async def cleanup(self, conn: Connection):
- await conn.close()
- class Lazy(InjectorWithCleanup[Callable[[], Coroutine[None, None, Dep]]]):
- def __init__(self, deferred: Injector[Dep]):
- self.deferred = deferred
- async def inject(
- self, message: Message, context: Context
- ) -> Callable[[], Coroutine[None, None, Dep]]:
- class _Wrapper:
- def __init__(self, deferred):
- self._calculated = None
- async def call():
- if self._calculated is None:
- self._calculated = await deferred.inject(message, context)
- return self._calculated
- self._call = call
- def __call__(self):
- return self._call()
-
- return _Wrapper(self.deferred)
- async def cleanup(self, dep: Callable[[], Coroutine[None, None, Dep]]):
- if isinstance(self.deferred, InjectorWithCleanup) and dep._calculated is not None:
- await self.deferred.cleanup(dep._calculated)
- Args = ArgsInjector()
- ArgList = ArgListSplitOn()
- Database = DatabaseInjector()
|