mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Merge branch 'bot_new_api' into extract-shared-python-code
This commit is contained in:
@@ -0,0 +1,15 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from bot_settings import settings
|
||||
|
||||
from bot import OpenAssistantBot
|
||||
|
||||
# invite bot url: https://discord.com/api/oauth2/authorize?client_id=1054078345542910022&permissions=1634235579456&scope=bot
|
||||
|
||||
if __name__ == "__main__":
|
||||
bot = OpenAssistantBot(
|
||||
settings.BOT_TOKEN,
|
||||
bot_channel_name=settings.BOT_CHANNEL_NAME,
|
||||
backend_url=settings.BACKEND_URL,
|
||||
api_key=settings.API_KEY,
|
||||
)
|
||||
bot.run()
|
||||
@@ -0,0 +1,74 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import enum
|
||||
from typing import Optional, Type
|
||||
|
||||
import requests
|
||||
from schemas import protocol as protocol_schema
|
||||
|
||||
|
||||
class TaskType(str, enum.Enum):
|
||||
summarize_story = "summarize_story"
|
||||
rate_summary = "rate_summary"
|
||||
initial_prompt = "initial_prompt"
|
||||
user_reply = "user_reply"
|
||||
assistant_reply = "assistant_reply"
|
||||
rank_initial_prompts = "rank_initial_prompts"
|
||||
rank_user_replies = "rank_user_replies"
|
||||
rank_assistant_replies = "rank_assistant_replies"
|
||||
done = "task_done"
|
||||
|
||||
|
||||
class ApiClient:
|
||||
def __init__(self, backend_url: str, api_key: str):
|
||||
self.backend_url = backend_url
|
||||
self.api_key = api_key
|
||||
|
||||
task_models_map: dict[str, Type[protocol_schema.Task]] = {
|
||||
TaskType.summarize_story: protocol_schema.SummarizeStoryTask,
|
||||
TaskType.rate_summary: protocol_schema.RateSummaryTask,
|
||||
TaskType.initial_prompt: protocol_schema.InitialPromptTask,
|
||||
TaskType.user_reply: protocol_schema.UserReplyTask,
|
||||
TaskType.assistant_reply: protocol_schema.AssistantReplyTask,
|
||||
TaskType.rank_initial_prompts: protocol_schema.RankInitialPromptsTask,
|
||||
TaskType.rank_user_replies: protocol_schema.RankUserRepliesTask,
|
||||
TaskType.rank_assistant_replies: protocol_schema.RankAssistantRepliesTask,
|
||||
TaskType.done: protocol_schema.TaskDone,
|
||||
}
|
||||
self.task_models_map = task_models_map
|
||||
|
||||
def post(self, path: str, json: dict) -> dict:
|
||||
response = requests.post(f"{self.backend_url}{path}", json=json, headers={"X-API-Key": self.api_key})
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def _parse_task(self, data: dict) -> protocol_schema.Task:
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError("dict expected")
|
||||
|
||||
task_type = data.get("type")
|
||||
if task_type not in self.task_models_map:
|
||||
raise RuntimeError(f"Unsupported task type: {task_type}")
|
||||
|
||||
return self.task_models_map[task_type].parse_obj(data)
|
||||
|
||||
def fetch_task(
|
||||
self, task_type: protocol_schema.TaskRequestType, user: Optional[protocol_schema.User] = None
|
||||
) -> protocol_schema.Task:
|
||||
req = protocol_schema.TaskRequest(type=task_type, user=user)
|
||||
data = self.post("/api/v1/tasks/", req.dict())
|
||||
return self._parse_task(data)
|
||||
|
||||
def fetch_random_task(self, user: Optional[protocol_schema.User] = None) -> protocol_schema.Task:
|
||||
return self.fetch_task(protocol_schema.TaskRequestType.random, user)
|
||||
|
||||
def ack_task(self, task_id: str, post_id: str) -> None:
|
||||
req = protocol_schema.TaskAck(post_id=post_id)
|
||||
return self.post(f"/api/v1/tasks/{task_id}/ack", req.dict())
|
||||
|
||||
def nack_task(self, task_id: str, reason: str) -> None:
|
||||
req = protocol_schema.TaskNAck(reason=reason)
|
||||
return self.post(f"/api/v1/tasks/{task_id}/nack", req.dict())
|
||||
|
||||
def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.TaskDone:
|
||||
data = self.post("/api/v1/tasks/interaction", interaction.dict())
|
||||
return self._parse_task(data)
|
||||
+221
-200
@@ -1,215 +1,236 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import json
|
||||
import os
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
import discord
|
||||
import requests
|
||||
from discord import app_commands
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
bot_url = "https://discord.com/api/oauth2/authorize?client_id=1051614245940375683&permissions=8&scope=bot"
|
||||
|
||||
# Load up all the important environment variables.
|
||||
load_dotenv()
|
||||
|
||||
# For authentication.
|
||||
TOKEN = os.getenv("DISCORD_TOKEN")
|
||||
|
||||
# For Backends.
|
||||
API_SERVER_URL = os.getenv("API_SERVER_URL")
|
||||
API_SERVER_KEY = os.getenv("API_SERVER_KEY")
|
||||
|
||||
labelers_url = f"{API_SERVER_URL}/api/v1/labelers/"
|
||||
prompts_url = f"{API_SERVER_URL}/api/v1/prompts/"
|
||||
headers = {"X-API-Key": API_SERVER_KEY}
|
||||
|
||||
# For testing only.
|
||||
TEST_GUILD = os.getenv("TEST_GUILD")
|
||||
TEST_GUILD_LAION = os.getenv("TEST_GUILD_LAION")
|
||||
# TEST_GUILD = False
|
||||
guild_ids = [TEST_GUILD, TEST_GUILD_LAION]
|
||||
from api_client import ApiClient, TaskType
|
||||
from schemas import protocol as protocol_schema
|
||||
|
||||
|
||||
# Initiate the client and command tree to create slash commands.
|
||||
class OpenAssistantClient(discord.Client):
|
||||
def __init__(self, *, intents: discord.Intents):
|
||||
super().__init__(intents=intents)
|
||||
self.tree = app_commands.CommandTree(self)
|
||||
class RatingButton(discord.ui.Button):
|
||||
def __init__(self, label, value, response_handler):
|
||||
super().__init__(label=label, style=discord.ButtonStyle.green)
|
||||
self.value = value
|
||||
self.response_handler = response_handler
|
||||
|
||||
async def callback(self, interaction):
|
||||
await self.response_handler(self.value, interaction)
|
||||
|
||||
|
||||
def generate_rating_view(lo: int, hi: int, response_handler) -> discord.ui.View:
|
||||
view = discord.ui.View()
|
||||
for i in range(lo, hi + 1):
|
||||
view.add_item(RatingButton(str(i), i, response_handler))
|
||||
return view
|
||||
|
||||
|
||||
class ModifiedClient(discord.Client):
|
||||
def __init__(self, *, intents: discord.Intents, **options: Any):
|
||||
super().__init__(intents=intents, **options)
|
||||
|
||||
async def setup_hook(self):
|
||||
if TEST_GUILD:
|
||||
# When testing the bot it's handy to run in a single server (called a
|
||||
# Guide in the API). This is relatively fast.
|
||||
for guild_id in guild_ids:
|
||||
guild = discord.Object(id=guild_id)
|
||||
self.tree.copy_global_to(guild=guild)
|
||||
await self.tree.sync(guild=guild)
|
||||
|
||||
# guild = discord.Object(id=TEST_GUILD)
|
||||
# self.tree.copy_global_to(guild=guild)
|
||||
# await self.tree.sync(guild=guild)
|
||||
else:
|
||||
# This can take up to an hour for the commands to be registered.
|
||||
await self.tree.sync()
|
||||
logger.debug("Ready!")
|
||||
print("setup")
|
||||
|
||||
|
||||
# List the set of intents needed for commands to operate properly.
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
client = OpenAssistantClient(intents=intents)
|
||||
class OpenAssistantBot:
|
||||
def __init__(self, bot_token: str, bot_channel_name: str, backend_url: str, api_key: str):
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
self.bot_token = bot_token
|
||||
client = ModifiedClient(intents=intents)
|
||||
self.client = client
|
||||
self.bot_channel: discord.TextChannel = None
|
||||
self.backend = ApiClient(backend_url, api_key)
|
||||
self.reply_handlers = {} # handlers by msg_id
|
||||
|
||||
@client.event
|
||||
async def on_ready():
|
||||
self.bot_channel = self.get_text_channel_by_name(bot_channel_name)
|
||||
|
||||
class LikeButton(discord.ui.Button):
|
||||
def __init__(self, label, channel, username, prompt):
|
||||
super().__init__(label=label, style=discord.ButtonStyle.green, emoji="👍")
|
||||
self.channel = channel
|
||||
self.username = username
|
||||
self.prompt = prompt
|
||||
client.loop.create_task(self.background_timer(), name="OpenAssistantBot.background_timer()")
|
||||
print(f"{client.user} is now running!")
|
||||
|
||||
async def callback(self, interaction):
|
||||
# interaction holds the interaction object
|
||||
# await interaction.response.defer()
|
||||
await interaction.response.send_message("Thanks for your feedback. You liked this 👍 ")
|
||||
|
||||
|
||||
class NeutralButton(discord.ui.Button):
|
||||
def __init__(self, label, channel, username, prompt):
|
||||
super().__init__(label=label, style=discord.ButtonStyle.green, emoji="😐")
|
||||
self.channel = channel
|
||||
self.username = username
|
||||
self.prompt = prompt
|
||||
|
||||
async def callback(self, interaction):
|
||||
# interaction holds the interaction object
|
||||
# await interaction.response.defer()
|
||||
await interaction.response.send_message("Thanks for your feedback. You thought this was neutral 😐 ")
|
||||
|
||||
|
||||
class DislikeButton(discord.ui.Button):
|
||||
def __init__(self, label, channel, username, prompt):
|
||||
super().__init__(label=label, style=discord.ButtonStyle.green, emoji="👎")
|
||||
self.channel = channel
|
||||
self.username = username
|
||||
self.prompt = prompt
|
||||
|
||||
async def callback(self, interaction):
|
||||
# interaction holds the interaction object
|
||||
# await interaction.response.defer()
|
||||
# send the feedback to the backend #
|
||||
await interaction.response.send_message("Thanks for your feedback. You disliked this 👎 ")
|
||||
|
||||
|
||||
@client.tree.command()
|
||||
async def register(interaction: discord.Interaction):
|
||||
"""Registers the user for submissions."""
|
||||
labeler = {
|
||||
"discord_username": f"{interaction.user.id}",
|
||||
"display_name": interaction.user.name,
|
||||
"is_enabled": True,
|
||||
}
|
||||
response = requests.post(labelers_url, headers=headers, json=labeler)
|
||||
if response.status_code == 200:
|
||||
await interaction.response.send_message(f"Added you {interaction.user.name}")
|
||||
else:
|
||||
logger.debug(response)
|
||||
await interaction.response.send_message("Failed to add you")
|
||||
|
||||
|
||||
@client.tree.command()
|
||||
async def list_participants(interaction: discord.Interaction):
|
||||
"""Reports the set of registered participants."""
|
||||
response = requests.get(labelers_url, headers=headers)
|
||||
if response.status_code == 200:
|
||||
names = ",".join([labeler["display_name"] for labeler in response.json()])
|
||||
await interaction.response.send_message(f"Found these users: {names}")
|
||||
else:
|
||||
await interaction.response.send_message("Failed to fetch participants")
|
||||
|
||||
|
||||
async def send_prompt_with_response_and_button(channel, username, prompt, response):
|
||||
await channel.send(f"What do you think about the following interaction: \nprompt: {prompt} \nresponse: {response}")
|
||||
# await channel.send(f'Please click on the button that best describes your reaction to the response:')
|
||||
|
||||
# add buttons
|
||||
view = discord.ui.View()
|
||||
like = LikeButton(label="Like", channel=channel, username=username, prompt=prompt)
|
||||
neutral = NeutralButton(label="Neutral", channel=channel, username=username, prompt=prompt)
|
||||
dislike = DislikeButton(label="Dislike", channel=channel, username=username, prompt=prompt)
|
||||
|
||||
view.add_item(item=like)
|
||||
view.add_item(item=neutral)
|
||||
view.add_item(item=dislike)
|
||||
await channel.send(view=view)
|
||||
|
||||
|
||||
@client.tree.command()
|
||||
async def review_prompts(interaction: discord.Interaction, number_of_prompts: int):
|
||||
# get the prompt from the db
|
||||
url = f"{prompts_url}?begin_id=0&limit={number_of_prompts}"
|
||||
response = requests.get(url, headers=headers)
|
||||
if response.status_code == 200:
|
||||
prompts = response.json()
|
||||
logger.debug("the responses are:", prompts)
|
||||
for prompt in prompts:
|
||||
await send_prompt_with_response_and_button(
|
||||
interaction.channel, interaction.user.name, prompt["prompt"], prompt["response"]
|
||||
)
|
||||
else:
|
||||
await interaction.response.send_message("Failed to get prompts for review")
|
||||
|
||||
|
||||
@client.tree.command()
|
||||
async def add_prompt(interaction: discord.Interaction, prompt: str, response: str, language: str = "en"):
|
||||
"""Uploads a single prompt to the server."""
|
||||
prompt = {
|
||||
"discord_username": f"{interaction.user.id}",
|
||||
"labeler_id": 5,
|
||||
"prompt": prompt,
|
||||
"response": response,
|
||||
"lang": language,
|
||||
}
|
||||
response = requests.post(prompts_url, headers=headers, json=prompt)
|
||||
if response.status_code == 200:
|
||||
await send_prompt_with_response_and_button(
|
||||
interaction.channel, interaction.user.name, prompt["prompt"], prompt["response"]
|
||||
)
|
||||
# send the prompt back with buttons for the user to click on
|
||||
# await interaction.response.send_message("Added your prompt")
|
||||
else:
|
||||
await interaction.response.send_message("Failed to add the prompt")
|
||||
|
||||
|
||||
@client.tree.command()
|
||||
async def add_prompts_set(interaction: discord.Interaction, prompts: discord.Attachment):
|
||||
"""Uploads a batch of prompts to the server."""
|
||||
# Loading a bunch of prompts from a file can take a while. So first defer
|
||||
# the response to ensure we're able to later tell the user what happened.
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
|
||||
# Read the prompts and load them one by one.
|
||||
# TODO: Upload a batch when the API supports it.
|
||||
# TODO: Handle incorrect file types and parsing errors.
|
||||
prompts_raw = await prompts.read()
|
||||
prompts_loaded = json.loads(prompts_raw)
|
||||
count = 0
|
||||
for entry in prompts_loaded:
|
||||
for response in entry["responses"]:
|
||||
prompt = {
|
||||
"discord_username": f"{interaction.user.id}",
|
||||
"labeler_id": 5,
|
||||
"prompt": entry["prompt"],
|
||||
"response": response,
|
||||
"lang": "en",
|
||||
}
|
||||
response = requests.post(prompts_url, headers=headers, json=prompt)
|
||||
if response.status_code != 200:
|
||||
await interaction.followup.send("Failed to upload")
|
||||
@client.event
|
||||
async def on_message(message: discord.Message):
|
||||
# ignore own messages
|
||||
if message.author == client.user:
|
||||
return
|
||||
count += 1
|
||||
await interaction.followup.send(f"Loaded up {count} prompts")
|
||||
|
||||
await self.handle_message(message)
|
||||
|
||||
client.run(TOKEN)
|
||||
async def generate_summarize_story(self, task: protocol_schema.SummarizeStoryTask):
|
||||
text = f"Summarize to the following story:\n{task.story}"
|
||||
msg: discord.Message = await self.bot_channel.send(text)
|
||||
|
||||
async def on_reply(message: discord.Message):
|
||||
print("on_summarize_story_reply", message)
|
||||
await message.reply("thx, on_summarize_story_reply")
|
||||
|
||||
self.reply_handlers[msg.id] = on_reply
|
||||
|
||||
return msg
|
||||
|
||||
async def generate_rate_summary(self, task: protocol_schema.RateSummaryTask):
|
||||
s = [
|
||||
"Rate the following summary:",
|
||||
task.summary,
|
||||
"Full text:",
|
||||
task.full_text,
|
||||
f"Rating scale: {task.scale.min} - {task.scale.max}",
|
||||
]
|
||||
text = "\n".join(s)
|
||||
|
||||
async def rating_response_handler(score, interaction: discord.Interaction):
|
||||
print("rating_response_handler", score)
|
||||
await interaction.response.send_message(f"got your feedback: {score}")
|
||||
|
||||
view = generate_rating_view(task.scale.min, task.scale.max, rating_response_handler)
|
||||
msg: discord.Message = await self.bot_channel.send(text, view=view)
|
||||
|
||||
async def on_reply(message: discord.Message):
|
||||
print("on_summary_reply", message)
|
||||
await message.reply("thx, on_summary_reply")
|
||||
|
||||
self.reply_handlers[msg.id] = on_reply
|
||||
|
||||
return msg
|
||||
|
||||
async def generate_initial_prompt(self, task: protocol_schema.InitialPromptTask):
|
||||
text = "Please provide an initial prompt to the assistant."
|
||||
if task.hint:
|
||||
text += f"\nHint: {task.hint}"
|
||||
msg: discord.Message = await self.bot_channel.send(text)
|
||||
|
||||
async def on_reply(message: discord.Message):
|
||||
print("on_initial_prompt_reply", message)
|
||||
await message.reply("thx, on_initial_prompt_reply")
|
||||
|
||||
self.reply_handlers[msg.id] = on_reply
|
||||
|
||||
return msg
|
||||
|
||||
def _render_message(self, message: protocol_schema.ConversationMessage) -> str:
|
||||
"""Render a message to the user."""
|
||||
if message.is_assistant:
|
||||
return f"Assistant: {message.text}"
|
||||
return f"User: {message.text}"
|
||||
|
||||
async def generate_user_reply(self, task: protocol_schema.UserReplyTask):
|
||||
s = ["Please provide a reply to the assistant.", "Here is the conversation so far:"]
|
||||
for message in task.conversation.messages:
|
||||
s.append(self._render_message(message))
|
||||
if task.hint:
|
||||
s.append(f"Hint: {task.hint}")
|
||||
text = "\n".join(s)
|
||||
msg: discord.Message = await self.bot_channel.send(text)
|
||||
|
||||
async def on_reply(message: discord.Message):
|
||||
print("on_user_reply_reply", message)
|
||||
await message.reply("thx, on_user_reply_reply")
|
||||
|
||||
self.reply_handlers[msg.id] = on_reply
|
||||
|
||||
return msg
|
||||
|
||||
async def generate_assistant_reply(self, task: protocol_schema.AssistantReplyTask):
|
||||
s = ["Act as the assistant and reply to the user.", "Here is the conversation so far:"]
|
||||
for message in task.conversation.messages:
|
||||
s.append(self._render_message(message))
|
||||
text = "\n".join(s)
|
||||
msg: discord.Message = await self.bot_channel.send(text)
|
||||
|
||||
async def on_reply(message: discord.Message):
|
||||
print("on_assistant_reply_reply", message)
|
||||
await message.reply("thx, on_assistant_reply_reply")
|
||||
|
||||
self.reply_handlers[msg.id] = on_reply
|
||||
|
||||
return msg
|
||||
|
||||
async def generate_rank_initial_prompts(self, task: protocol_schema.RankInitialPromptsTask):
|
||||
s = ["Rank the following prompts:"]
|
||||
for idx, prompt in enumerate(task.prompts, start=1):
|
||||
s.append(f"{idx}: {prompt}")
|
||||
text = "\n".join(s)
|
||||
msg: discord.Message = await self.bot_channel.send(text)
|
||||
|
||||
async def on_reply(message: discord.Message):
|
||||
print("on_rank_initial_prompts_reply", message)
|
||||
await message.reply("thx, on_rank_initial_prompts_reply")
|
||||
|
||||
self.reply_handlers[msg.id] = on_reply
|
||||
|
||||
return msg
|
||||
|
||||
async def generate_rank_conversation(self, task: protocol_schema.RankConversationRepliesTask):
|
||||
s = ["Here is the conversation so far:"]
|
||||
for message in task.conversation.messages:
|
||||
s.append(self._render_message(message))
|
||||
s.append("Rank the following replies:")
|
||||
for idx, reply in enumerate(task.replies, start=1):
|
||||
s.append(f"{idx}: {reply}")
|
||||
text = "\n".join(s)
|
||||
msg: discord.Message = await self.bot_channel.send(text)
|
||||
|
||||
async def on_reply(message: discord.Message):
|
||||
print("on_rrank_conversation_reply", message)
|
||||
message
|
||||
|
||||
self.reply_handlers[msg.id] = on_reply
|
||||
|
||||
return msg
|
||||
|
||||
async def next_task(self):
|
||||
task = self.backend.fetch_task(protocol_schema.TaskRequestType.rate_summary, user=None)
|
||||
# task = self.backend.fetch_random_task(user=None)
|
||||
|
||||
msg: discord.Message = None
|
||||
match task.type:
|
||||
case TaskType.summarize_story:
|
||||
msg = await self.generate_summarize_story(task)
|
||||
case TaskType.rate_summary:
|
||||
msg = await self.generate_rate_summary(task)
|
||||
case TaskType.initial_prompt:
|
||||
msg = await self.generate_initial_prompt(task)
|
||||
case TaskType.user_reply:
|
||||
msg = await self.generate_user_reply(task)
|
||||
case TaskType.assistant_reply:
|
||||
msg = await self.generate_assistant_reply(task)
|
||||
case TaskType.rank_initial_prompts:
|
||||
msg = await self.generate_rank_initial_prompts(task)
|
||||
case TaskType.rank_user_replies | TaskType.rank_assistant_replies:
|
||||
msg = await self.generate_rank_conversation(task)
|
||||
|
||||
if msg is not None:
|
||||
await self.backend.ack_task(task.id, msg.id)
|
||||
else:
|
||||
await self.backend.nack_task(task.id, "not supported")
|
||||
|
||||
async def background_timer(self):
|
||||
while True:
|
||||
if self.bot_channel:
|
||||
try:
|
||||
await self.next_task()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
await asyncio.sleep(30)
|
||||
|
||||
def run(self):
|
||||
"""Run bot loop blocking."""
|
||||
self.client.run(self.bot_token)
|
||||
|
||||
async def handle_message(self, message: discord.Message):
|
||||
user_id = message.author.id
|
||||
user_display_name = message.author.name
|
||||
|
||||
if message.reference:
|
||||
handler = self.reply_handlers.get(message.reference.message_id)
|
||||
if handler:
|
||||
await handler(message)
|
||||
|
||||
print(user_id, user_display_name, message.content, type(message.content))
|
||||
|
||||
def get_text_channel_by_name(self, channel_name) -> discord.TextChannel:
|
||||
for channel in self.client.get_all_channels():
|
||||
if channel.type == discord.ChannelType.text and channel.name == channel_name:
|
||||
return channel
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from pydantic import AnyHttpUrl, BaseSettings
|
||||
|
||||
|
||||
class BotSettings(BaseSettings):
|
||||
BACKEND_URL: AnyHttpUrl = "http://localhost:8080"
|
||||
API_KEY: str = "any_key"
|
||||
BOT_TOKEN: str
|
||||
BOT_CHANNEL_NAME: str = "bot"
|
||||
TEST_GUILD: str = None
|
||||
|
||||
|
||||
settings = BotSettings(_env_file=".env")
|
||||
@@ -1,2 +1,4 @@
|
||||
discord.py==2.1.0
|
||||
pydantic==1.9.1
|
||||
python-dotenv==0.21.0
|
||||
requests==2.28.1
|
||||
|
||||
@@ -0,0 +1,206 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import enum
|
||||
from typing import Literal, Optional, Union
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pydantic
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TaskRequestType(str, enum.Enum):
|
||||
random = "random"
|
||||
summarize_story = "summarize_story"
|
||||
rate_summary = "rate_summary"
|
||||
initial_prompt = "initial_prompt"
|
||||
user_reply = "user_reply"
|
||||
assistant_reply = "assistant_reply"
|
||||
rank_initial_prompts = "rank_initial_prompts"
|
||||
rank_user_replies = "rank_user_replies"
|
||||
rank_assistant_replies = "rank_assistant_replies"
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
id: str
|
||||
display_name: str
|
||||
auth_method: Literal["discord", "local"]
|
||||
|
||||
|
||||
class ConversationMessage(BaseModel):
|
||||
"""Represents a message in a conversation between the user and the assistant."""
|
||||
|
||||
text: str
|
||||
is_assistant: bool
|
||||
|
||||
|
||||
class Conversation(BaseModel):
|
||||
"""Represents a conversation between the user and the assistant."""
|
||||
|
||||
messages: list[ConversationMessage] = []
|
||||
|
||||
|
||||
class TaskRequest(BaseModel):
|
||||
"""The frontend asks the backend for a task."""
|
||||
|
||||
type: TaskRequestType = TaskRequestType.random
|
||||
user: Optional[User] = None
|
||||
|
||||
|
||||
class TaskAck(BaseModel):
|
||||
"""The frontend acknowledges that it has received a task and created a post."""
|
||||
|
||||
post_id: str
|
||||
|
||||
|
||||
class TaskNAck(BaseModel):
|
||||
"""The frontend acknowledges that it has received a task but cannot create a post."""
|
||||
|
||||
reason: str
|
||||
|
||||
|
||||
class Task(BaseModel):
|
||||
"""A task is a unit of work that the backend gives to the frontend."""
|
||||
|
||||
id: UUID = pydantic.Field(default_factory=uuid4)
|
||||
type: str
|
||||
|
||||
|
||||
class SummarizeStoryTask(Task):
|
||||
"""A task to summarize a story."""
|
||||
|
||||
type: Literal["summarize_story"] = "summarize_story"
|
||||
story: str
|
||||
|
||||
|
||||
class RatingScale(BaseModel):
|
||||
min: int
|
||||
max: int
|
||||
|
||||
|
||||
class AbstractRatingTask(Task):
|
||||
"""A task to rate something."""
|
||||
|
||||
scale: RatingScale = RatingScale(min=1, max=5)
|
||||
|
||||
|
||||
class RateSummaryTask(AbstractRatingTask):
|
||||
"""A task to rate a summary."""
|
||||
|
||||
type: Literal["rate_summary"] = "rate_summary"
|
||||
full_text: str
|
||||
summary: str
|
||||
|
||||
|
||||
class WithHintMixin(BaseModel):
|
||||
hint: str | None = None # provide a hint to the user to spark their imagination
|
||||
|
||||
|
||||
class InitialPromptTask(Task, WithHintMixin):
|
||||
"""A task to prompt the user to submit an initial prompt to the assistant."""
|
||||
|
||||
type: Literal["initial_prompt"] = "initial_prompt"
|
||||
|
||||
|
||||
class ReplyToConversationTask(Task):
|
||||
"""A task to prompt the user to submit a reply to a conversation."""
|
||||
|
||||
type: Literal["reply_to_conversation"] = "reply_to_conversation"
|
||||
conversation: Conversation # the conversation so far
|
||||
|
||||
|
||||
class UserReplyTask(ReplyToConversationTask, WithHintMixin):
|
||||
"""A task to prompt the user to submit a reply to the assistant."""
|
||||
|
||||
type: Literal["user_reply"] = "user_reply"
|
||||
|
||||
|
||||
class AssistantReplyTask(ReplyToConversationTask):
|
||||
"""A task to prompt the user to act as the assistant."""
|
||||
|
||||
type: Literal["assistant_reply"] = "assistant_reply"
|
||||
|
||||
|
||||
class RankInitialPromptsTask(Task):
|
||||
"""A task to rank a set of initial prompts."""
|
||||
|
||||
type: Literal["rank_initial_prompts"] = "rank_initial_prompts"
|
||||
prompts: list[str]
|
||||
|
||||
|
||||
class RankConversationRepliesTask(Task):
|
||||
"""A task to rank a set of replies to a conversation."""
|
||||
|
||||
type: Literal["rank_conversation_replies"] = "rank_conversation_replies"
|
||||
conversation: Conversation # the conversation so far
|
||||
replies: list[str]
|
||||
|
||||
|
||||
class RankUserRepliesTask(RankConversationRepliesTask):
|
||||
"""A task to rank a set of user replies to a conversation."""
|
||||
|
||||
type: Literal["rank_user_replies"] = "rank_user_replies"
|
||||
|
||||
|
||||
class RankAssistantRepliesTask(RankConversationRepliesTask):
|
||||
"""A task to rank a set of assistant replies to a conversation."""
|
||||
|
||||
type: Literal["rank_assistant_replies"] = "rank_assistant_replies"
|
||||
|
||||
|
||||
class TaskDone(Task):
|
||||
"""Signals to the frontend that the task is done."""
|
||||
|
||||
type: Literal["task_done"] = "task_done"
|
||||
|
||||
|
||||
AnyTask = Union[
|
||||
TaskDone,
|
||||
SummarizeStoryTask,
|
||||
RateSummaryTask,
|
||||
InitialPromptTask,
|
||||
ReplyToConversationTask,
|
||||
UserReplyTask,
|
||||
AssistantReplyTask,
|
||||
RankInitialPromptsTask,
|
||||
RankConversationRepliesTask,
|
||||
RankUserRepliesTask,
|
||||
RankAssistantRepliesTask,
|
||||
]
|
||||
|
||||
|
||||
class Interaction(BaseModel):
|
||||
"""An interaction is a user-generated action in the frontend."""
|
||||
|
||||
type: str
|
||||
user: User
|
||||
|
||||
|
||||
class TextReplyToPost(Interaction):
|
||||
"""A user has replied to a post with text."""
|
||||
|
||||
type: Literal["text_reply_to_post"] = "text_reply_to_post"
|
||||
post_id: str
|
||||
user_post_id: str
|
||||
text: str
|
||||
|
||||
|
||||
class PostRating(Interaction):
|
||||
"""A user has rated a post."""
|
||||
|
||||
type: Literal["post_rating"] = "post_rating"
|
||||
post_id: str
|
||||
rating: int
|
||||
|
||||
|
||||
class PostRanking(Interaction):
|
||||
"""A user has given a ranking for a post."""
|
||||
|
||||
type: Literal["post_ranking"] = "post_ranking"
|
||||
post_id: str
|
||||
ranking: list[int]
|
||||
|
||||
|
||||
AnyInteraction = Union[
|
||||
TextReplyToPost,
|
||||
PostRating,
|
||||
PostRanking,
|
||||
]
|
||||
Reference in New Issue
Block a user