344: Create tasks for text labels (#381)

* Implement label task for initial prompts and replies

* Resolve formatting

* Include missing argument

* Modify text_labels API to match new model, update DB schema accordingly

* Send valid labels as part of label tasks

* Send correctly formatted valid_labels list

* Fix request format

* Fix request details for text-frontend reply label task

* Include message_id in tasks

* Address review comments

* Fix alembic tree
This commit is contained in:
Oliver Stanley
2023-01-06 17:39:04 +00:00
committed by GitHub
parent 05c4550569
commit 69bc799cd9
11 changed files with 357 additions and 25 deletions
@@ -0,0 +1,30 @@
"""Added user to TextLabels
Revision ID: 20cd871f4ec7
Revises: d4161e384f83
Create Date: 2023-01-05 17:45:15.696468
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "20cd871f4ec7"
down_revision = "3b0adfadbef9"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("text_labels", sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False))
op.create_foreign_key(None, "text_labels", "user", ["user_id"], ["id"])
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(None, "text_labels", type_="foreignkey")
op.drop_column("text_labels", "user_id")
# ### end Alembic commands ###
+39
View File
@@ -119,6 +119,38 @@ def generate_task(
conversation=protocol_schema.Conversation(messages=task_messages),
replies=replies,
)
case protocol_schema.TaskRequestType.label_initial_prompt:
logger.info("Generating a LabelInitialPromptTask.")
message = pr.fetch_random_initial_prompts(1)[0]
task = protocol_schema.LabelInitialPromptTask(
message_id=message.id,
prompt=message.payload.payload.text,
valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)),
)
case protocol_schema.TaskRequestType.label_prompter_reply:
logger.info("Generating a LabelPrompterReplyTask.")
conversation, messages = pr.fetch_multiple_random_replies(max_size=1, message_role="assistant")
message = messages[0].payload.payload.text
task = protocol_schema.LabelPrompterReplyTask(
message_id=message.id,
conversation=conversation,
reply=message,
valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)),
)
case protocol_schema.TaskRequestType.label_assistant_reply:
logger.info("Generating a LabelAssistantReplyTask.")
conversation, messages = pr.fetch_multiple_random_replies(max_size=1, message_role="prompter")
message = messages[0].payload.payload.text
task = protocol_schema.LabelAssistantReplyTask(
message_id=message.id,
conversation=conversation,
reply=message,
valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)),
)
case _:
raise OasstError("Invalid request type", OasstErrorCode.TASK_INVALID_REQUEST_TYPE)
@@ -256,6 +288,13 @@ def tasks_interaction(
pr.store_ranking(interaction)
# here we would store the ranking in the database
return protocol_schema.TaskDone()
case protocol_schema.TextLabels:
logger.info(
f"Frontend reports labels of {interaction.message_id=} with {interaction.labels=} by {interaction.user=}."
)
# TODO: check if the labels are valid?
pr.store_text_labels(interaction)
return protocol_schema.TaskDone()
case _:
raise OasstError("Invalid response type.", OasstErrorCode.TASK_INVALID_RESPONSE_TYPE)
except OasstError:
+4 -10
View File
@@ -1,4 +1,3 @@
import pydantic
from fastapi import APIRouter, Depends, HTTPException
from fastapi.security.api_key import APIKey
from loguru import logger
@@ -11,17 +10,12 @@ from starlette.status import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST
router = APIRouter()
class LabelTextRequest(pydantic.BaseModel):
text_labels: protocol_schema.TextLabels
user: protocol_schema.User
@router.post("/", status_code=HTTP_204_NO_CONTENT)
def label_text(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
request: LabelTextRequest,
text_labels: protocol_schema.TextLabels,
) -> None:
"""
Label a piece of text.
@@ -29,9 +23,9 @@ def label_text(
api_client = deps.api_auth(api_key, db)
try:
logger.info(f"Labeling text {request=}.")
pr = PromptRepository(db, api_client, user=request.user)
pr.store_text_labels(request.text_labels)
logger.info(f"Labeling text {text_labels=}.")
pr = PromptRepository(db, api_client, user=text_labels.user)
pr.store_text_labels(text_labels)
except Exception:
logger.exception("Failed to store label.")
@@ -1,4 +1,5 @@
from typing import Literal
from uuid import UUID
from oasst_backend.models.payload_column_type import payload_type
from oasst_shared.schemas import protocol as protocol_schema
@@ -91,3 +92,37 @@ class RankAssistantRepliesPayload(RankConversationRepliesPayload):
"""A task to rank a set of assistant replies to a conversation."""
type: Literal["rank_assistant_replies"] = "rank_assistant_replies"
@payload_type
class LabelInitialPromptPayload(TaskPayload):
"""A task to label an initial prompt."""
type: Literal["label_initial_prompt"] = "label_initial_prompt"
message_id: UUID
prompt: str
valid_labels: list[str]
@payload_type
class LabelConversationReplyPayload(TaskPayload):
"""A task to label a conversation reply."""
message_id: UUID
conversation: protocol_schema.Conversation
reply: str
valid_labels: list[str]
@payload_type
class LabelPrompterReplyPayload(LabelConversationReplyPayload):
"""A task to label a prompter reply."""
type: Literal["label_prompter_reply"] = "label_prompter_reply"
@payload_type
class LabelAssistantReplyPayload(LabelConversationReplyPayload):
"""A task to label an assistant reply."""
type: Literal["label_assistant_reply"] = "label_assistant_reply"
@@ -15,6 +15,7 @@ class TextLabels(SQLModel, table=True):
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
),
)
user_id: UUID = Field(sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), nullable=False))
created_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
)
+29 -6
View File
@@ -282,16 +282,39 @@ class PromptRepository:
payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation)
case protocol_schema.RankInitialPromptsTask:
payload = db_payload.RankInitialPromptsPayload(tpye=task.type, prompts=task.prompts)
payload = db_payload.RankInitialPromptsPayload(type=task.type, prompts=task.prompts)
case protocol_schema.RankPrompterRepliesTask:
payload = db_payload.RankPrompterRepliesPayload(
tpye=task.type, conversation=task.conversation, replies=task.replies
type=task.type, conversation=task.conversation, replies=task.replies
)
case protocol_schema.RankAssistantRepliesTask:
payload = db_payload.RankAssistantRepliesPayload(
tpye=task.type, conversation=task.conversation, replies=task.replies
type=task.type, conversation=task.conversation, replies=task.replies
)
case protocol_schema.LabelInitialPromptTask:
payload = db_payload.LabelInitialPromptPayload(
type=task.type, message_id=task.message_id, prompt=task.prompt, valid_labels=task.valid_labels
)
case protocol_schema.LabelPrompterReplyTask:
payload = db_payload.LabelPrompterReplyPayload(
type=task.type,
message_id=task.message_id,
conversation=task.conversation,
reply=task.reply,
valid_labels=task.valid_labels,
)
case protocol_schema.LabelAssistantReplyTask:
payload = db_payload.LabelAssistantReplyPayload(
type=task.type,
message_id=task.message_id,
conversation=task.conversation,
reply=task.reply,
valid_labels=task.valid_labels,
)
case _:
@@ -388,12 +411,12 @@ class PromptRepository:
def store_text_labels(self, text_labels: protocol_schema.TextLabels) -> TextLabels:
model = TextLabels(
api_client_id=self.api_client.id,
message_id=text_labels.message_id,
user_id=self.user_id,
text=text_labels.text,
labels=text_labels.labels,
)
if text_labels.has_message_id:
self.fetch_message_by_frontend_message_id(text_labels.message_id, fail_if_missing=True)
model.message_id = text_labels.message_id
self.db.add(model)
self.db.commit()
self.db.refresh(model)
+52
View File
@@ -10,10 +10,14 @@ import miru
from aiosqlite import Connection
from bot.messages import (
assistant_reply_message,
confirm_label_response_message,
confirm_ranking_response_message,
confirm_text_response_message,
initial_prompt_message,
invalid_user_input_embed,
label_assistant_reply_message,
label_initial_prompt_message,
label_prompter_reply_message,
plain_embed,
prompter_reply_message,
rank_assistant_reply_message,
@@ -145,6 +149,8 @@ async def _handle_task(ctx: lightbulb.Context, task_type: TaskRequestType) -> No
content = confirm_ranking_response_message(event.content, task.replies)
elif isinstance(task, protocol_schema.RankInitialPromptsTask):
content = confirm_ranking_response_message(event.content, task.prompts)
elif isinstance(task, protocol_schema.LabelConversationReplyTask | protocol_schema.LabelInitialPromptTask):
content = confirm_label_response_message(event.content)
elif isinstance(task, protocol_schema.ReplyToConversationTask | protocol_schema.InitialPromptTask):
content = confirm_text_response_message(event.content)
else:
@@ -171,6 +177,17 @@ async def _handle_task(ctx: lightbulb.Context, task_type: TaskRequestType) -> No
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
),
)
elif isinstance(task, protocol_schema.LabelConversationReplyTask | protocol_schema.LabelInitialPromptTask):
labels = event.content.replace(" ", "").split(",")
labels_dict = {label: 1 if label in labels else 0 for label in task.valid_labels}
reply = protocol_schema.TextLabels(
message_id=task.message_id,
labels=labels_dict,
user=protocol_schema.User(
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
),
)
elif isinstance(task, protocol_schema.ReplyToConversationTask | protocol_schema.InitialPromptTask):
reply = protocol_schema.TextReplyToMessage(
message_id=str(msg_id),
@@ -300,6 +317,21 @@ async def _send_task(
logger.debug("sending rank assistant reply task")
content = rank_assistant_reply_message(task)
elif task.type == TaskRequestType.label_initial_prompt:
assert isinstance(task, protocol_schema.LabelInitialPromptTask)
logger.debug("sending label initial prompt task")
content = label_initial_prompt_message(task)
elif task.type == TaskRequestType.label_prompter_reply:
assert isinstance(task, protocol_schema.LabelPrompterReplyTask)
logger.debug("sending label prompter reply task")
content = label_prompter_reply_message(task)
elif task.type == TaskRequestType.label_assistant_reply:
assert isinstance(task, protocol_schema.LabelAssistantReplyTask)
logger.debug("sending label assistant reply task")
content = label_assistant_reply_message(task)
elif task.type == TaskRequestType.prompter_reply:
assert isinstance(task, protocol_schema.PrompterReplyTask)
logger.debug("sending user reply task")
@@ -382,6 +414,26 @@ def _validate_user_input(content: str | None, task: protocol_schema.Task) -> tup
"Message must contain numbers for all prompts.",
)
# Labels tasks
elif task.type in (
TaskRequestType.label_initial_prompt,
TaskRequestType.label_prompter_reply,
TaskRequestType.label_assistant_reply,
):
assert isinstance(
task,
protocol_schema.LabelInitialPromptTask
| protocol_schema.LabelPrompterReplyTask
| protocol_schema.LabelAssistantReplyTask,
)
labels = content.replace(" ", "").split(",")
valid_labels = set(task.valid_labels)
return (
set(labels).issubset(valid_labels),
"Message must only contain labels from predefined set of labels.",
)
elif task.type == TaskRequestType.summarize_story:
raise NotImplementedError
elif task.type == TaskRequestType.rate_summary:
+58
View File
@@ -33,6 +33,10 @@ def _ranking_prompt(text: str) -> str:
return f":trophy: _{text}_"
def _label_prompt(text: str) -> str:
return f":question: _{text}"
def _response_prompt(text: str) -> str:
return f":speech_balloon: _{text}_"
@@ -129,6 +133,49 @@ def rank_assistant_reply_message(task: protocol_schema.RankAssistantRepliesTask)
"""
def label_initial_prompt_message(task: protocol_schema.LabelInitialPromptTask) -> str:
"""Creates the message that gets sent to users when they request a `label_initial_prompt` task."""
return f"""\
{_h1("LABEL INITIAL PROMPT")}
{task.prompt}
{_label_prompt("Reply with labels for the prompt separated by commas (example: 'profanity,misleading')")}
"""
def label_prompter_reply_message(task: protocol_schema.LabelPrompterReplyTask) -> str:
"""Creates the message that gets sent to users when they request a `label_prompter_reply` task."""
return f"""\
{_h1("LABEL PROMPTER REPLY")}
{_conversation(task.conversation)}
{_user(None)}
{task.reply}
{_label_prompt("Reply with labels for the reply separated by commas (example: 'profanity,misleading')")}
"""
def label_assistant_reply_message(task: protocol_schema.LabelAssistantReplyTask) -> str:
"""Creates the message that gets sent to users when they request a `label_assistant_reply` task."""
return f"""\
{_h1("LABEL ASSISTANT REPLY")}
{_conversation(task.conversation)}
{_assistant(None)}
{task.reply}
{_label_prompt("Reply with labels for the reply separated by commas (example: 'profanity,misleading')")}
"""
def prompter_reply_message(task: protocol_schema.PrompterReplyTask) -> str:
"""Creates the message that gets sent to users when they request a `prompter_reply` task."""
return f"""\
@@ -175,6 +222,17 @@ def confirm_ranking_response_message(content: str, items: list[str]) -> str:
"""
def confirm_label_response_message(content: str) -> str:
user_labels = content.lower().replace(" ", "").split(",")
user_labels_str = ", ".join(user_labels)
return f"""\
{_h2("CONFIRM RESPONSE")}
{user_labels_str}
"""
###
# Embeds
###
+6
View File
@@ -24,6 +24,9 @@ class TaskType(str, enum.Enum):
rank_initial_prompts = "rank_initial_prompts"
rank_prompter_replies = "rank_prompter_replies"
rank_assistant_replies = "rank_assistant_replies"
label_initial_prompt = "label_initial_prompt"
label_assistant_reply = "label_assistant_reply"
label_prompter_reply = "label_prompter_reply"
done = "task_done"
@@ -56,6 +59,9 @@ class OasstApiClient:
TaskType.rank_initial_prompts: protocol_schema.RankInitialPromptsTask,
TaskType.rank_prompter_replies: protocol_schema.RankPrompterRepliesTask,
TaskType.rank_assistant_replies: protocol_schema.RankAssistantRepliesTask,
TaskType.label_initial_prompt: protocol_schema.LabelInitialPromptTask,
TaskType.label_prompter_reply: protocol_schema.LabelPrompterReplyTask,
TaskType.label_assistant_reply: protocol_schema.LabelAssistantReplyTask,
TaskType.done: protocol_schema.TaskDone,
}
+49 -9
View File
@@ -18,6 +18,9 @@ class TaskRequestType(str, enum.Enum):
rank_initial_prompts = "rank_initial_prompts"
rank_prompter_replies = "rank_prompter_replies"
rank_assistant_replies = "rank_assistant_replies"
label_initial_prompt = "label_initial_prompt"
label_assistant_reply = "label_assistant_reply"
label_prompter_reply = "label_prompter_reply"
class User(BaseModel):
@@ -169,6 +172,37 @@ class RankAssistantRepliesTask(RankConversationRepliesTask):
type: Literal["rank_assistant_replies"] = "rank_assistant_replies"
class LabelInitialPromptTask(Task):
"""A task to label an initial prompt."""
type: Literal["label_initial_prompt"] = "label_initial_prompt"
message_id: UUID
prompt: str
valid_labels: list[str]
class LabelConversationReplyTask(Task):
"""A task to label a reply to a conversation."""
type: Literal["label_conversation_reply"] = "label_conversation_reply"
conversation: Conversation # the conversation so far
message_id: UUID
reply: str
valid_labels: list[str]
class LabelPrompterReplyTask(LabelConversationReplyTask):
"""A task to label a prompter reply to a conversation."""
type: Literal["label_prompter_reply"] = "label_prompter_reply"
class LabelAssistantReplyTask(LabelConversationReplyTask):
"""A task to label an assistant reply to a conversation."""
type: Literal["label_assistant_reply"] = "label_assistant_reply"
class TaskDone(Task):
"""Signals to the frontend that the task is done."""
@@ -187,6 +221,10 @@ AnyTask = Union[
RankConversationRepliesTask,
RankPrompterRepliesTask,
RankAssistantRepliesTask,
LabelInitialPromptTask,
LabelConversationReplyTask,
LabelPrompterReplyTask,
LabelAssistantReplyTask,
]
@@ -222,13 +260,6 @@ class MessageRanking(Interaction):
ranking: conlist(item_type=int, min_items=1)
AnyInteraction = Union[
TextReplyToMessage,
MessageRating,
MessageRanking,
]
class TextLabel(str, enum.Enum):
"""A label for a piece of text."""
@@ -256,12 +287,13 @@ class TextLabel(str, enum.Enum):
slang = "slang"
class TextLabels(BaseModel):
class TextLabels(Interaction):
"""A set of labels for a piece of text."""
type: Literal["text_labels"] = "text_labels"
text: str
labels: dict[TextLabel, float]
message_id: str | None = None
message_id: UUID
@property
def has_message_id(self) -> bool:
@@ -277,6 +309,14 @@ class TextLabels(BaseModel):
return v
AnyInteraction = Union[
TextReplyToMessage,
MessageRating,
MessageRanking,
TextLabels,
]
class SystemStats(BaseModel):
all: int = 0
active: int = 0
+54
View File
@@ -203,6 +203,60 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
)
tasks.append(new_task)
case "label_initial_prompt":
typer.echo("Label the following prompt:")
typer.echo(task["prompt"])
# acknowledge task
message_id = _random_message_id()
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
valid_labels = task["valid_labels"]
labels_str: str = typer.prompt("Enter labels, separated by commas")
labels = labels_str.lower().replace(" ", "").split(",")
labels_dict = {label: "1" if label in labels else "0" for label in valid_labels}
# send ranking
new_task = _post(
"/api/v1/tasks/interaction",
{
"type": "text_labels",
"message_id": task["message_id"],
"text": task["prompt"],
"labels": labels_dict,
"user": USER,
},
)
tasks.append(new_task)
case "label_prompter_reply" | "label_assistant_reply":
typer.echo("Here is the conversation so far:")
for message in task["conversation"]["messages"]:
typer.echo(_render_message(message))
typer.echo("Label the following reply:")
typer.echo(task["reply"])
# acknowledge task
message_id = _random_message_id()
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
valid_labels = task["valid_labels"]
labels_str: str = typer.prompt("Enter labels, separated by commas")
labels = labels_str.lower().replace(" ", "").split(",")
labels_dict = {label: "1" if label in labels else "0" for label in valid_labels}
# send ranking
new_task = _post(
"/api/v1/tasks/interaction",
{
"type": "text_labels",
"message_id": task["message_id"],
"text": task["prompt"],
"labels": labels_dict,
"user": USER,
},
)
tasks.append(new_task)
case "task_done":
typer.echo("Task done!")
case _: