database.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. import datetime
  2. from sqlalchemy import Column, DateTime, Binary, String, Float, Integer
  3. from sqlalchemy.ext.declarative import declarative_base
  4. ModelBase = declarative_base()
  5. def as_group_singleton(cls):
  6. columns = {}
  7. for name, typ in cls.__annotations__.items():
  8. if name == "group_id":
  9. raise ValueError(f"Cannot have column named group_id in as_group_singleton class {cls.__name__}")
  10. if typ == int:
  11. columns[name] = Column(Integer)
  12. elif typ == float:
  13. columns[name] = Column(Float)
  14. elif typ == str:
  15. columns[name] = Column(String)
  16. elif typ in (object, "binary"):
  17. columns[name] = Column(Binary)
  18. elif typ == datetime.datetime:
  19. columns[name] = Column(DateTime)
  20. else:
  21. raise TypeError(f"Unsupported annotation {typ} for {name} in {cls.__name__}")
  22. cons_params = {k: getattr(cls, k, None) for k in columns}
  23. def get_or_create_standin(cls, db, group_id):
  24. sing = db.query(cls).get(group_id)
  25. if sing is None:
  26. sing = cls(group_id=group_id, **cons_params)
  27. db.add(sing)
  28. return sing
  29. columns["__tablename__"] = "".join(("_" + c.lower()) if "A" <= c <= "Z" else c for c in cls.__name__).strip("_")
  30. columns["group_id"] = Column(String, primary_key=True)
  31. return type(
  32. cls.__name__,
  33. (ModelBase, cls),
  34. dict(**columns, get_or_create=classmethod(get_or_create_standin))
  35. )