Kirk Trombley 4 anos atrás
pai
commit
07dfbb86db
5 arquivos alterados com 52 adições e 17 exclusões
  1. 14 4
      lib/rollbot/bot.py
  2. 14 1
      lib/rollbot/command/decorators.py
  3. 4 2
      lib/rollbot/types.py
  4. 2 2
      lib/setup.py
  5. 18 8
      repl_driver.py

+ 14 - 4
lib/rollbot/bot.py

@@ -9,7 +9,9 @@ from .types import CommandConfiguration, Message, Response, Context
 # TODO logging
 
 
-RawMsg = TypeVar('RawMsg')
+RawMsg = TypeVar("RawMsg")
+
+
 @dataclass
 class Rollbot(Generic[RawMsg]):
     command_config: CommandConfiguration
@@ -32,10 +34,14 @@ class Rollbot(Generic[RawMsg]):
         raise NotImplemented("Must be implemented by driver")
 
     async def on_startup(self):
-        await asyncio.gather(*[task(self.context) for task in self.command_config.startup])
+        await asyncio.gather(
+            *[task(self.context) for task in self.command_config.startup]
+        )
 
     async def on_shutdown(self):
-        await asyncio.gather(*[task(self.context) for task in self.command_config.shutdown])
+        await asyncio.gather(
+            *[task(self.context) for task in self.command_config.shutdown]
+        )
 
     async def on_message(self, incoming: RawMsg):
         message = self.parse(incoming)
@@ -58,7 +64,11 @@ class Rollbot(Generic[RawMsg]):
 
         command_call = self.command_config.commands.get(command, None)
         if command_call is None:
-            await self.respond(Response.from_message(message, f"Sorry! I don't know the command {command}."))
+            await self.respond(
+                Response.from_message(
+                    message, f"Sorry! I don't know the command {command}."
+                )
+            )
             return
 
         await command_call(message, self.context)

+ 14 - 1
lib/rollbot/command/decorators.py

@@ -2,7 +2,14 @@ from collections.abc import Callable
 from typing import Union
 import inspect
 
-from ..types import Message, Context, CommandType, Response, StartupShutdownType, CommandConfiguration
+from ..types import (
+    Message,
+    Context,
+    CommandType,
+    Response,
+    StartupShutdownType,
+    CommandConfiguration,
+)
 from .failure import RollbotFailureException
 
 decorated_startup: list[StartupShutdownType] = []
@@ -25,15 +32,21 @@ def as_command(arg: Union[str, Callable]):
         if inspect.isasyncgenfunction(fn):
             lifted = fn
         elif inspect.iscoroutinefunction(fn):
+
             async def lifted(*args):
                 yield await fn(*args)
+
         elif inspect.isgeneratorfunction(fn):
+
             async def lifted(*args):
                 for res in fn(*args):
                     yield res
+
         elif inspect.isfunction(fn):
+
             async def lifted(*args):
                 yield fn(*args)
+
         else:
             raise ValueError  # TODO details
 

+ 4 - 2
lib/rollbot/types.py

@@ -32,7 +32,9 @@ class Response:
     attachments: Optional[list[Attachment]] = None
 
     @staticmethod
-    def from_message(msg: Message, text: Optional[str] = None, attachments: list[Attachment] = None) -> "Response":
+    def from_message(
+        msg: Message, text: Optional[str] = None, attachments: list[Attachment] = None
+    ) -> "Response":
         return Response(
             origin_id=msg.origin_id,
             channel_id=msg.channel_id,
@@ -78,4 +80,4 @@ class CommandConfiguration:
             bangs=(*self.bangs, *other.bangs),
             startup=[*self.startup, *other.startup],
             shutdown=[*self.shutdown, *other.shutdown],
-        )
+        )

+ 2 - 2
lib/setup.py

@@ -11,5 +11,5 @@ setup(
     packages=["rollbot"],
     install_requires=[
         "aiosqlite",
-    ]
-)
+    ],
+)

+ 18 - 8
repl_driver.py

@@ -32,11 +32,16 @@ async def count_command(message, context):
     args = message.text.split("count", maxsplit=1)[1].strip().split()
     name = args[0] if len(args) > 0 else "main"
     async with context.database() as db:
-        await db.execute("INSERT INTO counter VALUES (?, 1) \
+        await db.execute(
+            "INSERT INTO counter VALUES (?, 1) \
                           ON CONFLICT (name) DO \
-                          UPDATE SET count=count + 1", (name,))
+                          UPDATE SET count=count + 1",
+            (name,),
+        )
         await db.commit()
-        async with db.execute("SELECT count FROM counter WHERE name = ?", (name,)) as cursor:
+        async with db.execute(
+            "SELECT count FROM counter WHERE name = ?", (name,)
+        ) as cursor:
             res = (await cursor.fetchone())[0]
     await context.respond(rollbot.Response.from_message(message, f"{name} = {res}"))
 
@@ -44,16 +49,20 @@ async def count_command(message, context):
 @rollbot.on_startup
 async def make_table(context):
     async with context.database() as db:
-        await db.execute("CREATE TABLE IF NOT EXISTS counter ( \
+        await db.execute(
+            "CREATE TABLE IF NOT EXISTS counter ( \
                             name TEXT NOT NULL PRIMARY KEY, \
                             count INTEGER NOT NULL DEFAULT 0 \
-                        );")
+                        );"
+        )
         await db.commit()
 
 
 @rollbot.on_shutdown
 async def shutdown(context):
-    await context.respond(rollbot.Response(origin_id="REPL", channel_id=".", text="Shutting down!"))
+    await context.respond(
+        rollbot.Response(origin_id="REPL", channel_id=".", text="Shutting down!")
+    )
 
 
 @rollbot.as_command
@@ -77,7 +86,7 @@ async def coroutine():
 async def asyncgen():
     yield "This is"
     await asyncio.sleep(0.5)
-    yield "an async" 
+    yield "an async"
     await asyncio.sleep(0.5)
     yield "generator!"
 
@@ -102,6 +111,7 @@ config = rollbot.get_command_config().extend(
 
 bot = MyBot(config, "/tmp/my.db")
 
+
 async def run():
     await bot.on_startup()
     try:
@@ -112,4 +122,4 @@ async def run():
     await bot.on_shutdown()
 
 
-asyncio.run(run())
+asyncio.run(run())