mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-07-04 17:20:19 +08:00
fix: redundant row updates with no Task id in text_labels table (#876)
* fix: redundant row updates with no Task id in text_labels table * fix: review comments incorporated * fix: better error handling and function name * fix: review comments Co-authored-by: James Melvin <melvin@gameface.ai>
This commit is contained in:
committed by
GitHub
parent
43732442fc
commit
c0391a6df9
@@ -4,8 +4,9 @@ from loguru import logger
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.schemas.text_labels import LabelOption, ValidLabelsResponse
|
||||
from oasst_backend.utils.database_utils import CommitMode, managed_tx_function
|
||||
from oasst_shared.exceptions import OasstError
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST
|
||||
|
||||
router = APIRouter()
|
||||
@@ -14,20 +15,25 @@ router = APIRouter()
|
||||
@router.post("/", status_code=HTTP_204_NO_CONTENT)
|
||||
def label_text(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
text_labels: protocol_schema.TextLabels,
|
||||
) -> None:
|
||||
"""
|
||||
Label a piece of text.
|
||||
"""
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
|
||||
@managed_tx_function(CommitMode.COMMIT)
|
||||
def store_text_labels(session: deps.Session):
|
||||
api_client = deps.api_auth(api_key, session)
|
||||
pr = PromptRepository(session, api_client, client_user=text_labels.user)
|
||||
pr.store_text_labels(text_labels)
|
||||
|
||||
try:
|
||||
logger.info(f"Labeling text {text_labels=}.")
|
||||
pr = PromptRepository(db, api_client, client_user=text_labels.user)
|
||||
pr.store_text_labels(text_labels)
|
||||
store_text_labels()
|
||||
|
||||
except OasstError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to store label.")
|
||||
raise HTTPException(
|
||||
|
||||
@@ -448,6 +448,11 @@ class PromptRepository:
|
||||
if task:
|
||||
message.review_count += 1
|
||||
self.db.add(message)
|
||||
# for the same User id with no task id associated with the message, then update existing record for repeated updates
|
||||
existing_text_label = self.fetch_non_task_text_labels(message_id, self.user_id)
|
||||
if existing_text_label is not None:
|
||||
existing_text_label.labels = text_labels.labels
|
||||
model = existing_text_label
|
||||
|
||||
self.db.add(model)
|
||||
return model, task, message
|
||||
@@ -561,6 +566,16 @@ class PromptRepository:
|
||||
raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTP_404_NOT_FOUND)
|
||||
return message
|
||||
|
||||
def fetch_non_task_text_labels(self, message_id: UUID, user_id: UUID) -> Optional[TextLabels]:
|
||||
|
||||
query = (
|
||||
self.db.query(TextLabels)
|
||||
.outerjoin(Task, Task.id == TextLabels.id)
|
||||
.filter(Task.id.is_(None), TextLabels.message_id == message_id, TextLabels.user_id == user_id)
|
||||
)
|
||||
text_label = query.one_or_none()
|
||||
return text_label
|
||||
|
||||
@staticmethod
|
||||
def trace_conversation(messages: list[Message] | dict[UUID, Message], last_message: Message) -> list[Message]:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user