base.py 1.1 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. from collections.abc import Callable
  2. from typing import Generic, TypeVar, Any
  3. from contextlib import asynccontextmanager
  4. from ..types import Message, Context
  5. Dep = TypeVar("DepType")
  6. class Injector(Generic[Dep]):
  7. async def inject(self, message: Message, context: Context) -> Dep:
  8. raise NotImplementedError
  9. class InjectorWithCleanup(Injector[Dep]):
  10. async def cleanup(self, dep: Dep):
  11. raise NotImplementedError
  12. class Simple(Injector[Dep]):
  13. def __init__(self, extract: Callable[[Message, Context], Dep]):
  14. self.extract = extract
  15. async def inject(self, message: Message, context: Context) -> Dep:
  16. return self.extract(message, context)
  17. @asynccontextmanager
  18. async def inject_all(injectors: list[Injector[Any]], message: Message, context: Context):
  19. try:
  20. deps = []
  21. for inj in injectors:
  22. deps.append(await inj.inject(message, context))
  23. yield deps
  24. finally:
  25. for dep, inj in zip(deps, injectors):
  26. if isinstance(inj, InjectorWithCleanup):
  27. await inj.cleanup(dep)