From b20cee5685bdf805554208ef9bdca4059f5bb40c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Fri, 16 Dec 2022 14:53:34 +0100 Subject: [PATCH] tested initial PromptRepository --- backend/app/prompt_repository.py | 123 ++++++++++++++++++++----------- 1 file changed, 80 insertions(+), 43 deletions(-) diff --git a/backend/app/prompt_repository.py b/backend/app/prompt_repository.py index 85f372d3..0baac989 100644 --- a/backend/app/prompt_repository.py +++ b/backend/app/prompt_repository.py @@ -48,6 +48,11 @@ class AssistantReplyPayload(TaskPayload): conversation: protocol_schema.Conversation +@payload_type +class PostPayload(BaseModel): + text: str + + @payload_type class ReactionPayload(BaseModel): type: str @@ -67,10 +72,10 @@ class PromptRepository: self.person_id = self.person.id if self.person else None def lookup_person(self, user: protocol_schema.User) -> Person: + if not user: + return None person: Person = ( - self.db.query(Person) - .filter(Person.api_client_id == self.api_client.id and Person.username == user.id) - .first() + self.db.query(Person).filter(Person.api_client_id == self.api_client.id, Person.username == user.id).first() ) if person is None: # user is unknown, create new record @@ -97,7 +102,7 @@ class PromptRepository: # find work package work_pack: WorkPackage = ( self.db.query(WorkPackage) - .filter(WorkPackage.id == task_id and WorkPackage.api_client_id == self.api_client.id) + .filter(WorkPackage.id == task_id, WorkPackage.api_client_id == self.api_client.id) .first() ) if work_pack is None: @@ -111,35 +116,32 @@ class PromptRepository: thread_root = ( self.db.query(Post) .filter( - Post.workpackage_id == work_pack.id - and Post.frontend_post_id == post_id - and Post.parent_id is None - and self.api_client == self.api_client + Post.workpackage_id == work_pack.id, + Post.frontend_post_id == post_id, + Post.parent_id is None, + self.api_client == self.api_client, ) .one_or_none() ) if thread_root is None: thread_id = uuid4() - thread_root = Post( - id=thread_id, + thread_root = self.insert_post( + post_id=thread_id, thread_id=thread_id, - role="system", - person_id=work_pack.person_id, - workpackage_id=work_pack.id, frontend_post_id=post_id, - api_client_id=self.api_client.id, + parent_id=None, + role="system", + workpackage_id=work_pack.id, + payload=None, payload_type="bind", ) - self.db.add(thread_root) - self.db.commit() - self.db.refresh(thread_root) return thread_root def fetch_post_by_frontend_post_id(self, frontend_post_id: str, fail_if_missing: bool = True) -> Post: self.validate_post_id(frontend_post_id) post: Post = ( self.db.query(Post) - .filter(Post.api_client_id == self.api_client.id and Post.frontend_post_id == frontend_post_id) + .filter(Post.api_client_id == self.api_client.id, Post.frontend_post_id == frontend_post_id) .one_or_none() ) if fail_if_missing and post is None: @@ -160,31 +162,27 @@ class PromptRepository: parent_post: Post = ( self.db.query(Post) .filter( - Post.api_client_id == self.api_client.id - and Post.frontend_post_id == reply.post_id - and Post.person_id == self.person_id + Post.api_client_id == self.api_client.id, + Post.frontend_post_id == reply.post_id, + # Post.person_id == self.person_id ) .one_or_none() ) + if parent_post is None: raise RuntimeError(f"Post for post_id {reply.post_id} not found.") # create reply post user_post_id = uuid4() - - user_post = Post( - id=user_post_id, + user_post = self.insert_post( + post_id=user_post_id, + frontend_post_id=reply.user_post_id, parent_id=parent_post.id, thread_id=parent_post.thread_id, workpackage_id=parent_post.workpackage_id, - person_id=self.person_id, role=role, - frontend_post_id=reply.user_post_id, - api_client_id=self.api_client.id, + payload=PostPayload(text=reply.text), ) - self.db.add(user_post) - self.db.commit() - self.db.refresh(user_post) return user_post def store_rating(self, rating: protocol_schema.PostRating) -> Post: @@ -203,19 +201,6 @@ class PromptRepository: reaction = self.insert_reaction(post.id, reaction_payload) return reaction - def insert_reaction(self, post_id: UUID, payload: ReactionPayload) -> PostReaction: - if self.person_id is None: - raise RuntimeError("User required") - - container = PayloadContainer(payload=payload) - reaction = PostReaction( - post_id=post_id, person_id=self.person_id, payload=container, api_client_id=self.api_client.id - ) - self.db.add(reaction) - self.db.commit() - self.db.refresh(reaction) - return reaction - def store_task(self, task: protocol_schema.Task) -> WorkPackage: payload: TaskPayload = None match type(task): @@ -256,3 +241,55 @@ class PromptRepository: self.db.commit() self.db.refresh(wp) return wp + + def insert_post( + self, + *, + post_id: UUID, + frontend_post_id: str, + parent_id: UUID, + thread_id: UUID, + workpackage_id: UUID, + role: str, + payload: PostPayload, + payload_type: str = None, + ) -> Post: + if payload_type is None: + if payload is None: + payload_type = "null" + else: + payload_type = type(payload).__name__ + + post = Post( + id=post_id, + parent_id=parent_id, + thread_id=thread_id, + workpackage_id=workpackage_id, + person_id=self.person_id, + role=role, + frontend_post_id=frontend_post_id, + api_client_id=self.api_client.id, + payload_type=payload_type, + payload=PayloadContainer(payload=payload), + ) + self.db.add(post) + self.db.commit() + self.db.refresh(post) + return post + + def insert_reaction(self, post_id: UUID, payload: ReactionPayload) -> PostReaction: + if self.person_id is None: + raise RuntimeError("User required") + + container = PayloadContainer(payload=payload) + reaction = PostReaction( + post_id=post_id, + person_id=self.person_id, + payload=container, + api_client_id=self.api_client.id, + payload_type=type(payload).__name__, + ) + self.db.add(reaction) + self.db.commit() + self.db.refresh(reaction) + return reaction