From a6c957ccfdb500aa48d68beaeb6419b1f81579cb Mon Sep 17 00:00:00 2001 From: Yannic Kilcher Date: Sat, 17 Dec 2022 23:55:55 +0100 Subject: [PATCH] renamed to open assistant --- backend/app/alembic/env.py | 2 +- backend/app/main.py | 4 +- backend/app/oasst/api/deps.py | 57 +++++ backend/app/oasst/api/v1/api.py | 6 + backend/app/oasst/api/v1/tasks.py | 259 ++++++++++++++++++++ backend/app/oasst/database.py | 8 + backend/app/oasst/models/db_payload.py | 94 ++++++++ backend/app/oasst/prompt_repository.py | 311 +++++++++++++++++++++++++ 8 files changed, 738 insertions(+), 3 deletions(-) create mode 100644 backend/app/oasst/api/deps.py create mode 100644 backend/app/oasst/api/v1/api.py create mode 100644 backend/app/oasst/api/v1/tasks.py create mode 100644 backend/app/oasst/database.py create mode 100644 backend/app/oasst/models/db_payload.py create mode 100644 backend/app/oasst/prompt_repository.py diff --git a/backend/app/alembic/env.py b/backend/app/alembic/env.py index 634b79a5..6d2ec0c3 100644 --- a/backend/app/alembic/env.py +++ b/backend/app/alembic/env.py @@ -3,7 +3,7 @@ from logging.config import fileConfig import sqlmodel from alembic import context -from ocgpt import models # noqa: F401 +from oasst import models # noqa: F401 from sqlalchemy import engine_from_config, pool # this is the Alembic Config object, which provides diff --git a/backend/app/main.py b/backend/app/main.py index 02f0198a..f78f608c 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -5,8 +5,8 @@ import alembic.command import alembic.config import fastapi from loguru import logger -from ocgpt.api.v1.api import api_router -from ocgpt.config import settings +from oasst.api.v1.api import api_router +from oasst.config import settings from starlette.middleware.cors import CORSMiddleware app = fastapi.FastAPI(title=settings.PROJECT_NAME, openapi_url=f"{settings.API_V1_STR}/openapi.json") diff --git a/backend/app/oasst/api/deps.py b/backend/app/oasst/api/deps.py new file mode 100644 index 00000000..bfa54931 --- /dev/null +++ b/backend/app/oasst/api/deps.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +from secrets import token_hex +from typing import Generator +from uuid import UUID + +from fastapi import HTTPException, Security +from fastapi.security.api_key import APIKey, APIKeyHeader, APIKeyQuery +from loguru import logger +from oasst.config import settings +from oasst.database import engine +from oasst.models import ApiClient +from sqlmodel import Session +from starlette.status import HTTP_403_FORBIDDEN + + +def get_db() -> Generator: + with Session(engine) as db: + yield db + + +api_key_query = APIKeyQuery(name="api_key", auto_error=False) +api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) + + +async def get_api_key( + api_key_query: str = Security(api_key_query), + api_key_header: str = Security(api_key_header), +): + if api_key_query: + return api_key_query + else: + return api_key_header + + +def api_auth( + api_key: APIKey, + db: Session, +) -> ApiClient: + + if api_key is not None: + if settings.ALLOW_ANY_API_KEY: + # make sure that a dummy api key exits in db (foreign key references) + ANY_API_KEY_ID = UUID("00000000-1111-2222-3333-444444444444") + api_client: ApiClient = db.query(ApiClient).filter(ApiClient.id == ANY_API_KEY_ID).first() + if api_client is None: + token = token_hex(32) + logger.info(f"ANY_API_KEY missing, inserting api_key: {token}") + api_client = ApiClient(id=ANY_API_KEY_ID, api_key=token, description="ANY_API_KEY, random token") + db.add(api_client) + db.commit() + return api_client + + api_client = db.query(ApiClient).filter(ApiClient.api_key == api_key).first() + if api_client is not None and api_client.enabled: + return api_client + + raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Could not validate credentials") diff --git a/backend/app/oasst/api/v1/api.py b/backend/app/oasst/api/v1/api.py new file mode 100644 index 00000000..3d568cb9 --- /dev/null +++ b/backend/app/oasst/api/v1/api.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- +from fastapi import APIRouter +from oasst.api.v1 import tasks + +api_router = APIRouter() +api_router.include_router(tasks.router, prefix="/tasks", tags=["tasks"]) diff --git a/backend/app/oasst/api/v1/tasks.py b/backend/app/oasst/api/v1/tasks.py new file mode 100644 index 00000000..41f01f3c --- /dev/null +++ b/backend/app/oasst/api/v1/tasks.py @@ -0,0 +1,259 @@ +# -*- coding: utf-8 -*- +import random +from typing import Any +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException +from fastapi.security.api_key import APIKey +from loguru import logger +from oasst.api import deps +from oasst.models.db_payload import TaskPayload +from oasst.prompt_repository import PromptRepository +from oasst.schemas import protocol as protocol_schema +from sqlmodel import Session +from starlette.status import HTTP_400_BAD_REQUEST + +router = APIRouter() + + +def generate_task(request: protocol_schema.TaskRequest) -> protocol_schema.Task: + match (request.type): + case protocol_schema.TaskRequestType.random: + logger.info("Frontend requested a random task.") + while request.type == protocol_schema.TaskRequestType.random: + request.type = random.choice(list(protocol_schema.TaskRequestType)).value + return generate_task(request) + case protocol_schema.TaskRequestType.summarize_story: + logger.info("Generating a SummarizeStoryTask.") + task = protocol_schema.SummarizeStoryTask( + story="This is a story. A very long story. So long, it needs to be summarized.", + ) + case protocol_schema.TaskRequestType.rate_summary: + logger.info("Generating a RateSummaryTask.") + task = protocol_schema.RateSummaryTask( + full_text="This is a story. A very long story. So long, it needs to be summarized.", + summary="This is a summary.", + scale=protocol_schema.RatingScale(min=1, max=5), + ) + case protocol_schema.TaskRequestType.initial_prompt: + logger.info("Generating an InitialPromptTask.") + task = protocol_schema.InitialPromptTask( + hint="Ask the assistant about a current event." # this is optional + ) + case protocol_schema.TaskRequestType.user_reply: + logger.info("Generating a UserReplyTask.") + task = protocol_schema.UserReplyTask( + conversation=protocol_schema.Conversation( + messages=[ + protocol_schema.ConversationMessage( + text="Hey, assistant, what's going on in the world?", + is_assistant=False, + ), + protocol_schema.ConversationMessage( + text="I'm not sure I understood correctly, could you rephrase that?", + is_assistant=True, + ), + ], + ) + ) + case protocol_schema.TaskRequestType.assistant_reply: + logger.info("Generating a AssistantReplyTask.") + task = protocol_schema.AssistantReplyTask( + conversation=protocol_schema.Conversation( + messages=[ + protocol_schema.ConversationMessage( + text="Hey, assistant, write me an English essay about water.", + is_assistant=False, + ), + ], + ) + ) + case protocol_schema.TaskRequestType.rank_initial_prompts: + logger.info("Generating a RankInitialPromptsTask.") + task = protocol_schema.RankInitialPromptsTask( + prompts=[ + "Please write a story about a time you were happy.", + "Please write a story about a time you were sad.", + ] + ) + case protocol_schema.TaskRequestType.rank_user_replies: + logger.info("Generating a RankUserRepliesTask.") + task = protocol_schema.RankUserRepliesTask( + conversation=protocol_schema.Conversation( + messages=[ + protocol_schema.ConversationMessage( + text="Hey, assistant, what's going on in the world?", + is_assistant=False, + ), + protocol_schema.ConversationMessage( + text="I'm not sure I understood correctly, could you rephrase that?", + is_assistant=True, + ), + ], + ), + replies=[ + "Oh come oooooon!", + "What are the news?", + ], + ) + + case protocol_schema.TaskRequestType.rank_assistant_replies: + logger.info("Generating a RankAssistantRepliesTask.") + task = protocol_schema.RankAssistantRepliesTask( + conversation=protocol_schema.Conversation( + messages=[ + protocol_schema.ConversationMessage( + text="Hey, assistant, what's going on in the world?", + is_assistant=False, + ), + ], + ), + replies=[ + "I'm not sure I understood correctly, could you rephrase that?", + "The world is fine. All good.", + "Crap is hitting the fan. Start farming.", + ], + ) + case _: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail="Invalid request type.", + ) + + logger.info(f"Generated {task=}.") + + return task + + +@router.post("/", response_model=protocol_schema.AnyTask) # work with Union once more types are added +def request_task( + *, + db: Session = Depends(deps.get_db), + api_key: APIKey = Depends(deps.get_api_key), + request: protocol_schema.TaskRequest, +) -> Any: + """ + Create new task. + """ + api_client = deps.api_auth(api_key, db) + + try: + task = generate_task(request) + + pr = PromptRepository(db, api_client, request.user) + pr.store_task(task) + + except Exception: + logger.exception("Failed to generate task.") + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + ) + return task + + +@router.post("/{task_id}/ack") +def acknowledge_task( + *, + db: Session = Depends(deps.get_db), + api_key: APIKey = Depends(deps.get_api_key), + task_id: UUID, + ack_request: protocol_schema.TaskAck, +) -> Any: + """ + The frontend acknowledges a task. + """ + + api_client = deps.api_auth(api_key, db) + + try: + pr = PromptRepository(db, api_client, user=None) + + # here we store the post id in the database for the task + pr.bind_frontend_post_id(task_id=task_id, post_id=ack_request.post_id) + logger.info(f"Frontend acknowledges task {task_id=}, {ack_request=}.") + + except Exception: + logger.exception("Failed to acknowledge task.") + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + ) + return {} + + +@router.post("/{task_id}/nack") +def acknowledge_task_failure( + *, + db: Session = Depends(deps.get_db), + api_key: APIKey = Depends(deps.get_api_key), + task_id: UUID, + nack_request: protocol_schema.TaskNAck, +) -> Any: + """ + The frontend reports failure to implement a task. + """ + deps.api_auth(api_key, db) + + logger.info(f"Frontend reports failure to implement task {task_id=}, {nack_request=}.") + # here we would store the post id in the database for the task + return {} + + +@router.post("/interaction") +def post_interaction( + *, + db: Session = Depends(deps.get_db), + api_key: APIKey = Depends(deps.get_api_key), + interaction: protocol_schema.AnyInteraction, +) -> Any: + """ + The frontend reports an interaction. + """ + api_client = deps.api_auth(api_key, db) + + try: + pr = PromptRepository(db, api_client, user=interaction.user) + + match type(interaction): + case protocol_schema.TextReplyToPost: + logger.info( + f"Frontend reports text reply to {interaction.post_id=} with {interaction.text=} by {interaction.user=}." + ) + + work_package = pr.fetch_workpackage_by_postid(interaction.post_id) + work_payload: TaskPayload = work_package.payload.payload + logger.info(f"found task work package in db: {work_payload}") + + # here we store the text reply in the database + # ToDo: role user or agent? + pr.store_text_reply(interaction, role="unknown") + + return protocol_schema.TaskDone() + case protocol_schema.PostRating: + logger.info( + f"Frontend reports rating of {interaction.post_id=} with {interaction.rating=} by {interaction.user=}." + ) + + # here we store the rating in the database + pr.store_rating(interaction) + + return protocol_schema.TaskDone() + case protocol_schema.PostRanking: + logger.info( + f"Frontend reports ranking of {interaction.post_id=} with {interaction.ranking=} by {interaction.user=}." + ) + + # TODO: check if the ranking is valid + pr.store_ranking(interaction) + # here we would store the ranking in the database + return protocol_schema.TaskDone() + case _: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail="Invalid response type.", + ) + + except Exception: + logger.exception("Interaction request failed.") + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + ) diff --git a/backend/app/oasst/database.py b/backend/app/oasst/database.py new file mode 100644 index 00000000..ca729f4e --- /dev/null +++ b/backend/app/oasst/database.py @@ -0,0 +1,8 @@ +# -*- coding: utf-8 -*- +from oasst.config import settings +from sqlmodel import create_engine + +if settings.DATABASE_URI is None: + raise ValueError("DATABASE_URI is not set") + +engine = create_engine(settings.DATABASE_URI) diff --git a/backend/app/oasst/models/db_payload.py b/backend/app/oasst/models/db_payload.py new file mode 100644 index 00000000..b01cecce --- /dev/null +++ b/backend/app/oasst/models/db_payload.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +from typing import Literal + +from oasst.models.payload_column_type import payload_type +from oasst.schemas import protocol as protocol_schema +from pydantic import BaseModel + + +@payload_type +class TaskPayload(BaseModel): + type: str + + +@payload_type +class SummarizationStoryPayload(TaskPayload): + type: Literal["summarize_story"] = "summarize_story" + story: str + + +@payload_type +class RateSummaryPayload(TaskPayload): + type: Literal["rate_summary"] = "rate_summary" + full_text: str + summary: str + scale: protocol_schema.RatingScale + + +@payload_type +class InitialPromptPayload(TaskPayload): + type: Literal["initial_prompt"] = "initial_prompt" + hint: str + + +@payload_type +class UserReplyPayload(TaskPayload): + type: Literal["user_reply"] = "user_reply" + conversation: protocol_schema.Conversation + hint: str | None + + +@payload_type +class AssistantReplyPayload(TaskPayload): + type: Literal["assistant_reply"] = "assistant_reply" + conversation: protocol_schema.Conversation + + +@payload_type +class PostPayload(BaseModel): + text: str + + +@payload_type +class ReactionPayload(BaseModel): + type: str + + +@payload_type +class RatingReactionPayload(ReactionPayload): + type: Literal["post_rating"] = "post_rating" + rating: str + + +@payload_type +class RankingReactionPayload(ReactionPayload): + type: Literal["post_ranking"] = "post_ranking" + ranking: list[int] + + +@payload_type +class RankConversationRepliesPayload(TaskPayload): + conversation: protocol_schema.Conversation # the conversation so far + replies: list[str] + + +@payload_type +class RankInitialPromptsPayload(TaskPayload): + """A task to rank a set of initial prompts.""" + + type: Literal["rank_initial_prompts"] = "rank_initial_prompts" + prompts: list[str] + + +@payload_type +class RankUserRepliesPayload(RankConversationRepliesPayload): + """A task to rank a set of user replies to a conversation.""" + + type: Literal["rank_user_replies"] = "rank_user_replies" + + +@payload_type +class RankAssistantRepliesPayload(RankConversationRepliesPayload): + """A task to rank a set of assistant replies to a conversation.""" + + type: Literal["rank_assistant_replies"] = "rank_assistant_replies" diff --git a/backend/app/oasst/prompt_repository.py b/backend/app/oasst/prompt_repository.py new file mode 100644 index 00000000..35f1c9b3 --- /dev/null +++ b/backend/app/oasst/prompt_repository.py @@ -0,0 +1,311 @@ +# -*- coding: utf-8 -*- +from datetime import datetime +from typing import Optional +from uuid import UUID, uuid4 + +import oasst.models.db_payload as db_payload +from loguru import logger +from oasst.models import ApiClient, Person, Post, PostReaction, WorkPackage +from oasst.models.payload_column_type import PayloadContainer +from oasst.schemas import protocol as protocol_schema +from sqlmodel import Session + + +class PromptRepository: + def __init__(self, db: Session, api_client: ApiClient, user: Optional[protocol_schema.User]): + self.db = db + self.api_client = api_client + self.person = self.lookup_person(user) + self.person_id = self.person.id if self.person else None + + def lookup_person(self, user: protocol_schema.User) -> Person: + if not user: + return None + person: Person = ( + self.db.query(Person) + .filter( + Person.api_client_id == self.api_client.id, + Person.username == user.id, + Person.auth_method == user.auth_method, + ) + .first() + ) + if person is None: + # user is unknown, create new record + person = Person(username=user.id, display_name=user.display_name, api_client_id=self.api_client.id) + self.db.add(person) + self.db.commit() + self.db.refresh(person) + elif user.display_name and user.display_name != person.display_name: + # we found the user but the display name changed + person.display_name = user.display_name + self.db.add(person) + self.db.commit() + return person + + def validate_post_id(self, post_id: str) -> None: + if not isinstance(post_id, str): + raise TypeError(f"post_id must be string, not {type(post_id)}") + if not post_id: + raise ValueError("post_id must not be empty") + + def bind_frontend_post_id(self, task_id: UUID, post_id: str): + self.validate_post_id(post_id) + + # find work package + work_pack: WorkPackage = ( + self.db.query(WorkPackage) + .filter(WorkPackage.id == task_id, WorkPackage.api_client_id == self.api_client.id) + .first() + ) + if work_pack is None: + raise KeyError(f"WorkPackage for task {task_id} not found") + if work_pack.expiry_date is not None and datetime.utcnow() > work_pack.expiry_date: + raise RuntimeError("WorkPackage already expired.") + + # ToDo: check race-condition, transaction + + # check if task thread exits + thread_root = ( + self.db.query(Post) + .filter( + Post.workpackage_id == work_pack.id, + Post.frontend_post_id == post_id, + Post.parent_id is None, + Post.api_client_id == self.api_client.id, + ) + .one_or_none() + ) + if thread_root is None: + thread_id = uuid4() + thread_root = self.insert_post( + post_id=thread_id, + thread_id=thread_id, + frontend_post_id=post_id, + parent_id=None, + role="system", + workpackage_id=work_pack.id, + payload=None, + payload_type="bind", + ) + return thread_root + + def fetch_post_by_frontend_post_id(self, frontend_post_id: str, fail_if_missing: bool = True) -> Post: + self.validate_post_id(frontend_post_id) + post: Post = ( + self.db.query(Post) + .filter(Post.api_client_id == self.api_client.id, Post.frontend_post_id == frontend_post_id) + .one_or_none() + ) + if fail_if_missing and post is None: + raise KeyError(f"Post with post_id {frontend_post_id} not found.") + return post + + def fetch_workpackage_by_postid(self, post_id: str) -> WorkPackage: + self.validate_post_id(post_id) + post = self.fetch_post_by_frontend_post_id(post_id, fail_if_missing=True) + work_pack = self.db.query(WorkPackage).filter(WorkPackage.id == post.workpackage_id).one() + return work_pack + + def store_text_reply(self, reply: protocol_schema.TextReplyToPost, role: str) -> Post: + self.validate_post_id(reply.post_id) + self.validate_post_id(reply.user_post_id) + + # find post with post-id + parent_post: Post = ( + self.db.query(Post) + .filter( + Post.api_client_id == self.api_client.id, + Post.frontend_post_id == reply.post_id, + # Post.person_id == self.person_id + ) + .one_or_none() + ) + + if parent_post is None: + raise KeyError(f"Post for post_id {reply.post_id} not found.") + + # create reply post + user_post_id = uuid4() + user_post = self.insert_post( + post_id=user_post_id, + frontend_post_id=reply.user_post_id, + parent_id=parent_post.id, + thread_id=parent_post.thread_id, + workpackage_id=parent_post.workpackage_id, + role=role, + payload=db_payload.PostPayload(text=reply.text), + ) + return user_post + + def store_rating(self, rating: protocol_schema.PostRating) -> PostReaction: + post = self.fetch_post_by_frontend_post_id(rating.post_id, fail_if_missing=True) + + work_package = self.fetch_workpackage_by_postid(rating.post_id) + work_payload: db_payload.RateSummaryPayload = work_package.payload.payload + if type(work_payload) != db_payload.RateSummaryPayload: + raise ValueError( + f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RateSummaryPayload}" + ) + + if rating.rating < work_payload.scale.min or rating.rating > work_payload.scale.max: + raise ValueError(f"Invalid rating value: {rating.rating=} not in {work_payload.scale=}") + + # store reaction to post + reaction_payload = db_payload.RatingReactionPayload(rating=rating.rating) + reaction = self.insert_reaction(post.id, reaction_payload) + logger.info(f"Ranking {rating.rating} stored for work_package {work_package.id}.") + return reaction + + def store_ranking(self, ranking: protocol_schema.PostRanking) -> PostReaction: + post = self.fetch_post_by_frontend_post_id(ranking.post_id, fail_if_missing=True) + + # fetch work_package + work_package = self.fetch_workpackage_by_postid(ranking.post_id) + work_payload: db_payload.RankConversationRepliesPayload | db_payload.RankInitialPromptsPayload = ( + work_package.payload.payload + ) + + match type(work_payload): + + case db_payload.RankUserRepliesPayload | db_payload.RankAssistantRepliesPayload: + # validate ranking + num_replies = len(work_payload.replies) + if sorted(ranking.ranking) != list(range(num_replies)): + raise ValueError( + f"Invalid ranking submitted. Each reply index must appear exactly once ({num_replies=})." + ) + + # store reaction to post + reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking) + reaction = self.insert_reaction(post.id, reaction_payload) + + logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.") + + return reaction + + case db_payload.RankInitialPromptsPayload: + # validate ranking + if sorted(ranking.ranking) != list(range(num_prompts := len(work_payload.prompts))): + raise ValueError( + f"Invalid ranking submitted. Each reply index must appear exactly once ({num_prompts=})." + ) + + # store reaction to post + reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking) + reaction = self.insert_reaction(post.id, reaction_payload) + + logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.") + + return reaction + + case _: + raise ValueError( + f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RankConversationRepliesPayload}" + ) + + def store_task(self, task: protocol_schema.Task) -> WorkPackage: + payload: db_payload.TaskPayload + match type(task): + case protocol_schema.SummarizeStoryTask: + payload = db_payload.SummarizationStoryPayload(story=task.story) + + case protocol_schema.RateSummaryTask: + payload = db_payload.RateSummaryPayload( + full_text=task.full_text, summary=task.summary, scale=task.scale + ) + + case protocol_schema.InitialPromptTask: + payload = db_payload.InitialPromptPayload(hint=task.hint) + + case protocol_schema.UserReplyTask: + payload = db_payload.UserReplyPayload(conversation=task.conversation, hint=task.hint) + + case protocol_schema.AssistantReplyTask: + payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation) + + case protocol_schema.RankInitialPromptsTask: + payload = db_payload.RankInitialPromptsPayload(tpye=task.type, prompts=task.prompts) + + case protocol_schema.RankUserRepliesTask: + payload = db_payload.RankUserRepliesPayload( + tpye=task.type, conversation=task.conversation, replies=task.replies + ) + + case protocol_schema.RankAssistantRepliesTask: + payload = db_payload.RankAssistantRepliesPayload( + tpye=task.type, conversation=task.conversation, replies=task.replies + ) + + case _: + raise ValueError(f"Invalid task type: {type(task)=}") + + wp = self.insert_work_package(payload=payload, id=task.id) + assert wp.id == task.id + return wp + + def insert_work_package(self, payload: db_payload.TaskPayload, id: UUID = None) -> WorkPackage: + c = PayloadContainer(payload=payload) + wp = WorkPackage( + id=id, + person_id=self.person_id, + payload_type=type(payload).__name__, + payload=c, + api_client_id=self.api_client.id, + ) + self.db.add(wp) + self.db.commit() + self.db.refresh(wp) + return wp + + def insert_post( + self, + *, + post_id: UUID, + frontend_post_id: str, + parent_id: UUID, + thread_id: UUID, + workpackage_id: UUID, + role: str, + payload: db_payload.PostPayload, + payload_type: str = None, + ) -> Post: + if payload_type is None: + if payload is None: + payload_type = "null" + else: + payload_type = type(payload).__name__ + + post = Post( + id=post_id, + parent_id=parent_id, + thread_id=thread_id, + workpackage_id=workpackage_id, + person_id=self.person_id, + role=role, + frontend_post_id=frontend_post_id, + api_client_id=self.api_client.id, + payload_type=payload_type, + payload=PayloadContainer(payload=payload), + ) + self.db.add(post) + self.db.commit() + self.db.refresh(post) + return post + + def insert_reaction(self, post_id: UUID, payload: db_payload.ReactionPayload) -> PostReaction: + if self.person_id is None: + raise ValueError("User required") + + container = PayloadContainer(payload=payload) + reaction = PostReaction( + post_id=post_id, + person_id=self.person_id, + payload=container, + api_client_id=self.api_client.id, + payload_type=type(payload).__name__, + ) + self.db.add(reaction) + self.db.commit() + self.db.refresh(reaction) + return reaction