From 0f2a8971e5796d045adafdd174f9f37af229957a Mon Sep 17 00:00:00 2001 From: Yannic Kilcher Date: Fri, 16 Dec 2022 10:36:40 +0100 Subject: [PATCH] added tasks to act as user or assistant --- backend/app/api/v1/tasks.py | 40 ++++++++++++--- backend/app/schemas/protocol.py | 43 +++++++++++++++-- text-frontend/__main__.py | 86 ++++++++++++++++++++++++++++----- 3 files changed, 145 insertions(+), 24 deletions(-) diff --git a/backend/app/api/v1/tasks.py b/backend/app/api/v1/tasks.py index 145ba4af..2431b70d 100644 --- a/backend/app/api/v1/tasks.py +++ b/backend/app/api/v1/tasks.py @@ -16,9 +16,9 @@ router = APIRouter() def generate_task(request: protocol_schema.TaskRequest) -> protocol_schema.Task: match (request.type): - case protocol_schema.TaskRequestType.generic: - logger.info("Frontend requested a generic task.") - while request.type == protocol_schema.TaskRequestType.generic: + case protocol_schema.TaskRequestType.random: + logger.info("Frontend requested a random task.") + while request.type == protocol_schema.TaskRequestType.random: request.type = random.choice(list(protocol_schema.TaskRequestType)).value return generate_task(request) case protocol_schema.TaskRequestType.summarize_story: @@ -38,6 +38,34 @@ def generate_task(request: protocol_schema.TaskRequest) -> protocol_schema.Task: task = protocol_schema.InitialPromptTask( hint="Ask the assistant about a current event." # this is optional ) + case protocol_schema.TaskRequestType.user_reply: + logger.info("Generating a UserReplyTask.") + task = protocol_schema.UserReplyTask( + conversation=protocol_schema.Conversation( + messages=[ + protocol_schema.ConversationMessage( + text="Hey, assistant, what's going on in the world?", + is_assistant=False, + ), + protocol_schema.ConversationMessage( + text="I'm not sure I understood correctly, could you rephrase that?", + is_assistant=True, + ), + ], + ) + ) + case protocol_schema.TaskRequestType.assistant_reply: + logger.info("Generating a AssistantReplyTask.") + task = protocol_schema.AssistantReplyTask( + conversation=protocol_schema.Conversation( + messages=[ + protocol_schema.ConversationMessage( + text="Hey, assistant, write me an English essay about water.", + is_assistant=False, + ), + ], + ) + ) case _: raise HTTPException( status_code=HTTP_400_BAD_REQUEST, @@ -45,7 +73,7 @@ def generate_task(request: protocol_schema.TaskRequest) -> protocol_schema.Task: ) logger.info(f"Generated {task=}.") if request.user is not None: - task.addressed_users = [request.user] + task.addressed_user = request.user return task @@ -122,7 +150,7 @@ def post_interaction( # here we would store the text reply in the database return protocol_schema.TaskDone( reply_to_post_id=interaction.user_post_id, - addressed_users=[interaction.user], + addressed_user=interaction.user, ) case protocol_schema.PostRating: logger.info( @@ -132,7 +160,7 @@ def post_interaction( # here we would store the rating in the database return protocol_schema.TaskDone( reply_to_post_id=interaction.post_id, - addressed_users=[interaction.user], + addressed_user=interaction.user, ) case _: raise HTTPException( diff --git a/backend/app/schemas/protocol.py b/backend/app/schemas/protocol.py index 6086bc4a..0abbf142 100644 --- a/backend/app/schemas/protocol.py +++ b/backend/app/schemas/protocol.py @@ -8,21 +8,37 @@ from pydantic import BaseModel class TaskRequestType(str, enum.Enum): - generic = "generic" + random = "random" summarize_story = "summarize_story" rate_summary = "rate_summary" initial_prompt = "initial_prompt" + user_reply = "user_reply" + assistant_reply = "assistant_reply" class User(BaseModel): id: str - name: str + display_name: str + auth_method: Literal["discord", "local"] + + +class ConversationMessage(BaseModel): + """Represents a message in a conversation between the user and the assistant.""" + + text: str + is_assistant: bool + + +class Conversation(BaseModel): + """Represents a conversation between the user and the assistant.""" + + messages: list[ConversationMessage] = [] class TaskRequest(BaseModel): """The frontend asks the backend for a task.""" - type: TaskRequestType = TaskRequestType.generic + type: TaskRequestType = TaskRequestType.random user: Optional[User] = None @@ -31,7 +47,7 @@ class Task(BaseModel): id: UUID = pydantic.Field(default_factory=uuid4) type: str - addressed_users: Optional[list[User]] = None + addressed_user: Optional[User] = None class TaskResponse(BaseModel): @@ -91,6 +107,21 @@ class InitialPromptTask(Task): ) +class UserReplyTask(Task): + """A task to prompt the user to submit a reply to the assistant.""" + + type: Literal["user_reply"] = "user_reply" + conversation: Conversation # the conversation so far + hint: str | None = None # e.g. "Try to ask for clarification." + + +class AssistantReplyTask(Task): + """A task to prompt the user to act as the assistant.""" + + type: Literal["assistant_reply"] = "assistant_reply" + conversation: Conversation # the conversation so far + + class TaskDone(Task): """Signals to the frontend that the task is done.""" @@ -99,10 +130,12 @@ class TaskDone(Task): AnyTask = Union[ + TaskDone, SummarizeStoryTask, RateSummaryTask, InitialPromptTask, - TaskDone, + UserReplyTask, + AssistantReplyTask, ] diff --git a/text-frontend/__main__.py b/text-frontend/__main__.py index d912fd60..44b7cbc5 100644 --- a/text-frontend/__main__.py +++ b/text-frontend/__main__.py @@ -7,6 +7,19 @@ import typer app = typer.Typer() +# debug constants +POST_ID = "1234" +USER_POST_ID = "5678" +USER = {"id": "1234", "display_name": "John Doe", "auth_method": "local"} + + +def _render_message(message: dict) -> str: + """Render a message to the user.""" + if message["is_assistant"]: + return f"Assistant: {message['text']}" + return f"User: {message['text']}" + + @app.command() def main(backend_url: str, api_key: str): """Simple REPL frontend.""" @@ -17,7 +30,7 @@ def main(backend_url: str, api_key: str): return response.json() typer.echo("Requesting work...") - tasks = [_post("/api/v1/tasks/", {"type": "generic"})] + tasks = [_post("/api/v1/tasks/", {"type": "random"})] while tasks: task = tasks.pop(0) match (task["type"]): @@ -26,7 +39,7 @@ def main(backend_url: str, api_key: str): typer.echo(task["story"]) # acknowledge task - _post(f"/api/v1/tasks/{task['id']}/ack", {"type": "post_created", "post_id": "1234"}) + _post(f"/api/v1/tasks/{task['id']}/ack", {"type": "post_created", "post_id": POST_ID}) summary = typer.prompt("Enter your summary") @@ -35,10 +48,10 @@ def main(backend_url: str, api_key: str): "/api/v1/tasks/interaction", { "type": "text_reply_to_post", - "post_id": "1234", - "user_post_id": "5678", + "post_id": POST_ID, + "user_post_id": USER_POST_ID, "text": summary, - "user": {"id": "1234", "name": "John Doe"}, + "user": USER, }, ) tasks.append(new_task) @@ -50,7 +63,7 @@ def main(backend_url: str, api_key: str): typer.echo(f"Rating scale: {task['scale']['min']} - {task['scale']['max']}") # acknowledge task - _post(f"/api/v1/tasks/{task['id']}/ack", {"type": "rating_created", "post_id": "1234"}) + _post(f"/api/v1/tasks/{task['id']}/ack", {"type": "rating_created", "post_id": POST_ID}) rating = typer.prompt("Enter your rating", type=int) # send interaction @@ -58,9 +71,9 @@ def main(backend_url: str, api_key: str): "/api/v1/tasks/interaction", { "type": "post_rating", - "post_id": "1234", + "post_id": POST_ID, "rating": rating, - "user": {"id": "1234", "name": "John Doe"}, + "user": USER, }, ) tasks.append(new_task) @@ -69,23 +82,70 @@ def main(backend_url: str, api_key: str): if task["hint"]: typer.echo(f"Hint: {task['hint']}") # acknowledge task - _post(f"/api/v1/tasks/{task['id']}/ack", {"type": "post_created", "post_id": "1234"}) + _post(f"/api/v1/tasks/{task['id']}/ack", {"type": "post_created", "post_id": POST_ID}) prompt = typer.prompt("Enter your prompt") # send interaction new_task = _post( "/api/v1/tasks/interaction", { "type": "text_reply_to_post", - "post_id": "1234", - "user_post_id": "5678", + "post_id": POST_ID, + "user_post_id": USER_POST_ID, "text": prompt, - "user": {"id": "1234", "name": "John Doe"}, + "user": USER, + }, + ) + tasks.append(new_task) + + case "user_reply": + typer.echo("Please provide a reply to the assistant.") + typer.echo("Here is the conversation so far:") + for message in task["conversation"]["messages"]: + typer.echo(_render_message(message)) + if task["hint"]: + typer.echo(f"Hint: {task['hint']}") + # acknowledge task + _post(f"/api/v1/tasks/{task['id']}/ack", {"type": "post_created", "post_id": POST_ID}) + reply = typer.prompt("Enter your reply") + # send interaction + new_task = _post( + "/api/v1/tasks/interaction", + { + "type": "text_reply_to_post", + "post_id": POST_ID, + "user_post_id": USER_POST_ID, + "text": reply, + "user": USER, + }, + ) + tasks.append(new_task) + + case "assistant_reply": + typer.echo("Act as the assistant and reply to the user.") + typer.echo("Here is the conversation so far:") + for message in task["conversation"]["messages"]: + typer.echo(_render_message(message)) + # acknowledge task + _post(f"/api/v1/tasks/{task['id']}/ack", {"type": "post_created", "post_id": POST_ID}) + reply = typer.prompt("Enter your reply") + # send interaction + new_task = _post( + "/api/v1/tasks/interaction", + { + "type": "text_reply_to_post", + "post_id": POST_ID, + "user_post_id": USER_POST_ID, + "text": reply, + "user": USER, }, ) tasks.append(new_task) case "task_done": - typer.echo("Task done!") + if addressed_user := task["addressed_user"]: + typer.echo(f"Hey, {addressed_user['display_name']}! Thank you!") + else: + typer.echo("Task done!") case _: typer.echo(f"Unknown task type {task['type']}")