mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Add LabelDescription list to labeling tasks, make +1/-1 emojis exclusive (#947)
* add LabelDescription list to labeling tasks * make +1 & -1 emoji exclusive (only one of both or none) * add red_flag emoji to message when reported * fix task's valid labels * fix typo
This commit is contained in:
@@ -3,10 +3,11 @@ from fastapi.security.api_key import APIKey
|
||||
from loguru import logger
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.schemas.text_labels import LabelOption, ValidLabelsResponse
|
||||
from oasst_backend.schemas.text_labels import LabelDescription, ValidLabelsResponse
|
||||
from oasst_backend.utils.database_utils import CommitMode, managed_tx_function
|
||||
from oasst_shared.exceptions import OasstError
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from oasst_shared.schemas.protocol import TextLabel
|
||||
from starlette.status import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST
|
||||
|
||||
router = APIRouter()
|
||||
@@ -45,7 +46,29 @@ def label_text(
|
||||
def get_valid_lables() -> ValidLabelsResponse:
|
||||
return ValidLabelsResponse(
|
||||
valid_labels=[
|
||||
LabelOption(name=l.value, display_text=l.display_text, help_text=l.help_text)
|
||||
for l in protocol_schema.TextLabel
|
||||
LabelDescription(name=l.value, widget=l.widget.value, display_text=l.display_text, help_text=l.help_text)
|
||||
for l in TextLabel
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@router.get("/report_labels")
|
||||
def get_report_lables() -> ValidLabelsResponse:
|
||||
report_labels = [
|
||||
TextLabel.spam,
|
||||
TextLabel.not_appropriate,
|
||||
TextLabel.pii,
|
||||
TextLabel.hate_speech,
|
||||
TextLabel.sexual_content,
|
||||
TextLabel.moral_judgement,
|
||||
TextLabel.political_content,
|
||||
TextLabel.toxicity,
|
||||
TextLabel.violence,
|
||||
TextLabel.quality,
|
||||
]
|
||||
return ValidLabelsResponse(
|
||||
valid_labels=[
|
||||
LabelDescription(name=l.value, widget=l.widget.value, display_text=l.display_text, help_text=l.help_text)
|
||||
for l in report_labels
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from oasst_shared.schemas.protocol import TextLabel
|
||||
from pydantic import AnyHttpUrl, BaseModel, BaseSettings, FilePath, PostgresDsn, validator
|
||||
|
||||
|
||||
@@ -46,13 +46,56 @@ class TreeManagerConfiguration(BaseModel):
|
||||
num_required_rankings: int = 3
|
||||
"""Number of rankings in which the message participated."""
|
||||
|
||||
mandatory_labels_initial_prompt: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam]
|
||||
labels_initial_prompt: list[TextLabel] = [
|
||||
TextLabel.spam,
|
||||
TextLabel.quality,
|
||||
TextLabel.helpfulness,
|
||||
TextLabel.creativity,
|
||||
TextLabel.humor,
|
||||
TextLabel.toxicity,
|
||||
TextLabel.violence,
|
||||
TextLabel.not_appropriate,
|
||||
TextLabel.pii,
|
||||
TextLabel.hate_speech,
|
||||
TextLabel.sexual_content,
|
||||
]
|
||||
|
||||
labels_assistant_reply: list[TextLabel] = [
|
||||
TextLabel.spam,
|
||||
TextLabel.fails_task,
|
||||
TextLabel.quality,
|
||||
TextLabel.helpfulness,
|
||||
TextLabel.creativity,
|
||||
TextLabel.humor,
|
||||
TextLabel.toxicity,
|
||||
TextLabel.violence,
|
||||
TextLabel.not_appropriate,
|
||||
TextLabel.pii,
|
||||
TextLabel.hate_speech,
|
||||
TextLabel.sexual_content,
|
||||
]
|
||||
|
||||
labels_prompter_reply: list[TextLabel] = [
|
||||
TextLabel.spam,
|
||||
TextLabel.quality,
|
||||
TextLabel.helpfulness,
|
||||
TextLabel.humor,
|
||||
TextLabel.creativity,
|
||||
TextLabel.toxicity,
|
||||
TextLabel.violence,
|
||||
TextLabel.not_appropriate,
|
||||
TextLabel.pii,
|
||||
TextLabel.hate_speech,
|
||||
TextLabel.sexual_content,
|
||||
]
|
||||
|
||||
mandatory_labels_initial_prompt: Optional[list[TextLabel]] = [TextLabel.spam]
|
||||
"""Mandatory labels in text-labeling tasks for initial prompts."""
|
||||
|
||||
mandatory_labels_assistant_reply: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam]
|
||||
mandatory_labels_assistant_reply: Optional[list[TextLabel]] = [TextLabel.spam]
|
||||
"""Mandatory labels in text-labeling tasks for assistant replies."""
|
||||
|
||||
mandatory_labels_prompter_reply: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam]
|
||||
mandatory_labels_prompter_reply: Optional[list[TextLabel]] = [TextLabel.spam]
|
||||
"""Mandatory labels in text-labeling tasks for prompter replies."""
|
||||
|
||||
rank_prompter_replies: bool = False
|
||||
|
||||
@@ -64,6 +64,12 @@ class Message(SQLModel, table=True):
|
||||
if not self.payload or not isinstance(self.payload.payload, MessagePayload):
|
||||
raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
|
||||
def has_emoji(self, emoji_code: str) -> bool:
|
||||
return self.emojis and emoji_code in self.emojis and self.emojis[emoji_code] > 0
|
||||
|
||||
def has_user_emoji(self, emoji_code: str) -> bool:
|
||||
return self._user_emojis and emoji_code in self._user_emojis
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
self.ensure_is_message()
|
||||
|
||||
@@ -461,15 +461,22 @@ class PromptRepository:
|
||||
)
|
||||
|
||||
if message_id:
|
||||
message = self.fetch_message(message_id)
|
||||
if task:
|
||||
if not task:
|
||||
if text_labels.is_report is True:
|
||||
message = self.handle_message_emoji(
|
||||
message_id, protocol_schema.EmojiOp.add, protocol_schema.EmojiCode.red_flag
|
||||
)
|
||||
|
||||
# update existing record for repeated updates (same user no task associated)
|
||||
existing_text_label = self.fetch_non_task_text_labels(message_id, self.user_id)
|
||||
if existing_text_label is not None:
|
||||
existing_text_label.labels = text_labels.labels
|
||||
model = existing_text_label
|
||||
|
||||
else:
|
||||
message = self.fetch_message(message_id)
|
||||
message.review_count += 1
|
||||
self.db.add(message)
|
||||
# for the same User id with no task id associated with the message, then update existing record for repeated updates
|
||||
existing_text_label = self.fetch_non_task_text_labels(message_id, self.user_id)
|
||||
if existing_text_label is not None:
|
||||
existing_text_label.labels = text_labels.labels
|
||||
model = existing_text_label
|
||||
|
||||
self.db.add(model)
|
||||
return model, task, message
|
||||
@@ -936,6 +943,20 @@ WHERE message.id = cc.id;
|
||||
op = protocol_schema.EmojiOp.add
|
||||
|
||||
if op == protocol_schema.EmojiOp.add:
|
||||
# hard coded exclusivity of thumbs_up & thumbs_down
|
||||
if emoji == protocol_schema.EmojiCode.thumbs_up and message.has_user_emoji(
|
||||
protocol_schema.EmojiCode.thumbs_down.value
|
||||
):
|
||||
message = self.handle_message_emoji(
|
||||
message_id, protocol_schema.EmojiOp.remove, protocol_schema.EmojiCode.thumbs_down
|
||||
)
|
||||
elif emoji == protocol_schema.EmojiCode.thumbs_down and message.has_user_emoji(
|
||||
protocol_schema.EmojiCode.thumbs_up.value
|
||||
):
|
||||
message = self.handle_message_emoji(
|
||||
message_id, protocol_schema.EmojiOp.remove, protocol_schema.EmojiCode.thumbs_up
|
||||
)
|
||||
|
||||
# insert emoji record & increment count
|
||||
message_emoji = MessageEmoji(message_id=message.id, user_id=self.user_id, emoji=emoji)
|
||||
self.db.add(message_emoji)
|
||||
|
||||
@@ -1,13 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from oasst_shared.schemas.protocol import LabelDescription
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LabelOption(BaseModel):
|
||||
name: str
|
||||
display_text: str
|
||||
help_text: Optional[str]
|
||||
|
||||
|
||||
class ValidLabelsResponse(BaseModel):
|
||||
valid_labels: list[LabelOption]
|
||||
valid_labels: list[LabelDescription]
|
||||
|
||||
@@ -94,8 +94,6 @@ class TreeManagerStats(pydantic.BaseModel):
|
||||
|
||||
|
||||
class TreeManager:
|
||||
_all_text_labels = list(map(lambda x: x.value, protocol_schema.TextLabel))
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db: Session,
|
||||
@@ -216,6 +214,15 @@ class TreeManager:
|
||||
incomplete_rankings=incomplete_rankings,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_label_descriptions(valid_labels: list[TextLabels]) -> list[protocol_schema.LabelDescription]:
|
||||
return [
|
||||
protocol_schema.LabelDescription(
|
||||
name=l.value, widget=l.widget.value, display_text=l.display_text, help_text=l.help_text
|
||||
)
|
||||
for l in valid_labels
|
||||
]
|
||||
|
||||
def next_task(
|
||||
self,
|
||||
desired_task_type: protocol_schema.TaskRequestType = protocol_schema.TaskRequestType.random,
|
||||
@@ -356,14 +363,14 @@ class TreeManager:
|
||||
|
||||
label_mode = protocol_schema.LabelTaskMode.full
|
||||
label_disposition = protocol_schema.LabelTaskDisposition.quality
|
||||
valid_labels = self._all_text_labels
|
||||
|
||||
if message.role == "assistant":
|
||||
valid_labels = self.cfg.labels_assistant_reply
|
||||
if (
|
||||
desired_task_type == protocol_schema.TaskRequestType.random
|
||||
and random.random() > self.cfg.p_full_labeling_review_reply_assistant
|
||||
):
|
||||
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply))
|
||||
valid_labels = self.cfg.mandatory_labels_assistant_reply
|
||||
label_mode = protocol_schema.LabelTaskMode.simple
|
||||
label_disposition = protocol_schema.LabelTaskDisposition.spam
|
||||
|
||||
@@ -372,27 +379,30 @@ class TreeManager:
|
||||
message_id=message.id,
|
||||
conversation=conversation,
|
||||
reply=message.text,
|
||||
valid_labels=valid_labels,
|
||||
valid_labels=list(map(lambda x: x.value, valid_labels)),
|
||||
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)),
|
||||
mode=label_mode,
|
||||
disposition=label_disposition,
|
||||
labels=self._get_label_descriptions(valid_labels),
|
||||
)
|
||||
else:
|
||||
valid_labels = self.cfg.labels_prompter_reply
|
||||
if (
|
||||
desired_task_type == protocol_schema.TaskRequestType.random
|
||||
and random.random() > self.cfg.p_full_labeling_review_reply_prompter
|
||||
):
|
||||
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply))
|
||||
valid_labels = self.cfg.mandatory_labels_prompter_reply
|
||||
label_mode = protocol_schema.LabelTaskMode.simple
|
||||
logger.info(f"Generating a LabelPrompterReplyTask. ({label_mode=:s})")
|
||||
task = protocol_schema.LabelPrompterReplyTask(
|
||||
message_id=message.id,
|
||||
conversation=conversation,
|
||||
reply=message.text,
|
||||
valid_labels=valid_labels,
|
||||
valid_labels=list(map(lambda x: x.value, valid_labels)),
|
||||
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)),
|
||||
mode=label_mode,
|
||||
disposition=label_disposition,
|
||||
labels=self._get_label_descriptions(valid_labels),
|
||||
)
|
||||
|
||||
parent_message_id = message.id
|
||||
@@ -456,10 +466,10 @@ class TreeManager:
|
||||
|
||||
label_mode = protocol_schema.LabelTaskMode.full
|
||||
label_disposition = protocol_schema.LabelTaskDisposition.quality
|
||||
valid_labels = self._all_text_labels
|
||||
valid_labels = self.cfg.labels_initial_prompt
|
||||
|
||||
if random.random() > self.cfg.p_full_labeling_review_prompt:
|
||||
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt))
|
||||
valid_labels = self.cfg.mandatory_labels_initial_prompt
|
||||
label_mode = protocol_schema.LabelTaskMode.simple
|
||||
label_disposition = protocol_schema.LabelTaskDisposition.spam
|
||||
|
||||
@@ -467,10 +477,11 @@ class TreeManager:
|
||||
task = protocol_schema.LabelInitialPromptTask(
|
||||
message_id=message.id,
|
||||
prompt=message.text,
|
||||
valid_labels=valid_labels,
|
||||
valid_labels=list(map(lambda x: x.value, valid_labels)),
|
||||
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt)),
|
||||
mode=label_mode,
|
||||
disposition=label_disposition,
|
||||
labels=self._get_label_descriptions(valid_labels),
|
||||
)
|
||||
|
||||
parent_message_id = message.id
|
||||
@@ -577,7 +588,7 @@ class TreeManager:
|
||||
|
||||
_, task, msg = pr.store_text_labels(interaction)
|
||||
|
||||
# if it was a respones for a task, check if we have enough reviews to calc review_result
|
||||
# if it was a response for a task, check if we have enough reviews to calc review_result
|
||||
if task and msg:
|
||||
reviews = self.query_reviews_for_message(msg.id)
|
||||
acceptance_score = self._calculate_acceptance(reviews)
|
||||
|
||||
@@ -231,12 +231,20 @@ class LabelTaskDisposition(str, enum.Enum):
|
||||
spam = "spam"
|
||||
|
||||
|
||||
class LabelDescription(BaseModel):
|
||||
name: str
|
||||
widget: str
|
||||
display_text: str
|
||||
help_text: Optional[str]
|
||||
|
||||
|
||||
class AbstractLabelTask(Task):
|
||||
message_id: UUID
|
||||
valid_labels: list[str]
|
||||
mandatory_labels: Optional[list[str]]
|
||||
mode: Optional[LabelTaskMode]
|
||||
disposition: Optional[LabelTaskDisposition]
|
||||
labels: Optional[list[LabelDescription]]
|
||||
|
||||
|
||||
class LabelInitialPromptTask(AbstractLabelTask):
|
||||
@@ -324,39 +332,48 @@ class MessageRanking(Interaction):
|
||||
ranking: conlist(item_type=int, min_items=1)
|
||||
|
||||
|
||||
class LabelWidget(str, enum.Enum):
|
||||
yes_no = "yes_no"
|
||||
flag = "flag"
|
||||
likert = "likert"
|
||||
|
||||
|
||||
class TextLabel(str, enum.Enum):
|
||||
"""A label for a piece of text."""
|
||||
|
||||
def __new__(cls, label: str, display_text: str = "", help_text: str = None):
|
||||
def __new__(cls, label: str, widget: LabelWidget, display_text: str = "", help_text: str = None):
|
||||
obj = str.__new__(cls, label)
|
||||
obj._value_ = label
|
||||
obj.widget = widget
|
||||
obj.display_text = display_text
|
||||
obj.help_text = help_text
|
||||
return obj
|
||||
|
||||
spam = "spam", "Seems to be intentionally low-quality or irrelevant"
|
||||
fails_task = "fails_task", "Fails to follow the correct instruction / task"
|
||||
not_appropriate = "not_appropriate", "Inappropriate for customer assistant"
|
||||
violence = "violence", "Encourages or fails to discourage violence/abuse/terrorism/self-harm"
|
||||
excessive_harm = (
|
||||
"excessive_harm",
|
||||
"Content likely to cause excessive harm not justifiable in the context",
|
||||
"Harm refers to physical or mental damage or injury to someone or something. Excessive refers to a reasonable threshold of harm in the context, for instance damaging skin is not excessive in the context of surgery.",
|
||||
)
|
||||
sexual_content = "sexual_content", "Contains sexual content"
|
||||
toxicity = "toxicity", "Contains rude, abusive, profane or insulting content"
|
||||
moral_judgement = "moral_judgement", "Expresses moral judgement"
|
||||
political_content = "political_content", "Expresses political views"
|
||||
humor = "humor", "Contains humorous content including sarcasm"
|
||||
# yes/no questions
|
||||
spam = "spam", LabelWidget.yes_no, "Seems to be intentionally low-quality or irrelevant"
|
||||
fails_task = "fails_task", LabelWidget.yes_no, "Fails to follow the correct instruction / task"
|
||||
|
||||
# flags
|
||||
pii = "pii", LabelWidget.flag, "Contains personal identifiable information (PII)"
|
||||
not_appropriate = "not_appropriate", LabelWidget.flag, "Inappropriate"
|
||||
hate_speech = (
|
||||
"hate_speech",
|
||||
LabelWidget.flag,
|
||||
"Content is abusive or threatening and expresses prejudice against a protected characteristic",
|
||||
"Prejudice refers to preconceived views not based on reason. Protected characteristics include gender, ethnicity, religion, sexual orientation, and similar characteristics.",
|
||||
"Prejudice refers to preconceived views not based on reason. Protected characteristics "
|
||||
"include gender, ethnicity, religion, sexual orientation, and similar characteristics.",
|
||||
)
|
||||
threat = "threat", "Contains a threat against a person or persons"
|
||||
misleading = "misleading", "Contains text which is incorrect or misleading"
|
||||
helpful = "helpful", "Completes the task to a high standard"
|
||||
creative = "creative", "Expresses creativity in responding to the task"
|
||||
sexual_content = "sexual_content", LabelWidget.flag, "Contains sexual content"
|
||||
moral_judgement = "moral_judgement", LabelWidget.flag, "Expresses moral judgement"
|
||||
political_content = "political_content", LabelWidget.flag, "Expresses political views"
|
||||
|
||||
# likert
|
||||
quality = "quality", LabelWidget.likert, "Overall subjective quality rating of the message"
|
||||
toxicity = "toxicity", LabelWidget.likert, "Rude, abusive, profane or insulting content"
|
||||
humor = "humor", LabelWidget.likert, "Humorous content including sarcasm"
|
||||
helpfulness = "helpfulness", LabelWidget.likert, "Helpfulness of the message"
|
||||
creativity = "creativity", LabelWidget.likert, "Creativity"
|
||||
violence = "violence", LabelWidget.likert, "Violence/abuse/terrorism/self-harm"
|
||||
|
||||
|
||||
class TextLabels(Interaction):
|
||||
@@ -367,6 +384,7 @@ class TextLabels(Interaction):
|
||||
labels: dict[TextLabel, float]
|
||||
message_id: UUID
|
||||
task_id: Optional[UUID]
|
||||
is_report: Optional[bool]
|
||||
|
||||
@property
|
||||
def has_message_id(self) -> bool:
|
||||
|
||||
Reference in New Issue
Block a user