Browse Source

Improving some docs and moving some db code into lib where it belongs

Kirk Trombley 5 years ago
parent
commit
5883e156da
6 changed files with 97 additions and 29 deletions
  1. 27 4
      lib/rollbot/bot.py
  2. 66 1
      lib/rollbot/messaging.py
  3. 0 19
      src/db.py
  4. 1 1
      src/discord_bot.py
  5. 1 1
      src/groupme_bot.py
  6. 2 3
      src/test_driver.py

+ 27 - 4
lib/rollbot/bot.py

@@ -1,5 +1,6 @@
 import logging
 import time
+from contextlib import contextmanager
 
 from .messaging import RollbotResponse, RollbotFailure
 from .plugins import as_plugin
@@ -12,10 +13,32 @@ def lift_response(call, response):
     return response_func
 
 
+def get_session_manager_factory(session_factory):
+    @contextmanager
+    def session_manager_factory():
+        """Provide a transactional scope around a series of operations."""
+        session = session_factory()
+        try:
+            yield session
+            session.commit()
+        except:
+            # TODO there is some worry that this would rollback things in other threads...
+            # we should probably find a more correct solution for managing the threaded
+            # db access, but the risk is fairly low at this point.
+            session.rollback()
+            raise
+        finally:
+            session.close()
+    return session_manager_factory
+
+
 class Rollbot:
     def __init__(self, logger=logging.getLogger(__name__), plugin_classes={}, aliases={}, responses={}, sleep_time=0.0, session_factory=None, callback=None):
         self.logger = logger
-        self.session_factory = session_factory or (lambda: None)
+        if session_factory is not None:
+            self.session_manager_factory = get_session_manager_factory(session_factory)
+        else:
+            self.session_manager_factory = lambda: None
         self.post_callback = callback or (lambda txt, gid: self.logger.info(f"Responding to {gid} with {txt}"))
         self.commands = {}
         self.to_start = set()
@@ -57,14 +80,14 @@ class Rollbot:
 
     def start_plugins(self):
         self.logger.info("Starting plugins")
-        with self.session_factory() as session:
+        with self.session_manager_factory() as session:
             for cmd in self.to_start:
                 cmd.on_start(session)
         self.logger.info("Finished starting plugins")
 
     def shutdown_plugins(self):
         self.logger.info("Shutting down plugins")
-        with self.session_factory() as session:
+        with self.session_manager_factory() as session:
             for cmd in self.to_stop:
                 cmd.on_shutdown(session)
         self.logger.info("Finished shutting down plugins")
@@ -80,7 +103,7 @@ class Rollbot:
             self.logger.warn(f"Message {message.message_id} had a command {message.command} that could not be run.")
             return RollbotResponse(message, failure=RollbotFailure.INVALID_COMMAND)
 
-        with self.session_factory() as session:
+        with self.session_manager_factory() as session:
             response = plugin.on_command(session, message)
 
         if not response.is_success:

+ 66 - 1
lib/rollbot/messaging.py

@@ -6,16 +6,57 @@ BANGS = ('!',)
 
 
 def pop_arg(text):
+    """
+    Pop an argument from a string of text. The text is split at the first
+    substring containing only whitespace characters, ignoring leading and
+    trailing whitespace of the string, with the text preceeding being 
+    returned as the first value, and the text following being returned 
+    as the second value. 
+    
+    Both return values will be stripped of leading and trailing whitespace. 
+    If the given text is None, both return values will be None. If there is 
+    no text following the split point, the second return value will be None.
+    """
     if text is None:
         return None, None
+    text = text.strip()
     parts = text.split(maxsplit=1)
     if len(parts) == 1:
         return parts[0], None
-    return parts[0], parts[1].strip()
+    return parts[0], parts[1]
 
 
 @dataclass
 class RollbotMessage:
+    """
+    A data class modeling a message that was received by a Rollbot.
+    This class is used both for mundane messages and for commands.
+
+    Init Fields:
+    src - a string describing the source of the message, usually GROUPME
+    name - the plaintext display name of the sender of the message
+    sender_id - the service-specific id of the sender, which will not change
+    group_id - the id of the "group" the message was sent to, which can 
+        mean different concepts depending on src
+    message_id - the service-specific unique id of the message
+    message_txt - the raw, full text of the message
+    from_admin - a boolean flag denoting if the sender has admin privileges
+
+    Derived Fields:
+    is_command - a boolean flag denoting if the message is a command. This 
+        will be true if message_txt begins with a "!" character followed by
+        one or more non-whitespace characters (with whitespace between the
+        bang and the first non-whitespace character being ignored)
+    raw_command - the raw text of the command, i.e., the first "word" after
+        the bang, with leading and trailing whitespace removed. This field 
+        will only be present if is_command is True
+    command - raw_command normalized to lower case. This field will only be 
+        present if is_command is True
+    raw_args - the raw text of the arguments following command, i.e., the 
+        remaining ontent of message_txt, with leading and trailing whitespace
+        removed. This field will only be present if is_command is True
+
+    """
     src: str
     name: str
     sender_id: str
@@ -82,11 +123,35 @@ class RollbotMessage:
         )
 
     def args(self, normalize=True):
+        """
+        Lazily pop arguments from the raw argument string of this message
+        and yield them one at a time as a generator. If the optional
+        normalize parameter is set to False, the arguments will
+        be returned exactly as they appear in the message, and if normalize
+        is set to True or omitted, the arguments will be converted to lower
+        case. 
+        
+        For details on argument "popping", see the rollbot.pop_arg function.
+
+        Behavior is undefined if this method is called on a message whose
+        is_command field is false.
+        """
         arg, rest = pop_arg(self.raw_args)
         while arg is not None:
             yield arg.lower() if normalize else arg
             arg, rest = pop_arg(rest)
 
+    def arg_list(self):
+        """
+        Take the raw argument string of this message and split it on any
+        sequence of one or more whitespace characters, and return the result.
+        This can be useful to pass to an argparse.ArgumentParser.
+
+        Behavior is undefined if this method is called on a message whose
+        is_command field is false.
+        """
+        return self.raw_args.split()
+
 
 class RollbotFailure(Enum):
     INVALID_COMMAND = auto()

+ 0 - 19
src/db.py

@@ -1,5 +1,3 @@
-from contextlib import contextmanager
-
 from sqlalchemy import create_engine
 from sqlalchemy.orm import sessionmaker, scoped_session
 
@@ -14,20 +12,3 @@ Session = scoped_session(sessionmaker(bind=engine))
 
 def init_db():
     ModelBase.metadata.create_all(engine)
-
-
-@contextmanager
-def session_scope():
-    """Provide a transactional scope around a series of operations."""
-    session = Session()
-    try:
-        yield session
-        session.commit()
-    except:
-        # TODO there is some worry that this would rollback things in other threads...
-        # we should probably find a more correct solution for managing the threaded
-        # db access, but the risk is fairly low at this point.
-        session.rollback()
-        raise
-    finally:
-        session.close()

+ 1 - 1
src/discord_bot.py

@@ -46,7 +46,7 @@ rollbot = Rollbot(
     aliases=get_config("aliases"),
     responses=get_config("responses"),
     callback=lambda msg, channel_id: msg_queue.append((msg, discord.utils.get(client.get_all_channels(), id=int(channel_id)))),
-    session_factory=db.session_scope
+    session_factory=db.Session
 )
 rollbot.logger.info("Initializing database tables")
 db.init_db()

+ 1 - 1
src/groupme_bot.py

@@ -53,7 +53,7 @@ rollbot = Rollbot(
     responses=get_config("responses"),
     sleep_time=float(get_config("sleep_time")),
     callback=post_groupme_message,
-    session_factory=db.session_scope
+    session_factory=db.Session
 )
 app.logger.info("Initializing database tables")
 db.init_db()

+ 2 - 3
src/test_driver.py

@@ -14,13 +14,12 @@ except FileNotFoundError:
 import db
 
 rollbot = Rollbot(
-    plugins=get_config("plugins"),
+    plugin_classes=get_config("plugins"),
     aliases=get_config("aliases"),
     responses=get_config("responses"),
-    lookup=BOTS_LOOKUP,
     sleep_time=float(get_config("sleep_time")),
     callback=None,
-    session_factory=db.session_scope
+    session_factory=db.Session
 )
 db.init_db()