mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Allow user to start work from DMs. (#267)
* merge upstream/main * update permissions check for guild settings * add error handler for the bot * allow users to start work from DMs and broadcast task completion messages to all log channels * remove print statement
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
BOT_TOKEN=<discord bot token>
|
||||
DECLARE_GLOBAL_COMMANDS=<testing guild id>
|
||||
OWNER_IDS=[<your user id>, <other user ids>]
|
||||
PREFIX="./"
|
||||
PREFIX="/" # Don't change, this allows for slash commands in DMs
|
||||
|
||||
OASST_API_URL="http://localhost:8080" # No trailing '/'
|
||||
OASST_API_KEY=""
|
||||
|
||||
@@ -8,7 +8,6 @@ import lightbulb
|
||||
import lightbulb.decorators
|
||||
import miru
|
||||
from aiosqlite import Connection
|
||||
from bot.db.schemas import GuildSettings
|
||||
from bot.utils import EMPTY
|
||||
from loguru import logger
|
||||
from oasst_shared.api_client import OasstApiClient, TaskType
|
||||
@@ -31,8 +30,8 @@ MAX_TASK_ACCEPT_TIME = 60 # 1 minute
|
||||
type=str,
|
||||
)
|
||||
@lightbulb.command("work", "Complete a task.")
|
||||
@lightbulb.implements(lightbulb.SlashCommand)
|
||||
async def work(ctx: lightbulb.SlashContext):
|
||||
@lightbulb.implements(lightbulb.SlashCommand, lightbulb.PrefixCommand)
|
||||
async def work(ctx: lightbulb.Context):
|
||||
"""Create and handle a task."""
|
||||
# make sure the user isn't currently doing a task
|
||||
currently_working: set[hikari.Snowflakeish] = ctx.bot.d.currently_working
|
||||
@@ -55,7 +54,7 @@ async def work(ctx: lightbulb.SlashContext):
|
||||
currently_working.remove(ctx.author.id)
|
||||
|
||||
|
||||
async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType) -> None:
|
||||
async def _handle_task(ctx: lightbulb.Context, task_type: TaskRequestType) -> None:
|
||||
"""Handle creating and collecting user input for a task.
|
||||
|
||||
Continually present tasks to the user until they select one, cancel, or time out.
|
||||
@@ -117,16 +116,16 @@ async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType)
|
||||
else:
|
||||
logger.critical(f"Unexpected task type received: {new_task.type}")
|
||||
|
||||
# Send a message in the log channel that the task is complete
|
||||
# TODO: Maybe do something with the msg ID so users can rate the "answer"
|
||||
assert ctx.guild_id is not None
|
||||
# Send a message in all the log channels that the task is complete
|
||||
conn: Connection = ctx.bot.d.db
|
||||
guild_settings = await GuildSettings.from_db(conn, ctx.guild_id)
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute("SELECT log_channel_id FROM guild_settings")
|
||||
log_channel_ids = await cursor.fetchall()
|
||||
|
||||
if guild_settings is not None and guild_settings.log_channel_id is not None:
|
||||
|
||||
channel = await ctx.bot.rest.fetch_channel(guild_settings.log_channel_id)
|
||||
assert isinstance(channel, hikari.TextableChannel) # option converter
|
||||
channels = [
|
||||
ctx.bot.cache.get_guild_channel(id[0]) or await ctx.bot.rest.fetch_channel(id[0])
|
||||
for id in log_channel_ids
|
||||
]
|
||||
|
||||
done_embed = (
|
||||
hikari.Embed(
|
||||
@@ -140,7 +139,10 @@ async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType)
|
||||
.add_field("Global Ranking", "0/0", inline=True)
|
||||
.set_footer(f"Task ID: {task.id}")
|
||||
)
|
||||
await channel.send(EMPTY, embed=done_embed)
|
||||
# This will definitely get the bot rate limited, but that's a future problem
|
||||
asyncio.gather(
|
||||
*(ch.send(EMPTY, embed=done_embed) for ch in channels if isinstance(ch, hikari.TextableChannel))
|
||||
)
|
||||
|
||||
# ask the user if they want to do another task
|
||||
choice_view = ChoiceView(timeout=MAX_TASK_ACCEPT_TIME)
|
||||
@@ -157,7 +159,7 @@ async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType)
|
||||
|
||||
|
||||
async def _select_task(
|
||||
ctx: lightbulb.SlashContext, task_type: TaskRequestType, user: protocol_schema.User | None = None
|
||||
ctx: lightbulb.Context, task_type: TaskRequestType, user: protocol_schema.User | None = None
|
||||
) -> tuple[protocol_schema.Task | None, str]:
|
||||
"""Present tasks to the user until they accept one, cancel, or time out."""
|
||||
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
|
||||
@@ -196,7 +198,7 @@ async def _select_task(
|
||||
|
||||
|
||||
async def _send_task(
|
||||
ctx: lightbulb.SlashContext, task: protocol_schema.Task
|
||||
ctx: lightbulb.Context, task: protocol_schema.Task
|
||||
) -> tuple[t.Literal["accept", "next", "cancel"] | None, str]:
|
||||
"""Send a task to the user.
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ class Settings(BaseSettings):
|
||||
bot_token: str = Field(env="BOT_TOKEN", default="")
|
||||
declare_global_commands: int = Field(env="DECLARE_GLOBAL_COMMANDS", default=0)
|
||||
owner_ids: list[int] = Field(env="OWNER_IDS", default_factory=list)
|
||||
prefix: str = Field(env="PREFIX", default="./")
|
||||
prefix: str = Field(env="PREFIX", default="/")
|
||||
oasst_api_url: str = Field(env="OASST_API_URL", default="http://localhost:8080")
|
||||
oasst_api_key: str = Field(env="OASST_API_KEY", default="")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user