discord_driver.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. from __future__ import annotations
  2. import io
  3. import asyncio
  4. import logging.config
  5. import os
  6. import tomllib
  7. from typing import Any
  8. import discord
  9. import rollbot
  10. from commands import config
  11. logging.config.fileConfig("logging.conf", disable_existing_loggers=False)
  12. with open(os.environ.get("SECRET_FILE", "secrets.toml"), "rb") as sfile:
  13. secrets = tomllib.load(sfile)
  14. database_file = os.environ.get("DATABASE_FILE", secrets["database_file"])
  15. config.bangs = tuple(t for t in secrets["discord"].get("bangs", "!/"))
  16. class DiscordBot(rollbot.Rollbot[discord.Message]):
  17. def __init__(self, client):
  18. super().__init__(config, database_file)
  19. self.discord_client = client
  20. client.event(self.on_message)
  21. if secrets["discord"].get("enable_react_notifs", False):
  22. client.event(self.on_reaction_add)
  23. def read_config(self, key: str) -> Any:
  24. cfg = secrets
  25. for part in key.split("."):
  26. cfg = cfg.get(part, None)
  27. if cfg is None:
  28. return None
  29. return cfg
  30. async def on_command(
  31. self, raw: discord.Message, message: rollbot.Message, command: str
  32. ):
  33. async with raw.channel.typing():
  34. return await super().on_command(raw, message, command)
  35. async def on_reaction_add(
  36. self, reaction: discord.Reaction, user: discord.Member | discord.User
  37. ):
  38. sender_id = getattr(reaction.message.author, "id", None)
  39. if (
  40. str(sender_id) not in secrets["discord"].get("wants_react_notifs", [])
  41. or user.id == sender_id
  42. ):
  43. return
  44. channel_name = getattr(reaction.message.channel, "name", "UNKNOWN CHANNEL")
  45. content = reaction.message.content
  46. content = (content[:17] + "...") if len(content) > 17 else content
  47. react_name = (
  48. reaction.emoji
  49. if isinstance(reaction.emoji, str)
  50. else f":{reaction.emoji.name}:"
  51. )
  52. notif = f"{user.name} {react_name}'d your message '{content}' in {channel_name}"
  53. user = await self.discord_client.fetch_user(sender_id)
  54. await user.send(notif)
  55. async def parse(self, msg: discord.Message) -> rollbot.Message:
  56. # TODO might be nice to only read attachments lazily
  57. attachments = [
  58. rollbot.Attachment(
  59. name=att.filename,
  60. body=await att.read(),
  61. )
  62. for att in msg.attachments
  63. ]
  64. if msg.reference is not None and msg.reference.resolved is not None:
  65. channel = await self.discord_client.fetch_channel(msg.channel.id)
  66. attachments.append(
  67. rollbot.Attachment(
  68. name="reply",
  69. body=await channel.fetch_message(msg.reference.resolved.id),
  70. )
  71. )
  72. return rollbot.Message(
  73. origin_id="DISCORD",
  74. channel_id=str(msg.channel.id),
  75. sender_id=str(msg.author.id),
  76. message_id=str(msg.id),
  77. timestamp=msg.created_at,
  78. origin_admin="RollbotAdmin"
  79. in [r.name for r in getattr(msg.author, "roles", [])],
  80. channel_admin=False, # TODO - implement this if discord allows it
  81. sender_name=msg.author.name,
  82. text=msg.content,
  83. attachments=attachments,
  84. )
  85. async def respond(self, response: rollbot.Response):
  86. if response.origin_id != "DISCORD":
  87. self.context.logger.error(f"Unable to respond to {response.origin_id}")
  88. return
  89. channel = await self.discord_client.fetch_channel(response.channel_id)
  90. args = {}
  91. args["content"] = response.text or ""
  92. attachments = []
  93. files = []
  94. reacts = []
  95. pin = False
  96. if response.attachments is not None:
  97. for att in response.attachments:
  98. if att.name == "image":
  99. if isinstance(att.body, bytes):
  100. embed = discord.Embed()
  101. file = discord.File(io.BytesIO(att.body), filename="image.png")
  102. embed.set_image(url="attachment://image.png")
  103. args["embed"] = embed
  104. # TODO might eventually be nice to figure out a way of doing multiple embeds
  105. files.append(file)
  106. else:
  107. args["content"] += "\n" + att.body
  108. elif att.name == "reply":
  109. if att.body is None or not isinstance(att.body, str):
  110. raise ValueError("Invalid reply body type, must be message ID")
  111. args["reference"] = await channel.fetch_message(int(att.body))
  112. elif att.name == "react":
  113. reacts.append(att.body)
  114. elif att.name == "pin":
  115. pin = True
  116. elif isinstance(att.body, discord.Attachment):
  117. attachments.append(att.body)
  118. if len(attachments) > 0:
  119. args["attachments"] = attachments
  120. if len(files) > 0:
  121. args["files"] = files
  122. # TODO add abilitly to disable silent?
  123. message = await channel.send(silent=True, **args)
  124. for react in reacts:
  125. await message.add_reaction(react)
  126. if pin:
  127. await message.pin()
  128. if __name__ == "__main__":
  129. loop = asyncio.get_event_loop()
  130. intents = discord.Intents.default()
  131. intents.message_content = True
  132. intents.reactions = True
  133. client = discord.Client(intents=intents, loop=loop)
  134. bot = DiscordBot(client)
  135. try:
  136. loop.run_until_complete(bot.on_startup())
  137. loop.run_until_complete(client.start(secrets["discord"]["token"]))
  138. except KeyboardInterrupt:
  139. loop.run_until_complete(client.close())
  140. finally:
  141. loop.run_until_complete(bot.on_shutdown())