types.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. from __future__ import annotations
  2. from logging import Logger
  3. from dataclasses import dataclass, field
  4. from datetime import datetime
  5. from collections.abc import Callable, Coroutine, Container
  6. from typing import Union, Any, Optional
  7. from aiosqlite import Connection
  8. from aiohttp import ClientSession
  9. __all__ = [
  10. "Attachment",
  11. "Message",
  12. "Command",
  13. "Response",
  14. "Context",
  15. "CommandType",
  16. "StartupShutdownType",
  17. "CommandConfiguration",
  18. ]
  19. @dataclass
  20. class Attachment:
  21. name: str
  22. body: Union[str, bytes]
  23. @dataclass
  24. class Message:
  25. origin_id: str
  26. channel_id: str
  27. sender_id: str
  28. timestamp: datetime
  29. origin_admin: bool
  30. channel_admin: bool
  31. sender_name: Optional[str] = None
  32. text: Optional[str] = None
  33. attachments: list[Attachment] = field(default_factory=list)
  34. def __post_init__(self):
  35. self.command = None
  36. @dataclass
  37. class Command:
  38. bang: str
  39. name: str
  40. raw_name: str
  41. args: str
  42. def __post_init__(self):
  43. # cache is used by injectors to store results and avoid recomputation
  44. self.cache = {}
  45. @staticmethod
  46. def from_text(text: str) -> Optional[Command]:
  47. cleaned = text.lstrip()
  48. if len(cleaned) < 2:
  49. return None
  50. parts = cleaned[1:].lstrip().split(maxsplit=1)
  51. if len(parts) == 0:
  52. return None
  53. return Command(
  54. bang=cleaned[0],
  55. name=parts[0].lower(),
  56. raw_name=parts[0],
  57. args=parts[1] if len(parts) > 1 else "",
  58. )
  59. def get_subcommand(self, inherit_bang=True) -> Command:
  60. saved = self.cache.get(("subcommand", inherit_bang), None)
  61. if saved is None:
  62. if inherit_bang and not self.args.startswith(self.bang):
  63. saved = Command.from_text(self.bang + self.args)
  64. else:
  65. saved = Command.from_text(self.args)
  66. self.cache[("subcommand", inherit_bang)] = saved
  67. return saved
  68. @dataclass
  69. class Response:
  70. origin_id: str
  71. channel_id: str
  72. text: Optional[str] = None
  73. attachments: Optional[list[Attachment]] = None
  74. @staticmethod
  75. def from_message(
  76. msg: Message, text: Optional[str] = None, attachments: list[Attachment] = None
  77. ) -> Response:
  78. return Response(
  79. origin_id=msg.origin_id,
  80. channel_id=msg.channel_id,
  81. text=text,
  82. attachments=attachments or [],
  83. )
  84. @dataclass
  85. class Context:
  86. config: Callable[[str], Any]
  87. respond: Callable[[], Coroutine[None, None, None]]
  88. request: ClientSession
  89. database: Callable[[], Coroutine[None, None, Connection]]
  90. logger: Logger
  91. debugging: Optional[str] = None
  92. def get_debugging(self) -> Optional[str]:
  93. old = self.debugging
  94. self.debugging = None
  95. return old
  96. CommandType = Callable[[Message, Context], Coroutine[None, None, None]]
  97. StartupShutdownType = Callable[[Context], Coroutine[None, None, None]]
  98. @dataclass
  99. class CommandConfiguration:
  100. commands: dict[str, CommandType] = field(default_factory=dict)
  101. call_and_response: dict[str, str] = field(default_factory=dict)
  102. aliases: dict[str, str] = field(default_factory=dict)
  103. bangs: Container[str] = ("!",)
  104. startup: list[StartupShutdownType] = field(default_factory=list)
  105. shutdown: list[StartupShutdownType] = field(default_factory=list)
  106. def extend(self, other: CommandConfiguration) -> CommandConfiguration:
  107. return CommandConfiguration(
  108. commands={
  109. **self.commands,
  110. **other.commands,
  111. },
  112. call_and_response={
  113. **self.call_and_response,
  114. **other.call_and_response,
  115. },
  116. aliases={
  117. **self.aliases,
  118. **other.aliases,
  119. },
  120. bangs=(*self.bangs, *other.bangs),
  121. startup=[*self.startup, *other.startup],
  122. shutdown=[*self.shutdown, *other.shutdown],
  123. )