arg_wiring.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. import inspect
  2. from ...messaging import RollbotMessage
  3. class ArgConverter:
  4. def __init__(self, conv):
  5. self.conv = conv
  6. Message = ArgConverter(lambda _, __, msg: msg)
  7. Database = ArgConverter(lambda _, db, __: db)
  8. Logger = ArgConverter(lambda cmd, _, __: cmd.logger)
  9. Bot = ArgConverter(lambda cmd, _, __: cmd.bot)
  10. ArgList = ArgConverter(lambda _, __, msg: msg.arg_list())
  11. Subcommand = ArgConverter(lambda _, __, msg: RollbotMessage.from_subcommand(msg))
  12. class Config(ArgConverter):
  13. def __init__(self, key=None):
  14. if key is None:
  15. super().__init__(lambda cmd, _, __: cmd.bot.config)
  16. else:
  17. super().__init__(lambda cmd, _, __, key=key: cmd.bot.config.get(key))
  18. def run_converter_function(fn, cmd, db, msg):
  19. return fn(*[fn.__annotations__[param].conv(cmd, db, msg) for param in inspect.signature(fn).parameters])
  20. class Singleton(ArgConverter):
  21. def __init__(self, model_cls):
  22. super().__init__(lambda _, db, msg, model_cls=model_cls: model_cls.get_or_create(db, msg))
  23. self.model_cls = model_cls
  24. def by(self, fn):
  25. def do_get(cmd, db, msg):
  26. key = run_converter_function(fn, cmd, db, msg)
  27. return db.query(self.model_cls).get(key) or self.model_cls.create_from_key(key)
  28. return ArgConverter(do_get)
  29. def by_all(self, fn):
  30. def do_get(cmd, db, msg):
  31. return [db.query(self.model_cls).get(x) or self.model_cls.create_from_key(x) for x in run_converter_function(fn, cmd, db, msg)]
  32. return ArgConverter(do_get)
  33. class Query(ArgConverter):
  34. def __init__(self, model_cls):
  35. super().__init__(self._convert)
  36. self.model_cls = model_cls
  37. self.filters = []
  38. def filter(self, fn):
  39. self.filters.append(fn)
  40. return self
  41. def _convert(self, cmd, db, msg):
  42. query = db.query(self.model_cls)
  43. for fn in self.filters:
  44. query = query.filter(run_converter_function(fn, cmd, db, msg))
  45. return query.all()
  46. class Lazy(ArgConverter):
  47. def __init__(self, child):
  48. super().__init__(self._convert)
  49. self.child = child
  50. self.result = None
  51. self.run = False
  52. def _convert(self, cmd, db, msg):
  53. self.cmd = cmd
  54. self.db = db
  55. self.msg = msg
  56. return self
  57. def get(self):
  58. if not self.run:
  59. self.result = self.child.conv(self.cmd, self.db, self.msg)
  60. return self.result
  61. def get_converters(parameters, annotations):
  62. converters = []
  63. for p in parameters:
  64. annot = annotations.get(p, None)
  65. if isinstance(annot, ArgConverter):
  66. converters.append(annot.conv)
  67. elif p in ("msg", "message", "_msg"):
  68. converters.append(lambda cmd, db, msg: msg)
  69. elif p in ("db", "database"):
  70. converters.append(lambda cmd, db, msg: db)
  71. elif p in ("log", "logger"):
  72. converters.append(lambda cmd, db, msg: cmd.logger)
  73. elif p in ("bot", "rollbot"):
  74. converters.append(lambda cmd, db, msg: cmd.bot)
  75. elif p in ("args", "arg_list"):
  76. converters.append(lambda cmd, db, msg: msg.arg_list())
  77. elif p in ("subc", "subcommand"):
  78. converters.append(lambda cmd, db, msg: RollbotMessage.from_subcommand(msg))
  79. elif p in ("cfg", "config"):
  80. converters.append(lambda cmd, db, msg: cmd.bot.config)
  81. elif p.startswith("cfg") or p.endswith("cfg"):
  82. annot = annot or p
  83. converters.append(lambda cmd, db, msg, key=annot: cmd.bot.config.get(key))
  84. elif p.startswith("data") or p.endswith("data"):
  85. converters.append(lambda cmd, db, msg, sing_cls=annot: sing_cls.get_or_create(db, msg))
  86. else:
  87. raise ValueError(p)
  88. return converters