1234567891011121314151617181920212223242526272829303132333435363738 |
- from collections.abc import Callable
- from typing import Generic, TypeVar, Any
- from contextlib import asynccontextmanager
- from ..types import Message, Context
- Dep = TypeVar("DepType")
- class Injector(Generic[Dep]):
- async def inject(self, message: Message, context: Context) -> Dep:
- raise NotImplementedError
- class InjectorWithCleanup(Injector[Dep]):
- async def cleanup(self, dep: Dep):
- raise NotImplementedError
- class Simple(Injector[Dep]):
- def __init__(self, extract: Callable[[Message, Context], Dep]):
- self.extract = extract
- async def inject(self, message: Message, context: Context) -> Dep:
- return self.extract(message, context)
- @asynccontextmanager
- async def inject_all(injectors: list[Injector[Any]], message: Message, context: Context):
- try:
- deps = []
- for inj in injectors:
- deps.append(await inj.inject(message, context))
- yield deps
- finally:
- for dep, inj in zip(deps, injectors):
- if isinstance(inj, InjectorWithCleanup):
- await inj.cleanup(dep)
|