From 082981b9d92cb5164829ec3ab2e8a148ba7c6771 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Fri, 16 Dec 2022 12:26:52 +0100 Subject: [PATCH] insert tasks into work_package table, add PromptRepository --- backend/alembic.ini | 2 +- backend/app/api/deps.py | 14 ++- backend/app/api/v1/api.py | 4 +- backend/app/api/v1/tasks2.py | 173 +++++++++++++++++++++++++++++++ backend/app/prompt_repository.py | 116 +++++++++++++++++++++ 5 files changed, 306 insertions(+), 3 deletions(-) create mode 100644 backend/app/api/v1/tasks2.py create mode 100644 backend/app/prompt_repository.py diff --git a/backend/alembic.ini b/backend/alembic.ini index 39874aee..36a8c3ae 100644 --- a/backend/alembic.ini +++ b/backend/alembic.ini @@ -56,7 +56,7 @@ version_path_separator = os # Use os.pathsep. Default configuration used for ne # output_encoding = utf-8 # sqlalchemy.url = postgresql://:@/ - +sqlalchemy.url = postgresql://postgres:postgres@localhost:5432/postgres [post_write_hooks] # post_write_hooks defines scripts or Python functions that are run diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py index e8b780c1..0790cd4d 100644 --- a/backend/app/api/deps.py +++ b/backend/app/api/deps.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +from secrets import token_hex from typing import Generator from uuid import UUID @@ -7,6 +8,7 @@ from app.database import engine from app.models import ApiClient from fastapi import HTTPException, Security from fastapi.security.api_key import APIKey, APIKeyHeader, APIKeyQuery +from loguru import logger from sqlmodel import Session from starlette.status import HTTP_403_FORBIDDEN @@ -37,7 +39,17 @@ def api_auth( if api_key is not None: if settings.ALLOW_ANY_API_KEY: - return ApiClient(id=UUID("00000000-1111-2222-3333-444444444444"), api_key=api_key, name=api_key) + # make sure that a dummy api key exits in db (foreign key references) + ANY_API_KEY_ID = UUID("00000000-1111-2222-3333-444444444444") + api_client: ApiClient = db.query(ApiClient).filter(ApiClient.id == ANY_API_KEY_ID).first() + if api_client is None: + token = token_hex(32) + logger.info(f"ANY_API_KEY missing, inserting api_key: {token}") + api_client = ApiClient(id=ANY_API_KEY_ID, api_key=token, description="ANY_API_KEY, random token") + db.add(api_client) + db.commit() + return api_client + api_client = db.query(ApiClient).filter(ApiClient.api_key == api_key).first() if api_client is not None and api_client.enabled: return api_client diff --git a/backend/app/api/v1/api.py b/backend/app/api/v1/api.py index 5a704c2d..7e3ea5eb 100644 --- a/backend/app/api/v1/api.py +++ b/backend/app/api/v1/api.py @@ -1,8 +1,10 @@ # -*- coding: utf-8 -*- -from app.api.v1 import labelers, prompts, tasks +from app.api.v1 import labelers, prompts, tasks, tasks2 from fastapi import APIRouter api_router = APIRouter() api_router.include_router(labelers.router, prefix="/labelers", tags=["labelers"]) api_router.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) api_router.include_router(tasks.router, prefix="/tasks", tags=["tasks"]) + +api_router.include_router(tasks2.router, prefix="/task2", tags=["task2"]) # temporary diff --git a/backend/app/api/v1/tasks2.py b/backend/app/api/v1/tasks2.py new file mode 100644 index 00000000..23986fef --- /dev/null +++ b/backend/app/api/v1/tasks2.py @@ -0,0 +1,173 @@ +# -*- coding: utf-8 -*- +import random +from typing import Any +from uuid import UUID + +from app.api import deps +from app.prompt_repository import PromptRepository +from app.schemas import protocol as protocol_schema +from fastapi import APIRouter, Depends, HTTPException +from fastapi.security.api_key import APIKey +from loguru import logger +from sqlmodel import Session +from starlette.status import HTTP_400_BAD_REQUEST + +router = APIRouter() + + +def generate_task(request: protocol_schema.TaskRequest) -> protocol_schema.Task: + match (request.type): + case protocol_schema.TaskRequestType.random: + logger.info("Frontend requested a random task.") + while request.type == protocol_schema.TaskRequestType.random: + request.type = random.choice(list(protocol_schema.TaskRequestType)).value + return generate_task(request) + case protocol_schema.TaskRequestType.summarize_story: + logger.info("Generating a SummarizeStoryTask.") + task = protocol_schema.SummarizeStoryTask( + story="This is a story. A very long story. So long, it needs to be summarized.", + ) + case protocol_schema.TaskRequestType.rate_summary: + logger.info("Generating a RateSummaryTask.") + task = protocol_schema.RateSummaryTask( + full_text="This is a story. A very long story. So long, it needs to be summarized.", + summary="This is a summary.", + scale=protocol_schema.RatingScale(min=1, max=5), + ) + case protocol_schema.TaskRequestType.initial_prompt: + logger.info("Generating an InitialPromptTask.") + task = protocol_schema.InitialPromptTask( + hint="Ask the assistant about a current event." # this is optional + ) + case protocol_schema.TaskRequestType.user_reply: + logger.info("Generating a UserReplyTask.") + task = protocol_schema.UserReplyTask( + conversation=protocol_schema.Conversation( + messages=[ + protocol_schema.ConversationMessage( + text="Hey, assistant, what's going on in the world?", + is_assistant=False, + ), + protocol_schema.ConversationMessage( + text="I'm not sure I understood correctly, could you rephrase that?", + is_assistant=True, + ), + ], + ) + ) + case protocol_schema.TaskRequestType.assistant_reply: + logger.info("Generating a AssistantReplyTask.") + task = protocol_schema.AssistantReplyTask( + conversation=protocol_schema.Conversation( + messages=[ + protocol_schema.ConversationMessage( + text="Hey, assistant, write me an English essay about water.", + is_assistant=False, + ), + ], + ) + ) + case _: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail="Invalid request type.", + ) + logger.info(f"Generated {task=}.") + if request.user is not None: + task.addressed_user = request.user + + return task + + +@router.post("/", response_model=protocol_schema.AnyTask) # work with Union once more types are added +def request_task( + *, + db: Session = Depends(deps.get_db), + api_key: APIKey = Depends(deps.get_api_key), + request: protocol_schema.TaskRequest, +) -> Any: + """ + Create new task. + """ + api_client = deps.api_auth(api_key, db) + + try: + task = generate_task(request) + + pr = PromptRepository(db, api_client, request.user) + pr.store_task(task) + + except Exception: + logger.exception("Failed to generate task.") + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + ) + return task + + +@router.post("/{task_id}/ack") +def acknowledge_task( + *, + db: Session = Depends(deps.get_db), + api_key: APIKey = Depends(deps.get_api_key), + task_id: UUID, + response: protocol_schema.AnyTaskResponse, +) -> Any: + """ + The frontend acknowledges a task. + """ + deps.api_auth(api_key, db) + + match (type(response)): + case protocol_schema.PostCreatedTaskResponse: + logger.info(f"Frontend acknowledged {task_id=} and created {response.post_id=}.") + # here we would store the post id in the database for the task + case protocol_schema.RatingCreatedTaskResponse: + logger.info(f"Frontend acknowledged {task_id=} for {response.post_id=}.") + # here we would store the rating id in the database for the task + case _: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail="Invalid response type.", + ) + + return {} + + +@router.post("/interaction") +def post_interaction( + *, + db: Session = Depends(deps.get_db), + api_key: APIKey = Depends(deps.get_api_key), + interaction: protocol_schema.AnyInteraction, +) -> Any: + """ + The frontend reports an interaction. + """ + deps.api_auth(api_key, db) + + match (type(interaction)): + case protocol_schema.TextReplyToPost: + logger.info( + f"Frontend reports text reply to {interaction.post_id=} with {interaction.text=} by {interaction.user=}." + ) + # here we would store the text reply in the database + return protocol_schema.TaskDone( + reply_to_post_id=interaction.user_post_id, + addressed_user=interaction.user, + ) + case protocol_schema.PostRating: + logger.info( + f"Frontend reports rating of {interaction.post_id=} with {interaction.rating=} by {interaction.user=}." + ) + # check if rating in range + # here we would store the rating in the database + return protocol_schema.TaskDone( + reply_to_post_id=interaction.post_id, + addressed_user=interaction.user, + ) + case _: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail="Invalid response type.", + ) diff --git a/backend/app/prompt_repository.py b/backend/app/prompt_repository.py new file mode 100644 index 00000000..539ef1ab --- /dev/null +++ b/backend/app/prompt_repository.py @@ -0,0 +1,116 @@ +# -*- coding: utf-8 -*- +from typing import Literal, Optional +from uuid import UUID + +# from app.models import ApiClient, Person, PersonStats, Post, PostReaction, WorkPackage +from app.models import ApiClient, Person, WorkPackage +from app.models.payload_column_type import PayloadContainer, payload_tpye +from app.schemas import protocol as protocol_schema +from pydantic import BaseModel +from sqlmodel import Session + + +@payload_tpye +class TaskPayload(BaseModel): + type: str + + +@payload_tpye +class SummarizationStoryPayload(TaskPayload): + type: Literal["summarize_story"] = "summarize_story" + story: str + + +@payload_tpye +class RateSummaryPayload(TaskPayload): + type: Literal["rate_summary"] = "rate_summary" + full_text: str + summary: str + scale: protocol_schema.RatingScale + + +@payload_tpye +class InitialPromptPayload(TaskPayload): + type: Literal["initial_prompt"] = "initial_prompt" + hint: str + + +@payload_tpye +class UserReplyPayload(TaskPayload): + type: Literal["user_reply"] = "user_reply" + conversation: protocol_schema.Conversation + hint: str | None + + +@payload_tpye +class AssistantReplyPayload(TaskPayload): + type: Literal["assistant_reply"] = "assistant_reply" + conversation: protocol_schema.Conversation + + +class PromptRepository: + def __init__(self, db: Session, api_client: ApiClient, user: Optional[protocol_schema.User]): + self.db = db + self.api_client = api_client + self.person = self.lookup_person(user) + self.person_id = self.person.id if self.person else None + + def lookup_person(self, user: protocol_schema.User) -> Person: + person: Person = ( + self.db.query(Person) + .filter(Person.api_client_id == self.api_client.id and Person.username == user.id) + .first() + ) + if person is None: + # user is unknown, create new record + person = Person(username=user.id, display_name=user.display_name, api_client_id=self.api_client.id) + self.db.add(person) + self.db.commit() + self.db.refresh(person) + elif user.display_name and user.display_name != person.display_name: + # we found the user but the display name changed + person.display_name = user.display_name + self.db.add(person) + self.db.commit() + return person + + def store_task(self, task: protocol_schema.Task) -> WorkPackage: + payload: TaskPayload = None + match type(task): + case protocol_schema.SummarizeStoryTask: + payload = SummarizationStoryPayload(story=task.story) + + case protocol_schema.RateSummaryTask: + payload = RateSummaryPayload(full_text=task.full_text, summary=task.summary, scale=task.scale) + + case protocol_schema.InitialPromptTask: + payload = InitialPromptPayload(hint=task.hint) + + case protocol_schema.UserReplyTask: + payload = UserReplyPayload(conversation=task.conversation, hint=task.hint) + + case protocol_schema.AssistantReplyTask: + payload = AssistantReplyPayload(type=task.type, conversation=task.conversation) + + case _: + raise RuntimeError( + detail="Invalid task type.", + ) + + wp = self.insert_work_package(payload=payload, id=task.id) + assert wp.id == task.id + return wp + + def insert_work_package(self, payload: TaskPayload, id: UUID = None) -> WorkPackage: + c = PayloadContainer(payload=payload) + wp = WorkPackage( + id=id, + person_id=self.person_id, + payload_type=type(payload).__name__, + payload=c, + api_client_id=self.api_client.id, + ) + self.db.add(wp) + self.db.commit() + self.db.refresh(wp) + return wp