types.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. from __future__ import annotations
  2. from dataclasses import dataclass, field
  3. from datetime import datetime
  4. from collections.abc import Callable, Coroutine, Container
  5. from typing import Union, Any, Optional
  6. from aiosqlite import Connection
  7. from aiohttp import ClientSession
  8. __all__ = [
  9. "Attachment",
  10. "Message",
  11. "Command",
  12. "Response",
  13. "Context",
  14. "CommandType",
  15. "StartupShutdownType",
  16. "CommandConfiguration",
  17. ]
  18. @dataclass
  19. class Attachment:
  20. name: str
  21. body: Union[str, bytes]
  22. @dataclass
  23. class Message:
  24. origin_id: str
  25. channel_id: str
  26. sender_id: str
  27. timestamp: datetime
  28. origin_admin: bool
  29. channel_admin: bool
  30. text: Optional[str] = None
  31. attachments: list[Attachment] = field(default_factory=list)
  32. def __post_init__(self):
  33. self.command = None
  34. @dataclass
  35. class Command:
  36. bang: str
  37. name: str
  38. args: str
  39. @staticmethod
  40. def from_text(text: str) -> Optional[Command]:
  41. cleaned = text.lstrip()
  42. if len(cleaned) < 2:
  43. return None
  44. parts = cleaned[1:].lstrip().split(maxsplit=1)
  45. if len(parts) == 0:
  46. return None
  47. return Command(
  48. bang=cleaned[0],
  49. name=parts[0],
  50. args=parts[1] if len(parts) > 1 else "",
  51. )
  52. @dataclass
  53. class Response:
  54. origin_id: str
  55. channel_id: str
  56. text: Optional[str] = None
  57. attachments: Optional[list[Attachment]] = None
  58. @staticmethod
  59. def from_message(
  60. msg: Message, text: Optional[str] = None, attachments: list[Attachment] = None
  61. ) -> Response:
  62. return Response(
  63. origin_id=msg.origin_id,
  64. channel_id=msg.channel_id,
  65. text=text,
  66. attachments=attachments or [],
  67. )
  68. @dataclass
  69. class Context:
  70. config: Callable[[str], Any]
  71. respond: Callable[[], Coroutine[None, None, None]]
  72. request: ClientSession
  73. database: Callable[[], Coroutine[None, None, Connection]]
  74. CommandType = Callable[[Message, Context], Coroutine[None, None, None]]
  75. StartupShutdownType = Callable[[Context], Coroutine[None, None, None]]
  76. @dataclass
  77. class CommandConfiguration:
  78. commands: dict[str, CommandType] = field(default_factory=dict)
  79. call_and_response: dict[str, str] = field(default_factory=dict)
  80. aliases: dict[str, str] = field(default_factory=dict)
  81. bangs: Container[str] = ("!",)
  82. startup: list[StartupShutdownType] = field(default_factory=list)
  83. shutdown: list[StartupShutdownType] = field(default_factory=list)
  84. def extend(self, other: CommandConfiguration) -> CommandConfiguration:
  85. return CommandConfiguration(
  86. commands={
  87. **self.commands,
  88. **other.commands,
  89. },
  90. call_and_response={
  91. **self.call_and_response,
  92. **other.call_and_response,
  93. },
  94. aliases={
  95. **self.aliases,
  96. **other.aliases,
  97. },
  98. bangs=(*self.bangs, *other.bangs),
  99. startup=[*self.startup, *other.startup],
  100. shutdown=[*self.shutdown, *other.shutdown],
  101. )