args.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. from argparse import ArgumentParser, Namespace
  2. from typing import Optional, TypeVar
  3. from collections.abc import Callable
  4. import shlex
  5. from ..types import Message, Context
  6. from ..failure import RollbotFailure
  7. from .base import Injector, Simple
  8. __all__ = [
  9. "Args",
  10. "ArgList",
  11. "ArgListSplitOn",
  12. "ArgParse",
  13. "Arg",
  14. ]
  15. Args = Simple[str](lambda m, c: m.command.args)
  16. class ArgListSplitOn(Injector[list[str]]):
  17. def __init__(self, split: Optional[str] = None):
  18. self.split = split
  19. async def inject(self, message: Message, context: Context) -> str:
  20. cache_key = (ArgListSplitOn.__name__, self.split)
  21. result = message.command.cache.get(cache_key, None)
  22. if result is not None:
  23. return result
  24. if self.split is not None:
  25. result = message.command.args.split(self.split)
  26. else:
  27. result = message.command.args.split()
  28. message.command.cache[cache_key] = result
  29. return result
  30. class ArgParse(Injector[Namespace]):
  31. def __init__(self, parser: ArgumentParser):
  32. self.parser = parser
  33. async def inject(self, message: Message, context: Context) -> Namespace:
  34. return self.parser.parse_args(shlex.split(message.command.args))
  35. ArgList = ArgListSplitOn()
  36. ArgType = TypeVar("ArgType")
  37. class Arg(Injector[ArgType]):
  38. def __init__(
  39. self,
  40. index: int = 0,
  41. convert: Callable[[str], ArgType] = str,
  42. required: bool = True,
  43. default: Optional[ArgType] = None,
  44. fail_msg: Optional[str] = None,
  45. ):
  46. self.index = index
  47. self.convert = convert
  48. self.required = required
  49. self.default = default
  50. self.fail_msg = fail_msg or "Invalid argument: {}"
  51. async def inject(self, message: Message, context: Context) -> str:
  52. try:
  53. arg = (await ArgList.inject(message, context))[self.index]
  54. except IndexError:
  55. if self.required:
  56. RollbotFailure.INVALID_ARGUMENTS.with_reason(
  57. f"Missing argument {self.index}"
  58. ).raise_exc()
  59. else:
  60. return self.default
  61. try:
  62. return self.convert(arg)
  63. except ValueError:
  64. RollbotFailure.INVALID_ARGUMENTS.with_reason(self.fail_msg.format(arg)).raise_exc()