diff --git a/backend/oasst_backend/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py index 3a85cd5b..0e469602 100644 --- a/backend/oasst_backend/api/v1/tasks.py +++ b/backend/oasst_backend/api/v1/tasks.py @@ -6,6 +6,7 @@ from fastapi import APIRouter, Depends from fastapi.security.api_key import APIKey from loguru import logger from oasst_backend.api import deps +from oasst_backend.api.v1.utils import prepare_conversation from oasst_backend.prompt_repository import PromptRepository from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema @@ -126,7 +127,7 @@ def generate_task( ] replies = [p.text for p in replies] task = protocol_schema.RankAssistantRepliesTask( - conversation=protocol_schema.Conversation(messages=task_messages), + conversation=prepare_conversation(conversation), replies=replies, ) @@ -142,22 +143,22 @@ def generate_task( case protocol_schema.TaskRequestType.label_prompter_reply: logger.info("Generating a LabelPrompterReplyTask.") conversation, messages = pr.fetch_multiple_random_replies(max_size=1, message_role="assistant") - message = messages[0].text + message = messages[0] task = protocol_schema.LabelPrompterReplyTask( message_id=message.id, - conversation=conversation, - reply=message, + conversation=prepare_conversation(conversation), + reply=message.text, valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)), ) case protocol_schema.TaskRequestType.label_assistant_reply: logger.info("Generating a LabelAssistantReplyTask.") conversation, messages = pr.fetch_multiple_random_replies(max_size=1, message_role="prompter") - message = messages[0].text + message = messages[0] task = protocol_schema.LabelAssistantReplyTask( message_id=message.id, - conversation=conversation, - reply=message, + conversation=prepare_conversation(conversation), + reply=message.text, valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)), ) diff --git a/text-frontend/__main__.py b/text-frontend/__main__.py index 2060498d..a6e5f947 100644 --- a/text-frontend/__main__.py +++ b/text-frontend/__main__.py @@ -1,5 +1,6 @@ """Simple REPL frontend.""" +import http import random import requests @@ -30,6 +31,8 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY") def _post(path: str, json: dict) -> dict: response = requests.post(f"{backend_url}{path}", json=json, headers={"X-API-Key": api_key}) response.raise_for_status() + if response.status_code == http.HTTPStatus.NO_CONTENT: + return None return response.json() typer.echo("Requesting work...") @@ -191,7 +194,7 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY") ranking_str = typer.prompt("Enter the reply numbers in order of preference, separated by commas") ranking = [int(x) - 1 for x in ranking_str.split(",")] - # send ranking + # send labels new_task = _post( "/api/v1/tasks/interaction", { @@ -223,7 +226,7 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY") invalid_labels = [label for label in labels if label not in valid_labels] typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}") - # send ranking + # send labels new_task = _post( "/api/v1/tasks/interaction", { @@ -260,13 +263,13 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY") invalid_labels = [label for label in labels if label not in valid_labels] typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}") - # send ranking + # send labels new_task = _post( "/api/v1/tasks/interaction", { "type": "text_labels", "message_id": task["message_id"], - "text": task["prompt"], + "text": task["reply"], "labels": labels_dict, "user": USER, },