database.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import datetime
  2. from sqlalchemy import Column, DateTime, Binary, String, Float, Integer
  3. from sqlalchemy.ext.declarative import declarative_base
  4. ModelBase = declarative_base()
  5. def get_columns(cls, banned=()):
  6. columns = {}
  7. for name, typ in cls.__annotations__.items():
  8. if name in banned:
  9. raise ValueError(f"Cannot have column named group_id in as_group_singleton class {cls.__name__}")
  10. if typ == int:
  11. columns[name] = Column(Integer)
  12. elif typ == float:
  13. columns[name] = Column(Float)
  14. elif typ == str:
  15. columns[name] = Column(String)
  16. elif typ in (object, "binary"):
  17. columns[name] = Column(Binary)
  18. elif typ == datetime.datetime:
  19. columns[name] = Column(DateTime)
  20. else:
  21. raise TypeError(f"Unsupported annotation {typ} for {name} in {cls.__name__}")
  22. return columns
  23. def get_column_defaults(cls, columns):
  24. return {k: getattr(cls, k, None) for k in columns}
  25. def get_table_name(cls):
  26. return "".join(("_" + c.lower()) if "A" <= c <= "Z" else c for c in cls.__name__).strip("_")
  27. def make_db_class(cls, key_fields):
  28. columns = get_columns(cls, banned=key_fields)
  29. cons_params = get_column_defaults(cls, columns)
  30. if len(key_fields) == 1:
  31. key_extractor = lambda msg: getattr(msg, key_fields[0])
  32. else:
  33. key_extractor = lambda msg: tuple(getattr(msg, k) for k in key_fields)
  34. def get_or_create_standin(cls, db, msg):
  35. sing = db.query(cls).get(key_extractor(msg))
  36. if sing is None:
  37. sing = cls(**{k: getattr(msg, k) for k in key_fields}, **cons_params)
  38. db.add(sing)
  39. return sing
  40. columns["__tablename__"] = get_table_name(cls)
  41. for k in key_fields:
  42. columns[k] = Column(String, primary_key=True)
  43. return type(
  44. cls.__name__,
  45. (ModelBase, cls),
  46. dict(**columns, get_or_create=classmethod(get_or_create_standin))
  47. )
  48. def as_group_singleton(cls):
  49. return make_db_class(cls, ("group_id",))