Browse Source

Add database connection and example

Kirk Trombley 4 years ago
parent
commit
3faf32aa1d
3 changed files with 61 additions and 19 deletions
  1. 17 6
      lib/rollbot/bot.py
  2. 9 6
      lib/rollbot/types.py
  3. 35 7
      repl_driver.py

+ 17 - 6
lib/rollbot/bot.py

@@ -1,5 +1,8 @@
 from dataclasses import dataclass
 from typing import Any, Generic, TypeVar
+import asyncio
+
+import aiosqlite
 
 from .types import CommandConfiguration, Message, Response, Context
 
@@ -12,6 +15,13 @@ class Rollbot(Generic[RawMsg]):
     command_config: CommandConfiguration
     database_file: str
 
+    def __post_init__(self):
+        self.context = Context(
+            config=self.read_config,
+            respond=self.respond,
+            database=lambda: aiosqlite.connect(self.database_file),
+        )
+
     def read_config(self, key: str) -> Any:
         raise NotImplemented("Must be implemented by driver")
 
@@ -21,6 +31,12 @@ class Rollbot(Generic[RawMsg]):
     async def respond(self, response: Response):
         raise NotImplemented("Must be implemented by driver")
 
+    async def on_startup(self):
+        await asyncio.gather(*[task(self.context) for task in self.command_config.startup])
+
+    async def on_shutdown(self):
+        await asyncio.gather(*[task(self.context) for task in self.command_config.shutdown])
+
     async def on_message(self, incoming: RawMsg):
         message = self.parse(incoming)
         if message.text is None:
@@ -45,9 +61,4 @@ class Rollbot(Generic[RawMsg]):
             await self.respond(Response.from_message(message, f"Sorry! I don't know the command {command}."))
             return
 
-        await command_call(Context(
-            message=message,
-            config=self.read_config,
-            respond=self.respond,
-            # database=..., # TODO database
-        ))
+        await command_call(message, self.context)

+ 9 - 6
lib/rollbot/types.py

@@ -1,8 +1,10 @@
-from dataclasses import dataclass
+from dataclasses import dataclass, field
 from datetime import datetime
 from collections.abc import Callable, Coroutine, Container
 from typing import Union, Any, Optional
 
+from aiosqlite.core import Connection
+
 
 @dataclass
 class Attachment:
@@ -26,7 +28,7 @@ class Message:
 class Response:
     origin_id: str
     channel_id: str
-    text: str
+    text: Optional[str]
     attachments: list[Attachment]
 
     @staticmethod
@@ -41,13 +43,12 @@ class Response:
 
 @dataclass
 class Context:
-    message: Message
     config: Callable[[str], Any]
     respond: Callable[[], Coroutine[None, None, None]]
-    # database: Callable # TODO proper type
+    database: Callable[[], Coroutine[None, None, Connection]]
 
 
-CommandType = Callable[[Context], Coroutine[None, None, None]]
+CommandType = Callable[[Message, Context], Coroutine[None, None, None]]
 
 
 @dataclass
@@ -55,4 +56,6 @@ class CommandConfiguration:
     commands: dict[str, CommandType]
     call_and_response: dict[str, str]
     aliases: dict[str, str]
-    bangs: Container[str] = ("!",)
+    bangs: Container[str] = ("!",)
+    startup: list[Callable[[Context], Coroutine[None, None, None]]] = field(default_factory=list)
+    shutdown: list[Callable[[Context], Coroutine[None, None, None]]] = field(default_factory=list)

+ 35 - 7
repl_driver.py

@@ -24,14 +24,36 @@ class MyBot(rollbot.Rollbot[str]):
         print(res, flush=True)
 
 
-async def goodbye_command(context):
-    await context.respond(rollbot.Response.from_message(context.message, "Goodbye!"))
+async def goodbye_command(message, context):
+    await context.respond(rollbot.Response.from_message(message, "Goodbye!"))
+
+
+async def count_command(message, context):
+    args = message.text.split("count", maxsplit=1)[1].strip().split()
+    name = args[0] if len(args) > 0 else "main"
+    async with context.database() as db:
+        await db.execute("INSERT INTO counter VALUES (?, 1) \
+                          ON CONFLICT (name) DO \
+                          UPDATE SET count=count + 1", (name,))
+        await db.commit()
+        async with db.execute("SELECT count FROM counter WHERE name = ?", (name,)) as cursor:
+            res = (await cursor.fetchone())[0]
+    await context.respond(rollbot.Response.from_message(message, f"{name} = {res}"))
+
+
+async def make_table(context):
+    async with context.database() as db:
+        await db.execute("CREATE TABLE IF NOT EXISTS counter ( \
+                            name TEXT NOT NULL PRIMARY KEY, \
+                            count INTEGER NOT NULL DEFAULT 0 \
+                        );")
+        await db.commit()
 
 
 config = rollbot.CommandConfiguration(
-    bangs=("/",),
     commands={
         "goodbye": goodbye_command,
+        "count": count_command,
     },
     call_and_response={
         "hello": "Hello!",
@@ -39,15 +61,21 @@ config = rollbot.CommandConfiguration(
     aliases={
         "hi": "hello",
         "bye": "goodbye",
-    }
+    },
+    bangs=("/",),
+    startup=[make_table],
 )
 
 bot = MyBot(config, "/tmp/my.db")
 
-
 async def run():
-    while True:
-        await bot.on_message(input("> "))
+    await bot.on_startup()
+    try:
+        while True:
+            await bot.on_message(input("> "))
+    except EOFError:
+        pass
+    await bot.on_shutdown()
 
 
 asyncio.run(run())