repl_driver.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. from datetime import datetime
  2. from dataclasses import dataclass
  3. import asyncio
  4. import rollbot
  5. from rollbot.injection import ArgList, Lazy, Database, Data, Args
  6. @rollbot.initialize_data
  7. @dataclass
  8. class MyCounter:
  9. one: int = 0
  10. two: int = 0
  11. class MyBot(rollbot.Rollbot[str]):
  12. def read_config(self, key):
  13. return key
  14. def parse(self, raw):
  15. return rollbot.Message(
  16. origin_id="REPL",
  17. channel_id=".",
  18. sender_id=".",
  19. timestamp=datetime.now(),
  20. origin_admin=True,
  21. channel_admin=True,
  22. text=raw,
  23. attachments=[],
  24. )
  25. async def respond(self, res):
  26. print(res, flush=True)
  27. async def goodbye_command(message, context):
  28. await context.respond(rollbot.Response.from_message(message, "Goodbye!"))
  29. async def count_command(message, context):
  30. args = message.text.split("count", maxsplit=1)[1].strip().split()
  31. name = args[0] if len(args) > 0 else "main"
  32. async with context.database() as db:
  33. await db.execute(
  34. "INSERT INTO counter VALUES (?, 1) \
  35. ON CONFLICT (name) DO UPDATE SET count=count + 1",
  36. (name,),
  37. )
  38. await db.commit()
  39. async with db.execute("SELECT count FROM counter WHERE name = ?", (name,)) as cursor:
  40. res = (await cursor.fetchone())[0]
  41. await context.respond(rollbot.Response.from_message(message, f"{name} = {res}"))
  42. @rollbot.as_command
  43. async def count2(args: ArgList, connect: Lazy(Database)):
  44. name = args[0] if len(args) > 0 else "main"
  45. db = await connect()
  46. await db.execute(
  47. "INSERT INTO counter VALUES (?, 1) \
  48. ON CONFLICT (name) DO UPDATE SET count=count + 2",
  49. (name,),
  50. )
  51. await db.commit()
  52. async with db.execute("SELECT count FROM counter WHERE name = ?", (name,)) as cursor:
  53. res = (await cursor.fetchone())[0]
  54. return f"{name} = {res}"
  55. @rollbot.as_command
  56. async def count3(counter: Data(MyCounter).For(Args), store: Data(MyCounter), args: Args):
  57. counter.one += 1
  58. counter.two += 2
  59. await store.save(args, counter)
  60. return f"{args} = {counter}"
  61. @rollbot.on_startup
  62. async def make_table(context):
  63. async with context.database() as db:
  64. await db.execute(
  65. "CREATE TABLE IF NOT EXISTS counter ( \
  66. name TEXT NOT NULL PRIMARY KEY, \
  67. count INTEGER NOT NULL DEFAULT 0 \
  68. );"
  69. )
  70. await db.commit()
  71. @rollbot.on_shutdown
  72. async def shutdown(context):
  73. await context.respond(rollbot.Response(origin_id="REPL", channel_id=".", text="Shutting down!"))
  74. @rollbot.as_command
  75. def simple():
  76. return "Simple!"
  77. @rollbot.as_command
  78. def generator():
  79. yield "This is"
  80. yield "a generator!"
  81. @rollbot.as_command
  82. async def coroutine():
  83. await asyncio.sleep(1.0)
  84. return "Here's a coroutine!"
  85. @rollbot.as_command
  86. async def asyncgen():
  87. yield "This is"
  88. await asyncio.sleep(0.5)
  89. yield "an async"
  90. await asyncio.sleep(0.5)
  91. yield "generator!"
  92. config = rollbot.get_command_config().extend(
  93. rollbot.CommandConfiguration(
  94. commands={
  95. "goodbye": goodbye_command,
  96. "count": count_command,
  97. },
  98. call_and_response={
  99. "hello": "Hello!",
  100. },
  101. aliases={
  102. "hi": "hello",
  103. "bye": "goodbye",
  104. },
  105. bangs=("/",),
  106. )
  107. )
  108. bot = MyBot(config, "/tmp/my.db")
  109. async def run():
  110. await bot.on_startup()
  111. try:
  112. while True:
  113. await bot.on_message(input("> "))
  114. except EOFError:
  115. pass
  116. await bot.on_shutdown()
  117. asyncio.run(run())