From c0391a6df9aff9bf83e2baa1b59d0d5478ad434b Mon Sep 17 00:00:00 2001 From: James Melvin Ebenezer Date: Sun, 22 Jan 2023 15:38:02 +0530 Subject: [PATCH] fix: redundant row updates with no Task id in text_labels table (#876) * fix: redundant row updates with no Task id in text_labels table * fix: review comments incorporated * fix: better error handling and function name * fix: review comments Co-authored-by: James Melvin --- backend/oasst_backend/api/v1/text_labels.py | 16 +++++++++++----- backend/oasst_backend/prompt_repository.py | 15 +++++++++++++++ 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/backend/oasst_backend/api/v1/text_labels.py b/backend/oasst_backend/api/v1/text_labels.py index c9afd88c..dc6cc889 100644 --- a/backend/oasst_backend/api/v1/text_labels.py +++ b/backend/oasst_backend/api/v1/text_labels.py @@ -4,8 +4,9 @@ from loguru import logger from oasst_backend.api import deps from oasst_backend.prompt_repository import PromptRepository from oasst_backend.schemas.text_labels import LabelOption, ValidLabelsResponse +from oasst_backend.utils.database_utils import CommitMode, managed_tx_function +from oasst_shared.exceptions import OasstError from oasst_shared.schemas import protocol as protocol_schema -from sqlmodel import Session from starlette.status import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST router = APIRouter() @@ -14,20 +15,25 @@ router = APIRouter() @router.post("/", status_code=HTTP_204_NO_CONTENT) def label_text( *, - db: Session = Depends(deps.get_db), api_key: APIKey = Depends(deps.get_api_key), text_labels: protocol_schema.TextLabels, ) -> None: """ Label a piece of text. """ - api_client = deps.api_auth(api_key, db) + + @managed_tx_function(CommitMode.COMMIT) + def store_text_labels(session: deps.Session): + api_client = deps.api_auth(api_key, session) + pr = PromptRepository(session, api_client, client_user=text_labels.user) + pr.store_text_labels(text_labels) try: logger.info(f"Labeling text {text_labels=}.") - pr = PromptRepository(db, api_client, client_user=text_labels.user) - pr.store_text_labels(text_labels) + store_text_labels() + except OasstError: + raise except Exception: logger.exception("Failed to store label.") raise HTTPException( diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 8c259dda..0a0fa61d 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -448,6 +448,11 @@ class PromptRepository: if task: message.review_count += 1 self.db.add(message) + # for the same User id with no task id associated with the message, then update existing record for repeated updates + existing_text_label = self.fetch_non_task_text_labels(message_id, self.user_id) + if existing_text_label is not None: + existing_text_label.labels = text_labels.labels + model = existing_text_label self.db.add(model) return model, task, message @@ -561,6 +566,16 @@ class PromptRepository: raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTP_404_NOT_FOUND) return message + def fetch_non_task_text_labels(self, message_id: UUID, user_id: UUID) -> Optional[TextLabels]: + + query = ( + self.db.query(TextLabels) + .outerjoin(Task, Task.id == TextLabels.id) + .filter(Task.id.is_(None), TextLabels.message_id == message_id, TextLabels.user_id == user_id) + ) + text_label = query.one_or_none() + return text_label + @staticmethod def trace_conversation(messages: list[Message] | dict[UUID, Message], last_message: Message) -> list[Message]: """