Selaa lähdekoodia

Add single arg injector

Kirk Trombley 4 vuotta sitten
vanhempi
commit
b139dd2bab
2 muutettua tiedostoa jossa 47 lisäystä ja 7 poistoa
  1. 45 4
      lib/rollbot/injection/args.py
  2. 2 3
      repl_driver.py

+ 45 - 4
lib/rollbot/injection/args.py

@@ -1,8 +1,10 @@
 from argparse import ArgumentParser, Namespace
-from typing import Optional
+from typing import Optional, TypeVar
+from collections.abc import Callable
 import shlex
 
 from ..types import Message, Context
+from ..failure import RollbotFailure
 from .base import Injector
 
 __all__ = [
@@ -10,6 +12,7 @@ __all__ = [
     "ArgList",
     "ArgListSplitOn",
     "ArgParse",
+    "Arg",
 ]
 
 
@@ -23,10 +26,16 @@ class ArgListSplitOn(Injector[list[str]]):
         self.split = split
 
     async def inject(self, message: Message, context: Context) -> str:
+        cache_key = (ArgListSplitOn.__name__, self.split)
+        result = message.command.cache.get(cache_key, None)
+        if result is not None:
+            return result
         if self.split is not None:
-            return message.command.args.split(self.split)
+            result = message.command.args.split(self.split)
         else:
-            return message.command.args.split()
+            result = message.command.args.split()
+        message.command.cache[cache_key] = result
+        return result
 
 
 class ArgParse(Injector[Namespace]):
@@ -39,4 +48,36 @@ class ArgParse(Injector[Namespace]):
 
 Args = ArgsInjector()
 ArgList = ArgListSplitOn()
-# TODO Arg(n)
+
+ArgType = TypeVar("ArgType")
+
+
+class Arg(Injector[ArgType]):
+    def __init__(
+        self,
+        index: int = 0,
+        convert: Callable[[str], ArgType] = str,
+        required: bool = True,
+        default: Optional[ArgType] = None,
+        fail_msg: Optional[str] = None,
+    ):
+        self.index = index
+        self.convert = convert
+        self.required = required
+        self.default = default
+        self.fail_msg = fail_msg or "Invalid argument: {}"
+
+    async def inject(self, message: Message, context: Context) -> str:
+        try:
+            arg = (await ArgList.inject(message, context))[self.index]
+        except IndexError:
+            if self.required:
+                RollbotFailure.INVALID_ARGUMENTS.with_reason(
+                    f"Missing argument {self.index}"
+                ).raise_exc()
+            else:
+                return self.default
+        try:
+            return self.convert(arg)
+        except ValueError:
+            RollbotFailure.INVALID_ARGUMENTS.with_reason(self.fail_msg.format(arg)).raise_exc()

+ 2 - 3
repl_driver.py

@@ -2,7 +2,7 @@ from datetime import datetime
 import asyncio
 
 import rollbot
-from rollbot.injection import ArgList, Lazy, Database, Data, Args
+from rollbot.injection import ArgList, Lazy, Database, Data, Args, Arg
 
 
 @rollbot.initialize_data
@@ -51,8 +51,7 @@ async def count_command(message, context):
 
 
 @rollbot.as_command
-async def count2(args: ArgList, connect: Lazy(Database)):
-    name = args[0] if len(args) > 0 else "main"
+async def count2(name: Arg(0, required=False, default="main"), connect: Lazy(Database)):
     db = await connect()
     await db.execute(
         "INSERT INTO counter VALUES (?, 1) \