data.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. from typing import Generic, TypeVar, Type, Optional, Any
  2. from collections.abc import AsyncGenerator
  3. import dataclasses
  4. import json
  5. from aiosqlite import Connection
  6. from ..types import Message, Context
  7. from .base import Injector, InjectorWithCleanup
  8. __all__ = [
  9. "Database",
  10. "Data",
  11. ]
  12. class DatabaseInjector(InjectorWithCleanup[Connection]):
  13. async def inject(self, message: Message, context: Context) -> Connection:
  14. return await context.database()
  15. async def cleanup(self, conn: Connection):
  16. await conn.close()
  17. Database = DatabaseInjector()
  18. DataType = TypeVar("DataType")
  19. class DataStore(Generic[DataType]):
  20. def __init__(self, datatype: Type[DataType], connection: Connection):
  21. if not dataclasses.is_dataclass(datatype):
  22. raise ValueError
  23. self.datatype = datatype
  24. self.connection = connection
  25. self.table_name = "".join(
  26. ("_" + c.lower()) if "A" <= c <= "Z" else c for c in datatype.__name__ if c.isalnum()
  27. ).strip("_")
  28. async def _setup(self):
  29. # TODO virtual columns with indexes for certain types?
  30. await self.connection.execute(
  31. f'CREATE TABLE IF NOT EXISTS {self.table_name} ( \
  32. key TEXT NOT NULL PRIMARY KEY, \
  33. body TEXT DEFAULT "" \
  34. )'
  35. )
  36. await self.connection.commit()
  37. async def load(self, key: str) -> Optional[DataType]:
  38. async with self.connection.execute(
  39. f"SELECT body FROM {self.table_name} WHERE key = ?", (key,)
  40. ) as cursor:
  41. found = await cursor.fetchone()
  42. if found is None:
  43. return found
  44. return self.datatype(**json.loads(found[0]))
  45. async def load_or(self, key: str, **kw) -> DataType:
  46. result = await self.load(key)
  47. if result is not None:
  48. return result
  49. result = self.datatype(**kw)
  50. await self.save(key, result)
  51. return result
  52. async def all(self, **kw) -> AsyncGenerator[tuple[str, DataType], None]:
  53. query = f"SELECT key, body FROM {self.table_name}"
  54. filter_params = []
  55. if len(kw) > 0:
  56. query += " WHERE " + (" AND ".join("json_extract(body, ?) = ?" for _ in range(len(kw))))
  57. for (key, value) in kw.items():
  58. filter_params.append(f"$.{''.join(k for k in key if k.isalnum() or k == '_')}")
  59. filter_params.append(value)
  60. async with self.connection.execute(query, filter_params) as cursor:
  61. async for (key, body) in cursor:
  62. yield (key, self.datatype(**json.loads(body)))
  63. async def save(self, key: str, obj: DataType):
  64. blob = json.dumps(dataclasses.asdict(obj))
  65. await self.connection.execute(
  66. f"INSERT INTO {self.table_name} VALUES (:key, :body) \
  67. ON CONFLICT(key) DO UPDATE SET body=:body",
  68. {"key": key, "body": blob},
  69. )
  70. await self.connection.commit()
  71. class DataFor(Injector[Optional[DataType]]):
  72. def __init__(self, datatype: Type[DataType], key: Injector[str], kwargs: dict[str, Any]):
  73. self.datatype = datatype
  74. self.key = key
  75. self.kwargs = kwargs
  76. async def inject(self, message: Message, context: Context) -> DataStore[DataType]:
  77. key = await self.key.inject(message, context)
  78. async with context.database() as db:
  79. store = DataStore(self.datatype, db)
  80. return await store.load_or(key, **self.kwargs)
  81. class Data(InjectorWithCleanup[DataStore[DataType]]):
  82. def __init__(self, datatype: Type[DataType]):
  83. self.datatype = datatype
  84. self.For = lambda key, **kw: DataFor(datatype, key, kw)
  85. @staticmethod
  86. async def initialize(datatype: Type[DataType], connection: Connection):
  87. store = DataStore(datatype, connection)
  88. await store._setup()
  89. async def inject(self, message: Message, context: Context) -> DataStore[DataType]:
  90. store = DataStore(self.datatype, await context.database())
  91. return store
  92. async def cleanup(self, store: DataStore[DataType]):
  93. await store.connection.close()