|
@@ -0,0 +1,96 @@
|
|
|
+from collections.abc import Callable, AsyncGenerator
|
|
|
+from typing import Union, Any
|
|
|
+from functools import wraps
|
|
|
+import inspect
|
|
|
+import asyncio
|
|
|
+import dataclasses
|
|
|
+import json
|
|
|
+
|
|
|
+from ..types import (
|
|
|
+ Message,
|
|
|
+ Context,
|
|
|
+ CommandType,
|
|
|
+ Response,
|
|
|
+ StartupShutdownType,
|
|
|
+ CommandConfiguration,
|
|
|
+)
|
|
|
+from .error_handling import with_failure_handling
|
|
|
+from .injection import Injector, inject_message, inject_context, inject_all
|
|
|
+
|
|
|
+decorated_commands: dict[str, CommandType] = {}
|
|
|
+
|
|
|
+
|
|
|
+def _lift_command_fn(fn: Callable[..., Any]) -> Callable[..., AsyncGenerator[Any, None]]:
|
|
|
+ if inspect.isasyncgenfunction(fn):
|
|
|
+ lifted = fn
|
|
|
+ elif inspect.iscoroutinefunction(fn):
|
|
|
+
|
|
|
+ @wraps(fn)
|
|
|
+ async def lifted(*args):
|
|
|
+ yield await fn(*args)
|
|
|
+
|
|
|
+ elif inspect.isgeneratorfunction(fn):
|
|
|
+
|
|
|
+ @wraps(fn)
|
|
|
+ async def lifted(*args):
|
|
|
+ for res in fn(*args):
|
|
|
+ yield res
|
|
|
+
|
|
|
+ elif inspect.isfunction(fn):
|
|
|
+
|
|
|
+ @wraps(fn)
|
|
|
+ async def lifted(*args):
|
|
|
+ yield fn(*args)
|
|
|
+
|
|
|
+ else:
|
|
|
+ raise ValueError # TODO details
|
|
|
+
|
|
|
+ return lifted
|
|
|
+
|
|
|
+
|
|
|
+def _get_injectors(fn: Callable[..., Any]) -> list[Injector]:
|
|
|
+ injectors = []
|
|
|
+ for param in inspect.signature(fn).parameters:
|
|
|
+ annot = fn.__annotations__[param]
|
|
|
+ if annot == Message:
|
|
|
+ injectors.append(inject_message)
|
|
|
+ elif annot == Context:
|
|
|
+ injectors.append(inject_context)
|
|
|
+ elif isinstance(annot, Injector):
|
|
|
+ injectors.append(annot.inject)
|
|
|
+ else:
|
|
|
+ raise ValueError # TODO details
|
|
|
+ return injectors
|
|
|
+
|
|
|
+
|
|
|
+def _make_response(result: Any) -> Response:
|
|
|
+ if result is None or isinstance(result, Response):
|
|
|
+ return result
|
|
|
+ elif isinstance(result, str):
|
|
|
+ return Response.from_message(message, text=result)
|
|
|
+ # TODO handle attachments, other special returns
|
|
|
+ else:
|
|
|
+ return Response.from_message(message, text=str(result))
|
|
|
+
|
|
|
+
|
|
|
+def _on_command_impl(name: str, fn: Callable[..., Any]) -> Callable[..., Any]:
|
|
|
+ lifted = _lift_command_fn(fn)
|
|
|
+ injectors = _get_injectors(fn)
|
|
|
+
|
|
|
+ @with_failure_handling
|
|
|
+ @wraps(lifted)
|
|
|
+ async def injected_command(message: Message, context: Context):
|
|
|
+ async with inject_all(injectors, message, context) as args:
|
|
|
+ async for result in lifted(*args):
|
|
|
+ if (response := _make_response(result)) is not None:
|
|
|
+ await context.respond(response)
|
|
|
+
|
|
|
+ decorated_commands[name] = injected_command
|
|
|
+ return fn
|
|
|
+
|
|
|
+
|
|
|
+def as_command(arg: Union[str, Callable[[Message, Context], Any]]):
|
|
|
+ if isinstance(arg, str):
|
|
|
+ return lambda fn: _on_command_impl(arg, fn)
|
|
|
+ else:
|
|
|
+ return _on_command_impl(arg.__name__, arg)
|