mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
344: Create tasks for text labels (#381)
* Implement label task for initial prompts and replies * Resolve formatting * Include missing argument * Modify text_labels API to match new model, update DB schema accordingly * Send valid labels as part of label tasks * Send correctly formatted valid_labels list * Fix request format * Fix request details for text-frontend reply label task * Include message_id in tasks * Address review comments * Fix alembic tree
This commit is contained in:
@@ -0,0 +1,30 @@
|
||||
"""Added user to TextLabels
|
||||
|
||||
Revision ID: 20cd871f4ec7
|
||||
Revises: d4161e384f83
|
||||
Create Date: 2023-01-05 17:45:15.696468
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "20cd871f4ec7"
|
||||
down_revision = "3b0adfadbef9"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("text_labels", sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False))
|
||||
op.create_foreign_key(None, "text_labels", "user", ["user_id"], ["id"])
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_constraint(None, "text_labels", type_="foreignkey")
|
||||
op.drop_column("text_labels", "user_id")
|
||||
# ### end Alembic commands ###
|
||||
@@ -119,6 +119,38 @@ def generate_task(
|
||||
conversation=protocol_schema.Conversation(messages=task_messages),
|
||||
replies=replies,
|
||||
)
|
||||
|
||||
case protocol_schema.TaskRequestType.label_initial_prompt:
|
||||
logger.info("Generating a LabelInitialPromptTask.")
|
||||
message = pr.fetch_random_initial_prompts(1)[0]
|
||||
task = protocol_schema.LabelInitialPromptTask(
|
||||
message_id=message.id,
|
||||
prompt=message.payload.payload.text,
|
||||
valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)),
|
||||
)
|
||||
|
||||
case protocol_schema.TaskRequestType.label_prompter_reply:
|
||||
logger.info("Generating a LabelPrompterReplyTask.")
|
||||
conversation, messages = pr.fetch_multiple_random_replies(max_size=1, message_role="assistant")
|
||||
message = messages[0].payload.payload.text
|
||||
task = protocol_schema.LabelPrompterReplyTask(
|
||||
message_id=message.id,
|
||||
conversation=conversation,
|
||||
reply=message,
|
||||
valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)),
|
||||
)
|
||||
|
||||
case protocol_schema.TaskRequestType.label_assistant_reply:
|
||||
logger.info("Generating a LabelAssistantReplyTask.")
|
||||
conversation, messages = pr.fetch_multiple_random_replies(max_size=1, message_role="prompter")
|
||||
message = messages[0].payload.payload.text
|
||||
task = protocol_schema.LabelAssistantReplyTask(
|
||||
message_id=message.id,
|
||||
conversation=conversation,
|
||||
reply=message,
|
||||
valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)),
|
||||
)
|
||||
|
||||
case _:
|
||||
raise OasstError("Invalid request type", OasstErrorCode.TASK_INVALID_REQUEST_TYPE)
|
||||
|
||||
@@ -256,6 +288,13 @@ def tasks_interaction(
|
||||
pr.store_ranking(interaction)
|
||||
# here we would store the ranking in the database
|
||||
return protocol_schema.TaskDone()
|
||||
case protocol_schema.TextLabels:
|
||||
logger.info(
|
||||
f"Frontend reports labels of {interaction.message_id=} with {interaction.labels=} by {interaction.user=}."
|
||||
)
|
||||
# TODO: check if the labels are valid?
|
||||
pr.store_text_labels(interaction)
|
||||
return protocol_schema.TaskDone()
|
||||
case _:
|
||||
raise OasstError("Invalid response type.", OasstErrorCode.TASK_INVALID_RESPONSE_TYPE)
|
||||
except OasstError:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import pydantic
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.security.api_key import APIKey
|
||||
from loguru import logger
|
||||
@@ -11,17 +10,12 @@ from starlette.status import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class LabelTextRequest(pydantic.BaseModel):
|
||||
text_labels: protocol_schema.TextLabels
|
||||
user: protocol_schema.User
|
||||
|
||||
|
||||
@router.post("/", status_code=HTTP_204_NO_CONTENT)
|
||||
def label_text(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
request: LabelTextRequest,
|
||||
text_labels: protocol_schema.TextLabels,
|
||||
) -> None:
|
||||
"""
|
||||
Label a piece of text.
|
||||
@@ -29,9 +23,9 @@ def label_text(
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
|
||||
try:
|
||||
logger.info(f"Labeling text {request=}.")
|
||||
pr = PromptRepository(db, api_client, user=request.user)
|
||||
pr.store_text_labels(request.text_labels)
|
||||
logger.info(f"Labeling text {text_labels=}.")
|
||||
pr = PromptRepository(db, api_client, user=text_labels.user)
|
||||
pr.store_text_labels(text_labels)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to store label.")
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from oasst_backend.models.payload_column_type import payload_type
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
@@ -91,3 +92,37 @@ class RankAssistantRepliesPayload(RankConversationRepliesPayload):
|
||||
"""A task to rank a set of assistant replies to a conversation."""
|
||||
|
||||
type: Literal["rank_assistant_replies"] = "rank_assistant_replies"
|
||||
|
||||
|
||||
@payload_type
|
||||
class LabelInitialPromptPayload(TaskPayload):
|
||||
"""A task to label an initial prompt."""
|
||||
|
||||
type: Literal["label_initial_prompt"] = "label_initial_prompt"
|
||||
message_id: UUID
|
||||
prompt: str
|
||||
valid_labels: list[str]
|
||||
|
||||
|
||||
@payload_type
|
||||
class LabelConversationReplyPayload(TaskPayload):
|
||||
"""A task to label a conversation reply."""
|
||||
|
||||
message_id: UUID
|
||||
conversation: protocol_schema.Conversation
|
||||
reply: str
|
||||
valid_labels: list[str]
|
||||
|
||||
|
||||
@payload_type
|
||||
class LabelPrompterReplyPayload(LabelConversationReplyPayload):
|
||||
"""A task to label a prompter reply."""
|
||||
|
||||
type: Literal["label_prompter_reply"] = "label_prompter_reply"
|
||||
|
||||
|
||||
@payload_type
|
||||
class LabelAssistantReplyPayload(LabelConversationReplyPayload):
|
||||
"""A task to label an assistant reply."""
|
||||
|
||||
type: Literal["label_assistant_reply"] = "label_assistant_reply"
|
||||
|
||||
@@ -15,6 +15,7 @@ class TextLabels(SQLModel, table=True):
|
||||
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
|
||||
),
|
||||
)
|
||||
user_id: UUID = Field(sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), nullable=False))
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
|
||||
)
|
||||
|
||||
@@ -282,16 +282,39 @@ class PromptRepository:
|
||||
payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation)
|
||||
|
||||
case protocol_schema.RankInitialPromptsTask:
|
||||
payload = db_payload.RankInitialPromptsPayload(tpye=task.type, prompts=task.prompts)
|
||||
payload = db_payload.RankInitialPromptsPayload(type=task.type, prompts=task.prompts)
|
||||
|
||||
case protocol_schema.RankPrompterRepliesTask:
|
||||
payload = db_payload.RankPrompterRepliesPayload(
|
||||
tpye=task.type, conversation=task.conversation, replies=task.replies
|
||||
type=task.type, conversation=task.conversation, replies=task.replies
|
||||
)
|
||||
|
||||
case protocol_schema.RankAssistantRepliesTask:
|
||||
payload = db_payload.RankAssistantRepliesPayload(
|
||||
tpye=task.type, conversation=task.conversation, replies=task.replies
|
||||
type=task.type, conversation=task.conversation, replies=task.replies
|
||||
)
|
||||
|
||||
case protocol_schema.LabelInitialPromptTask:
|
||||
payload = db_payload.LabelInitialPromptPayload(
|
||||
type=task.type, message_id=task.message_id, prompt=task.prompt, valid_labels=task.valid_labels
|
||||
)
|
||||
|
||||
case protocol_schema.LabelPrompterReplyTask:
|
||||
payload = db_payload.LabelPrompterReplyPayload(
|
||||
type=task.type,
|
||||
message_id=task.message_id,
|
||||
conversation=task.conversation,
|
||||
reply=task.reply,
|
||||
valid_labels=task.valid_labels,
|
||||
)
|
||||
|
||||
case protocol_schema.LabelAssistantReplyTask:
|
||||
payload = db_payload.LabelAssistantReplyPayload(
|
||||
type=task.type,
|
||||
message_id=task.message_id,
|
||||
conversation=task.conversation,
|
||||
reply=task.reply,
|
||||
valid_labels=task.valid_labels,
|
||||
)
|
||||
|
||||
case _:
|
||||
@@ -388,12 +411,12 @@ class PromptRepository:
|
||||
def store_text_labels(self, text_labels: protocol_schema.TextLabels) -> TextLabels:
|
||||
model = TextLabels(
|
||||
api_client_id=self.api_client.id,
|
||||
message_id=text_labels.message_id,
|
||||
user_id=self.user_id,
|
||||
text=text_labels.text,
|
||||
labels=text_labels.labels,
|
||||
)
|
||||
if text_labels.has_message_id:
|
||||
self.fetch_message_by_frontend_message_id(text_labels.message_id, fail_if_missing=True)
|
||||
model.message_id = text_labels.message_id
|
||||
|
||||
self.db.add(model)
|
||||
self.db.commit()
|
||||
self.db.refresh(model)
|
||||
|
||||
@@ -10,10 +10,14 @@ import miru
|
||||
from aiosqlite import Connection
|
||||
from bot.messages import (
|
||||
assistant_reply_message,
|
||||
confirm_label_response_message,
|
||||
confirm_ranking_response_message,
|
||||
confirm_text_response_message,
|
||||
initial_prompt_message,
|
||||
invalid_user_input_embed,
|
||||
label_assistant_reply_message,
|
||||
label_initial_prompt_message,
|
||||
label_prompter_reply_message,
|
||||
plain_embed,
|
||||
prompter_reply_message,
|
||||
rank_assistant_reply_message,
|
||||
@@ -145,6 +149,8 @@ async def _handle_task(ctx: lightbulb.Context, task_type: TaskRequestType) -> No
|
||||
content = confirm_ranking_response_message(event.content, task.replies)
|
||||
elif isinstance(task, protocol_schema.RankInitialPromptsTask):
|
||||
content = confirm_ranking_response_message(event.content, task.prompts)
|
||||
elif isinstance(task, protocol_schema.LabelConversationReplyTask | protocol_schema.LabelInitialPromptTask):
|
||||
content = confirm_label_response_message(event.content)
|
||||
elif isinstance(task, protocol_schema.ReplyToConversationTask | protocol_schema.InitialPromptTask):
|
||||
content = confirm_text_response_message(event.content)
|
||||
else:
|
||||
@@ -171,6 +177,17 @@ async def _handle_task(ctx: lightbulb.Context, task_type: TaskRequestType) -> No
|
||||
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
|
||||
),
|
||||
)
|
||||
elif isinstance(task, protocol_schema.LabelConversationReplyTask | protocol_schema.LabelInitialPromptTask):
|
||||
labels = event.content.replace(" ", "").split(",")
|
||||
labels_dict = {label: 1 if label in labels else 0 for label in task.valid_labels}
|
||||
|
||||
reply = protocol_schema.TextLabels(
|
||||
message_id=task.message_id,
|
||||
labels=labels_dict,
|
||||
user=protocol_schema.User(
|
||||
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
|
||||
),
|
||||
)
|
||||
elif isinstance(task, protocol_schema.ReplyToConversationTask | protocol_schema.InitialPromptTask):
|
||||
reply = protocol_schema.TextReplyToMessage(
|
||||
message_id=str(msg_id),
|
||||
@@ -300,6 +317,21 @@ async def _send_task(
|
||||
logger.debug("sending rank assistant reply task")
|
||||
content = rank_assistant_reply_message(task)
|
||||
|
||||
elif task.type == TaskRequestType.label_initial_prompt:
|
||||
assert isinstance(task, protocol_schema.LabelInitialPromptTask)
|
||||
logger.debug("sending label initial prompt task")
|
||||
content = label_initial_prompt_message(task)
|
||||
|
||||
elif task.type == TaskRequestType.label_prompter_reply:
|
||||
assert isinstance(task, protocol_schema.LabelPrompterReplyTask)
|
||||
logger.debug("sending label prompter reply task")
|
||||
content = label_prompter_reply_message(task)
|
||||
|
||||
elif task.type == TaskRequestType.label_assistant_reply:
|
||||
assert isinstance(task, protocol_schema.LabelAssistantReplyTask)
|
||||
logger.debug("sending label assistant reply task")
|
||||
content = label_assistant_reply_message(task)
|
||||
|
||||
elif task.type == TaskRequestType.prompter_reply:
|
||||
assert isinstance(task, protocol_schema.PrompterReplyTask)
|
||||
logger.debug("sending user reply task")
|
||||
@@ -382,6 +414,26 @@ def _validate_user_input(content: str | None, task: protocol_schema.Task) -> tup
|
||||
"Message must contain numbers for all prompts.",
|
||||
)
|
||||
|
||||
# Labels tasks
|
||||
elif task.type in (
|
||||
TaskRequestType.label_initial_prompt,
|
||||
TaskRequestType.label_prompter_reply,
|
||||
TaskRequestType.label_assistant_reply,
|
||||
):
|
||||
assert isinstance(
|
||||
task,
|
||||
protocol_schema.LabelInitialPromptTask
|
||||
| protocol_schema.LabelPrompterReplyTask
|
||||
| protocol_schema.LabelAssistantReplyTask,
|
||||
)
|
||||
|
||||
labels = content.replace(" ", "").split(",")
|
||||
valid_labels = set(task.valid_labels)
|
||||
return (
|
||||
set(labels).issubset(valid_labels),
|
||||
"Message must only contain labels from predefined set of labels.",
|
||||
)
|
||||
|
||||
elif task.type == TaskRequestType.summarize_story:
|
||||
raise NotImplementedError
|
||||
elif task.type == TaskRequestType.rate_summary:
|
||||
|
||||
@@ -33,6 +33,10 @@ def _ranking_prompt(text: str) -> str:
|
||||
return f":trophy: _{text}_"
|
||||
|
||||
|
||||
def _label_prompt(text: str) -> str:
|
||||
return f":question: _{text}"
|
||||
|
||||
|
||||
def _response_prompt(text: str) -> str:
|
||||
return f":speech_balloon: _{text}_"
|
||||
|
||||
@@ -129,6 +133,49 @@ def rank_assistant_reply_message(task: protocol_schema.RankAssistantRepliesTask)
|
||||
"""
|
||||
|
||||
|
||||
def label_initial_prompt_message(task: protocol_schema.LabelInitialPromptTask) -> str:
|
||||
"""Creates the message that gets sent to users when they request a `label_initial_prompt` task."""
|
||||
return f"""\
|
||||
|
||||
{_h1("LABEL INITIAL PROMPT")}
|
||||
|
||||
|
||||
{task.prompt}
|
||||
|
||||
{_label_prompt("Reply with labels for the prompt separated by commas (example: 'profanity,misleading')")}
|
||||
"""
|
||||
|
||||
|
||||
def label_prompter_reply_message(task: protocol_schema.LabelPrompterReplyTask) -> str:
|
||||
"""Creates the message that gets sent to users when they request a `label_prompter_reply` task."""
|
||||
return f"""\
|
||||
|
||||
{_h1("LABEL PROMPTER REPLY")}
|
||||
|
||||
|
||||
{_conversation(task.conversation)}
|
||||
{_user(None)}
|
||||
{task.reply}
|
||||
|
||||
{_label_prompt("Reply with labels for the reply separated by commas (example: 'profanity,misleading')")}
|
||||
"""
|
||||
|
||||
|
||||
def label_assistant_reply_message(task: protocol_schema.LabelAssistantReplyTask) -> str:
|
||||
"""Creates the message that gets sent to users when they request a `label_assistant_reply` task."""
|
||||
return f"""\
|
||||
|
||||
{_h1("LABEL ASSISTANT REPLY")}
|
||||
|
||||
|
||||
{_conversation(task.conversation)}
|
||||
{_assistant(None)}
|
||||
{task.reply}
|
||||
|
||||
{_label_prompt("Reply with labels for the reply separated by commas (example: 'profanity,misleading')")}
|
||||
"""
|
||||
|
||||
|
||||
def prompter_reply_message(task: protocol_schema.PrompterReplyTask) -> str:
|
||||
"""Creates the message that gets sent to users when they request a `prompter_reply` task."""
|
||||
return f"""\
|
||||
@@ -175,6 +222,17 @@ def confirm_ranking_response_message(content: str, items: list[str]) -> str:
|
||||
"""
|
||||
|
||||
|
||||
def confirm_label_response_message(content: str) -> str:
|
||||
user_labels = content.lower().replace(" ", "").split(",")
|
||||
user_labels_str = ", ".join(user_labels)
|
||||
|
||||
return f"""\
|
||||
{_h2("CONFIRM RESPONSE")}
|
||||
|
||||
{user_labels_str}
|
||||
"""
|
||||
|
||||
|
||||
###
|
||||
# Embeds
|
||||
###
|
||||
|
||||
@@ -24,6 +24,9 @@ class TaskType(str, enum.Enum):
|
||||
rank_initial_prompts = "rank_initial_prompts"
|
||||
rank_prompter_replies = "rank_prompter_replies"
|
||||
rank_assistant_replies = "rank_assistant_replies"
|
||||
label_initial_prompt = "label_initial_prompt"
|
||||
label_assistant_reply = "label_assistant_reply"
|
||||
label_prompter_reply = "label_prompter_reply"
|
||||
done = "task_done"
|
||||
|
||||
|
||||
@@ -56,6 +59,9 @@ class OasstApiClient:
|
||||
TaskType.rank_initial_prompts: protocol_schema.RankInitialPromptsTask,
|
||||
TaskType.rank_prompter_replies: protocol_schema.RankPrompterRepliesTask,
|
||||
TaskType.rank_assistant_replies: protocol_schema.RankAssistantRepliesTask,
|
||||
TaskType.label_initial_prompt: protocol_schema.LabelInitialPromptTask,
|
||||
TaskType.label_prompter_reply: protocol_schema.LabelPrompterReplyTask,
|
||||
TaskType.label_assistant_reply: protocol_schema.LabelAssistantReplyTask,
|
||||
TaskType.done: protocol_schema.TaskDone,
|
||||
}
|
||||
|
||||
|
||||
@@ -18,6 +18,9 @@ class TaskRequestType(str, enum.Enum):
|
||||
rank_initial_prompts = "rank_initial_prompts"
|
||||
rank_prompter_replies = "rank_prompter_replies"
|
||||
rank_assistant_replies = "rank_assistant_replies"
|
||||
label_initial_prompt = "label_initial_prompt"
|
||||
label_assistant_reply = "label_assistant_reply"
|
||||
label_prompter_reply = "label_prompter_reply"
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
@@ -169,6 +172,37 @@ class RankAssistantRepliesTask(RankConversationRepliesTask):
|
||||
type: Literal["rank_assistant_replies"] = "rank_assistant_replies"
|
||||
|
||||
|
||||
class LabelInitialPromptTask(Task):
|
||||
"""A task to label an initial prompt."""
|
||||
|
||||
type: Literal["label_initial_prompt"] = "label_initial_prompt"
|
||||
message_id: UUID
|
||||
prompt: str
|
||||
valid_labels: list[str]
|
||||
|
||||
|
||||
class LabelConversationReplyTask(Task):
|
||||
"""A task to label a reply to a conversation."""
|
||||
|
||||
type: Literal["label_conversation_reply"] = "label_conversation_reply"
|
||||
conversation: Conversation # the conversation so far
|
||||
message_id: UUID
|
||||
reply: str
|
||||
valid_labels: list[str]
|
||||
|
||||
|
||||
class LabelPrompterReplyTask(LabelConversationReplyTask):
|
||||
"""A task to label a prompter reply to a conversation."""
|
||||
|
||||
type: Literal["label_prompter_reply"] = "label_prompter_reply"
|
||||
|
||||
|
||||
class LabelAssistantReplyTask(LabelConversationReplyTask):
|
||||
"""A task to label an assistant reply to a conversation."""
|
||||
|
||||
type: Literal["label_assistant_reply"] = "label_assistant_reply"
|
||||
|
||||
|
||||
class TaskDone(Task):
|
||||
"""Signals to the frontend that the task is done."""
|
||||
|
||||
@@ -187,6 +221,10 @@ AnyTask = Union[
|
||||
RankConversationRepliesTask,
|
||||
RankPrompterRepliesTask,
|
||||
RankAssistantRepliesTask,
|
||||
LabelInitialPromptTask,
|
||||
LabelConversationReplyTask,
|
||||
LabelPrompterReplyTask,
|
||||
LabelAssistantReplyTask,
|
||||
]
|
||||
|
||||
|
||||
@@ -222,13 +260,6 @@ class MessageRanking(Interaction):
|
||||
ranking: conlist(item_type=int, min_items=1)
|
||||
|
||||
|
||||
AnyInteraction = Union[
|
||||
TextReplyToMessage,
|
||||
MessageRating,
|
||||
MessageRanking,
|
||||
]
|
||||
|
||||
|
||||
class TextLabel(str, enum.Enum):
|
||||
"""A label for a piece of text."""
|
||||
|
||||
@@ -256,12 +287,13 @@ class TextLabel(str, enum.Enum):
|
||||
slang = "slang"
|
||||
|
||||
|
||||
class TextLabels(BaseModel):
|
||||
class TextLabels(Interaction):
|
||||
"""A set of labels for a piece of text."""
|
||||
|
||||
type: Literal["text_labels"] = "text_labels"
|
||||
text: str
|
||||
labels: dict[TextLabel, float]
|
||||
message_id: str | None = None
|
||||
message_id: UUID
|
||||
|
||||
@property
|
||||
def has_message_id(self) -> bool:
|
||||
@@ -277,6 +309,14 @@ class TextLabels(BaseModel):
|
||||
return v
|
||||
|
||||
|
||||
AnyInteraction = Union[
|
||||
TextReplyToMessage,
|
||||
MessageRating,
|
||||
MessageRanking,
|
||||
TextLabels,
|
||||
]
|
||||
|
||||
|
||||
class SystemStats(BaseModel):
|
||||
all: int = 0
|
||||
active: int = 0
|
||||
|
||||
@@ -203,6 +203,60 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
|
||||
)
|
||||
tasks.append(new_task)
|
||||
|
||||
case "label_initial_prompt":
|
||||
typer.echo("Label the following prompt:")
|
||||
typer.echo(task["prompt"])
|
||||
# acknowledge task
|
||||
message_id = _random_message_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
|
||||
|
||||
valid_labels = task["valid_labels"]
|
||||
labels_str: str = typer.prompt("Enter labels, separated by commas")
|
||||
labels = labels_str.lower().replace(" ", "").split(",")
|
||||
labels_dict = {label: "1" if label in labels else "0" for label in valid_labels}
|
||||
|
||||
# send ranking
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "text_labels",
|
||||
"message_id": task["message_id"],
|
||||
"text": task["prompt"],
|
||||
"labels": labels_dict,
|
||||
"user": USER,
|
||||
},
|
||||
)
|
||||
tasks.append(new_task)
|
||||
|
||||
case "label_prompter_reply" | "label_assistant_reply":
|
||||
typer.echo("Here is the conversation so far:")
|
||||
for message in task["conversation"]["messages"]:
|
||||
typer.echo(_render_message(message))
|
||||
|
||||
typer.echo("Label the following reply:")
|
||||
typer.echo(task["reply"])
|
||||
# acknowledge task
|
||||
message_id = _random_message_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
|
||||
|
||||
valid_labels = task["valid_labels"]
|
||||
labels_str: str = typer.prompt("Enter labels, separated by commas")
|
||||
labels = labels_str.lower().replace(" ", "").split(",")
|
||||
labels_dict = {label: "1" if label in labels else "0" for label in valid_labels}
|
||||
|
||||
# send ranking
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "text_labels",
|
||||
"message_id": task["message_id"],
|
||||
"text": task["prompt"],
|
||||
"labels": labels_dict,
|
||||
"user": USER,
|
||||
},
|
||||
)
|
||||
tasks.append(new_task)
|
||||
|
||||
case "task_done":
|
||||
typer.echo("Task done!")
|
||||
case _:
|
||||
|
||||
Reference in New Issue
Block a user