Browse Source

Rewriting groupme bot driver to use RollbotConfig

Kirk Trombley 4 years ago
parent
commit
fd60f05e8b
2 changed files with 51 additions and 29 deletions
  1. 4 7
      src/db.py
  2. 47 22
      src/groupme_bot.py

+ 4 - 7
src/db.py

@@ -3,12 +3,9 @@ from sqlalchemy.orm import sessionmaker, scoped_session
 
 from rollbot import ModelBase
 
-from config import DB_FILE
 
-
-engine = create_engine("sqlite:///" + DB_FILE)
-Session = scoped_session(sessionmaker(bind=engine))
-
-
-def init_db():
+def init_db(db_file):
+    engine = create_engine("sqlite:///" + db_file)
+    Session = scoped_session(sessionmaker(bind=engine))
     ModelBase.metadata.create_all(engine)
+    return Session

+ 47 - 22
src/groupme_bot.py

@@ -2,21 +2,20 @@ import atexit
 from logging.config import dictConfig
 from threading import Thread
 import random
+import os
+import os.path
 
 from flask import Flask, request
 import requests
 import requests.exceptions
+import toml
 
-from rollbot import Rollbot, RollbotMessage, RollbotPlugin
+from rollbot import RollbotConfig, Rollbot, RollbotMessage, RollbotPlugin
 
 import plugins
-
 import db
-from config import BOTS_LOOKUP, get_config, get_secret
-
-GLOBAL_ADMINS = get_secret("auths.global")
-GROUP_ADMINS = get_secret("auths.group")
 
+# Configure loggers
 dictConfig({
     "version": 1,
     "formatters": {"default": {
@@ -33,13 +32,31 @@ dictConfig({
     }
 })
 
+# Read bot configuration
+config_dir = os.environ.get("ROLLBOT_CFG_DIR", ".")
+with open(os.path.join(config_dir, "config.toml")) as infile:
+    raw_config = toml.load(infile)
+with open(os.path.join(config_dir, "secrets.toml")) as infile:
+    raw_secrets = toml.load(infile)
+
+db_file = os.path.abspath(raw_config["database"])
+auths = raw_secrets["auths"]
+global_admins = auths["global"]
+group_admins = auths["group"]
+bots_lookup = raw_secrets["groupme_bots"]
+
+# Define Flask app
+app = Flask(__name__)
+app.config["PROPAGATE_EXCEPTIONS"] = True
+
+# Define reply logic
 max_msg_len = 1000
 split_text = "\n..."
 msg_slice = max_msg_len - len(split_text)
 
 
 def post_groupme_message(msg, group_id):
-    bot_id = BOTS_LOOKUP[group_id]
+    bot_id = bots_lookup[group_id]
     msgs = []
     rem = msg
     while len(rem) > max_msg_len:
@@ -62,29 +79,37 @@ def post_groupme_message(msg, group_id):
             app.log_exception(h)
 
 
-app = Flask(__name__)
-app.config["PROPAGATE_EXCEPTIONS"] = True
-rollbot = Rollbot(
-    logger=app.logger,
-    plugin_classes=RollbotPlugin.find_all_plugins(),
-    aliases=get_config("aliases"),
-    responses=get_config("responses"),
-    sleep_time=float(get_config("sleep_time")),
-    callback=post_groupme_message,
-    session_factory=db.Session
-)
+# Init db
 app.logger.info("Initializing database tables")
-db.init_db()
+session_factory = db.init_db(db_file)
 app.logger.info("Finished initializing database")
+
+# Create bot
+config = RollbotConfig(
+    plugins=RollbotPlugin.find_all_plugins(),
+    session_factory=session_factory,
+    reply_callback=post_groupme_message,
+    aliases=raw_config.get("aliases", []),
+    responses=raw_config.get("responses", []),
+    sleep_time=raw_config.get("sleep_time", 0.0),
+    other={
+        **{k: v for k, v in raw_config.items() if k not in ("database", "aliases", "responses", "sleep_time")},
+        **{k: v for k, v in raw_secrets.items() if k not in ("auths", "groupme_bots")},
+    },
+)
+rollbot = Rollbot(config, logger=app.logger)
+
+# Setup plugins
 rollbot.start_plugins()
 atexit.register(rollbot.shutdown_plugins)
 
 
+# Routing
 @app.route("/", methods=["POST"])
 def execute():
     json = request.get_json()
-    msg = RollbotMessage.from_groupme(json, global_admins=GLOBAL_ADMINS, group_admins=GROUP_ADMINS)
-    if msg.group_id not in BOTS_LOOKUP:
+    msg = RollbotMessage.from_groupme(json, global_admins=global_admins, group_admins=group_admins)
+    if msg.group_id not in bots_lookup:
         app.logger.warning(f"Received message from unknown group ID {msg.group_id}")
         return "", 400
     t = Thread(target=lambda: rollbot.handle_command(msg))
@@ -94,7 +119,7 @@ def execute():
 
 @app.route("/health")
 def health():
-    return "Rollbot healthy!", 200
+    return "", 204
 
 
 if __name__ == "__main__":