data.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. from typing import Generic, TypeVar, Type, Optional, Any
  2. from collections.abc import AsyncGenerator
  3. import dataclasses
  4. import json
  5. from aiosqlite import Connection
  6. from ..types import Message, Context
  7. from .base import Injector, InjectorWithCleanup
  8. __all__ = [
  9. "Database",
  10. "Data",
  11. ]
  12. class DatabaseInjector(InjectorWithCleanup[Connection]):
  13. async def inject(self, message: Message, context: Context) -> Connection:
  14. return await context.database()
  15. async def cleanup(self, conn: Connection):
  16. await conn.close()
  17. Database = DatabaseInjector()
  18. DataType = TypeVar("DataType")
  19. class DataStore(Generic[DataType]):
  20. def __init__(self, datatype: Type[DataType], connection: Connection):
  21. if not dataclasses.is_dataclass(datatype):
  22. raise ValueError
  23. self.datatype = datatype
  24. self.connection = connection
  25. self.table_name = "".join(
  26. ("_" + c.lower()) if "A" <= c <= "Z" else c for c in datatype.__name__ if c.isalnum()
  27. ).strip("_")
  28. async def _setup(self):
  29. await self.connection.execute(
  30. f'CREATE TABLE IF NOT EXISTS {self.table_name} ( \
  31. key TEXT NOT NULL PRIMARY KEY, \
  32. body TEXT DEFAULT "" \
  33. )'
  34. )
  35. await self.connection.commit()
  36. async def load(self, key: str) -> Optional[DataType]:
  37. async with self.connection.execute(
  38. f"SELECT body FROM {self.table_name} WHERE key = ?", (key,)
  39. ) as cursor:
  40. found = await cursor.fetchone()
  41. if found is None:
  42. return found
  43. return self.datatype(**json.loads(found[0]))
  44. async def load_or(self, key: str, **kw) -> DataType:
  45. result = await self.load(key)
  46. if result is not None:
  47. return result
  48. result = self.datatype(**kw)
  49. await self.save(key, result)
  50. return result
  51. async def all(self, **kw) -> AsyncGenerator[tuple[str, DataType], None]:
  52. query = f"SELECT key, body FROM {self.table_name}"
  53. filter_params = []
  54. if len(kw) > 0:
  55. query += " WHERE " + (" AND ".join("json_extract(body, ?) = ?" for _ in range(len(kw))))
  56. for (key, value) in kw.items():
  57. filter_params.append(f"$.{''.join(k for k in key if k.isalnum() or k == '_')}")
  58. filter_params.append(value)
  59. async with self.connection.execute(query, filter_params) as cursor:
  60. async for (key, body) in cursor:
  61. yield (key, self.datatype(**json.loads(body)))
  62. async def save(self, key: str, obj: DataType):
  63. blob = json.dumps(dataclasses.asdict(obj))
  64. await self.connection.execute(
  65. f"INSERT INTO {self.table_name} VALUES (:key, :body) \
  66. ON CONFLICT(key) DO UPDATE SET body=:body",
  67. {"key": key, "body": blob},
  68. )
  69. await self.connection.commit()
  70. class DataFor(Injector[Optional[DataType]]):
  71. def __init__(self, datatype: Type[DataType], key: Injector[str], kwargs: dict[str, Any]):
  72. self.datatype = datatype
  73. self.key = key
  74. self.kwargs = kwargs
  75. async def inject(self, message: Message, context: Context) -> DataStore[DataType]:
  76. key = await self.key.inject(message, context)
  77. async with context.database() as db:
  78. store = DataStore(self.datatype, db)
  79. return await store.load_or(key, **self.kwargs)
  80. class Data(InjectorWithCleanup[DataStore[DataType]]):
  81. def __init__(self, datatype: Type[DataType]):
  82. self.datatype = datatype
  83. self.For = lambda key, **kw: DataFor(datatype, key, kw)
  84. @staticmethod
  85. async def initialize(datatype: Type[DataType], connection: Connection):
  86. store = DataStore(datatype, connection)
  87. await store._setup()
  88. async def inject(self, message: Message, context: Context) -> DataStore[DataType]:
  89. store = DataStore(self.datatype, await context.database())
  90. return store
  91. async def cleanup(self, store: DataStore[DataType]):
  92. await store.connection.close()