bot.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. import logging
  2. import time
  3. from contextlib import contextmanager
  4. from dataclasses import dataclass
  5. from typing import List, Type, Dict, Callable, Any
  6. from sqlalchemy.orm.session import Session
  7. from .messaging import RollbotResponse, RollbotFailure
  8. from .plugins import as_plugin, RollbotPlugin
  9. def lift_response(call, response):
  10. @as_plugin(call)
  11. def response_func(db, msg):
  12. return RollbotResponse(msg, txt=response)
  13. return response_func
  14. def get_session_manager_factory(session_factory):
  15. @contextmanager
  16. def session_manager_factory():
  17. """Provide a transactional scope around a series of operations."""
  18. session = session_factory()
  19. try:
  20. yield session
  21. session.commit()
  22. except:
  23. # TODO there is some worry that this would rollback things in other threads...
  24. # we should probably find a more correct solution for managing the threaded
  25. # db access, but the risk is fairly low at this point.
  26. session.rollback()
  27. raise
  28. finally:
  29. session.close()
  30. return session_manager_factory
  31. @dataclass
  32. class RollbotConfig:
  33. plugins: List[Type[RollbotPlugin]]
  34. session_factory: Callable[[], Session]
  35. reply_callback: Callable[[str, str], None]
  36. aliases: Dict[str, str]
  37. responses: Dict[str, str]
  38. sleep_time: float
  39. other: Dict[str, Any]
  40. def get(self, key):
  41. c = self.other
  42. for k in key.split("."):
  43. c = c[k]
  44. return c
  45. class Rollbot:
  46. def __init__(self, config, logger=logging.getLogger(__name__)):
  47. self.logger = logger
  48. if config.session_factory is not None:
  49. self.session_manager_factory = get_session_manager_factory(config.session_factory)
  50. else:
  51. self.session_manager_factory = lambda: None
  52. self.post_callback = config.reply_callback or (lambda txt, gid: self.logger.info(f"Responding to {gid} with {txt}"))
  53. self.commands = {}
  54. self.to_start = set()
  55. self.to_stop = set()
  56. self.sleep_time = config.sleep_time
  57. self.last_exception = None
  58. self.config = config
  59. self.logger.info("Loading command plugins")
  60. for plugin_class in config.plugins:
  61. plugin_instance = plugin_class(self, logger=logger)
  62. if plugin_instance.command in self.commands:
  63. self.logger.error(f"Duplicate command word '{plugin_instance.command}'")
  64. raise ValueError(f"Duplicate command word '{plugin_instance.command}'")
  65. self.commands[plugin_instance.command] = plugin_instance
  66. if "on_start" in plugin_class.__dict__:
  67. self.to_start.add(plugin_instance)
  68. if "on_shutdown" in plugin_class.__dict__:
  69. self.to_stop.add(plugin_instance)
  70. self.logger.info(f"Finished loading plugins, {len(self.commands)} commands found")
  71. self.logger.info("Loading simple responses")
  72. for cmd, response in config.responses.items():
  73. if cmd in self.commands:
  74. self.logger.error(f"Duplicate command word '{cmd}'")
  75. raise ValueError(f"Duplicate command word '{cmd}'")
  76. self.commands[cmd] = lift_response(cmd, response)(self, logger=logger)
  77. self.logger.info(f"Finished loading simple responses, {len(self.commands)} total commands available")
  78. self.logger.info("Loading aliases")
  79. for alias, cmd in config.aliases.items():
  80. if cmd not in self.commands:
  81. self.logger.error(f"Missing aliased command word '{cmd}'")
  82. raise ValueError(f"Missing aliased command word '{cmd}'")
  83. if alias in self.commands:
  84. self.logger.error(f"Duplicate command word '{alias}'")
  85. raise ValueError(f"Duplicate command word '{alias}'")
  86. self.commands[alias] = self.commands[cmd]
  87. self.logger.info(f"Finished loading aliases, {len(self.commands)} total commands + aliases available")
  88. def start_plugins(self):
  89. self.logger.info("Starting plugins")
  90. with self.session_manager_factory() as session:
  91. for cmd in self.to_start:
  92. cmd.on_start(session)
  93. self.logger.info("Finished starting plugins")
  94. def shutdown_plugins(self):
  95. self.logger.info("Shutting down plugins")
  96. with self.session_manager_factory() as session:
  97. for cmd in self.to_stop:
  98. cmd.on_shutdown(session)
  99. self.logger.info("Finished shutting down plugins")
  100. def run_command(self, message):
  101. if not message.is_command:
  102. self.logger.warn(f"Tried to run non-command message {message.message_id}")
  103. return RollbotResponse(message, failure=RollbotFailure.INTERNAL_ERROR)
  104. if message.command == "help":
  105. topic = next(message.args())
  106. targeted = self.commands.get(topic, None)
  107. if targeted is None:
  108. return RollbotResponse(message, failure=RollbotFailure.INVALID_ARGUMENTS, debugging={"explain": f"Could not find command {topic}"})
  109. return RollbotResponse(message, txt=targeted.help_msg())
  110. plugin = self.commands.get(message.command, None)
  111. if plugin is None:
  112. self.logger.warn(f"Message {message.message_id} had a command {message.command} that could not be run.")
  113. return RollbotResponse(message, failure=RollbotFailure.INVALID_COMMAND)
  114. with self.session_manager_factory() as session:
  115. response = plugin.on_command(session, message)
  116. if not response.is_success:
  117. self.logger.warn(f"Message {message.message_id} caused failure")
  118. self.logger.warn(response.info)
  119. return response
  120. def handle_command(self, message):
  121. if not message.is_command:
  122. self.logger.debug("Ignoring non-command message")
  123. return
  124. self.logger.info(f"Handling message {message.message_id}")
  125. t = time.time()
  126. try:
  127. response = self.run_command(message)
  128. except Exception as e:
  129. self.logger.exception(f"Exception during command execution for message {message.message_id}")
  130. response = RollbotResponse(message, failure=RollbotFailure.INTERNAL_ERROR)
  131. self.last_exception = e
  132. if not response.respond:
  133. self.logger.info(f"Skipping response to message {message.message_id}")
  134. return
  135. self.logger.info(f"Responding to message {message.message_id}")
  136. sleep = self.sleep_time - time.time() + t
  137. if sleep > 0:
  138. self.logger.info(f"Sleeping for {sleep:.3f}s before responding")
  139. time.sleep(sleep)
  140. if response.is_success:
  141. if response.txt is not None:
  142. self.post_callback(response.txt, message.group_id)
  143. if response.img is not None:
  144. self.post_callback(response.img, message.group_id)
  145. else:
  146. self.post_callback(response.failure_msg, message.group_id)
  147. self.logger.warning(f"Failed command response: {response}")
  148. t = time.time() - t
  149. self.logger.info(f"Exiting command thread for {message.message_id} after {t:.3f}s")
  150. def manually_post_message(self, message_text, group_id):
  151. self.post_callback(message_text, group_id)