瀏覽代碼

First pass at dataclass storage API

Kirk Trombley 4 年之前
父節點
當前提交
14bb88fed2
共有 2 個文件被更改,包括 101 次插入3 次删除
  1. 85 2
      lib/rollbot/injection/data.py
  2. 16 1
      repl_driver.py

+ 85 - 2
lib/rollbot/injection/data.py

@@ -1,10 +1,16 @@
+from typing import Generic, TypeVar, Type, Optional, Any
+from collections.abc import AsyncGenerator
+import dataclasses
+import json
+
 from aiosqlite import Connection
 
 from ..types import Message, Context
-from .base import InjectorWithCleanup
+from .base import Injector, InjectorWithCleanup
 
 __all__ = [
     "Database",
+    "Data",
 ]
 
 
@@ -17,4 +23,81 @@ class DatabaseInjector(InjectorWithCleanup[Connection]):
 
 
 Database = DatabaseInjector()
-# TODO data store for blob data
+
+DataType = TypeVar("DataType")
+
+
+class DataStore(Generic[DataType]):
+    def __init__(self, datatype: Type[DataType], connection: Connection):
+        if not dataclasses.is_dataclass(datatype):
+            raise ValueError
+        self.datatype = datatype
+        self.connection = connection
+        self.table_name = "".join(("_" + c.lower()) if "A" <= c <= "Z" else c for c in datatype.__name__).strip("_")
+
+    async def setup(self):
+        await self.connection.execute(
+            f'CREATE TABLE IF NOT EXISTS {self.table_name} ( \
+                key TEXT NOT NULL PRIMARY KEY, \
+                body TEXT DEFAULT "" \
+            )'
+        )
+        await self.connection.commit()
+
+    async def load(self, key: str) -> Optional[DataType]:
+        async with self.connection.execute(f"SELECT body FROM {self.table_name} WHERE key = ?", (key,)) as cursor:
+            found = await cursor.fetchone()
+        if found is None:
+            return found
+        return self.datatype(**json.loads(found[0]))
+
+    async def load_or(self, key: str, **kw) -> DataType:
+        result = await self.load(key)
+        if result is not None:
+            return result
+        result = self.datatype(**kw)
+        await self.save(key, result)
+        return result
+
+    async def all(self) -> AsyncGenerator[tuple[str, DataType], None]:
+        async with self.connection.execute(f"SELECT key, body FROM {self.table_name}") as cursor:
+            async for (key, body) in cursor:
+                yield (key, self.datatype(**json.loads(body)))
+
+    async def save(self, key: str, obj: DataType):
+        blob = json.dumps(dataclasses.asdict(obj))
+        await self.connection.execute(
+            f"INSERT INTO {self.table_name} VALUES (:key, :body) \
+                ON CONFLICT(key) DO UPDATE SET body=:body",
+            { "key": key, "body": blob }
+        )
+        await self.connection.commit()
+
+
+class DataFor(Injector[Optional[DataType]]):
+    def __init__(self, datatype: Type[DataType], key: Injector[str], kwargs: dict[str, Any]):
+        self.datatype = datatype
+        self.key = key
+        self.kwargs = kwargs
+
+    async def inject(self, message: Message, context: Context) -> DataStore[DataType]:
+        key = await self.key.inject(message, context)
+        async with context.database() as db:
+            store = DataStore(self.datatype, db)
+            await store.setup()
+            return await store.load_or(key, **self.kwargs)
+
+
+class Data(InjectorWithCleanup[DataStore[DataType]]):
+    def __init__(self, datatype: Type[DataType]):
+        self.datatype = datatype
+        self.For = lambda key, **kw: DataFor(datatype, key, kw)
+
+    async def inject(self, message: Message, context: Context) -> DataStore[DataType]:
+        store = DataStore(self.datatype, await context.database())
+        await store.setup()
+        return store
+
+    async def cleanup(self, store: DataStore[DataType]):
+        await store.connection.close()
+

+ 16 - 1
repl_driver.py

@@ -1,8 +1,15 @@
 from datetime import datetime
+from dataclasses import dataclass
 import asyncio
 
 import rollbot
-from rollbot.injection import ArgList, Lazy, Database
+from rollbot.injection import ArgList, Lazy, Database, Data, Args
+
+
+@dataclass
+class MyCounter:
+    one: int = 0
+    two: int = 0
 
 
 class MyBot(rollbot.Rollbot[str]):
@@ -59,6 +66,14 @@ async def count2(args: ArgList, connect: Lazy(Database)):
     return f"{name} = {res}"
 
 
+@rollbot.as_command
+async def count3(counter: Data(MyCounter).For(Args), store: Data(MyCounter), args: Args):
+    counter.one += 1
+    counter.two += 2
+    await store.save(args, counter)
+    return f"{args} = {counter}"
+
+
 @rollbot.on_startup
 async def make_table(context):
     async with context.database() as db: