fix text_frontend 204 & creation of labeling tasks (#487)

This commit is contained in:
Andreas Köpf
2023-01-07 21:58:29 +01:00
committed by GitHub
parent 5e01f421aa
commit d910c310c0
2 changed files with 15 additions and 11 deletions
+8 -7
View File
@@ -6,6 +6,7 @@ from fastapi import APIRouter, Depends
from fastapi.security.api_key import APIKey
from loguru import logger
from oasst_backend.api import deps
from oasst_backend.api.v1.utils import prepare_conversation
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
@@ -126,7 +127,7 @@ def generate_task(
]
replies = [p.text for p in replies]
task = protocol_schema.RankAssistantRepliesTask(
conversation=protocol_schema.Conversation(messages=task_messages),
conversation=prepare_conversation(conversation),
replies=replies,
)
@@ -142,22 +143,22 @@ def generate_task(
case protocol_schema.TaskRequestType.label_prompter_reply:
logger.info("Generating a LabelPrompterReplyTask.")
conversation, messages = pr.fetch_multiple_random_replies(max_size=1, message_role="assistant")
message = messages[0].text
message = messages[0]
task = protocol_schema.LabelPrompterReplyTask(
message_id=message.id,
conversation=conversation,
reply=message,
conversation=prepare_conversation(conversation),
reply=message.text,
valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)),
)
case protocol_schema.TaskRequestType.label_assistant_reply:
logger.info("Generating a LabelAssistantReplyTask.")
conversation, messages = pr.fetch_multiple_random_replies(max_size=1, message_role="prompter")
message = messages[0].text
message = messages[0]
task = protocol_schema.LabelAssistantReplyTask(
message_id=message.id,
conversation=conversation,
reply=message,
conversation=prepare_conversation(conversation),
reply=message.text,
valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)),
)
+7 -4
View File
@@ -1,5 +1,6 @@
"""Simple REPL frontend."""
import http
import random
import requests
@@ -30,6 +31,8 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
def _post(path: str, json: dict) -> dict:
response = requests.post(f"{backend_url}{path}", json=json, headers={"X-API-Key": api_key})
response.raise_for_status()
if response.status_code == http.HTTPStatus.NO_CONTENT:
return None
return response.json()
typer.echo("Requesting work...")
@@ -191,7 +194,7 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
ranking_str = typer.prompt("Enter the reply numbers in order of preference, separated by commas")
ranking = [int(x) - 1 for x in ranking_str.split(",")]
# send ranking
# send labels
new_task = _post(
"/api/v1/tasks/interaction",
{
@@ -223,7 +226,7 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
invalid_labels = [label for label in labels if label not in valid_labels]
typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}")
# send ranking
# send labels
new_task = _post(
"/api/v1/tasks/interaction",
{
@@ -260,13 +263,13 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
invalid_labels = [label for label in labels if label not in valid_labels]
typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}")
# send ranking
# send labels
new_task = _post(
"/api/v1/tasks/interaction",
{
"type": "text_labels",
"message_id": task["message_id"],
"text": task["prompt"],
"text": task["reply"],
"labels": labels_dict,
"user": USER,
},