Add LabelDescription list to labeling tasks, make +1/-1 emojis exclusive (#947)

* add LabelDescription list to labeling tasks

* make +1 & -1 emoji exclusive (only one of both or none)

* add red_flag emoji to message when reported

* fix task's valid labels

* fix typo
This commit is contained in:
Andreas Köpf
2023-01-27 00:54:29 +01:00
committed by GitHub
parent f3ffde47ff
commit da1c81d2c9
7 changed files with 169 additions and 54 deletions
+26 -3
View File
@@ -3,10 +3,11 @@ from fastapi.security.api_key import APIKey
from loguru import logger
from oasst_backend.api import deps
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.schemas.text_labels import LabelOption, ValidLabelsResponse
from oasst_backend.schemas.text_labels import LabelDescription, ValidLabelsResponse
from oasst_backend.utils.database_utils import CommitMode, managed_tx_function
from oasst_shared.exceptions import OasstError
from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.schemas.protocol import TextLabel
from starlette.status import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST
router = APIRouter()
@@ -45,7 +46,29 @@ def label_text(
def get_valid_lables() -> ValidLabelsResponse:
return ValidLabelsResponse(
valid_labels=[
LabelOption(name=l.value, display_text=l.display_text, help_text=l.help_text)
for l in protocol_schema.TextLabel
LabelDescription(name=l.value, widget=l.widget.value, display_text=l.display_text, help_text=l.help_text)
for l in TextLabel
]
)
@router.get("/report_labels")
def get_report_lables() -> ValidLabelsResponse:
report_labels = [
TextLabel.spam,
TextLabel.not_appropriate,
TextLabel.pii,
TextLabel.hate_speech,
TextLabel.sexual_content,
TextLabel.moral_judgement,
TextLabel.political_content,
TextLabel.toxicity,
TextLabel.violence,
TextLabel.quality,
]
return ValidLabelsResponse(
valid_labels=[
LabelDescription(name=l.value, widget=l.widget.value, display_text=l.display_text, help_text=l.help_text)
for l in report_labels
]
)
+47 -4
View File
@@ -1,7 +1,7 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.schemas.protocol import TextLabel
from pydantic import AnyHttpUrl, BaseModel, BaseSettings, FilePath, PostgresDsn, validator
@@ -46,13 +46,56 @@ class TreeManagerConfiguration(BaseModel):
num_required_rankings: int = 3
"""Number of rankings in which the message participated."""
mandatory_labels_initial_prompt: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam]
labels_initial_prompt: list[TextLabel] = [
TextLabel.spam,
TextLabel.quality,
TextLabel.helpfulness,
TextLabel.creativity,
TextLabel.humor,
TextLabel.toxicity,
TextLabel.violence,
TextLabel.not_appropriate,
TextLabel.pii,
TextLabel.hate_speech,
TextLabel.sexual_content,
]
labels_assistant_reply: list[TextLabel] = [
TextLabel.spam,
TextLabel.fails_task,
TextLabel.quality,
TextLabel.helpfulness,
TextLabel.creativity,
TextLabel.humor,
TextLabel.toxicity,
TextLabel.violence,
TextLabel.not_appropriate,
TextLabel.pii,
TextLabel.hate_speech,
TextLabel.sexual_content,
]
labels_prompter_reply: list[TextLabel] = [
TextLabel.spam,
TextLabel.quality,
TextLabel.helpfulness,
TextLabel.humor,
TextLabel.creativity,
TextLabel.toxicity,
TextLabel.violence,
TextLabel.not_appropriate,
TextLabel.pii,
TextLabel.hate_speech,
TextLabel.sexual_content,
]
mandatory_labels_initial_prompt: Optional[list[TextLabel]] = [TextLabel.spam]
"""Mandatory labels in text-labeling tasks for initial prompts."""
mandatory_labels_assistant_reply: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam]
mandatory_labels_assistant_reply: Optional[list[TextLabel]] = [TextLabel.spam]
"""Mandatory labels in text-labeling tasks for assistant replies."""
mandatory_labels_prompter_reply: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam]
mandatory_labels_prompter_reply: Optional[list[TextLabel]] = [TextLabel.spam]
"""Mandatory labels in text-labeling tasks for prompter replies."""
rank_prompter_replies: bool = False
+6
View File
@@ -64,6 +64,12 @@ class Message(SQLModel, table=True):
if not self.payload or not isinstance(self.payload.payload, MessagePayload):
raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE, HTTPStatus.INTERNAL_SERVER_ERROR)
def has_emoji(self, emoji_code: str) -> bool:
return self.emojis and emoji_code in self.emojis and self.emojis[emoji_code] > 0
def has_user_emoji(self, emoji_code: str) -> bool:
return self._user_emojis and emoji_code in self._user_emojis
@property
def text(self) -> str:
self.ensure_is_message()
+28 -7
View File
@@ -461,15 +461,22 @@ class PromptRepository:
)
if message_id:
message = self.fetch_message(message_id)
if task:
if not task:
if text_labels.is_report is True:
message = self.handle_message_emoji(
message_id, protocol_schema.EmojiOp.add, protocol_schema.EmojiCode.red_flag
)
# update existing record for repeated updates (same user no task associated)
existing_text_label = self.fetch_non_task_text_labels(message_id, self.user_id)
if existing_text_label is not None:
existing_text_label.labels = text_labels.labels
model = existing_text_label
else:
message = self.fetch_message(message_id)
message.review_count += 1
self.db.add(message)
# for the same User id with no task id associated with the message, then update existing record for repeated updates
existing_text_label = self.fetch_non_task_text_labels(message_id, self.user_id)
if existing_text_label is not None:
existing_text_label.labels = text_labels.labels
model = existing_text_label
self.db.add(model)
return model, task, message
@@ -936,6 +943,20 @@ WHERE message.id = cc.id;
op = protocol_schema.EmojiOp.add
if op == protocol_schema.EmojiOp.add:
# hard coded exclusivity of thumbs_up & thumbs_down
if emoji == protocol_schema.EmojiCode.thumbs_up and message.has_user_emoji(
protocol_schema.EmojiCode.thumbs_down.value
):
message = self.handle_message_emoji(
message_id, protocol_schema.EmojiOp.remove, protocol_schema.EmojiCode.thumbs_down
)
elif emoji == protocol_schema.EmojiCode.thumbs_down and message.has_user_emoji(
protocol_schema.EmojiCode.thumbs_up.value
):
message = self.handle_message_emoji(
message_id, protocol_schema.EmojiOp.remove, protocol_schema.EmojiCode.thumbs_up
)
# insert emoji record & increment count
message_emoji = MessageEmoji(message_id=message.id, user_id=self.user_id, emoji=emoji)
self.db.add(message_emoji)
+2 -9
View File
@@ -1,13 +1,6 @@
from typing import Optional
from oasst_shared.schemas.protocol import LabelDescription
from pydantic import BaseModel
class LabelOption(BaseModel):
name: str
display_text: str
help_text: Optional[str]
class ValidLabelsResponse(BaseModel):
valid_labels: list[LabelOption]
valid_labels: list[LabelDescription]
+22 -11
View File
@@ -94,8 +94,6 @@ class TreeManagerStats(pydantic.BaseModel):
class TreeManager:
_all_text_labels = list(map(lambda x: x.value, protocol_schema.TextLabel))
def __init__(
self,
db: Session,
@@ -216,6 +214,15 @@ class TreeManager:
incomplete_rankings=incomplete_rankings,
)
@staticmethod
def _get_label_descriptions(valid_labels: list[TextLabels]) -> list[protocol_schema.LabelDescription]:
return [
protocol_schema.LabelDescription(
name=l.value, widget=l.widget.value, display_text=l.display_text, help_text=l.help_text
)
for l in valid_labels
]
def next_task(
self,
desired_task_type: protocol_schema.TaskRequestType = protocol_schema.TaskRequestType.random,
@@ -356,14 +363,14 @@ class TreeManager:
label_mode = protocol_schema.LabelTaskMode.full
label_disposition = protocol_schema.LabelTaskDisposition.quality
valid_labels = self._all_text_labels
if message.role == "assistant":
valid_labels = self.cfg.labels_assistant_reply
if (
desired_task_type == protocol_schema.TaskRequestType.random
and random.random() > self.cfg.p_full_labeling_review_reply_assistant
):
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply))
valid_labels = self.cfg.mandatory_labels_assistant_reply
label_mode = protocol_schema.LabelTaskMode.simple
label_disposition = protocol_schema.LabelTaskDisposition.spam
@@ -372,27 +379,30 @@ class TreeManager:
message_id=message.id,
conversation=conversation,
reply=message.text,
valid_labels=valid_labels,
valid_labels=list(map(lambda x: x.value, valid_labels)),
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)),
mode=label_mode,
disposition=label_disposition,
labels=self._get_label_descriptions(valid_labels),
)
else:
valid_labels = self.cfg.labels_prompter_reply
if (
desired_task_type == protocol_schema.TaskRequestType.random
and random.random() > self.cfg.p_full_labeling_review_reply_prompter
):
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply))
valid_labels = self.cfg.mandatory_labels_prompter_reply
label_mode = protocol_schema.LabelTaskMode.simple
logger.info(f"Generating a LabelPrompterReplyTask. ({label_mode=:s})")
task = protocol_schema.LabelPrompterReplyTask(
message_id=message.id,
conversation=conversation,
reply=message.text,
valid_labels=valid_labels,
valid_labels=list(map(lambda x: x.value, valid_labels)),
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)),
mode=label_mode,
disposition=label_disposition,
labels=self._get_label_descriptions(valid_labels),
)
parent_message_id = message.id
@@ -456,10 +466,10 @@ class TreeManager:
label_mode = protocol_schema.LabelTaskMode.full
label_disposition = protocol_schema.LabelTaskDisposition.quality
valid_labels = self._all_text_labels
valid_labels = self.cfg.labels_initial_prompt
if random.random() > self.cfg.p_full_labeling_review_prompt:
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt))
valid_labels = self.cfg.mandatory_labels_initial_prompt
label_mode = protocol_schema.LabelTaskMode.simple
label_disposition = protocol_schema.LabelTaskDisposition.spam
@@ -467,10 +477,11 @@ class TreeManager:
task = protocol_schema.LabelInitialPromptTask(
message_id=message.id,
prompt=message.text,
valid_labels=valid_labels,
valid_labels=list(map(lambda x: x.value, valid_labels)),
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt)),
mode=label_mode,
disposition=label_disposition,
labels=self._get_label_descriptions(valid_labels),
)
parent_message_id = message.id
@@ -577,7 +588,7 @@ class TreeManager:
_, task, msg = pr.store_text_labels(interaction)
# if it was a respones for a task, check if we have enough reviews to calc review_result
# if it was a response for a task, check if we have enough reviews to calc review_result
if task and msg:
reviews = self.query_reviews_for_message(msg.id)
acceptance_score = self._calculate_acceptance(reviews)
+38 -20
View File
@@ -231,12 +231,20 @@ class LabelTaskDisposition(str, enum.Enum):
spam = "spam"
class LabelDescription(BaseModel):
name: str
widget: str
display_text: str
help_text: Optional[str]
class AbstractLabelTask(Task):
message_id: UUID
valid_labels: list[str]
mandatory_labels: Optional[list[str]]
mode: Optional[LabelTaskMode]
disposition: Optional[LabelTaskDisposition]
labels: Optional[list[LabelDescription]]
class LabelInitialPromptTask(AbstractLabelTask):
@@ -324,39 +332,48 @@ class MessageRanking(Interaction):
ranking: conlist(item_type=int, min_items=1)
class LabelWidget(str, enum.Enum):
yes_no = "yes_no"
flag = "flag"
likert = "likert"
class TextLabel(str, enum.Enum):
"""A label for a piece of text."""
def __new__(cls, label: str, display_text: str = "", help_text: str = None):
def __new__(cls, label: str, widget: LabelWidget, display_text: str = "", help_text: str = None):
obj = str.__new__(cls, label)
obj._value_ = label
obj.widget = widget
obj.display_text = display_text
obj.help_text = help_text
return obj
spam = "spam", "Seems to be intentionally low-quality or irrelevant"
fails_task = "fails_task", "Fails to follow the correct instruction / task"
not_appropriate = "not_appropriate", "Inappropriate for customer assistant"
violence = "violence", "Encourages or fails to discourage violence/abuse/terrorism/self-harm"
excessive_harm = (
"excessive_harm",
"Content likely to cause excessive harm not justifiable in the context",
"Harm refers to physical or mental damage or injury to someone or something. Excessive refers to a reasonable threshold of harm in the context, for instance damaging skin is not excessive in the context of surgery.",
)
sexual_content = "sexual_content", "Contains sexual content"
toxicity = "toxicity", "Contains rude, abusive, profane or insulting content"
moral_judgement = "moral_judgement", "Expresses moral judgement"
political_content = "political_content", "Expresses political views"
humor = "humor", "Contains humorous content including sarcasm"
# yes/no questions
spam = "spam", LabelWidget.yes_no, "Seems to be intentionally low-quality or irrelevant"
fails_task = "fails_task", LabelWidget.yes_no, "Fails to follow the correct instruction / task"
# flags
pii = "pii", LabelWidget.flag, "Contains personal identifiable information (PII)"
not_appropriate = "not_appropriate", LabelWidget.flag, "Inappropriate"
hate_speech = (
"hate_speech",
LabelWidget.flag,
"Content is abusive or threatening and expresses prejudice against a protected characteristic",
"Prejudice refers to preconceived views not based on reason. Protected characteristics include gender, ethnicity, religion, sexual orientation, and similar characteristics.",
"Prejudice refers to preconceived views not based on reason. Protected characteristics "
"include gender, ethnicity, religion, sexual orientation, and similar characteristics.",
)
threat = "threat", "Contains a threat against a person or persons"
misleading = "misleading", "Contains text which is incorrect or misleading"
helpful = "helpful", "Completes the task to a high standard"
creative = "creative", "Expresses creativity in responding to the task"
sexual_content = "sexual_content", LabelWidget.flag, "Contains sexual content"
moral_judgement = "moral_judgement", LabelWidget.flag, "Expresses moral judgement"
political_content = "political_content", LabelWidget.flag, "Expresses political views"
# likert
quality = "quality", LabelWidget.likert, "Overall subjective quality rating of the message"
toxicity = "toxicity", LabelWidget.likert, "Rude, abusive, profane or insulting content"
humor = "humor", LabelWidget.likert, "Humorous content including sarcasm"
helpfulness = "helpfulness", LabelWidget.likert, "Helpfulness of the message"
creativity = "creativity", LabelWidget.likert, "Creativity"
violence = "violence", LabelWidget.likert, "Violence/abuse/terrorism/self-harm"
class TextLabels(Interaction):
@@ -367,6 +384,7 @@ class TextLabels(Interaction):
labels: dict[TextLabel, float]
message_id: UUID
task_id: Optional[UUID]
is_report: Optional[bool]
@property
def has_message_id(self) -> bool: