Przeglądaj źródła

Moving a lot of the complex logic into the rollbot class and out of the app.py

Kirk Trombley 6 lat temu
rodzic
commit
5abc6ea970
3 zmienionych plików z 78 dodań i 45 usunięć
  1. 15 14
      src/app.py
  2. 1 5
      src/config.py
  3. 62 26
      src/rollbot.py

+ 15 - 14
src/app.py

@@ -1,11 +1,11 @@
 import atexit
-import time
 from logging.config import dictConfig
 from threading import Thread
 
 from flask import Flask, request, render_template, jsonify
 
-from config import BOTS_LOOKUP, SLEEP_TIME, get_config
+import db
+from config import BOTS_LOOKUP, get_config
 from command_system import RollbotMessage, RollbotResponse, RollbotFailure
 from rollbot import Rollbot
 from util import post_message
@@ -27,8 +27,18 @@ dictConfig({
 })
 
 app = Flask(__name__)
-
-rollbot = Rollbot(app.logger)
+app.config["PROPAGATE_EXCEPTIONS"] = True
+rollbot = Rollbot(
+    app.logger,
+    plugins=get_config("plugins"),
+    aliases=get_config("aliases"),
+    responses=get_config("responses"),
+    lookup=BOTS_LOOKUP,
+    sleep_time=float(get_config("sleep_time"))
+)
+app.logger.info("Initializing database tables")
+db.init_db()
+app.logger.info("Finished initializing database")
 rollbot.start_plugins()
 atexit.register(rollbot.shutdown_plugins)
 
@@ -47,18 +57,9 @@ def teamspeak():
 def execute():
     json = request.get_json()
     msg = RollbotMessage.from_groupme(json)
-
-    if msg.group_id not in BOTS_LOOKUP:
-        app.logger.warning(f"Received message from unknown group ID {msg.group_id}")
-        return jsonify({"message": "Invalid group ID"}), 403
-
-    if msg.is_command:
-        t = Thread(target=lambda: rollbot.handle_message(msg))
-        t.start()
-
+    rollbot.handle_command_threaded(msg)
     return "", 204
 
 
 if __name__ == "__main__":
-    # default deployment in debug mode
     app.run(host="0.0.0.0", port=6070)

+ 1 - 5
src/config.py

@@ -29,9 +29,5 @@ def get_secret(key):
 BOTS_LOOKUP = get_secret("bots")
 GLOBAL_ADMINS = get_secret("auths.global")
 GROUP_ADMINS = get_secret("auths.group")
-PLUGINS = get_config("plugins")
-ALIASES = get_config("aliases")
-RESPONSES = get_config("responses")
 API_KEY = get_secret("api_key")
-DB_FILE = os.path.abspath(get_config("database"))
-SLEEP_TIME = float(get_config("sleep_time"))
+DB_FILE = os.path.abspath(get_config("database"))

+ 62 - 26
src/rollbot.py

@@ -1,43 +1,77 @@
 import importlib
 import time
+from threading import Thread
 
 import db
-from config import PLUGINS, ALIASES, RESPONSES, BOTS_LOOKUP, SLEEP_TIME
-from command_system import RollbotResponse, RollbotFailure
+from command_system import RollbotResponse, RollbotFailure, as_plugin
 from util import post_message
 
 
+def lift_response(call, response):
+    @as_plugin(call)
+    def response_func(db, msg):
+        return RollbotResponse(msg, txt=response)
+    return response_func
+
+
 class Rollbot:
-    def __init__(self, logger):
+    def __init__(self, logger, plugins={}, aliases={}, responses={}, lookup={}, sleep_time=0.0):
         self.logger = logger
         self.commands = {}
+        self.to_start = set()
+        self.to_stop = set()
+
         self.logger.info("Loading command plugins")
-        for module, classes in PLUGINS.items():
+        for module, classes in plugins.items():
             plugin_module = importlib.import_module("plugins." + module)
             for class_name in classes:
+                self.logger.info(class_name)
                 plugin_class = getattr(plugin_module, class_name)
-                plugin_instance = plugin_class(logger)
+                plugin_instance = plugin_class(logger=logger)
                 if plugin_instance.command in self.commands:
                     self.logger.error(f"Duplicate command word '{plugin_instance.command}'")
                     raise ValueError(f"Duplicate command word '{plugin_instance.command}'")
                 self.commands[plugin_instance.command] = plugin_instance
+                if "on_start" in plugin_class.__dict__:
+                    self.to_start.add(plugin_instance)
+                if "on_shutdown" in plugin_class.__dict__:
+                    self.to_stop.add(plugin_instance)
         self.logger.info(f"Finished loading plugins, {len(self.commands)} commands found")
-        self.logger.info("Initializing database tables")
-        db.init_db()
-        self.logger.info("Finished initializing database")
+
+        self.logger.info("Loading simple responses")
+        for cmd, response in responses.items():
+            if cmd in self.commands:
+                self.logger.error(f"Duplicate command word '{cmd}'")
+                raise ValueError(f"Duplicate command word '{cmd}'")
+            self.commands[cmd] = lift_response(cmd, response)(logger=logger)
+        self.logger.info(f"Finished loading simple responses, {len(self.commands)} total commands available")
+
+        self.logger.info("Loading aliases")
+        for alias, cmd in aliases.items():
+            if cmd not in self.commands:
+                self.logger.error(f"Missing aliased command word '{cmd}'")
+                raise ValueError(f"Missing aliased command word '{cmd}'")
+            if alias in self.commands:
+                self.logger.error(f"Duplicate command word '{alias}'")
+                raise ValueError(f"Duplicate command word '{alias}'")
+            self.commands[alias] = self.commands[cmd]
+        self.logger.info(f"Finished loading aliases, {len(self.commands)} total commands + aliases available")
+
+        self.bot_lookup = lookup
+        self.sleep_time = sleep_time
 
     def start_plugins(self):
-        self.logger.info("Starting all plugins")
+        self.logger.info("Starting plugins")
         with db.session_scope() as session:
-            for plugin_instance in self.commands.values():
-                plugin_instance.on_start(session)
+            for cmd in self.to_start:
+                cmd.on_start(session)
         self.logger.info("Finished starting plugins")
 
     def shutdown_plugins(self):
-        self.logger.info("Shutting down all plugins")
+        self.logger.info("Shutting down plugins")
         with db.session_scope() as session:
-            for plugin_instance in self.commands.values():
-                plugin_instance.on_shutdown(session)
+            for cmd in self.to_stop:
+                cmd.on_shutdown(session)
         self.logger.info("Finished shutting down plugins")
 
     def run_command(self, message):
@@ -45,16 +79,10 @@ class Rollbot:
             self.logger.warn(f"Tried to run non-command message {message.message_id}")
             return RollbotResponse(msg, failure=RollbotFailure.INTERNAL_ERROR)
 
-        # if this command is aliased, resolve that first, otherwise use the literal command
-        cmd = ALIASES.get(message.command, message.command)
-
-        if cmd in RESPONSES:
-            return RollbotResponse(message, txt=RESPONSES[cmd])
-
-        plugin = self.commands.get(cmd, None)
+        plugin = self.commands.get(message.command, None)
 
         if plugin is None:
-            self.logger.warn(f"Message {message.message_id} had a command {message.command} (resolved to {cmd}) that could not be run.")
+            self.logger.warn(f"Message {message.message_id} had a command {message.command} that could not be run.")
             return RollbotResponse(message, failure=RollbotFailure.INVALID_COMMAND)
 
         with db.session_scope() as session:
@@ -66,13 +94,17 @@ class Rollbot:
 
         return response
 
-    def handle_message(self, msg):
+    def handle_command(self, msg):
+        if msg.group_id not in self.bot_lookup:
+            self.logger.warning(f"Received message from unknown group ID {msg.group_id}")
+            return
+
         self.logger.info(f"Handling message {msg.message_id}")
         t = time.time()
         try:
             response = self.run_command(msg)
         except Exception as e:
-            self.logger.error(f"Exception during command execution {e}, for message {msg.message_id}")
+            self.logger.exception(f"Exception during command execution for message {msg.message_id}")
             response = RollbotResponse(msg, failure=RollbotFailure.INTERNAL_ERROR)
 
         if not response.respond:
@@ -81,12 +113,12 @@ class Rollbot:
 
         self.logger.info(f"Responding to message {msg.message_id}")
 
-        sleep = SLEEP_TIME - time.time() + t
+        sleep = self.sleep_time - time.time() + t
         if sleep > 0:
             self.logger.info(f"Sleeping for {sleep:.3f}s before responding")
             time.sleep(sleep)
 
-        bot_id = BOTS_LOOKUP[msg.group_id]
+        bot_id = self.bot_lookup[msg.group_id]
         if response.is_success:
             if response.txt is not None:
                 post_message(response.txt, bot_id)
@@ -98,3 +130,7 @@ class Rollbot:
 
         t = time.time() - t
         self.logger.info(f"Exiting command thread for {msg.message_id} after {t:.3f}s")
+
+    def handle_command_threaded(self, message):
+        t = Thread(target=lambda: self.handle_command(message))
+        t.start()