discord_driver.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. from __future__ import annotations
  2. import io
  3. import asyncio
  4. import logging.config
  5. import os
  6. from typing import Any
  7. import toml
  8. import discord
  9. import rollbot
  10. from commands import config
  11. logging.config.fileConfig("logging.conf", disable_existing_loggers=False)
  12. with open(os.environ.get("SECRET_FILE", "secrets.toml"), "r") as sfile:
  13. secrets = toml.load(sfile)
  14. database_file = os.environ.get("DATABASE_FILE", secrets["database_file"])
  15. bangs = tuple(t for t in secrets["discord"].get("bangs", "!/"))
  16. class DiscordBot(rollbot.Rollbot[discord.Message]):
  17. def __init__(self, client):
  18. super().__init__(config.extend(rollbot.CommandConfiguration(bangs=bangs)), database_file)
  19. self.discord_client = client
  20. client.event(self.on_message)
  21. def read_config(self, key: str) -> Any:
  22. cfg = secrets
  23. for part in key.split("."):
  24. cfg = cfg.get(part, None)
  25. if cfg is None:
  26. return None
  27. return cfg
  28. async def parse(self, msg: discord.Message) -> rollbot.Message:
  29. # TODO might be nice to only read attachments lazily
  30. attachments = [rollbot.Attachment(
  31. name=att.filename,
  32. body=await att.read(),
  33. ) for att in msg.attachments]
  34. if msg.reference is not None:
  35. channel = await self.discord_client.fetch_channel(msg.channel.id)
  36. attachments.append(rollbot.Attachment(
  37. name="reply",
  38. body=await channel.fetch_message(msg.reference.resolved.id),
  39. ))
  40. return rollbot.Message(
  41. origin_id="DISCORD",
  42. channel_id=str(msg.channel.id),
  43. sender_id=str(msg.author.id),
  44. message_id=str(msg.id),
  45. timestamp=msg.created_at,
  46. origin_admin="RollbotAdmin" in [r.name for r in msg.author.roles],
  47. channel_admin=False, # TODO - implement this if discord allows it
  48. sender_name=msg.author.name,
  49. text=msg.content,
  50. attachments=attachments,
  51. )
  52. async def respond(self, response: rollbot.Response):
  53. if response.origin_id != "DISCORD":
  54. self.context.logger.error(f"Unable to respond to {response.origin_id}")
  55. return
  56. channel = await self.discord_client.fetch_channel(response.channel_id)
  57. args = {}
  58. args["content"] = response.text or ""
  59. attachments = []
  60. files = []
  61. if response.attachments is not None:
  62. for att in response.attachments:
  63. if att.name == "image":
  64. if isinstance(att.body, bytes):
  65. embed = discord.Embed(description="Embedded Image")
  66. file = discord.File(io.BytesIO(att.body), filename="image.png")
  67. embed.set_image(url="attachment://image.png")
  68. args["embed"] = embed
  69. # TODO might eventually be nice to figure out a way of doing multiple embeds
  70. files.append(file)
  71. else:
  72. args["content"] += "\n" + att.body
  73. elif att.name == "reply":
  74. if att.body is None or not isinstance(att.body, str):
  75. raise ValueError("Invalid reply body type, must be message ID")
  76. args["reference"] = await channel.fetch_message(int(att.body))
  77. elif isinstance(att.body, discord.Attachment):
  78. attachments.append(att.body)
  79. if len(attachments) > 0:
  80. args["attachments"] = attachments
  81. if len(files) > 0:
  82. args["files"] = files
  83. await channel.send(**args)
  84. if __name__ == "__main__":
  85. loop = asyncio.get_event_loop()
  86. intents = discord.Intents.default()
  87. intents.message_content = True
  88. client = discord.Client(intents=intents, loop=loop)
  89. bot = DiscordBot(client)
  90. try:
  91. loop.run_until_complete(bot.on_startup())
  92. loop.run_until_complete(client.start(secrets["discord"]["token"]))
  93. except KeyboardInterrupt:
  94. loop.run_until_complete(client.close())
  95. finally:
  96. loop.run_until_complete(bot.on_shutdown())