|
@@ -1,10 +1,16 @@
|
|
|
+from typing import Generic, TypeVar, Type, Optional, Any
|
|
|
+from collections.abc import AsyncGenerator
|
|
|
+import dataclasses
|
|
|
+import json
|
|
|
+
|
|
|
from aiosqlite import Connection
|
|
|
|
|
|
from ..types import Message, Context
|
|
|
-from .base import InjectorWithCleanup
|
|
|
+from .base import Injector, InjectorWithCleanup
|
|
|
|
|
|
__all__ = [
|
|
|
"Database",
|
|
|
+ "Data",
|
|
|
]
|
|
|
|
|
|
|
|
@@ -17,4 +23,81 @@ class DatabaseInjector(InjectorWithCleanup[Connection]):
|
|
|
|
|
|
|
|
|
Database = DatabaseInjector()
|
|
|
-# TODO data store for blob data
|
|
|
+
|
|
|
+DataType = TypeVar("DataType")
|
|
|
+
|
|
|
+
|
|
|
+class DataStore(Generic[DataType]):
|
|
|
+ def __init__(self, datatype: Type[DataType], connection: Connection):
|
|
|
+ if not dataclasses.is_dataclass(datatype):
|
|
|
+ raise ValueError
|
|
|
+ self.datatype = datatype
|
|
|
+ self.connection = connection
|
|
|
+ self.table_name = "".join(("_" + c.lower()) if "A" <= c <= "Z" else c for c in datatype.__name__).strip("_")
|
|
|
+
|
|
|
+ async def setup(self):
|
|
|
+ await self.connection.execute(
|
|
|
+ f'CREATE TABLE IF NOT EXISTS {self.table_name} ( \
|
|
|
+ key TEXT NOT NULL PRIMARY KEY, \
|
|
|
+ body TEXT DEFAULT "" \
|
|
|
+ )'
|
|
|
+ )
|
|
|
+ await self.connection.commit()
|
|
|
+
|
|
|
+ async def load(self, key: str) -> Optional[DataType]:
|
|
|
+ async with self.connection.execute(f"SELECT body FROM {self.table_name} WHERE key = ?", (key,)) as cursor:
|
|
|
+ found = await cursor.fetchone()
|
|
|
+ if found is None:
|
|
|
+ return found
|
|
|
+ return self.datatype(**json.loads(found[0]))
|
|
|
+
|
|
|
+ async def load_or(self, key: str, **kw) -> DataType:
|
|
|
+ result = await self.load(key)
|
|
|
+ if result is not None:
|
|
|
+ return result
|
|
|
+ result = self.datatype(**kw)
|
|
|
+ await self.save(key, result)
|
|
|
+ return result
|
|
|
+
|
|
|
+ async def all(self) -> AsyncGenerator[tuple[str, DataType], None]:
|
|
|
+ async with self.connection.execute(f"SELECT key, body FROM {self.table_name}") as cursor:
|
|
|
+ async for (key, body) in cursor:
|
|
|
+ yield (key, self.datatype(**json.loads(body)))
|
|
|
+
|
|
|
+ async def save(self, key: str, obj: DataType):
|
|
|
+ blob = json.dumps(dataclasses.asdict(obj))
|
|
|
+ await self.connection.execute(
|
|
|
+ f"INSERT INTO {self.table_name} VALUES (:key, :body) \
|
|
|
+ ON CONFLICT(key) DO UPDATE SET body=:body",
|
|
|
+ { "key": key, "body": blob }
|
|
|
+ )
|
|
|
+ await self.connection.commit()
|
|
|
+
|
|
|
+
|
|
|
+class DataFor(Injector[Optional[DataType]]):
|
|
|
+ def __init__(self, datatype: Type[DataType], key: Injector[str], kwargs: dict[str, Any]):
|
|
|
+ self.datatype = datatype
|
|
|
+ self.key = key
|
|
|
+ self.kwargs = kwargs
|
|
|
+
|
|
|
+ async def inject(self, message: Message, context: Context) -> DataStore[DataType]:
|
|
|
+ key = await self.key.inject(message, context)
|
|
|
+ async with context.database() as db:
|
|
|
+ store = DataStore(self.datatype, db)
|
|
|
+ await store.setup()
|
|
|
+ return await store.load_or(key, **self.kwargs)
|
|
|
+
|
|
|
+
|
|
|
+class Data(InjectorWithCleanup[DataStore[DataType]]):
|
|
|
+ def __init__(self, datatype: Type[DataType]):
|
|
|
+ self.datatype = datatype
|
|
|
+ self.For = lambda key, **kw: DataFor(datatype, key, kw)
|
|
|
+
|
|
|
+ async def inject(self, message: Message, context: Context) -> DataStore[DataType]:
|
|
|
+ store = DataStore(self.datatype, await context.database())
|
|
|
+ await store.setup()
|
|
|
+ return store
|
|
|
+
|
|
|
+ async def cleanup(self, store: DataStore[DataType]):
|
|
|
+ await store.connection.close()
|
|
|
+
|