Эх сурвалжийг харах

WIP. Bunch of arg extraction and validation logic. Again, might be overkill

Kirk Trombley 4 жил өмнө
parent
commit
120a56870b

+ 4 - 1
lib/rollbot/messaging.py

@@ -74,6 +74,7 @@ class RollbotMessage:
                 self.raw_command = cmd
                 self.command = cmd.lower()
                 self.raw_args = raw
+                self._arg_list_thunk = None
 
     @staticmethod
     def from_subcommand(msg):
@@ -150,7 +151,9 @@ class RollbotMessage:
         Behavior is undefined if this method is called on a message whose
         is_command field is false.
         """
-        return self.raw_args.split()
+        if self._arg_list_thunk is None:
+            self._arg_list_thunk = self.raw_args.split()
+        return self._arg_list_thunk
 
 
 class RollbotFailureException(BaseException):

+ 2 - 1
lib/rollbot/plugins/decorators/__init__.py

@@ -1,2 +1,3 @@
 from .as_plugin import as_plugin
-from .attachers import with_help, with_startup
+from .attachers import with_help, with_startup
+from .requires import require_min_args, require_args

+ 28 - 11
lib/rollbot/plugins/decorators/arg_wiring.py

@@ -1,6 +1,6 @@
 import inspect
 
-from ...messaging import RollbotMessage
+from ...messaging import RollbotMessage, RollbotFailure
 
 
 class ArgConverter:
@@ -16,6 +16,24 @@ ArgList = ArgConverter(lambda _, __, msg: msg.arg_list())
 Subcommand = ArgConverter(lambda _, __, msg: RollbotMessage.from_subcommand(msg))
 
 
+class Arg(ArgConverter):
+    def __init__(self, index, conversion=str, fail_msg=None):
+        super().__init__(self._convert)
+        self.index = index
+        self.conversion = conversion
+        self.fail_msg = fail_msg
+    
+    def _convert(self, cmd, db, msg):
+        try:
+            arg = msg.arg_list()[self.index]
+        except IndexError:
+            RollbotFailure.INVALID_ARGUMENTS.with_reason(f"Missing argument {self.index}").raise_exc()
+        try:
+            return self.conversion(arg)
+        except ValueError:
+            RollbotFailure.INVALID_ARGUMENTS.with_reason(self.fail_msg.format(arg)).raise_exc()
+
+
 class Config(ArgConverter):
     def __init__(self, key=None):
         if key is None:
@@ -88,24 +106,23 @@ def get_converters(parameters, annotations):
         if isinstance(annot, ArgConverter):
             converters.append(annot.conv)
         elif p in ("msg", "message", "_msg"):
-            converters.append(lambda cmd, db, msg: msg)
+            converters.append(Message.conv)
         elif p in ("db", "database"):
-            converters.append(lambda cmd, db, msg: db)
+            converters.append(Database.conv)
         elif p in ("log", "logger"):
-            converters.append(lambda cmd, db, msg: cmd.logger)
+            converters.append(Logger.conv)
         elif p in ("bot", "rollbot"):
-            converters.append(lambda cmd, db, msg: cmd.bot)
+            converters.append(Bot.conv)
         elif p in ("args", "arg_list"):
-            converters.append(lambda cmd, db, msg: msg.arg_list())
+            converters.append(ArgList.conv)
         elif p in ("subc", "subcommand"):
-            converters.append(lambda cmd, db, msg: RollbotMessage.from_subcommand(msg))
+            converters.append(Subcommand.conv)
         elif p in ("cfg", "config"):
-            converters.append(lambda cmd, db, msg: cmd.bot.config)
+            converters.append(Config())
         elif p.startswith("cfg") or p.endswith("cfg"):
-            annot = annot or p
-            converters.append(lambda cmd, db, msg, key=annot: cmd.bot.config.get(key))
+            converters.append(Config(annot or p).conv)
         elif p.startswith("data") or p.endswith("data"):
-            converters.append(lambda cmd, db, msg, sing_cls=annot: sing_cls.get_or_create(db, msg))
+            converters.append(Singleton(annot).conv)
         else:
             raise ValueError(p)
     return converters

+ 32 - 0
lib/rollbot/plugins/decorators/requires.py

@@ -0,0 +1,32 @@
+from functools import wraps
+
+from ...messaging import RollbotResponse, RollbotFailure
+
+
+def require_min_args(n, alert_response=None):
+    def decorator(cls):
+        old_on_command = cls.on_command
+        @wraps(old_on_command)
+        def wrapper(self, db, message):
+            if len(message.arg_list()) < n:
+                failure = RollbotFailure.INVALID_ARGUMENTS.with_reason(alert_response or f"{cls.command.title()} requires at least {n} argument(s)")
+                return RollbotResponse(message, failure=failure, debugging=failure.get_debugging())
+            return old_on_command(self, db, message)
+        setattr(cls, "on_command", wrapper)
+        return cls
+    return decorator
+
+
+# TODO refactor to share code?
+def require_args(n, alert_response=None):
+    def decorator(cls):
+        old_on_command = cls.on_command
+        @wraps(old_on_command)
+        def wrapper(self, db, message):
+            if len(message.arg_list()) < n:
+                failure = RollbotFailure.INVALID_ARGUMENTS.with_reason(alert_response or f"{cls.command.title()} requires exactly {n} argument(s)")
+                return RollbotResponse(message, failure=failure, debugging=failure.get_debugging())
+            return old_on_command(self, db, message)
+        setattr(cls, "on_command", wrapper)
+        return cls
+    return decorator

+ 17 - 29
src/plugins/rollcoin.py

@@ -4,8 +4,9 @@ import pickle
 
 from sudoku_py import SudokuGenerator, Sudoku
 
-from rollbot import as_plugin, with_help, with_startup, as_sender_singleton, as_group_singleton, RollbotFailure
-from rollbot.plugins.decorators.arg_wiring import Database, Config, ArgList, Singleton, Message, Lazy, Query
+from rollbot import as_plugin, with_help, as_sender_singleton, as_group_singleton, RollbotFailure
+from rollbot.plugins.decorators import with_startup, require_min_args, require_args
+from rollbot.plugins.decorators.arg_wiring import Database, Config, Arg, Singleton, Message, Lazy, Query
 
 @as_sender_singleton
 class RollcoinWallet:
@@ -73,15 +74,10 @@ def balance(wallet: Singleton(RollcoinWallet), holdings: Singleton(RollcoinGambl
 SPECIAL_AMOUNTS = ("ALL", "FRAC", "-ALL", "-FRAC")
 
 
-def pop_amount_arg(args):
-    raw_amount, *rest = args
-    if (up_amount := raw_amount.upper()) in SPECIAL_AMOUNTS:
-        return up_amount, rest
-    try:
-        amount = float(raw_amount)
-    except ValueError:
-        RollbotFailure.INVALID_ARGUMENTS.with_reason(f"Amount should be a number or ALL or FRAC - not {raw_amount}").raise_exc()
-    return amount, rest
+def parse_amount(amount):
+    if (up_amount := amount.upper()) in SPECIAL_AMOUNTS:
+        return up_amount
+    return float(amount)
 
 
 def assert_positive(amount):
@@ -93,8 +89,8 @@ def fractional_part(number):
     return float("0." + str(float(number)).split(".")[1])
 
 
-def lookup_target(args: ArgList, wallets: Config("rollcoin.wallets")):
-    target_id = wallets.get(args[0].lower(), None)
+def lookup_target(target_name: Arg(0), wallets: Config("rollcoin.wallets")):
+    target_id = wallets.get(target_name.lower(), None)
     if target_id is None:
         RollbotFailure.INVALID_ARGUMENTS.with_reason(f"Could not find wallet-holder {target}").raise_exc()
     
@@ -102,14 +98,12 @@ def lookup_target(args: ArgList, wallets: Config("rollcoin.wallets")):
 
 
 @with_help("Tip someone else some Rollcoins: !tip target amount")
+@require_min_args(2, "Tip requires 2 arguments - target and amount")
 @as_plugin
-def tip(args: ArgList, 
+def tip(target_name: Arg(0),
+        amount: Arg(1, parse_amount, "Amount should be a number or ALL or FRAC - not {}"),
         sender_wallet: Singleton(RollcoinWallet), 
         target_wallet: Lazy(Singleton(RollcoinWallet).by(lookup_target))):
-    if len(args) < 2:
-        RollbotFailure.INVALID_ARGUMENTS.with_reason("Tip requires 2 arguments - target and amount").raise_exc()
-
-    amount, _ = pop_amount_arg(args[1:])
     if amount in SPECIAL_AMOUNTS:
         if amount == "ALL":
             amount = sender_wallet.balance
@@ -128,7 +122,7 @@ def tip(args: ArgList,
 
     sender_wallet.balance = sender_wallet.balance - amount
     target_wallet.balance = target_wallet.balance + amount
-    return f"Done! {args[0]} now has {target_wallet.balance} Rollcoins"
+    return f"Done! {target_name} now has {target_wallet.balance} Rollcoins"
 
 
 def get_non_sender_ids(wallets: Config("rollcoin.wallets"), msg: Message):
@@ -136,14 +130,11 @@ def get_non_sender_ids(wallets: Config("rollcoin.wallets"), msg: Message):
 
 
 @with_help("Donate money to be distributed evenly among all wallets")
+@require_min_args(1, "Need an amount to donate")
 @as_plugin
-def donate(args: ArgList, 
+def donate(amount: Arg(0, parse_amount, "Amount should be a number or ALL or FRAC - not {}"), 
            sender_wallet: Singleton(RollcoinWallet),
            other_wallets: Lazy(Singleton(RollcoinWallet).by_all(get_non_sender_ids))):
-    if len(args) < 1:
-        return RollbotFailure.INVALID_ARGUMENTS.with_reason("Need an amount to donate")
-
-    amount, _ = pop_amount_arg(args)
     if amount in SPECIAL_AMOUNTS:
         if amount == "ALL":
             amount = sender_wallet.balance
@@ -256,8 +247,9 @@ def not_sender_gambling_wallet(msg: Message):
 
 
 @with_help("Gamble some Rollcoins")
+@require_args(1, "Gambling requires exactly one argument: the amount to transfer (can be positive or negative)")
 @as_plugin
-def gamble(args: ArgList, 
+def gamble(amount: Arg(0, parse_amount, "Amount should be a number or ALL or FRAC - not {}"), 
            sender_wallet: Singleton(RollcoinWallet), 
            sender_holdings: Singleton(RollcoinGamblingWallet), 
            market: Singleton(RollcoinMarket), 
@@ -270,10 +262,6 @@ def gamble(args: ArgList,
     #   - current market state
     #   - rng
     # market enters new state and affects all gambling wallets, multiplying their balances
-    if len(args) != 1:
-        RollbotFailure.INVALID_ARGUMENTS.with_reason("Gambling requires exactly one argument: the amount to transfer (can be positive or negative)").raise_exc()
-
-    amount, _ = pop_amount_arg(args)
     if amount in SPECIAL_AMOUNTS:
         if amount == "ALL":
             amount = sender_wallet.balance