Преглед на файлове

Implementing first pass at some injectors

Kirk Trombley преди 4 години
родител
ревизия
0cf7e9e0de
променени са 4 файла, в които са добавени 99 реда и са изтрити 6 реда
  1. 1 0
      lib/rollbot/command/__init__.py
  2. 19 1
      lib/rollbot/command/decorators.py
  3. 57 0
      lib/rollbot/command/injection.py
  4. 22 5
      repl_driver.py

+ 1 - 0
lib/rollbot/command/__init__.py

@@ -1,2 +1,3 @@
 from .decorators import as_command, on_startup, on_shutdown, get_command_config
 from .failure import RollbotFailure
+from .injection import Args, ArgList, ArgListSplitOn, ArgParse, Database

+ 19 - 1
lib/rollbot/command/decorators.py

@@ -1,6 +1,8 @@
 from collections.abc import Callable
 from typing import Union
+from functools import wraps
 import inspect
+import asyncio
 
 from ..types import (
     Message,
@@ -11,6 +13,7 @@ from ..types import (
     CommandConfiguration,
 )
 from .failure import RollbotFailureException
+from .injection import Injector, inject_message, inject_context
 
 decorated_startup: list[StartupShutdownType] = []
 decorated_shutdown: list[StartupShutdownType] = []
@@ -33,25 +36,40 @@ def as_command(arg: Union[str, Callable]):
             lifted = fn
         elif inspect.iscoroutinefunction(fn):
 
+            @wraps(fn)
             async def lifted(*args):
                 yield await fn(*args)
 
         elif inspect.isgeneratorfunction(fn):
 
+            @wraps(fn)
             async def lifted(*args):
                 for res in fn(*args):
                     yield res
 
         elif inspect.isfunction(fn):
 
+            @wraps(fn)
             async def lifted(*args):
                 yield fn(*args)
 
         else:
             raise ValueError  # TODO details
 
+        injectors = []
+        for param in inspect.signature(fn).parameters:
+            annot = fn.__annotations__[param]
+            if annot == Message:
+                injectors.append(inject_message)
+            elif annot == Context:
+                injectors.append(inject_context)
+            elif isinstance(annot, Injector):
+                injectors.append(annot.inject)
+            else:
+                raise ValueError  # TODO details
+
         async def command_impl(message: Message, context: Context):
-            args = []  # TODO implement dep injection
+            args = await asyncio.gather(*[inj(message, context) for inj in injectors])
 
             try:
                 async for result in lifted(*args):

+ 57 - 0
lib/rollbot/command/injection.py

@@ -0,0 +1,57 @@
+from typing import Generic, TypeVar, Optional
+from argparse import ArgumentParser, Namespace
+import shlex
+
+from aiosqlite.core import Connection
+
+from ..types import Message, Context
+
+
+async def inject_message(message: Message, context: Context) -> Message:
+    return message
+
+
+async def inject_context(message: Message, context: Context) -> Context:
+    return context
+
+
+Dep = TypeVar("DepType")
+
+
+class Injector(Generic[Dep]):
+    async def inject(self, message: Message, context: Context) -> Dep:
+        raise NotImplemented
+
+
+class ArgsInjector(Injector[str]):
+    async def inject(self, message: Message, context: Context) -> str:
+        return message.command.args
+
+
+class ArgListSplitOn(Injector[list[str]]):
+    def __init__(self, split: Optional[str] = None):
+        self.split = split
+
+    async def inject(self, message: Message, context: Context) -> str:
+        if self.split is not None:
+            return message.command.args.split(self.split)
+        else:
+            return message.command.args.split()
+
+
+class ArgParse(Injector[Namespace]):
+    def __init__(self, parser: ArgumentParser):
+        self.parser = parser
+
+    async def inject(self, message: Message, context: Context) -> Namespace:
+        return self.parser.parse_args(shlex.split(message.text))
+
+
+class DatabaseInjector(Injector[Connection]):
+    async def inject(self, message: Message, context: Context) -> Connection:
+        return context.database()
+
+
+Args = ArgsInjector()
+ArgList = ArgListSplitOn()
+Database = DatabaseInjector()

+ 22 - 5
repl_driver.py

@@ -34,8 +34,7 @@ async def count_command(message, context):
     async with context.database() as db:
         await db.execute(
             "INSERT INTO counter VALUES (?, 1) \
-                          ON CONFLICT (name) DO \
-                          UPDATE SET count=count + 1",
+            ON CONFLICT (name) DO UPDATE SET count=count + 1",
             (name,),
         )
         await db.commit()
@@ -46,14 +45,32 @@ async def count_command(message, context):
     await context.respond(rollbot.Response.from_message(message, f"{name} = {res}"))
 
 
+@rollbot.as_command
+async def count2(args: rollbot.ArgList, db: rollbot.Database):
+    name = args[0] if len(args) > 0 else "main"
+    async with db:
+        await db.execute(
+            "INSERT INTO counter VALUES (?, 1) \
+            ON CONFLICT (name) DO UPDATE SET count=count + 2",
+            (name,),
+        )
+        await db.commit()
+        async with db.execute(
+            "SELECT count FROM counter WHERE name = ?", (name,)
+        ) as cursor:
+            res = (await cursor.fetchone())[0]
+    return f"{name} = {res}"
+
+
+
 @rollbot.on_startup
 async def make_table(context):
     async with context.database() as db:
         await db.execute(
             "CREATE TABLE IF NOT EXISTS counter ( \
-                            name TEXT NOT NULL PRIMARY KEY, \
-                            count INTEGER NOT NULL DEFAULT 0 \
-                        );"
+                name TEXT NOT NULL PRIMARY KEY, \
+                count INTEGER NOT NULL DEFAULT 0 \
+            );"
         )
         await db.commit()