from typing import Generic, TypeVar, Type, Optional, Any from collections.abc import AsyncGenerator import dataclasses import json from aiosqlite import Connection from ..types import Message, Context from .base import Injector, InjectorWithCleanup __all__ = [ "Database", "Data", ] class DatabaseInjector(InjectorWithCleanup[Connection]): async def inject(self, message: Message, context: Context) -> Connection: return await context.database() async def cleanup(self, conn: Connection): await conn.close() Database = DatabaseInjector() DataType = TypeVar("DataType") class DataStore(Generic[DataType]): def __init__(self, datatype: Type[DataType], connection: Connection): if not dataclasses.is_dataclass(datatype): raise ValueError self.datatype = datatype self.connection = connection self.table_name = "".join( ("_" + c.lower()) if "A" <= c <= "Z" else c for c in datatype.__name__ if c.isalnum() ).strip("_") async def _setup(self): await self.connection.execute( f'CREATE TABLE IF NOT EXISTS {self.table_name} ( \ key TEXT NOT NULL PRIMARY KEY, \ body TEXT DEFAULT "" \ )' ) await self.connection.commit() async def load(self, key: str) -> Optional[DataType]: async with self.connection.execute( f"SELECT body FROM {self.table_name} WHERE key = ?", (key,) ) as cursor: found = await cursor.fetchone() if found is None: return found return self.datatype(**json.loads(found[0])) async def load_or(self, key: str, **kw) -> DataType: result = await self.load(key) if result is not None: return result result = self.datatype(**kw) await self.save(key, result) return result async def all(self) -> AsyncGenerator[tuple[str, DataType], None]: async with self.connection.execute(f"SELECT key, body FROM {self.table_name}") as cursor: async for (key, body) in cursor: yield (key, self.datatype(**json.loads(body))) async def save(self, key: str, obj: DataType): blob = json.dumps(dataclasses.asdict(obj)) await self.connection.execute( f"INSERT INTO {self.table_name} VALUES (:key, :body) \ ON CONFLICT(key) DO UPDATE SET body=:body", {"key": key, "body": blob}, ) await self.connection.commit() class DataFor(Injector[Optional[DataType]]): def __init__(self, datatype: Type[DataType], key: Injector[str], kwargs: dict[str, Any]): self.datatype = datatype self.key = key self.kwargs = kwargs async def inject(self, message: Message, context: Context) -> DataStore[DataType]: key = await self.key.inject(message, context) async with context.database() as db: store = DataStore(self.datatype, db) return await store.load_or(key, **self.kwargs) class Data(InjectorWithCleanup[DataStore[DataType]]): def __init__(self, datatype: Type[DataType]): self.datatype = datatype self.For = lambda key, **kw: DataFor(datatype, key, kw) @staticmethod async def initialize(datatype: Type[DataType], connection: Connection): store = DataStore(datatype, connection) await store._setup() async def inject(self, message: Message, context: Context) -> DataStore[DataType]: store = DataStore(self.datatype, await context.database()) return store async def cleanup(self, store: DataStore[DataType]): await store.connection.close()