mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-07-02 17:00:28 +08:00
fix text_frontend 204 & creation of labeling tasks (#487)
This commit is contained in:
@@ -6,6 +6,7 @@ from fastapi import APIRouter, Depends
|
||||
from fastapi.security.api_key import APIKey
|
||||
from loguru import logger
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.api.v1.utils import prepare_conversation
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
@@ -126,7 +127,7 @@ def generate_task(
|
||||
]
|
||||
replies = [p.text for p in replies]
|
||||
task = protocol_schema.RankAssistantRepliesTask(
|
||||
conversation=protocol_schema.Conversation(messages=task_messages),
|
||||
conversation=prepare_conversation(conversation),
|
||||
replies=replies,
|
||||
)
|
||||
|
||||
@@ -142,22 +143,22 @@ def generate_task(
|
||||
case protocol_schema.TaskRequestType.label_prompter_reply:
|
||||
logger.info("Generating a LabelPrompterReplyTask.")
|
||||
conversation, messages = pr.fetch_multiple_random_replies(max_size=1, message_role="assistant")
|
||||
message = messages[0].text
|
||||
message = messages[0]
|
||||
task = protocol_schema.LabelPrompterReplyTask(
|
||||
message_id=message.id,
|
||||
conversation=conversation,
|
||||
reply=message,
|
||||
conversation=prepare_conversation(conversation),
|
||||
reply=message.text,
|
||||
valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)),
|
||||
)
|
||||
|
||||
case protocol_schema.TaskRequestType.label_assistant_reply:
|
||||
logger.info("Generating a LabelAssistantReplyTask.")
|
||||
conversation, messages = pr.fetch_multiple_random_replies(max_size=1, message_role="prompter")
|
||||
message = messages[0].text
|
||||
message = messages[0]
|
||||
task = protocol_schema.LabelAssistantReplyTask(
|
||||
message_id=message.id,
|
||||
conversation=conversation,
|
||||
reply=message,
|
||||
conversation=prepare_conversation(conversation),
|
||||
reply=message.text,
|
||||
valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Simple REPL frontend."""
|
||||
|
||||
import http
|
||||
import random
|
||||
|
||||
import requests
|
||||
@@ -30,6 +31,8 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
|
||||
def _post(path: str, json: dict) -> dict:
|
||||
response = requests.post(f"{backend_url}{path}", json=json, headers={"X-API-Key": api_key})
|
||||
response.raise_for_status()
|
||||
if response.status_code == http.HTTPStatus.NO_CONTENT:
|
||||
return None
|
||||
return response.json()
|
||||
|
||||
typer.echo("Requesting work...")
|
||||
@@ -191,7 +194,7 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
|
||||
ranking_str = typer.prompt("Enter the reply numbers in order of preference, separated by commas")
|
||||
ranking = [int(x) - 1 for x in ranking_str.split(",")]
|
||||
|
||||
# send ranking
|
||||
# send labels
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
@@ -223,7 +226,7 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
|
||||
invalid_labels = [label for label in labels if label not in valid_labels]
|
||||
typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}")
|
||||
|
||||
# send ranking
|
||||
# send labels
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
@@ -260,13 +263,13 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
|
||||
invalid_labels = [label for label in labels if label not in valid_labels]
|
||||
typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}")
|
||||
|
||||
# send ranking
|
||||
# send labels
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "text_labels",
|
||||
"message_id": task["message_id"],
|
||||
"text": task["prompt"],
|
||||
"text": task["reply"],
|
||||
"labels": labels_dict,
|
||||
"user": USER,
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user