args.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. from argparse import ArgumentParser, Namespace
  2. from typing import Optional, TypeVar
  3. from collections.abc import Callable
  4. import shlex
  5. from ..types import Message, Context
  6. from ..failure import RollbotFailure
  7. from .base import Injector
  8. __all__ = [
  9. "Args",
  10. "ArgList",
  11. "ArgListSplitOn",
  12. "ArgParse",
  13. "Arg",
  14. ]
  15. class ArgsInjector(Injector[str]):
  16. async def inject(self, message: Message, context: Context) -> str:
  17. return message.command.args
  18. class ArgListSplitOn(Injector[list[str]]):
  19. def __init__(self, split: Optional[str] = None):
  20. self.split = split
  21. async def inject(self, message: Message, context: Context) -> str:
  22. cache_key = (ArgListSplitOn.__name__, self.split)
  23. result = message.command.cache.get(cache_key, None)
  24. if result is not None:
  25. return result
  26. if self.split is not None:
  27. result = message.command.args.split(self.split)
  28. else:
  29. result = message.command.args.split()
  30. message.command.cache[cache_key] = result
  31. return result
  32. class ArgParse(Injector[Namespace]):
  33. def __init__(self, parser: ArgumentParser):
  34. self.parser = parser
  35. async def inject(self, message: Message, context: Context) -> Namespace:
  36. return self.parser.parse_args(shlex.split(message.text))
  37. Args = ArgsInjector()
  38. ArgList = ArgListSplitOn()
  39. ArgType = TypeVar("ArgType")
  40. class Arg(Injector[ArgType]):
  41. def __init__(
  42. self,
  43. index: int = 0,
  44. convert: Callable[[str], ArgType] = str,
  45. required: bool = True,
  46. default: Optional[ArgType] = None,
  47. fail_msg: Optional[str] = None,
  48. ):
  49. self.index = index
  50. self.convert = convert
  51. self.required = required
  52. self.default = default
  53. self.fail_msg = fail_msg or "Invalid argument: {}"
  54. async def inject(self, message: Message, context: Context) -> str:
  55. try:
  56. arg = (await ArgList.inject(message, context))[self.index]
  57. except IndexError:
  58. if self.required:
  59. RollbotFailure.INVALID_ARGUMENTS.with_reason(
  60. f"Missing argument {self.index}"
  61. ).raise_exc()
  62. else:
  63. return self.default
  64. try:
  65. return self.convert(arg)
  66. except ValueError:
  67. RollbotFailure.INVALID_ARGUMENTS.with_reason(self.fail_msg.format(arg)).raise_exc()