data.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. from typing import Generic, TypeVar, Type, Optional, Any
  2. from collections.abc import AsyncGenerator
  3. import dataclasses
  4. import json
  5. from aiosqlite import Connection
  6. from ..types import Message, Context
  7. from .base import Injector, InjectorWithCleanup
  8. __all__ = [
  9. "Database",
  10. "Data",
  11. ]
  12. class DatabaseInjector(InjectorWithCleanup[Connection]):
  13. async def inject(self, message: Message, context: Context) -> Connection:
  14. return await context.database()
  15. async def cleanup(self, conn: Connection):
  16. await conn.close()
  17. Database = DatabaseInjector()
  18. DataType = TypeVar("DataType")
  19. class DataStore(Generic[DataType]):
  20. def __init__(self, datatype: Type[DataType], connection: Connection):
  21. if not dataclasses.is_dataclass(datatype):
  22. raise ValueError
  23. self.datatype = datatype
  24. self.connection = connection
  25. self.table_name = "".join(("_" + c.lower()) if "A" <= c <= "Z" else c for c in datatype.__name__).strip("_")
  26. async def setup(self):
  27. await self.connection.execute(
  28. f'CREATE TABLE IF NOT EXISTS {self.table_name} ( \
  29. key TEXT NOT NULL PRIMARY KEY, \
  30. body TEXT DEFAULT "" \
  31. )'
  32. )
  33. await self.connection.commit()
  34. async def load(self, key: str) -> Optional[DataType]:
  35. async with self.connection.execute(f"SELECT body FROM {self.table_name} WHERE key = ?", (key,)) as cursor:
  36. found = await cursor.fetchone()
  37. if found is None:
  38. return found
  39. return self.datatype(**json.loads(found[0]))
  40. async def load_or(self, key: str, **kw) -> DataType:
  41. result = await self.load(key)
  42. if result is not None:
  43. return result
  44. result = self.datatype(**kw)
  45. await self.save(key, result)
  46. return result
  47. async def all(self) -> AsyncGenerator[tuple[str, DataType], None]:
  48. async with self.connection.execute(f"SELECT key, body FROM {self.table_name}") as cursor:
  49. async for (key, body) in cursor:
  50. yield (key, self.datatype(**json.loads(body)))
  51. async def save(self, key: str, obj: DataType):
  52. blob = json.dumps(dataclasses.asdict(obj))
  53. await self.connection.execute(
  54. f"INSERT INTO {self.table_name} VALUES (:key, :body) \
  55. ON CONFLICT(key) DO UPDATE SET body=:body",
  56. { "key": key, "body": blob }
  57. )
  58. await self.connection.commit()
  59. class DataFor(Injector[Optional[DataType]]):
  60. def __init__(self, datatype: Type[DataType], key: Injector[str], kwargs: dict[str, Any]):
  61. self.datatype = datatype
  62. self.key = key
  63. self.kwargs = kwargs
  64. async def inject(self, message: Message, context: Context) -> DataStore[DataType]:
  65. key = await self.key.inject(message, context)
  66. async with context.database() as db:
  67. store = DataStore(self.datatype, db)
  68. await store.setup()
  69. return await store.load_or(key, **self.kwargs)
  70. class Data(InjectorWithCleanup[DataStore[DataType]]):
  71. def __init__(self, datatype: Type[DataType]):
  72. self.datatype = datatype
  73. self.For = lambda key, **kw: DataFor(datatype, key, kw)
  74. async def inject(self, message: Message, context: Context) -> DataStore[DataType]:
  75. store = DataStore(self.datatype, await context.database())
  76. await store.setup()
  77. return store
  78. async def cleanup(self, store: DataStore[DataType]):
  79. await store.connection.close()