import datetime from sqlalchemy import Column, DateTime, Binary, String, Float, Integer from sqlalchemy.ext.declarative import declarative_base ModelBase = declarative_base() def as_group_singleton(cls): columns = {} for name, typ in cls.__annotations__.items(): if name == "group_id": raise ValueError(f"Cannot have column named group_id in as_group_singleton class {cls.__name__}") if typ == int: columns[name] = Column(Integer) elif typ == float: columns[name] = Column(Float) elif typ == str: columns[name] = Column(String) elif typ in (object, "binary"): columns[name] = Column(Binary) elif typ == datetime.datetime: columns[name] = Column(DateTime) else: raise TypeError(f"Unsupported annotation {typ} for {name} in {cls.__name__}") cons_params = {k: getattr(cls, k, None) for k in columns} def get_or_create_standin(cls, db, group_id): sing = db.query(cls).get(group_id) if sing is None: sing = cls(group_id=group_id, **cons_params) db.add(sing) return sing columns["__tablename__"] = "".join(("_" + c.lower()) if "A" <= c <= "Z" else c for c in cls.__name__).strip("_") columns["group_id"] = Column(String, primary_key=True) return type( cls.__name__, (ModelBase, cls), dict(**columns, get_or_create=classmethod(get_or_create_standin)) )