command_system.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. import logging
  2. from dataclasses import dataclass
  3. from enum import Enum, auto
  4. import inspect
  5. import functools
  6. from sqlalchemy import Column, DateTime, Binary, String, Float, Integer
  7. from sqlalchemy.ext.declarative import declarative_base
  8. BANGS = ('!',)
  9. ModelBase = declarative_base()
  10. class GroupBasedSingleton(ModelBase):
  11. __tablename__ = "group_based_singleton"
  12. group_id = Column(String, primary_key=True)
  13. command_name = Column(String, primary_key=True)
  14. subpart_name = Column(String, primary_key=True)
  15. integer_data = Column(Integer)
  16. float_data = Column(Float)
  17. string_data = Column(String)
  18. binary_data = Column(Binary)
  19. datetime_data = Column(DateTime)
  20. @staticmethod
  21. def get_or_create(db, group_id, command_name, subpart_name):
  22. sing = db.query(GroupBasedSingleton).get((group_id, command_name, subpart_name))
  23. if sing is None:
  24. sing = GroupBasedSingleton(
  25. group_id=group_id,
  26. command_name=command_name,
  27. subpart_name=subpart_name,
  28. integer_data=None,
  29. float_data=None,
  30. string_data=None,
  31. binary_data=None,
  32. datetime_data=None
  33. )
  34. db.add(sing)
  35. return sing
  36. def pop_arg(text):
  37. if text is None:
  38. return None, None
  39. parts = text.split(maxsplit=1)
  40. if len(parts) == 1:
  41. return parts[0], None
  42. return parts[0], parts[1].strip()
  43. @dataclass
  44. class RollbotMessage:
  45. src: str
  46. name: str
  47. sender_id: str
  48. group_id: str
  49. message_id: str
  50. message_txt: str
  51. from_admin: bool
  52. def __post_init__(self):
  53. self.is_command = False
  54. if len(self.message_txt) > 0 and self.message_txt[0] in BANGS:
  55. cmd, raw = pop_arg(self.message_txt[1:].strip())
  56. if cmd is not None:
  57. self.is_command = True
  58. self.command = cmd.lower()
  59. self.raw_args = raw
  60. @staticmethod
  61. def from_groupme(msg, global_admins=(), group_admins={}):
  62. sender_id = msg["sender_id"]
  63. group_id = msg["group_id"]
  64. return RollbotMessage(
  65. "GROUPME",
  66. msg["name"],
  67. sender_id,
  68. group_id,
  69. msg["id"],
  70. msg["text"].strip(),
  71. sender_id in global_admins or (
  72. group_id in group_admins and
  73. sender_id in group_admins[group_id])
  74. )
  75. @staticmethod
  76. def from_discord(msg, global_admins=(), group_admins={}):
  77. sender_id = str(msg.author.id)
  78. group_id = str(msg.channel.id)
  79. return RollbotMessage(
  80. "DISCORD",
  81. msg.author.name,
  82. sender_id,
  83. group_id,
  84. msg.id,
  85. msg.content.strip(),
  86. sender_id in global_admins or (
  87. group_id in group_admins and
  88. sender_id in group_admins[group_id])
  89. )
  90. def args(self, normalize=True):
  91. arg, rest = pop_arg(self.raw_args)
  92. while arg is not None:
  93. yield arg.lower() if normalize else arg
  94. arg, rest = pop_arg(rest)
  95. class RollbotFailure(Enum):
  96. INVALID_COMMAND = auto()
  97. MISSING_SUBCOMMAND = auto()
  98. INVALID_SUBCOMMAND = auto()
  99. INVALID_ARGUMENTS = auto()
  100. SERVICE_DOWN = auto()
  101. PERMISSIONS = auto()
  102. INTERNAL_ERROR = auto()
  103. _RESPONSE_TEMPLATE = """Response{
  104. Original Message: %s,
  105. Text Response: %s,
  106. Image Response: %s,
  107. Respond: %s,
  108. Failure Reason: %s,
  109. Failure Notes: %s
  110. }"""
  111. @dataclass
  112. class RollbotResponse:
  113. msg: RollbotMessage
  114. txt: str = None
  115. img: str = None
  116. respond: bool = True
  117. failure: RollbotFailure = None
  118. debugging: dict = None
  119. def __post_init__(self):
  120. self.info = _RESPONSE_TEMPLATE % (self.msg, self.txt, self.img, self.respond, self.failure, self.debugging)
  121. self.is_success = self.failure is None
  122. if self.failure is None:
  123. self.failure_msg = None
  124. elif self.failure == RollbotFailure.INVALID_COMMAND:
  125. self.failure_msg = "Sorry - I don't think I understand the command '!%s'... " % self.msg.command \
  126. + "I'll try to figure it out and get back to you!"
  127. elif self.failure == RollbotFailure.MISSING_SUBCOMMAND:
  128. self.failure_msg = "Sorry - !%s requires a sub-command." % self.msg.command
  129. elif self.failure == RollbotFailure.INVALID_SUBCOMMAND:
  130. self.failure_msg = "Sorry - the sub-command you used for %s was not valid." % self.msg.command
  131. elif self.failure == RollbotFailure.INVALID_ARGUMENTS:
  132. self.failure_msg = "Sorry - %s cannot use those arguments!" % self.msg.command
  133. elif self.failure == RollbotFailure.SERVICE_DOWN:
  134. self.failure_msg = "Sorry - %s relies on a service I couldn't reach!" % self.msg.command
  135. elif self.failure == RollbotFailure.PERMISSIONS:
  136. self.failure_msg = "Sorry - you don't have permission to use that command or sub-command in this chat!"
  137. elif self.failure == RollbotFailure.INTERNAL_ERROR:
  138. self.failure_msg = "Sorry - I encountered an unrecoverable error, please review internal logs."
  139. if self.debugging is not None and "explain" in self.debugging:
  140. self.failure_msg += " " + self.debugging["explain"]
  141. class RollbotPlugin:
  142. def __init__(self, command, bot, logger=logging.getLogger(__name__)):
  143. self.command = command
  144. self.bot = bot
  145. self.logger = logger
  146. self.logger.info(f"Intializing {type(self).__name__} matching {command}")
  147. def on_start(self, db):
  148. self.logger.info(f"No on_start initialization of {type(self).__name__}")
  149. def on_shutdown(self, db):
  150. self.logger.info(f"No on_shutdown de-initialization of {type(self).__name__}")
  151. def on_command(self, db, message):
  152. raise NotImplementedError
  153. def as_plugin(command):
  154. if isinstance(command, str):
  155. command_name = command
  156. else:
  157. command_name = command.__name__
  158. def init_standin(self, bot, logger=logging.getLogger(__name__)):
  159. RollbotPlugin.__init__(self, command_name, bot, logger=logger)
  160. def decorator(fn):
  161. sig = inspect.signature(fn)
  162. converters = []
  163. for p in sig.parameters:
  164. if p in ("msg", "message", "_msg"):
  165. converters.append(lambda self, db, msg: msg)
  166. elif p in ("db", "database"):
  167. converters.append(lambda self, db, msg: db)
  168. elif p in ("log", "logger"):
  169. converters.append(lambda self, db, msg: self.logger)
  170. elif p in ("bot", "rollbot"):
  171. converters.append(lambda self, db, msg: self.bot)
  172. elif p.startswith("data") or p.endswith("data") or p in ("group_singleton", "singleton"):
  173. subp = fn.__annotations__.get(p, "")
  174. converters.append(lambda self, db, msg, subp=subp: GroupBasedSingleton.get_or_create(db, msg.group_id, self.command, subp))
  175. else:
  176. raise ValueError(f"Illegal argument name {p} in decorated plugin {command_name}")
  177. def on_command_standin(self, db, msg):
  178. res = fn(*[c(self, db, msg) for c in converters])
  179. if isinstance(res, RollbotResponse):
  180. return res
  181. else:
  182. return RollbotResponse(msg, txt=str(res))
  183. return type(
  184. f"AutoGenerated`{command_name}`Command",
  185. (RollbotPlugin,),
  186. dict(
  187. __init__=init_standin,
  188. on_command=on_command_standin,
  189. )
  190. )
  191. if isinstance(command, str):
  192. return decorator
  193. else:
  194. return decorator(command)