12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788 |
- from argparse import ArgumentParser, Namespace
- from typing import Optional, TypeVar
- from collections.abc import Callable
- import shlex
- from ..types import Message, Context
- from ..failure import RollbotFailure
- from .base import Injector, Simple
- __all__ = [
- "Args",
- "ArgList",
- "ArgListSplitOn",
- "ArgParse",
- "Arg",
- ]
- Args = Simple[str](lambda m, c: m.command.args)
- class ArgAccessorBase:
- def get_args(self, message: Message) -> str:
- return message.command.args
- class ArgListSplitOn(Injector[list[str]], ArgAccessorBase):
- def __init__(self, split: Optional[str] = None):
- self.split = split
- async def inject(self, message: Message, context: Context) -> str:
- cache_key = (ArgListSplitOn.__name__, self.split)
- result = message.command.cache.get(cache_key, None)
- if result is not None:
- return result
- if self.split is not None:
- result = self.get_args(message).split(self.split)
- else:
- result = self.get_args(message).split()
- message.command.cache[cache_key] = result
- return result
- class ArgParse(Injector[Namespace], ArgAccessorBase):
- def __init__(self, parser: ArgumentParser):
- self.parser = parser
- async def inject(self, message: Message, context: Context) -> Namespace:
- return self.parser.parse_args(shlex.split(self.get_args(message)))
- ArgList = ArgListSplitOn()
- ArgType = TypeVar("ArgType")
- class Arg(Injector[ArgType]):
- def __init__(
- self,
- index: int = 0,
- convert: Callable[[str], ArgType] = str,
- required: bool = True,
- default: Optional[ArgType] = None,
- missing_msg: Optional[str] = None,
- fail_msg: Optional[str] = None,
- ):
- self.index = index
- self.convert = convert
- self.required = required
- self.default = default
- self.missing_msg = missing_msg or f"Missing argument {self.index}"
- self.fail_msg = fail_msg or "Invalid argument: {}"
- def arg_source(self) -> Injector[list[str]]:
- return ArgList
- async def inject(self, message: Message, context: Context) -> str:
- try:
- arg = (await self.arg_source().inject(message, context))[self.index]
- except IndexError:
- if self.required:
- RollbotFailure.INVALID_ARGUMENTS.raise_exc(detail=self.missing_msg)
- else:
- return self.default
- try:
- return self.convert(arg)
- except ValueError:
- RollbotFailure.INVALID_ARGUMENTS.raise_exc(detail=self.fail_msg.format(arg))
|