diff --git a/backend/alembic/versions/2023_01_05_1745-20cd871f4ec7_added_user_to_textlabels.py b/backend/alembic/versions/2023_01_05_1745-20cd871f4ec7_added_user_to_textlabels.py new file mode 100644 index 00000000..f042642e --- /dev/null +++ b/backend/alembic/versions/2023_01_05_1745-20cd871f4ec7_added_user_to_textlabels.py @@ -0,0 +1,30 @@ +"""Added user to TextLabels + +Revision ID: 20cd871f4ec7 +Revises: d4161e384f83 +Create Date: 2023-01-05 17:45:15.696468 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "20cd871f4ec7" +down_revision = "3b0adfadbef9" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("text_labels", sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False)) + op.create_foreign_key(None, "text_labels", "user", ["user_id"], ["id"]) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(None, "text_labels", type_="foreignkey") + op.drop_column("text_labels", "user_id") + # ### end Alembic commands ### diff --git a/backend/oasst_backend/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py index 05dc92a9..9f81eabb 100644 --- a/backend/oasst_backend/api/v1/tasks.py +++ b/backend/oasst_backend/api/v1/tasks.py @@ -119,6 +119,38 @@ def generate_task( conversation=protocol_schema.Conversation(messages=task_messages), replies=replies, ) + + case protocol_schema.TaskRequestType.label_initial_prompt: + logger.info("Generating a LabelInitialPromptTask.") + message = pr.fetch_random_initial_prompts(1)[0] + task = protocol_schema.LabelInitialPromptTask( + message_id=message.id, + prompt=message.payload.payload.text, + valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)), + ) + + case protocol_schema.TaskRequestType.label_prompter_reply: + logger.info("Generating a LabelPrompterReplyTask.") + conversation, messages = pr.fetch_multiple_random_replies(max_size=1, message_role="assistant") + message = messages[0].payload.payload.text + task = protocol_schema.LabelPrompterReplyTask( + message_id=message.id, + conversation=conversation, + reply=message, + valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)), + ) + + case protocol_schema.TaskRequestType.label_assistant_reply: + logger.info("Generating a LabelAssistantReplyTask.") + conversation, messages = pr.fetch_multiple_random_replies(max_size=1, message_role="prompter") + message = messages[0].payload.payload.text + task = protocol_schema.LabelAssistantReplyTask( + message_id=message.id, + conversation=conversation, + reply=message, + valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)), + ) + case _: raise OasstError("Invalid request type", OasstErrorCode.TASK_INVALID_REQUEST_TYPE) @@ -256,6 +288,13 @@ def tasks_interaction( pr.store_ranking(interaction) # here we would store the ranking in the database return protocol_schema.TaskDone() + case protocol_schema.TextLabels: + logger.info( + f"Frontend reports labels of {interaction.message_id=} with {interaction.labels=} by {interaction.user=}." + ) + # TODO: check if the labels are valid? + pr.store_text_labels(interaction) + return protocol_schema.TaskDone() case _: raise OasstError("Invalid response type.", OasstErrorCode.TASK_INVALID_RESPONSE_TYPE) except OasstError: diff --git a/backend/oasst_backend/api/v1/text_labels.py b/backend/oasst_backend/api/v1/text_labels.py index 0613711c..97422119 100644 --- a/backend/oasst_backend/api/v1/text_labels.py +++ b/backend/oasst_backend/api/v1/text_labels.py @@ -1,4 +1,3 @@ -import pydantic from fastapi import APIRouter, Depends, HTTPException from fastapi.security.api_key import APIKey from loguru import logger @@ -11,17 +10,12 @@ from starlette.status import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST router = APIRouter() -class LabelTextRequest(pydantic.BaseModel): - text_labels: protocol_schema.TextLabels - user: protocol_schema.User - - @router.post("/", status_code=HTTP_204_NO_CONTENT) def label_text( *, db: Session = Depends(deps.get_db), api_key: APIKey = Depends(deps.get_api_key), - request: LabelTextRequest, + text_labels: protocol_schema.TextLabels, ) -> None: """ Label a piece of text. @@ -29,9 +23,9 @@ def label_text( api_client = deps.api_auth(api_key, db) try: - logger.info(f"Labeling text {request=}.") - pr = PromptRepository(db, api_client, user=request.user) - pr.store_text_labels(request.text_labels) + logger.info(f"Labeling text {text_labels=}.") + pr = PromptRepository(db, api_client, user=text_labels.user) + pr.store_text_labels(text_labels) except Exception: logger.exception("Failed to store label.") diff --git a/backend/oasst_backend/models/db_payload.py b/backend/oasst_backend/models/db_payload.py index 9a6fabb6..fed60dd8 100644 --- a/backend/oasst_backend/models/db_payload.py +++ b/backend/oasst_backend/models/db_payload.py @@ -1,4 +1,5 @@ from typing import Literal +from uuid import UUID from oasst_backend.models.payload_column_type import payload_type from oasst_shared.schemas import protocol as protocol_schema @@ -91,3 +92,37 @@ class RankAssistantRepliesPayload(RankConversationRepliesPayload): """A task to rank a set of assistant replies to a conversation.""" type: Literal["rank_assistant_replies"] = "rank_assistant_replies" + + +@payload_type +class LabelInitialPromptPayload(TaskPayload): + """A task to label an initial prompt.""" + + type: Literal["label_initial_prompt"] = "label_initial_prompt" + message_id: UUID + prompt: str + valid_labels: list[str] + + +@payload_type +class LabelConversationReplyPayload(TaskPayload): + """A task to label a conversation reply.""" + + message_id: UUID + conversation: protocol_schema.Conversation + reply: str + valid_labels: list[str] + + +@payload_type +class LabelPrompterReplyPayload(LabelConversationReplyPayload): + """A task to label a prompter reply.""" + + type: Literal["label_prompter_reply"] = "label_prompter_reply" + + +@payload_type +class LabelAssistantReplyPayload(LabelConversationReplyPayload): + """A task to label an assistant reply.""" + + type: Literal["label_assistant_reply"] = "label_assistant_reply" diff --git a/backend/oasst_backend/models/text_labels.py b/backend/oasst_backend/models/text_labels.py index ec10dca6..e6878a87 100644 --- a/backend/oasst_backend/models/text_labels.py +++ b/backend/oasst_backend/models/text_labels.py @@ -15,6 +15,7 @@ class TextLabels(SQLModel, table=True): pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()") ), ) + user_id: UUID = Field(sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), nullable=False)) created_date: Optional[datetime] = Field( sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()), ) diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 157e42a7..7c7dd7b6 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -282,16 +282,39 @@ class PromptRepository: payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation) case protocol_schema.RankInitialPromptsTask: - payload = db_payload.RankInitialPromptsPayload(tpye=task.type, prompts=task.prompts) + payload = db_payload.RankInitialPromptsPayload(type=task.type, prompts=task.prompts) case protocol_schema.RankPrompterRepliesTask: payload = db_payload.RankPrompterRepliesPayload( - tpye=task.type, conversation=task.conversation, replies=task.replies + type=task.type, conversation=task.conversation, replies=task.replies ) case protocol_schema.RankAssistantRepliesTask: payload = db_payload.RankAssistantRepliesPayload( - tpye=task.type, conversation=task.conversation, replies=task.replies + type=task.type, conversation=task.conversation, replies=task.replies + ) + + case protocol_schema.LabelInitialPromptTask: + payload = db_payload.LabelInitialPromptPayload( + type=task.type, message_id=task.message_id, prompt=task.prompt, valid_labels=task.valid_labels + ) + + case protocol_schema.LabelPrompterReplyTask: + payload = db_payload.LabelPrompterReplyPayload( + type=task.type, + message_id=task.message_id, + conversation=task.conversation, + reply=task.reply, + valid_labels=task.valid_labels, + ) + + case protocol_schema.LabelAssistantReplyTask: + payload = db_payload.LabelAssistantReplyPayload( + type=task.type, + message_id=task.message_id, + conversation=task.conversation, + reply=task.reply, + valid_labels=task.valid_labels, ) case _: @@ -388,12 +411,12 @@ class PromptRepository: def store_text_labels(self, text_labels: protocol_schema.TextLabels) -> TextLabels: model = TextLabels( api_client_id=self.api_client.id, + message_id=text_labels.message_id, + user_id=self.user_id, text=text_labels.text, labels=text_labels.labels, ) - if text_labels.has_message_id: - self.fetch_message_by_frontend_message_id(text_labels.message_id, fail_if_missing=True) - model.message_id = text_labels.message_id + self.db.add(model) self.db.commit() self.db.refresh(model) diff --git a/discord-bot/bot/extensions/work.py b/discord-bot/bot/extensions/work.py index 0561039d..51daca3b 100644 --- a/discord-bot/bot/extensions/work.py +++ b/discord-bot/bot/extensions/work.py @@ -10,10 +10,14 @@ import miru from aiosqlite import Connection from bot.messages import ( assistant_reply_message, + confirm_label_response_message, confirm_ranking_response_message, confirm_text_response_message, initial_prompt_message, invalid_user_input_embed, + label_assistant_reply_message, + label_initial_prompt_message, + label_prompter_reply_message, plain_embed, prompter_reply_message, rank_assistant_reply_message, @@ -145,6 +149,8 @@ async def _handle_task(ctx: lightbulb.Context, task_type: TaskRequestType) -> No content = confirm_ranking_response_message(event.content, task.replies) elif isinstance(task, protocol_schema.RankInitialPromptsTask): content = confirm_ranking_response_message(event.content, task.prompts) + elif isinstance(task, protocol_schema.LabelConversationReplyTask | protocol_schema.LabelInitialPromptTask): + content = confirm_label_response_message(event.content) elif isinstance(task, protocol_schema.ReplyToConversationTask | protocol_schema.InitialPromptTask): content = confirm_text_response_message(event.content) else: @@ -171,6 +177,17 @@ async def _handle_task(ctx: lightbulb.Context, task_type: TaskRequestType) -> No auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username ), ) + elif isinstance(task, protocol_schema.LabelConversationReplyTask | protocol_schema.LabelInitialPromptTask): + labels = event.content.replace(" ", "").split(",") + labels_dict = {label: 1 if label in labels else 0 for label in task.valid_labels} + + reply = protocol_schema.TextLabels( + message_id=task.message_id, + labels=labels_dict, + user=protocol_schema.User( + auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username + ), + ) elif isinstance(task, protocol_schema.ReplyToConversationTask | protocol_schema.InitialPromptTask): reply = protocol_schema.TextReplyToMessage( message_id=str(msg_id), @@ -300,6 +317,21 @@ async def _send_task( logger.debug("sending rank assistant reply task") content = rank_assistant_reply_message(task) + elif task.type == TaskRequestType.label_initial_prompt: + assert isinstance(task, protocol_schema.LabelInitialPromptTask) + logger.debug("sending label initial prompt task") + content = label_initial_prompt_message(task) + + elif task.type == TaskRequestType.label_prompter_reply: + assert isinstance(task, protocol_schema.LabelPrompterReplyTask) + logger.debug("sending label prompter reply task") + content = label_prompter_reply_message(task) + + elif task.type == TaskRequestType.label_assistant_reply: + assert isinstance(task, protocol_schema.LabelAssistantReplyTask) + logger.debug("sending label assistant reply task") + content = label_assistant_reply_message(task) + elif task.type == TaskRequestType.prompter_reply: assert isinstance(task, protocol_schema.PrompterReplyTask) logger.debug("sending user reply task") @@ -382,6 +414,26 @@ def _validate_user_input(content: str | None, task: protocol_schema.Task) -> tup "Message must contain numbers for all prompts.", ) + # Labels tasks + elif task.type in ( + TaskRequestType.label_initial_prompt, + TaskRequestType.label_prompter_reply, + TaskRequestType.label_assistant_reply, + ): + assert isinstance( + task, + protocol_schema.LabelInitialPromptTask + | protocol_schema.LabelPrompterReplyTask + | protocol_schema.LabelAssistantReplyTask, + ) + + labels = content.replace(" ", "").split(",") + valid_labels = set(task.valid_labels) + return ( + set(labels).issubset(valid_labels), + "Message must only contain labels from predefined set of labels.", + ) + elif task.type == TaskRequestType.summarize_story: raise NotImplementedError elif task.type == TaskRequestType.rate_summary: diff --git a/discord-bot/bot/messages.py b/discord-bot/bot/messages.py index 8db54e37..c1a6d355 100644 --- a/discord-bot/bot/messages.py +++ b/discord-bot/bot/messages.py @@ -33,6 +33,10 @@ def _ranking_prompt(text: str) -> str: return f":trophy: _{text}_" +def _label_prompt(text: str) -> str: + return f":question: _{text}" + + def _response_prompt(text: str) -> str: return f":speech_balloon: _{text}_" @@ -129,6 +133,49 @@ def rank_assistant_reply_message(task: protocol_schema.RankAssistantRepliesTask) """ +def label_initial_prompt_message(task: protocol_schema.LabelInitialPromptTask) -> str: + """Creates the message that gets sent to users when they request a `label_initial_prompt` task.""" + return f"""\ + +{_h1("LABEL INITIAL PROMPT")} + + +{task.prompt} + +{_label_prompt("Reply with labels for the prompt separated by commas (example: 'profanity,misleading')")} +""" + + +def label_prompter_reply_message(task: protocol_schema.LabelPrompterReplyTask) -> str: + """Creates the message that gets sent to users when they request a `label_prompter_reply` task.""" + return f"""\ + +{_h1("LABEL PROMPTER REPLY")} + + +{_conversation(task.conversation)} +{_user(None)} +{task.reply} + +{_label_prompt("Reply with labels for the reply separated by commas (example: 'profanity,misleading')")} +""" + + +def label_assistant_reply_message(task: protocol_schema.LabelAssistantReplyTask) -> str: + """Creates the message that gets sent to users when they request a `label_assistant_reply` task.""" + return f"""\ + +{_h1("LABEL ASSISTANT REPLY")} + + +{_conversation(task.conversation)} +{_assistant(None)} +{task.reply} + +{_label_prompt("Reply with labels for the reply separated by commas (example: 'profanity,misleading')")} +""" + + def prompter_reply_message(task: protocol_schema.PrompterReplyTask) -> str: """Creates the message that gets sent to users when they request a `prompter_reply` task.""" return f"""\ @@ -175,6 +222,17 @@ def confirm_ranking_response_message(content: str, items: list[str]) -> str: """ +def confirm_label_response_message(content: str) -> str: + user_labels = content.lower().replace(" ", "").split(",") + user_labels_str = ", ".join(user_labels) + + return f"""\ +{_h2("CONFIRM RESPONSE")} + +{user_labels_str} +""" + + ### # Embeds ### diff --git a/oasst-shared/oasst_shared/api_client.py b/oasst-shared/oasst_shared/api_client.py index 404521db..1ee2865b 100644 --- a/oasst-shared/oasst_shared/api_client.py +++ b/oasst-shared/oasst_shared/api_client.py @@ -24,6 +24,9 @@ class TaskType(str, enum.Enum): rank_initial_prompts = "rank_initial_prompts" rank_prompter_replies = "rank_prompter_replies" rank_assistant_replies = "rank_assistant_replies" + label_initial_prompt = "label_initial_prompt" + label_assistant_reply = "label_assistant_reply" + label_prompter_reply = "label_prompter_reply" done = "task_done" @@ -56,6 +59,9 @@ class OasstApiClient: TaskType.rank_initial_prompts: protocol_schema.RankInitialPromptsTask, TaskType.rank_prompter_replies: protocol_schema.RankPrompterRepliesTask, TaskType.rank_assistant_replies: protocol_schema.RankAssistantRepliesTask, + TaskType.label_initial_prompt: protocol_schema.LabelInitialPromptTask, + TaskType.label_prompter_reply: protocol_schema.LabelPrompterReplyTask, + TaskType.label_assistant_reply: protocol_schema.LabelAssistantReplyTask, TaskType.done: protocol_schema.TaskDone, } diff --git a/oasst-shared/oasst_shared/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py index e035d387..1cafc93d 100644 --- a/oasst-shared/oasst_shared/schemas/protocol.py +++ b/oasst-shared/oasst_shared/schemas/protocol.py @@ -18,6 +18,9 @@ class TaskRequestType(str, enum.Enum): rank_initial_prompts = "rank_initial_prompts" rank_prompter_replies = "rank_prompter_replies" rank_assistant_replies = "rank_assistant_replies" + label_initial_prompt = "label_initial_prompt" + label_assistant_reply = "label_assistant_reply" + label_prompter_reply = "label_prompter_reply" class User(BaseModel): @@ -169,6 +172,37 @@ class RankAssistantRepliesTask(RankConversationRepliesTask): type: Literal["rank_assistant_replies"] = "rank_assistant_replies" +class LabelInitialPromptTask(Task): + """A task to label an initial prompt.""" + + type: Literal["label_initial_prompt"] = "label_initial_prompt" + message_id: UUID + prompt: str + valid_labels: list[str] + + +class LabelConversationReplyTask(Task): + """A task to label a reply to a conversation.""" + + type: Literal["label_conversation_reply"] = "label_conversation_reply" + conversation: Conversation # the conversation so far + message_id: UUID + reply: str + valid_labels: list[str] + + +class LabelPrompterReplyTask(LabelConversationReplyTask): + """A task to label a prompter reply to a conversation.""" + + type: Literal["label_prompter_reply"] = "label_prompter_reply" + + +class LabelAssistantReplyTask(LabelConversationReplyTask): + """A task to label an assistant reply to a conversation.""" + + type: Literal["label_assistant_reply"] = "label_assistant_reply" + + class TaskDone(Task): """Signals to the frontend that the task is done.""" @@ -187,6 +221,10 @@ AnyTask = Union[ RankConversationRepliesTask, RankPrompterRepliesTask, RankAssistantRepliesTask, + LabelInitialPromptTask, + LabelConversationReplyTask, + LabelPrompterReplyTask, + LabelAssistantReplyTask, ] @@ -222,13 +260,6 @@ class MessageRanking(Interaction): ranking: conlist(item_type=int, min_items=1) -AnyInteraction = Union[ - TextReplyToMessage, - MessageRating, - MessageRanking, -] - - class TextLabel(str, enum.Enum): """A label for a piece of text.""" @@ -256,12 +287,13 @@ class TextLabel(str, enum.Enum): slang = "slang" -class TextLabels(BaseModel): +class TextLabels(Interaction): """A set of labels for a piece of text.""" + type: Literal["text_labels"] = "text_labels" text: str labels: dict[TextLabel, float] - message_id: str | None = None + message_id: UUID @property def has_message_id(self) -> bool: @@ -277,6 +309,14 @@ class TextLabels(BaseModel): return v +AnyInteraction = Union[ + TextReplyToMessage, + MessageRating, + MessageRanking, + TextLabels, +] + + class SystemStats(BaseModel): all: int = 0 active: int = 0 diff --git a/text-frontend/__main__.py b/text-frontend/__main__.py index 39cc7b26..de65749a 100644 --- a/text-frontend/__main__.py +++ b/text-frontend/__main__.py @@ -203,6 +203,60 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY") ) tasks.append(new_task) + case "label_initial_prompt": + typer.echo("Label the following prompt:") + typer.echo(task["prompt"]) + # acknowledge task + message_id = _random_message_id() + _post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id}) + + valid_labels = task["valid_labels"] + labels_str: str = typer.prompt("Enter labels, separated by commas") + labels = labels_str.lower().replace(" ", "").split(",") + labels_dict = {label: "1" if label in labels else "0" for label in valid_labels} + + # send ranking + new_task = _post( + "/api/v1/tasks/interaction", + { + "type": "text_labels", + "message_id": task["message_id"], + "text": task["prompt"], + "labels": labels_dict, + "user": USER, + }, + ) + tasks.append(new_task) + + case "label_prompter_reply" | "label_assistant_reply": + typer.echo("Here is the conversation so far:") + for message in task["conversation"]["messages"]: + typer.echo(_render_message(message)) + + typer.echo("Label the following reply:") + typer.echo(task["reply"]) + # acknowledge task + message_id = _random_message_id() + _post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id}) + + valid_labels = task["valid_labels"] + labels_str: str = typer.prompt("Enter labels, separated by commas") + labels = labels_str.lower().replace(" ", "").split(",") + labels_dict = {label: "1" if label in labels else "0" for label in valid_labels} + + # send ranking + new_task = _post( + "/api/v1/tasks/interaction", + { + "type": "text_labels", + "message_id": task["message_id"], + "text": task["prompt"], + "labels": labels_dict, + "user": USER, + }, + ) + tasks.append(new_task) + case "task_done": typer.echo("Task done!") case _: