mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-07-01 16:50:12 +08:00
add initial task loop for initial_prompt and rank_initial_prompts
This commit is contained in:
@@ -1,75 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import enum
|
||||
from typing import Optional, Type
|
||||
import typing as t
|
||||
|
||||
import requests
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
|
||||
|
||||
class TaskType(str, enum.Enum):
|
||||
summarize_story = "summarize_story"
|
||||
rate_summary = "rate_summary"
|
||||
initial_prompt = "initial_prompt"
|
||||
user_reply = "user_reply"
|
||||
assistant_reply = "assistant_reply"
|
||||
rank_initial_prompts = "rank_initial_prompts"
|
||||
rank_user_replies = "rank_user_replies"
|
||||
rank_assistant_replies = "rank_assistant_replies"
|
||||
done = "task_done"
|
||||
|
||||
|
||||
class ApiClient:
|
||||
def __init__(self, backend_url: str, api_key: str):
|
||||
self.backend_url = backend_url
|
||||
self.api_key = api_key
|
||||
|
||||
task_models_map: dict[str, Type[protocol_schema.Task]] = {
|
||||
TaskType.summarize_story: protocol_schema.SummarizeStoryTask,
|
||||
TaskType.rate_summary: protocol_schema.RateSummaryTask,
|
||||
TaskType.initial_prompt: protocol_schema.InitialPromptTask,
|
||||
TaskType.user_reply: protocol_schema.UserReplyTask,
|
||||
TaskType.assistant_reply: protocol_schema.AssistantReplyTask,
|
||||
TaskType.rank_initial_prompts: protocol_schema.RankInitialPromptsTask,
|
||||
TaskType.rank_user_replies: protocol_schema.RankUserRepliesTask,
|
||||
TaskType.rank_assistant_replies: protocol_schema.RankAssistantRepliesTask,
|
||||
TaskType.done: protocol_schema.TaskDone,
|
||||
}
|
||||
self.task_models_map = task_models_map
|
||||
|
||||
def post(self, path: str, json: dict) -> dict:
|
||||
response = requests.post(f"{self.backend_url}{path}", json=json, headers={"X-API-Key": self.api_key})
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def _parse_task(self, data: dict[str, t.Any]) -> protocol_schema.Task:
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError("dict expected")
|
||||
|
||||
task_type = data.get("type")
|
||||
if task_type not in self.task_models_map:
|
||||
raise RuntimeError(f"Unsupported task type: {task_type}")
|
||||
|
||||
return self.task_models_map[task_type].parse_obj(data)
|
||||
|
||||
def fetch_task(
|
||||
self, task_type: protocol_schema.TaskRequestType, user: Optional[protocol_schema.User] = None
|
||||
) -> protocol_schema.Task:
|
||||
req = protocol_schema.TaskRequest(type=task_type, user=user)
|
||||
data = self.post("/api/v1/tasks/", req.dict())
|
||||
return self._parse_task(data)
|
||||
|
||||
def fetch_random_task(self, user: Optional[protocol_schema.User] = None) -> protocol_schema.Task:
|
||||
return self.fetch_task(protocol_schema.TaskRequestType.random, user)
|
||||
|
||||
def ack_task(self, task_id: str, post_id: str) -> None:
|
||||
req = protocol_schema.TaskAck(post_id=post_id)
|
||||
return self.post(f"/api/v1/tasks/{task_id}/ack", req.dict())
|
||||
|
||||
def nack_task(self, task_id: str, reason: str) -> None:
|
||||
req = protocol_schema.TaskNAck(reason=reason)
|
||||
return self.post(f"/api/v1/tasks/{task_id}/nack", req.dict())
|
||||
|
||||
def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task:
|
||||
data = self.post("/api/v1/tasks/interaction", interaction.dict())
|
||||
return self._parse_task(data)
|
||||
@@ -0,0 +1,130 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import asyncio
|
||||
import enum
|
||||
import typing as t
|
||||
from typing import Optional, Type
|
||||
from uuid import UUID
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
|
||||
|
||||
class TaskType(str, enum.Enum):
|
||||
summarize_story = "summarize_story"
|
||||
rate_summary = "rate_summary"
|
||||
initial_prompt = "initial_prompt"
|
||||
user_reply = "user_reply"
|
||||
assistant_reply = "assistant_reply"
|
||||
rank_initial_prompts = "rank_initial_prompts"
|
||||
rank_user_replies = "rank_user_replies"
|
||||
rank_assistant_replies = "rank_assistant_replies"
|
||||
done = "task_done"
|
||||
|
||||
|
||||
class OasstApiClient:
|
||||
"""API Client for interacting with the OASST backend."""
|
||||
|
||||
def __init__(self, backend_url: str, api_key: str):
|
||||
"""Create a new OasstApiClient.
|
||||
|
||||
Args:
|
||||
backend_url (str): The base backend URL.
|
||||
api_key (str): The API key to use for authentication.
|
||||
"""
|
||||
logger.debug("Opening OasstApiClient session")
|
||||
self.session = aiohttp.ClientSession()
|
||||
self.backend_url = backend_url
|
||||
self.api_key = api_key
|
||||
|
||||
self.task_models_map: dict[str, Type[protocol_schema.Task]] = {
|
||||
TaskType.summarize_story: protocol_schema.SummarizeStoryTask,
|
||||
TaskType.rate_summary: protocol_schema.RateSummaryTask,
|
||||
TaskType.initial_prompt: protocol_schema.InitialPromptTask,
|
||||
TaskType.user_reply: protocol_schema.UserReplyTask,
|
||||
TaskType.assistant_reply: protocol_schema.AssistantReplyTask,
|
||||
TaskType.rank_initial_prompts: protocol_schema.RankInitialPromptsTask,
|
||||
TaskType.rank_user_replies: protocol_schema.RankUserRepliesTask,
|
||||
TaskType.rank_assistant_replies: protocol_schema.RankAssistantRepliesTask,
|
||||
TaskType.done: protocol_schema.TaskDone,
|
||||
}
|
||||
|
||||
async def post(self, path: str, data: dict[str, t.Any]) -> dict[str, t.Any]:
|
||||
"""Make a POST request to the backend."""
|
||||
logger.debug(f"POST {self.backend_url}{path} DATA: {data}")
|
||||
response = await self.session.post(f"{self.backend_url}{path}", json=data, headers={"X-API-Key": self.api_key})
|
||||
response.raise_for_status()
|
||||
return await response.json()
|
||||
|
||||
def _parse_task(self, data: dict[str, t.Any]) -> protocol_schema.Task:
|
||||
task_type = data.get("type")
|
||||
|
||||
if not isinstance(task_type, str):
|
||||
logger.error(f"task type must be a `str`: {task_type}")
|
||||
raise ValueError(f"task type must be a `str`: {task_type}")
|
||||
|
||||
model = self.task_models_map.get(task_type)
|
||||
if not model:
|
||||
logger.error(f"Unsupported task type: {task_type}")
|
||||
raise ValueError(f"Unsupported task type: {task_type}")
|
||||
return self.task_models_map[task_type].parse_obj(data)
|
||||
|
||||
async def fetch_task(
|
||||
self, task_type: protocol_schema.TaskRequestType, user: Optional[protocol_schema.User] = None
|
||||
) -> protocol_schema.Task:
|
||||
"""Fetch a task from the backend."""
|
||||
logger.debug(f"Fetching task {task_type} for user {user}")
|
||||
req = protocol_schema.TaskRequest(type=task_type.value, user=user)
|
||||
resp = await self.post(f"/api/v1/tasks/", data=req.dict())
|
||||
print("resp", resp)
|
||||
return self._parse_task(resp)
|
||||
|
||||
async def fetch_random_task(self, user: Optional[protocol_schema.User] = None) -> protocol_schema.Task:
|
||||
"""Fetch a random task from the backend."""
|
||||
logger.debug(f"Fetching random for user {user}")
|
||||
return await self.fetch_task(protocol_schema.TaskRequestType.random, user)
|
||||
|
||||
async def ack_task(self, task_id: str | UUID, post_id: str):
|
||||
"""Send an ACK for a task to the backend."""
|
||||
logger.debug(f"ACK task {task_id} with post {post_id}")
|
||||
req = protocol_schema.TaskAck(post_id=post_id)
|
||||
return await self.post(f"/api/v1/tasks/{task_id}/ack", data=req.dict())
|
||||
|
||||
async def nack_task(self, task_id: str | UUID, reason: str):
|
||||
"""Send a NACK for a task to the backend."""
|
||||
logger.debug(f"NACK task {task_id} with reason {reason}")
|
||||
req = protocol_schema.TaskNAck(reason=reason)
|
||||
return await self.post(f"/api/v1/tasks/{task_id}/nack", data=req.dict())
|
||||
|
||||
async def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task:
|
||||
"""Send a completed task to the backend."""
|
||||
logger.debug(f"Interaction: {interaction}")
|
||||
resp = await self.post("/api/v1/tasks/interaction", data=interaction.dict())
|
||||
|
||||
return self._parse_task(resp)
|
||||
|
||||
async def close(self):
|
||||
logger.debug("Closing OasstApiClient session")
|
||||
await self.session.close()
|
||||
|
||||
|
||||
async def main():
|
||||
api = OasstApiClient("http://localhost:8080", "test")
|
||||
try:
|
||||
task = await api.fetch_task(protocol_schema.TaskRequestType.initial_prompt, None)
|
||||
print(task)
|
||||
finally:
|
||||
|
||||
await api.close()
|
||||
# session = aiohttp.ClientSession()
|
||||
# try:
|
||||
# resp = await session.post("http://localhost:8080/api/v1/tasks/", json={"type": "initial_prompt", "user": None})
|
||||
# resp.raise_for_status()
|
||||
# print(await resp.text())
|
||||
# finally:
|
||||
# await session.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -6,6 +6,7 @@ import lightbulb
|
||||
import miru
|
||||
|
||||
from bot.config import Config
|
||||
from bot.api_client import OasstApiClient
|
||||
|
||||
config = Config.from_env()
|
||||
|
||||
@@ -29,8 +30,11 @@ async def on_starting(event: hikari.StartingEvent):
|
||||
await bot.d.db.executescript(open("./bot/db/schema.sql").read())
|
||||
await bot.d.db.commit()
|
||||
|
||||
bot.d.oasst_api = OasstApiClient("http://localhost:8080", "any_key")
|
||||
|
||||
|
||||
@bot.listen()
|
||||
async def on_stopping(event: hikari.StoppingEvent):
|
||||
"""Cleanup."""
|
||||
await bot.d.db.close()
|
||||
await bot.d.oasst_api.close()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# TODO: Convert file to markdown
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Example plugins for reference.
|
||||
"""Example plugin for reference.
|
||||
|
||||
Because this file starts with an `_`, it cannot be loaded by the bot. To see the example plugin in action, rename this file to `example.py`.
|
||||
"""
|
||||
@@ -396,6 +397,10 @@ async def modal_example(ctx: lightbulb.SlashContext) -> None:
|
||||
await view.start(await resp.message())
|
||||
|
||||
|
||||
# TODO: Database example
|
||||
# TODO: Rest client example
|
||||
|
||||
|
||||
def load(bot: lightbulb.BotApp):
|
||||
"""Add the plugin to the bot."""
|
||||
bot.add_plugin(plugin)
|
||||
|
||||
@@ -15,7 +15,7 @@ EXTENSIONS_FOLDER = "bot/extensions"
|
||||
|
||||
def _get_extensions() -> list[str]:
|
||||
# Recursively get all the .py files in the extensions directory not starting with an `_`.
|
||||
exts = glob("bot/extensions/**/*[!_].py", recursive=True)
|
||||
exts = glob("bot/extensions/**/[!_]*.py", recursive=True)
|
||||
# Turn the path into a plugin path ("path/to/extension.py" -> "path.to.extension")
|
||||
return [ext.replace("/", ".").replace("\\", ".").replace(".py", "") for ext in exts]
|
||||
|
||||
|
||||
@@ -0,0 +1,302 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Task plugin for testing different data collection methods."""
|
||||
import asyncio
|
||||
import logging
|
||||
import typing as t
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import hikari
|
||||
|
||||
import lightbulb
|
||||
import lightbulb.decorators
|
||||
import miru
|
||||
from bot.utils import format_time
|
||||
from oasst_shared.schemas.protocol import TaskRequestType
|
||||
|
||||
plugin = lightbulb.Plugin("TaskPlugin")
|
||||
|
||||
MAX_TASK_TIME = 60 * 60
|
||||
MAX_TASK_ACCEPT_TIME = 60
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
@plugin.command
|
||||
@lightbulb.option(
|
||||
"type",
|
||||
"The type of task to request.",
|
||||
choices=[hikari.CommandChoice(name=task.split(".")[-1], value=task) for task in TaskRequestType],
|
||||
required=False,
|
||||
default=TaskRequestType.summarize_story,
|
||||
type=str,
|
||||
)
|
||||
@lightbulb.command("task_thread", "Request a task from the backend.", ephemeral=True)
|
||||
@lightbulb.implements(lightbulb.SlashCommand)
|
||||
async def task_thread(ctx: lightbulb.SlashContext):
|
||||
"""Request a task from the backend."""
|
||||
typ: str = ctx.options.type
|
||||
|
||||
# Create a thread for the task
|
||||
thread = await ctx.bot.rest.create_thread(ctx.channel_id, hikari.ChannelType.GUILD_PUBLIC_THREAD, f"Task: {typ}")
|
||||
|
||||
await ctx.respond(f"Please complete the task in the thread: {thread.mention}")
|
||||
|
||||
# Send the task in the thread
|
||||
# TODO: Request task from the backend
|
||||
await thread.send(
|
||||
f"Please complete the task.\nSample Task\n\nSelf destruct {format_time(datetime.now() + timedelta(seconds=MAX_TASK_TIME), 'R')}"
|
||||
)
|
||||
|
||||
# Wait for the user to respond
|
||||
try:
|
||||
event = await ctx.bot.wait_for(
|
||||
hikari.GuildMessageCreateEvent,
|
||||
timeout=MAX_TASK_TIME,
|
||||
predicate=lambda e: e.author.id == ctx.author.id and e.channel_id == thread.id,
|
||||
)
|
||||
await ctx.respond(f"Received message: {event.message.content}")
|
||||
# TODO: Send the message to the backend
|
||||
except asyncio.TimeoutError:
|
||||
await ctx.respond("You took too long to respond.")
|
||||
finally:
|
||||
await thread.delete()
|
||||
|
||||
|
||||
@plugin.command
|
||||
@lightbulb.option(
|
||||
"type",
|
||||
"The type of task to request.",
|
||||
choices=[hikari.CommandChoice(name=task.split(".")[-1], value=task) for task in TaskRequestType],
|
||||
required=False,
|
||||
default=TaskRequestType.summarize_story,
|
||||
type=str,
|
||||
)
|
||||
@lightbulb.command("task_dm", "Request a task from the backend.", ephemeral=True)
|
||||
@lightbulb.implements(lightbulb.SlashCommand, lightbulb.PrefixCommand)
|
||||
async def task_dm(ctx: lightbulb.Context):
|
||||
"""Request a task from the backend."""
|
||||
typ: str = ctx.options.type
|
||||
|
||||
# Create a thread for the task
|
||||
|
||||
await ctx.respond(f"Please complete the task in your DMs")
|
||||
|
||||
# Send the task in the thread
|
||||
# TODO: Request task from the backend
|
||||
await ctx.author.send(
|
||||
f"Please complete the task.\nSample Task ({typ})\n\nSelf destruct {format_time(datetime.now() + timedelta(seconds=MAX_TASK_TIME), 'R')}"
|
||||
)
|
||||
|
||||
# Wait for the user to respond
|
||||
try:
|
||||
event = await ctx.bot.wait_for(
|
||||
hikari.DMMessageCreateEvent,
|
||||
timeout=MAX_TASK_TIME,
|
||||
predicate=lambda e: e.author.id == ctx.author.id,
|
||||
)
|
||||
await ctx.respond(f"Received message: {event.message.content}")
|
||||
# TODO: Send the message to the backend
|
||||
except asyncio.TimeoutError:
|
||||
await ctx.respond("You took too long to respond.")
|
||||
|
||||
|
||||
class TaskModal(miru.Modal):
|
||||
"""Modal for submitting a task."""
|
||||
|
||||
response = miru.TextInput(
|
||||
label="Response",
|
||||
placeholder="Enter your response!",
|
||||
required=True,
|
||||
style=hikari.TextInputStyle.PARAGRAPH,
|
||||
row=2,
|
||||
)
|
||||
|
||||
async def callback(self, context: miru.ModalContext) -> None:
|
||||
await context.respond(f"Received response: {self.response.value}", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
# TODO: Send the message to the backend
|
||||
|
||||
|
||||
class ModalView(miru.View):
|
||||
"""View for opening a modal."""
|
||||
|
||||
def __init__(self, modal_title: str, task: str, *args: t.Any, **kwargs: t.Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.modal_title = modal_title
|
||||
self.task = task
|
||||
|
||||
@miru.button(label="Start Task!", style=hikari.ButtonStyle.PRIMARY)
|
||||
async def modal_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
modal = TaskModal(title=self.modal_title)
|
||||
modal.add_item(miru.TextInput(label="Task", value=self.task, style=hikari.TextInputStyle.PARAGRAPH, row=1))
|
||||
await ctx.respond_with_modal(modal)
|
||||
|
||||
|
||||
@plugin.command
|
||||
@lightbulb.option(
|
||||
"type",
|
||||
"The type of task to request.",
|
||||
choices=[hikari.CommandChoice(name=task.split(".")[-1], value=task) for task in TaskRequestType],
|
||||
required=False,
|
||||
default=TaskRequestType.summarize_story,
|
||||
type=str,
|
||||
)
|
||||
@lightbulb.command("task_modal", "Request a task from the backend.", ephemeral=True, auto_defer=True)
|
||||
@lightbulb.implements(lightbulb.SlashCommand)
|
||||
async def task_modal(ctx: lightbulb.SlashContext):
|
||||
"""Request a task from the backend."""
|
||||
# typ: str = ctx.options.type
|
||||
view = ModalView(
|
||||
modal_title=f"Assistant Response",
|
||||
task="Please explain the moon landing to a six year old.",
|
||||
timeout=MAX_TASK_TIME,
|
||||
)
|
||||
resp = await ctx.respond(
|
||||
"Task - Respond to the prompt as if you were the Assistant:",
|
||||
flags=hikari.MessageFlag.EPHEMERAL,
|
||||
components=view,
|
||||
)
|
||||
await view.start(await resp.message())
|
||||
|
||||
|
||||
class RatingView(miru.View):
|
||||
"""View for rating a task."""
|
||||
|
||||
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.presses: list[str] = []
|
||||
|
||||
def _close_if_all_pressed(self) -> None:
|
||||
if len(self.presses) == 5:
|
||||
self.stop()
|
||||
|
||||
@miru.button(label="1", style=hikari.ButtonStyle.PRIMARY)
|
||||
async def button_1(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
if button.label not in self.presses:
|
||||
self.presses.append("1")
|
||||
await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
self._close_if_all_pressed()
|
||||
|
||||
@miru.button(label="2", style=hikari.ButtonStyle.PRIMARY)
|
||||
async def button_2(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
if button.label not in self.presses:
|
||||
self.presses.append("2")
|
||||
await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
self._close_if_all_pressed()
|
||||
|
||||
@miru.button(label="3", style=hikari.ButtonStyle.PRIMARY)
|
||||
async def button_3(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
if button.label not in self.presses:
|
||||
self.presses.append("3")
|
||||
await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
self._close_if_all_pressed()
|
||||
|
||||
@miru.button(label="4", style=hikari.ButtonStyle.PRIMARY)
|
||||
async def button_4(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
if button.label not in self.presses:
|
||||
self.presses.append("4")
|
||||
await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
self._close_if_all_pressed()
|
||||
|
||||
@miru.button(label="5", style=hikari.ButtonStyle.PRIMARY)
|
||||
async def button_5(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
if button.label not in self.presses:
|
||||
self.presses.append("5")
|
||||
await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
self._close_if_all_pressed()
|
||||
|
||||
@miru.button(label="Reset", style=hikari.ButtonStyle.DANGER)
|
||||
async def reset_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
self.presses = []
|
||||
await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
|
||||
|
||||
class SelectRating(miru.View):
|
||||
@miru.select(
|
||||
options=[
|
||||
hikari.SelectMenuOption(
|
||||
label="1",
|
||||
value="1",
|
||||
description=None,
|
||||
emoji=None,
|
||||
is_default=False,
|
||||
),
|
||||
hikari.SelectMenuOption(
|
||||
label="2",
|
||||
value="2",
|
||||
description=None,
|
||||
emoji=None,
|
||||
is_default=False,
|
||||
),
|
||||
hikari.SelectMenuOption(
|
||||
label="3",
|
||||
value="3",
|
||||
description=None,
|
||||
emoji=None,
|
||||
is_default=False,
|
||||
),
|
||||
],
|
||||
placeholder="Select the good responses",
|
||||
min_values=0,
|
||||
max_values=3,
|
||||
row=3,
|
||||
)
|
||||
async def select(self, select: miru.Select, ctx: miru.ViewContext) -> None:
|
||||
await ctx.respond(f"You selected {select.values}", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
|
||||
|
||||
@plugin.command
|
||||
@lightbulb.command("rating_task", "Rate stuff.")
|
||||
@lightbulb.implements(lightbulb.SlashCommand)
|
||||
async def rating_task(ctx: lightbulb.SlashContext):
|
||||
"""Rate stuff."""
|
||||
|
||||
# Message Based rating
|
||||
await ctx.respond(
|
||||
"List the responses in order of best to worst response (1,2,3,4,5)", flags=hikari.MessageFlag.EPHEMERAL
|
||||
)
|
||||
try:
|
||||
event = await ctx.bot.wait_for(
|
||||
hikari.MessageCreateEvent, timeout=MAX_TASK_TIME, predicate=lambda e: e.author.id == ctx.author.id
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
await ctx.respond("Timed out waiting for response")
|
||||
return
|
||||
|
||||
if event.content is None:
|
||||
await ctx.respond("No content in message")
|
||||
return
|
||||
ratings = event.content.replace(" ", "").split(",")
|
||||
|
||||
# Check if the ratings are valid
|
||||
if len(ratings) != 5:
|
||||
await ctx.respond("Invalid number of ratings")
|
||||
if not all([rating in ("1", "2", "3", "4", "5") for rating in ratings]):
|
||||
await ctx.respond("Invalid rating")
|
||||
|
||||
await ctx.respond(f"Your responses: {ratings}", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
# Button Based rating
|
||||
view = RatingView(timeout=MAX_TASK_TIME)
|
||||
|
||||
resp = await ctx.respond("Click the buttons in order of best to worst response", components=view)
|
||||
await view.start(await resp.message())
|
||||
await view.wait()
|
||||
await ctx.respond(f"Your responses: {view.presses}", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
await resp.delete()
|
||||
|
||||
# Select Based rating
|
||||
select_view = SelectRating(timeout=MAX_TASK_TIME)
|
||||
resp_2 = await ctx.respond("Select the good responses", components=select_view, flags=hikari.MessageFlag.EPHEMERAL)
|
||||
await select_view.start(await resp_2.message())
|
||||
await select_view.wait()
|
||||
await resp_2.delete()
|
||||
|
||||
|
||||
def load(bot: lightbulb.BotApp):
|
||||
"""Add the plugin to the bot."""
|
||||
bot.add_plugin(plugin)
|
||||
|
||||
|
||||
def unload(bot: lightbulb.BotApp):
|
||||
"""Remove the plugin to the bot."""
|
||||
bot.remove_plugin(plugin)
|
||||
@@ -0,0 +1,281 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Work plugin for collecting user data."""
|
||||
import asyncio
|
||||
import logging
|
||||
import typing as t
|
||||
from datetime import datetime
|
||||
|
||||
import hikari
|
||||
|
||||
import lightbulb
|
||||
import lightbulb.decorators
|
||||
import miru
|
||||
from bot.api_client import OasstApiClient, TaskType
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from oasst_shared.schemas.protocol import TaskRequestType
|
||||
from bot.utils import ZWJ
|
||||
|
||||
plugin = lightbulb.Plugin("WorkPlugin")
|
||||
|
||||
MAX_TASK_TIME = 60 * 60 # 1 hour
|
||||
MAX_TASK_ACCEPT_TIME = 60 # 1 minute
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
@plugin.command
|
||||
@lightbulb.option(
|
||||
"type",
|
||||
"The type of task to request.",
|
||||
choices=[hikari.CommandChoice(name=task.value, value=task) for task in TaskRequestType],
|
||||
required=False,
|
||||
default=str(TaskRequestType.rank_initial_prompts), # TODO: change back to random
|
||||
type=str,
|
||||
)
|
||||
@lightbulb.command("work", "Complete a task.")
|
||||
@lightbulb.implements(lightbulb.SlashCommand)
|
||||
async def work(ctx: lightbulb.SlashContext):
|
||||
"""Create and handle a task."""
|
||||
task_type: TaskRequestType = TaskRequestType(ctx.options.type)
|
||||
|
||||
await ctx.respond("Sending you a task, check your DMs", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
logger.debug(f"task_type: {task_type!r}, task_type type {type(task_type)}")
|
||||
|
||||
await _handle_task(ctx, task_type)
|
||||
|
||||
|
||||
async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType) -> None:
|
||||
"""Handle creating and collecting user input for a task.
|
||||
|
||||
Continually present tasks to the user until they select one, cancel, or time out.
|
||||
If they select one, present the task steps until a `task_done` task is received.
|
||||
Finally, ask the user if they want to perform another task (of the same type).
|
||||
"""
|
||||
|
||||
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
|
||||
|
||||
# Continue to complete tasks until the user doesn't want to do another
|
||||
done = False
|
||||
while not done:
|
||||
|
||||
# Loop until the user accepts a task
|
||||
task, msg_id = await _select_task(ctx, task_type)
|
||||
|
||||
if task is None:
|
||||
return
|
||||
|
||||
# Task action loop
|
||||
completed = False
|
||||
while not completed:
|
||||
await ctx.author.send("Please type your response here:")
|
||||
try:
|
||||
event = await ctx.bot.wait_for(
|
||||
hikari.DMMessageCreateEvent, timeout=MAX_TASK_TIME, predicate=lambda e: e.author.id == ctx.author.id
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
await ctx.author.send("Task timed out. Exiting")
|
||||
# TODO: NACK task maybe?
|
||||
return
|
||||
|
||||
# Invalid response
|
||||
if event.content is None:
|
||||
await ctx.author.send("No content in message")
|
||||
continue
|
||||
|
||||
logger.info(f"User input received: {event.content}")
|
||||
|
||||
# Send the response to the backend
|
||||
reply = protocol_schema.TextReplyToPost(
|
||||
post_id=str(msg_id),
|
||||
user_post_id=str(event.message_id),
|
||||
user=protocol_schema.User(
|
||||
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
|
||||
),
|
||||
text=event.content,
|
||||
)
|
||||
logger.debug(f"Sending reply to backend: {reply!r}")
|
||||
|
||||
# Get next task
|
||||
new_task = await oasst_api.post_interaction(reply)
|
||||
logger.info(f"New task {new_task}")
|
||||
|
||||
if new_task.type == TaskType.done:
|
||||
await ctx.author.send("Task completed")
|
||||
completed = True
|
||||
continue
|
||||
else:
|
||||
logger.fatal(f"Unexpected task type received: {new_task.type}")
|
||||
|
||||
# ask the user if they want to do another task
|
||||
choice_view = ChoiceView(timeout=MAX_TASK_ACCEPT_TIME)
|
||||
msg = await ctx.author.send("Would you like another task?", components=choice_view)
|
||||
await choice_view.start(msg)
|
||||
await choice_view.wait()
|
||||
|
||||
match choice_view.choice:
|
||||
case False | None:
|
||||
done = True
|
||||
await ctx.author.send("Exiting, goodbye!")
|
||||
case True:
|
||||
pass
|
||||
|
||||
|
||||
async def _select_task(
|
||||
ctx: lightbulb.SlashContext, task_type: TaskRequestType, user: protocol_schema.User | None = None
|
||||
) -> tuple[protocol_schema.Task | None, str]:
|
||||
"""Present tasks to the user until they accept one, cancel, or time out."""
|
||||
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
|
||||
logger.debug(f"Starting task selection for {task_type}")
|
||||
|
||||
# Loop until the user accepts a task, cancels, or times out
|
||||
while True:
|
||||
logger.debug(f"Requesting task of type {task_type}")
|
||||
task = await oasst_api.fetch_task(task_type, user)
|
||||
resp, msg_id = await _send_task(ctx, task)
|
||||
|
||||
logger.debug(f"user choice: {resp}")
|
||||
match resp:
|
||||
case "accept":
|
||||
logger.info(f"Task {task.id} accepted, sending ACK")
|
||||
await oasst_api.ack_task(task.id, msg_id)
|
||||
return task, msg_id
|
||||
|
||||
case "next":
|
||||
logger.info(f"Task {task.id} rejected, sending NACK")
|
||||
await oasst_api.nack_task(task.id, "rejected")
|
||||
await ctx.author.send("Sending next task...")
|
||||
continue
|
||||
|
||||
case "cancel":
|
||||
logger.info(f"Task {task.id} canceled, sending NACK")
|
||||
await oasst_api.nack_task(task.id, "canceled")
|
||||
await ctx.author.send("Task canceled. Exiting")
|
||||
return None, msg_id
|
||||
|
||||
case None:
|
||||
logger.info(f"Task {task.id} timed out, sending NACK")
|
||||
await oasst_api.nack_task(task.id, "timed out")
|
||||
await ctx.author.send("Task timed out. Exiting")
|
||||
return None, msg_id
|
||||
|
||||
|
||||
async def _send_task(
|
||||
ctx: lightbulb.SlashContext, task: protocol_schema.Task
|
||||
) -> tuple[t.Literal["accept", "next", "cancel"] | None, str]:
|
||||
"""Send a task to the user.
|
||||
|
||||
Returns the user's choice and the message ID of the task message."""
|
||||
|
||||
# The clean way to do this would be to attach a `to_embed` method to the task classes
|
||||
# but the tasks aren't discord specific so that doesn't really make sense.
|
||||
|
||||
view = TaskAcceptView(timeout=MAX_TASK_ACCEPT_TIME)
|
||||
embed: hikari.UndefinedOr[hikari.Embed] = hikari.UNDEFINED
|
||||
|
||||
# Create an embed based on the task's type
|
||||
if task.type == TaskRequestType.initial_prompt:
|
||||
assert isinstance(task, protocol_schema.InitialPromptTask)
|
||||
logger.info("sending initial prompt task")
|
||||
embed = _initial_prompt_embed(task)
|
||||
|
||||
elif task.type == TaskRequestType.rank_initial_prompts:
|
||||
assert isinstance(task, protocol_schema.RankInitialPromptsTask)
|
||||
logger.info("sending rank initial prompt task")
|
||||
embed = _rank_initial_prompt_embed(task)
|
||||
|
||||
else:
|
||||
logger.error(f"unknown task type {task.type}")
|
||||
|
||||
msg = await ctx.author.send(
|
||||
ZWJ,
|
||||
embed=embed,
|
||||
components=view,
|
||||
)
|
||||
|
||||
assert msg is not None
|
||||
|
||||
await view.start(msg)
|
||||
await view.wait()
|
||||
|
||||
return view.choice, str(msg.id)
|
||||
|
||||
|
||||
def _initial_prompt_embed(task: protocol_schema.InitialPromptTask) -> hikari.Embed:
|
||||
return (
|
||||
hikari.Embed(title="Initial Prompt", description=f"Hint: {task.hint}", timestamp=datetime.now().astimezone())
|
||||
.set_image(
|
||||
"https://images.unsplash.com/photo-1455390582262-044cdead277a?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=1073&q=80",
|
||||
)
|
||||
.set_footer(text=f"OASST Assistant | {task.id}")
|
||||
)
|
||||
|
||||
|
||||
def _rank_initial_prompt_embed(task: protocol_schema.RankInitialPromptsTask) -> hikari.Embed:
|
||||
embed = (
|
||||
hikari.Embed(
|
||||
title="Rank Initial Prompt",
|
||||
description=f"Rank the following tasks from best to worst (1,2,3,4,5)",
|
||||
timestamp=datetime.now().astimezone(),
|
||||
)
|
||||
.set_image(
|
||||
"https://images.unsplash.com/photo-1455390582262-044cdead277a?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=1073&q=80",
|
||||
)
|
||||
.set_footer(text=f"OASST Assistant | {task.id}")
|
||||
)
|
||||
|
||||
for i, prompt in enumerate(task.prompts):
|
||||
embed.add_field(name=f"Prompt {i + 1}", value=prompt, inline=False)
|
||||
|
||||
return embed
|
||||
|
||||
|
||||
class TaskAcceptView(miru.View):
|
||||
"""View with three buttons: accept, next, and cancel.
|
||||
|
||||
The view stops once one of the buttons is pressed and the choice is stored in the `choice` attribute.
|
||||
"""
|
||||
|
||||
choice: t.Literal["accept", "next", "cancel"] | None = None
|
||||
|
||||
@miru.button(label="Accept", custom_id="accept", row=0, style=hikari.ButtonStyle.SUCCESS)
|
||||
async def accept_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
logger.info("Accept button pressed")
|
||||
self.choice = "accept"
|
||||
self.stop()
|
||||
|
||||
@miru.button(label="Next Task", custom_id="next_task", row=0, style=hikari.ButtonStyle.SECONDARY)
|
||||
async def next_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
logger.info("Next button pressed")
|
||||
self.choice = "next"
|
||||
self.stop()
|
||||
|
||||
@miru.button(label="Cancel", custom_id="cancel", row=0, style=hikari.ButtonStyle.DANGER)
|
||||
async def cancel_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
logger.info("Cancel button pressed")
|
||||
self.choice = "cancel"
|
||||
self.stop()
|
||||
|
||||
|
||||
class ChoiceView(miru.View):
|
||||
choice: bool | None = None
|
||||
|
||||
@miru.button(label="Yes", custom_id="yes", style=hikari.ButtonStyle.SUCCESS)
|
||||
async def yes_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
self.choice = True
|
||||
self.stop()
|
||||
|
||||
@miru.button(label="No", custom_id="no", style=hikari.ButtonStyle.DANGER)
|
||||
async def no_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
self.choice = False
|
||||
self.stop()
|
||||
|
||||
|
||||
def load(bot: lightbulb.BotApp):
|
||||
"""Add the plugin to the bot."""
|
||||
bot.add_plugin(plugin)
|
||||
|
||||
|
||||
def unload(bot: lightbulb.BotApp):
|
||||
"""Remove the plugin to the bot."""
|
||||
bot.remove_plugin(plugin)
|
||||
@@ -21,3 +21,10 @@ def format_time(dt: datetime, fmt: t.Literal["t", "T", "D", "f", "F", "R"]) -> s
|
||||
return f"<t:{dt.timestamp():.0f}:{fmt}>"
|
||||
case _:
|
||||
raise ValueError(f"`fmt` must be 't', 'T', 'D', 'f', 'F' or 'R', not {fmt}")
|
||||
|
||||
|
||||
ZWJ = "\u200d"
|
||||
"""Zero-width joiner.
|
||||
|
||||
This appears as an empty message in Discord.
|
||||
"""
|
||||
|
||||
@@ -7,4 +7,5 @@ hikari-miru # modals and buttons
|
||||
python-dotenv # .env file support
|
||||
aiosqlite # database
|
||||
aiohttp # http client
|
||||
aiohttp[speedups] # speedups for aiohttp
|
||||
aiohttp[speedups] # speedups for aiohttp
|
||||
loguru
|
||||
Reference in New Issue
Block a user