Browse Source

Adding a new way to declare and access database entities, and a local run script

Kirk Trombley 6 years ago
parent
commit
147b47aa36
2 changed files with 61 additions and 6 deletions
  1. 5 0
      rollbot-local.sh
  2. 56 6
      src/command_system.py

+ 5 - 0
rollbot-local.sh

@@ -0,0 +1,5 @@
+#!/usr/bin/env bash
+
+pushd src/
+ROLLBOT_CFG_DIR=../config python3 app.py
+popd

+ 56 - 6
src/command_system.py

@@ -3,6 +3,8 @@ from dataclasses import dataclass
 from enum import Enum, auto
 import inspect
 import functools
+import pickle
+import datetime
 
 from sqlalchemy import Column, DateTime, Binary, String, Float, Integer
 from sqlalchemy.ext.declarative import declarative_base
@@ -41,6 +43,51 @@ class GroupBasedSingleton(ModelBase):
             db.add(sing)
         return sing
 
+    def set_binary(self, obj):
+        self.binary_data = pickle.dumps(obj)
+
+    def get_binary(self):
+        if self.binary_data is None:
+            return None
+        return pickle.loads(self.binary_data)
+
+
+def as_group_singleton(cls):
+    columns = {}
+    for name, typ in cls.__annotations__.items():
+        if name == "group_id":
+            raise ValueError(f"Cannot have column named group_id in as_group_singleton class {cls.__name__}")
+        if typ == int:
+            columns[name] = Column(Integer)
+        elif typ == float:
+            columns[name] = Column(Float)
+        elif typ == str:
+            columns[name] = Column(String)
+        elif typ in (object, "binary"):
+            columns[name] = Column(Binary)
+        elif typ == datetime.datetime:
+            columns[name] = Column(DateTime)
+        else:
+            raise TypeError(f"Unsupported annotation {typ} for {name} in {cls.__name__}")
+
+    cons_params = {k: getattr(cls, k, None) for k in columns}
+
+    def get_or_create_standin(cls, db, group_id):
+        sing = db.query(cls).get(group_id)
+        if sing is None:
+            sing = cls(group_id=group_id, **cons_params)
+            db.add(sing)
+        return sing
+
+    columns["__tablename__"] = "".join(("_" + c.lower()) if "A" <= c <= "Z" else c for c in cls.__name__).strip("_")
+    columns["group_id"] = Column(String, primary_key=True)
+
+    return type(
+        cls.__name__,
+        (ModelBase,),
+        dict(**columns, get_or_create=classmethod(get_or_create_standin))
+    )
+
 
 def pop_arg(text):
     if text is None:
@@ -194,16 +241,19 @@ def as_plugin(command):
         converters = []
         for p in sig.parameters:
             if p in ("msg", "message", "_msg"):
-                converters.append(lambda self, db, msg: msg)
+                converters.append(lambda cmd, db, msg: msg)
             elif p in ("db", "database"):
-                converters.append(lambda self, db, msg: db)
+                converters.append(lambda cmd, db, msg: db)
             elif p in ("log", "logger"):
-                converters.append(lambda self, db, msg: self.logger)
+                converters.append(lambda cmd, db, msg: cmd.logger)
             elif p in ("bot", "rollbot"):
-                converters.append(lambda self, db, msg: self.bot)
+                converters.append(lambda cmd, db, msg: cmd.bot)
             elif p.startswith("data") or p.endswith("data") or p in ("group_singleton", "singleton"):
-                subp = fn.__annotations__.get(p, "")
-                converters.append(lambda self, db, msg, subp=subp: GroupBasedSingleton.get_or_create(db, msg.group_id, self.command, subp))
+                annot = fn.__annotations__.get(p, p)
+                if isinstance(annot, str):
+                    converters.append(lambda cmd, db, msg, subp=annot: GroupBasedSingleton.get_or_create(db, msg.group_id, cmd.command, subp))
+                else:
+                    converters.append(lambda cmd, db, msg, sing_cls=annot: sing_cls.get_or_create(db, msg.group_id))
             else:
                 raise ValueError(f"Illegal argument name {p} in decorated plugin {command_name}")