discord_driver.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. from __future__ import annotations
  2. import io
  3. import asyncio
  4. import logging.config
  5. import os
  6. import tomllib
  7. from typing import Any
  8. import discord
  9. import rollbot
  10. from commands import config
  11. logging.config.fileConfig("logging.conf", disable_existing_loggers=False)
  12. with open(os.environ.get("SECRET_FILE", "secrets.toml"), "rb") as sfile:
  13. secrets = tomllib.load(sfile)
  14. database_file = os.environ.get("DATABASE_FILE", secrets["database_file"])
  15. config.bangs = tuple(t for t in secrets["discord"].get("bangs", "!/"))
  16. class DiscordBot(rollbot.Rollbot[discord.Message]):
  17. def __init__(self, client, loop):
  18. super().__init__(config, database_file, loop=loop)
  19. self.discord_client = client
  20. client.event(self.on_message)
  21. if secrets["discord"].get("enable_react_notifs", False):
  22. client.event(self.on_reaction_add)
  23. def read_config(self, key: str) -> Any:
  24. cfg = secrets
  25. for part in key.split("."):
  26. cfg = cfg.get(part, None)
  27. if cfg is None:
  28. return None
  29. return cfg
  30. async def on_command(
  31. self, raw: discord.Message, message: rollbot.Message, command: str
  32. ):
  33. async with raw.channel.typing():
  34. return await super().on_command(raw, message, command)
  35. async def on_reaction_add(
  36. self, reaction: discord.Reaction, user: discord.Member | discord.User
  37. ):
  38. sender_id = getattr(reaction.message.author, "id", None)
  39. if (
  40. str(sender_id) not in secrets["discord"].get("wants_react_notifs", [])
  41. or user.id == sender_id
  42. ):
  43. return
  44. channel_name = getattr(reaction.message.channel, "name", "UNKNOWN CHANNEL")
  45. content = reaction.message.content
  46. content = (content[:17] + "...") if len(content) > 17 else content
  47. react_name = (
  48. reaction.emoji
  49. if isinstance(reaction.emoji, str)
  50. else f":{reaction.emoji.name}:"
  51. )
  52. notif = f"{user.name} {react_name}'d your message '{content}' in {channel_name}"
  53. user = await self.discord_client.fetch_user(sender_id)
  54. await user.send(notif)
  55. async def parse(self, msg: discord.Message) -> rollbot.Message:
  56. # TODO might be nice to only read attachments lazily
  57. attachments = [
  58. rollbot.Attachment(
  59. name=att.filename,
  60. body=await att.read(),
  61. )
  62. for att in msg.attachments
  63. ]
  64. if msg.reference is not None and msg.reference.resolved is not None:
  65. channel = await self.discord_client.fetch_channel(msg.channel.id)
  66. attachments.append(
  67. rollbot.Attachment(
  68. name="reply",
  69. body=await channel.fetch_message(msg.reference.resolved.id),
  70. )
  71. )
  72. return rollbot.Message(
  73. origin_id="DISCORD",
  74. channel_id=str(msg.channel.id),
  75. sender_id=str(msg.author.id),
  76. message_id=str(msg.id),
  77. timestamp=msg.created_at,
  78. origin_admin="RollbotAdmin"
  79. in [r.name for r in getattr(msg.author, "roles", [])],
  80. channel_admin=False, # TODO - implement this if discord allows it
  81. sender_name=msg.author.name,
  82. text=msg.content,
  83. attachments=attachments,
  84. force_command=(
  85. isinstance(msg.channel, discord.DMChannel)
  86. and msg.author != msg.channel.me
  87. ),
  88. )
  89. async def respond(self, response: rollbot.Response):
  90. if response.origin_id != "DISCORD":
  91. self.context.logger.error(f"Unable to respond to {response.origin_id}")
  92. return
  93. channel = await self.discord_client.fetch_channel(response.channel_id)
  94. args = {}
  95. args["content"] = response.text or ""
  96. attachments = []
  97. files = []
  98. reacts = []
  99. pin = False
  100. if response.attachments is not None:
  101. for att in response.attachments:
  102. if att.name == "image":
  103. if isinstance(att.body, bytes):
  104. embed = discord.Embed()
  105. file = discord.File(io.BytesIO(att.body), filename="image.png")
  106. embed.set_image(url="attachment://image.png")
  107. args["embed"] = embed
  108. # TODO might eventually be nice to figure out a way of doing multiple embeds
  109. files.append(file)
  110. else:
  111. args["content"] += "\n" + att.body
  112. elif att.name == "reply":
  113. if att.body is None or not isinstance(att.body, str):
  114. raise ValueError("Invalid reply body type, must be message ID")
  115. args["reference"] = await channel.fetch_message(int(att.body))
  116. elif att.name == "react":
  117. reacts.append(att.body)
  118. elif att.name == "pin":
  119. pin = True
  120. elif isinstance(att.body, discord.Attachment):
  121. attachments.append(att.body)
  122. if len(attachments) > 0:
  123. args["attachments"] = attachments
  124. if len(files) > 0:
  125. args["files"] = files
  126. # TODO add abilitly to disable silent?
  127. message = await channel.send(silent=True, **args)
  128. for react in reacts:
  129. await message.add_reaction(react)
  130. if pin:
  131. await message.pin()
  132. if __name__ == "__main__":
  133. loop = asyncio.get_event_loop()
  134. intents = discord.Intents.default()
  135. intents.message_content = True
  136. intents.reactions = True
  137. client = discord.Client(intents=intents, loop=loop)
  138. bot = DiscordBot(client, loop)
  139. try:
  140. loop.run_until_complete(bot.on_startup())
  141. loop.run_until_complete(client.start(secrets["discord"]["token"]))
  142. except KeyboardInterrupt:
  143. loop.run_until_complete(client.close())
  144. finally:
  145. loop.run_until_complete(bot.on_shutdown())