mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-07-02 17:00:28 +08:00
Add LabelDescription list to labeling tasks, make +1/-1 emojis exclusive (#947)
* add LabelDescription list to labeling tasks * make +1 & -1 emoji exclusive (only one of both or none) * add red_flag emoji to message when reported * fix task's valid labels * fix typo
This commit is contained in:
@@ -3,10 +3,11 @@ from fastapi.security.api_key import APIKey
|
||||
from loguru import logger
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.schemas.text_labels import LabelOption, ValidLabelsResponse
|
||||
from oasst_backend.schemas.text_labels import LabelDescription, ValidLabelsResponse
|
||||
from oasst_backend.utils.database_utils import CommitMode, managed_tx_function
|
||||
from oasst_shared.exceptions import OasstError
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from oasst_shared.schemas.protocol import TextLabel
|
||||
from starlette.status import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST
|
||||
|
||||
router = APIRouter()
|
||||
@@ -45,7 +46,29 @@ def label_text(
|
||||
def get_valid_lables() -> ValidLabelsResponse:
|
||||
return ValidLabelsResponse(
|
||||
valid_labels=[
|
||||
LabelOption(name=l.value, display_text=l.display_text, help_text=l.help_text)
|
||||
for l in protocol_schema.TextLabel
|
||||
LabelDescription(name=l.value, widget=l.widget.value, display_text=l.display_text, help_text=l.help_text)
|
||||
for l in TextLabel
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@router.get("/report_labels")
|
||||
def get_report_lables() -> ValidLabelsResponse:
|
||||
report_labels = [
|
||||
TextLabel.spam,
|
||||
TextLabel.not_appropriate,
|
||||
TextLabel.pii,
|
||||
TextLabel.hate_speech,
|
||||
TextLabel.sexual_content,
|
||||
TextLabel.moral_judgement,
|
||||
TextLabel.political_content,
|
||||
TextLabel.toxicity,
|
||||
TextLabel.violence,
|
||||
TextLabel.quality,
|
||||
]
|
||||
return ValidLabelsResponse(
|
||||
valid_labels=[
|
||||
LabelDescription(name=l.value, widget=l.widget.value, display_text=l.display_text, help_text=l.help_text)
|
||||
for l in report_labels
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from oasst_shared.schemas.protocol import TextLabel
|
||||
from pydantic import AnyHttpUrl, BaseModel, BaseSettings, FilePath, PostgresDsn, validator
|
||||
|
||||
|
||||
@@ -46,13 +46,56 @@ class TreeManagerConfiguration(BaseModel):
|
||||
num_required_rankings: int = 3
|
||||
"""Number of rankings in which the message participated."""
|
||||
|
||||
mandatory_labels_initial_prompt: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam]
|
||||
labels_initial_prompt: list[TextLabel] = [
|
||||
TextLabel.spam,
|
||||
TextLabel.quality,
|
||||
TextLabel.helpfulness,
|
||||
TextLabel.creativity,
|
||||
TextLabel.humor,
|
||||
TextLabel.toxicity,
|
||||
TextLabel.violence,
|
||||
TextLabel.not_appropriate,
|
||||
TextLabel.pii,
|
||||
TextLabel.hate_speech,
|
||||
TextLabel.sexual_content,
|
||||
]
|
||||
|
||||
labels_assistant_reply: list[TextLabel] = [
|
||||
TextLabel.spam,
|
||||
TextLabel.fails_task,
|
||||
TextLabel.quality,
|
||||
TextLabel.helpfulness,
|
||||
TextLabel.creativity,
|
||||
TextLabel.humor,
|
||||
TextLabel.toxicity,
|
||||
TextLabel.violence,
|
||||
TextLabel.not_appropriate,
|
||||
TextLabel.pii,
|
||||
TextLabel.hate_speech,
|
||||
TextLabel.sexual_content,
|
||||
]
|
||||
|
||||
labels_prompter_reply: list[TextLabel] = [
|
||||
TextLabel.spam,
|
||||
TextLabel.quality,
|
||||
TextLabel.helpfulness,
|
||||
TextLabel.humor,
|
||||
TextLabel.creativity,
|
||||
TextLabel.toxicity,
|
||||
TextLabel.violence,
|
||||
TextLabel.not_appropriate,
|
||||
TextLabel.pii,
|
||||
TextLabel.hate_speech,
|
||||
TextLabel.sexual_content,
|
||||
]
|
||||
|
||||
mandatory_labels_initial_prompt: Optional[list[TextLabel]] = [TextLabel.spam]
|
||||
"""Mandatory labels in text-labeling tasks for initial prompts."""
|
||||
|
||||
mandatory_labels_assistant_reply: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam]
|
||||
mandatory_labels_assistant_reply: Optional[list[TextLabel]] = [TextLabel.spam]
|
||||
"""Mandatory labels in text-labeling tasks for assistant replies."""
|
||||
|
||||
mandatory_labels_prompter_reply: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam]
|
||||
mandatory_labels_prompter_reply: Optional[list[TextLabel]] = [TextLabel.spam]
|
||||
"""Mandatory labels in text-labeling tasks for prompter replies."""
|
||||
|
||||
rank_prompter_replies: bool = False
|
||||
|
||||
@@ -64,6 +64,12 @@ class Message(SQLModel, table=True):
|
||||
if not self.payload or not isinstance(self.payload.payload, MessagePayload):
|
||||
raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
|
||||
def has_emoji(self, emoji_code: str) -> bool:
|
||||
return self.emojis and emoji_code in self.emojis and self.emojis[emoji_code] > 0
|
||||
|
||||
def has_user_emoji(self, emoji_code: str) -> bool:
|
||||
return self._user_emojis and emoji_code in self._user_emojis
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
self.ensure_is_message()
|
||||
|
||||
@@ -461,15 +461,22 @@ class PromptRepository:
|
||||
)
|
||||
|
||||
if message_id:
|
||||
message = self.fetch_message(message_id)
|
||||
if task:
|
||||
if not task:
|
||||
if text_labels.is_report is True:
|
||||
message = self.handle_message_emoji(
|
||||
message_id, protocol_schema.EmojiOp.add, protocol_schema.EmojiCode.red_flag
|
||||
)
|
||||
|
||||
# update existing record for repeated updates (same user no task associated)
|
||||
existing_text_label = self.fetch_non_task_text_labels(message_id, self.user_id)
|
||||
if existing_text_label is not None:
|
||||
existing_text_label.labels = text_labels.labels
|
||||
model = existing_text_label
|
||||
|
||||
else:
|
||||
message = self.fetch_message(message_id)
|
||||
message.review_count += 1
|
||||
self.db.add(message)
|
||||
# for the same User id with no task id associated with the message, then update existing record for repeated updates
|
||||
existing_text_label = self.fetch_non_task_text_labels(message_id, self.user_id)
|
||||
if existing_text_label is not None:
|
||||
existing_text_label.labels = text_labels.labels
|
||||
model = existing_text_label
|
||||
|
||||
self.db.add(model)
|
||||
return model, task, message
|
||||
@@ -936,6 +943,20 @@ WHERE message.id = cc.id;
|
||||
op = protocol_schema.EmojiOp.add
|
||||
|
||||
if op == protocol_schema.EmojiOp.add:
|
||||
# hard coded exclusivity of thumbs_up & thumbs_down
|
||||
if emoji == protocol_schema.EmojiCode.thumbs_up and message.has_user_emoji(
|
||||
protocol_schema.EmojiCode.thumbs_down.value
|
||||
):
|
||||
message = self.handle_message_emoji(
|
||||
message_id, protocol_schema.EmojiOp.remove, protocol_schema.EmojiCode.thumbs_down
|
||||
)
|
||||
elif emoji == protocol_schema.EmojiCode.thumbs_down and message.has_user_emoji(
|
||||
protocol_schema.EmojiCode.thumbs_up.value
|
||||
):
|
||||
message = self.handle_message_emoji(
|
||||
message_id, protocol_schema.EmojiOp.remove, protocol_schema.EmojiCode.thumbs_up
|
||||
)
|
||||
|
||||
# insert emoji record & increment count
|
||||
message_emoji = MessageEmoji(message_id=message.id, user_id=self.user_id, emoji=emoji)
|
||||
self.db.add(message_emoji)
|
||||
|
||||
@@ -1,13 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from oasst_shared.schemas.protocol import LabelDescription
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LabelOption(BaseModel):
|
||||
name: str
|
||||
display_text: str
|
||||
help_text: Optional[str]
|
||||
|
||||
|
||||
class ValidLabelsResponse(BaseModel):
|
||||
valid_labels: list[LabelOption]
|
||||
valid_labels: list[LabelDescription]
|
||||
|
||||
@@ -94,8 +94,6 @@ class TreeManagerStats(pydantic.BaseModel):
|
||||
|
||||
|
||||
class TreeManager:
|
||||
_all_text_labels = list(map(lambda x: x.value, protocol_schema.TextLabel))
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db: Session,
|
||||
@@ -216,6 +214,15 @@ class TreeManager:
|
||||
incomplete_rankings=incomplete_rankings,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_label_descriptions(valid_labels: list[TextLabels]) -> list[protocol_schema.LabelDescription]:
|
||||
return [
|
||||
protocol_schema.LabelDescription(
|
||||
name=l.value, widget=l.widget.value, display_text=l.display_text, help_text=l.help_text
|
||||
)
|
||||
for l in valid_labels
|
||||
]
|
||||
|
||||
def next_task(
|
||||
self,
|
||||
desired_task_type: protocol_schema.TaskRequestType = protocol_schema.TaskRequestType.random,
|
||||
@@ -356,14 +363,14 @@ class TreeManager:
|
||||
|
||||
label_mode = protocol_schema.LabelTaskMode.full
|
||||
label_disposition = protocol_schema.LabelTaskDisposition.quality
|
||||
valid_labels = self._all_text_labels
|
||||
|
||||
if message.role == "assistant":
|
||||
valid_labels = self.cfg.labels_assistant_reply
|
||||
if (
|
||||
desired_task_type == protocol_schema.TaskRequestType.random
|
||||
and random.random() > self.cfg.p_full_labeling_review_reply_assistant
|
||||
):
|
||||
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply))
|
||||
valid_labels = self.cfg.mandatory_labels_assistant_reply
|
||||
label_mode = protocol_schema.LabelTaskMode.simple
|
||||
label_disposition = protocol_schema.LabelTaskDisposition.spam
|
||||
|
||||
@@ -372,27 +379,30 @@ class TreeManager:
|
||||
message_id=message.id,
|
||||
conversation=conversation,
|
||||
reply=message.text,
|
||||
valid_labels=valid_labels,
|
||||
valid_labels=list(map(lambda x: x.value, valid_labels)),
|
||||
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)),
|
||||
mode=label_mode,
|
||||
disposition=label_disposition,
|
||||
labels=self._get_label_descriptions(valid_labels),
|
||||
)
|
||||
else:
|
||||
valid_labels = self.cfg.labels_prompter_reply
|
||||
if (
|
||||
desired_task_type == protocol_schema.TaskRequestType.random
|
||||
and random.random() > self.cfg.p_full_labeling_review_reply_prompter
|
||||
):
|
||||
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply))
|
||||
valid_labels = self.cfg.mandatory_labels_prompter_reply
|
||||
label_mode = protocol_schema.LabelTaskMode.simple
|
||||
logger.info(f"Generating a LabelPrompterReplyTask. ({label_mode=:s})")
|
||||
task = protocol_schema.LabelPrompterReplyTask(
|
||||
message_id=message.id,
|
||||
conversation=conversation,
|
||||
reply=message.text,
|
||||
valid_labels=valid_labels,
|
||||
valid_labels=list(map(lambda x: x.value, valid_labels)),
|
||||
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)),
|
||||
mode=label_mode,
|
||||
disposition=label_disposition,
|
||||
labels=self._get_label_descriptions(valid_labels),
|
||||
)
|
||||
|
||||
parent_message_id = message.id
|
||||
@@ -456,10 +466,10 @@ class TreeManager:
|
||||
|
||||
label_mode = protocol_schema.LabelTaskMode.full
|
||||
label_disposition = protocol_schema.LabelTaskDisposition.quality
|
||||
valid_labels = self._all_text_labels
|
||||
valid_labels = self.cfg.labels_initial_prompt
|
||||
|
||||
if random.random() > self.cfg.p_full_labeling_review_prompt:
|
||||
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt))
|
||||
valid_labels = self.cfg.mandatory_labels_initial_prompt
|
||||
label_mode = protocol_schema.LabelTaskMode.simple
|
||||
label_disposition = protocol_schema.LabelTaskDisposition.spam
|
||||
|
||||
@@ -467,10 +477,11 @@ class TreeManager:
|
||||
task = protocol_schema.LabelInitialPromptTask(
|
||||
message_id=message.id,
|
||||
prompt=message.text,
|
||||
valid_labels=valid_labels,
|
||||
valid_labels=list(map(lambda x: x.value, valid_labels)),
|
||||
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt)),
|
||||
mode=label_mode,
|
||||
disposition=label_disposition,
|
||||
labels=self._get_label_descriptions(valid_labels),
|
||||
)
|
||||
|
||||
parent_message_id = message.id
|
||||
@@ -577,7 +588,7 @@ class TreeManager:
|
||||
|
||||
_, task, msg = pr.store_text_labels(interaction)
|
||||
|
||||
# if it was a respones for a task, check if we have enough reviews to calc review_result
|
||||
# if it was a response for a task, check if we have enough reviews to calc review_result
|
||||
if task and msg:
|
||||
reviews = self.query_reviews_for_message(msg.id)
|
||||
acceptance_score = self._calculate_acceptance(reviews)
|
||||
|
||||
Reference in New Issue
Block a user