args.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. from argparse import ArgumentParser, Namespace
  2. from typing import Optional, TypeVar
  3. from collections.abc import Callable
  4. import shlex
  5. from ..types import Message, Context
  6. from ..failure import RollbotFailure
  7. from .base import Injector, Simple
  8. __all__ = [
  9. "Args",
  10. "ArgList",
  11. "ArgListSplitOn",
  12. "ArgParse",
  13. "Arg",
  14. ]
  15. Args = Simple[str](lambda m, c: m.command.args)
  16. class ArgAccessorBase:
  17. def get_args(self, message: Message) -> str:
  18. return message.command.args
  19. class ArgListSplitOn(Injector[list[str]], ArgAccessorBase):
  20. def __init__(self, split: Optional[str] = None):
  21. self.split = split
  22. async def inject(self, message: Message, context: Context) -> str:
  23. cache_key = (ArgListSplitOn.__name__, self.split)
  24. result = message.command.cache.get(cache_key, None)
  25. if result is not None:
  26. return result
  27. if self.split is not None:
  28. result = self.get_args(message).split(self.split)
  29. else:
  30. result = self.get_args(message).split()
  31. message.command.cache[cache_key] = result
  32. return result
  33. class ArgParse(Injector[Namespace], ArgAccessorBase):
  34. def __init__(self, parser: ArgumentParser):
  35. self.parser = parser
  36. async def inject(self, message: Message, context: Context) -> Namespace:
  37. return self.parser.parse_args(shlex.split(self.get_args(message)))
  38. ArgList = ArgListSplitOn()
  39. ArgType = TypeVar("ArgType")
  40. class Arg(Injector[ArgType]):
  41. def __init__(
  42. self,
  43. index: int = 0,
  44. convert: Callable[[str], ArgType] = str,
  45. required: bool = True,
  46. default: Optional[ArgType] = None,
  47. missing_msg: Optional[str] = None,
  48. fail_msg: Optional[str] = None,
  49. ):
  50. self.index = index
  51. self.convert = convert
  52. self.required = required
  53. self.default = default
  54. self.missing_msg = missing_msg or f"Missing argument {self.index}"
  55. self.fail_msg = fail_msg or "Invalid argument: {}"
  56. def arg_source(self) -> Injector[list[str]]:
  57. return ArgList
  58. async def inject(self, message: Message, context: Context) -> str:
  59. try:
  60. arg = (await self.arg_source().inject(message, context))[self.index]
  61. except IndexError:
  62. if self.required:
  63. RollbotFailure.INVALID_ARGUMENTS.raise_exc(detail=self.missing_msg)
  64. else:
  65. return self.default
  66. try:
  67. return self.convert(arg)
  68. except ValueError:
  69. RollbotFailure.INVALID_ARGUMENTS.raise_exc(detail=self.fail_msg.format(arg))