Add LabelDescription list to labeling tasks, make +1/-1 emojis exclusive (#947)

* add LabelDescription list to labeling tasks

* make +1 & -1 emoji exclusive (only one of both or none)

* add red_flag emoji to message when reported

* fix task's valid labels

* fix typo
This commit is contained in:
Andreas Köpf
2023-01-27 00:54:29 +01:00
committed by GitHub
parent f3ffde47ff
commit da1c81d2c9
7 changed files with 169 additions and 54 deletions
+26 -3
View File
@@ -3,10 +3,11 @@ from fastapi.security.api_key import APIKey
from loguru import logger
from oasst_backend.api import deps
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.schemas.text_labels import LabelOption, ValidLabelsResponse
from oasst_backend.schemas.text_labels import LabelDescription, ValidLabelsResponse
from oasst_backend.utils.database_utils import CommitMode, managed_tx_function
from oasst_shared.exceptions import OasstError
from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.schemas.protocol import TextLabel
from starlette.status import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST
router = APIRouter()
@@ -45,7 +46,29 @@ def label_text(
def get_valid_lables() -> ValidLabelsResponse:
return ValidLabelsResponse(
valid_labels=[
LabelOption(name=l.value, display_text=l.display_text, help_text=l.help_text)
for l in protocol_schema.TextLabel
LabelDescription(name=l.value, widget=l.widget.value, display_text=l.display_text, help_text=l.help_text)
for l in TextLabel
]
)
@router.get("/report_labels")
def get_report_lables() -> ValidLabelsResponse:
report_labels = [
TextLabel.spam,
TextLabel.not_appropriate,
TextLabel.pii,
TextLabel.hate_speech,
TextLabel.sexual_content,
TextLabel.moral_judgement,
TextLabel.political_content,
TextLabel.toxicity,
TextLabel.violence,
TextLabel.quality,
]
return ValidLabelsResponse(
valid_labels=[
LabelDescription(name=l.value, widget=l.widget.value, display_text=l.display_text, help_text=l.help_text)
for l in report_labels
]
)
+47 -4
View File
@@ -1,7 +1,7 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.schemas.protocol import TextLabel
from pydantic import AnyHttpUrl, BaseModel, BaseSettings, FilePath, PostgresDsn, validator
@@ -46,13 +46,56 @@ class TreeManagerConfiguration(BaseModel):
num_required_rankings: int = 3
"""Number of rankings in which the message participated."""
mandatory_labels_initial_prompt: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam]
labels_initial_prompt: list[TextLabel] = [
TextLabel.spam,
TextLabel.quality,
TextLabel.helpfulness,
TextLabel.creativity,
TextLabel.humor,
TextLabel.toxicity,
TextLabel.violence,
TextLabel.not_appropriate,
TextLabel.pii,
TextLabel.hate_speech,
TextLabel.sexual_content,
]
labels_assistant_reply: list[TextLabel] = [
TextLabel.spam,
TextLabel.fails_task,
TextLabel.quality,
TextLabel.helpfulness,
TextLabel.creativity,
TextLabel.humor,
TextLabel.toxicity,
TextLabel.violence,
TextLabel.not_appropriate,
TextLabel.pii,
TextLabel.hate_speech,
TextLabel.sexual_content,
]
labels_prompter_reply: list[TextLabel] = [
TextLabel.spam,
TextLabel.quality,
TextLabel.helpfulness,
TextLabel.humor,
TextLabel.creativity,
TextLabel.toxicity,
TextLabel.violence,
TextLabel.not_appropriate,
TextLabel.pii,
TextLabel.hate_speech,
TextLabel.sexual_content,
]
mandatory_labels_initial_prompt: Optional[list[TextLabel]] = [TextLabel.spam]
"""Mandatory labels in text-labeling tasks for initial prompts."""
mandatory_labels_assistant_reply: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam]
mandatory_labels_assistant_reply: Optional[list[TextLabel]] = [TextLabel.spam]
"""Mandatory labels in text-labeling tasks for assistant replies."""
mandatory_labels_prompter_reply: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam]
mandatory_labels_prompter_reply: Optional[list[TextLabel]] = [TextLabel.spam]
"""Mandatory labels in text-labeling tasks for prompter replies."""
rank_prompter_replies: bool = False
+6
View File
@@ -64,6 +64,12 @@ class Message(SQLModel, table=True):
if not self.payload or not isinstance(self.payload.payload, MessagePayload):
raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE, HTTPStatus.INTERNAL_SERVER_ERROR)
def has_emoji(self, emoji_code: str) -> bool:
return self.emojis and emoji_code in self.emojis and self.emojis[emoji_code] > 0
def has_user_emoji(self, emoji_code: str) -> bool:
return self._user_emojis and emoji_code in self._user_emojis
@property
def text(self) -> str:
self.ensure_is_message()
+28 -7
View File
@@ -461,15 +461,22 @@ class PromptRepository:
)
if message_id:
message = self.fetch_message(message_id)
if task:
if not task:
if text_labels.is_report is True:
message = self.handle_message_emoji(
message_id, protocol_schema.EmojiOp.add, protocol_schema.EmojiCode.red_flag
)
# update existing record for repeated updates (same user no task associated)
existing_text_label = self.fetch_non_task_text_labels(message_id, self.user_id)
if existing_text_label is not None:
existing_text_label.labels = text_labels.labels
model = existing_text_label
else:
message = self.fetch_message(message_id)
message.review_count += 1
self.db.add(message)
# for the same User id with no task id associated with the message, then update existing record for repeated updates
existing_text_label = self.fetch_non_task_text_labels(message_id, self.user_id)
if existing_text_label is not None:
existing_text_label.labels = text_labels.labels
model = existing_text_label
self.db.add(model)
return model, task, message
@@ -936,6 +943,20 @@ WHERE message.id = cc.id;
op = protocol_schema.EmojiOp.add
if op == protocol_schema.EmojiOp.add:
# hard coded exclusivity of thumbs_up & thumbs_down
if emoji == protocol_schema.EmojiCode.thumbs_up and message.has_user_emoji(
protocol_schema.EmojiCode.thumbs_down.value
):
message = self.handle_message_emoji(
message_id, protocol_schema.EmojiOp.remove, protocol_schema.EmojiCode.thumbs_down
)
elif emoji == protocol_schema.EmojiCode.thumbs_down and message.has_user_emoji(
protocol_schema.EmojiCode.thumbs_up.value
):
message = self.handle_message_emoji(
message_id, protocol_schema.EmojiOp.remove, protocol_schema.EmojiCode.thumbs_up
)
# insert emoji record & increment count
message_emoji = MessageEmoji(message_id=message.id, user_id=self.user_id, emoji=emoji)
self.db.add(message_emoji)
+2 -9
View File
@@ -1,13 +1,6 @@
from typing import Optional
from oasst_shared.schemas.protocol import LabelDescription
from pydantic import BaseModel
class LabelOption(BaseModel):
name: str
display_text: str
help_text: Optional[str]
class ValidLabelsResponse(BaseModel):
valid_labels: list[LabelOption]
valid_labels: list[LabelDescription]
+22 -11
View File
@@ -94,8 +94,6 @@ class TreeManagerStats(pydantic.BaseModel):
class TreeManager:
_all_text_labels = list(map(lambda x: x.value, protocol_schema.TextLabel))
def __init__(
self,
db: Session,
@@ -216,6 +214,15 @@ class TreeManager:
incomplete_rankings=incomplete_rankings,
)
@staticmethod
def _get_label_descriptions(valid_labels: list[TextLabels]) -> list[protocol_schema.LabelDescription]:
return [
protocol_schema.LabelDescription(
name=l.value, widget=l.widget.value, display_text=l.display_text, help_text=l.help_text
)
for l in valid_labels
]
def next_task(
self,
desired_task_type: protocol_schema.TaskRequestType = protocol_schema.TaskRequestType.random,
@@ -356,14 +363,14 @@ class TreeManager:
label_mode = protocol_schema.LabelTaskMode.full
label_disposition = protocol_schema.LabelTaskDisposition.quality
valid_labels = self._all_text_labels
if message.role == "assistant":
valid_labels = self.cfg.labels_assistant_reply
if (
desired_task_type == protocol_schema.TaskRequestType.random
and random.random() > self.cfg.p_full_labeling_review_reply_assistant
):
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply))
valid_labels = self.cfg.mandatory_labels_assistant_reply
label_mode = protocol_schema.LabelTaskMode.simple
label_disposition = protocol_schema.LabelTaskDisposition.spam
@@ -372,27 +379,30 @@ class TreeManager:
message_id=message.id,
conversation=conversation,
reply=message.text,
valid_labels=valid_labels,
valid_labels=list(map(lambda x: x.value, valid_labels)),
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)),
mode=label_mode,
disposition=label_disposition,
labels=self._get_label_descriptions(valid_labels),
)
else:
valid_labels = self.cfg.labels_prompter_reply
if (
desired_task_type == protocol_schema.TaskRequestType.random
and random.random() > self.cfg.p_full_labeling_review_reply_prompter
):
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply))
valid_labels = self.cfg.mandatory_labels_prompter_reply
label_mode = protocol_schema.LabelTaskMode.simple
logger.info(f"Generating a LabelPrompterReplyTask. ({label_mode=:s})")
task = protocol_schema.LabelPrompterReplyTask(
message_id=message.id,
conversation=conversation,
reply=message.text,
valid_labels=valid_labels,
valid_labels=list(map(lambda x: x.value, valid_labels)),
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)),
mode=label_mode,
disposition=label_disposition,
labels=self._get_label_descriptions(valid_labels),
)
parent_message_id = message.id
@@ -456,10 +466,10 @@ class TreeManager:
label_mode = protocol_schema.LabelTaskMode.full
label_disposition = protocol_schema.LabelTaskDisposition.quality
valid_labels = self._all_text_labels
valid_labels = self.cfg.labels_initial_prompt
if random.random() > self.cfg.p_full_labeling_review_prompt:
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt))
valid_labels = self.cfg.mandatory_labels_initial_prompt
label_mode = protocol_schema.LabelTaskMode.simple
label_disposition = protocol_schema.LabelTaskDisposition.spam
@@ -467,10 +477,11 @@ class TreeManager:
task = protocol_schema.LabelInitialPromptTask(
message_id=message.id,
prompt=message.text,
valid_labels=valid_labels,
valid_labels=list(map(lambda x: x.value, valid_labels)),
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt)),
mode=label_mode,
disposition=label_disposition,
labels=self._get_label_descriptions(valid_labels),
)
parent_message_id = message.id
@@ -577,7 +588,7 @@ class TreeManager:
_, task, msg = pr.store_text_labels(interaction)
# if it was a respones for a task, check if we have enough reviews to calc review_result
# if it was a response for a task, check if we have enough reviews to calc review_result
if task and msg:
reviews = self.query_reviews_for_message(msg.id)
acceptance_score = self._calculate_acceptance(reviews)