diff --git a/backend/main.py b/backend/main.py index 29a0dc07..387d4e51 100644 --- a/backend/main.py +++ b/backend/main.py @@ -87,7 +87,7 @@ if settings.DEBUG_USE_SEED_DATA: user_message_id="6f1d0711", parent_message_id=None, text="Hi!", - role="user", + role="prompter", ), DummyMessage( task_message_id="74c381d4", @@ -101,14 +101,14 @@ if settings.DEBUG_USE_SEED_DATA: user_message_id="a8c01c04", parent_message_id="4a24530b", text="Do you have a recipe for potato soup?", - role="user", + role="prompter", ), DummyMessage( task_message_id="643716c1", user_message_id="f43a93b7", parent_message_id="4a24530b", text="Who were the 8 presidents before George Washington?", - role="user", + role="prompter", ), DummyMessage( task_message_id="2e4e1e6", @@ -122,7 +122,7 @@ if settings.DEBUG_USE_SEED_DATA: user_message_id="cec432cf", parent_message_id=None, text="euirdteunvglfe23908230892309832098 AAAAAAAA", - role="user", + role="prompter", ), DummyMessage( task_message_id="6066118e", diff --git a/backend/oasst_backend/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py index ee200958..bb8c2efc 100644 --- a/backend/oasst_backend/api/v1/tasks.py +++ b/backend/oasst_backend/api/v1/tasks.py @@ -54,7 +54,7 @@ def generate_task( task = protocol_schema.InitialPromptTask( hint="Ask the assistant about a current event." # this is optional ) - case protocol_schema.TaskRequestType.user_reply: + case protocol_schema.TaskRequestType.prompter_reply: logger.info("Generating a UserReplyTask.") messages = pr.fetch_random_conversation("assistant") task_messages = [ @@ -64,7 +64,7 @@ def generate_task( for msg in messages ] - task = protocol_schema.UserReplyTask(conversation=protocol_schema.Conversation(messages=task_messages)) + task = protocol_schema.PrompterReplyTask(conversation=protocol_schema.Conversation(messages=task_messages)) message_tree_id = messages[-1].message_tree_id parent_message_id = messages[-1].id case protocol_schema.TaskRequestType.assistant_reply: @@ -85,7 +85,7 @@ def generate_task( messages = pr.fetch_random_initial_prompts() task = protocol_schema.RankInitialPromptsTask(prompts=[msg.payload.payload.text for msg in messages]) - case protocol_schema.TaskRequestType.rank_user_replies: + case protocol_schema.TaskRequestType.rank_prompter_replies: logger.info("Generating a RankUserRepliesTask.") conversation, replies = pr.fetch_multiple_random_replies(message_role="assistant") @@ -97,7 +97,7 @@ def generate_task( for p in conversation ] replies = [p.payload.payload.text for p in replies] - task = protocol_schema.RankUserRepliesTask( + task = protocol_schema.RankPrompterRepliesTask( conversation=protocol_schema.Conversation( messages=task_messages, ), @@ -106,7 +106,7 @@ def generate_task( case protocol_schema.TaskRequestType.rank_assistant_replies: logger.info("Generating a RankAssistantRepliesTask.") - conversation, replies = pr.fetch_multiple_random_replies(message_role="user") + conversation, replies = pr.fetch_multiple_random_replies(message_role="prompter") task_messages = [ protocol_schema.ConversationMessage( diff --git a/backend/oasst_backend/models/db_payload.py b/backend/oasst_backend/models/db_payload.py index b44228e0..62dffa51 100644 --- a/backend/oasst_backend/models/db_payload.py +++ b/backend/oasst_backend/models/db_payload.py @@ -32,8 +32,8 @@ class InitialPromptPayload(TaskPayload): @payload_type -class UserReplyPayload(TaskPayload): - type: Literal["user_reply"] = "user_reply" +class PrompterReplyPayload(TaskPayload): + type: Literal["prompter_reply"] = "prompter_reply" conversation: protocol_schema.Conversation hint: str | None @@ -81,10 +81,10 @@ class RankInitialPromptsPayload(TaskPayload): @payload_type -class RankUserRepliesPayload(RankConversationRepliesPayload): - """A task to rank a set of user replies to a conversation.""" +class RankPrompterRepliesPayload(RankConversationRepliesPayload): + """A task to rank a set of prompter replies to a conversation.""" - type: Literal["rank_user_replies"] = "rank_user_replies" + type: Literal["rank_prompter_replies"] = "rank_prompter_replies" @payload_type diff --git a/backend/oasst_backend/models/message.py b/backend/oasst_backend/models/message.py index 37babdbb..1425ce98 100644 --- a/backend/oasst_backend/models/message.py +++ b/backend/oasst_backend/models/message.py @@ -23,7 +23,7 @@ class Message(SQLModel, table=True): message_tree_id: UUID = Field(nullable=False, index=True) task_id: UUID = Field(nullable=True, index=True) user_id: UUID = Field(nullable=True, foreign_key="user.id", index=True) - role: str = Field(nullable=False, max_length=128) + role: str = Field(nullable=False, max_length=128) # valid: "prompter" | "assistant" api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id") frontend_message_id: str = Field(max_length=200, nullable=False) created_date: Optional[datetime] = Field( diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index c741db05..5606ca69 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -132,7 +132,7 @@ class PromptRepository: raise OasstError("Task already done.", OasstErrorCode.TASK_ALREADY_DONE) # If there's no parent message assume user started new conversation - role = "user" + role = "prompter" depth = 0 if task.parent_message_id: @@ -142,7 +142,7 @@ class PromptRepository: depth = parent_message.depth + 1 if parent_message.role == "assistant": - role = "user" + role = "prompter" else: role = "assistant" @@ -206,7 +206,7 @@ class PromptRepository: match type(task_payload): - case db_payload.RankUserRepliesPayload | db_payload.RankAssistantRepliesPayload: + case db_payload.RankPrompterRepliesPayload | db_payload.RankAssistantRepliesPayload: # validate ranking num_replies = len(task_payload.replies) if sorted(ranking.ranking) != list(range(num_replies)): @@ -269,8 +269,8 @@ class PromptRepository: case protocol_schema.InitialPromptTask: payload = db_payload.InitialPromptPayload(hint=task.hint) - case protocol_schema.UserReplyTask: - payload = db_payload.UserReplyPayload(conversation=task.conversation, hint=task.hint) + case protocol_schema.PrompterReplyTask: + payload = db_payload.PrompterReplyPayload(conversation=task.conversation, hint=task.hint) case protocol_schema.AssistantReplyTask: payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation) @@ -278,8 +278,8 @@ class PromptRepository: case protocol_schema.RankInitialPromptsTask: payload = db_payload.RankInitialPromptsPayload(tpye=task.type, prompts=task.prompts) - case protocol_schema.RankUserRepliesTask: - payload = db_payload.RankUserRepliesPayload( + case protocol_schema.RankPrompterRepliesTask: + payload = db_payload.RankPrompterRepliesPayload( tpye=task.type, conversation=task.conversation, replies=task.replies ) diff --git a/discord-bot/api_client.py b/discord-bot/api_client.py index 4e9ce612..7c2e8d5a 100644 --- a/discord-bot/api_client.py +++ b/discord-bot/api_client.py @@ -10,10 +10,10 @@ class TaskType(str, enum.Enum): summarize_story = "summarize_story" rate_summary = "rate_summary" initial_prompt = "initial_prompt" - user_reply = "user_reply" + prompter_reply = "prompter_reply" assistant_reply = "assistant_reply" rank_initial_prompts = "rank_initial_prompts" - rank_user_replies = "rank_user_replies" + rank_prompter_replies = "rank_prompter_replies" rank_assistant_replies = "rank_assistant_replies" done = "task_done" @@ -27,10 +27,10 @@ class ApiClient: TaskType.summarize_story: protocol_schema.SummarizeStoryTask, TaskType.rate_summary: protocol_schema.RateSummaryTask, TaskType.initial_prompt: protocol_schema.InitialPromptTask, - TaskType.user_reply: protocol_schema.UserReplyTask, + TaskType.prompter_reply: protocol_schema.PrompterReplyTask, TaskType.assistant_reply: protocol_schema.AssistantReplyTask, TaskType.rank_initial_prompts: protocol_schema.RankInitialPromptsTask, - TaskType.rank_user_replies: protocol_schema.RankUserRepliesTask, + TaskType.rank_prompter_replies: protocol_schema.RankPrompterRepliesTask, TaskType.rank_assistant_replies: protocol_schema.RankAssistantRepliesTask, TaskType.done: protocol_schema.TaskDone, } diff --git a/discord-bot/bot.py b/discord-bot/bot.py index a19fdfe1..c54c4a5f 100644 --- a/discord-bot/bot.py +++ b/discord-bot/bot.py @@ -137,13 +137,13 @@ class OpenAssistantBot(BotBase): handler = task_handlers.RateSummaryHandler() case TaskType.initial_prompt: handler = task_handlers.InitialPromptHandler() - case TaskType.user_reply: - handler = task_handlers.UserReplyHandler() + case TaskType.prompter_reply: + handler = task_handlers.PrompterReplyHandler() case TaskType.assistant_reply: handler = task_handlers.AssistantReplyHandler() case TaskType.rank_initial_prompts: handler = task_handlers.RankInitialPromptsHandler() - case TaskType.rank_user_replies | TaskType.rank_assistant_replies: + case TaskType.rank_prompter_replies | TaskType.rank_assistant_replies: handler = task_handlers.RankConversationsHandler() case _: logger.warning(f"Unsupported task type received: {task.type}") diff --git a/discord-bot/task_handlers.py b/discord-bot/task_handlers.py index 9213ac30..488f91b1 100644 --- a/discord-bot/task_handlers.py +++ b/discord-bot/task_handlers.py @@ -146,15 +146,15 @@ class InitialPromptHandler(ChannelTaskBase): await self.handle_text_reply_to_post(msg) -class UserReplyHandler(ChannelTaskBase): - task: protocol_schema.UserReplyTask +class PrompterReplyHandler(ChannelTaskBase): + task: protocol_schema.PrompterReplyTask thread_name: str = "User replies" async def send_first_message(self) -> discord.message: - return await self.post_teaser_msg("teaser_user_reply.msg") + return await self.post_teaser_msg("teaser_prompter_reply.msg") async def on_thread_created(self, thread: discord.Thread) -> None: - await self.bot.post_template("task_user_reply.msg", channel=thread, task=self.task) + await self.bot.post_template("task_prompter_reply.msg", channel=thread, task=self.task) async def handler_loop(self): while True: diff --git a/discord-bot/templates/task_user_reply.msg b/discord-bot/templates/task_prompter_reply.msg similarity index 100% rename from discord-bot/templates/task_user_reply.msg rename to discord-bot/templates/task_prompter_reply.msg diff --git a/discord-bot/templates/teaser_user_reply.msg b/discord-bot/templates/teaser_prompter_reply.msg similarity index 100% rename from discord-bot/templates/teaser_user_reply.msg rename to discord-bot/templates/teaser_prompter_reply.msg diff --git a/oasst-shared/oasst_shared/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py index 59780d01..8fe8bdea 100644 --- a/oasst-shared/oasst_shared/schemas/protocol.py +++ b/oasst-shared/oasst_shared/schemas/protocol.py @@ -12,10 +12,10 @@ class TaskRequestType(str, enum.Enum): summarize_story = "summarize_story" rate_summary = "rate_summary" initial_prompt = "initial_prompt" - user_reply = "user_reply" + prompter_reply = "prompter_reply" assistant_reply = "assistant_reply" rank_initial_prompts = "rank_initial_prompts" - rank_user_replies = "rank_user_replies" + rank_prompter_replies = "rank_prompter_replies" rank_assistant_replies = "rank_assistant_replies" @@ -33,7 +33,7 @@ class ConversationMessage(BaseModel): class Conversation(BaseModel): - """Represents a conversation between the user and the assistant.""" + """Represents a conversation between the prompter and the assistant.""" messages: list[ConversationMessage] = [] @@ -114,10 +114,10 @@ class ReplyToConversationTask(Task): conversation: Conversation # the conversation so far -class UserReplyTask(ReplyToConversationTask, WithHintMixin): +class PrompterReplyTask(ReplyToConversationTask, WithHintMixin): """A task to prompt the user to submit a reply to the assistant.""" - type: Literal["user_reply"] = "user_reply" + type: Literal["prompter_reply"] = "prompter_reply" class AssistantReplyTask(ReplyToConversationTask): @@ -141,10 +141,10 @@ class RankConversationRepliesTask(Task): replies: list[str] -class RankUserRepliesTask(RankConversationRepliesTask): - """A task to rank a set of user replies to a conversation.""" +class RankPrompterRepliesTask(RankConversationRepliesTask): + """A task to rank a set of prompter replies to a conversation.""" - type: Literal["rank_user_replies"] = "rank_user_replies" + type: Literal["rank_prompter_replies"] = "rank_prompter_replies" class RankAssistantRepliesTask(RankConversationRepliesTask): @@ -165,11 +165,11 @@ AnyTask = Union[ RateSummaryTask, InitialPromptTask, ReplyToConversationTask, - UserReplyTask, + PrompterReplyTask, AssistantReplyTask, RankInitialPromptsTask, RankConversationRepliesTask, - RankUserRepliesTask, + RankPrompterRepliesTask, RankAssistantRepliesTask, ] diff --git a/text-frontend/__main__.py b/text-frontend/__main__.py index 54601c22..2bec4942 100644 --- a/text-frontend/__main__.py +++ b/text-frontend/__main__.py @@ -21,7 +21,7 @@ def _render_message(message: dict) -> str: """Render a message to the user.""" if message["is_assistant"]: return f"Assistant: {message['text']}" - return f"User: {message['text']}" + return f"Prompter: {message['text']}" @app.command() @@ -107,7 +107,7 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY") ) tasks.append(new_task) - case "user_reply": + case "prompter_reply": typer.echo("Please provide a reply to the assistant.") typer.echo("Here is the conversation so far:") for message in task["conversation"]["messages"]: @@ -178,7 +178,7 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY") ) tasks.append(new_task) - case "rank_user_replies" | "rank_assistant_replies": + case "rank_prompter_replies" | "rank_assistant_replies": typer.echo("Here is the conversation so far:") for message in task["conversation"]["messages"]: typer.echo(_render_message(message))