123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141 |
- from __future__ import annotations
- from logging import Logger
- from dataclasses import dataclass, field
- from datetime import datetime
- from collections.abc import Callable, Coroutine, Container
- from typing import Union, Any, Optional
- from aiosqlite import Connection
- from aiohttp import ClientSession
- __all__ = [
- "Attachment",
- "Message",
- "Command",
- "Response",
- "Context",
- "CommandType",
- "StartupShutdownType",
- "CommandConfiguration",
- ]
- @dataclass
- class Attachment:
- name: str
- body: Union[str, bytes]
- @dataclass
- class Message:
- origin_id: str
- channel_id: str
- sender_id: str
- timestamp: datetime
- origin_admin: bool
- channel_admin: bool
- text: Optional[str] = None
- attachments: list[Attachment] = field(default_factory=list)
- def __post_init__(self):
- self.command = None
- @dataclass
- class Command:
- bang: str
- name: str
- args: str
- def __post_init__(self):
- # cache is used by injectors to store results and avoid recomputation
- self.cache = {}
- @staticmethod
- def from_text(text: str) -> Optional[Command]:
- cleaned = text.lstrip()
- if len(cleaned) < 2:
- return None
- parts = cleaned[1:].lstrip().split(maxsplit=1)
- if len(parts) == 0:
- return None
- return Command(
- bang=cleaned[0],
- name=parts[0],
- args=parts[1] if len(parts) > 1 else "",
- )
- def get_subcommand(self, inherit_bang=True) -> Command:
- saved = self.cache.get(("subcommand", inherit_bang), None)
- if saved is None:
- if inherit_bang and not self.args.startswith(self.bang):
- saved = Command.from_text(self.bang + self.args)
- else:
- saved = Command.from_text(self.args)
- self.cache[("subcommand", inherit_bang)] = saved
- return saved
- @dataclass
- class Response:
- origin_id: str
- channel_id: str
- text: Optional[str] = None
- attachments: Optional[list[Attachment]] = None
- @staticmethod
- def from_message(
- msg: Message, text: Optional[str] = None, attachments: list[Attachment] = None
- ) -> Response:
- return Response(
- origin_id=msg.origin_id,
- channel_id=msg.channel_id,
- text=text,
- attachments=attachments or [],
- )
- @dataclass
- class Context:
- config: Callable[[str], Any]
- respond: Callable[[], Coroutine[None, None, None]]
- request: ClientSession
- database: Callable[[], Coroutine[None, None, Connection]]
- logger: Logger
- CommandType = Callable[[Message, Context], Coroutine[None, None, None]]
- StartupShutdownType = Callable[[Context], Coroutine[None, None, None]]
- @dataclass
- class CommandConfiguration:
- commands: dict[str, CommandType] = field(default_factory=dict)
- call_and_response: dict[str, str] = field(default_factory=dict)
- aliases: dict[str, str] = field(default_factory=dict)
- bangs: Container[str] = ("!",)
- startup: list[StartupShutdownType] = field(default_factory=list)
- shutdown: list[StartupShutdownType] = field(default_factory=list)
- def extend(self, other: CommandConfiguration) -> CommandConfiguration:
- return CommandConfiguration(
- commands={
- **self.commands,
- **other.commands,
- },
- call_and_response={
- **self.call_and_response,
- **other.call_and_response,
- },
- aliases={
- **self.aliases,
- **other.aliases,
- },
- bangs=(*self.bangs, *other.bangs),
- startup=[*self.startup, *other.startup],
- shutdown=[*self.shutdown, *other.shutdown],
- )
|