|
@@ -1,8 +1,10 @@
|
|
|
from argparse import ArgumentParser, Namespace
|
|
|
-from typing import Optional
|
|
|
+from typing import Optional, TypeVar
|
|
|
+from collections.abc import Callable
|
|
|
import shlex
|
|
|
|
|
|
from ..types import Message, Context
|
|
|
+from ..failure import RollbotFailure
|
|
|
from .base import Injector
|
|
|
|
|
|
__all__ = [
|
|
@@ -10,6 +12,7 @@ __all__ = [
|
|
|
"ArgList",
|
|
|
"ArgListSplitOn",
|
|
|
"ArgParse",
|
|
|
+ "Arg",
|
|
|
]
|
|
|
|
|
|
|
|
@@ -23,10 +26,16 @@ class ArgListSplitOn(Injector[list[str]]):
|
|
|
self.split = split
|
|
|
|
|
|
async def inject(self, message: Message, context: Context) -> str:
|
|
|
+ cache_key = (ArgListSplitOn.__name__, self.split)
|
|
|
+ result = message.command.cache.get(cache_key, None)
|
|
|
+ if result is not None:
|
|
|
+ return result
|
|
|
if self.split is not None:
|
|
|
- return message.command.args.split(self.split)
|
|
|
+ result = message.command.args.split(self.split)
|
|
|
else:
|
|
|
- return message.command.args.split()
|
|
|
+ result = message.command.args.split()
|
|
|
+ message.command.cache[cache_key] = result
|
|
|
+ return result
|
|
|
|
|
|
|
|
|
class ArgParse(Injector[Namespace]):
|
|
@@ -39,4 +48,36 @@ class ArgParse(Injector[Namespace]):
|
|
|
|
|
|
Args = ArgsInjector()
|
|
|
ArgList = ArgListSplitOn()
|
|
|
-# TODO Arg(n)
|
|
|
+
|
|
|
+ArgType = TypeVar("ArgType")
|
|
|
+
|
|
|
+
|
|
|
+class Arg(Injector[ArgType]):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ index: int = 0,
|
|
|
+ convert: Callable[[str], ArgType] = str,
|
|
|
+ required: bool = True,
|
|
|
+ default: Optional[ArgType] = None,
|
|
|
+ fail_msg: Optional[str] = None,
|
|
|
+ ):
|
|
|
+ self.index = index
|
|
|
+ self.convert = convert
|
|
|
+ self.required = required
|
|
|
+ self.default = default
|
|
|
+ self.fail_msg = fail_msg or "Invalid argument: {}"
|
|
|
+
|
|
|
+ async def inject(self, message: Message, context: Context) -> str:
|
|
|
+ try:
|
|
|
+ arg = (await ArgList.inject(message, context))[self.index]
|
|
|
+ except IndexError:
|
|
|
+ if self.required:
|
|
|
+ RollbotFailure.INVALID_ARGUMENTS.with_reason(
|
|
|
+ f"Missing argument {self.index}"
|
|
|
+ ).raise_exc()
|
|
|
+ else:
|
|
|
+ return self.default
|
|
|
+ try:
|
|
|
+ return self.convert(arg)
|
|
|
+ except ValueError:
|
|
|
+ RollbotFailure.INVALID_ARGUMENTS.with_reason(self.fail_msg.format(arg)).raise_exc()
|