base.py 874 B

1234567891011121314151617181920212223242526272829
  1. from typing import Generic, TypeVar, Any
  2. from contextlib import asynccontextmanager
  3. import asyncio
  4. from ..types import Message, Context
  5. Dep = TypeVar("DepType")
  6. class Injector(Generic[Dep]):
  7. async def inject(self, message: Message, context: Context) -> Dep:
  8. raise NotImplementedError
  9. class InjectorWithCleanup(Injector[Dep]):
  10. async def cleanup(self, dep: Dep):
  11. raise NotImplementedError
  12. @asynccontextmanager
  13. async def inject_all(injectors: list[Injector[Any]], message: Message, context: Context):
  14. # TODO want to catch exceptions for the inject calls too to ensure cleanup
  15. deps = await asyncio.gather(*[inj.inject(message, context) for inj in injectors])
  16. try:
  17. yield deps
  18. finally:
  19. for dep, inj in zip(deps, injectors):
  20. if isinstance(inj, InjectorWithCleanup):
  21. await inj.cleanup(dep)