mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
tested initial PromptRepository
This commit is contained in:
@@ -48,6 +48,11 @@ class AssistantReplyPayload(TaskPayload):
|
||||
conversation: protocol_schema.Conversation
|
||||
|
||||
|
||||
@payload_type
|
||||
class PostPayload(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
@payload_type
|
||||
class ReactionPayload(BaseModel):
|
||||
type: str
|
||||
@@ -67,10 +72,10 @@ class PromptRepository:
|
||||
self.person_id = self.person.id if self.person else None
|
||||
|
||||
def lookup_person(self, user: protocol_schema.User) -> Person:
|
||||
if not user:
|
||||
return None
|
||||
person: Person = (
|
||||
self.db.query(Person)
|
||||
.filter(Person.api_client_id == self.api_client.id and Person.username == user.id)
|
||||
.first()
|
||||
self.db.query(Person).filter(Person.api_client_id == self.api_client.id, Person.username == user.id).first()
|
||||
)
|
||||
if person is None:
|
||||
# user is unknown, create new record
|
||||
@@ -97,7 +102,7 @@ class PromptRepository:
|
||||
# find work package
|
||||
work_pack: WorkPackage = (
|
||||
self.db.query(WorkPackage)
|
||||
.filter(WorkPackage.id == task_id and WorkPackage.api_client_id == self.api_client.id)
|
||||
.filter(WorkPackage.id == task_id, WorkPackage.api_client_id == self.api_client.id)
|
||||
.first()
|
||||
)
|
||||
if work_pack is None:
|
||||
@@ -111,35 +116,32 @@ class PromptRepository:
|
||||
thread_root = (
|
||||
self.db.query(Post)
|
||||
.filter(
|
||||
Post.workpackage_id == work_pack.id
|
||||
and Post.frontend_post_id == post_id
|
||||
and Post.parent_id is None
|
||||
and self.api_client == self.api_client
|
||||
Post.workpackage_id == work_pack.id,
|
||||
Post.frontend_post_id == post_id,
|
||||
Post.parent_id is None,
|
||||
self.api_client == self.api_client,
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
if thread_root is None:
|
||||
thread_id = uuid4()
|
||||
thread_root = Post(
|
||||
id=thread_id,
|
||||
thread_root = self.insert_post(
|
||||
post_id=thread_id,
|
||||
thread_id=thread_id,
|
||||
role="system",
|
||||
person_id=work_pack.person_id,
|
||||
workpackage_id=work_pack.id,
|
||||
frontend_post_id=post_id,
|
||||
api_client_id=self.api_client.id,
|
||||
parent_id=None,
|
||||
role="system",
|
||||
workpackage_id=work_pack.id,
|
||||
payload=None,
|
||||
payload_type="bind",
|
||||
)
|
||||
self.db.add(thread_root)
|
||||
self.db.commit()
|
||||
self.db.refresh(thread_root)
|
||||
return thread_root
|
||||
|
||||
def fetch_post_by_frontend_post_id(self, frontend_post_id: str, fail_if_missing: bool = True) -> Post:
|
||||
self.validate_post_id(frontend_post_id)
|
||||
post: Post = (
|
||||
self.db.query(Post)
|
||||
.filter(Post.api_client_id == self.api_client.id and Post.frontend_post_id == frontend_post_id)
|
||||
.filter(Post.api_client_id == self.api_client.id, Post.frontend_post_id == frontend_post_id)
|
||||
.one_or_none()
|
||||
)
|
||||
if fail_if_missing and post is None:
|
||||
@@ -160,31 +162,27 @@ class PromptRepository:
|
||||
parent_post: Post = (
|
||||
self.db.query(Post)
|
||||
.filter(
|
||||
Post.api_client_id == self.api_client.id
|
||||
and Post.frontend_post_id == reply.post_id
|
||||
and Post.person_id == self.person_id
|
||||
Post.api_client_id == self.api_client.id,
|
||||
Post.frontend_post_id == reply.post_id,
|
||||
# Post.person_id == self.person_id
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
if parent_post is None:
|
||||
raise RuntimeError(f"Post for post_id {reply.post_id} not found.")
|
||||
|
||||
# create reply post
|
||||
user_post_id = uuid4()
|
||||
|
||||
user_post = Post(
|
||||
id=user_post_id,
|
||||
user_post = self.insert_post(
|
||||
post_id=user_post_id,
|
||||
frontend_post_id=reply.user_post_id,
|
||||
parent_id=parent_post.id,
|
||||
thread_id=parent_post.thread_id,
|
||||
workpackage_id=parent_post.workpackage_id,
|
||||
person_id=self.person_id,
|
||||
role=role,
|
||||
frontend_post_id=reply.user_post_id,
|
||||
api_client_id=self.api_client.id,
|
||||
payload=PostPayload(text=reply.text),
|
||||
)
|
||||
self.db.add(user_post)
|
||||
self.db.commit()
|
||||
self.db.refresh(user_post)
|
||||
return user_post
|
||||
|
||||
def store_rating(self, rating: protocol_schema.PostRating) -> Post:
|
||||
@@ -203,19 +201,6 @@ class PromptRepository:
|
||||
reaction = self.insert_reaction(post.id, reaction_payload)
|
||||
return reaction
|
||||
|
||||
def insert_reaction(self, post_id: UUID, payload: ReactionPayload) -> PostReaction:
|
||||
if self.person_id is None:
|
||||
raise RuntimeError("User required")
|
||||
|
||||
container = PayloadContainer(payload=payload)
|
||||
reaction = PostReaction(
|
||||
post_id=post_id, person_id=self.person_id, payload=container, api_client_id=self.api_client.id
|
||||
)
|
||||
self.db.add(reaction)
|
||||
self.db.commit()
|
||||
self.db.refresh(reaction)
|
||||
return reaction
|
||||
|
||||
def store_task(self, task: protocol_schema.Task) -> WorkPackage:
|
||||
payload: TaskPayload = None
|
||||
match type(task):
|
||||
@@ -256,3 +241,55 @@ class PromptRepository:
|
||||
self.db.commit()
|
||||
self.db.refresh(wp)
|
||||
return wp
|
||||
|
||||
def insert_post(
|
||||
self,
|
||||
*,
|
||||
post_id: UUID,
|
||||
frontend_post_id: str,
|
||||
parent_id: UUID,
|
||||
thread_id: UUID,
|
||||
workpackage_id: UUID,
|
||||
role: str,
|
||||
payload: PostPayload,
|
||||
payload_type: str = None,
|
||||
) -> Post:
|
||||
if payload_type is None:
|
||||
if payload is None:
|
||||
payload_type = "null"
|
||||
else:
|
||||
payload_type = type(payload).__name__
|
||||
|
||||
post = Post(
|
||||
id=post_id,
|
||||
parent_id=parent_id,
|
||||
thread_id=thread_id,
|
||||
workpackage_id=workpackage_id,
|
||||
person_id=self.person_id,
|
||||
role=role,
|
||||
frontend_post_id=frontend_post_id,
|
||||
api_client_id=self.api_client.id,
|
||||
payload_type=payload_type,
|
||||
payload=PayloadContainer(payload=payload),
|
||||
)
|
||||
self.db.add(post)
|
||||
self.db.commit()
|
||||
self.db.refresh(post)
|
||||
return post
|
||||
|
||||
def insert_reaction(self, post_id: UUID, payload: ReactionPayload) -> PostReaction:
|
||||
if self.person_id is None:
|
||||
raise RuntimeError("User required")
|
||||
|
||||
container = PayloadContainer(payload=payload)
|
||||
reaction = PostReaction(
|
||||
post_id=post_id,
|
||||
person_id=self.person_id,
|
||||
payload=container,
|
||||
api_client_id=self.api_client.id,
|
||||
payload_type=type(payload).__name__,
|
||||
)
|
||||
self.db.add(reaction)
|
||||
self.db.commit()
|
||||
self.db.refresh(reaction)
|
||||
return reaction
|
||||
|
||||
Reference in New Issue
Block a user