浏览代码

Add optional filters to Data.all

Kirk Trombley 4 年之前
父节点
当前提交
70133cfb18
共有 2 个文件被更改,包括 25 次插入3 次删除
  1. 9 2
      lib/rollbot/injection/data.py
  2. 16 1
      repl_driver.py

+ 9 - 2
lib/rollbot/injection/data.py

@@ -63,8 +63,15 @@ class DataStore(Generic[DataType]):
         await self.save(key, result)
         return result
 
-    async def all(self) -> AsyncGenerator[tuple[str, DataType], None]:
-        async with self.connection.execute(f"SELECT key, body FROM {self.table_name}") as cursor:
+    async def all(self, **kw) -> AsyncGenerator[tuple[str, DataType], None]:
+        query = f"SELECT key, body FROM {self.table_name}"
+        filter_params = []
+        if len(kw) > 0:
+            query += " WHERE " + (" AND ".join("json_extract(body, ?) = ?" for _ in range(len(kw))))
+            for (key, value) in kw.items():
+                filter_params.append(f"$.{''.join(k for k in key if k.isalnum() or k == '_')}")
+                filter_params.append(value)
+        async with self.connection.execute(query, filter_params) as cursor:
             async for (key, body) in cursor:
                 yield (key, self.datatype(**json.loads(body)))
 

+ 16 - 1
repl_driver.py

@@ -73,6 +73,20 @@ async def count3(counter: Data(MyCounter).For(Args), store: Data(MyCounter), arg
     return f"{args} = {counter}"
 
 
+@rollbot.as_command
+async def count6(counters: Data(MyCounter)):
+    async for (key, counter) in counters.all(one=6):
+        yield f"{key} = {counter}"
+    async for (key, counter) in counters.all(two=6):
+        yield f"{key} = {counter}"
+
+
+@rollbot.as_command
+async def lscounters(counters: Data(MyCounter)):
+    async for (key, counter) in counters.all():
+        yield f"{key} = {counter}"
+
+
 @rollbot.on_startup
 async def make_table(context):
     async with context.database() as db:
@@ -144,7 +158,8 @@ async def run():
             await bot.on_message(input("> "))
     except EOFError:
         pass
-    await bot.on_shutdown()
+    finally:
+        await bot.on_shutdown()
 
 
 asyncio.run(run())