types.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. from __future__ import annotations
  2. from logging import Logger
  3. from dataclasses import dataclass, field
  4. from datetime import datetime
  5. from collections.abc import Callable, Coroutine, Container
  6. from typing import Union, Any, Optional
  7. from aiosqlite import Connection
  8. from aiohttp import ClientSession
  9. __all__ = [
  10. "Attachment",
  11. "Message",
  12. "Command",
  13. "Response",
  14. "Context",
  15. "CommandType",
  16. "StartupShutdownType",
  17. "CommandConfiguration",
  18. ]
  19. @dataclass
  20. class Attachment:
  21. name: str
  22. body: Union[str, bytes]
  23. @dataclass
  24. class Message:
  25. origin_id: str
  26. channel_id: str
  27. sender_id: str
  28. timestamp: datetime
  29. origin_admin: bool
  30. channel_admin: bool
  31. text: Optional[str] = None
  32. attachments: list[Attachment] = field(default_factory=list)
  33. def __post_init__(self):
  34. self.command = None
  35. @dataclass
  36. class Command:
  37. bang: str
  38. name: str
  39. args: str
  40. def __post_init__(self):
  41. # cache is used by injectors to store results and avoid recomputation
  42. self.cache = {}
  43. @staticmethod
  44. def from_text(text: str) -> Optional[Command]:
  45. cleaned = text.lstrip()
  46. if len(cleaned) < 2:
  47. return None
  48. parts = cleaned[1:].lstrip().split(maxsplit=1)
  49. if len(parts) == 0:
  50. return None
  51. return Command(
  52. bang=cleaned[0],
  53. name=parts[0],
  54. args=parts[1] if len(parts) > 1 else "",
  55. )
  56. def get_subcommand(self, inherit_bang=True) -> Command:
  57. saved = self.cache.get(("subcommand", inherit_bang), None)
  58. if saved is None:
  59. if inherit_bang and not self.args.startswith(self.bang):
  60. saved = Command.from_text(self.bang + self.args)
  61. else:
  62. saved = Command.from_text(self.args)
  63. self.cache[("subcommand", inherit_bang)] = saved
  64. return saved
  65. @dataclass
  66. class Response:
  67. origin_id: str
  68. channel_id: str
  69. text: Optional[str] = None
  70. attachments: Optional[list[Attachment]] = None
  71. @staticmethod
  72. def from_message(
  73. msg: Message, text: Optional[str] = None, attachments: list[Attachment] = None
  74. ) -> Response:
  75. return Response(
  76. origin_id=msg.origin_id,
  77. channel_id=msg.channel_id,
  78. text=text,
  79. attachments=attachments or [],
  80. )
  81. @dataclass
  82. class Context:
  83. config: Callable[[str], Any]
  84. respond: Callable[[], Coroutine[None, None, None]]
  85. request: ClientSession
  86. database: Callable[[], Coroutine[None, None, Connection]]
  87. logger: Logger
  88. CommandType = Callable[[Message, Context], Coroutine[None, None, None]]
  89. StartupShutdownType = Callable[[Context], Coroutine[None, None, None]]
  90. @dataclass
  91. class CommandConfiguration:
  92. commands: dict[str, CommandType] = field(default_factory=dict)
  93. call_and_response: dict[str, str] = field(default_factory=dict)
  94. aliases: dict[str, str] = field(default_factory=dict)
  95. bangs: Container[str] = ("!",)
  96. startup: list[StartupShutdownType] = field(default_factory=list)
  97. shutdown: list[StartupShutdownType] = field(default_factory=list)
  98. def extend(self, other: CommandConfiguration) -> CommandConfiguration:
  99. return CommandConfiguration(
  100. commands={
  101. **self.commands,
  102. **other.commands,
  103. },
  104. call_and_response={
  105. **self.call_and_response,
  106. **other.call_and_response,
  107. },
  108. aliases={
  109. **self.aliases,
  110. **other.aliases,
  111. },
  112. bangs=(*self.bangs, *other.bangs),
  113. startup=[*self.startup, *other.startup],
  114. shutdown=[*self.shutdown, *other.shutdown],
  115. )