discord_driver.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. from __future__ import annotations
  2. import io
  3. import asyncio
  4. import logging.config
  5. import os
  6. import tomllib
  7. from typing import Any
  8. import discord
  9. import rollbot
  10. from commands import config
  11. logging.config.fileConfig("logging.conf", disable_existing_loggers=False)
  12. with open(os.environ.get("SECRET_FILE", "secrets.toml"), "rb") as sfile:
  13. secrets = tomllib.load(sfile)
  14. database_file = os.environ.get("DATABASE_FILE", secrets["database_file"])
  15. bangs = tuple(t for t in secrets["discord"].get("bangs", "!/"))
  16. class DiscordBot(rollbot.Rollbot[discord.Message]):
  17. def __init__(self, client):
  18. super().__init__(config.extend(rollbot.CommandConfiguration(bangs=bangs)), database_file)
  19. self.discord_client = client
  20. client.event(self.on_message)
  21. if secrets["discord"].get("enable_react_notifs", False):
  22. client.event(self.on_reaction_add)
  23. def read_config(self, key: str) -> Any:
  24. cfg = secrets
  25. for part in key.split("."):
  26. cfg = cfg.get(part, None)
  27. if cfg is None:
  28. return None
  29. return cfg
  30. async def on_command(self, raw: discord.Message, message: rollbot.Message, command: str):
  31. async with raw.channel.typing():
  32. return await super().on_command(raw, message, command)
  33. async def on_reaction_add(self, reaction: discord.Reaction, user: discord.Member | discord.User):
  34. sender_id = getattr(reaction.message.author, "id", None)
  35. if str(sender_id) not in secrets["discord"].get("wants_react_notifs", []) or user.id == sender_id:
  36. return
  37. channel_name = getattr(reaction.message.channel, "name", "UNKNOWN CHANNEL")
  38. content = reaction.message.content
  39. content = (content[:17] + '...') if len(content) > 17 else content
  40. react_name = reaction.emoji if isinstance(reaction.emoji, str) else f":{reaction.emoji.name}:"
  41. notif = f"{user.name} {react_name}'d your message '{content}' in {channel_name}"
  42. user = await self.discord_client.fetch_user(sender_id)
  43. await user.send(notif)
  44. async def parse(self, msg: discord.Message) -> rollbot.Message:
  45. # TODO might be nice to only read attachments lazily
  46. attachments = [rollbot.Attachment(
  47. name=att.filename,
  48. body=await att.read(),
  49. ) for att in msg.attachments]
  50. if msg.reference is not None:
  51. channel = await self.discord_client.fetch_channel(msg.channel.id)
  52. attachments.append(rollbot.Attachment(
  53. name="reply",
  54. body=await channel.fetch_message(msg.reference.resolved.id),
  55. ))
  56. return rollbot.Message(
  57. origin_id="DISCORD",
  58. channel_id=str(msg.channel.id),
  59. sender_id=str(msg.author.id),
  60. message_id=str(msg.id),
  61. timestamp=msg.created_at,
  62. origin_admin="RollbotAdmin" in [r.name for r in getattr(msg.author, "roles", [])],
  63. channel_admin=False, # TODO - implement this if discord allows it
  64. sender_name=msg.author.name,
  65. text=msg.content,
  66. attachments=attachments,
  67. )
  68. async def respond(self, response: rollbot.Response):
  69. if response.origin_id != "DISCORD":
  70. self.context.logger.error(f"Unable to respond to {response.origin_id}")
  71. return
  72. channel = await self.discord_client.fetch_channel(response.channel_id)
  73. args = {}
  74. args["content"] = response.text or ""
  75. attachments = []
  76. files = []
  77. if response.attachments is not None:
  78. for att in response.attachments:
  79. if att.name == "image":
  80. if isinstance(att.body, bytes):
  81. embed = discord.Embed(description="Embedded Image")
  82. file = discord.File(io.BytesIO(att.body), filename="image.png")
  83. embed.set_image(url="attachment://image.png")
  84. args["embed"] = embed
  85. # TODO might eventually be nice to figure out a way of doing multiple embeds
  86. files.append(file)
  87. else:
  88. args["content"] += "\n" + att.body
  89. elif att.name == "reply":
  90. if att.body is None or not isinstance(att.body, str):
  91. raise ValueError("Invalid reply body type, must be message ID")
  92. args["reference"] = await channel.fetch_message(int(att.body))
  93. elif isinstance(att.body, discord.Attachment):
  94. attachments.append(att.body)
  95. if len(attachments) > 0:
  96. args["attachments"] = attachments
  97. if len(files) > 0:
  98. args["files"] = files
  99. await channel.send(**args)
  100. if __name__ == "__main__":
  101. loop = asyncio.get_event_loop()
  102. intents = discord.Intents.default()
  103. intents.message_content = True
  104. intents.reactions = True
  105. client = discord.Client(intents=intents, loop=loop)
  106. bot = DiscordBot(client)
  107. try:
  108. loop.run_until_complete(bot.on_startup())
  109. loop.run_until_complete(client.start(secrets["discord"]["token"]))
  110. except KeyboardInterrupt:
  111. loop.run_until_complete(client.close())
  112. finally:
  113. loop.run_until_complete(bot.on_shutdown())