|
@@ -1,6 +1,8 @@
|
|
from collections.abc import Callable
|
|
from collections.abc import Callable
|
|
from typing import Union
|
|
from typing import Union
|
|
|
|
+from functools import wraps
|
|
import inspect
|
|
import inspect
|
|
|
|
+import asyncio
|
|
|
|
|
|
from ..types import (
|
|
from ..types import (
|
|
Message,
|
|
Message,
|
|
@@ -11,6 +13,7 @@ from ..types import (
|
|
CommandConfiguration,
|
|
CommandConfiguration,
|
|
)
|
|
)
|
|
from .failure import RollbotFailureException
|
|
from .failure import RollbotFailureException
|
|
|
|
+from .injection import Injector, inject_message, inject_context
|
|
|
|
|
|
decorated_startup: list[StartupShutdownType] = []
|
|
decorated_startup: list[StartupShutdownType] = []
|
|
decorated_shutdown: list[StartupShutdownType] = []
|
|
decorated_shutdown: list[StartupShutdownType] = []
|
|
@@ -33,25 +36,40 @@ def as_command(arg: Union[str, Callable]):
|
|
lifted = fn
|
|
lifted = fn
|
|
elif inspect.iscoroutinefunction(fn):
|
|
elif inspect.iscoroutinefunction(fn):
|
|
|
|
|
|
|
|
+ @wraps(fn)
|
|
async def lifted(*args):
|
|
async def lifted(*args):
|
|
yield await fn(*args)
|
|
yield await fn(*args)
|
|
|
|
|
|
elif inspect.isgeneratorfunction(fn):
|
|
elif inspect.isgeneratorfunction(fn):
|
|
|
|
|
|
|
|
+ @wraps(fn)
|
|
async def lifted(*args):
|
|
async def lifted(*args):
|
|
for res in fn(*args):
|
|
for res in fn(*args):
|
|
yield res
|
|
yield res
|
|
|
|
|
|
elif inspect.isfunction(fn):
|
|
elif inspect.isfunction(fn):
|
|
|
|
|
|
|
|
+ @wraps(fn)
|
|
async def lifted(*args):
|
|
async def lifted(*args):
|
|
yield fn(*args)
|
|
yield fn(*args)
|
|
|
|
|
|
else:
|
|
else:
|
|
raise ValueError # TODO details
|
|
raise ValueError # TODO details
|
|
|
|
|
|
|
|
+ injectors = []
|
|
|
|
+ for param in inspect.signature(fn).parameters:
|
|
|
|
+ annot = fn.__annotations__[param]
|
|
|
|
+ if annot == Message:
|
|
|
|
+ injectors.append(inject_message)
|
|
|
|
+ elif annot == Context:
|
|
|
|
+ injectors.append(inject_context)
|
|
|
|
+ elif isinstance(annot, Injector):
|
|
|
|
+ injectors.append(annot.inject)
|
|
|
|
+ else:
|
|
|
|
+ raise ValueError # TODO details
|
|
|
|
+
|
|
async def command_impl(message: Message, context: Context):
|
|
async def command_impl(message: Message, context: Context):
|
|
- args = [] # TODO implement dep injection
|
|
|
|
|
|
+ args = await asyncio.gather(*[inj(message, context) for inj in injectors])
|
|
|
|
|
|
try:
|
|
try:
|
|
async for result in lifted(*args):
|
|
async for result in lifted(*args):
|