Browse Source

Internal refactor to allow creation of new group singleton-esque injected data

Kirk Trombley 5 years ago
parent
commit
0468eb5dcb
2 changed files with 31 additions and 11 deletions
  1. 29 9
      lib/rollbot/database.py
  2. 2 2
      lib/rollbot/plugins.py

+ 29 - 9
lib/rollbot/database.py

@@ -6,11 +6,10 @@ from sqlalchemy.ext.declarative import declarative_base
 
 ModelBase = declarative_base()
 
-
-def as_group_singleton(cls):
+def get_columns(cls, banned=()):
     columns = {}
     for name, typ in cls.__annotations__.items():
-        if name == "group_id":
+        if name in banned:
             raise ValueError(f"Cannot have column named group_id in as_group_singleton class {cls.__name__}")
         if typ == int:
             columns[name] = Column(Integer)
@@ -24,21 +23,42 @@ def as_group_singleton(cls):
             columns[name] = Column(DateTime)
         else:
             raise TypeError(f"Unsupported annotation {typ} for {name} in {cls.__name__}")
+    return columns
+
+
+def get_column_defaults(cls, columns):
+    return {k: getattr(cls, k, None) for k in columns}
+
 
-    cons_params = {k: getattr(cls, k, None) for k in columns}
+def get_table_name(cls):
+    return "".join(("_" + c.lower()) if "A" <= c <= "Z" else c for c in cls.__name__).strip("_")
 
-    def get_or_create_standin(cls, db, group_id):
-        sing = db.query(cls).get(group_id)
+
+def make_db_class(cls, key_fields):
+    columns = get_columns(cls, banned=key_fields)
+    cons_params = get_column_defaults(cls, columns)
+    if len(key_fields) == 1:
+        key_extractor = lambda msg: getattr(msg, key_fields[0])
+    else:
+        key_extractor = lambda msg: tuple(getattr(msg, k) for k in key_fields)
+
+    def get_or_create_standin(cls, db, msg):
+        sing = db.query(cls).get(key_extractor(msg))
         if sing is None:
-            sing = cls(group_id=group_id, **cons_params)
+            sing = cls(**{k: getattr(msg, k) for k in key_fields}, **cons_params)
             db.add(sing)
         return sing
 
-    columns["__tablename__"] = "".join(("_" + c.lower()) if "A" <= c <= "Z" else c for c in cls.__name__).strip("_")
-    columns["group_id"] = Column(String, primary_key=True)
+    columns["__tablename__"] = get_table_name(cls)
+    for k in key_fields:
+        columns[k] = Column(String, primary_key=True)
 
     return type(
         cls.__name__,
         (ModelBase, cls),
         dict(**columns, get_or_create=classmethod(get_or_create_standin))
     )
+
+
+def as_group_singleton(cls):
+    return make_db_class(cls, ("group_id",))

+ 2 - 2
lib/rollbot/plugins.py

@@ -52,9 +52,9 @@ def as_plugin(command):
                 converters.append(lambda cmd, db, msg: cmd.bot)
             elif p in ("subc", "subcommand"):
                 converters.append(lambda cmd, db, msg: RollbotMessage.from_subcommand(msg))
-            elif p.startswith("data") or p.endswith("data") or p in ("group_singleton", "singleton"):
+            elif p.startswith("data") or p.endswith("data"):
                 annot = fn.__annotations__.get(p, p)
-                converters.append(lambda cmd, db, msg, sing_cls=annot: sing_cls.get_or_create(db, msg.group_id))
+                converters.append(lambda cmd, db, msg, sing_cls=annot: sing_cls.get_or_create(db, msg))
             else:
                 raise ValueError(f"Illegal argument name {p} in decorated plugin {command_name}")