ソースを参照

Merge branch 'devops/improved-code-layout' of kirkleon/rollbot3 into master

kirkleon 6 年 前
コミット
f3f32627bd

+ 1 - 0
config/config.toml

@@ -36,3 +36,4 @@ bump = "Bumping the chat!\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\
 [teamspeak]
 host = "kirkleon.ddns.net"
 user = "serveradmin"
+scrolling = true

+ 1 - 4
rollbot-docker.sh

@@ -49,12 +49,9 @@ clean_container() {
 case $1 in
     "l"|"logs")
         STATUS=$(status_check)
-        if [ "$STATUS" = "true" ]
+        if [ "$STATUS" = "true" ] || [ "$STATUS" = "false" ]
         then
             docker logs -f $CONTAINER_NAME
-        elif [ "$STATUS" = "false" ]
-        then
-            echo "Existing container $CONTAINER_NAME is stopped."
         else
             echo "No existing container $CONTAINER_NAME could be found."
         fi

+ 25 - 74
src/app.py

@@ -1,15 +1,17 @@
 import atexit
-import time
 from logging.config import dictConfig
-from threading import Thread
 
 from flask import Flask, request, render_template
 
-from config import BOTS_LOOKUP, SLEEP_TIME
-from command_system import RollbotMessage, RollbotResponse, RollbotFailure
+import db
+from config import BOTS_LOOKUP, get_config, get_secret
+from command_system import RollbotMessage
 from rollbot import Rollbot
 from util import post_message
 
+GLOBAL_ADMINS = get_secret("auths.global")
+GROUP_ADMINS = get_secret("auths.group")
+
 dictConfig({
     "version": 1,
     "formatters": {"default": {
@@ -27,92 +29,41 @@ dictConfig({
 })
 
 app = Flask(__name__)
-
-rollbot = Rollbot(app.logger)
+app.config["PROPAGATE_EXCEPTIONS"] = True
+rollbot = Rollbot(
+    logger=app.logger,
+    plugins=get_config("plugins"),
+    aliases=get_config("aliases"),
+    responses=get_config("responses"),
+    lookup=BOTS_LOOKUP,
+    sleep_time=float(get_config("sleep_time")),
+    callback=post_message,
+    session_factory=db.session_scope
+)
+app.logger.info("Initializing database tables")
+db.init_db()
+app.logger.info("Finished initializing database")
 rollbot.start_plugins()
 atexit.register(rollbot.shutdown_plugins)
 
 
 @app.route("/teamspeak", methods=["GET"])
 def teamspeak():
-    response = rollbot.run_command(RollbotMessage.from_web("!teamspeak"))
+    response = rollbot.run_command(RollbotMessage("MANUAL", None, None, None, None, "!teamspeak", False))
     if response.is_success:
         response = response.txt
     else:
         response = response.failure_msg
-    return render_template("teamspeak.html", r=response)
-
-
-@app.route("/services", methods=["GET", "POST"])
-def services():
-    if request.method == "POST":
-        msg = RollbotMessage.from_web(request.form["cmd"])
-        if msg.is_command:
-            response = rollbot.run_command(msg)
-            if response.is_success:
-                txt = response.txt
-                img = response.img
-            else:
-                txt = response.failure_msg
-                img = None
-            return render_template("services.html", r=response.respond, txt=txt, img=img)
-    return render_template("services.html", r=None)
-
+    return render_template("teamspeak.html", r=response, scrolling=get_config("teamspeak.scrolling"))
 
 
 @app.route("/rollbot", methods=["POST"])
 def execute():
     json = request.get_json()
-    msg = RollbotMessage.from_groupme(json)
-
-    if not msg.is_command:
-        app.logger.debug("Received non-command message")
-        return "", 204
-
-    if msg.group_id not in BOTS_LOOKUP:
-        app.logger.warning(f"Received message from unknown group ID {msg.group_id}")
-        return "Invalid group ID", 403
-
-
-    def run_command_and_respond():
-        app.logger.info(f"Entering command thread for {msg.message_id}")
-        t = time.time()
-        try:
-            response = rollbot.run_command(msg)
-        except Exception as e:
-            app.logger.error(f"Exception during command execution {e}, for message {msg.message_id}")
-            response = RollbotResponse(msg, failure=RollbotFailure.INTERNAL_ERROR)
-
-        if not response.respond:
-            app.logger.info(f"Skipping response to message {msg.message_id}")
-            return
-
-        app.logger.info(f"Responding to message {msg.message_id}")
-
-        sleep = SLEEP_TIME - time.time() + t
-        if sleep > 0:
-            app.logger.info(f"Sleeping for {sleep:.3f}s before responding")
-            time.sleep(sleep)
-
-        bot_id = BOTS_LOOKUP[msg.group_id]
-        if response.is_success:
-            if response.txt is not None:
-                post_message(response.txt, bot_id)
-            if response.img is not None:
-                post_message(response.img, bot_id)
-        else:
-            post_message(response.failure_msg, bot_id)
-            app.logger.warning(f"Failed command response: {response}")
-
-        t = time.time() - t
-        app.logger.info(f"Exiting command thread for {msg.message_id} after {t:.3f}s")
-
-    t = Thread(target=run_command_and_respond)
-    t.start()
-
-    return "OK", 200
+    msg = RollbotMessage.from_groupme(json, global_admins=GLOBAL_ADMINS, group_admins=GROUP_ADMINS)
+    rollbot.handle_command_threaded(msg)
+    return "", 204
 
 
 if __name__ == "__main__":
-    # default deployment in debug mode
     app.run(host="0.0.0.0", port=6070)

+ 15 - 17
src/command_system.py

@@ -4,8 +4,6 @@ from enum import Enum, auto
 
 from sqlalchemy.ext.declarative import declarative_base
 
-from config import GLOBAL_ADMINS, GROUP_ADMINS
-
 
 BANGS = ('!',)
 
@@ -29,6 +27,7 @@ class RollbotMessage:
     group_id: str
     message_id: str
     message_txt: str
+    from_admin: bool
 
     def __post_init__(self):
         self.is_command = False
@@ -39,22 +38,21 @@ class RollbotMessage:
                 self.command = cmd.lower()
                 self.raw_args = raw
 
-        self.from_admin = self.sender_id is not None and \
-            self.sender_id in GLOBAL_ADMINS or (
-                self.group_id in GROUP_ADMINS and
-                self.sender_id in GROUP_ADMINS[self.group_id])
-
     @staticmethod
-    def from_groupme(msg):
-        return RollbotMessage("GROUPME", msg["name"], msg["sender_id"], msg["group_id"], msg["id"], msg["text"].strip())
-
-    @staticmethod
-    def from_web(content):
-        content = content.strip()
-        if len(content) > 0 and content[0] not in BANGS:
-            content = BANGS[0] + content
-        # TODO should still assign an id...
-        return RollbotMessage("WEB_FRONTEND", "user", None, None, None, content)
+    def from_groupme(msg, global_admins=(), group_admins={}):
+        sender_id = msg["sender_id"]
+        group_id = msg["group_id"]
+        return RollbotMessage(
+            "GROUPME",
+            msg["name"],
+            sender_id,
+            group_id,
+            msg["id"],
+            msg["text"].strip(),
+            sender_id in global_admins or (
+                group_id in group_admins and
+                sender_id in group_admins[group_id])
+        )
 
     def args(self, normalize=True):
         arg, rest = pop_arg(self.raw_args)

+ 1 - 7
src/config.py

@@ -27,11 +27,5 @@ def get_secret(key):
 
 
 BOTS_LOOKUP = get_secret("bots")
-GLOBAL_ADMINS = get_secret("auths.global")
-GROUP_ADMINS = get_secret("auths.group")
-PLUGINS = get_config("plugins")
-ALIASES = get_config("aliases")
-RESPONSES = get_config("responses")
 API_KEY = get_secret("api_key")
-DB_FILE = os.path.abspath(get_config("database"))
-SLEEP_TIME = float(get_config("sleep_time"))
+DB_FILE = os.path.abspath(get_config("database"))

+ 5 - 2
src/db.py

@@ -1,14 +1,14 @@
 from contextlib import contextmanager
 
 from sqlalchemy import create_engine
-from sqlalchemy.orm import sessionmaker
+from sqlalchemy.orm import sessionmaker, scoped_session
 
 from config import DB_FILE
 from command_system import ModelBase
 
 
 engine = create_engine("sqlite:///" + DB_FILE)
-Session = sessionmaker(bind=engine)
+Session = scoped_session(sessionmaker(bind=engine))
 
 
 def init_db():
@@ -23,6 +23,9 @@ def session_scope():
         yield session
         session.commit()
     except:
+        # TODO there is some worry that this would rollback things in other threads...
+        # we should probably find a more correct solution for managing the threaded
+        # db access, but the risk is fairly low at this point.
         session.rollback()
         raise
     finally:

+ 3 - 6
src/plugins/teamspeak.py

@@ -11,9 +11,8 @@ TS3_LOGIN = ("login %s %s\n" % (_TS3_USER, _TS3_PASS)).encode("utf-8")
 
 
 class teamspeak(RollbotPlugin):
-    def __init__(self, command, logger=logging.getLogger(__name__)):
+    def __init__(self, logger=logging.getLogger(__name__)):
         RollbotPlugin.__init__(self, "teamspeak", logger)
-        self.logger.info(f"Intializing Teamspeak command")
 
     def on_command(self, db, message):
         try:
@@ -53,9 +52,8 @@ class teamspeak(RollbotPlugin):
 
 
 class speamteek(teamspeak):
-    def __init__(self, command, logger=logging.getLogger(__name__)):
+    def __init__(self, logger=logging.getLogger(__name__)):
         RollbotPlugin.__init__(self, "speamteek", logger)
-        self.logger.info(f"Intializing Speamteek command")
 
     def on_command(self, db, message):
         r = super().on_command(db, message)
@@ -65,9 +63,8 @@ class speamteek(teamspeak):
 
 
 class teamscream(teamspeak):
-    def __init__(self, command, logger=logging.getLogger(__name__)):
+    def __init__(self, logger=logging.getLogger(__name__)):
         RollbotPlugin.__init__(self, "teamscream", logger)
-        self.logger.info(f"Intializing Speamteek command")
 
     def on_command(self, db, message):
         r = super().on_command(db, message)

+ 99 - 27
src/rollbot.py

@@ -1,61 +1,92 @@
 import importlib
+import logging
+import time
+from threading import Thread
 
-import db
-from config import PLUGINS, ALIASES, RESPONSES, BOTS_LOOKUP
-from command_system import RollbotResponse, RollbotFailure
+from command_system import RollbotResponse, RollbotFailure, as_plugin
+
+
+def lift_response(call, response):
+    @as_plugin(call)
+    def response_func(db, msg):
+        return RollbotResponse(msg, txt=response)
+    return response_func
 
 
 class Rollbot:
-    def __init__(self, logger):
+    def __init__(self, logger=logging.getLogger(__name__), plugins={}, aliases={}, responses={}, lookup={}, sleep_time=0.0, session_factory=None, callback=None):
         self.logger = logger
+        self.session_factory = session_factory or (lambda: None)
+        self.callback = callback or (lambda txt, gid: self.logger.info(f"Responding to {gid} with {txt}"))
         self.commands = {}
+        self.to_start = set()
+        self.to_stop = set()
+
         self.logger.info("Loading command plugins")
-        for module, classes in PLUGINS.items():
+        for module, classes in plugins.items():
             plugin_module = importlib.import_module("plugins." + module)
             for class_name in classes:
+                self.logger.info(class_name)
                 plugin_class = getattr(plugin_module, class_name)
-                plugin_instance = plugin_class(logger)
+                plugin_instance = plugin_class(logger=logger)
                 if plugin_instance.command in self.commands:
                     self.logger.error(f"Duplicate command word '{plugin_instance.command}'")
                     raise ValueError(f"Duplicate command word '{plugin_instance.command}'")
                 self.commands[plugin_instance.command] = plugin_instance
+                if "on_start" in plugin_class.__dict__:
+                    self.to_start.add(plugin_instance)
+                if "on_shutdown" in plugin_class.__dict__:
+                    self.to_stop.add(plugin_instance)
         self.logger.info(f"Finished loading plugins, {len(self.commands)} commands found")
-        self.logger.info("Initializing database tables")
-        db.init_db()
-        self.logger.info("Finished initializing database")
+
+        self.logger.info("Loading simple responses")
+        for cmd, response in responses.items():
+            if cmd in self.commands:
+                self.logger.error(f"Duplicate command word '{cmd}'")
+                raise ValueError(f"Duplicate command word '{cmd}'")
+            self.commands[cmd] = lift_response(cmd, response)(logger=logger)
+        self.logger.info(f"Finished loading simple responses, {len(self.commands)} total commands available")
+
+        self.logger.info("Loading aliases")
+        for alias, cmd in aliases.items():
+            if cmd not in self.commands:
+                self.logger.error(f"Missing aliased command word '{cmd}'")
+                raise ValueError(f"Missing aliased command word '{cmd}'")
+            if alias in self.commands:
+                self.logger.error(f"Duplicate command word '{alias}'")
+                raise ValueError(f"Duplicate command word '{alias}'")
+            self.commands[alias] = self.commands[cmd]
+        self.logger.info(f"Finished loading aliases, {len(self.commands)} total commands + aliases available")
+
+        self.bot_lookup = lookup
+        self.sleep_time = sleep_time
 
     def start_plugins(self):
-        self.logger.info("Starting all plugins")
-        with db.session_scope() as session:
-            for plugin_instance in self.commands.values():
-                plugin_instance.on_start(session)
+        self.logger.info("Starting plugins")
+        with self.session_factory() as session:
+            for cmd in self.to_start:
+                cmd.on_start(session)
         self.logger.info("Finished starting plugins")
 
     def shutdown_plugins(self):
-        self.logger.info("Shutting down all plugins")
-        with db.session_scope() as session:
-            for plugin_instance in self.commands.values():
-                plugin_instance.on_shutdown(session)
+        self.logger.info("Shutting down plugins")
+        with self.session_factory() as session:
+            for cmd in self.to_stop:
+                cmd.on_shutdown(session)
         self.logger.info("Finished shutting down plugins")
 
     def run_command(self, message):
         if not message.is_command:
             self.logger.warn(f"Tried to run non-command message {message.message_id}")
-            return
-
-        # if this command is aliased, resolve that first, otherwise use the literal command
-        cmd = ALIASES.get(message.command, message.command)
+            return RollbotResponse(msg, failure=RollbotFailure.INTERNAL_ERROR)
 
-        if cmd in RESPONSES:
-            return RollbotResponse(message, txt=RESPONSES[cmd])
-
-        plugin = self.commands.get(cmd, None)
+        plugin = self.commands.get(message.command, None)
 
         if plugin is None:
-            self.logger.warn(f"Message {message.message_id} had a command {message.command} (resolved to {cmd}) that could not be run.")
+            self.logger.warn(f"Message {message.message_id} had a command {message.command} that could not be run.")
             return RollbotResponse(message, failure=RollbotFailure.INVALID_COMMAND)
 
-        with db.session_scope() as session:
+        with self.session_factory() as session:
             response = plugin.on_command(session, message)
 
         if not response.is_success:
@@ -63,3 +94,44 @@ class Rollbot:
             self.logger.warn(response.info)
 
         return response
+
+    def handle_command(self, msg):
+        if msg.group_id not in self.bot_lookup:
+            self.logger.warning(f"Received message from unknown group ID {msg.group_id}")
+            return
+
+        self.logger.info(f"Handling message {msg.message_id}")
+        t = time.time()
+        try:
+            response = self.run_command(msg)
+        except Exception as e:
+            self.logger.exception(f"Exception during command execution for message {msg.message_id}")
+            response = RollbotResponse(msg, failure=RollbotFailure.INTERNAL_ERROR)
+
+        if not response.respond:
+            self.logger.info(f"Skipping response to message {msg.message_id}")
+            return
+
+        self.logger.info(f"Responding to message {msg.message_id}")
+
+        sleep = self.sleep_time - time.time() + t
+        if sleep > 0:
+            self.logger.info(f"Sleeping for {sleep:.3f}s before responding")
+            time.sleep(sleep)
+
+        bot_id = self.bot_lookup[msg.group_id]
+        if response.is_success:
+            if response.txt is not None:
+                self.callback(response.txt, bot_id)
+            if response.img is not None:
+                self.callback(response.img, bot_id)
+        else:
+            self.callback(response.failure_msg, bot_id)
+            self.logger.warning(f"Failed command response: {response}")
+
+        t = time.time() - t
+        self.logger.info(f"Exiting command thread for {msg.message_id} after {t:.3f}s")
+
+    def handle_command_threaded(self, message):
+        t = Thread(target=lambda: self.handle_command(message))
+        t.start()

+ 0 - 29
src/templates/services.html

@@ -1,29 +0,0 @@
-<!doctype html>
-<title>Rollbot 3.0</title>
-<link rel="stylesheet" type="text/css" href="{{ url_for('static', filename='rollbot.css') }}">
-<div class="page">
-    <h1>Rollbot 3.0</h1>
-    {% if r is not none %}
-        {% if r %}
-            {% if txt is not none%}
-                {{ txt }}<br/>
-            {% endif %}
-            {% if img is not none%}
-                <img src="{{ img }}"/><br/>
-            {% endif %}
-        {% else %}
-            Operation produced no response!
-        {% endif %}
-        Try entering another command if you need help!
-    {% else %}
-        Hello! I'm Rollbot! You can enter a command below!<br/>
-        You can also elide the initial "!" if you want!
-    {% endif %}
-    <br/>
-    <br/>
-    <br/>
-    <form action="/services" method=POST>
-        <input type=text autofocus name="cmd"/><br/><br/>
-        <button type=submit>Submit</button>
-    </form>
-</div>

+ 10 - 2
src/templates/teamspeak.html

@@ -1,8 +1,16 @@
 <!doctype html>
 <title>Rollbot 3.0 - Teamspeak Server Status</title>
 <link rel="stylesheet" type="text/css" href="{{ url_for('static', filename='rollbot.css') }}">
+{% if scrolling %}
+<marquee direction="up" style="height 600px" behavior="alternate">
+<marquee direction="right" style="width 600px" behavior="alternate">
+{% endif %}
 <div class="page">
     <h1>Rollbot 3.0 - TeamSpeak Server Status</h1>
-    {{ r }}<br/>
-    <a href="{{ url_for('services') }}">Click here to return to my main front-end!</a>
+    {{ r }}
 </div>
+{% if scrolling %}
+</marquee>
+</marquee>
+{% endif %}
+

+ 15 - 4
src/test_driver.py

@@ -4,18 +4,29 @@ import os
 
 from rollbot import Rollbot
 from command_system import RollbotMessage
-from config import DB_FILE
+from config import BOTS_LOOKUP, DB_FILE, get_config
 
 try:
     os.remove(DB_FILE)
 except FileNotFoundError:
     pass
-rollbot = Rollbot(logging.getLogger(__name__))
+
+import db
+
+rollbot = Rollbot(
+    plugins=get_config("plugins"),
+    aliases=get_config("aliases"),
+    responses=get_config("responses"),
+    lookup=BOTS_LOOKUP,
+    sleep_time=float(get_config("sleep_time")),
+    callback=None,
+    session_factory=db.session_scope
+)
+db.init_db()
 
 
 def test_drive(msg, from_admin=True):
-    rmsg = RollbotMessage("mock", None, None, "test_group", None, msg)
-    rmsg.from_admin = from_admin
+    rmsg = RollbotMessage("mock", None, None, "test_group", None, msg, from_admin)
     r =rollbot.run_command(rmsg)
     print(msg, ":", r.txt, ",", r.failure)