|
@@ -6,11 +6,10 @@ from sqlalchemy.ext.declarative import declarative_base
|
|
|
|
|
|
ModelBase = declarative_base()
|
|
|
|
|
|
-
|
|
|
-def as_group_singleton(cls):
|
|
|
+def get_columns(cls, banned=()):
|
|
|
columns = {}
|
|
|
for name, typ in cls.__annotations__.items():
|
|
|
- if name == "group_id":
|
|
|
+ if name in banned:
|
|
|
raise ValueError(f"Cannot have column named group_id in as_group_singleton class {cls.__name__}")
|
|
|
if typ == int:
|
|
|
columns[name] = Column(Integer)
|
|
@@ -24,21 +23,42 @@ def as_group_singleton(cls):
|
|
|
columns[name] = Column(DateTime)
|
|
|
else:
|
|
|
raise TypeError(f"Unsupported annotation {typ} for {name} in {cls.__name__}")
|
|
|
+ return columns
|
|
|
+
|
|
|
+
|
|
|
+def get_column_defaults(cls, columns):
|
|
|
+ return {k: getattr(cls, k, None) for k in columns}
|
|
|
+
|
|
|
|
|
|
- cons_params = {k: getattr(cls, k, None) for k in columns}
|
|
|
+def get_table_name(cls):
|
|
|
+ return "".join(("_" + c.lower()) if "A" <= c <= "Z" else c for c in cls.__name__).strip("_")
|
|
|
|
|
|
- def get_or_create_standin(cls, db, group_id):
|
|
|
- sing = db.query(cls).get(group_id)
|
|
|
+
|
|
|
+def make_db_class(cls, key_fields):
|
|
|
+ columns = get_columns(cls, banned=key_fields)
|
|
|
+ cons_params = get_column_defaults(cls, columns)
|
|
|
+ if len(key_fields) == 1:
|
|
|
+ key_extractor = lambda msg: getattr(msg, key_fields[0])
|
|
|
+ else:
|
|
|
+ key_extractor = lambda msg: tuple(getattr(msg, k) for k in key_fields)
|
|
|
+
|
|
|
+ def get_or_create_standin(cls, db, msg):
|
|
|
+ sing = db.query(cls).get(key_extractor(msg))
|
|
|
if sing is None:
|
|
|
- sing = cls(group_id=group_id, **cons_params)
|
|
|
+ sing = cls(**{k: getattr(msg, k) for k in key_fields}, **cons_params)
|
|
|
db.add(sing)
|
|
|
return sing
|
|
|
|
|
|
- columns["__tablename__"] = "".join(("_" + c.lower()) if "A" <= c <= "Z" else c for c in cls.__name__).strip("_")
|
|
|
- columns["group_id"] = Column(String, primary_key=True)
|
|
|
+ columns["__tablename__"] = get_table_name(cls)
|
|
|
+ for k in key_fields:
|
|
|
+ columns[k] = Column(String, primary_key=True)
|
|
|
|
|
|
return type(
|
|
|
cls.__name__,
|
|
|
(ModelBase, cls),
|
|
|
dict(**columns, get_or_create=classmethod(get_or_create_standin))
|
|
|
)
|
|
|
+
|
|
|
+
|
|
|
+def as_group_singleton(cls):
|
|
|
+ return make_db_class(cls, ("group_id",))
|