from collections.abc import Callable from typing import Union from functools import wraps import inspect import asyncio import dataclasses import json from ..types import ( Message, Context, CommandType, Response, StartupShutdownType, CommandConfiguration, ) from .failure import RollbotFailureException from .injection import Injector, inject_message, inject_context decorated_startup: list[StartupShutdownType] = [] decorated_shutdown: list[StartupShutdownType] = [] decorated_commands: dict[str, CommandType] = {} def on_startup(fn): decorated_startup.append(fn) return fn def on_shutdown(fn): decorated_shutdown.append(fn) return fn def as_command(arg: Union[str, Callable]): def impl(name, fn): if inspect.isasyncgenfunction(fn): lifted = fn elif inspect.iscoroutinefunction(fn): @wraps(fn) async def lifted(*args): yield await fn(*args) elif inspect.isgeneratorfunction(fn): @wraps(fn) async def lifted(*args): for res in fn(*args): yield res elif inspect.isfunction(fn): @wraps(fn) async def lifted(*args): yield fn(*args) else: raise ValueError # TODO details injectors = [] for param in inspect.signature(fn).parameters: annot = fn.__annotations__[param] if annot == Message: injectors.append(inject_message) elif annot == Context: injectors.append(inject_context) elif isinstance(annot, Injector): injectors.append(annot.inject) else: raise ValueError # TODO details async def command_impl(message: Message, context: Context): args = await asyncio.gather(*[inj(message, context) for inj in injectors]) try: async for result in lifted(*args): if isinstance(result, Response): response = result elif isinstance(result, str): response = Response.from_message(message, text=result) # TODO handle attachments, other special returns else: response = Response.from_message(message, str(result)) await context.respond(response) except RollbotFailureException as exc: # TODO handle errors more specifically await context.respond(Response.from_message(message, str(exc.failure))) decorated_commands[name] = command_impl return fn if isinstance(arg, str): return lambda fn: impl(arg, fn) else: return impl(arg.__name__, arg) def get_command_config() -> CommandConfiguration: return CommandConfiguration( commands=decorated_commands, call_and_response={}, aliases={}, bangs=(), startup=decorated_startup, shutdown=decorated_shutdown, ) def as_data(cls): table_name = "".join(("_" + c.lower()) if "A" <= c <= "Z" else c for c in cls.__name__).strip("_") columns = [ "key TEXT NOT NULL PRIMARY KEY", 'body TEXT DEFAULT ""', ] queries = {} for k, v in cls.__annotations__.items(): if v == int: t = "INT" elif v == float: t = "REAL" elif v == str: t = "TEXT" else: continue columns.append(f"__{k} {t} GENERATED ALWAYS AS (json_extract(body, '$.{k}')) VIRTUAL") queries[k] = f"SELECT body FROM {table_name} WHERE __{k} = ?" create_query = f"CREATE TABLE IF NOT EXISTS {table_name} ({', '.join(columns)})" @on_startup async def create_table(context: Context): async with context.database() as db: await db.execute(create_query) await db.commit() new_class = type( cls.__name__, (dataclasses.dataclass(cls),), dict( __field_queries=queries, __create_query=create_query, __select_query=f"SELECT body FROM {table_name} WHERE key=:key", __save_query=f"INSERT INTO {table_name} VALUES (:key, :body) ON CONFLICT(key) DO UPDATE SET body=:body", __to_blob=lambda self: json.dumps(dataclasses.asdict(self)), __from_blob=staticmethod(lambda blob: new_class(**json.loads(blob))) ), ) return new_class