repl_driver.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. from datetime import datetime
  2. from logging import Logger
  3. import logging.config
  4. import asyncio
  5. import rollbot
  6. from rollbot.injection import ArgList, Lazy, Database, Data, Args, Arg, Subcommand
  7. logging.config.fileConfig('logging.conf', disable_existing_loggers=False)
  8. @rollbot.initialize_data
  9. class MyCounter:
  10. one: int = 0
  11. two: int = 0
  12. class MyBot(rollbot.Rollbot[str]):
  13. def read_config(self, key):
  14. return key
  15. def parse(self, raw):
  16. return rollbot.Message(
  17. origin_id="REPL",
  18. channel_id=".",
  19. sender_id=".",
  20. timestamp=datetime.now(),
  21. origin_admin=True,
  22. channel_admin=True,
  23. text=raw,
  24. attachments=[],
  25. )
  26. async def respond(self, res):
  27. print(res, flush=True)
  28. async def goodbye_command(message, context):
  29. await context.respond(rollbot.Response.from_message(message, "Goodbye!"))
  30. async def count_command(message, context):
  31. args = message.text.split("count", maxsplit=1)[1].strip().split()
  32. name = args[0] if len(args) > 0 else "main"
  33. async with context.database() as db:
  34. await db.execute(
  35. "INSERT INTO counter VALUES (?, 1) \
  36. ON CONFLICT (name) DO UPDATE SET count=count + 1",
  37. (name,),
  38. )
  39. await db.commit()
  40. async with db.execute("SELECT count FROM counter WHERE name = ?", (name,)) as cursor:
  41. res = (await cursor.fetchone())[0]
  42. await context.respond(rollbot.Response.from_message(message, f"{name} = {res}"))
  43. @rollbot.as_command
  44. async def count2(name: Arg(0, required=False, default="main"), connect: Lazy(Database)):
  45. db = await connect()
  46. await db.execute(
  47. "INSERT INTO counter VALUES (?, 1) \
  48. ON CONFLICT (name) DO UPDATE SET count=count + 2",
  49. (name,),
  50. )
  51. await db.commit()
  52. async with db.execute("SELECT count FROM counter WHERE name = ?", (name,)) as cursor:
  53. res = (await cursor.fetchone())[0]
  54. return f"{name} = {res}"
  55. @rollbot.as_command
  56. async def count3(counter: Data(MyCounter).For(Args), store: Data(MyCounter), args: Args):
  57. counter.one += 1
  58. counter.two += 2
  59. await store.save(args, counter)
  60. return f"{args} = {counter}"
  61. @rollbot.as_command
  62. async def count6(counters: Data(MyCounter)):
  63. async for (key, counter) in counters.all(one=6):
  64. yield f"{key} = {counter}"
  65. async for (key, counter) in counters.all(two=6):
  66. yield f"{key} = {counter}"
  67. @rollbot.as_command
  68. async def lscounters(counters: Data(MyCounter)):
  69. async for (key, counter) in counters.all():
  70. yield f"{key} = {counter}"
  71. @rollbot.as_command("count.")
  72. async def count_improved(
  73. subc: Subcommand,
  74. name: Subcommand.Arg(0, required=False, default="main"),
  75. counter: Data(MyCounter).For(Subcommand.Arg(0, required=False, default="main")),
  76. store: Data(MyCounter),
  77. log: Logger,
  78. ):
  79. if subc.name == "up":
  80. counter.one += 1
  81. counter.two += 2
  82. elif subc.name == "down":
  83. counter.one -= 1
  84. counter.two -= 2
  85. elif subc.name == "show":
  86. yield f"{name} = {counter}"
  87. log.info(f"Saving {counter} under {name}")
  88. await store.save(name, counter)
  89. @rollbot.on_startup
  90. async def make_table(context):
  91. async with context.database() as db:
  92. await db.execute(
  93. "CREATE TABLE IF NOT EXISTS counter ( \
  94. name TEXT NOT NULL PRIMARY KEY, \
  95. count INTEGER NOT NULL DEFAULT 0 \
  96. );"
  97. )
  98. await db.commit()
  99. @rollbot.on_shutdown
  100. async def shutdown(context):
  101. await context.respond(rollbot.Response(origin_id="REPL", channel_id=".", text="Shutting down!"))
  102. @rollbot.as_command
  103. def simple():
  104. return "Simple!"
  105. @rollbot.as_command
  106. def generator():
  107. yield "This is"
  108. yield "a generator!"
  109. @rollbot.as_command
  110. async def coroutine():
  111. await asyncio.sleep(1.0)
  112. return "Here's a coroutine!"
  113. @rollbot.as_command
  114. async def asyncgen():
  115. yield "This is"
  116. await asyncio.sleep(0.5)
  117. yield "an async"
  118. await asyncio.sleep(0.5)
  119. yield "generator!"
  120. config = rollbot.get_command_config().extend(
  121. rollbot.CommandConfiguration(
  122. commands={
  123. "goodbye": goodbye_command,
  124. "count": count_command,
  125. },
  126. call_and_response={
  127. "hello": "Hello!",
  128. },
  129. aliases={
  130. "hi": "hello",
  131. "bye": "goodbye",
  132. },
  133. bangs=("/",),
  134. )
  135. )
  136. bot = MyBot(config, "/tmp/my.db")
  137. async def run():
  138. await bot.on_startup()
  139. try:
  140. while True:
  141. await bot.on_message(input("> "))
  142. except EOFError:
  143. pass
  144. finally:
  145. await bot.on_shutdown()
  146. asyncio.run(run())