base.py 802 B

1234567891011121314151617181920212223242526272829
  1. from typing import Generic, TypeVar, Any
  2. from contextlib import asynccontextmanager
  3. from ..types import Message, Context
  4. Dep = TypeVar("DepType")
  5. class Injector(Generic[Dep]):
  6. async def inject(self, message: Message, context: Context) -> Dep:
  7. raise NotImplementedError
  8. class InjectorWithCleanup(Injector[Dep]):
  9. async def cleanup(self, dep: Dep):
  10. raise NotImplementedError
  11. @asynccontextmanager
  12. async def inject_all(injectors: list[Injector[Any]], message: Message, context: Context):
  13. try:
  14. deps = []
  15. for inj in injectors:
  16. deps.append(await inj.inject(message, context))
  17. yield deps
  18. finally:
  19. for dep, inj in zip(deps, injectors):
  20. if isinstance(inj, InjectorWithCleanup):
  21. await inj.cleanup(dep)