|
@@ -1,5 +1,8 @@
|
|
from dataclasses import dataclass
|
|
from dataclasses import dataclass
|
|
from typing import Any, Generic, TypeVar
|
|
from typing import Any, Generic, TypeVar
|
|
|
|
+import asyncio
|
|
|
|
+
|
|
|
|
+import aiosqlite
|
|
|
|
|
|
from .types import CommandConfiguration, Message, Response, Context
|
|
from .types import CommandConfiguration, Message, Response, Context
|
|
|
|
|
|
@@ -12,6 +15,13 @@ class Rollbot(Generic[RawMsg]):
|
|
command_config: CommandConfiguration
|
|
command_config: CommandConfiguration
|
|
database_file: str
|
|
database_file: str
|
|
|
|
|
|
|
|
+ def __post_init__(self):
|
|
|
|
+ self.context = Context(
|
|
|
|
+ config=self.read_config,
|
|
|
|
+ respond=self.respond,
|
|
|
|
+ database=lambda: aiosqlite.connect(self.database_file),
|
|
|
|
+ )
|
|
|
|
+
|
|
def read_config(self, key: str) -> Any:
|
|
def read_config(self, key: str) -> Any:
|
|
raise NotImplemented("Must be implemented by driver")
|
|
raise NotImplemented("Must be implemented by driver")
|
|
|
|
|
|
@@ -21,6 +31,12 @@ class Rollbot(Generic[RawMsg]):
|
|
async def respond(self, response: Response):
|
|
async def respond(self, response: Response):
|
|
raise NotImplemented("Must be implemented by driver")
|
|
raise NotImplemented("Must be implemented by driver")
|
|
|
|
|
|
|
|
+ async def on_startup(self):
|
|
|
|
+ await asyncio.gather(*[task(self.context) for task in self.command_config.startup])
|
|
|
|
+
|
|
|
|
+ async def on_shutdown(self):
|
|
|
|
+ await asyncio.gather(*[task(self.context) for task in self.command_config.shutdown])
|
|
|
|
+
|
|
async def on_message(self, incoming: RawMsg):
|
|
async def on_message(self, incoming: RawMsg):
|
|
message = self.parse(incoming)
|
|
message = self.parse(incoming)
|
|
if message.text is None:
|
|
if message.text is None:
|
|
@@ -45,9 +61,4 @@ class Rollbot(Generic[RawMsg]):
|
|
await self.respond(Response.from_message(message, f"Sorry! I don't know the command {command}."))
|
|
await self.respond(Response.from_message(message, f"Sorry! I don't know the command {command}."))
|
|
return
|
|
return
|
|
|
|
|
|
- await command_call(Context(
|
|
|
|
- message=message,
|
|
|
|
- config=self.read_config,
|
|
|
|
- respond=self.respond,
|
|
|
|
- # database=..., # TODO database
|
|
|
|
- ))
|
|
|
|
|
|
+ await command_call(message, self.context)
|