bot.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import logging
  2. import time
  3. import traceback
  4. from contextlib import contextmanager
  5. from dataclasses import dataclass
  6. from typing import List, Type, Dict, Callable, Any
  7. from sqlalchemy.orm.session import Session
  8. from .messaging import RollbotResponse, RollbotFailure
  9. from .plugins import as_plugin, RollbotPlugin
  10. from .database import init_db_at_url
  11. def lift_response(call, response):
  12. @as_plugin(call)
  13. def response_func(db, msg):
  14. return RollbotResponse(msg, txt=response)
  15. return response_func
  16. def get_session_manager_factory(session_factory):
  17. @contextmanager
  18. def session_manager_factory():
  19. """Provide a transactional scope around a series of operations."""
  20. session = session_factory()
  21. try:
  22. yield session
  23. session.commit()
  24. except:
  25. # TODO there is some worry that this would rollback things in other threads...
  26. # we should probably find a more correct solution for managing the threaded
  27. # db access, but the risk is fairly low at this point.
  28. session.rollback()
  29. raise
  30. finally:
  31. session.close()
  32. return session_manager_factory
  33. @dataclass
  34. class RollbotConfig:
  35. plugins: List[Type[RollbotPlugin]]
  36. db_url: str
  37. reply_callback: Callable[[str, str], None]
  38. aliases: Dict[str, str]
  39. responses: Dict[str, str]
  40. sleep_time: float
  41. other: Dict[str, Any]
  42. def get(self, key):
  43. c = self.other
  44. for k in key.split("."):
  45. c = c[k]
  46. return c
  47. class Rollbot:
  48. def __init__(self, config, logger=logging.getLogger(__name__)):
  49. self.logger = logger
  50. self.session_manager_factory = lambda: None
  51. self.post_callback = config.reply_callback or (lambda txt, gid: self.logger.info(f"Responding to {gid} with {txt}"))
  52. self.commands = {}
  53. self.to_start = set()
  54. self.to_stop = set()
  55. self.sleep_time = config.sleep_time
  56. self.last_exception = None
  57. self.config = config
  58. self.logger.info("Loading command plugins")
  59. for plugin_class in config.plugins:
  60. plugin_instance = plugin_class(self, logger=logger)
  61. if plugin_instance.command in self.commands:
  62. self.logger.error(f"Duplicate command word '{plugin_instance.command}'")
  63. raise ValueError(f"Duplicate command word '{plugin_instance.command}'")
  64. self.commands[plugin_instance.command] = plugin_instance
  65. if "on_start" in plugin_class.__dict__:
  66. self.to_start.add(plugin_instance)
  67. if "on_shutdown" in plugin_class.__dict__:
  68. self.to_stop.add(plugin_instance)
  69. self.logger.info(f"Finished loading plugins, {len(self.commands)} commands found")
  70. self.logger.info("Loading simple responses")
  71. for cmd, response in config.responses.items():
  72. if cmd in self.commands:
  73. self.logger.error(f"Duplicate command word '{cmd}'")
  74. raise ValueError(f"Duplicate command word '{cmd}'")
  75. self.commands[cmd] = lift_response(cmd, response)(self, logger=logger)
  76. self.logger.info(f"Finished loading simple responses, {len(self.commands)} total commands available")
  77. self.logger.info("Loading aliases")
  78. for alias, cmd in config.aliases.items():
  79. if cmd not in self.commands:
  80. self.logger.error(f"Missing aliased command word '{cmd}'")
  81. raise ValueError(f"Missing aliased command word '{cmd}'")
  82. if alias in self.commands:
  83. self.logger.error(f"Duplicate command word '{alias}'")
  84. raise ValueError(f"Duplicate command word '{alias}'")
  85. self.commands[alias] = self.commands[cmd]
  86. self.logger.info(f"Finished loading aliases, {len(self.commands)} total commands + aliases available")
  87. def init_db(self):
  88. session_factory = init_db_at_url(self.config.db_url)
  89. self.session_manager_factory = get_session_manager_factory(session_factory)
  90. def start_plugins(self):
  91. self.logger.info("Starting plugins")
  92. with self.session_manager_factory() as session:
  93. for cmd in self.to_start:
  94. cmd.on_start(session)
  95. self.logger.info("Finished starting plugins")
  96. def shutdown_plugins(self):
  97. self.logger.info("Shutting down plugins")
  98. with self.session_manager_factory() as session:
  99. for cmd in self.to_stop:
  100. cmd.on_shutdown(session)
  101. self.logger.info("Finished shutting down plugins")
  102. def run_command(self, message):
  103. if not message.is_command:
  104. self.logger.warn(f"Tried to run non-command message {message.message_id}")
  105. return RollbotResponse(message, failure=RollbotFailure.INTERNAL_ERROR)
  106. if message.command == "help":
  107. topic = next(message.args())
  108. targeted = self.commands.get(topic, None)
  109. if targeted is None:
  110. return RollbotResponse(message, failure=RollbotFailure.INVALID_ARGUMENTS, debugging={"explain": f"Could not find command {topic}"})
  111. return RollbotResponse(message, txt=targeted.help_msg())
  112. plugin = self.commands.get(message.command, None)
  113. if plugin is None:
  114. self.logger.warn(f"Message {message.message_id} had a command {message.command} that could not be run.")
  115. return RollbotResponse(message, failure=RollbotFailure.INVALID_COMMAND)
  116. with self.session_manager_factory() as session:
  117. response = plugin.on_command(session, message)
  118. if not response.is_success:
  119. self.logger.warn(f"Message {message.message_id} caused failure")
  120. self.logger.warn(response.info)
  121. return response
  122. def handle_command(self, message):
  123. if not message.is_command:
  124. self.logger.debug("Ignoring non-command message")
  125. return
  126. self.logger.info(f"Handling message {message.message_id}")
  127. t = time.time()
  128. try:
  129. response = self.run_command(message)
  130. except Exception as e:
  131. self.logger.exception(f"Exception during command execution for message {message.message_id}")
  132. response = RollbotResponse(message, failure=RollbotFailure.INTERNAL_ERROR)
  133. self.last_exception = "".join(traceback.format_exc())
  134. if not response.respond:
  135. self.logger.info(f"Skipping response to message {message.message_id}")
  136. return
  137. self.logger.info(f"Responding to message {message.message_id}")
  138. sleep = self.sleep_time - time.time() + t
  139. if sleep > 0:
  140. self.logger.info(f"Sleeping for {sleep:.3f}s before responding")
  141. time.sleep(sleep)
  142. if response.is_success:
  143. if response.txt is not None:
  144. self.post_callback(response.txt, message.group_id)
  145. if response.img is not None:
  146. self.post_callback(response.img, message.group_id)
  147. else:
  148. self.post_callback(response.failure_msg, message.group_id)
  149. self.logger.warning(f"Failed command response: {response}")
  150. t = time.time() - t
  151. self.logger.info(f"Exiting command thread for {message.message_id} after {t:.3f}s")
  152. def manually_post_message(self, message_text, group_id):
  153. self.post_callback(message_text, group_id)