add initial task loop for initial_prompt and rank_initial_prompts

This commit is contained in:
AlexanderHOtt
2022-12-29 14:20:56 -08:00
parent 99303ed265
commit 9fd2e76917
9 changed files with 733 additions and 78 deletions
-75
View File
@@ -1,75 +0,0 @@
# -*- coding: utf-8 -*-
import enum
from typing import Optional, Type
import typing as t
import requests
from oasst_shared.schemas import protocol as protocol_schema
class TaskType(str, enum.Enum):
summarize_story = "summarize_story"
rate_summary = "rate_summary"
initial_prompt = "initial_prompt"
user_reply = "user_reply"
assistant_reply = "assistant_reply"
rank_initial_prompts = "rank_initial_prompts"
rank_user_replies = "rank_user_replies"
rank_assistant_replies = "rank_assistant_replies"
done = "task_done"
class ApiClient:
def __init__(self, backend_url: str, api_key: str):
self.backend_url = backend_url
self.api_key = api_key
task_models_map: dict[str, Type[protocol_schema.Task]] = {
TaskType.summarize_story: protocol_schema.SummarizeStoryTask,
TaskType.rate_summary: protocol_schema.RateSummaryTask,
TaskType.initial_prompt: protocol_schema.InitialPromptTask,
TaskType.user_reply: protocol_schema.UserReplyTask,
TaskType.assistant_reply: protocol_schema.AssistantReplyTask,
TaskType.rank_initial_prompts: protocol_schema.RankInitialPromptsTask,
TaskType.rank_user_replies: protocol_schema.RankUserRepliesTask,
TaskType.rank_assistant_replies: protocol_schema.RankAssistantRepliesTask,
TaskType.done: protocol_schema.TaskDone,
}
self.task_models_map = task_models_map
def post(self, path: str, json: dict) -> dict:
response = requests.post(f"{self.backend_url}{path}", json=json, headers={"X-API-Key": self.api_key})
response.raise_for_status()
return response.json()
def _parse_task(self, data: dict[str, t.Any]) -> protocol_schema.Task:
if not isinstance(data, dict):
raise ValueError("dict expected")
task_type = data.get("type")
if task_type not in self.task_models_map:
raise RuntimeError(f"Unsupported task type: {task_type}")
return self.task_models_map[task_type].parse_obj(data)
def fetch_task(
self, task_type: protocol_schema.TaskRequestType, user: Optional[protocol_schema.User] = None
) -> protocol_schema.Task:
req = protocol_schema.TaskRequest(type=task_type, user=user)
data = self.post("/api/v1/tasks/", req.dict())
return self._parse_task(data)
def fetch_random_task(self, user: Optional[protocol_schema.User] = None) -> protocol_schema.Task:
return self.fetch_task(protocol_schema.TaskRequestType.random, user)
def ack_task(self, task_id: str, post_id: str) -> None:
req = protocol_schema.TaskAck(post_id=post_id)
return self.post(f"/api/v1/tasks/{task_id}/ack", req.dict())
def nack_task(self, task_id: str, reason: str) -> None:
req = protocol_schema.TaskNAck(reason=reason)
return self.post(f"/api/v1/tasks/{task_id}/nack", req.dict())
def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task:
data = self.post("/api/v1/tasks/interaction", interaction.dict())
return self._parse_task(data)
+130
View File
@@ -0,0 +1,130 @@
# -*- coding: utf-8 -*-
import asyncio
import enum
import typing as t
from typing import Optional, Type
from uuid import UUID
import aiohttp
from loguru import logger
from oasst_shared.schemas import protocol as protocol_schema
class TaskType(str, enum.Enum):
summarize_story = "summarize_story"
rate_summary = "rate_summary"
initial_prompt = "initial_prompt"
user_reply = "user_reply"
assistant_reply = "assistant_reply"
rank_initial_prompts = "rank_initial_prompts"
rank_user_replies = "rank_user_replies"
rank_assistant_replies = "rank_assistant_replies"
done = "task_done"
class OasstApiClient:
"""API Client for interacting with the OASST backend."""
def __init__(self, backend_url: str, api_key: str):
"""Create a new OasstApiClient.
Args:
backend_url (str): The base backend URL.
api_key (str): The API key to use for authentication.
"""
logger.debug("Opening OasstApiClient session")
self.session = aiohttp.ClientSession()
self.backend_url = backend_url
self.api_key = api_key
self.task_models_map: dict[str, Type[protocol_schema.Task]] = {
TaskType.summarize_story: protocol_schema.SummarizeStoryTask,
TaskType.rate_summary: protocol_schema.RateSummaryTask,
TaskType.initial_prompt: protocol_schema.InitialPromptTask,
TaskType.user_reply: protocol_schema.UserReplyTask,
TaskType.assistant_reply: protocol_schema.AssistantReplyTask,
TaskType.rank_initial_prompts: protocol_schema.RankInitialPromptsTask,
TaskType.rank_user_replies: protocol_schema.RankUserRepliesTask,
TaskType.rank_assistant_replies: protocol_schema.RankAssistantRepliesTask,
TaskType.done: protocol_schema.TaskDone,
}
async def post(self, path: str, data: dict[str, t.Any]) -> dict[str, t.Any]:
"""Make a POST request to the backend."""
logger.debug(f"POST {self.backend_url}{path} DATA: {data}")
response = await self.session.post(f"{self.backend_url}{path}", json=data, headers={"X-API-Key": self.api_key})
response.raise_for_status()
return await response.json()
def _parse_task(self, data: dict[str, t.Any]) -> protocol_schema.Task:
task_type = data.get("type")
if not isinstance(task_type, str):
logger.error(f"task type must be a `str`: {task_type}")
raise ValueError(f"task type must be a `str`: {task_type}")
model = self.task_models_map.get(task_type)
if not model:
logger.error(f"Unsupported task type: {task_type}")
raise ValueError(f"Unsupported task type: {task_type}")
return self.task_models_map[task_type].parse_obj(data)
async def fetch_task(
self, task_type: protocol_schema.TaskRequestType, user: Optional[protocol_schema.User] = None
) -> protocol_schema.Task:
"""Fetch a task from the backend."""
logger.debug(f"Fetching task {task_type} for user {user}")
req = protocol_schema.TaskRequest(type=task_type.value, user=user)
resp = await self.post(f"/api/v1/tasks/", data=req.dict())
print("resp", resp)
return self._parse_task(resp)
async def fetch_random_task(self, user: Optional[protocol_schema.User] = None) -> protocol_schema.Task:
"""Fetch a random task from the backend."""
logger.debug(f"Fetching random for user {user}")
return await self.fetch_task(protocol_schema.TaskRequestType.random, user)
async def ack_task(self, task_id: str | UUID, post_id: str):
"""Send an ACK for a task to the backend."""
logger.debug(f"ACK task {task_id} with post {post_id}")
req = protocol_schema.TaskAck(post_id=post_id)
return await self.post(f"/api/v1/tasks/{task_id}/ack", data=req.dict())
async def nack_task(self, task_id: str | UUID, reason: str):
"""Send a NACK for a task to the backend."""
logger.debug(f"NACK task {task_id} with reason {reason}")
req = protocol_schema.TaskNAck(reason=reason)
return await self.post(f"/api/v1/tasks/{task_id}/nack", data=req.dict())
async def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task:
"""Send a completed task to the backend."""
logger.debug(f"Interaction: {interaction}")
resp = await self.post("/api/v1/tasks/interaction", data=interaction.dict())
return self._parse_task(resp)
async def close(self):
logger.debug("Closing OasstApiClient session")
await self.session.close()
async def main():
api = OasstApiClient("http://localhost:8080", "test")
try:
task = await api.fetch_task(protocol_schema.TaskRequestType.initial_prompt, None)
print(task)
finally:
await api.close()
# session = aiohttp.ClientSession()
# try:
# resp = await session.post("http://localhost:8080/api/v1/tasks/", json={"type": "initial_prompt", "user": None})
# resp.raise_for_status()
# print(await resp.text())
# finally:
# await session.close()
if __name__ == "__main__":
asyncio.run(main())
+4
View File
@@ -6,6 +6,7 @@ import lightbulb
import miru
from bot.config import Config
from bot.api_client import OasstApiClient
config = Config.from_env()
@@ -29,8 +30,11 @@ async def on_starting(event: hikari.StartingEvent):
await bot.d.db.executescript(open("./bot/db/schema.sql").read())
await bot.d.db.commit()
bot.d.oasst_api = OasstApiClient("http://localhost:8080", "any_key")
@bot.listen()
async def on_stopping(event: hikari.StoppingEvent):
"""Cleanup."""
await bot.d.db.close()
await bot.d.oasst_api.close()
+6 -1
View File
@@ -1,5 +1,6 @@
# TODO: Convert file to markdown
# -*- coding: utf-8 -*-
"""Example plugins for reference.
"""Example plugin for reference.
Because this file starts with an `_`, it cannot be loaded by the bot. To see the example plugin in action, rename this file to `example.py`.
"""
@@ -396,6 +397,10 @@ async def modal_example(ctx: lightbulb.SlashContext) -> None:
await view.start(await resp.message())
# TODO: Database example
# TODO: Rest client example
def load(bot: lightbulb.BotApp):
"""Add the plugin to the bot."""
bot.add_plugin(plugin)
+1 -1
View File
@@ -15,7 +15,7 @@ EXTENSIONS_FOLDER = "bot/extensions"
def _get_extensions() -> list[str]:
# Recursively get all the .py files in the extensions directory not starting with an `_`.
exts = glob("bot/extensions/**/*[!_].py", recursive=True)
exts = glob("bot/extensions/**/[!_]*.py", recursive=True)
# Turn the path into a plugin path ("path/to/extension.py" -> "path.to.extension")
return [ext.replace("/", ".").replace("\\", ".").replace(".py", "") for ext in exts]
+302
View File
@@ -0,0 +1,302 @@
# -*- coding: utf-8 -*-
"""Task plugin for testing different data collection methods."""
import asyncio
import logging
import typing as t
from datetime import datetime, timedelta
import hikari
import lightbulb
import lightbulb.decorators
import miru
from bot.utils import format_time
from oasst_shared.schemas.protocol import TaskRequestType
plugin = lightbulb.Plugin("TaskPlugin")
MAX_TASK_TIME = 60 * 60
MAX_TASK_ACCEPT_TIME = 60
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@plugin.command
@lightbulb.option(
"type",
"The type of task to request.",
choices=[hikari.CommandChoice(name=task.split(".")[-1], value=task) for task in TaskRequestType],
required=False,
default=TaskRequestType.summarize_story,
type=str,
)
@lightbulb.command("task_thread", "Request a task from the backend.", ephemeral=True)
@lightbulb.implements(lightbulb.SlashCommand)
async def task_thread(ctx: lightbulb.SlashContext):
"""Request a task from the backend."""
typ: str = ctx.options.type
# Create a thread for the task
thread = await ctx.bot.rest.create_thread(ctx.channel_id, hikari.ChannelType.GUILD_PUBLIC_THREAD, f"Task: {typ}")
await ctx.respond(f"Please complete the task in the thread: {thread.mention}")
# Send the task in the thread
# TODO: Request task from the backend
await thread.send(
f"Please complete the task.\nSample Task\n\nSelf destruct {format_time(datetime.now() + timedelta(seconds=MAX_TASK_TIME), 'R')}"
)
# Wait for the user to respond
try:
event = await ctx.bot.wait_for(
hikari.GuildMessageCreateEvent,
timeout=MAX_TASK_TIME,
predicate=lambda e: e.author.id == ctx.author.id and e.channel_id == thread.id,
)
await ctx.respond(f"Received message: {event.message.content}")
# TODO: Send the message to the backend
except asyncio.TimeoutError:
await ctx.respond("You took too long to respond.")
finally:
await thread.delete()
@plugin.command
@lightbulb.option(
"type",
"The type of task to request.",
choices=[hikari.CommandChoice(name=task.split(".")[-1], value=task) for task in TaskRequestType],
required=False,
default=TaskRequestType.summarize_story,
type=str,
)
@lightbulb.command("task_dm", "Request a task from the backend.", ephemeral=True)
@lightbulb.implements(lightbulb.SlashCommand, lightbulb.PrefixCommand)
async def task_dm(ctx: lightbulb.Context):
"""Request a task from the backend."""
typ: str = ctx.options.type
# Create a thread for the task
await ctx.respond(f"Please complete the task in your DMs")
# Send the task in the thread
# TODO: Request task from the backend
await ctx.author.send(
f"Please complete the task.\nSample Task ({typ})\n\nSelf destruct {format_time(datetime.now() + timedelta(seconds=MAX_TASK_TIME), 'R')}"
)
# Wait for the user to respond
try:
event = await ctx.bot.wait_for(
hikari.DMMessageCreateEvent,
timeout=MAX_TASK_TIME,
predicate=lambda e: e.author.id == ctx.author.id,
)
await ctx.respond(f"Received message: {event.message.content}")
# TODO: Send the message to the backend
except asyncio.TimeoutError:
await ctx.respond("You took too long to respond.")
class TaskModal(miru.Modal):
"""Modal for submitting a task."""
response = miru.TextInput(
label="Response",
placeholder="Enter your response!",
required=True,
style=hikari.TextInputStyle.PARAGRAPH,
row=2,
)
async def callback(self, context: miru.ModalContext) -> None:
await context.respond(f"Received response: {self.response.value}", flags=hikari.MessageFlag.EPHEMERAL)
# TODO: Send the message to the backend
class ModalView(miru.View):
"""View for opening a modal."""
def __init__(self, modal_title: str, task: str, *args: t.Any, **kwargs: t.Any) -> None:
super().__init__(*args, **kwargs)
self.modal_title = modal_title
self.task = task
@miru.button(label="Start Task!", style=hikari.ButtonStyle.PRIMARY)
async def modal_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
modal = TaskModal(title=self.modal_title)
modal.add_item(miru.TextInput(label="Task", value=self.task, style=hikari.TextInputStyle.PARAGRAPH, row=1))
await ctx.respond_with_modal(modal)
@plugin.command
@lightbulb.option(
"type",
"The type of task to request.",
choices=[hikari.CommandChoice(name=task.split(".")[-1], value=task) for task in TaskRequestType],
required=False,
default=TaskRequestType.summarize_story,
type=str,
)
@lightbulb.command("task_modal", "Request a task from the backend.", ephemeral=True, auto_defer=True)
@lightbulb.implements(lightbulb.SlashCommand)
async def task_modal(ctx: lightbulb.SlashContext):
"""Request a task from the backend."""
# typ: str = ctx.options.type
view = ModalView(
modal_title=f"Assistant Response",
task="Please explain the moon landing to a six year old.",
timeout=MAX_TASK_TIME,
)
resp = await ctx.respond(
"Task - Respond to the prompt as if you were the Assistant:",
flags=hikari.MessageFlag.EPHEMERAL,
components=view,
)
await view.start(await resp.message())
class RatingView(miru.View):
"""View for rating a task."""
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
super().__init__(*args, **kwargs)
self.presses: list[str] = []
def _close_if_all_pressed(self) -> None:
if len(self.presses) == 5:
self.stop()
@miru.button(label="1", style=hikari.ButtonStyle.PRIMARY)
async def button_1(self, button: miru.Button, ctx: miru.ViewContext) -> None:
if button.label not in self.presses:
self.presses.append("1")
await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL)
self._close_if_all_pressed()
@miru.button(label="2", style=hikari.ButtonStyle.PRIMARY)
async def button_2(self, button: miru.Button, ctx: miru.ViewContext) -> None:
if button.label not in self.presses:
self.presses.append("2")
await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL)
self._close_if_all_pressed()
@miru.button(label="3", style=hikari.ButtonStyle.PRIMARY)
async def button_3(self, button: miru.Button, ctx: miru.ViewContext) -> None:
if button.label not in self.presses:
self.presses.append("3")
await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL)
self._close_if_all_pressed()
@miru.button(label="4", style=hikari.ButtonStyle.PRIMARY)
async def button_4(self, button: miru.Button, ctx: miru.ViewContext) -> None:
if button.label not in self.presses:
self.presses.append("4")
await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL)
self._close_if_all_pressed()
@miru.button(label="5", style=hikari.ButtonStyle.PRIMARY)
async def button_5(self, button: miru.Button, ctx: miru.ViewContext) -> None:
if button.label not in self.presses:
self.presses.append("5")
await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL)
self._close_if_all_pressed()
@miru.button(label="Reset", style=hikari.ButtonStyle.DANGER)
async def reset_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
self.presses = []
await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL)
class SelectRating(miru.View):
@miru.select(
options=[
hikari.SelectMenuOption(
label="1",
value="1",
description=None,
emoji=None,
is_default=False,
),
hikari.SelectMenuOption(
label="2",
value="2",
description=None,
emoji=None,
is_default=False,
),
hikari.SelectMenuOption(
label="3",
value="3",
description=None,
emoji=None,
is_default=False,
),
],
placeholder="Select the good responses",
min_values=0,
max_values=3,
row=3,
)
async def select(self, select: miru.Select, ctx: miru.ViewContext) -> None:
await ctx.respond(f"You selected {select.values}", flags=hikari.MessageFlag.EPHEMERAL)
@plugin.command
@lightbulb.command("rating_task", "Rate stuff.")
@lightbulb.implements(lightbulb.SlashCommand)
async def rating_task(ctx: lightbulb.SlashContext):
"""Rate stuff."""
# Message Based rating
await ctx.respond(
"List the responses in order of best to worst response (1,2,3,4,5)", flags=hikari.MessageFlag.EPHEMERAL
)
try:
event = await ctx.bot.wait_for(
hikari.MessageCreateEvent, timeout=MAX_TASK_TIME, predicate=lambda e: e.author.id == ctx.author.id
)
except asyncio.TimeoutError:
await ctx.respond("Timed out waiting for response")
return
if event.content is None:
await ctx.respond("No content in message")
return
ratings = event.content.replace(" ", "").split(",")
# Check if the ratings are valid
if len(ratings) != 5:
await ctx.respond("Invalid number of ratings")
if not all([rating in ("1", "2", "3", "4", "5") for rating in ratings]):
await ctx.respond("Invalid rating")
await ctx.respond(f"Your responses: {ratings}", flags=hikari.MessageFlag.EPHEMERAL)
# Button Based rating
view = RatingView(timeout=MAX_TASK_TIME)
resp = await ctx.respond("Click the buttons in order of best to worst response", components=view)
await view.start(await resp.message())
await view.wait()
await ctx.respond(f"Your responses: {view.presses}", flags=hikari.MessageFlag.EPHEMERAL)
await resp.delete()
# Select Based rating
select_view = SelectRating(timeout=MAX_TASK_TIME)
resp_2 = await ctx.respond("Select the good responses", components=select_view, flags=hikari.MessageFlag.EPHEMERAL)
await select_view.start(await resp_2.message())
await select_view.wait()
await resp_2.delete()
def load(bot: lightbulb.BotApp):
"""Add the plugin to the bot."""
bot.add_plugin(plugin)
def unload(bot: lightbulb.BotApp):
"""Remove the plugin to the bot."""
bot.remove_plugin(plugin)
+281
View File
@@ -0,0 +1,281 @@
# -*- coding: utf-8 -*-
"""Work plugin for collecting user data."""
import asyncio
import logging
import typing as t
from datetime import datetime
import hikari
import lightbulb
import lightbulb.decorators
import miru
from bot.api_client import OasstApiClient, TaskType
from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.schemas.protocol import TaskRequestType
from bot.utils import ZWJ
plugin = lightbulb.Plugin("WorkPlugin")
MAX_TASK_TIME = 60 * 60 # 1 hour
MAX_TASK_ACCEPT_TIME = 60 # 1 minute
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@plugin.command
@lightbulb.option(
"type",
"The type of task to request.",
choices=[hikari.CommandChoice(name=task.value, value=task) for task in TaskRequestType],
required=False,
default=str(TaskRequestType.rank_initial_prompts), # TODO: change back to random
type=str,
)
@lightbulb.command("work", "Complete a task.")
@lightbulb.implements(lightbulb.SlashCommand)
async def work(ctx: lightbulb.SlashContext):
"""Create and handle a task."""
task_type: TaskRequestType = TaskRequestType(ctx.options.type)
await ctx.respond("Sending you a task, check your DMs", flags=hikari.MessageFlag.EPHEMERAL)
logger.debug(f"task_type: {task_type!r}, task_type type {type(task_type)}")
await _handle_task(ctx, task_type)
async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType) -> None:
"""Handle creating and collecting user input for a task.
Continually present tasks to the user until they select one, cancel, or time out.
If they select one, present the task steps until a `task_done` task is received.
Finally, ask the user if they want to perform another task (of the same type).
"""
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
# Continue to complete tasks until the user doesn't want to do another
done = False
while not done:
# Loop until the user accepts a task
task, msg_id = await _select_task(ctx, task_type)
if task is None:
return
# Task action loop
completed = False
while not completed:
await ctx.author.send("Please type your response here:")
try:
event = await ctx.bot.wait_for(
hikari.DMMessageCreateEvent, timeout=MAX_TASK_TIME, predicate=lambda e: e.author.id == ctx.author.id
)
except asyncio.TimeoutError:
await ctx.author.send("Task timed out. Exiting")
# TODO: NACK task maybe?
return
# Invalid response
if event.content is None:
await ctx.author.send("No content in message")
continue
logger.info(f"User input received: {event.content}")
# Send the response to the backend
reply = protocol_schema.TextReplyToPost(
post_id=str(msg_id),
user_post_id=str(event.message_id),
user=protocol_schema.User(
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
),
text=event.content,
)
logger.debug(f"Sending reply to backend: {reply!r}")
# Get next task
new_task = await oasst_api.post_interaction(reply)
logger.info(f"New task {new_task}")
if new_task.type == TaskType.done:
await ctx.author.send("Task completed")
completed = True
continue
else:
logger.fatal(f"Unexpected task type received: {new_task.type}")
# ask the user if they want to do another task
choice_view = ChoiceView(timeout=MAX_TASK_ACCEPT_TIME)
msg = await ctx.author.send("Would you like another task?", components=choice_view)
await choice_view.start(msg)
await choice_view.wait()
match choice_view.choice:
case False | None:
done = True
await ctx.author.send("Exiting, goodbye!")
case True:
pass
async def _select_task(
ctx: lightbulb.SlashContext, task_type: TaskRequestType, user: protocol_schema.User | None = None
) -> tuple[protocol_schema.Task | None, str]:
"""Present tasks to the user until they accept one, cancel, or time out."""
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
logger.debug(f"Starting task selection for {task_type}")
# Loop until the user accepts a task, cancels, or times out
while True:
logger.debug(f"Requesting task of type {task_type}")
task = await oasst_api.fetch_task(task_type, user)
resp, msg_id = await _send_task(ctx, task)
logger.debug(f"user choice: {resp}")
match resp:
case "accept":
logger.info(f"Task {task.id} accepted, sending ACK")
await oasst_api.ack_task(task.id, msg_id)
return task, msg_id
case "next":
logger.info(f"Task {task.id} rejected, sending NACK")
await oasst_api.nack_task(task.id, "rejected")
await ctx.author.send("Sending next task...")
continue
case "cancel":
logger.info(f"Task {task.id} canceled, sending NACK")
await oasst_api.nack_task(task.id, "canceled")
await ctx.author.send("Task canceled. Exiting")
return None, msg_id
case None:
logger.info(f"Task {task.id} timed out, sending NACK")
await oasst_api.nack_task(task.id, "timed out")
await ctx.author.send("Task timed out. Exiting")
return None, msg_id
async def _send_task(
ctx: lightbulb.SlashContext, task: protocol_schema.Task
) -> tuple[t.Literal["accept", "next", "cancel"] | None, str]:
"""Send a task to the user.
Returns the user's choice and the message ID of the task message."""
# The clean way to do this would be to attach a `to_embed` method to the task classes
# but the tasks aren't discord specific so that doesn't really make sense.
view = TaskAcceptView(timeout=MAX_TASK_ACCEPT_TIME)
embed: hikari.UndefinedOr[hikari.Embed] = hikari.UNDEFINED
# Create an embed based on the task's type
if task.type == TaskRequestType.initial_prompt:
assert isinstance(task, protocol_schema.InitialPromptTask)
logger.info("sending initial prompt task")
embed = _initial_prompt_embed(task)
elif task.type == TaskRequestType.rank_initial_prompts:
assert isinstance(task, protocol_schema.RankInitialPromptsTask)
logger.info("sending rank initial prompt task")
embed = _rank_initial_prompt_embed(task)
else:
logger.error(f"unknown task type {task.type}")
msg = await ctx.author.send(
ZWJ,
embed=embed,
components=view,
)
assert msg is not None
await view.start(msg)
await view.wait()
return view.choice, str(msg.id)
def _initial_prompt_embed(task: protocol_schema.InitialPromptTask) -> hikari.Embed:
return (
hikari.Embed(title="Initial Prompt", description=f"Hint: {task.hint}", timestamp=datetime.now().astimezone())
.set_image(
"https://images.unsplash.com/photo-1455390582262-044cdead277a?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=1073&q=80",
)
.set_footer(text=f"OASST Assistant | {task.id}")
)
def _rank_initial_prompt_embed(task: protocol_schema.RankInitialPromptsTask) -> hikari.Embed:
embed = (
hikari.Embed(
title="Rank Initial Prompt",
description=f"Rank the following tasks from best to worst (1,2,3,4,5)",
timestamp=datetime.now().astimezone(),
)
.set_image(
"https://images.unsplash.com/photo-1455390582262-044cdead277a?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=1073&q=80",
)
.set_footer(text=f"OASST Assistant | {task.id}")
)
for i, prompt in enumerate(task.prompts):
embed.add_field(name=f"Prompt {i + 1}", value=prompt, inline=False)
return embed
class TaskAcceptView(miru.View):
"""View with three buttons: accept, next, and cancel.
The view stops once one of the buttons is pressed and the choice is stored in the `choice` attribute.
"""
choice: t.Literal["accept", "next", "cancel"] | None = None
@miru.button(label="Accept", custom_id="accept", row=0, style=hikari.ButtonStyle.SUCCESS)
async def accept_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
logger.info("Accept button pressed")
self.choice = "accept"
self.stop()
@miru.button(label="Next Task", custom_id="next_task", row=0, style=hikari.ButtonStyle.SECONDARY)
async def next_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
logger.info("Next button pressed")
self.choice = "next"
self.stop()
@miru.button(label="Cancel", custom_id="cancel", row=0, style=hikari.ButtonStyle.DANGER)
async def cancel_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
logger.info("Cancel button pressed")
self.choice = "cancel"
self.stop()
class ChoiceView(miru.View):
choice: bool | None = None
@miru.button(label="Yes", custom_id="yes", style=hikari.ButtonStyle.SUCCESS)
async def yes_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
self.choice = True
self.stop()
@miru.button(label="No", custom_id="no", style=hikari.ButtonStyle.DANGER)
async def no_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
self.choice = False
self.stop()
def load(bot: lightbulb.BotApp):
"""Add the plugin to the bot."""
bot.add_plugin(plugin)
def unload(bot: lightbulb.BotApp):
"""Remove the plugin to the bot."""
bot.remove_plugin(plugin)
+7
View File
@@ -21,3 +21,10 @@ def format_time(dt: datetime, fmt: t.Literal["t", "T", "D", "f", "F", "R"]) -> s
return f"<t:{dt.timestamp():.0f}:{fmt}>"
case _:
raise ValueError(f"`fmt` must be 't', 'T', 'D', 'f', 'F' or 'R', not {fmt}")
ZWJ = "\u200d"
"""Zero-width joiner.
This appears as an empty message in Discord.
"""
+2 -1
View File
@@ -7,4 +7,5 @@ hikari-miru # modals and buttons
python-dotenv # .env file support
aiosqlite # database
aiohttp # http client
aiohttp[speedups] # speedups for aiohttp
aiohttp[speedups] # speedups for aiohttp
loguru