discord_driver.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. from __future__ import annotations
  2. import io
  3. import asyncio
  4. import logging.config
  5. import os
  6. import tomllib
  7. from typing import Any
  8. import discord
  9. import rollbot
  10. from commands import config
  11. logging.config.fileConfig("logging.conf", disable_existing_loggers=False)
  12. with open(os.environ.get("SECRET_FILE", "secrets.toml"), "rb") as sfile:
  13. secrets = tomllib.load(sfile)
  14. database_file = os.environ.get("DATABASE_FILE", secrets["database_file"])
  15. config.bangs = tuple(t for t in secrets["discord"].get("bangs", "!/"))
  16. class DiscordBot(rollbot.Rollbot[discord.Message]):
  17. def __init__(self, client, loop):
  18. super().__init__(config, database_file, loop=loop)
  19. self.discord_client = client
  20. client.event(self.on_message)
  21. client.event(self.on_reaction_add)
  22. self.missing_pings = []
  23. def read_config(self, key: str) -> Any:
  24. cfg = secrets
  25. for part in key.split("."):
  26. cfg = cfg.get(part, None)
  27. if cfg is None:
  28. return None
  29. return cfg
  30. async def on_command(
  31. self, raw: discord.Message, message: rollbot.Message, command: str
  32. ):
  33. async with raw.channel.typing():
  34. return await super().on_command(raw, message, command)
  35. async def on_reaction_add(
  36. self, reaction: discord.Reaction, user: discord.Member | discord.User
  37. ):
  38. try:
  39. if (
  40. not isinstance(reaction.emoji, str)
  41. and reaction.emoji.name == "missingping"
  42. ):
  43. if not self.missing_pings:
  44. # note this probably breaks if used across multiple servers...
  45. await client.fetch_guild(reaction.message.guild.id)
  46. self.missing_pings = [
  47. discord.utils.get(client.emojis, name=name)
  48. for name in (
  49. "missingping2",
  50. "missingping3",
  51. "missingping4",
  52. "missingping5",
  53. )
  54. ]
  55. for ping in self.missing_pings:
  56. await reaction.message.add_reaction(ping)
  57. except Exception:
  58. self.context.logger.exception("Failed to do the funny missing ping bit")
  59. sender_id = getattr(reaction.message.author, "id", None)
  60. if (
  61. str(sender_id) not in secrets["discord"].get("wants_react_notifs", [])
  62. or user.id == sender_id
  63. ):
  64. return
  65. channel_name = getattr(reaction.message.channel, "name", "UNKNOWN CHANNEL")
  66. content = reaction.message.content
  67. content = (content[:17] + "...") if len(content) > 17 else content
  68. react_name = (
  69. reaction.emoji
  70. if isinstance(reaction.emoji, str)
  71. else f":{reaction.emoji.name}:"
  72. )
  73. notif = f"{user.name} {react_name}'d your message '{content}' in {channel_name}"
  74. user = await self.discord_client.fetch_user(sender_id)
  75. await user.send(notif)
  76. async def parse(self, msg: discord.Message) -> rollbot.Message:
  77. # TODO might be nice to only read attachments lazily
  78. attachments = [
  79. rollbot.Attachment(
  80. name=att.filename,
  81. body=await att.read(),
  82. )
  83. for att in msg.attachments
  84. ]
  85. if msg.reference is not None and msg.reference.resolved is not None:
  86. channel = await self.discord_client.fetch_channel(msg.channel.id)
  87. attachments.append(
  88. rollbot.Attachment(
  89. name="reply",
  90. body=await channel.fetch_message(msg.reference.resolved.id),
  91. )
  92. )
  93. return rollbot.Message(
  94. origin_id="DISCORD",
  95. channel_id=str(msg.channel.id),
  96. sender_id=str(msg.author.id),
  97. message_id=str(msg.id),
  98. timestamp=msg.created_at,
  99. origin_admin="RollbotAdmin"
  100. in [r.name for r in getattr(msg.author, "roles", [])],
  101. channel_admin=False, # TODO - implement this if discord allows it
  102. sender_name=msg.author.name,
  103. text=msg.content,
  104. attachments=attachments,
  105. force_command=(
  106. isinstance(msg.channel, discord.DMChannel)
  107. and msg.author != msg.channel.me
  108. ),
  109. )
  110. async def respond(self, response: rollbot.Response):
  111. if response.origin_id != "DISCORD":
  112. self.context.logger.error(f"Unable to respond to {response.origin_id}")
  113. return
  114. channel = await self.discord_client.fetch_channel(response.channel_id)
  115. args = {}
  116. args["content"] = response.text or ""
  117. attachments = []
  118. files = []
  119. reacts = []
  120. pin = False
  121. if response.attachments is not None:
  122. for att in response.attachments:
  123. if att.name == "image":
  124. if isinstance(att.body, bytes):
  125. embed = discord.Embed()
  126. file = discord.File(io.BytesIO(att.body), filename="image.png")
  127. embed.set_image(url="attachment://image.png")
  128. args["embed"] = embed
  129. # TODO might eventually be nice to figure out a way of doing multiple embeds
  130. files.append(file)
  131. else:
  132. args["content"] += "\n" + att.body
  133. elif att.name == "reply":
  134. if att.body is None or not isinstance(att.body, str):
  135. raise ValueError("Invalid reply body type, must be message ID")
  136. args["reference"] = await channel.fetch_message(int(att.body))
  137. elif att.name == "react":
  138. reacts.append(att.body)
  139. elif att.name == "pin":
  140. pin = True
  141. elif isinstance(att.body, discord.Attachment):
  142. attachments.append(att.body)
  143. if len(attachments) > 0:
  144. args["attachments"] = attachments
  145. if len(files) > 0:
  146. args["files"] = files
  147. # TODO add abilitly to disable silent?
  148. message = await channel.send(silent=True, **args)
  149. for react in reacts:
  150. await message.add_reaction(react)
  151. if pin:
  152. await message.pin()
  153. if __name__ == "__main__":
  154. loop = asyncio.get_event_loop()
  155. intents = discord.Intents.default()
  156. intents.message_content = True
  157. intents.reactions = True
  158. client = discord.Client(intents=intents, loop=loop)
  159. bot = DiscordBot(client, loop)
  160. try:
  161. loop.run_until_complete(bot.on_startup())
  162. loop.run_until_complete(client.start(secrets["discord"]["token"]))
  163. except KeyboardInterrupt:
  164. loop.run_until_complete(client.close())
  165. finally:
  166. loop.run_until_complete(bot.on_shutdown())