diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index b805a4d7..dacd5f9b 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -160,7 +160,12 @@ class PromptRepository: return message def _validate_task( - self, task: Task, *, task_id: Optional[UUID] = None, frontend_message_id: Optional[str] = None + self, + task: Task, + *, + task_id: Optional[UUID] = None, + frontend_message_id: Optional[str] = None, + check_ack: bool = True, ) -> Task: if task is None: if task_id: @@ -171,7 +176,7 @@ class PromptRepository: if task.expired: raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED) - if not task.ack: + if check_ack and not task.ack: raise OasstError("Task is not acknowledged.", OasstErrorCode.TASK_NOT_ACK) if task.done: raise OasstError("Task already done.", OasstErrorCode.TASK_ALREADY_DONE) @@ -1021,7 +1026,7 @@ WHERE message.id = cc.id; self.ensure_user_is_enabled() task = self.task_repository.fetch_task_by_id(task_id) - self._validate_task(task) + self._validate_task(task, check_ack=False) if not task.collective: task.skipped = True