123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165 |
- import inspect
- from ..messaging import RollbotMessage, RollbotFailure
- class ArgConverter:
- def __init__(self, conv):
- self.conv = conv
- Message = ArgConverter(lambda _, __, msg: msg)
- Database = ArgConverter(lambda _, db, __: db)
- Logger = ArgConverter(lambda cmd, _, __: cmd.logger)
- Bot = ArgConverter(lambda cmd, _, __: cmd.bot)
- ArgString = ArgConverter(lambda _, __, msg: msg.raw_args)
- ArgList = ArgConverter(lambda _, __, msg: msg.arg_list())
- Subcommand = ArgConverter(lambda _, __, msg: RollbotMessage.from_subcommand(msg))
- Subcommand.ArgString = ArgConverter(lambda _, __, msg: RollbotMessage.from_subcommand(msg).raw_args)
- Subcommand.ArgList = ArgConverter(lambda _, __, msg: RollbotMessage.from_subcommand(msg).arg_list())
- class Arg(ArgConverter):
- def __init__(self, index, conversion=str, fail_msg=None):
- super().__init__(self._convert)
- self.index = index
- self.conversion = conversion
- self.fail_msg = fail_msg or "Invalid argument: {}"
-
- def _convert(self, cmd, db, msg):
- try:
- arg = msg.arg_list()[self.index]
- except IndexError:
- RollbotFailure.INVALID_ARGUMENTS.with_reason(f"Missing argument {self.index}").raise_exc()
- try:
- return self.conversion(arg)
- except ValueError:
- RollbotFailure.INVALID_ARGUMENTS.with_reason(self.fail_msg.format(arg)).raise_exc()
- class OptArg(ArgConverter):
- def __init__(self, index, conversion=str, default=None, fail_msg=None):
- super().__init__(self._convert)
- self.index = index
- self.conversion = conversion
- self.default = default
- self.fail_msg = fail_msg or "Invalid argument: {}"
-
- def _convert(self, cmd, db, msg):
- if msg.raw_args is None:
- return self.default
- try:
- arg = msg.arg_list()[self.index]
- except IndexError:
- return self.default
- try:
- return self.conversion(arg)
- except ValueError:
- RollbotFailure.INVALID_ARGUMENTS.with_reason(self.fail_msg.format(arg)).raise_exc()
- class _SubcArg(Arg):
- def _convert(self, cmd, db, msg):
- return super()._convert(self, cmd, db, RollbotMessage.from_subcommand(msg))
- class _SubcOptArg(OptArg):
- def _convert(self, cmd, db, msg):
- return super()._convert(self, cmd, db, RollbotMessage.from_subcommand(msg))
- Subcommand.Arg = _SubcArg
- Subcommand.OptArg = _SubcOptArg
- class Config(ArgConverter):
- def __init__(self, key=None):
- if key is None:
- super().__init__(lambda cmd, _, __: cmd.bot.config)
- else:
- super().__init__(lambda cmd, _, __, key=key: cmd.bot.config.get(key))
- def _run_converter_function(fn, cmd, db, msg):
- return fn(*[fn.__annotations__[param].conv(cmd, db, msg) for param in inspect.signature(fn).parameters])
- class Singleton(ArgConverter):
- def __init__(self, model_cls):
- super().__init__(lambda _, db, msg, model_cls=model_cls: model_cls.get_or_create(db, msg))
- self.model_cls = model_cls
- def by(self, fn):
- def do_get(cmd, db, msg):
- key = _run_converter_function(fn, cmd, db, msg)
- return db.query(self.model_cls).get(key) or self.model_cls.create_from_key(key)
- return ArgConverter(do_get)
- def by_all(self, fn):
- def do_get(cmd, db, msg):
- return [db.query(self.model_cls).get(x) or self.model_cls.create_from_key(x) for x in _run_converter_function(fn, cmd, db, msg)]
- return ArgConverter(do_get)
- class Query(ArgConverter):
- def __init__(self, model_cls):
- super().__init__(self._convert)
- self.model_cls = model_cls
- self._filters = []
- self._order_by = None
- self._limit = None
- def filter(self, fn):
- self._filters.append(fn)
- return self
- def order_by(self, fn):
- self._order_by = fn
- return self
- def limit(self, n):
- self._limit = n
- return self
- def _convert(self, cmd, db, msg):
- query = db.query(self.model_cls)
- for fn in self._filters:
- query = query.filter(_run_converter_function(fn, cmd, db, msg))
- if self._order_by is not None:
- query = query.order_by(_run_converter_function(self._order_by, cmd, db, msg))
- if self._limit is None:
- return query.all()
- else:
- return query.limit(self._limit)
- class Lazy(ArgConverter):
- def __init__(self, child):
- super().__init__(self._convert)
- self.child = child
- self.result = None
- self.run = False
- def _convert(self, cmd, db, msg):
- self.cmd = cmd
- self.db = db
- self.msg = msg
- return self
- def get(self):
- if not self.run:
- self.result = self.child.conv(self.cmd, self.db, self.msg)
- return self.result
- class _Executor:
- def __init__(self, cmd, db, msg):
- self.cmd = cmd
- self.db = db
- self.msg = msg
- def run_with_deps(self, fn):
- return _run_converter_function(fn, self.cmd, self.db, self.msg)
-
- Injector = ArgConverter(_Executor)
|