From 082981b9d92cb5164829ec3ab2e8a148ba7c6771 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Fri, 16 Dec 2022 12:26:52 +0100 Subject: [PATCH 1/5] insert tasks into work_package table, add PromptRepository --- backend/alembic.ini | 2 +- backend/app/api/deps.py | 14 ++- backend/app/api/v1/api.py | 4 +- backend/app/api/v1/tasks2.py | 173 +++++++++++++++++++++++++++++++ backend/app/prompt_repository.py | 116 +++++++++++++++++++++ 5 files changed, 306 insertions(+), 3 deletions(-) create mode 100644 backend/app/api/v1/tasks2.py create mode 100644 backend/app/prompt_repository.py diff --git a/backend/alembic.ini b/backend/alembic.ini index 39874aee..36a8c3ae 100644 --- a/backend/alembic.ini +++ b/backend/alembic.ini @@ -56,7 +56,7 @@ version_path_separator = os # Use os.pathsep. Default configuration used for ne # output_encoding = utf-8 # sqlalchemy.url = postgresql://:@/ - +sqlalchemy.url = postgresql://postgres:postgres@localhost:5432/postgres [post_write_hooks] # post_write_hooks defines scripts or Python functions that are run diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py index e8b780c1..0790cd4d 100644 --- a/backend/app/api/deps.py +++ b/backend/app/api/deps.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +from secrets import token_hex from typing import Generator from uuid import UUID @@ -7,6 +8,7 @@ from app.database import engine from app.models import ApiClient from fastapi import HTTPException, Security from fastapi.security.api_key import APIKey, APIKeyHeader, APIKeyQuery +from loguru import logger from sqlmodel import Session from starlette.status import HTTP_403_FORBIDDEN @@ -37,7 +39,17 @@ def api_auth( if api_key is not None: if settings.ALLOW_ANY_API_KEY: - return ApiClient(id=UUID("00000000-1111-2222-3333-444444444444"), api_key=api_key, name=api_key) + # make sure that a dummy api key exits in db (foreign key references) + ANY_API_KEY_ID = UUID("00000000-1111-2222-3333-444444444444") + api_client: ApiClient = db.query(ApiClient).filter(ApiClient.id == ANY_API_KEY_ID).first() + if api_client is None: + token = token_hex(32) + logger.info(f"ANY_API_KEY missing, inserting api_key: {token}") + api_client = ApiClient(id=ANY_API_KEY_ID, api_key=token, description="ANY_API_KEY, random token") + db.add(api_client) + db.commit() + return api_client + api_client = db.query(ApiClient).filter(ApiClient.api_key == api_key).first() if api_client is not None and api_client.enabled: return api_client diff --git a/backend/app/api/v1/api.py b/backend/app/api/v1/api.py index 5a704c2d..7e3ea5eb 100644 --- a/backend/app/api/v1/api.py +++ b/backend/app/api/v1/api.py @@ -1,8 +1,10 @@ # -*- coding: utf-8 -*- -from app.api.v1 import labelers, prompts, tasks +from app.api.v1 import labelers, prompts, tasks, tasks2 from fastapi import APIRouter api_router = APIRouter() api_router.include_router(labelers.router, prefix="/labelers", tags=["labelers"]) api_router.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) api_router.include_router(tasks.router, prefix="/tasks", tags=["tasks"]) + +api_router.include_router(tasks2.router, prefix="/task2", tags=["task2"]) # temporary diff --git a/backend/app/api/v1/tasks2.py b/backend/app/api/v1/tasks2.py new file mode 100644 index 00000000..23986fef --- /dev/null +++ b/backend/app/api/v1/tasks2.py @@ -0,0 +1,173 @@ +# -*- coding: utf-8 -*- +import random +from typing import Any +from uuid import UUID + +from app.api import deps +from app.prompt_repository import PromptRepository +from app.schemas import protocol as protocol_schema +from fastapi import APIRouter, Depends, HTTPException +from fastapi.security.api_key import APIKey +from loguru import logger +from sqlmodel import Session +from starlette.status import HTTP_400_BAD_REQUEST + +router = APIRouter() + + +def generate_task(request: protocol_schema.TaskRequest) -> protocol_schema.Task: + match (request.type): + case protocol_schema.TaskRequestType.random: + logger.info("Frontend requested a random task.") + while request.type == protocol_schema.TaskRequestType.random: + request.type = random.choice(list(protocol_schema.TaskRequestType)).value + return generate_task(request) + case protocol_schema.TaskRequestType.summarize_story: + logger.info("Generating a SummarizeStoryTask.") + task = protocol_schema.SummarizeStoryTask( + story="This is a story. A very long story. So long, it needs to be summarized.", + ) + case protocol_schema.TaskRequestType.rate_summary: + logger.info("Generating a RateSummaryTask.") + task = protocol_schema.RateSummaryTask( + full_text="This is a story. A very long story. So long, it needs to be summarized.", + summary="This is a summary.", + scale=protocol_schema.RatingScale(min=1, max=5), + ) + case protocol_schema.TaskRequestType.initial_prompt: + logger.info("Generating an InitialPromptTask.") + task = protocol_schema.InitialPromptTask( + hint="Ask the assistant about a current event." # this is optional + ) + case protocol_schema.TaskRequestType.user_reply: + logger.info("Generating a UserReplyTask.") + task = protocol_schema.UserReplyTask( + conversation=protocol_schema.Conversation( + messages=[ + protocol_schema.ConversationMessage( + text="Hey, assistant, what's going on in the world?", + is_assistant=False, + ), + protocol_schema.ConversationMessage( + text="I'm not sure I understood correctly, could you rephrase that?", + is_assistant=True, + ), + ], + ) + ) + case protocol_schema.TaskRequestType.assistant_reply: + logger.info("Generating a AssistantReplyTask.") + task = protocol_schema.AssistantReplyTask( + conversation=protocol_schema.Conversation( + messages=[ + protocol_schema.ConversationMessage( + text="Hey, assistant, write me an English essay about water.", + is_assistant=False, + ), + ], + ) + ) + case _: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail="Invalid request type.", + ) + logger.info(f"Generated {task=}.") + if request.user is not None: + task.addressed_user = request.user + + return task + + +@router.post("/", response_model=protocol_schema.AnyTask) # work with Union once more types are added +def request_task( + *, + db: Session = Depends(deps.get_db), + api_key: APIKey = Depends(deps.get_api_key), + request: protocol_schema.TaskRequest, +) -> Any: + """ + Create new task. + """ + api_client = deps.api_auth(api_key, db) + + try: + task = generate_task(request) + + pr = PromptRepository(db, api_client, request.user) + pr.store_task(task) + + except Exception: + logger.exception("Failed to generate task.") + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + ) + return task + + +@router.post("/{task_id}/ack") +def acknowledge_task( + *, + db: Session = Depends(deps.get_db), + api_key: APIKey = Depends(deps.get_api_key), + task_id: UUID, + response: protocol_schema.AnyTaskResponse, +) -> Any: + """ + The frontend acknowledges a task. + """ + deps.api_auth(api_key, db) + + match (type(response)): + case protocol_schema.PostCreatedTaskResponse: + logger.info(f"Frontend acknowledged {task_id=} and created {response.post_id=}.") + # here we would store the post id in the database for the task + case protocol_schema.RatingCreatedTaskResponse: + logger.info(f"Frontend acknowledged {task_id=} for {response.post_id=}.") + # here we would store the rating id in the database for the task + case _: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail="Invalid response type.", + ) + + return {} + + +@router.post("/interaction") +def post_interaction( + *, + db: Session = Depends(deps.get_db), + api_key: APIKey = Depends(deps.get_api_key), + interaction: protocol_schema.AnyInteraction, +) -> Any: + """ + The frontend reports an interaction. + """ + deps.api_auth(api_key, db) + + match (type(interaction)): + case protocol_schema.TextReplyToPost: + logger.info( + f"Frontend reports text reply to {interaction.post_id=} with {interaction.text=} by {interaction.user=}." + ) + # here we would store the text reply in the database + return protocol_schema.TaskDone( + reply_to_post_id=interaction.user_post_id, + addressed_user=interaction.user, + ) + case protocol_schema.PostRating: + logger.info( + f"Frontend reports rating of {interaction.post_id=} with {interaction.rating=} by {interaction.user=}." + ) + # check if rating in range + # here we would store the rating in the database + return protocol_schema.TaskDone( + reply_to_post_id=interaction.post_id, + addressed_user=interaction.user, + ) + case _: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail="Invalid response type.", + ) diff --git a/backend/app/prompt_repository.py b/backend/app/prompt_repository.py new file mode 100644 index 00000000..539ef1ab --- /dev/null +++ b/backend/app/prompt_repository.py @@ -0,0 +1,116 @@ +# -*- coding: utf-8 -*- +from typing import Literal, Optional +from uuid import UUID + +# from app.models import ApiClient, Person, PersonStats, Post, PostReaction, WorkPackage +from app.models import ApiClient, Person, WorkPackage +from app.models.payload_column_type import PayloadContainer, payload_tpye +from app.schemas import protocol as protocol_schema +from pydantic import BaseModel +from sqlmodel import Session + + +@payload_tpye +class TaskPayload(BaseModel): + type: str + + +@payload_tpye +class SummarizationStoryPayload(TaskPayload): + type: Literal["summarize_story"] = "summarize_story" + story: str + + +@payload_tpye +class RateSummaryPayload(TaskPayload): + type: Literal["rate_summary"] = "rate_summary" + full_text: str + summary: str + scale: protocol_schema.RatingScale + + +@payload_tpye +class InitialPromptPayload(TaskPayload): + type: Literal["initial_prompt"] = "initial_prompt" + hint: str + + +@payload_tpye +class UserReplyPayload(TaskPayload): + type: Literal["user_reply"] = "user_reply" + conversation: protocol_schema.Conversation + hint: str | None + + +@payload_tpye +class AssistantReplyPayload(TaskPayload): + type: Literal["assistant_reply"] = "assistant_reply" + conversation: protocol_schema.Conversation + + +class PromptRepository: + def __init__(self, db: Session, api_client: ApiClient, user: Optional[protocol_schema.User]): + self.db = db + self.api_client = api_client + self.person = self.lookup_person(user) + self.person_id = self.person.id if self.person else None + + def lookup_person(self, user: protocol_schema.User) -> Person: + person: Person = ( + self.db.query(Person) + .filter(Person.api_client_id == self.api_client.id and Person.username == user.id) + .first() + ) + if person is None: + # user is unknown, create new record + person = Person(username=user.id, display_name=user.display_name, api_client_id=self.api_client.id) + self.db.add(person) + self.db.commit() + self.db.refresh(person) + elif user.display_name and user.display_name != person.display_name: + # we found the user but the display name changed + person.display_name = user.display_name + self.db.add(person) + self.db.commit() + return person + + def store_task(self, task: protocol_schema.Task) -> WorkPackage: + payload: TaskPayload = None + match type(task): + case protocol_schema.SummarizeStoryTask: + payload = SummarizationStoryPayload(story=task.story) + + case protocol_schema.RateSummaryTask: + payload = RateSummaryPayload(full_text=task.full_text, summary=task.summary, scale=task.scale) + + case protocol_schema.InitialPromptTask: + payload = InitialPromptPayload(hint=task.hint) + + case protocol_schema.UserReplyTask: + payload = UserReplyPayload(conversation=task.conversation, hint=task.hint) + + case protocol_schema.AssistantReplyTask: + payload = AssistantReplyPayload(type=task.type, conversation=task.conversation) + + case _: + raise RuntimeError( + detail="Invalid task type.", + ) + + wp = self.insert_work_package(payload=payload, id=task.id) + assert wp.id == task.id + return wp + + def insert_work_package(self, payload: TaskPayload, id: UUID = None) -> WorkPackage: + c = PayloadContainer(payload=payload) + wp = WorkPackage( + id=id, + person_id=self.person_id, + payload_type=type(payload).__name__, + payload=c, + api_client_id=self.api_client.id, + ) + self.db.add(wp) + self.db.commit() + self.db.refresh(wp) + return wp From 1f31e6a499a0a3911708d4624a65b61f7beb706a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Fri, 16 Dec 2022 13:32:06 +0100 Subject: [PATCH 2/5] add frontend post_id binding and post reply handing --- .../versions/cd7de470586e_v1_db_structure.py | 2 +- backend/app/api/v1/tasks2.py | 29 ++++- backend/app/models/post.py | 2 +- backend/app/prompt_repository.py | 107 +++++++++++++++++- 4 files changed, 129 insertions(+), 11 deletions(-) diff --git a/backend/alembic/versions/cd7de470586e_v1_db_structure.py b/backend/alembic/versions/cd7de470586e_v1_db_structure.py index d1eac36f..67488e4b 100644 --- a/backend/alembic/versions/cd7de470586e_v1_db_structure.py +++ b/backend/alembic/versions/cd7de470586e_v1_db_structure.py @@ -94,7 +94,7 @@ def upgrade() -> None: sa.Column("frontend_post_id", sa.String(200), nullable=False), # unique together with api_client_id sa.Column("created_date", sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()), sa.Column("payload_type", sa.String(200), nullable=False), # deserialization hint & dbg aid - sa.Column("payload", JSONB(astext_type=sa.Text()), nullable=False), + sa.Column("payload", JSONB(astext_type=sa.Text()), nullable=True), sa.PrimaryKeyConstraint("id"), sa.ForeignKeyConstraint(["person_id"], ["person.id"]), sa.ForeignKeyConstraint(["api_client_id"], ["api_client.id"]), diff --git a/backend/app/api/v1/tasks2.py b/backend/app/api/v1/tasks2.py index 23986fef..4a359ab0 100644 --- a/backend/app/api/v1/tasks2.py +++ b/backend/app/api/v1/tasks2.py @@ -4,7 +4,7 @@ from typing import Any from uuid import UUID from app.api import deps -from app.prompt_repository import PromptRepository +from app.prompt_repository import PromptRepository, TaskPayload from app.schemas import protocol as protocol_schema from fastapi import APIRouter, Depends, HTTPException from fastapi.security.api_key import APIKey @@ -116,15 +116,24 @@ def acknowledge_task( """ The frontend acknowledges a task. """ - deps.api_auth(api_key, db) + api_client = deps.api_auth(api_key, db) + pr = PromptRepository(db, api_client, user=None) match (type(response)): case protocol_schema.PostCreatedTaskResponse: logger.info(f"Frontend acknowledged {task_id=} and created {response.post_id=}.") - # here we would store the post id in the database for the task + + if response.status == "success": + # here we store the post id in the database for the task + pr.bind_frontend_post_id(task_id=task_id, post_id=response.post_id) + case protocol_schema.RatingCreatedTaskResponse: logger.info(f"Frontend acknowledged {task_id=} for {response.post_id=}.") - # here we would store the rating id in the database for the task + + if response.status == "success": + # here we would store the rating id in the database for the task + pr.bind_frontend_post_id(task_id=task_id, post_id=response.post_id) + case _: raise HTTPException( status_code=HTTP_400_BAD_REQUEST, @@ -144,14 +153,22 @@ def post_interaction( """ The frontend reports an interaction. """ - deps.api_auth(api_key, db) + api_client = deps.api_auth(api_key, db) + pr = PromptRepository(db, api_client, user=interaction.user) match (type(interaction)): case protocol_schema.TextReplyToPost: logger.info( f"Frontend reports text reply to {interaction.post_id=} with {interaction.text=} by {interaction.user=}." ) - # here we would store the text reply in the database + + work_package = pr.fetch_workpackage_by_postid(interaction.post_id) + work_payload: TaskPayload = work_package.payload.payload + logger.info(f"found task work package in db: {work_payload}") + + # here we store the text reply in the database + pr.store_text_reply(interaction) + return protocol_schema.TaskDone( reply_to_post_id=interaction.user_post_id, addressed_user=interaction.user, diff --git a/backend/app/models/post.py b/backend/app/models/post.py index fb6d5160..d1569a67 100644 --- a/backend/app/models/post.py +++ b/backend/app/models/post.py @@ -30,4 +30,4 @@ class Post(SQLModel, table=True): sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()) ) payload_type: str = Field(nullable=False, max_length=200) - payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False)) + payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=True)) diff --git a/backend/app/prompt_repository.py b/backend/app/prompt_repository.py index 539ef1ab..d2145e2e 100644 --- a/backend/app/prompt_repository.py +++ b/backend/app/prompt_repository.py @@ -1,9 +1,9 @@ # -*- coding: utf-8 -*- +from datetime import datetime from typing import Literal, Optional -from uuid import UUID +from uuid import UUID, uuid4 -# from app.models import ApiClient, Person, PersonStats, Post, PostReaction, WorkPackage -from app.models import ApiClient, Person, WorkPackage +from app.models import ApiClient, Person, Post, WorkPackage from app.models.payload_column_type import PayloadContainer, payload_tpye from app.schemas import protocol as protocol_schema from pydantic import BaseModel @@ -74,6 +74,107 @@ class PromptRepository: self.db.commit() return person + def validate_post_id(self, post_id: str) -> None: + if not isinstance(post_id, str): + raise TypeError("post_id must be string") + if not post_id: + raise ValueError("post_id must not be empty") + + def bind_frontend_post_id(self, task_id: UUID, post_id: str): + self.validate_post_id(post_id) + + # find work package + work_pack: WorkPackage = ( + self.db.query(WorkPackage) + .filter(WorkPackage.id == task_id and WorkPackage.api_client_id == self.api_client.id) + .first() + ) + if work_pack is None: + raise RuntimeError(f"WorkPackage for task {task_id} not found") + if work_pack.expiry_date is not None and datetime.utcnow() > work_pack.expiry_date: + raise RuntimeError("WorkPackage already expired.") + + # ToDo: check race-condition, transaction + + # check if task thread exits + thread_root = ( + self.db.query(Post) + .filter( + Post.workpackage_id == work_pack.id + and Post.frontend_post_id == post_id + and Post.parent_id is None + and self.api_client == self.api_client + ) + .one_or_none() + ) + if thread_root is None: + thread_id = uuid4() + thread_root = Post( + id=thread_id, + thread_id=thread_id, + role="system", + person_id=work_pack.person_id, + workpackage_id=work_pack.id, + frontend_post_id=post_id, + api_client_id=self.api_client.id, + payload_type="bind", + ) + self.db.add(thread_root) + self.db.commit() + self.db.refresh(thread_root) + return thread_root + + def fetch_workpackage_by_postid(self, post_id: str) -> WorkPackage: + self.validate_post_id(post_id) + post: Post = ( + self.db.query(Post) + .filter(Post.api_client_id == self.api_client.id and Post.frontend_post_id == post_id) + .one_or_none() + ) + if post is None: + raise RuntimeError(f"Post with post_id {post_id} not found.") + + work_pack = self.db.query(WorkPackage).filter(WorkPackage.id == post.workpackage_id).one() + return work_pack + + def store_text_reply(self, reply: protocol_schema.TextReplyToPost) -> Post: + self.validate_post_id(reply.post_id) + self.validate_post_id(reply.user_post_id) + + # find post with post-id + parent_post: Post = ( + self.db.query(Post) + .filter( + Post.api_client_id == self.api_client.id + and Post.frontend_post_id == reply.post_id + and Post.person_id == self.person_id + ) + .one_or_none() + ) + if parent_post is None: + raise RuntimeError(f"Post for post_id {reply.post_id} not found.") + + # create reply post + user_post_id = uuid4() + # ToDo: role user or agent? + user_post = Post( + id=user_post_id, + parent_id=parent_post.id, + thread_id=parent_post.thread_id, + workpackage_id=parent_post.workpackage_id, + person_id=self.person_id, + role="unknown", + frontend_post_id=reply.user_post_id, + api_client_id=self.api_client.id, + ) + self.db.add(user_post) + self.db.commit() + self.db.refresh(user_post) + return user_post + + def store_rating(self, rating: protocol_schema.PostRating) -> Post: + pass + def store_task(self, task: protocol_schema.Task) -> WorkPackage: payload: TaskPayload = None match type(task): From 82965798955c223a369d21fb7c5fd3fcc74ae07d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Fri, 16 Dec 2022 14:07:43 +0100 Subject: [PATCH 3/5] added rating reaction --- backend/app/api/v1/tasks2.py | 3 +- backend/app/prompt_repository.py | 61 ++++++++++++++++++++++++++------ 2 files changed, 53 insertions(+), 11 deletions(-) diff --git a/backend/app/api/v1/tasks2.py b/backend/app/api/v1/tasks2.py index 4a359ab0..77b30ef0 100644 --- a/backend/app/api/v1/tasks2.py +++ b/backend/app/api/v1/tasks2.py @@ -167,7 +167,8 @@ def post_interaction( logger.info(f"found task work package in db: {work_payload}") # here we store the text reply in the database - pr.store_text_reply(interaction) + # ToDo: role user or agent? + pr.store_text_reply(interaction, role="unknown") return protocol_schema.TaskDone( reply_to_post_id=interaction.user_post_id, diff --git a/backend/app/prompt_repository.py b/backend/app/prompt_repository.py index d2145e2e..24a9ea1c 100644 --- a/backend/app/prompt_repository.py +++ b/backend/app/prompt_repository.py @@ -3,7 +3,7 @@ from datetime import datetime from typing import Literal, Optional from uuid import UUID, uuid4 -from app.models import ApiClient, Person, Post, WorkPackage +from app.models import ApiClient, Person, Post, PostReaction, WorkPackage from app.models.payload_column_type import PayloadContainer, payload_tpye from app.schemas import protocol as protocol_schema from pydantic import BaseModel @@ -48,6 +48,17 @@ class AssistantReplyPayload(TaskPayload): conversation: protocol_schema.Conversation +@payload_tpye +class ReactionPayload(BaseModel): + type: str + + +@payload_tpye +class RatingReactionPayload(ReactionPayload): + type: Literal["post_rating"] = "post_rating" + rating: str + + class PromptRepository: def __init__(self, db: Session, api_client: ApiClient, user: Optional[protocol_schema.User]): self.db = db @@ -124,20 +135,24 @@ class PromptRepository: self.db.refresh(thread_root) return thread_root - def fetch_workpackage_by_postid(self, post_id: str) -> WorkPackage: - self.validate_post_id(post_id) + def fetch_post_by_frontend_post_id(self, frontend_post_id: str, fail_if_missing: bool = True) -> Post: + self.validate_post_id(frontend_post_id) post: Post = ( self.db.query(Post) - .filter(Post.api_client_id == self.api_client.id and Post.frontend_post_id == post_id) + .filter(Post.api_client_id == self.api_client.id and Post.frontend_post_id == frontend_post_id) .one_or_none() ) - if post is None: - raise RuntimeError(f"Post with post_id {post_id} not found.") + if fail_if_missing and post is None: + raise RuntimeError(f"Post with post_id {frontend_post_id} not found.") + return post + def fetch_workpackage_by_postid(self, post_id: str) -> WorkPackage: + self.validate_post_id(post_id) + post = self.fetch_post_by_frontend_post_id(post_id, fail_if_missing=True) work_pack = self.db.query(WorkPackage).filter(WorkPackage.id == post.workpackage_id).one() return work_pack - def store_text_reply(self, reply: protocol_schema.TextReplyToPost) -> Post: + def store_text_reply(self, reply: protocol_schema.TextReplyToPost, role: str) -> Post: self.validate_post_id(reply.post_id) self.validate_post_id(reply.user_post_id) @@ -156,14 +171,14 @@ class PromptRepository: # create reply post user_post_id = uuid4() - # ToDo: role user or agent? + user_post = Post( id=user_post_id, parent_id=parent_post.id, thread_id=parent_post.thread_id, workpackage_id=parent_post.workpackage_id, person_id=self.person_id, - role="unknown", + role=role, frontend_post_id=reply.user_post_id, api_client_id=self.api_client.id, ) @@ -173,7 +188,33 @@ class PromptRepository: return user_post def store_rating(self, rating: protocol_schema.PostRating) -> Post: - pass + post = self.fetch_post_by_frontend_post_id(rating.post_id, fail_if_missing=True) + + work_package = self.fetch_workpackage_by_postid(rating.post_id) + work_payload: RateSummaryPayload = work_package.payload.payload + if type(work_payload) != RateSummaryPayload: + raise RuntimeError("work_package payload type missmatch") + + if rating.rating < work_payload.scale.min or rating.rating > work_payload.scale.max: + raise ValueError("Invalid rating value") + + # store reaction to post + reaction_payload = RatingReactionPayload(rating=rating.rating) + reaction = self.insert_reaction(post.id, reaction_payload) + return reaction + + def insert_reaction(self, post_id: UUID, payload: ReactionPayload) -> PostReaction: + if self.person_id is None: + raise RuntimeError("User required") + + container = PayloadContainer(payload=payload) + reaction = PostReaction( + post_id=post_id, person_id=self.person_id, payload=container, api_client_id=self.api_client.id + ) + self.db.add(reaction) + self.db.commit() + self.db.refresh(reaction) + return reaction def store_task(self, task: protocol_schema.Task) -> WorkPackage: payload: TaskPayload = None From 5ac985e435222d6ad01763fd40052733593319ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Fri, 16 Dec 2022 14:09:48 +0100 Subject: [PATCH 4/5] fix typo --- backend/app/api/v1/tasks2.py | 17 ++++++++++++++++- backend/app/models/payload_column_type.py | 2 +- backend/app/prompt_repository.py | 18 +++++++++--------- 3 files changed, 26 insertions(+), 11 deletions(-) diff --git a/backend/app/api/v1/tasks2.py b/backend/app/api/v1/tasks2.py index 77b30ef0..a334d0b3 100644 --- a/backend/app/api/v1/tasks2.py +++ b/backend/app/api/v1/tasks2.py @@ -4,7 +4,7 @@ from typing import Any from uuid import UUID from app.api import deps -from app.prompt_repository import PromptRepository, TaskPayload +from app.prompt_repository import PromptRepository, RateSummaryPayload, TaskPayload from app.schemas import protocol as protocol_schema from fastapi import APIRouter, Depends, HTTPException from fastapi.security.api_key import APIKey @@ -179,6 +179,21 @@ def post_interaction( f"Frontend reports rating of {interaction.post_id=} with {interaction.rating=} by {interaction.user=}." ) # check if rating in range + + work_package = pr.fetch_workpackage_by_postid(interaction.post_id) + work_payload: RateSummaryPayload = work_package.payload.payload + if ( + type(work_payload) != RateSummaryPayload + or interaction.rating < work_payload.scale.min + or interaction.rating > work_payload.scale.max + ): + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail="Invalid response type.", + ) + + pr.store_rating(interaction) + # here we would store the rating in the database return protocol_schema.TaskDone( reply_to_post_id=interaction.post_id, diff --git a/backend/app/models/payload_column_type.py b/backend/app/models/payload_column_type.py index a95ccb53..fbda51ce 100644 --- a/backend/app/models/payload_column_type.py +++ b/backend/app/models/payload_column_type.py @@ -14,7 +14,7 @@ payload_type_registry = {} P = TypeVar("P", bound=BaseModel) -def payload_tpye(cls: Type[P]) -> Type[P]: +def payload_type(cls: Type[P]) -> Type[P]: payload_type_registry[cls.__name__] = cls return cls diff --git a/backend/app/prompt_repository.py b/backend/app/prompt_repository.py index 24a9ea1c..85f372d3 100644 --- a/backend/app/prompt_repository.py +++ b/backend/app/prompt_repository.py @@ -4,24 +4,24 @@ from typing import Literal, Optional from uuid import UUID, uuid4 from app.models import ApiClient, Person, Post, PostReaction, WorkPackage -from app.models.payload_column_type import PayloadContainer, payload_tpye +from app.models.payload_column_type import PayloadContainer, payload_type from app.schemas import protocol as protocol_schema from pydantic import BaseModel from sqlmodel import Session -@payload_tpye +@payload_type class TaskPayload(BaseModel): type: str -@payload_tpye +@payload_type class SummarizationStoryPayload(TaskPayload): type: Literal["summarize_story"] = "summarize_story" story: str -@payload_tpye +@payload_type class RateSummaryPayload(TaskPayload): type: Literal["rate_summary"] = "rate_summary" full_text: str @@ -29,31 +29,31 @@ class RateSummaryPayload(TaskPayload): scale: protocol_schema.RatingScale -@payload_tpye +@payload_type class InitialPromptPayload(TaskPayload): type: Literal["initial_prompt"] = "initial_prompt" hint: str -@payload_tpye +@payload_type class UserReplyPayload(TaskPayload): type: Literal["user_reply"] = "user_reply" conversation: protocol_schema.Conversation hint: str | None -@payload_tpye +@payload_type class AssistantReplyPayload(TaskPayload): type: Literal["assistant_reply"] = "assistant_reply" conversation: protocol_schema.Conversation -@payload_tpye +@payload_type class ReactionPayload(BaseModel): type: str -@payload_tpye +@payload_type class RatingReactionPayload(ReactionPayload): type: Literal["post_rating"] = "post_rating" rating: str From b20cee5685bdf805554208ef9bdca4059f5bb40c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Fri, 16 Dec 2022 14:53:34 +0100 Subject: [PATCH 5/5] tested initial PromptRepository --- backend/app/prompt_repository.py | 123 ++++++++++++++++++++----------- 1 file changed, 80 insertions(+), 43 deletions(-) diff --git a/backend/app/prompt_repository.py b/backend/app/prompt_repository.py index 85f372d3..0baac989 100644 --- a/backend/app/prompt_repository.py +++ b/backend/app/prompt_repository.py @@ -48,6 +48,11 @@ class AssistantReplyPayload(TaskPayload): conversation: protocol_schema.Conversation +@payload_type +class PostPayload(BaseModel): + text: str + + @payload_type class ReactionPayload(BaseModel): type: str @@ -67,10 +72,10 @@ class PromptRepository: self.person_id = self.person.id if self.person else None def lookup_person(self, user: protocol_schema.User) -> Person: + if not user: + return None person: Person = ( - self.db.query(Person) - .filter(Person.api_client_id == self.api_client.id and Person.username == user.id) - .first() + self.db.query(Person).filter(Person.api_client_id == self.api_client.id, Person.username == user.id).first() ) if person is None: # user is unknown, create new record @@ -97,7 +102,7 @@ class PromptRepository: # find work package work_pack: WorkPackage = ( self.db.query(WorkPackage) - .filter(WorkPackage.id == task_id and WorkPackage.api_client_id == self.api_client.id) + .filter(WorkPackage.id == task_id, WorkPackage.api_client_id == self.api_client.id) .first() ) if work_pack is None: @@ -111,35 +116,32 @@ class PromptRepository: thread_root = ( self.db.query(Post) .filter( - Post.workpackage_id == work_pack.id - and Post.frontend_post_id == post_id - and Post.parent_id is None - and self.api_client == self.api_client + Post.workpackage_id == work_pack.id, + Post.frontend_post_id == post_id, + Post.parent_id is None, + self.api_client == self.api_client, ) .one_or_none() ) if thread_root is None: thread_id = uuid4() - thread_root = Post( - id=thread_id, + thread_root = self.insert_post( + post_id=thread_id, thread_id=thread_id, - role="system", - person_id=work_pack.person_id, - workpackage_id=work_pack.id, frontend_post_id=post_id, - api_client_id=self.api_client.id, + parent_id=None, + role="system", + workpackage_id=work_pack.id, + payload=None, payload_type="bind", ) - self.db.add(thread_root) - self.db.commit() - self.db.refresh(thread_root) return thread_root def fetch_post_by_frontend_post_id(self, frontend_post_id: str, fail_if_missing: bool = True) -> Post: self.validate_post_id(frontend_post_id) post: Post = ( self.db.query(Post) - .filter(Post.api_client_id == self.api_client.id and Post.frontend_post_id == frontend_post_id) + .filter(Post.api_client_id == self.api_client.id, Post.frontend_post_id == frontend_post_id) .one_or_none() ) if fail_if_missing and post is None: @@ -160,31 +162,27 @@ class PromptRepository: parent_post: Post = ( self.db.query(Post) .filter( - Post.api_client_id == self.api_client.id - and Post.frontend_post_id == reply.post_id - and Post.person_id == self.person_id + Post.api_client_id == self.api_client.id, + Post.frontend_post_id == reply.post_id, + # Post.person_id == self.person_id ) .one_or_none() ) + if parent_post is None: raise RuntimeError(f"Post for post_id {reply.post_id} not found.") # create reply post user_post_id = uuid4() - - user_post = Post( - id=user_post_id, + user_post = self.insert_post( + post_id=user_post_id, + frontend_post_id=reply.user_post_id, parent_id=parent_post.id, thread_id=parent_post.thread_id, workpackage_id=parent_post.workpackage_id, - person_id=self.person_id, role=role, - frontend_post_id=reply.user_post_id, - api_client_id=self.api_client.id, + payload=PostPayload(text=reply.text), ) - self.db.add(user_post) - self.db.commit() - self.db.refresh(user_post) return user_post def store_rating(self, rating: protocol_schema.PostRating) -> Post: @@ -203,19 +201,6 @@ class PromptRepository: reaction = self.insert_reaction(post.id, reaction_payload) return reaction - def insert_reaction(self, post_id: UUID, payload: ReactionPayload) -> PostReaction: - if self.person_id is None: - raise RuntimeError("User required") - - container = PayloadContainer(payload=payload) - reaction = PostReaction( - post_id=post_id, person_id=self.person_id, payload=container, api_client_id=self.api_client.id - ) - self.db.add(reaction) - self.db.commit() - self.db.refresh(reaction) - return reaction - def store_task(self, task: protocol_schema.Task) -> WorkPackage: payload: TaskPayload = None match type(task): @@ -256,3 +241,55 @@ class PromptRepository: self.db.commit() self.db.refresh(wp) return wp + + def insert_post( + self, + *, + post_id: UUID, + frontend_post_id: str, + parent_id: UUID, + thread_id: UUID, + workpackage_id: UUID, + role: str, + payload: PostPayload, + payload_type: str = None, + ) -> Post: + if payload_type is None: + if payload is None: + payload_type = "null" + else: + payload_type = type(payload).__name__ + + post = Post( + id=post_id, + parent_id=parent_id, + thread_id=thread_id, + workpackage_id=workpackage_id, + person_id=self.person_id, + role=role, + frontend_post_id=frontend_post_id, + api_client_id=self.api_client.id, + payload_type=payload_type, + payload=PayloadContainer(payload=payload), + ) + self.db.add(post) + self.db.commit() + self.db.refresh(post) + return post + + def insert_reaction(self, post_id: UUID, payload: ReactionPayload) -> PostReaction: + if self.person_id is None: + raise RuntimeError("User required") + + container = PayloadContainer(payload=payload) + reaction = PostReaction( + post_id=post_id, + person_id=self.person_id, + payload=container, + api_client_id=self.api_client.id, + payload_type=type(payload).__name__, + ) + self.db.add(reaction) + self.db.commit() + self.db.refresh(reaction) + return reaction