Browse Source

Little bit of database logic refactoring

Kirk Trombley 4 years ago
parent
commit
feb980c8f1
2 changed files with 18 additions and 25 deletions
  1. 1 24
      lib/rollbot/bot.py
  2. 17 1
      lib/rollbot/database.py

+ 1 - 24
lib/rollbot/bot.py

@@ -1,12 +1,9 @@
 import logging
 import time
 import traceback
-from contextlib import contextmanager
 from dataclasses import dataclass
 from typing import List, Type, Dict, Callable, Any
 
-from sqlalchemy.orm.session import Session
-
 from .messaging import RollbotResponse, RollbotFailure
 from .plugins import as_plugin, RollbotPlugin
 from .database import init_db_at_url
@@ -19,25 +16,6 @@ def lift_response(call, response):
     return response_func
 
 
-def get_session_manager_factory(session_factory):
-    @contextmanager
-    def session_manager_factory():
-        """Provide a transactional scope around a series of operations."""
-        session = session_factory()
-        try:
-            yield session
-            session.commit()
-        except:
-            # TODO there is some worry that this would rollback things in other threads...
-            # we should probably find a more correct solution for managing the threaded
-            # db access, but the risk is fairly low at this point.
-            session.rollback()
-            raise
-        finally:
-            session.close()
-    return session_manager_factory
-
-
 @dataclass
 class RollbotConfig:
     plugins: List[Type[RollbotPlugin]]
@@ -100,8 +78,7 @@ class Rollbot:
         self.logger.info(f"Finished loading aliases, {len(self.commands)} total commands + aliases available")
 
     def init_db(self):
-        session_factory = init_db_at_url(self.config.db_url)
-        self.session_manager_factory = get_session_manager_factory(session_factory)
+        self.session_manager_factory = init_db_at_url(self.config.db_url)
 
     def start_plugins(self):
         self.logger.info("Starting plugins")

+ 17 - 1
lib/rollbot/database.py

@@ -1,4 +1,5 @@
 import datetime
+from contextlib import contextmanager
 
 from sqlalchemy import Column, DateTime, Binary, String, Float, Integer, create_engine
 from sqlalchemy.orm import sessionmaker, scoped_session
@@ -12,7 +13,22 @@ def init_db_at_url(url):
     engine = create_engine(url)
     session_factory = scoped_session(sessionmaker(bind=engine))
     ModelBase.metadata.create_all(engine)
-    return session_factory
+    @contextmanager
+    def session_manager_factory():
+        """Provide a transactional scope around a series of operations."""
+        session = session_factory()
+        try:
+            yield session
+            session.commit()
+        except:
+            # TODO there is some worry that this would rollback things in other threads...
+            # we should probably find a more correct solution for managing the threaded
+            # db access, but the risk is fairly low at this point.
+            session.rollback()
+            raise
+        finally:
+            session.close()
+    return session_manager_factory
 
 
 def get_columns(cls, banned=()):