12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273 |
- import datetime
- from sqlalchemy import Column, DateTime, Binary, String, Float, Integer, create_engine
- from sqlalchemy.orm import sessionmaker, scoped_session
- from sqlalchemy.ext.declarative import declarative_base
- ModelBase = declarative_base()
- def init_db_at_url(url):
- engine = create_engine(url)
- session_factory = scoped_session(sessionmaker(bind=engine))
- ModelBase.metadata.create_all(engine)
- return session_factory
- def get_columns(cls, banned=()):
- columns = {}
- for name, typ in cls.__annotations__.items():
- if name in banned:
- raise ValueError(f"Cannot have column named group_id in as_group_singleton class {cls.__name__}")
- if typ == int:
- columns[name] = Column(Integer)
- elif typ == float:
- columns[name] = Column(Float)
- elif typ == str:
- columns[name] = Column(String)
- elif typ in (object, "binary"):
- columns[name] = Column(Binary)
- elif typ == datetime.datetime:
- columns[name] = Column(DateTime)
- else:
- raise TypeError(f"Unsupported annotation {typ} for {name} in {cls.__name__}")
- return columns
- def get_column_defaults(cls, columns):
- return {k: getattr(cls, k, None) for k in columns}
- def get_table_name(cls):
- return "".join(("_" + c.lower()) if "A" <= c <= "Z" else c for c in cls.__name__).strip("_")
- def make_db_class(cls, key_fields):
- columns = get_columns(cls, banned=key_fields)
- cons_params = get_column_defaults(cls, columns)
- if len(key_fields) == 1:
- key_extractor = lambda msg: getattr(msg, key_fields[0])
- else:
- key_extractor = lambda msg: tuple(getattr(msg, k) for k in key_fields)
- def get_or_create_standin(cls, db, msg):
- sing = db.query(cls).get(key_extractor(msg))
- if sing is None:
- sing = cls(**{k: getattr(msg, k) for k in key_fields}, **cons_params)
- db.add(sing)
- return sing
- columns["__tablename__"] = get_table_name(cls)
- for k in key_fields:
- columns[k] = Column(String, primary_key=True)
- return type(
- cls.__name__,
- (ModelBase, cls),
- dict(**columns, get_or_create=classmethod(get_or_create_standin))
- )
- def as_group_singleton(cls):
- return make_db_class(cls, ("group_id",))
|