arg_wiring.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. from ...messaging import RollbotMessage
  2. class ArgConverter:
  3. def __init__(self, conv):
  4. self.conv = conv
  5. Message = ArgConverter(lambda _, __, msg: msg)
  6. Database = ArgConverter(lambda _, db, __: db)
  7. Logger = ArgConverter(lambda cmd, _, __: cmd.logger)
  8. Bot = ArgConverter(lambda cmd, _, __: cmd.bot)
  9. ArgList = ArgConverter(lambda _, __, msg: msg.arg_list())
  10. Subcommand = ArgConverter(lambda _, __, msg: RollbotMessage.from_subcommand(msg))
  11. class Config(ArgConverter):
  12. def __init__(self, key):
  13. super().__init__(lambda cmd, _, __, key=key: cmd.bot.config.get(key))
  14. class Singleton(ArgConverter):
  15. def __init__(self, cls):
  16. super().__init__(lambda _, db, msg, cls=cls: cls.get_or_create(db, msg))
  17. def get_converters(parameters, annotations):
  18. converters = []
  19. for p in parameters:
  20. annot = annotations.get(p, None)
  21. if isinstance(annot, ArgConverter):
  22. converters.append(annot.conv)
  23. elif p in ("msg", "message", "_msg"):
  24. converters.append(lambda cmd, db, msg: msg)
  25. elif p in ("db", "database"):
  26. converters.append(lambda cmd, db, msg: db)
  27. elif p in ("log", "logger"):
  28. converters.append(lambda cmd, db, msg: cmd.logger)
  29. elif p in ("bot", "rollbot"):
  30. converters.append(lambda cmd, db, msg: cmd.bot)
  31. elif p in ("args", "arg_list"):
  32. converters.append(lambda cmd, db, msg: msg.arg_list())
  33. elif p in ("subc", "subcommand"):
  34. converters.append(lambda cmd, db, msg: RollbotMessage.from_subcommand(msg))
  35. elif p in ("cfg", "config"):
  36. converters.append(lambda cmd, db, msg: cmd.bot.config)
  37. elif p.startswith("cfg") or p.endswith("cfg"):
  38. annot = annot or p
  39. converters.append(lambda cmd, db, msg, key=annot: cmd.bot.config.get(key))
  40. elif p.startswith("data") or p.endswith("data"):
  41. converters.append(lambda cmd, db, msg, sing_cls=annot: sing_cls.get_or_create(db, msg))
  42. else:
  43. raise ValueError(p)
  44. return converters