123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104 |
- import datetime
- from contextlib import contextmanager
- from sqlalchemy import Column, DateTime, LargeBinary, String, Float, Integer, create_engine
- from sqlalchemy.orm import sessionmaker, scoped_session
- from sqlalchemy.ext.declarative import declarative_base
- ModelBase = declarative_base()
- def init_db_at_url(url):
- engine = create_engine(url)
- session_factory = scoped_session(sessionmaker(bind=engine))
- ModelBase.metadata.create_all(engine)
- @contextmanager
- def session_manager_factory():
- """Provide a transactional scope around a series of operations."""
- session = session_factory()
- try:
- yield session
- session.commit()
- except:
- # TODO there is some worry that this would rollback things in other threads...
- # we should probably find a more correct solution for managing the threaded
- # db access, but the risk is fairly low at this point.
- session.rollback()
- raise
- finally:
- session.close()
- return session_manager_factory
- def get_columns(cls, banned=()):
- columns = {}
- for name, typ in cls.__annotations__.items():
- if name in banned:
- raise ValueError(f"Cannot have column named group_id in as_group_singleton class {cls.__name__}")
- if typ == int:
- columns[name] = Column(Integer)
- elif typ == float:
- columns[name] = Column(Float)
- elif typ == str:
- columns[name] = Column(String)
- elif typ in (object, "binary"):
- columns[name] = Column(LargeBinary)
- elif typ == datetime.datetime:
- columns[name] = Column(DateTime)
- else:
- raise TypeError(f"Unsupported annotation {typ} for {name} in {cls.__name__}")
- return columns
- def get_column_defaults(cls, columns):
- return {k: getattr(cls, k, None) for k in columns}
- def get_table_name(cls):
- return "".join(("_" + c.lower()) if "A" <= c <= "Z" else c for c in cls.__name__).strip("_")
- def make_db_class(cls, key_fields):
- columns = get_columns(cls, banned=key_fields)
- cons_params = get_column_defaults(cls, columns)
- if len(key_fields) == 1:
- key_extractor = lambda msg: getattr(msg, key_fields[0])
- else:
- key_extractor = lambda msg: tuple(getattr(msg, k) for k in key_fields)
- def get_or_create_standin(cls, db, msg):
- sing = db.query(cls).get(key_extractor(msg))
- if sing is None:
- sing = cls(**{k: getattr(msg, k) for k in key_fields}, **cons_params)
- db.add(sing)
- return sing
- def create_from_key_standin(cls, key):
- if isinstance(key, tuple):
- key_dict = dict(zip(key_fields, key))
- else:
- key_dict = { key_fields[0]: key }
- return cls(**key_dict, **cons_params)
- columns["__tablename__"] = get_table_name(cls)
- for k in key_fields:
- columns[k] = Column(String, primary_key=True)
- return type(
- cls.__name__,
- (ModelBase, cls),
- dict(
- **columns,
- get_or_create=classmethod(get_or_create_standin),
- create_from_key=classmethod(create_from_key_standin)
- )
- )
- def as_group_singleton(cls):
- return make_db_class(cls, ("group_id",))
- def as_sender_singleton(cls):
- return make_db_class(cls, ("sender_id",))
|