Jelajahi Sumber

Add database connection and example

Kirk Trombley 4 tahun lalu
induk
melakukan
3faf32aa1d
3 mengubah file dengan 61 tambahan dan 19 penghapusan
  1. 17 6
      lib/rollbot/bot.py
  2. 9 6
      lib/rollbot/types.py
  3. 35 7
      repl_driver.py

+ 17 - 6
lib/rollbot/bot.py

@@ -1,5 +1,8 @@
 from dataclasses import dataclass
 from dataclasses import dataclass
 from typing import Any, Generic, TypeVar
 from typing import Any, Generic, TypeVar
+import asyncio
+
+import aiosqlite
 
 
 from .types import CommandConfiguration, Message, Response, Context
 from .types import CommandConfiguration, Message, Response, Context
 
 
@@ -12,6 +15,13 @@ class Rollbot(Generic[RawMsg]):
     command_config: CommandConfiguration
     command_config: CommandConfiguration
     database_file: str
     database_file: str
 
 
+    def __post_init__(self):
+        self.context = Context(
+            config=self.read_config,
+            respond=self.respond,
+            database=lambda: aiosqlite.connect(self.database_file),
+        )
+
     def read_config(self, key: str) -> Any:
     def read_config(self, key: str) -> Any:
         raise NotImplemented("Must be implemented by driver")
         raise NotImplemented("Must be implemented by driver")
 
 
@@ -21,6 +31,12 @@ class Rollbot(Generic[RawMsg]):
     async def respond(self, response: Response):
     async def respond(self, response: Response):
         raise NotImplemented("Must be implemented by driver")
         raise NotImplemented("Must be implemented by driver")
 
 
+    async def on_startup(self):
+        await asyncio.gather(*[task(self.context) for task in self.command_config.startup])
+
+    async def on_shutdown(self):
+        await asyncio.gather(*[task(self.context) for task in self.command_config.shutdown])
+
     async def on_message(self, incoming: RawMsg):
     async def on_message(self, incoming: RawMsg):
         message = self.parse(incoming)
         message = self.parse(incoming)
         if message.text is None:
         if message.text is None:
@@ -45,9 +61,4 @@ class Rollbot(Generic[RawMsg]):
             await self.respond(Response.from_message(message, f"Sorry! I don't know the command {command}."))
             await self.respond(Response.from_message(message, f"Sorry! I don't know the command {command}."))
             return
             return
 
 
-        await command_call(Context(
-            message=message,
-            config=self.read_config,
-            respond=self.respond,
-            # database=..., # TODO database
-        ))
+        await command_call(message, self.context)

+ 9 - 6
lib/rollbot/types.py

@@ -1,8 +1,10 @@
-from dataclasses import dataclass
+from dataclasses import dataclass, field
 from datetime import datetime
 from datetime import datetime
 from collections.abc import Callable, Coroutine, Container
 from collections.abc import Callable, Coroutine, Container
 from typing import Union, Any, Optional
 from typing import Union, Any, Optional
 
 
+from aiosqlite.core import Connection
+
 
 
 @dataclass
 @dataclass
 class Attachment:
 class Attachment:
@@ -26,7 +28,7 @@ class Message:
 class Response:
 class Response:
     origin_id: str
     origin_id: str
     channel_id: str
     channel_id: str
-    text: str
+    text: Optional[str]
     attachments: list[Attachment]
     attachments: list[Attachment]
 
 
     @staticmethod
     @staticmethod
@@ -41,13 +43,12 @@ class Response:
 
 
 @dataclass
 @dataclass
 class Context:
 class Context:
-    message: Message
     config: Callable[[str], Any]
     config: Callable[[str], Any]
     respond: Callable[[], Coroutine[None, None, None]]
     respond: Callable[[], Coroutine[None, None, None]]
-    # database: Callable # TODO proper type
+    database: Callable[[], Coroutine[None, None, Connection]]
 
 
 
 
-CommandType = Callable[[Context], Coroutine[None, None, None]]
+CommandType = Callable[[Message, Context], Coroutine[None, None, None]]
 
 
 
 
 @dataclass
 @dataclass
@@ -55,4 +56,6 @@ class CommandConfiguration:
     commands: dict[str, CommandType]
     commands: dict[str, CommandType]
     call_and_response: dict[str, str]
     call_and_response: dict[str, str]
     aliases: dict[str, str]
     aliases: dict[str, str]
-    bangs: Container[str] = ("!",)
+    bangs: Container[str] = ("!",)
+    startup: list[Callable[[Context], Coroutine[None, None, None]]] = field(default_factory=list)
+    shutdown: list[Callable[[Context], Coroutine[None, None, None]]] = field(default_factory=list)

+ 35 - 7
repl_driver.py

@@ -24,14 +24,36 @@ class MyBot(rollbot.Rollbot[str]):
         print(res, flush=True)
         print(res, flush=True)
 
 
 
 
-async def goodbye_command(context):
-    await context.respond(rollbot.Response.from_message(context.message, "Goodbye!"))
+async def goodbye_command(message, context):
+    await context.respond(rollbot.Response.from_message(message, "Goodbye!"))
+
+
+async def count_command(message, context):
+    args = message.text.split("count", maxsplit=1)[1].strip().split()
+    name = args[0] if len(args) > 0 else "main"
+    async with context.database() as db:
+        await db.execute("INSERT INTO counter VALUES (?, 1) \
+                          ON CONFLICT (name) DO \
+                          UPDATE SET count=count + 1", (name,))
+        await db.commit()
+        async with db.execute("SELECT count FROM counter WHERE name = ?", (name,)) as cursor:
+            res = (await cursor.fetchone())[0]
+    await context.respond(rollbot.Response.from_message(message, f"{name} = {res}"))
+
+
+async def make_table(context):
+    async with context.database() as db:
+        await db.execute("CREATE TABLE IF NOT EXISTS counter ( \
+                            name TEXT NOT NULL PRIMARY KEY, \
+                            count INTEGER NOT NULL DEFAULT 0 \
+                        );")
+        await db.commit()
 
 
 
 
 config = rollbot.CommandConfiguration(
 config = rollbot.CommandConfiguration(
-    bangs=("/",),
     commands={
     commands={
         "goodbye": goodbye_command,
         "goodbye": goodbye_command,
+        "count": count_command,
     },
     },
     call_and_response={
     call_and_response={
         "hello": "Hello!",
         "hello": "Hello!",
@@ -39,15 +61,21 @@ config = rollbot.CommandConfiguration(
     aliases={
     aliases={
         "hi": "hello",
         "hi": "hello",
         "bye": "goodbye",
         "bye": "goodbye",
-    }
+    },
+    bangs=("/",),
+    startup=[make_table],
 )
 )
 
 
 bot = MyBot(config, "/tmp/my.db")
 bot = MyBot(config, "/tmp/my.db")
 
 
-
 async def run():
 async def run():
-    while True:
-        await bot.on_message(input("> "))
+    await bot.on_startup()
+    try:
+        while True:
+            await bot.on_message(input("> "))
+    except EOFError:
+        pass
+    await bot.on_shutdown()
 
 
 
 
 asyncio.run(run())
 asyncio.run(run())