types.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. from __future__ import annotations
  2. from dataclasses import dataclass, field
  3. from datetime import datetime
  4. from collections.abc import Callable, Coroutine, Container
  5. from typing import Union, Any, Optional
  6. from aiosqlite import Connection
  7. from aiohttp import ClientSession
  8. __all__ = [
  9. "Attachment",
  10. "Message",
  11. "Command",
  12. "Response",
  13. "Context",
  14. "CommandType",
  15. "StartupShutdownType",
  16. "CommandConfiguration",
  17. ]
  18. @dataclass
  19. class Attachment:
  20. name: str
  21. body: Union[str, bytes]
  22. @dataclass
  23. class Message:
  24. origin_id: str
  25. channel_id: str
  26. sender_id: str
  27. timestamp: datetime
  28. origin_admin: bool
  29. channel_admin: bool
  30. text: Optional[str] = None
  31. attachments: list[Attachment] = field(default_factory=list)
  32. def __post_init__(self):
  33. self.command = None
  34. @dataclass
  35. class Command:
  36. bang: str
  37. name: str
  38. args: str
  39. def __post_init__(self):
  40. # cache is used by injectors to store results and avoid recomputation
  41. self.cache = {}
  42. @staticmethod
  43. def from_text(text: str) -> Optional[Command]:
  44. cleaned = text.lstrip()
  45. if len(cleaned) < 2:
  46. return None
  47. parts = cleaned[1:].lstrip().split(maxsplit=1)
  48. if len(parts) == 0:
  49. return None
  50. return Command(
  51. bang=cleaned[0],
  52. name=parts[0],
  53. args=parts[1] if len(parts) > 1 else "",
  54. )
  55. def get_subcommand(self, inherit_bang=True) -> Command:
  56. if inherit_bang and not self.args.startswith(self.bang):
  57. return Command.from_text(self.bang + self.args)
  58. return Command.from_text(self.args)
  59. @dataclass
  60. class Response:
  61. origin_id: str
  62. channel_id: str
  63. text: Optional[str] = None
  64. attachments: Optional[list[Attachment]] = None
  65. @staticmethod
  66. def from_message(
  67. msg: Message, text: Optional[str] = None, attachments: list[Attachment] = None
  68. ) -> Response:
  69. return Response(
  70. origin_id=msg.origin_id,
  71. channel_id=msg.channel_id,
  72. text=text,
  73. attachments=attachments or [],
  74. )
  75. @dataclass
  76. class Context:
  77. config: Callable[[str], Any]
  78. respond: Callable[[], Coroutine[None, None, None]]
  79. request: ClientSession
  80. database: Callable[[], Coroutine[None, None, Connection]]
  81. CommandType = Callable[[Message, Context], Coroutine[None, None, None]]
  82. StartupShutdownType = Callable[[Context], Coroutine[None, None, None]]
  83. @dataclass
  84. class CommandConfiguration:
  85. commands: dict[str, CommandType] = field(default_factory=dict)
  86. call_and_response: dict[str, str] = field(default_factory=dict)
  87. aliases: dict[str, str] = field(default_factory=dict)
  88. bangs: Container[str] = ("!",)
  89. startup: list[StartupShutdownType] = field(default_factory=list)
  90. shutdown: list[StartupShutdownType] = field(default_factory=list)
  91. def extend(self, other: CommandConfiguration) -> CommandConfiguration:
  92. return CommandConfiguration(
  93. commands={
  94. **self.commands,
  95. **other.commands,
  96. },
  97. call_and_response={
  98. **self.call_and_response,
  99. **other.call_and_response,
  100. },
  101. aliases={
  102. **self.aliases,
  103. **other.aliases,
  104. },
  105. bangs=(*self.bangs, *other.bangs),
  106. startup=[*self.startup, *other.startup],
  107. shutdown=[*self.shutdown, *other.shutdown],
  108. )