diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 95e7867e..c3728dfc 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -397,7 +397,7 @@ class PromptRepository: distinct_threads = distinct_threads.filter(Post.role == require_role) distinct_threads = distinct_threads.subquery() - random_thread = self.db.query(distinct_threads).order_by(func.random()).limit(1).subquery() + random_thread = self.db.query(distinct_threads).order_by(func.random()).limit(1) thread_posts = self.db.query(Post).filter(Post.thread_id.in_(random_thread)).all() return thread_posts @@ -443,7 +443,7 @@ class PromptRepository: if post_role: parent = parent.filter(Post.role == post_role) - parent = parent.order_by(func.random()).limit(1).subquery() + parent = parent.order_by(func.random()).limit(1) replies = self.db.query(Post).filter(Post.parent_id.in_(parent)).order_by(func.random()).limit(max_size).all() if not replies: raise OasstError("No replies found", OasstErrorCode.NO_REPLIES_FOUND)