Explorar el Código

Moving more eventually shared code into the lib

Kirk Trombley hace 4 años
padre
commit
ad3f3cec23
Se han modificado 4 ficheros con 23 adiciones y 25 borrados
  1. 7 5
      lib/rollbot/bot.py
  2. 10 1
      lib/rollbot/database.py
  3. 0 11
      src/db.py
  4. 6 8
      src/groupme_bot.py

+ 7 - 5
lib/rollbot/bot.py

@@ -9,6 +9,7 @@ from sqlalchemy.orm.session import Session
 
 from .messaging import RollbotResponse, RollbotFailure
 from .plugins import as_plugin, RollbotPlugin
+from .database import init_db_at_url
 
 
 def lift_response(call, response):
@@ -40,7 +41,7 @@ def get_session_manager_factory(session_factory):
 @dataclass
 class RollbotConfig:
     plugins: List[Type[RollbotPlugin]]
-    session_factory: Callable[[], Session]
+    db_url: str
     reply_callback: Callable[[str, str], None]
     aliases: Dict[str, str]
     responses: Dict[str, str]
@@ -57,10 +58,7 @@ class RollbotConfig:
 class Rollbot:
     def __init__(self, config, logger=logging.getLogger(__name__)):
         self.logger = logger
-        if config.session_factory is not None:
-            self.session_manager_factory = get_session_manager_factory(config.session_factory)
-        else:
-            self.session_manager_factory = lambda: None
+        self.session_manager_factory = lambda: None
         self.post_callback = config.reply_callback or (lambda txt, gid: self.logger.info(f"Responding to {gid} with {txt}"))
         self.commands = {}
         self.to_start = set()
@@ -101,6 +99,10 @@ class Rollbot:
             self.commands[alias] = self.commands[cmd]
         self.logger.info(f"Finished loading aliases, {len(self.commands)} total commands + aliases available")
 
+    def init_db(self):
+        session_factory = init_db_at_url(self.config.db_url)
+        self.session_manager_factory = get_session_manager_factory(session_factory)
+
     def start_plugins(self):
         self.logger.info("Starting plugins")
         with self.session_manager_factory() as session:

+ 10 - 1
lib/rollbot/database.py

@@ -1,11 +1,20 @@
 import datetime
 
-from sqlalchemy import Column, DateTime, Binary, String, Float, Integer
+from sqlalchemy import Column, DateTime, Binary, String, Float, Integer, create_engine
+from sqlalchemy.orm import sessionmaker, scoped_session
 from sqlalchemy.ext.declarative import declarative_base
 
 
 ModelBase = declarative_base()
 
+
+def init_db_at_url(url):
+    engine = create_engine(url)
+    session_factory = scoped_session(sessionmaker(bind=engine))
+    ModelBase.metadata.create_all(engine)
+    return session_factory
+
+
 def get_columns(cls, banned=()):
     columns = {}
     for name, typ in cls.__annotations__.items():

+ 0 - 11
src/db.py

@@ -1,11 +0,0 @@
-from sqlalchemy import create_engine
-from sqlalchemy.orm import sessionmaker, scoped_session
-
-from rollbot import ModelBase
-
-
-def init_db(db_file):
-    engine = create_engine("sqlite:///" + db_file)
-    Session = scoped_session(sessionmaker(bind=engine))
-    ModelBase.metadata.create_all(engine)
-    return Session

+ 6 - 8
src/groupme_bot.py

@@ -13,7 +13,6 @@ import toml
 from rollbot import RollbotConfig, Rollbot, RollbotMessage, RollbotPlugin
 
 import plugins
-import db
 
 # Configure loggers
 dictConfig({
@@ -39,7 +38,6 @@ with open(os.path.join(config_dir, "config.toml")) as infile:
 with open(os.path.join(config_dir, "secrets.toml")) as infile:
     raw_secrets = toml.load(infile)
 
-db_file = os.path.abspath(raw_config["database"])
 auths = raw_secrets["auths"]
 global_admins = auths["global"]
 group_admins = auths["group"]
@@ -79,15 +77,10 @@ def post_groupme_message(msg, group_id):
             app.log_exception(h)
 
 
-# Init db
-app.logger.info("Initializing database tables")
-session_factory = db.init_db(db_file)
-app.logger.info("Finished initializing database")
-
 # Create bot
 config = RollbotConfig(
     plugins=RollbotPlugin.find_all_plugins(),
-    session_factory=session_factory,
+    db_url="sqlite:///" + os.path.abspath(raw_config["database"]),
     reply_callback=post_groupme_message,
     aliases=raw_config.get("aliases", []),
     responses=raw_config.get("responses", []),
@@ -99,6 +92,11 @@ config = RollbotConfig(
 )
 rollbot = Rollbot(config, logger=app.logger)
 
+# Init db
+app.logger.info("Initializing database tables")
+rollbot.init_db()
+app.logger.info("Finished initializing database")
+
 # Setup plugins
 rollbot.start_plugins()
 atexit.register(rollbot.shutdown_plugins)