From da1c81d2c9ffaf748a94dddce1de8ea4a014b341 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Fri, 27 Jan 2023 00:54:29 +0100 Subject: [PATCH] Add LabelDescription list to labeling tasks, make +1/-1 emojis exclusive (#947) * add LabelDescription list to labeling tasks * make +1 & -1 emoji exclusive (only one of both or none) * add red_flag emoji to message when reported * fix task's valid labels * fix typo --- backend/oasst_backend/api/v1/text_labels.py | 29 +++++++++- backend/oasst_backend/config.py | 51 ++++++++++++++-- backend/oasst_backend/models/message.py | 6 ++ backend/oasst_backend/prompt_repository.py | 35 ++++++++--- backend/oasst_backend/schemas/text_labels.py | 11 +--- backend/oasst_backend/tree_manager.py | 33 +++++++---- oasst-shared/oasst_shared/schemas/protocol.py | 58 ++++++++++++------- 7 files changed, 169 insertions(+), 54 deletions(-) diff --git a/backend/oasst_backend/api/v1/text_labels.py b/backend/oasst_backend/api/v1/text_labels.py index dc6cc889..2025fd4c 100644 --- a/backend/oasst_backend/api/v1/text_labels.py +++ b/backend/oasst_backend/api/v1/text_labels.py @@ -3,10 +3,11 @@ from fastapi.security.api_key import APIKey from loguru import logger from oasst_backend.api import deps from oasst_backend.prompt_repository import PromptRepository -from oasst_backend.schemas.text_labels import LabelOption, ValidLabelsResponse +from oasst_backend.schemas.text_labels import LabelDescription, ValidLabelsResponse from oasst_backend.utils.database_utils import CommitMode, managed_tx_function from oasst_shared.exceptions import OasstError from oasst_shared.schemas import protocol as protocol_schema +from oasst_shared.schemas.protocol import TextLabel from starlette.status import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST router = APIRouter() @@ -45,7 +46,29 @@ def label_text( def get_valid_lables() -> ValidLabelsResponse: return ValidLabelsResponse( valid_labels=[ - LabelOption(name=l.value, display_text=l.display_text, help_text=l.help_text) - for l in protocol_schema.TextLabel + LabelDescription(name=l.value, widget=l.widget.value, display_text=l.display_text, help_text=l.help_text) + for l in TextLabel + ] + ) + + +@router.get("/report_labels") +def get_report_lables() -> ValidLabelsResponse: + report_labels = [ + TextLabel.spam, + TextLabel.not_appropriate, + TextLabel.pii, + TextLabel.hate_speech, + TextLabel.sexual_content, + TextLabel.moral_judgement, + TextLabel.political_content, + TextLabel.toxicity, + TextLabel.violence, + TextLabel.quality, + ] + return ValidLabelsResponse( + valid_labels=[ + LabelDescription(name=l.value, widget=l.widget.value, display_text=l.display_text, help_text=l.help_text) + for l in report_labels ] ) diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index 9952c654..157845d7 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -1,7 +1,7 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Union -from oasst_shared.schemas import protocol as protocol_schema +from oasst_shared.schemas.protocol import TextLabel from pydantic import AnyHttpUrl, BaseModel, BaseSettings, FilePath, PostgresDsn, validator @@ -46,13 +46,56 @@ class TreeManagerConfiguration(BaseModel): num_required_rankings: int = 3 """Number of rankings in which the message participated.""" - mandatory_labels_initial_prompt: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam] + labels_initial_prompt: list[TextLabel] = [ + TextLabel.spam, + TextLabel.quality, + TextLabel.helpfulness, + TextLabel.creativity, + TextLabel.humor, + TextLabel.toxicity, + TextLabel.violence, + TextLabel.not_appropriate, + TextLabel.pii, + TextLabel.hate_speech, + TextLabel.sexual_content, + ] + + labels_assistant_reply: list[TextLabel] = [ + TextLabel.spam, + TextLabel.fails_task, + TextLabel.quality, + TextLabel.helpfulness, + TextLabel.creativity, + TextLabel.humor, + TextLabel.toxicity, + TextLabel.violence, + TextLabel.not_appropriate, + TextLabel.pii, + TextLabel.hate_speech, + TextLabel.sexual_content, + ] + + labels_prompter_reply: list[TextLabel] = [ + TextLabel.spam, + TextLabel.quality, + TextLabel.helpfulness, + TextLabel.humor, + TextLabel.creativity, + TextLabel.toxicity, + TextLabel.violence, + TextLabel.not_appropriate, + TextLabel.pii, + TextLabel.hate_speech, + TextLabel.sexual_content, + ] + + mandatory_labels_initial_prompt: Optional[list[TextLabel]] = [TextLabel.spam] """Mandatory labels in text-labeling tasks for initial prompts.""" - mandatory_labels_assistant_reply: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam] + mandatory_labels_assistant_reply: Optional[list[TextLabel]] = [TextLabel.spam] """Mandatory labels in text-labeling tasks for assistant replies.""" - mandatory_labels_prompter_reply: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam] + mandatory_labels_prompter_reply: Optional[list[TextLabel]] = [TextLabel.spam] """Mandatory labels in text-labeling tasks for prompter replies.""" rank_prompter_replies: bool = False diff --git a/backend/oasst_backend/models/message.py b/backend/oasst_backend/models/message.py index 5f323d5d..24fafc01 100644 --- a/backend/oasst_backend/models/message.py +++ b/backend/oasst_backend/models/message.py @@ -64,6 +64,12 @@ class Message(SQLModel, table=True): if not self.payload or not isinstance(self.payload.payload, MessagePayload): raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE, HTTPStatus.INTERNAL_SERVER_ERROR) + def has_emoji(self, emoji_code: str) -> bool: + return self.emojis and emoji_code in self.emojis and self.emojis[emoji_code] > 0 + + def has_user_emoji(self, emoji_code: str) -> bool: + return self._user_emojis and emoji_code in self._user_emojis + @property def text(self) -> str: self.ensure_is_message() diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index b31b53d7..d3f655fc 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -461,15 +461,22 @@ class PromptRepository: ) if message_id: - message = self.fetch_message(message_id) - if task: + if not task: + if text_labels.is_report is True: + message = self.handle_message_emoji( + message_id, protocol_schema.EmojiOp.add, protocol_schema.EmojiCode.red_flag + ) + + # update existing record for repeated updates (same user no task associated) + existing_text_label = self.fetch_non_task_text_labels(message_id, self.user_id) + if existing_text_label is not None: + existing_text_label.labels = text_labels.labels + model = existing_text_label + + else: + message = self.fetch_message(message_id) message.review_count += 1 self.db.add(message) - # for the same User id with no task id associated with the message, then update existing record for repeated updates - existing_text_label = self.fetch_non_task_text_labels(message_id, self.user_id) - if existing_text_label is not None: - existing_text_label.labels = text_labels.labels - model = existing_text_label self.db.add(model) return model, task, message @@ -936,6 +943,20 @@ WHERE message.id = cc.id; op = protocol_schema.EmojiOp.add if op == protocol_schema.EmojiOp.add: + # hard coded exclusivity of thumbs_up & thumbs_down + if emoji == protocol_schema.EmojiCode.thumbs_up and message.has_user_emoji( + protocol_schema.EmojiCode.thumbs_down.value + ): + message = self.handle_message_emoji( + message_id, protocol_schema.EmojiOp.remove, protocol_schema.EmojiCode.thumbs_down + ) + elif emoji == protocol_schema.EmojiCode.thumbs_down and message.has_user_emoji( + protocol_schema.EmojiCode.thumbs_up.value + ): + message = self.handle_message_emoji( + message_id, protocol_schema.EmojiOp.remove, protocol_schema.EmojiCode.thumbs_up + ) + # insert emoji record & increment count message_emoji = MessageEmoji(message_id=message.id, user_id=self.user_id, emoji=emoji) self.db.add(message_emoji) diff --git a/backend/oasst_backend/schemas/text_labels.py b/backend/oasst_backend/schemas/text_labels.py index 9135c558..e846d8f4 100644 --- a/backend/oasst_backend/schemas/text_labels.py +++ b/backend/oasst_backend/schemas/text_labels.py @@ -1,13 +1,6 @@ -from typing import Optional - +from oasst_shared.schemas.protocol import LabelDescription from pydantic import BaseModel -class LabelOption(BaseModel): - name: str - display_text: str - help_text: Optional[str] - - class ValidLabelsResponse(BaseModel): - valid_labels: list[LabelOption] + valid_labels: list[LabelDescription] diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index 89a51807..77419184 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -94,8 +94,6 @@ class TreeManagerStats(pydantic.BaseModel): class TreeManager: - _all_text_labels = list(map(lambda x: x.value, protocol_schema.TextLabel)) - def __init__( self, db: Session, @@ -216,6 +214,15 @@ class TreeManager: incomplete_rankings=incomplete_rankings, ) + @staticmethod + def _get_label_descriptions(valid_labels: list[TextLabels]) -> list[protocol_schema.LabelDescription]: + return [ + protocol_schema.LabelDescription( + name=l.value, widget=l.widget.value, display_text=l.display_text, help_text=l.help_text + ) + for l in valid_labels + ] + def next_task( self, desired_task_type: protocol_schema.TaskRequestType = protocol_schema.TaskRequestType.random, @@ -356,14 +363,14 @@ class TreeManager: label_mode = protocol_schema.LabelTaskMode.full label_disposition = protocol_schema.LabelTaskDisposition.quality - valid_labels = self._all_text_labels if message.role == "assistant": + valid_labels = self.cfg.labels_assistant_reply if ( desired_task_type == protocol_schema.TaskRequestType.random and random.random() > self.cfg.p_full_labeling_review_reply_assistant ): - valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)) + valid_labels = self.cfg.mandatory_labels_assistant_reply label_mode = protocol_schema.LabelTaskMode.simple label_disposition = protocol_schema.LabelTaskDisposition.spam @@ -372,27 +379,30 @@ class TreeManager: message_id=message.id, conversation=conversation, reply=message.text, - valid_labels=valid_labels, + valid_labels=list(map(lambda x: x.value, valid_labels)), mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)), mode=label_mode, disposition=label_disposition, + labels=self._get_label_descriptions(valid_labels), ) else: + valid_labels = self.cfg.labels_prompter_reply if ( desired_task_type == protocol_schema.TaskRequestType.random and random.random() > self.cfg.p_full_labeling_review_reply_prompter ): - valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)) + valid_labels = self.cfg.mandatory_labels_prompter_reply label_mode = protocol_schema.LabelTaskMode.simple logger.info(f"Generating a LabelPrompterReplyTask. ({label_mode=:s})") task = protocol_schema.LabelPrompterReplyTask( message_id=message.id, conversation=conversation, reply=message.text, - valid_labels=valid_labels, + valid_labels=list(map(lambda x: x.value, valid_labels)), mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)), mode=label_mode, disposition=label_disposition, + labels=self._get_label_descriptions(valid_labels), ) parent_message_id = message.id @@ -456,10 +466,10 @@ class TreeManager: label_mode = protocol_schema.LabelTaskMode.full label_disposition = protocol_schema.LabelTaskDisposition.quality - valid_labels = self._all_text_labels + valid_labels = self.cfg.labels_initial_prompt if random.random() > self.cfg.p_full_labeling_review_prompt: - valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt)) + valid_labels = self.cfg.mandatory_labels_initial_prompt label_mode = protocol_schema.LabelTaskMode.simple label_disposition = protocol_schema.LabelTaskDisposition.spam @@ -467,10 +477,11 @@ class TreeManager: task = protocol_schema.LabelInitialPromptTask( message_id=message.id, prompt=message.text, - valid_labels=valid_labels, + valid_labels=list(map(lambda x: x.value, valid_labels)), mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt)), mode=label_mode, disposition=label_disposition, + labels=self._get_label_descriptions(valid_labels), ) parent_message_id = message.id @@ -577,7 +588,7 @@ class TreeManager: _, task, msg = pr.store_text_labels(interaction) - # if it was a respones for a task, check if we have enough reviews to calc review_result + # if it was a response for a task, check if we have enough reviews to calc review_result if task and msg: reviews = self.query_reviews_for_message(msg.id) acceptance_score = self._calculate_acceptance(reviews) diff --git a/oasst-shared/oasst_shared/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py index 31caa340..22a4adfb 100644 --- a/oasst-shared/oasst_shared/schemas/protocol.py +++ b/oasst-shared/oasst_shared/schemas/protocol.py @@ -231,12 +231,20 @@ class LabelTaskDisposition(str, enum.Enum): spam = "spam" +class LabelDescription(BaseModel): + name: str + widget: str + display_text: str + help_text: Optional[str] + + class AbstractLabelTask(Task): message_id: UUID valid_labels: list[str] mandatory_labels: Optional[list[str]] mode: Optional[LabelTaskMode] disposition: Optional[LabelTaskDisposition] + labels: Optional[list[LabelDescription]] class LabelInitialPromptTask(AbstractLabelTask): @@ -324,39 +332,48 @@ class MessageRanking(Interaction): ranking: conlist(item_type=int, min_items=1) +class LabelWidget(str, enum.Enum): + yes_no = "yes_no" + flag = "flag" + likert = "likert" + + class TextLabel(str, enum.Enum): """A label for a piece of text.""" - def __new__(cls, label: str, display_text: str = "", help_text: str = None): + def __new__(cls, label: str, widget: LabelWidget, display_text: str = "", help_text: str = None): obj = str.__new__(cls, label) obj._value_ = label + obj.widget = widget obj.display_text = display_text obj.help_text = help_text return obj - spam = "spam", "Seems to be intentionally low-quality or irrelevant" - fails_task = "fails_task", "Fails to follow the correct instruction / task" - not_appropriate = "not_appropriate", "Inappropriate for customer assistant" - violence = "violence", "Encourages or fails to discourage violence/abuse/terrorism/self-harm" - excessive_harm = ( - "excessive_harm", - "Content likely to cause excessive harm not justifiable in the context", - "Harm refers to physical or mental damage or injury to someone or something. Excessive refers to a reasonable threshold of harm in the context, for instance damaging skin is not excessive in the context of surgery.", - ) - sexual_content = "sexual_content", "Contains sexual content" - toxicity = "toxicity", "Contains rude, abusive, profane or insulting content" - moral_judgement = "moral_judgement", "Expresses moral judgement" - political_content = "political_content", "Expresses political views" - humor = "humor", "Contains humorous content including sarcasm" + # yes/no questions + spam = "spam", LabelWidget.yes_no, "Seems to be intentionally low-quality or irrelevant" + fails_task = "fails_task", LabelWidget.yes_no, "Fails to follow the correct instruction / task" + + # flags + pii = "pii", LabelWidget.flag, "Contains personal identifiable information (PII)" + not_appropriate = "not_appropriate", LabelWidget.flag, "Inappropriate" hate_speech = ( "hate_speech", + LabelWidget.flag, "Content is abusive or threatening and expresses prejudice against a protected characteristic", - "Prejudice refers to preconceived views not based on reason. Protected characteristics include gender, ethnicity, religion, sexual orientation, and similar characteristics.", + "Prejudice refers to preconceived views not based on reason. Protected characteristics " + "include gender, ethnicity, religion, sexual orientation, and similar characteristics.", ) - threat = "threat", "Contains a threat against a person or persons" - misleading = "misleading", "Contains text which is incorrect or misleading" - helpful = "helpful", "Completes the task to a high standard" - creative = "creative", "Expresses creativity in responding to the task" + sexual_content = "sexual_content", LabelWidget.flag, "Contains sexual content" + moral_judgement = "moral_judgement", LabelWidget.flag, "Expresses moral judgement" + political_content = "political_content", LabelWidget.flag, "Expresses political views" + + # likert + quality = "quality", LabelWidget.likert, "Overall subjective quality rating of the message" + toxicity = "toxicity", LabelWidget.likert, "Rude, abusive, profane or insulting content" + humor = "humor", LabelWidget.likert, "Humorous content including sarcasm" + helpfulness = "helpfulness", LabelWidget.likert, "Helpfulness of the message" + creativity = "creativity", LabelWidget.likert, "Creativity" + violence = "violence", LabelWidget.likert, "Violence/abuse/terrorism/self-harm" class TextLabels(Interaction): @@ -367,6 +384,7 @@ class TextLabels(Interaction): labels: dict[TextLabel, float] message_id: UUID task_id: Optional[UUID] + is_report: Optional[bool] @property def has_message_id(self) -> bool: