types.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. from __future__ import annotations
  2. from logging import Logger
  3. from dataclasses import dataclass, field
  4. from datetime import datetime
  5. from collections.abc import Callable, Coroutine, Container
  6. from typing import Union, Any, Optional
  7. from aiosqlite import Connection
  8. from aiohttp import ClientSession
  9. __all__ = [
  10. "Attachment",
  11. "Message",
  12. "Command",
  13. "Response",
  14. "Context",
  15. "CommandType",
  16. "StartupShutdownType",
  17. "CommandConfiguration",
  18. ]
  19. @dataclass
  20. class Attachment:
  21. name: str
  22. body: Union[str, bytes]
  23. @dataclass
  24. class Message:
  25. origin_id: str
  26. channel_id: str
  27. sender_id: str
  28. timestamp: datetime
  29. origin_admin: bool
  30. channel_admin: bool
  31. sender_name: Optional[str] = None
  32. text: Optional[str] = None
  33. attachments: list[Attachment] = field(default_factory=list)
  34. def __post_init__(self):
  35. self.command = None
  36. @dataclass
  37. class Command:
  38. bang: str
  39. name: str
  40. args: str
  41. def __post_init__(self):
  42. # cache is used by injectors to store results and avoid recomputation
  43. self.cache = {}
  44. @staticmethod
  45. def from_text(text: str) -> Optional[Command]:
  46. cleaned = text.lstrip()
  47. if len(cleaned) < 2:
  48. return None
  49. parts = cleaned[1:].lstrip().split(maxsplit=1)
  50. if len(parts) == 0:
  51. return None
  52. return Command(
  53. bang=cleaned[0],
  54. name=parts[0],
  55. args=parts[1] if len(parts) > 1 else "",
  56. )
  57. def get_subcommand(self, inherit_bang=True) -> Command:
  58. saved = self.cache.get(("subcommand", inherit_bang), None)
  59. if saved is None:
  60. if inherit_bang and not self.args.startswith(self.bang):
  61. saved = Command.from_text(self.bang + self.args)
  62. else:
  63. saved = Command.from_text(self.args)
  64. self.cache[("subcommand", inherit_bang)] = saved
  65. return saved
  66. @dataclass
  67. class Response:
  68. origin_id: str
  69. channel_id: str
  70. text: Optional[str] = None
  71. attachments: Optional[list[Attachment]] = None
  72. @staticmethod
  73. def from_message(
  74. msg: Message, text: Optional[str] = None, attachments: list[Attachment] = None
  75. ) -> Response:
  76. return Response(
  77. origin_id=msg.origin_id,
  78. channel_id=msg.channel_id,
  79. text=text,
  80. attachments=attachments or [],
  81. )
  82. @dataclass
  83. class Context:
  84. config: Callable[[str], Any]
  85. respond: Callable[[], Coroutine[None, None, None]]
  86. request: ClientSession
  87. database: Callable[[], Coroutine[None, None, Connection]]
  88. logger: Logger
  89. debugging: Optional[str] = None
  90. def get_debugging(self) -> Optional[str]:
  91. old = self.debugging
  92. self.debugging = None
  93. return old
  94. CommandType = Callable[[Message, Context], Coroutine[None, None, None]]
  95. StartupShutdownType = Callable[[Context], Coroutine[None, None, None]]
  96. @dataclass
  97. class CommandConfiguration:
  98. commands: dict[str, CommandType] = field(default_factory=dict)
  99. call_and_response: dict[str, str] = field(default_factory=dict)
  100. aliases: dict[str, str] = field(default_factory=dict)
  101. bangs: Container[str] = ("!",)
  102. startup: list[StartupShutdownType] = field(default_factory=list)
  103. shutdown: list[StartupShutdownType] = field(default_factory=list)
  104. def extend(self, other: CommandConfiguration) -> CommandConfiguration:
  105. return CommandConfiguration(
  106. commands={
  107. **self.commands,
  108. **other.commands,
  109. },
  110. call_and_response={
  111. **self.call_and_response,
  112. **other.call_and_response,
  113. },
  114. aliases={
  115. **self.aliases,
  116. **other.aliases,
  117. },
  118. bangs=(*self.bangs, *other.bangs),
  119. startup=[*self.startup, *other.startup],
  120. shutdown=[*self.shutdown, *other.shutdown],
  121. )