types.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. from __future__ import annotations
  2. from dataclasses import dataclass, field
  3. from datetime import datetime
  4. from collections.abc import Callable, Coroutine, Container
  5. from typing import Union, Any, Optional
  6. from aiosqlite import Connection
  7. from aiohttp import ClientSession
  8. __all__ = [
  9. "Attachment",
  10. "Message",
  11. "Command",
  12. "Response",
  13. "Context",
  14. "CommandType",
  15. "StartupShutdownType",
  16. "CommandConfiguration",
  17. ]
  18. @dataclass
  19. class Attachment:
  20. name: str
  21. body: Union[str, bytes]
  22. @dataclass
  23. class Message:
  24. origin_id: str
  25. channel_id: str
  26. sender_id: str
  27. timestamp: datetime
  28. origin_admin: bool
  29. channel_admin: bool
  30. text: Optional[str] = None
  31. attachments: list[Attachment] = field(default_factory=list)
  32. def __post_init__(self):
  33. self.command = None
  34. @dataclass
  35. class Command:
  36. bang: str
  37. name: str
  38. args: str
  39. def __post_init__(self):
  40. # cache is used by injectors to store results and avoid recomputation
  41. self.cache = {}
  42. @staticmethod
  43. def from_text(text: str) -> Optional[Command]:
  44. cleaned = text.lstrip()
  45. if len(cleaned) < 2:
  46. return None
  47. parts = cleaned[1:].lstrip().split(maxsplit=1)
  48. if len(parts) == 0:
  49. return None
  50. return Command(
  51. bang=cleaned[0],
  52. name=parts[0],
  53. args=parts[1] if len(parts) > 1 else "",
  54. )
  55. @dataclass
  56. class Response:
  57. origin_id: str
  58. channel_id: str
  59. text: Optional[str] = None
  60. attachments: Optional[list[Attachment]] = None
  61. @staticmethod
  62. def from_message(
  63. msg: Message, text: Optional[str] = None, attachments: list[Attachment] = None
  64. ) -> Response:
  65. return Response(
  66. origin_id=msg.origin_id,
  67. channel_id=msg.channel_id,
  68. text=text,
  69. attachments=attachments or [],
  70. )
  71. @dataclass
  72. class Context:
  73. config: Callable[[str], Any]
  74. respond: Callable[[], Coroutine[None, None, None]]
  75. request: ClientSession
  76. database: Callable[[], Coroutine[None, None, Connection]]
  77. CommandType = Callable[[Message, Context], Coroutine[None, None, None]]
  78. StartupShutdownType = Callable[[Context], Coroutine[None, None, None]]
  79. @dataclass
  80. class CommandConfiguration:
  81. commands: dict[str, CommandType] = field(default_factory=dict)
  82. call_and_response: dict[str, str] = field(default_factory=dict)
  83. aliases: dict[str, str] = field(default_factory=dict)
  84. bangs: Container[str] = ("!",)
  85. startup: list[StartupShutdownType] = field(default_factory=list)
  86. shutdown: list[StartupShutdownType] = field(default_factory=list)
  87. def extend(self, other: CommandConfiguration) -> CommandConfiguration:
  88. return CommandConfiguration(
  89. commands={
  90. **self.commands,
  91. **other.commands,
  92. },
  93. call_and_response={
  94. **self.call_and_response,
  95. **other.call_and_response,
  96. },
  97. aliases={
  98. **self.aliases,
  99. **other.aliases,
  100. },
  101. bangs=(*self.bangs, *other.bangs),
  102. startup=[*self.startup, *other.startup],
  103. shutdown=[*self.shutdown, *other.shutdown],
  104. )