123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106 |
- from typing import Generic, TypeVar, Type, Optional, Any
- from collections.abc import AsyncGenerator
- import dataclasses
- import json
- from aiosqlite import Connection
- from ..types import Message, Context
- from .base import Injector, InjectorWithCleanup
- __all__ = [
- "Database",
- "Data",
- ]
- class DatabaseInjector(InjectorWithCleanup[Connection]):
- async def inject(self, message: Message, context: Context) -> Connection:
- return await context.database()
- async def cleanup(self, conn: Connection):
- await conn.close()
- Database = DatabaseInjector()
- DataType = TypeVar("DataType")
- class DataStore(Generic[DataType]):
- def __init__(self, datatype: Type[DataType], connection: Connection):
- if not dataclasses.is_dataclass(datatype):
- raise ValueError
- self.datatype = datatype
- self.connection = connection
- self.table_name = "".join(
- ("_" + c.lower()) if "A" <= c <= "Z" else c for c in datatype.__name__ if c.isalnum()
- ).strip("_")
- async def setup(self):
- await self.connection.execute(
- f'CREATE TABLE IF NOT EXISTS {self.table_name} ( \
- key TEXT NOT NULL PRIMARY KEY, \
- body TEXT DEFAULT "" \
- )'
- )
- await self.connection.commit()
- async def load(self, key: str) -> Optional[DataType]:
- async with self.connection.execute(
- f"SELECT body FROM {self.table_name} WHERE key = ?", (key,)
- ) as cursor:
- found = await cursor.fetchone()
- if found is None:
- return found
- return self.datatype(**json.loads(found[0]))
- async def load_or(self, key: str, **kw) -> DataType:
- result = await self.load(key)
- if result is not None:
- return result
- result = self.datatype(**kw)
- await self.save(key, result)
- return result
- async def all(self) -> AsyncGenerator[tuple[str, DataType], None]:
- async with self.connection.execute(f"SELECT key, body FROM {self.table_name}") as cursor:
- async for (key, body) in cursor:
- yield (key, self.datatype(**json.loads(body)))
- async def save(self, key: str, obj: DataType):
- blob = json.dumps(dataclasses.asdict(obj))
- await self.connection.execute(
- f"INSERT INTO {self.table_name} VALUES (:key, :body) \
- ON CONFLICT(key) DO UPDATE SET body=:body",
- {"key": key, "body": blob},
- )
- await self.connection.commit()
- class DataFor(Injector[Optional[DataType]]):
- def __init__(self, datatype: Type[DataType], key: Injector[str], kwargs: dict[str, Any]):
- self.datatype = datatype
- self.key = key
- self.kwargs = kwargs
- async def inject(self, message: Message, context: Context) -> DataStore[DataType]:
- key = await self.key.inject(message, context)
- async with context.database() as db:
- store = DataStore(self.datatype, db)
- await store.setup()
- return await store.load_or(key, **self.kwargs)
- class Data(InjectorWithCleanup[DataStore[DataType]]):
- def __init__(self, datatype: Type[DataType]):
- self.datatype = datatype
- self.For = lambda key, **kw: DataFor(datatype, key, kw)
- async def inject(self, message: Message, context: Context) -> DataStore[DataType]:
- store = DataStore(self.datatype, await context.database())
- await store.setup()
- return store
- async def cleanup(self, store: DataStore[DataType]):
- await store.connection.close()
|