database.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. import datetime
  2. from sqlalchemy import Column, DateTime, Binary, String, Float, Integer, create_engine
  3. from sqlalchemy.orm import sessionmaker, scoped_session
  4. from sqlalchemy.ext.declarative import declarative_base
  5. ModelBase = declarative_base()
  6. def init_db_at_url(url):
  7. engine = create_engine(url)
  8. session_factory = scoped_session(sessionmaker(bind=engine))
  9. ModelBase.metadata.create_all(engine)
  10. return session_factory
  11. def get_columns(cls, banned=()):
  12. columns = {}
  13. for name, typ in cls.__annotations__.items():
  14. if name in banned:
  15. raise ValueError(f"Cannot have column named group_id in as_group_singleton class {cls.__name__}")
  16. if typ == int:
  17. columns[name] = Column(Integer)
  18. elif typ == float:
  19. columns[name] = Column(Float)
  20. elif typ == str:
  21. columns[name] = Column(String)
  22. elif typ in (object, "binary"):
  23. columns[name] = Column(Binary)
  24. elif typ == datetime.datetime:
  25. columns[name] = Column(DateTime)
  26. else:
  27. raise TypeError(f"Unsupported annotation {typ} for {name} in {cls.__name__}")
  28. return columns
  29. def get_column_defaults(cls, columns):
  30. return {k: getattr(cls, k, None) for k in columns}
  31. def get_table_name(cls):
  32. return "".join(("_" + c.lower()) if "A" <= c <= "Z" else c for c in cls.__name__).strip("_")
  33. def make_db_class(cls, key_fields):
  34. columns = get_columns(cls, banned=key_fields)
  35. cons_params = get_column_defaults(cls, columns)
  36. if len(key_fields) == 1:
  37. key_extractor = lambda msg: getattr(msg, key_fields[0])
  38. else:
  39. key_extractor = lambda msg: tuple(getattr(msg, k) for k in key_fields)
  40. def get_or_create_standin(cls, db, msg):
  41. sing = db.query(cls).get(key_extractor(msg))
  42. if sing is None:
  43. sing = cls(**{k: getattr(msg, k) for k in key_fields}, **cons_params)
  44. db.add(sing)
  45. return sing
  46. columns["__tablename__"] = get_table_name(cls)
  47. for k in key_fields:
  48. columns[k] = Column(String, primary_key=True)
  49. return type(
  50. cls.__name__,
  51. (ModelBase, cls),
  52. dict(**columns, get_or_create=classmethod(get_or_create_standin))
  53. )
  54. def as_group_singleton(cls):
  55. return make_db_class(cls, ("group_id",))