bot.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. import logging
  2. import time
  3. import traceback
  4. from dataclasses import dataclass
  5. from typing import List, Type, Dict, Callable, Any
  6. from .messaging import RollbotResponse, RollbotFailure
  7. from .plugins import as_plugin, RollbotPlugin
  8. from .database import init_db_at_url
  9. def lift_response(call, response):
  10. @as_plugin(call)
  11. def response_func(db, msg):
  12. return RollbotResponse(msg, txt=response)
  13. return response_func
  14. @dataclass
  15. class RollbotConfig:
  16. plugins: List[Type[RollbotPlugin]]
  17. db_url: str
  18. reply_callback: Callable[[str, str], None]
  19. aliases: Dict[str, str]
  20. responses: Dict[str, str]
  21. sleep_time: float
  22. other: Dict[str, Any]
  23. def get(self, key):
  24. c = self.other
  25. for k in key.split("."):
  26. c = c[k]
  27. return c
  28. class Rollbot:
  29. def __init__(self, config, logger=logging.getLogger(__name__)):
  30. self.logger = logger
  31. self.session_manager_factory = lambda: None
  32. self.post_callback = config.reply_callback or (lambda txt, gid: self.logger.info(f"Responding to {gid} with {txt}"))
  33. self.commands = {}
  34. self.to_start = set()
  35. self.to_stop = set()
  36. self.sleep_time = config.sleep_time
  37. self.last_exception = None
  38. self.config = config
  39. self.logger.info("Loading command plugins")
  40. for plugin_class in config.plugins:
  41. plugin_instance = plugin_class(self, logger=logger)
  42. if plugin_instance.command in self.commands:
  43. self.logger.error(f"Duplicate command word '{plugin_instance.command}'")
  44. raise ValueError(f"Duplicate command word '{plugin_instance.command}'")
  45. self.commands[plugin_instance.command] = plugin_instance
  46. if "on_start" in plugin_class.__dict__:
  47. self.to_start.add(plugin_instance)
  48. if "on_shutdown" in plugin_class.__dict__:
  49. self.to_stop.add(plugin_instance)
  50. self.logger.info(f"Finished loading plugins, {len(self.commands)} commands found")
  51. self.logger.info("Loading simple responses")
  52. for cmd, response in config.responses.items():
  53. if cmd in self.commands:
  54. self.logger.error(f"Duplicate command word '{cmd}'")
  55. raise ValueError(f"Duplicate command word '{cmd}'")
  56. self.commands[cmd] = lift_response(cmd, response)(self, logger=logger)
  57. self.logger.info(f"Finished loading simple responses, {len(self.commands)} total commands available")
  58. self.logger.info("Loading aliases")
  59. for alias, cmd in config.aliases.items():
  60. if cmd not in self.commands:
  61. self.logger.error(f"Missing aliased command word '{cmd}'")
  62. raise ValueError(f"Missing aliased command word '{cmd}'")
  63. if alias in self.commands:
  64. self.logger.error(f"Duplicate command word '{alias}'")
  65. raise ValueError(f"Duplicate command word '{alias}'")
  66. self.commands[alias] = self.commands[cmd]
  67. self.logger.info(f"Finished loading aliases, {len(self.commands)} total commands + aliases available")
  68. def init_db(self):
  69. self.session_manager_factory = init_db_at_url(self.config.db_url)
  70. def start_plugins(self):
  71. self.logger.info("Starting plugins")
  72. with self.session_manager_factory() as session:
  73. for cmd in self.to_start:
  74. cmd.on_start(session)
  75. self.logger.info("Finished starting plugins")
  76. def shutdown_plugins(self):
  77. self.logger.info("Shutting down plugins")
  78. with self.session_manager_factory() as session:
  79. for cmd in self.to_stop:
  80. cmd.on_shutdown(session)
  81. self.logger.info("Finished shutting down plugins")
  82. def run_command(self, message):
  83. if not message.is_command:
  84. self.logger.warn(f"Tried to run non-command message {message.message_id}")
  85. return RollbotResponse(message, failure=RollbotFailure.INTERNAL_ERROR)
  86. if message.command == "help":
  87. topic = next(message.args())
  88. targeted = self.commands.get(topic, None)
  89. if targeted is None:
  90. return RollbotResponse(message, failure=RollbotFailure.INVALID_ARGUMENTS, debugging={"explain": f"Could not find command {topic}"})
  91. return RollbotResponse(message, txt=targeted.help_msg())
  92. plugin = self.commands.get(message.command, None)
  93. if plugin is None:
  94. self.logger.warn(f"Message {message.message_id} had a command {message.command} that could not be run.")
  95. return RollbotResponse(message, failure=RollbotFailure.INVALID_COMMAND)
  96. with self.session_manager_factory() as session:
  97. response = plugin.on_command(session, message)
  98. if not response.is_success:
  99. self.logger.warn(f"Message {message.message_id} caused failure")
  100. self.logger.warn(response.info)
  101. return response
  102. def handle_command(self, message):
  103. if not message.is_command:
  104. self.logger.debug("Ignoring non-command message")
  105. return
  106. self.logger.info(f"Handling message {message.message_id}")
  107. t = time.time()
  108. try:
  109. response = self.run_command(message)
  110. except Exception as e:
  111. self.logger.exception(f"Exception during command execution for message {message.message_id}")
  112. response = RollbotResponse(message, failure=RollbotFailure.INTERNAL_ERROR)
  113. self.last_exception = "".join(traceback.format_exc())
  114. if not response.respond:
  115. self.logger.info(f"Skipping response to message {message.message_id}")
  116. return
  117. self.logger.info(f"Responding to message {message.message_id}")
  118. sleep = self.sleep_time - time.time() + t
  119. if sleep > 0:
  120. self.logger.info(f"Sleeping for {sleep:.3f}s before responding")
  121. time.sleep(sleep)
  122. if response.is_success:
  123. if response.txt is not None:
  124. self.post_callback(response.txt, message.group_id)
  125. if response.img is not None:
  126. self.post_callback(response.img, message.group_id)
  127. else:
  128. self.post_callback(response.failure_msg, message.group_id)
  129. self.logger.warning(f"Failed command response: {response}")
  130. t = time.time() - t
  131. self.logger.info(f"Exiting command thread for {message.message_id} after {t:.3f}s")
  132. def manually_post_message(self, message_text, group_id):
  133. self.post_callback(message_text, group_id)