|
@@ -0,0 +1,57 @@
|
|
|
+from typing import Generic, TypeVar, Optional
|
|
|
+from argparse import ArgumentParser, Namespace
|
|
|
+import shlex
|
|
|
+
|
|
|
+from aiosqlite.core import Connection
|
|
|
+
|
|
|
+from ..types import Message, Context
|
|
|
+
|
|
|
+
|
|
|
+async def inject_message(message: Message, context: Context) -> Message:
|
|
|
+ return message
|
|
|
+
|
|
|
+
|
|
|
+async def inject_context(message: Message, context: Context) -> Context:
|
|
|
+ return context
|
|
|
+
|
|
|
+
|
|
|
+Dep = TypeVar("DepType")
|
|
|
+
|
|
|
+
|
|
|
+class Injector(Generic[Dep]):
|
|
|
+ async def inject(self, message: Message, context: Context) -> Dep:
|
|
|
+ raise NotImplemented
|
|
|
+
|
|
|
+
|
|
|
+class ArgsInjector(Injector[str]):
|
|
|
+ async def inject(self, message: Message, context: Context) -> str:
|
|
|
+ return message.command.args
|
|
|
+
|
|
|
+
|
|
|
+class ArgListSplitOn(Injector[list[str]]):
|
|
|
+ def __init__(self, split: Optional[str] = None):
|
|
|
+ self.split = split
|
|
|
+
|
|
|
+ async def inject(self, message: Message, context: Context) -> str:
|
|
|
+ if self.split is not None:
|
|
|
+ return message.command.args.split(self.split)
|
|
|
+ else:
|
|
|
+ return message.command.args.split()
|
|
|
+
|
|
|
+
|
|
|
+class ArgParse(Injector[Namespace]):
|
|
|
+ def __init__(self, parser: ArgumentParser):
|
|
|
+ self.parser = parser
|
|
|
+
|
|
|
+ async def inject(self, message: Message, context: Context) -> Namespace:
|
|
|
+ return self.parser.parse_args(shlex.split(message.text))
|
|
|
+
|
|
|
+
|
|
|
+class DatabaseInjector(Injector[Connection]):
|
|
|
+ async def inject(self, message: Message, context: Context) -> Connection:
|
|
|
+ return context.database()
|
|
|
+
|
|
|
+
|
|
|
+Args = ArgsInjector()
|
|
|
+ArgList = ArgListSplitOn()
|
|
|
+Database = DatabaseInjector()
|