diff --git a/discord-bot/.env.example b/discord-bot/.env.example index 5cd18fac..ec114c8f 100644 --- a/discord-bot/.env.example +++ b/discord-bot/.env.example @@ -1,7 +1,7 @@ BOT_TOKEN= DECLARE_GLOBAL_COMMANDS= OWNER_IDS=[, ] -PREFIX="./" +PREFIX="/" # Don't change, this allows for slash commands in DMs OASST_API_URL="http://localhost:8080" # No trailing '/' OASST_API_KEY="" diff --git a/discord-bot/bot/extensions/work.py b/discord-bot/bot/extensions/work.py index 19802c64..c905e7a0 100644 --- a/discord-bot/bot/extensions/work.py +++ b/discord-bot/bot/extensions/work.py @@ -8,7 +8,6 @@ import lightbulb import lightbulb.decorators import miru from aiosqlite import Connection -from bot.db.schemas import GuildSettings from bot.utils import EMPTY from loguru import logger from oasst_shared.api_client import OasstApiClient, TaskType @@ -31,8 +30,8 @@ MAX_TASK_ACCEPT_TIME = 60 # 1 minute type=str, ) @lightbulb.command("work", "Complete a task.") -@lightbulb.implements(lightbulb.SlashCommand) -async def work(ctx: lightbulb.SlashContext): +@lightbulb.implements(lightbulb.SlashCommand, lightbulb.PrefixCommand) +async def work(ctx: lightbulb.Context): """Create and handle a task.""" # make sure the user isn't currently doing a task currently_working: set[hikari.Snowflakeish] = ctx.bot.d.currently_working @@ -55,7 +54,7 @@ async def work(ctx: lightbulb.SlashContext): currently_working.remove(ctx.author.id) -async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType) -> None: +async def _handle_task(ctx: lightbulb.Context, task_type: TaskRequestType) -> None: """Handle creating and collecting user input for a task. Continually present tasks to the user until they select one, cancel, or time out. @@ -117,16 +116,16 @@ async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType) else: logger.critical(f"Unexpected task type received: {new_task.type}") - # Send a message in the log channel that the task is complete - # TODO: Maybe do something with the msg ID so users can rate the "answer" - assert ctx.guild_id is not None + # Send a message in all the log channels that the task is complete conn: Connection = ctx.bot.d.db - guild_settings = await GuildSettings.from_db(conn, ctx.guild_id) + async with conn.cursor() as cursor: + await cursor.execute("SELECT log_channel_id FROM guild_settings") + log_channel_ids = await cursor.fetchall() - if guild_settings is not None and guild_settings.log_channel_id is not None: - - channel = await ctx.bot.rest.fetch_channel(guild_settings.log_channel_id) - assert isinstance(channel, hikari.TextableChannel) # option converter + channels = [ + ctx.bot.cache.get_guild_channel(id[0]) or await ctx.bot.rest.fetch_channel(id[0]) + for id in log_channel_ids + ] done_embed = ( hikari.Embed( @@ -140,7 +139,10 @@ async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType) .add_field("Global Ranking", "0/0", inline=True) .set_footer(f"Task ID: {task.id}") ) - await channel.send(EMPTY, embed=done_embed) + # This will definitely get the bot rate limited, but that's a future problem + asyncio.gather( + *(ch.send(EMPTY, embed=done_embed) for ch in channels if isinstance(ch, hikari.TextableChannel)) + ) # ask the user if they want to do another task choice_view = ChoiceView(timeout=MAX_TASK_ACCEPT_TIME) @@ -157,7 +159,7 @@ async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType) async def _select_task( - ctx: lightbulb.SlashContext, task_type: TaskRequestType, user: protocol_schema.User | None = None + ctx: lightbulb.Context, task_type: TaskRequestType, user: protocol_schema.User | None = None ) -> tuple[protocol_schema.Task | None, str]: """Present tasks to the user until they accept one, cancel, or time out.""" oasst_api: OasstApiClient = ctx.bot.d.oasst_api @@ -196,7 +198,7 @@ async def _select_task( async def _send_task( - ctx: lightbulb.SlashContext, task: protocol_schema.Task + ctx: lightbulb.Context, task: protocol_schema.Task ) -> tuple[t.Literal["accept", "next", "cancel"] | None, str]: """Send a task to the user. diff --git a/discord-bot/bot/settings.py b/discord-bot/bot/settings.py index 24c837a3..a2e2c2ba 100644 --- a/discord-bot/bot/settings.py +++ b/discord-bot/bot/settings.py @@ -8,7 +8,7 @@ class Settings(BaseSettings): bot_token: str = Field(env="BOT_TOKEN", default="") declare_global_commands: int = Field(env="DECLARE_GLOBAL_COMMANDS", default=0) owner_ids: list[int] = Field(env="OWNER_IDS", default_factory=list) - prefix: str = Field(env="PREFIX", default="./") + prefix: str = Field(env="PREFIX", default="/") oasst_api_url: str = Field(env="OASST_API_URL", default="http://localhost:8080") oasst_api_key: str = Field(env="OASST_API_KEY", default="")