123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152 |
- from collections.abc import Callable
- from typing import Union
- from functools import wraps
- import inspect
- import asyncio
- import dataclasses
- import json
- from ..types import (
- Message,
- Context,
- CommandType,
- Response,
- StartupShutdownType,
- CommandConfiguration,
- )
- from .failure import RollbotFailureException
- from .injection import Injector, inject_message, inject_context
- decorated_startup: list[StartupShutdownType] = []
- decorated_shutdown: list[StartupShutdownType] = []
- decorated_commands: dict[str, CommandType] = {}
- def on_startup(fn):
- decorated_startup.append(fn)
- return fn
- def on_shutdown(fn):
- decorated_shutdown.append(fn)
- return fn
- def as_command(arg: Union[str, Callable]):
- def impl(name, fn):
- if inspect.isasyncgenfunction(fn):
- lifted = fn
- elif inspect.iscoroutinefunction(fn):
- @wraps(fn)
- async def lifted(*args):
- yield await fn(*args)
- elif inspect.isgeneratorfunction(fn):
- @wraps(fn)
- async def lifted(*args):
- for res in fn(*args):
- yield res
- elif inspect.isfunction(fn):
- @wraps(fn)
- async def lifted(*args):
- yield fn(*args)
- else:
- raise ValueError # TODO details
- injectors = []
- for param in inspect.signature(fn).parameters:
- annot = fn.__annotations__[param]
- if annot == Message:
- injectors.append(inject_message)
- elif annot == Context:
- injectors.append(inject_context)
- elif isinstance(annot, Injector):
- injectors.append(annot.inject)
- else:
- raise ValueError # TODO details
- async def command_impl(message: Message, context: Context):
- args = await asyncio.gather(*[inj(message, context) for inj in injectors])
- try:
- async for result in lifted(*args):
- if isinstance(result, Response):
- response = result
- elif isinstance(result, str):
- response = Response.from_message(message, text=result)
- # TODO handle attachments, other special returns
- else:
- response = Response.from_message(message, str(result))
- await context.respond(response)
- except RollbotFailureException as exc:
- # TODO handle errors more specifically
- await context.respond(Response.from_message(message, str(exc.failure)))
- decorated_commands[name] = command_impl
- return fn
- if isinstance(arg, str):
- return lambda fn: impl(arg, fn)
- else:
- return impl(arg.__name__, arg)
- def get_command_config() -> CommandConfiguration:
- return CommandConfiguration(
- commands=decorated_commands,
- call_and_response={},
- aliases={},
- bangs=(),
- startup=decorated_startup,
- shutdown=decorated_shutdown,
- )
- def as_data(cls):
- table_name = "".join(("_" + c.lower()) if "A" <= c <= "Z" else c for c in cls.__name__).strip("_")
- columns = [
- "key TEXT NOT NULL PRIMARY KEY",
- 'body TEXT DEFAULT ""',
- ]
- queries = {}
- for k, v in cls.__annotations__.items():
- if v == int:
- t = "INT"
- elif v == float:
- t = "REAL"
- elif v == str:
- t = "TEXT"
- else:
- continue
- columns.append(f"__{k} {t} GENERATED ALWAYS AS (json_extract(body, '$.{k}')) VIRTUAL")
- queries[k] = f"SELECT body FROM {table_name} WHERE __{k} = ?"
- create_query = f"CREATE TABLE IF NOT EXISTS {table_name} ({', '.join(columns)})"
- @on_startup
- async def create_table(context: Context):
- async with context.database() as db:
- await db.execute(create_query)
- await db.commit()
- new_class = type(
- cls.__name__,
- (dataclasses.dataclass(cls),),
- dict(
- __field_queries=queries,
- __create_query=create_query,
- __select_query=f"SELECT body FROM {table_name} WHERE key=:key",
- __save_query=f"INSERT INTO {table_name} VALUES (:key, :body) ON CONFLICT(key) DO UPDATE SET body=:body",
- __to_blob=lambda self: json.dumps(dataclasses.asdict(self)),
- __from_blob=staticmethod(lambda blob: new_class(**json.loads(blob)))
- ),
- )
- return new_class
|