import datetime from contextlib import contextmanager from sqlalchemy import Column, DateTime, LargeBinary, String, Float, Integer, create_engine from sqlalchemy.orm import sessionmaker, scoped_session from sqlalchemy.ext.declarative import declarative_base ModelBase = declarative_base() def init_db_at_url(url): engine = create_engine(url) session_factory = scoped_session(sessionmaker(bind=engine)) ModelBase.metadata.create_all(engine) @contextmanager def session_manager_factory(): """Provide a transactional scope around a series of operations.""" session = session_factory() try: yield session session.commit() except: # TODO there is some worry that this would rollback things in other threads... # we should probably find a more correct solution for managing the threaded # db access, but the risk is fairly low at this point. session.rollback() raise finally: session.close() return session_manager_factory def get_columns(cls, banned=()): columns = {} for name, typ in cls.__annotations__.items(): if name in banned: raise ValueError(f"Cannot have column named group_id in as_group_singleton class {cls.__name__}") if typ == int: columns[name] = Column(Integer) elif typ == float: columns[name] = Column(Float) elif typ == str: columns[name] = Column(String) elif typ in (object, "binary"): columns[name] = Column(LargeBinary) elif typ == datetime.datetime: columns[name] = Column(DateTime) else: raise TypeError(f"Unsupported annotation {typ} for {name} in {cls.__name__}") return columns def get_column_defaults(cls, columns): return {k: getattr(cls, k, None) for k in columns} def get_table_name(cls): return "".join(("_" + c.lower()) if "A" <= c <= "Z" else c for c in cls.__name__).strip("_") def make_db_class(cls, key_fields): columns = get_columns(cls, banned=key_fields) cons_params = get_column_defaults(cls, columns) if len(key_fields) == 1: key_extractor = lambda msg: getattr(msg, key_fields[0]) else: key_extractor = lambda msg: tuple(getattr(msg, k) for k in key_fields) def get_or_create_standin(cls, db, msg): sing = db.query(cls).get(key_extractor(msg)) if sing is None: sing = cls(**{k: getattr(msg, k) for k in key_fields}, **cons_params) db.add(sing) return sing def create_from_key_standin(cls, key): if isinstance(key, tuple): key_dict = dict(zip(key_fields, key)) else: key_dict = { key_fields[0]: key } return cls(**key_dict, **cons_params) columns["__tablename__"] = get_table_name(cls) for k in key_fields: columns[k] = Column(String, primary_key=True) return type( cls.__name__, (ModelBase, cls), dict( **columns, get_or_create=classmethod(get_or_create_standin), create_from_key=classmethod(create_from_key_standin) ) ) def as_group_singleton(cls): return make_db_class(cls, ("group_id",)) def as_sender_singleton(cls): return make_db_class(cls, ("sender_id",))