diff --git a/backend/alembic.ini b/backend/alembic.ini index 39874aee..36a8c3ae 100644 --- a/backend/alembic.ini +++ b/backend/alembic.ini @@ -56,7 +56,7 @@ version_path_separator = os # Use os.pathsep. Default configuration used for ne # output_encoding = utf-8 # sqlalchemy.url = postgresql://:@/ - +sqlalchemy.url = postgresql://postgres:postgres@localhost:5432/postgres [post_write_hooks] # post_write_hooks defines scripts or Python functions that are run diff --git a/backend/alembic/versions/cd7de470586e_v1_db_structure.py b/backend/alembic/versions/cd7de470586e_v1_db_structure.py index d1eac36f..67488e4b 100644 --- a/backend/alembic/versions/cd7de470586e_v1_db_structure.py +++ b/backend/alembic/versions/cd7de470586e_v1_db_structure.py @@ -94,7 +94,7 @@ def upgrade() -> None: sa.Column("frontend_post_id", sa.String(200), nullable=False), # unique together with api_client_id sa.Column("created_date", sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()), sa.Column("payload_type", sa.String(200), nullable=False), # deserialization hint & dbg aid - sa.Column("payload", JSONB(astext_type=sa.Text()), nullable=False), + sa.Column("payload", JSONB(astext_type=sa.Text()), nullable=True), sa.PrimaryKeyConstraint("id"), sa.ForeignKeyConstraint(["person_id"], ["person.id"]), sa.ForeignKeyConstraint(["api_client_id"], ["api_client.id"]), diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py index e8b780c1..0790cd4d 100644 --- a/backend/app/api/deps.py +++ b/backend/app/api/deps.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +from secrets import token_hex from typing import Generator from uuid import UUID @@ -7,6 +8,7 @@ from app.database import engine from app.models import ApiClient from fastapi import HTTPException, Security from fastapi.security.api_key import APIKey, APIKeyHeader, APIKeyQuery +from loguru import logger from sqlmodel import Session from starlette.status import HTTP_403_FORBIDDEN @@ -37,7 +39,17 @@ def api_auth( if api_key is not None: if settings.ALLOW_ANY_API_KEY: - return ApiClient(id=UUID("00000000-1111-2222-3333-444444444444"), api_key=api_key, name=api_key) + # make sure that a dummy api key exits in db (foreign key references) + ANY_API_KEY_ID = UUID("00000000-1111-2222-3333-444444444444") + api_client: ApiClient = db.query(ApiClient).filter(ApiClient.id == ANY_API_KEY_ID).first() + if api_client is None: + token = token_hex(32) + logger.info(f"ANY_API_KEY missing, inserting api_key: {token}") + api_client = ApiClient(id=ANY_API_KEY_ID, api_key=token, description="ANY_API_KEY, random token") + db.add(api_client) + db.commit() + return api_client + api_client = db.query(ApiClient).filter(ApiClient.api_key == api_key).first() if api_client is not None and api_client.enabled: return api_client diff --git a/backend/app/api/v1/api.py b/backend/app/api/v1/api.py index 5a704c2d..7e3ea5eb 100644 --- a/backend/app/api/v1/api.py +++ b/backend/app/api/v1/api.py @@ -1,8 +1,10 @@ # -*- coding: utf-8 -*- -from app.api.v1 import labelers, prompts, tasks +from app.api.v1 import labelers, prompts, tasks, tasks2 from fastapi import APIRouter api_router = APIRouter() api_router.include_router(labelers.router, prefix="/labelers", tags=["labelers"]) api_router.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) api_router.include_router(tasks.router, prefix="/tasks", tags=["tasks"]) + +api_router.include_router(tasks2.router, prefix="/task2", tags=["task2"]) # temporary diff --git a/backend/app/api/v1/tasks2.py b/backend/app/api/v1/tasks2.py new file mode 100644 index 00000000..a334d0b3 --- /dev/null +++ b/backend/app/api/v1/tasks2.py @@ -0,0 +1,206 @@ +# -*- coding: utf-8 -*- +import random +from typing import Any +from uuid import UUID + +from app.api import deps +from app.prompt_repository import PromptRepository, RateSummaryPayload, TaskPayload +from app.schemas import protocol as protocol_schema +from fastapi import APIRouter, Depends, HTTPException +from fastapi.security.api_key import APIKey +from loguru import logger +from sqlmodel import Session +from starlette.status import HTTP_400_BAD_REQUEST + +router = APIRouter() + + +def generate_task(request: protocol_schema.TaskRequest) -> protocol_schema.Task: + match (request.type): + case protocol_schema.TaskRequestType.random: + logger.info("Frontend requested a random task.") + while request.type == protocol_schema.TaskRequestType.random: + request.type = random.choice(list(protocol_schema.TaskRequestType)).value + return generate_task(request) + case protocol_schema.TaskRequestType.summarize_story: + logger.info("Generating a SummarizeStoryTask.") + task = protocol_schema.SummarizeStoryTask( + story="This is a story. A very long story. So long, it needs to be summarized.", + ) + case protocol_schema.TaskRequestType.rate_summary: + logger.info("Generating a RateSummaryTask.") + task = protocol_schema.RateSummaryTask( + full_text="This is a story. A very long story. So long, it needs to be summarized.", + summary="This is a summary.", + scale=protocol_schema.RatingScale(min=1, max=5), + ) + case protocol_schema.TaskRequestType.initial_prompt: + logger.info("Generating an InitialPromptTask.") + task = protocol_schema.InitialPromptTask( + hint="Ask the assistant about a current event." # this is optional + ) + case protocol_schema.TaskRequestType.user_reply: + logger.info("Generating a UserReplyTask.") + task = protocol_schema.UserReplyTask( + conversation=protocol_schema.Conversation( + messages=[ + protocol_schema.ConversationMessage( + text="Hey, assistant, what's going on in the world?", + is_assistant=False, + ), + protocol_schema.ConversationMessage( + text="I'm not sure I understood correctly, could you rephrase that?", + is_assistant=True, + ), + ], + ) + ) + case protocol_schema.TaskRequestType.assistant_reply: + logger.info("Generating a AssistantReplyTask.") + task = protocol_schema.AssistantReplyTask( + conversation=protocol_schema.Conversation( + messages=[ + protocol_schema.ConversationMessage( + text="Hey, assistant, write me an English essay about water.", + is_assistant=False, + ), + ], + ) + ) + case _: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail="Invalid request type.", + ) + logger.info(f"Generated {task=}.") + if request.user is not None: + task.addressed_user = request.user + + return task + + +@router.post("/", response_model=protocol_schema.AnyTask) # work with Union once more types are added +def request_task( + *, + db: Session = Depends(deps.get_db), + api_key: APIKey = Depends(deps.get_api_key), + request: protocol_schema.TaskRequest, +) -> Any: + """ + Create new task. + """ + api_client = deps.api_auth(api_key, db) + + try: + task = generate_task(request) + + pr = PromptRepository(db, api_client, request.user) + pr.store_task(task) + + except Exception: + logger.exception("Failed to generate task.") + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + ) + return task + + +@router.post("/{task_id}/ack") +def acknowledge_task( + *, + db: Session = Depends(deps.get_db), + api_key: APIKey = Depends(deps.get_api_key), + task_id: UUID, + response: protocol_schema.AnyTaskResponse, +) -> Any: + """ + The frontend acknowledges a task. + """ + api_client = deps.api_auth(api_key, db) + pr = PromptRepository(db, api_client, user=None) + + match (type(response)): + case protocol_schema.PostCreatedTaskResponse: + logger.info(f"Frontend acknowledged {task_id=} and created {response.post_id=}.") + + if response.status == "success": + # here we store the post id in the database for the task + pr.bind_frontend_post_id(task_id=task_id, post_id=response.post_id) + + case protocol_schema.RatingCreatedTaskResponse: + logger.info(f"Frontend acknowledged {task_id=} for {response.post_id=}.") + + if response.status == "success": + # here we would store the rating id in the database for the task + pr.bind_frontend_post_id(task_id=task_id, post_id=response.post_id) + + case _: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail="Invalid response type.", + ) + + return {} + + +@router.post("/interaction") +def post_interaction( + *, + db: Session = Depends(deps.get_db), + api_key: APIKey = Depends(deps.get_api_key), + interaction: protocol_schema.AnyInteraction, +) -> Any: + """ + The frontend reports an interaction. + """ + api_client = deps.api_auth(api_key, db) + pr = PromptRepository(db, api_client, user=interaction.user) + + match (type(interaction)): + case protocol_schema.TextReplyToPost: + logger.info( + f"Frontend reports text reply to {interaction.post_id=} with {interaction.text=} by {interaction.user=}." + ) + + work_package = pr.fetch_workpackage_by_postid(interaction.post_id) + work_payload: TaskPayload = work_package.payload.payload + logger.info(f"found task work package in db: {work_payload}") + + # here we store the text reply in the database + # ToDo: role user or agent? + pr.store_text_reply(interaction, role="unknown") + + return protocol_schema.TaskDone( + reply_to_post_id=interaction.user_post_id, + addressed_user=interaction.user, + ) + case protocol_schema.PostRating: + logger.info( + f"Frontend reports rating of {interaction.post_id=} with {interaction.rating=} by {interaction.user=}." + ) + # check if rating in range + + work_package = pr.fetch_workpackage_by_postid(interaction.post_id) + work_payload: RateSummaryPayload = work_package.payload.payload + if ( + type(work_payload) != RateSummaryPayload + or interaction.rating < work_payload.scale.min + or interaction.rating > work_payload.scale.max + ): + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail="Invalid response type.", + ) + + pr.store_rating(interaction) + + # here we would store the rating in the database + return protocol_schema.TaskDone( + reply_to_post_id=interaction.post_id, + addressed_user=interaction.user, + ) + case _: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail="Invalid response type.", + ) diff --git a/backend/app/models/payload_column_type.py b/backend/app/models/payload_column_type.py index a95ccb53..fbda51ce 100644 --- a/backend/app/models/payload_column_type.py +++ b/backend/app/models/payload_column_type.py @@ -14,7 +14,7 @@ payload_type_registry = {} P = TypeVar("P", bound=BaseModel) -def payload_tpye(cls: Type[P]) -> Type[P]: +def payload_type(cls: Type[P]) -> Type[P]: payload_type_registry[cls.__name__] = cls return cls diff --git a/backend/app/models/post.py b/backend/app/models/post.py index fb6d5160..d1569a67 100644 --- a/backend/app/models/post.py +++ b/backend/app/models/post.py @@ -30,4 +30,4 @@ class Post(SQLModel, table=True): sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()) ) payload_type: str = Field(nullable=False, max_length=200) - payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False)) + payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=True)) diff --git a/backend/app/prompt_repository.py b/backend/app/prompt_repository.py new file mode 100644 index 00000000..0baac989 --- /dev/null +++ b/backend/app/prompt_repository.py @@ -0,0 +1,295 @@ +# -*- coding: utf-8 -*- +from datetime import datetime +from typing import Literal, Optional +from uuid import UUID, uuid4 + +from app.models import ApiClient, Person, Post, PostReaction, WorkPackage +from app.models.payload_column_type import PayloadContainer, payload_type +from app.schemas import protocol as protocol_schema +from pydantic import BaseModel +from sqlmodel import Session + + +@payload_type +class TaskPayload(BaseModel): + type: str + + +@payload_type +class SummarizationStoryPayload(TaskPayload): + type: Literal["summarize_story"] = "summarize_story" + story: str + + +@payload_type +class RateSummaryPayload(TaskPayload): + type: Literal["rate_summary"] = "rate_summary" + full_text: str + summary: str + scale: protocol_schema.RatingScale + + +@payload_type +class InitialPromptPayload(TaskPayload): + type: Literal["initial_prompt"] = "initial_prompt" + hint: str + + +@payload_type +class UserReplyPayload(TaskPayload): + type: Literal["user_reply"] = "user_reply" + conversation: protocol_schema.Conversation + hint: str | None + + +@payload_type +class AssistantReplyPayload(TaskPayload): + type: Literal["assistant_reply"] = "assistant_reply" + conversation: protocol_schema.Conversation + + +@payload_type +class PostPayload(BaseModel): + text: str + + +@payload_type +class ReactionPayload(BaseModel): + type: str + + +@payload_type +class RatingReactionPayload(ReactionPayload): + type: Literal["post_rating"] = "post_rating" + rating: str + + +class PromptRepository: + def __init__(self, db: Session, api_client: ApiClient, user: Optional[protocol_schema.User]): + self.db = db + self.api_client = api_client + self.person = self.lookup_person(user) + self.person_id = self.person.id if self.person else None + + def lookup_person(self, user: protocol_schema.User) -> Person: + if not user: + return None + person: Person = ( + self.db.query(Person).filter(Person.api_client_id == self.api_client.id, Person.username == user.id).first() + ) + if person is None: + # user is unknown, create new record + person = Person(username=user.id, display_name=user.display_name, api_client_id=self.api_client.id) + self.db.add(person) + self.db.commit() + self.db.refresh(person) + elif user.display_name and user.display_name != person.display_name: + # we found the user but the display name changed + person.display_name = user.display_name + self.db.add(person) + self.db.commit() + return person + + def validate_post_id(self, post_id: str) -> None: + if not isinstance(post_id, str): + raise TypeError("post_id must be string") + if not post_id: + raise ValueError("post_id must not be empty") + + def bind_frontend_post_id(self, task_id: UUID, post_id: str): + self.validate_post_id(post_id) + + # find work package + work_pack: WorkPackage = ( + self.db.query(WorkPackage) + .filter(WorkPackage.id == task_id, WorkPackage.api_client_id == self.api_client.id) + .first() + ) + if work_pack is None: + raise RuntimeError(f"WorkPackage for task {task_id} not found") + if work_pack.expiry_date is not None and datetime.utcnow() > work_pack.expiry_date: + raise RuntimeError("WorkPackage already expired.") + + # ToDo: check race-condition, transaction + + # check if task thread exits + thread_root = ( + self.db.query(Post) + .filter( + Post.workpackage_id == work_pack.id, + Post.frontend_post_id == post_id, + Post.parent_id is None, + self.api_client == self.api_client, + ) + .one_or_none() + ) + if thread_root is None: + thread_id = uuid4() + thread_root = self.insert_post( + post_id=thread_id, + thread_id=thread_id, + frontend_post_id=post_id, + parent_id=None, + role="system", + workpackage_id=work_pack.id, + payload=None, + payload_type="bind", + ) + return thread_root + + def fetch_post_by_frontend_post_id(self, frontend_post_id: str, fail_if_missing: bool = True) -> Post: + self.validate_post_id(frontend_post_id) + post: Post = ( + self.db.query(Post) + .filter(Post.api_client_id == self.api_client.id, Post.frontend_post_id == frontend_post_id) + .one_or_none() + ) + if fail_if_missing and post is None: + raise RuntimeError(f"Post with post_id {frontend_post_id} not found.") + return post + + def fetch_workpackage_by_postid(self, post_id: str) -> WorkPackage: + self.validate_post_id(post_id) + post = self.fetch_post_by_frontend_post_id(post_id, fail_if_missing=True) + work_pack = self.db.query(WorkPackage).filter(WorkPackage.id == post.workpackage_id).one() + return work_pack + + def store_text_reply(self, reply: protocol_schema.TextReplyToPost, role: str) -> Post: + self.validate_post_id(reply.post_id) + self.validate_post_id(reply.user_post_id) + + # find post with post-id + parent_post: Post = ( + self.db.query(Post) + .filter( + Post.api_client_id == self.api_client.id, + Post.frontend_post_id == reply.post_id, + # Post.person_id == self.person_id + ) + .one_or_none() + ) + + if parent_post is None: + raise RuntimeError(f"Post for post_id {reply.post_id} not found.") + + # create reply post + user_post_id = uuid4() + user_post = self.insert_post( + post_id=user_post_id, + frontend_post_id=reply.user_post_id, + parent_id=parent_post.id, + thread_id=parent_post.thread_id, + workpackage_id=parent_post.workpackage_id, + role=role, + payload=PostPayload(text=reply.text), + ) + return user_post + + def store_rating(self, rating: protocol_schema.PostRating) -> Post: + post = self.fetch_post_by_frontend_post_id(rating.post_id, fail_if_missing=True) + + work_package = self.fetch_workpackage_by_postid(rating.post_id) + work_payload: RateSummaryPayload = work_package.payload.payload + if type(work_payload) != RateSummaryPayload: + raise RuntimeError("work_package payload type missmatch") + + if rating.rating < work_payload.scale.min or rating.rating > work_payload.scale.max: + raise ValueError("Invalid rating value") + + # store reaction to post + reaction_payload = RatingReactionPayload(rating=rating.rating) + reaction = self.insert_reaction(post.id, reaction_payload) + return reaction + + def store_task(self, task: protocol_schema.Task) -> WorkPackage: + payload: TaskPayload = None + match type(task): + case protocol_schema.SummarizeStoryTask: + payload = SummarizationStoryPayload(story=task.story) + + case protocol_schema.RateSummaryTask: + payload = RateSummaryPayload(full_text=task.full_text, summary=task.summary, scale=task.scale) + + case protocol_schema.InitialPromptTask: + payload = InitialPromptPayload(hint=task.hint) + + case protocol_schema.UserReplyTask: + payload = UserReplyPayload(conversation=task.conversation, hint=task.hint) + + case protocol_schema.AssistantReplyTask: + payload = AssistantReplyPayload(type=task.type, conversation=task.conversation) + + case _: + raise RuntimeError( + detail="Invalid task type.", + ) + + wp = self.insert_work_package(payload=payload, id=task.id) + assert wp.id == task.id + return wp + + def insert_work_package(self, payload: TaskPayload, id: UUID = None) -> WorkPackage: + c = PayloadContainer(payload=payload) + wp = WorkPackage( + id=id, + person_id=self.person_id, + payload_type=type(payload).__name__, + payload=c, + api_client_id=self.api_client.id, + ) + self.db.add(wp) + self.db.commit() + self.db.refresh(wp) + return wp + + def insert_post( + self, + *, + post_id: UUID, + frontend_post_id: str, + parent_id: UUID, + thread_id: UUID, + workpackage_id: UUID, + role: str, + payload: PostPayload, + payload_type: str = None, + ) -> Post: + if payload_type is None: + if payload is None: + payload_type = "null" + else: + payload_type = type(payload).__name__ + + post = Post( + id=post_id, + parent_id=parent_id, + thread_id=thread_id, + workpackage_id=workpackage_id, + person_id=self.person_id, + role=role, + frontend_post_id=frontend_post_id, + api_client_id=self.api_client.id, + payload_type=payload_type, + payload=PayloadContainer(payload=payload), + ) + self.db.add(post) + self.db.commit() + self.db.refresh(post) + return post + + def insert_reaction(self, post_id: UUID, payload: ReactionPayload) -> PostReaction: + if self.person_id is None: + raise RuntimeError("User required") + + container = PayloadContainer(payload=payload) + reaction = PostReaction( + post_id=post_id, + person_id=self.person_id, + payload=container, + api_client_id=self.api_client.id, + payload_type=type(payload).__name__, + ) + self.db.add(reaction) + self.db.commit() + self.db.refresh(reaction) + return reaction