injection.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. import inspect
  2. from ..messaging import RollbotMessage, RollbotFailure
  3. class ArgConverter:
  4. def __init__(self, conv):
  5. self.conv = conv
  6. Message = ArgConverter(lambda _, __, msg: msg)
  7. Database = ArgConverter(lambda _, db, __: db)
  8. Logger = ArgConverter(lambda cmd, _, __: cmd.logger)
  9. Bot = ArgConverter(lambda cmd, _, __: cmd.bot)
  10. ArgString = ArgConverter(lambda _, __, msg: msg.raw_args)
  11. ArgList = ArgConverter(lambda _, __, msg: msg.arg_list())
  12. Subcommand = ArgConverter(lambda _, __, msg: RollbotMessage.from_subcommand(msg))
  13. Subcommand.ArgString = ArgConverter(lambda _, __, msg: RollbotMessage.from_subcommand(msg).raw_args)
  14. Subcommand.ArgList = ArgConverter(lambda _, __, msg: RollbotMessage.from_subcommand(msg).arg_list())
  15. class Arg(ArgConverter):
  16. def __init__(self, index, conversion=str, fail_msg=None):
  17. super().__init__(self._convert)
  18. self.index = index
  19. self.conversion = conversion
  20. self.fail_msg = fail_msg or "Invalid argument: {}"
  21. def _convert(self, cmd, db, msg):
  22. try:
  23. arg = msg.arg_list()[self.index]
  24. except IndexError:
  25. RollbotFailure.INVALID_ARGUMENTS.with_reason(f"Missing argument {self.index}").raise_exc()
  26. try:
  27. return self.conversion(arg)
  28. except ValueError:
  29. RollbotFailure.INVALID_ARGUMENTS.with_reason(self.fail_msg.format(arg)).raise_exc()
  30. class OptArg(ArgConverter):
  31. def __init__(self, index, conversion=str, default=None, fail_msg=None):
  32. super().__init__(self._convert)
  33. self.index = index
  34. self.conversion = conversion
  35. self.default = default
  36. self.fail_msg = fail_msg or "Invalid argument: {}"
  37. def _convert(self, cmd, db, msg):
  38. if msg.raw_args is None:
  39. return self.default
  40. try:
  41. arg = msg.arg_list()[self.index]
  42. except IndexError:
  43. return self.default
  44. try:
  45. return self.conversion(arg)
  46. except ValueError:
  47. RollbotFailure.INVALID_ARGUMENTS.with_reason(self.fail_msg.format(arg)).raise_exc()
  48. class _SubcArg(Arg):
  49. def _convert(self, cmd, db, msg):
  50. return super()._convert(self, cmd, db, RollbotMessage.from_subcommand(msg))
  51. class _SubcOptArg(OptArg):
  52. def _convert(self, cmd, db, msg):
  53. return super()._convert(self, cmd, db, RollbotMessage.from_subcommand(msg))
  54. Subcommand.Arg = _SubcArg
  55. Subcommand.OptArg = _SubcOptArg
  56. class Config(ArgConverter):
  57. def __init__(self, key=None):
  58. if key is None:
  59. super().__init__(lambda cmd, _, __: cmd.bot.config)
  60. else:
  61. super().__init__(lambda cmd, _, __, key=key: cmd.bot.config.get(key))
  62. def _run_converter_function(fn, cmd, db, msg):
  63. return fn(*[fn.__annotations__[param].conv(cmd, db, msg) for param in inspect.signature(fn).parameters])
  64. class Singleton(ArgConverter):
  65. def __init__(self, model_cls):
  66. super().__init__(lambda _, db, msg, model_cls=model_cls: model_cls.get_or_create(db, msg))
  67. self.model_cls = model_cls
  68. def by(self, fn):
  69. def do_get(cmd, db, msg):
  70. key = _run_converter_function(fn, cmd, db, msg)
  71. return db.query(self.model_cls).get(key) or self.model_cls.create_from_key(key)
  72. return ArgConverter(do_get)
  73. def by_all(self, fn):
  74. def do_get(cmd, db, msg):
  75. return [db.query(self.model_cls).get(x) or self.model_cls.create_from_key(x) for x in _run_converter_function(fn, cmd, db, msg)]
  76. return ArgConverter(do_get)
  77. class Query(ArgConverter):
  78. def __init__(self, model_cls):
  79. super().__init__(self._convert)
  80. self.model_cls = model_cls
  81. self._filters = []
  82. self._order_by = None
  83. self._limit = None
  84. def filter(self, fn):
  85. self._filters.append(fn)
  86. return self
  87. def order_by(self, fn):
  88. self._order_by = fn
  89. return self
  90. def limit(self, n):
  91. self._limit = n
  92. return self
  93. def _convert(self, cmd, db, msg):
  94. query = db.query(self.model_cls)
  95. for fn in self._filters:
  96. query = query.filter(_run_converter_function(fn, cmd, db, msg))
  97. if self._order_by is not None:
  98. query = query.order_by(_run_converter_function(self._order_by, cmd, db, msg))
  99. if self._limit is None:
  100. return query.all()
  101. else:
  102. return query.limit(self._limit)
  103. class Lazy(ArgConverter):
  104. def __init__(self, child):
  105. super().__init__(self._convert)
  106. self.child = child
  107. self.result = None
  108. self.run = False
  109. def _convert(self, cmd, db, msg):
  110. self.cmd = cmd
  111. self.db = db
  112. self.msg = msg
  113. return self
  114. def get(self):
  115. if not self.run:
  116. self.result = self.child.conv(self.cmd, self.db, self.msg)
  117. return self.result
  118. class _Executor:
  119. def __init__(self, cmd, db, msg):
  120. self.cmd = cmd
  121. self.db = db
  122. self.msg = msg
  123. def run_with_deps(self, fn):
  124. return _run_converter_function(fn, self.cmd, self.db, self.msg)
  125. Injector = ArgConverter(_Executor)