database.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import datetime
  2. from contextlib import contextmanager
  3. from sqlalchemy import Column, DateTime, LargeBinary, String, Float, Integer, create_engine
  4. from sqlalchemy.orm import sessionmaker, scoped_session
  5. from sqlalchemy.ext.declarative import declarative_base
  6. ModelBase = declarative_base()
  7. def init_db_at_url(url):
  8. engine = create_engine(url)
  9. session_factory = scoped_session(sessionmaker(bind=engine))
  10. ModelBase.metadata.create_all(engine)
  11. @contextmanager
  12. def session_manager_factory():
  13. """Provide a transactional scope around a series of operations."""
  14. session = session_factory()
  15. try:
  16. yield session
  17. session.commit()
  18. except:
  19. # TODO there is some worry that this would rollback things in other threads...
  20. # we should probably find a more correct solution for managing the threaded
  21. # db access, but the risk is fairly low at this point.
  22. session.rollback()
  23. raise
  24. finally:
  25. session.close()
  26. return session_manager_factory
  27. def get_columns(cls, banned=()):
  28. columns = {}
  29. for name, typ in cls.__annotations__.items():
  30. if name in banned:
  31. raise ValueError(f"Cannot have column named group_id in as_group_singleton class {cls.__name__}")
  32. if typ == int:
  33. columns[name] = Column(Integer)
  34. elif typ == float:
  35. columns[name] = Column(Float)
  36. elif typ == str:
  37. columns[name] = Column(String)
  38. elif typ in (object, "binary"):
  39. columns[name] = Column(LargeBinary)
  40. elif typ == datetime.datetime:
  41. columns[name] = Column(DateTime)
  42. else:
  43. raise TypeError(f"Unsupported annotation {typ} for {name} in {cls.__name__}")
  44. return columns
  45. def get_column_defaults(cls, columns):
  46. return {k: getattr(cls, k, None) for k in columns}
  47. def get_table_name(cls):
  48. return "".join(("_" + c.lower()) if "A" <= c <= "Z" else c for c in cls.__name__).strip("_")
  49. def make_db_class(cls, key_fields):
  50. columns = get_columns(cls, banned=key_fields)
  51. cons_params = get_column_defaults(cls, columns)
  52. if len(key_fields) == 1:
  53. key_extractor = lambda msg: getattr(msg, key_fields[0])
  54. else:
  55. key_extractor = lambda msg: tuple(getattr(msg, k) for k in key_fields)
  56. def get_or_create_standin(cls, db, msg):
  57. sing = db.query(cls).get(key_extractor(msg))
  58. if sing is None:
  59. sing = cls(**{k: getattr(msg, k) for k in key_fields}, **cons_params)
  60. db.add(sing)
  61. return sing
  62. def create_from_key_standin(cls, key):
  63. if isinstance(key, tuple):
  64. key_dict = dict(zip(key_fields, key))
  65. else:
  66. key_dict = { key_fields[0]: key }
  67. return cls(**key_dict, **cons_params)
  68. columns["__tablename__"] = get_table_name(cls)
  69. for k in key_fields:
  70. columns[k] = Column(String, primary_key=True)
  71. return type(
  72. cls.__name__,
  73. (ModelBase, cls),
  74. dict(
  75. **columns,
  76. get_or_create=classmethod(get_or_create_standin),
  77. create_from_key=classmethod(create_from_key_standin)
  78. )
  79. )
  80. def as_group_singleton(cls):
  81. return make_db_class(cls, ("group_id",))
  82. def as_sender_singleton(cls):
  83. return make_db_class(cls, ("sender_id",))