Browse Source

Finishing big refactor and implementing cleanup for Injectors

Kirk Trombley 4 years ago
parent
commit
fae2377cf5

+ 3 - 1
lib/rollbot/__init__.py

@@ -1,3 +1,5 @@
 from .bot import Rollbot
+from .failure import RollbotFailure
 from .types import *
-from .command import *
+from .decorators import *
+from .injection import *

+ 0 - 3
lib/rollbot/command/__init__.py

@@ -1,3 +0,0 @@
-from .decorators import as_command, on_startup, on_shutdown, get_command_config
-from .failure import RollbotFailure, with_failure_handling
-from .injection import Args, ArgList, ArgListSplitOn, ArgParse, Database, Lazy

+ 0 - 119
lib/rollbot/command/decorators.py

@@ -1,119 +0,0 @@
-from collections.abc import Callable, AsyncGenerator
-from typing import Union
-from functools import wraps
-import inspect
-import asyncio
-import dataclasses
-import json
-
-from ..types import (
-    Message,
-    Context,
-    CommandType,
-    Response,
-    StartupShutdownType,
-    CommandConfiguration,
-)
-from .failure import with_failure_handling
-from .injection import Injector, inject_message, inject_context, inject_all
-
-decorated_startup: list[StartupShutdownType] = []
-decorated_shutdown: list[StartupShutdownType] = []
-decorated_commands: dict[str, CommandType] = {}
-
-
-def on_startup(fn):
-    decorated_startup.append(fn)
-    return fn
-
-
-def on_shutdown(fn):
-    decorated_shutdown.append(fn)
-    return fn
-
-
-def _lift_command_fn(fn: Callable[..., Any]) -> Callable[..., AsyncGenerator[Any, None]]:
-    if inspect.isasyncgenfunction(fn):
-        lifted = fn
-    elif inspect.iscoroutinefunction(fn):
-
-        @wraps(fn)
-        async def lifted(*args):
-            yield await fn(*args)
-
-    elif inspect.isgeneratorfunction(fn):
-
-        @wraps(fn)
-        async def lifted(*args):
-            for res in fn(*args):
-                yield res
-
-    elif inspect.isfunction(fn):
-
-        @wraps(fn)
-        async def lifted(*args):
-            yield fn(*args)
-
-    else:
-        raise ValueError  # TODO details
-
-    return lifted
-
-
-def _get_injectors(fn: Callable[..., Any]) -> list[Injector]:
-    injectors = []
-    for param in inspect.signature(fn).parameters:
-        annot = fn.__annotations__[param]
-        if annot == Message:
-            injectors.append(inject_message)
-        elif annot == Context:
-            injectors.append(inject_context)
-        elif isinstance(annot, Injector):
-            injectors.append(annot.inject)
-        else:
-            raise ValueError  # TODO details
-    return injectors
-
-
-def _make_response(result: Any) -> Response:
-    if result is None or isinstance(result, Response):
-        return result
-    elif isinstance(result, str):
-        return Response.from_message(message, text=result)
-    # TODO handle attachments, other special returns
-    else:
-        return Response.from_message(message, text=str(result))
-
-
-def _on_command_impl(name: str, fn: Callable[..., Any]) -> Callable[..., Any]:
-    lifted = _lift_command_fn(fn)
-    injectors = _get_injectors(fn)
-
-    @with_failure_handling
-    @wraps(lifted)
-    async def injected_command(message: Message, context: Context):
-        async with inject_all(injectors, message, context) as args:
-            async for result in lifted(*args):
-                if (response := _make_response(result)) is not None:
-                    await context.respond(response)
-
-    decorated_commands[name] = injected_command
-    return fn
-
-
-def as_command(arg: Union[str, Callable[[Message, Context], Any]]):
-    if isinstance(arg, str):
-        return lambda fn: _on_command_impl(arg, fn)
-    else:
-        return _on_command_impl(arg.__name__, arg)
-
-
-def get_command_config() -> CommandConfiguration:
-    return CommandConfiguration(
-        commands=decorated_commands,
-        call_and_response={},
-        aliases={},
-        bangs=(),
-        startup=decorated_startup,
-        shutdown=decorated_shutdown,
-    )

+ 0 - 53
lib/rollbot/command/failure.py

@@ -1,53 +0,0 @@
-from functools import wraps
-from enum import Enum, auto
-
-from ..types import Message, Context, Response, CommandType
-
-
-class RollbotFailureException(BaseException):
-    def __init__(self, failure):
-        super().__init__()
-        self.failure = failure
-
-
-class RollbotFailure(Enum):
-    INVALID_COMMAND = auto()
-    MISSING_SUBCOMMAND = auto()
-    INVALID_SUBCOMMAND = auto()
-    INVALID_ARGUMENTS = auto()
-    SERVICE_DOWN = auto()
-    PERMISSIONS = auto()
-    INTERNAL_ERROR = auto()
-
-    def get_debugging(self):
-        debugging = {}
-        reason = getattr(self, "reason", None)
-        if reason is not None:
-            debugging["explain"] = reason
-        exception = getattr(self, "exception", None)
-        if exception is not None:
-            debugging["exception"] = exception
-        return debugging
-
-    def with_reason(self, reason):
-        self.reason = reason
-        return self
-
-    def with_exception(self, exception):
-        self.exception = exception
-        return self
-
-    def raise_exc(self):
-        raise RollbotFailureException(self)
-
-
-def with_failure_handling(fn: CommandType) -> CommandType:
-    @wraps(fn)
-    async def wrapped(message: Message, context: Context):
-        try:
-            await fn()
-        except RollbotFailureException as exc:
-            # TODO handle errors more specifically
-            await context.respond(Response.from_message(message, str(exc.failure)))
-
-    return wrapped

+ 0 - 104
lib/rollbot/command/injection.py

@@ -1,104 +0,0 @@
-from typing import Generic, TypeVar, Optional, Type
-from argparse import ArgumentParser, Namespace
-from collections.abc import Callable, Coroutine
-from contextlib import asynccontextmanager
-import shlex
-
-from aiosqlite.core import Connection
-
-from ..types import Message, Context
-
-
-async def inject_message(message: Message, context: Context) -> Message:
-    return message
-
-
-async def inject_context(message: Message, context: Context) -> Context:
-    return context
-
-
-Dep = TypeVar("DepType")
-
-
-class Injector(Generic[Dep]):
-    async def inject(self, message: Message, context: Context) -> Dep:
-        raise NotImplemented
-
-
-class InjectorWithCleanup(Injector[Dep]):
-    async def cleanup(self, dep: Dep):
-        raise NotImplemented
-
-
-@asynccontextmanager
-async def inject_all(injectors: list[Injector[Any]], message: Message, context: Context):
-    deps = await asyncio.gather(*[inj(message, context) for inj in injectors])
-    try:
-        yield deps
-    finally:
-        for dep, inj in zip(deps, injectors):
-            if isinstance(inj, InjectorWithCleanup):
-                await inj.cleanup(dep)
-
-
-class ArgsInjector(Injector[str]):
-    async def inject(self, message: Message, context: Context) -> str:
-        return message.command.args
-
-
-class ArgListSplitOn(Injector[list[str]]):
-    def __init__(self, split: Optional[str] = None):
-        self.split = split
-
-    async def inject(self, message: Message, context: Context) -> str:
-        if self.split is not None:
-            return message.command.args.split(self.split)
-        else:
-            return message.command.args.split()
-
-
-class ArgParse(Injector[Namespace]):
-    def __init__(self, parser: ArgumentParser):
-        self.parser = parser
-
-    async def inject(self, message: Message, context: Context) -> Namespace:
-        return self.parser.parse_args(shlex.split(message.text))
-
-
-class DatabaseInjector(InjectorWithCleanup[Connection]):
-    async def inject(self, message: Message, context: Context) -> Connection:
-        return context.database()
-
-    async def cleanup(self, conn: Connection):
-        await conn.close()
-
-
-class Lazy(InjectorWithCleanup[Callable[[], Coroutine[None, None, Dep]]]):
-    def __init__(self, deferred: Injector[Dep]):
-        self.deferred = deferred
-
-    async def inject(
-        self, message: Message, context: Context
-    ) -> Callable[[], Coroutine[None, None, Dep]]:
-        class _Wrapper:
-            def __init__(self, deferred):
-                self._calculated = None
-                async def call():
-                    if self._calculated is None:
-                        self._calculated = await deferred.inject(message, context)
-                    return self._calculated
-                self._call = call
-
-            def __call__(self):
-                return self._call()
-        
-        return _Wrapper(self.deferred)
-
-    async def cleanup(self, dep: Callable[[], Coroutine[None, None, Dep]]):
-        if isinstance(self.deferred, InjectorWithCleanup) and dep._calculated is not None:
-            await self.deferred.cleanup(dep._calculated)
-
-
-Args = ArgsInjector()
-ArgList = ArgListSplitOn()
-Database = DatabaseInjector()

+ 1 - 1
lib/rollbot/decorators/__init__.py

@@ -8,8 +8,8 @@ __all__ = [
     "as_command",
     "on_startup",
     "on_shutdown",
-    "get_command_config",
     "with_failure_handling",
+    "get_command_config",
 ]
 
 

+ 6 - 11
lib/rollbot/decorators/as_command.py

@@ -2,20 +2,15 @@ from collections.abc import Callable, AsyncGenerator
 from typing import Union, Any
 from functools import wraps
 import inspect
-import asyncio
-import dataclasses
-import json
 
 from ..types import (
     Message,
     Context,
     CommandType,
     Response,
-    StartupShutdownType,
-    CommandConfiguration,
 )
+from ..injection import Injector, inject_all, MessageInjector, ContextInjector
 from .error_handling import with_failure_handling
-from .injection import Injector, inject_message, inject_context, inject_all
 
 decorated_commands: dict[str, CommandType] = {}
 
@@ -53,17 +48,17 @@ def _get_injectors(fn: Callable[..., Any]) -> list[Injector]:
     for param in inspect.signature(fn).parameters:
         annot = fn.__annotations__[param]
         if annot == Message:
-            injectors.append(inject_message)
+            injectors.append(MessageInjector)
         elif annot == Context:
-            injectors.append(inject_context)
+            injectors.append(ContextInjector)
         elif isinstance(annot, Injector):
-            injectors.append(annot.inject)
+            injectors.append(annot)
         else:
             raise ValueError  # TODO details
     return injectors
 
 
-def _make_response(result: Any) -> Response:
+def _make_response(message: Message, result: Any) -> Response:
     if result is None or isinstance(result, Response):
         return result
     elif isinstance(result, str):
@@ -82,7 +77,7 @@ def _on_command_impl(name: str, fn: Callable[..., Any]) -> Callable[..., Any]:
     async def injected_command(message: Message, context: Context):
         async with inject_all(injectors, message, context) as args:
             async for result in lifted(*args):
-                if (response := _make_response(result)) is not None:
+                if (response := _make_response(message, result)) is not None:
                     await context.respond(response)
 
     decorated_commands[name] = injected_command

+ 1 - 1
lib/rollbot/decorators/error_handling.py

@@ -8,7 +8,7 @@ def with_failure_handling(fn: CommandType) -> CommandType:
     @wraps(fn)
     async def wrapped(message: Message, context: Context):
         try:
-            await fn()
+            await fn(message, context)
         except RollbotFailureException as exc:
             # TODO handle errors more specifically
             await context.respond(Response.from_message(message, str(exc.failure)))

+ 4 - 0
lib/rollbot/injection/__init__.py

@@ -0,0 +1,4 @@
+from .base import inject_all, Injector, InjectorWithCleanup
+from .args import *
+from .data import *
+from .util import *

+ 42 - 0
lib/rollbot/injection/args.py

@@ -0,0 +1,42 @@
+from argparse import ArgumentParser, Namespace
+from typing import Optional
+import shlex
+
+from ..types import Message, Context
+from .base import Injector
+
+__all__ = [
+    "Args",
+    "ArgList",
+    "ArgListSplitOn",
+    "ArgParse",
+]
+
+
+class ArgsInjector(Injector[str]):
+    async def inject(self, message: Message, context: Context) -> str:
+        return message.command.args
+
+
+class ArgListSplitOn(Injector[list[str]]):
+    def __init__(self, split: Optional[str] = None):
+        self.split = split
+
+    async def inject(self, message: Message, context: Context) -> str:
+        if self.split is not None:
+            return message.command.args.split(self.split)
+        else:
+            return message.command.args.split()
+
+
+class ArgParse(Injector[Namespace]):
+    def __init__(self, parser: ArgumentParser):
+        self.parser = parser
+
+    async def inject(self, message: Message, context: Context) -> Namespace:
+        return self.parser.parse_args(shlex.split(message.text))
+
+
+Args = ArgsInjector()
+ArgList = ArgListSplitOn()
+# TODO Arg(n)

+ 28 - 0
lib/rollbot/injection/base.py

@@ -0,0 +1,28 @@
+from typing import Generic, TypeVar, Any
+from contextlib import asynccontextmanager
+import asyncio
+
+from ..types import Message, Context
+
+Dep = TypeVar("DepType")
+
+
+class Injector(Generic[Dep]):
+    async def inject(self, message: Message, context: Context) -> Dep:
+        raise NotImplementedError
+
+
+class InjectorWithCleanup(Injector[Dep]):
+    async def cleanup(self, dep: Dep):
+        raise NotImplementedError
+
+
+@asynccontextmanager
+async def inject_all(injectors: list[Injector[Any]], message: Message, context: Context):
+    deps = await asyncio.gather(*[inj.inject(message, context) for inj in injectors])
+    try:
+        yield deps
+    finally:
+        for dep, inj in zip(deps, injectors):
+            if isinstance(inj, InjectorWithCleanup):
+                await inj.cleanup(dep)

+ 20 - 0
lib/rollbot/injection/data.py

@@ -0,0 +1,20 @@
+from aiosqlite import Connection
+
+from ..types import Message, Context
+from .base import InjectorWithCleanup
+
+__all__ = [
+    "Database",
+]
+
+
+class DatabaseInjector(InjectorWithCleanup[Connection]):
+    async def inject(self, message: Message, context: Context) -> Connection:
+        return await context.database()
+
+    async def cleanup(self, conn: Connection):
+        await conn.close()
+
+
+Database = DatabaseInjector()
+# TODO data store for blob data

+ 55 - 0
lib/rollbot/injection/util.py

@@ -0,0 +1,55 @@
+from collections.abc import Callable, Coroutine
+from typing import TypeVar
+
+from ..types import Message, Context
+from .base import Injector, InjectorWithCleanup
+
+__all__ = [
+    "Lazy",
+    "MessageInjector",
+    "ContextInjector",
+]
+
+
+class MessageInjector(Injector[Message]):
+    async def inject(self, message: Message, context: Context) -> Message:
+        return message
+
+MessageInjector = MessageInjector()
+
+
+class ContextInjector(Injector[Context]):
+    async def inject(self, message: Message, context: Context) -> Context:
+        return context
+
+ContextInjector = ContextInjector()
+
+Dep = TypeVar("Dep")
+
+
+class Lazy(InjectorWithCleanup[Callable[[], Coroutine[None, None, Dep]]]):
+    def __init__(self, deferred: Injector[Dep]):
+        self.deferred = deferred
+
+    async def inject(
+        self, message: Message, context: Context
+    ) -> Callable[[], Coroutine[None, None, Dep]]:
+        class _Wrapper:
+            def __init__(self, deferred):
+                self._calculated = None
+                async def call():
+                    if self._calculated is None:
+                        self._calculated = await deferred.inject(message, context)
+                    return self._calculated
+                self._call = call
+
+            def __call__(self):
+                return self._call()
+        
+        return _Wrapper(self.deferred)
+
+    async def cleanup(self, dep: Callable[[], Coroutine[None, None, Dep]]):
+        if isinstance(self.deferred, InjectorWithCleanup) and dep._calculated is not None:
+            await self.deferred.cleanup(dep._calculated)
+
+

+ 11 - 0
lib/rollbot/types.py

@@ -7,6 +7,17 @@ from typing import Union, Any, Optional
 
 from aiosqlite.core import Connection
 
+__all__ = [
+    "Attachment",
+    "Message",
+    "Command",
+    "Response",
+    "Context",
+    "CommandType",
+    "StartupShutdownType",
+    "CommandConfiguration",
+]
+
 
 @dataclass
 class Attachment:

+ 10 - 15
repl_driver.py

@@ -48,21 +48,18 @@ async def count_command(message, context):
 @rollbot.as_command
 async def count2(args: rollbot.ArgList, connect: rollbot.Lazy(rollbot.Database)):
     name = args[0] if len(args) > 0 else "main"
-    async with await connect() as db:
-        await db.execute(
-            "INSERT INTO counter VALUES (?, 1) \
-            ON CONFLICT (name) DO UPDATE SET count=count + 2",
-            (name,),
-        )
-        await db.commit()
-        async with db.execute(
-            "SELECT count FROM counter WHERE name = ?", (name,)
-        ) as cursor:
-            res = (await cursor.fetchone())[0]
+    db = await connect()
+    await db.execute(
+        "INSERT INTO counter VALUES (?, 1) \
+        ON CONFLICT (name) DO UPDATE SET count=count + 2",
+        (name,),
+    )
+    await db.commit()
+    async with db.execute("SELECT count FROM counter WHERE name = ?", (name,)) as cursor:
+        res = (await cursor.fetchone())[0]
     return f"{name} = {res}"
 
 
-
 @rollbot.on_startup
 async def make_table(context):
     async with context.database() as db:
@@ -77,9 +74,7 @@ async def make_table(context):
 
 @rollbot.on_shutdown
 async def shutdown(context):
-    await context.respond(
-        rollbot.Response(origin_id="REPL", channel_id=".", text="Shutting down!")
-    )
+    await context.respond(rollbot.Response(origin_id="REPL", channel_id=".", text="Shutting down!"))
 
 
 @rollbot.as_command