|
@@ -1,9 +1,13 @@
|
|
|
import logging
|
|
|
import time
|
|
|
from contextlib import contextmanager
|
|
|
+from dataclasses import dataclass
|
|
|
+from typing import List, Type, Dict, Callable, Any
|
|
|
+
|
|
|
+from sqlalchemy.orm.session import Session
|
|
|
|
|
|
from .messaging import RollbotResponse, RollbotFailure
|
|
|
-from .plugins import as_plugin
|
|
|
+from .plugins import as_plugin, RollbotPlugin
|
|
|
|
|
|
|
|
|
def lift_response(call, response):
|
|
@@ -32,20 +36,40 @@ def get_session_manager_factory(session_factory):
|
|
|
return session_manager_factory
|
|
|
|
|
|
|
|
|
+@dataclass
|
|
|
+class RollbotConfig:
|
|
|
+ plugins: List[Type[RollbotPlugin]]
|
|
|
+ session_factory: Callable[[], Session]
|
|
|
+ reply_callback: Callable[[str, str], None]
|
|
|
+ aliases: Dict[str, str]
|
|
|
+ responses: Dict[str, str]
|
|
|
+ sleep_time: float
|
|
|
+ other: Dict[str, Any]
|
|
|
+
|
|
|
+ def get(self, key):
|
|
|
+ c = self.other
|
|
|
+ for k in key.split("."):
|
|
|
+ c = c[k]
|
|
|
+ return c
|
|
|
+
|
|
|
+
|
|
|
class Rollbot:
|
|
|
- def __init__(self, logger=logging.getLogger(__name__), plugin_classes={}, aliases={}, responses={}, sleep_time=0.0, session_factory=None, callback=None):
|
|
|
+ def __init__(self, config, logger=logging.getLogger(__name__)):
|
|
|
self.logger = logger
|
|
|
- if session_factory is not None:
|
|
|
- self.session_manager_factory = get_session_manager_factory(session_factory)
|
|
|
+ if config.session_factory is not None:
|
|
|
+ self.session_manager_factory = get_session_manager_factory(config.session_factory)
|
|
|
else:
|
|
|
self.session_manager_factory = lambda: None
|
|
|
- self.post_callback = callback or (lambda txt, gid: self.logger.info(f"Responding to {gid} with {txt}"))
|
|
|
+ self.post_callback = config.reply_callback or (lambda txt, gid: self.logger.info(f"Responding to {gid} with {txt}"))
|
|
|
self.commands = {}
|
|
|
self.to_start = set()
|
|
|
self.to_stop = set()
|
|
|
+ self.sleep_time = config.sleep_time
|
|
|
+ self.last_exception = None
|
|
|
+ self.config = config
|
|
|
|
|
|
self.logger.info("Loading command plugins")
|
|
|
- for plugin_class in plugin_classes:
|
|
|
+ for plugin_class in config.plugins:
|
|
|
plugin_instance = plugin_class(self, logger=logger)
|
|
|
if plugin_instance.command in self.commands:
|
|
|
self.logger.error(f"Duplicate command word '{plugin_instance.command}'")
|
|
@@ -58,7 +82,7 @@ class Rollbot:
|
|
|
self.logger.info(f"Finished loading plugins, {len(self.commands)} commands found")
|
|
|
|
|
|
self.logger.info("Loading simple responses")
|
|
|
- for cmd, response in responses.items():
|
|
|
+ for cmd, response in config.responses.items():
|
|
|
if cmd in self.commands:
|
|
|
self.logger.error(f"Duplicate command word '{cmd}'")
|
|
|
raise ValueError(f"Duplicate command word '{cmd}'")
|
|
@@ -66,7 +90,7 @@ class Rollbot:
|
|
|
self.logger.info(f"Finished loading simple responses, {len(self.commands)} total commands available")
|
|
|
|
|
|
self.logger.info("Loading aliases")
|
|
|
- for alias, cmd in aliases.items():
|
|
|
+ for alias, cmd in config.aliases.items():
|
|
|
if cmd not in self.commands:
|
|
|
self.logger.error(f"Missing aliased command word '{cmd}'")
|
|
|
raise ValueError(f"Missing aliased command word '{cmd}'")
|
|
@@ -76,8 +100,6 @@ class Rollbot:
|
|
|
self.commands[alias] = self.commands[cmd]
|
|
|
self.logger.info(f"Finished loading aliases, {len(self.commands)} total commands + aliases available")
|
|
|
|
|
|
- self.sleep_time = sleep_time
|
|
|
-
|
|
|
def start_plugins(self):
|
|
|
self.logger.info("Starting plugins")
|
|
|
with self.session_manager_factory() as session:
|
|
@@ -131,6 +153,7 @@ class Rollbot:
|
|
|
except Exception as e:
|
|
|
self.logger.exception(f"Exception during command execution for message {message.message_id}")
|
|
|
response = RollbotResponse(message, failure=RollbotFailure.INTERNAL_ERROR)
|
|
|
+ self.last_exception = e
|
|
|
|
|
|
if not response.respond:
|
|
|
self.logger.info(f"Skipping response to message {message.message_id}")
|