injection.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import inspect
  2. from ..messaging import RollbotMessage, RollbotFailure
  3. class ArgConverter:
  4. def __init__(self, conv):
  5. self.conv = conv
  6. Message = ArgConverter(lambda _, __, msg: msg)
  7. Database = ArgConverter(lambda _, db, __: db)
  8. Logger = ArgConverter(lambda cmd, _, __: cmd.logger)
  9. Bot = ArgConverter(lambda cmd, _, __: cmd.bot)
  10. ArgList = ArgConverter(lambda _, __, msg: msg.arg_list())
  11. Subcommand = ArgConverter(lambda _, __, msg: RollbotMessage.from_subcommand(msg))
  12. class Arg(ArgConverter):
  13. def __init__(self, index, conversion=str, fail_msg=None):
  14. super().__init__(self._convert)
  15. self.index = index
  16. self.conversion = conversion
  17. self.fail_msg = fail_msg
  18. def _convert(self, cmd, db, msg):
  19. try:
  20. arg = msg.arg_list()[self.index]
  21. except IndexError:
  22. RollbotFailure.INVALID_ARGUMENTS.with_reason(f"Missing argument {self.index}").raise_exc()
  23. try:
  24. return self.conversion(arg)
  25. except ValueError:
  26. RollbotFailure.INVALID_ARGUMENTS.with_reason(self.fail_msg.format(arg)).raise_exc()
  27. class Config(ArgConverter):
  28. def __init__(self, key=None):
  29. if key is None:
  30. super().__init__(lambda cmd, _, __: cmd.bot.config)
  31. else:
  32. super().__init__(lambda cmd, _, __, key=key: cmd.bot.config.get(key))
  33. def _run_converter_function(fn, cmd, db, msg):
  34. return fn(*[fn.__annotations__[param].conv(cmd, db, msg) for param in inspect.signature(fn).parameters])
  35. class Singleton(ArgConverter):
  36. def __init__(self, model_cls):
  37. super().__init__(lambda _, db, msg, model_cls=model_cls: model_cls.get_or_create(db, msg))
  38. self.model_cls = model_cls
  39. def by(self, fn):
  40. def do_get(cmd, db, msg):
  41. key = _run_converter_function(fn, cmd, db, msg)
  42. return db.query(self.model_cls).get(key) or self.model_cls.create_from_key(key)
  43. return ArgConverter(do_get)
  44. def by_all(self, fn):
  45. def do_get(cmd, db, msg):
  46. return [db.query(self.model_cls).get(x) or self.model_cls.create_from_key(x) for x in _run_converter_function(fn, cmd, db, msg)]
  47. return ArgConverter(do_get)
  48. class Query(ArgConverter):
  49. def __init__(self, model_cls):
  50. super().__init__(self._convert)
  51. self.model_cls = model_cls
  52. self.filters = []
  53. def filter(self, fn):
  54. self.filters.append(fn)
  55. return self
  56. def _convert(self, cmd, db, msg):
  57. query = db.query(self.model_cls)
  58. for fn in self.filters:
  59. query = query.filter(_run_converter_function(fn, cmd, db, msg))
  60. return query.all()
  61. class Lazy(ArgConverter):
  62. def __init__(self, child):
  63. super().__init__(self._convert)
  64. self.child = child
  65. self.result = None
  66. self.run = False
  67. def _convert(self, cmd, db, msg):
  68. self.cmd = cmd
  69. self.db = db
  70. self.msg = msg
  71. return self
  72. def get(self):
  73. if not self.run:
  74. self.result = self.child.conv(self.cmd, self.db, self.msg)
  75. return self.result