types.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. from __future__ import annotations
  2. from logging import Logger
  3. from dataclasses import dataclass, field
  4. from datetime import datetime
  5. from collections.abc import Callable, Coroutine, Container
  6. from typing import Union, Any, Optional
  7. import time
  8. from aiosqlite import Connection
  9. from aiohttp import ClientSession
  10. __all__ = [
  11. "Attachment",
  12. "Message",
  13. "Command",
  14. "Response",
  15. "Context",
  16. "CommandType",
  17. "StartupShutdownType",
  18. "CommandConfiguration",
  19. ]
  20. @dataclass
  21. class Attachment:
  22. name: str
  23. body: Union[str, bytes, None, Callable[[], Coroutine[Any, None, Any]]]
  24. @dataclass
  25. class Message:
  26. origin_id: str
  27. channel_id: str
  28. sender_id: str
  29. timestamp: datetime
  30. origin_admin: bool
  31. channel_admin: bool
  32. sender_name: Optional[str] = None
  33. text: Optional[str] = None
  34. attachments: list[Attachment] = field(default_factory=list)
  35. message_id: Optional[str] = None
  36. received_at: float = field(default_factory=time.time)
  37. def __post_init__(self):
  38. self.command = None
  39. @dataclass
  40. class Command:
  41. bang: str
  42. name: str
  43. raw_name: str
  44. args: str
  45. def __post_init__(self):
  46. # cache is used by injectors to store results and avoid recomputation
  47. self.cache = {}
  48. @staticmethod
  49. def from_text(text: str) -> Optional[Command]:
  50. cleaned = text.lstrip()
  51. if len(cleaned) < 2:
  52. return None
  53. parts = cleaned[1:].lstrip().split(maxsplit=1)
  54. if len(parts) == 0:
  55. return None
  56. return Command(
  57. bang=cleaned[0],
  58. name=parts[0].lower(),
  59. raw_name=parts[0],
  60. args=parts[1] if len(parts) > 1 else "",
  61. )
  62. def get_subcommand(self, inherit_bang=True) -> Command:
  63. saved = self.cache.get(("subcommand", inherit_bang), None)
  64. if saved is None:
  65. if inherit_bang and not self.args.startswith(self.bang):
  66. saved = Command.from_text(self.bang + self.args)
  67. else:
  68. saved = Command.from_text(self.args)
  69. self.cache[("subcommand", inherit_bang)] = saved
  70. return saved
  71. @dataclass
  72. class Response:
  73. origin_id: str
  74. channel_id: str
  75. text: Optional[str] = None
  76. attachments: Optional[list[Attachment]] = None
  77. cause: Optional[Message] = None
  78. @staticmethod
  79. def from_message(
  80. msg: Message, text: Optional[str] = None, attachments: list[Attachment] = None
  81. ) -> Response:
  82. return Response(
  83. origin_id=msg.origin_id,
  84. channel_id=msg.channel_id,
  85. text=text,
  86. attachments=attachments or [],
  87. cause=msg,
  88. )
  89. @dataclass
  90. class Context:
  91. config: Callable[[str], Any]
  92. respond: Callable[[], Coroutine[None, None, None]]
  93. request: ClientSession
  94. database: Callable[[], Coroutine[None, None, Connection]]
  95. logger: Logger
  96. debugging: Optional[str] = None
  97. def get_debugging(self) -> Optional[str]:
  98. old = self.debugging
  99. self.debugging = None
  100. return old
  101. CommandType = Callable[[Message, Context], Coroutine[None, None, None]]
  102. StartupShutdownType = Callable[[Context], Coroutine[None, None, None]]
  103. @dataclass
  104. class CommandConfiguration:
  105. commands: dict[str, CommandType] = field(default_factory=dict)
  106. call_and_response: dict[str, str] = field(default_factory=dict)
  107. aliases: dict[str, str] = field(default_factory=dict)
  108. bangs: Container[str] = ("!",)
  109. startup: list[StartupShutdownType] = field(default_factory=list)
  110. shutdown: list[StartupShutdownType] = field(default_factory=list)
  111. def extend(self, other: CommandConfiguration) -> CommandConfiguration:
  112. return CommandConfiguration(
  113. commands={
  114. **self.commands,
  115. **other.commands,
  116. },
  117. call_and_response={
  118. **self.call_and_response,
  119. **other.call_and_response,
  120. },
  121. aliases={
  122. **self.aliases,
  123. **other.aliases,
  124. },
  125. bangs=(*self.bangs, *other.bangs),
  126. startup=[*self.startup, *other.startup],
  127. shutdown=[*self.shutdown, *other.shutdown],
  128. )