import logging import time import traceback from contextlib import contextmanager from dataclasses import dataclass from typing import List, Type, Dict, Callable, Any from sqlalchemy.orm.session import Session from .messaging import RollbotResponse, RollbotFailure from .plugins import as_plugin, RollbotPlugin from .database import init_db_at_url def lift_response(call, response): @as_plugin(call) def response_func(db, msg): return RollbotResponse(msg, txt=response) return response_func def get_session_manager_factory(session_factory): @contextmanager def session_manager_factory(): """Provide a transactional scope around a series of operations.""" session = session_factory() try: yield session session.commit() except: # TODO there is some worry that this would rollback things in other threads... # we should probably find a more correct solution for managing the threaded # db access, but the risk is fairly low at this point. session.rollback() raise finally: session.close() return session_manager_factory @dataclass class RollbotConfig: plugins: List[Type[RollbotPlugin]] db_url: str reply_callback: Callable[[str, str], None] aliases: Dict[str, str] responses: Dict[str, str] sleep_time: float other: Dict[str, Any] def get(self, key): c = self.other for k in key.split("."): c = c[k] return c class Rollbot: def __init__(self, config, logger=logging.getLogger(__name__)): self.logger = logger self.session_manager_factory = lambda: None self.post_callback = config.reply_callback or (lambda txt, gid: self.logger.info(f"Responding to {gid} with {txt}")) self.commands = {} self.to_start = set() self.to_stop = set() self.sleep_time = config.sleep_time self.last_exception = None self.config = config self.logger.info("Loading command plugins") for plugin_class in config.plugins: plugin_instance = plugin_class(self, logger=logger) if plugin_instance.command in self.commands: self.logger.error(f"Duplicate command word '{plugin_instance.command}'") raise ValueError(f"Duplicate command word '{plugin_instance.command}'") self.commands[plugin_instance.command] = plugin_instance if "on_start" in plugin_class.__dict__: self.to_start.add(plugin_instance) if "on_shutdown" in plugin_class.__dict__: self.to_stop.add(plugin_instance) self.logger.info(f"Finished loading plugins, {len(self.commands)} commands found") self.logger.info("Loading simple responses") for cmd, response in config.responses.items(): if cmd in self.commands: self.logger.error(f"Duplicate command word '{cmd}'") raise ValueError(f"Duplicate command word '{cmd}'") self.commands[cmd] = lift_response(cmd, response)(self, logger=logger) self.logger.info(f"Finished loading simple responses, {len(self.commands)} total commands available") self.logger.info("Loading aliases") for alias, cmd in config.aliases.items(): if cmd not in self.commands: self.logger.error(f"Missing aliased command word '{cmd}'") raise ValueError(f"Missing aliased command word '{cmd}'") if alias in self.commands: self.logger.error(f"Duplicate command word '{alias}'") raise ValueError(f"Duplicate command word '{alias}'") self.commands[alias] = self.commands[cmd] self.logger.info(f"Finished loading aliases, {len(self.commands)} total commands + aliases available") def init_db(self): session_factory = init_db_at_url(self.config.db_url) self.session_manager_factory = get_session_manager_factory(session_factory) def start_plugins(self): self.logger.info("Starting plugins") with self.session_manager_factory() as session: for cmd in self.to_start: cmd.on_start(session) self.logger.info("Finished starting plugins") def shutdown_plugins(self): self.logger.info("Shutting down plugins") with self.session_manager_factory() as session: for cmd in self.to_stop: cmd.on_shutdown(session) self.logger.info("Finished shutting down plugins") def run_command(self, message): if not message.is_command: self.logger.warn(f"Tried to run non-command message {message.message_id}") return RollbotResponse(message, failure=RollbotFailure.INTERNAL_ERROR) if message.command == "help": topic = next(message.args()) targeted = self.commands.get(topic, None) if targeted is None: return RollbotResponse(message, failure=RollbotFailure.INVALID_ARGUMENTS, debugging={"explain": f"Could not find command {topic}"}) return RollbotResponse(message, txt=targeted.help_msg()) plugin = self.commands.get(message.command, None) if plugin is None: self.logger.warn(f"Message {message.message_id} had a command {message.command} that could not be run.") return RollbotResponse(message, failure=RollbotFailure.INVALID_COMMAND) with self.session_manager_factory() as session: response = plugin.on_command(session, message) if not response.is_success: self.logger.warn(f"Message {message.message_id} caused failure") self.logger.warn(response.info) return response def handle_command(self, message): if not message.is_command: self.logger.debug("Ignoring non-command message") return self.logger.info(f"Handling message {message.message_id}") t = time.time() try: response = self.run_command(message) except Exception as e: self.logger.exception(f"Exception during command execution for message {message.message_id}") response = RollbotResponse(message, failure=RollbotFailure.INTERNAL_ERROR) self.last_exception = "".join(traceback.format_exc()) if not response.respond: self.logger.info(f"Skipping response to message {message.message_id}") return self.logger.info(f"Responding to message {message.message_id}") sleep = self.sleep_time - time.time() + t if sleep > 0: self.logger.info(f"Sleeping for {sleep:.3f}s before responding") time.sleep(sleep) if response.is_success: if response.txt is not None: self.post_callback(response.txt, message.group_id) if response.img is not None: self.post_callback(response.img, message.group_id) else: self.post_callback(response.failure_msg, message.group_id) self.logger.warning(f"Failed command response: {response}") t = time.time() - t self.logger.info(f"Exiting command thread for {message.message_id} after {t:.3f}s") def manually_post_message(self, message_text, group_id): self.post_callback(message_text, group_id)