from typing import Generic, TypeVar, Any from contextlib import asynccontextmanager import asyncio from ..types import Message, Context Dep = TypeVar("DepType") class Injector(Generic[Dep]): async def inject(self, message: Message, context: Context) -> Dep: raise NotImplementedError class InjectorWithCleanup(Injector[Dep]): async def cleanup(self, dep: Dep): raise NotImplementedError @asynccontextmanager async def inject_all(injectors: list[Injector[Any]], message: Message, context: Context): # TODO want to catch exceptions for the inject calls too to ensure cleanup deps = await asyncio.gather(*[inj.inject(message, context) for inj in injectors]) try: yield deps finally: for dep, inj in zip(deps, injectors): if isinstance(inj, InjectorWithCleanup): await inj.cleanup(dep)