mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-07-03 17:10:10 +08:00
shuffle ranking options, add simple mode for text-labels
This commit is contained in:
@@ -19,11 +19,11 @@ from sqlmodel import Session, func
|
||||
|
||||
|
||||
class TreeManagerConfiguration(pydantic.BaseModel):
|
||||
"""Configuration class for the TreeManager"""
|
||||
"""TreeManager configuration settings"""
|
||||
|
||||
max_active_trees: int = 10
|
||||
"""Maximum number of concurrently active trees in the database.
|
||||
No new initial prompt tasks will be handed out to users if this
|
||||
"""Maximum number of concurrently active message trees in the database.
|
||||
No new initial prompt tasks are handed out to users if this
|
||||
number is reached."""
|
||||
|
||||
max_tree_depth: int = 6
|
||||
@@ -41,6 +41,15 @@ class TreeManagerConfiguration(pydantic.BaseModel):
|
||||
num_reviews_reply: int = 3
|
||||
"""Number of peer review checks to collect per reply (other than initial_prompt)"""
|
||||
|
||||
p_full_labeling_review_prompt: float = 0.1
|
||||
"""Probability of full text-labeling (instead of mandatory only) for initial prompts"""
|
||||
|
||||
p_full_labeling_review_reply_assistant: float = 0.1
|
||||
"""Probability of full text-labeling (instead of mandatory only) for assistant replies"""
|
||||
|
||||
p_full_labeling_review_reply_prompter: float = 0.1
|
||||
"""Probability of full text-labeling (instead of mandatory only) for prompter replies"""
|
||||
|
||||
acceptance_threshold_initial_prompt: float = 0.6
|
||||
"""Threshold for accepting an initial prompt"""
|
||||
|
||||
@@ -103,6 +112,8 @@ class IncompleteRankingsRow(pydantic.BaseModel):
|
||||
|
||||
|
||||
class TreeManager:
|
||||
_all_text_labels = list(map(lambda x: x.value, protocol_schema.TextLabel))
|
||||
|
||||
def __init__(self, db: Session, prompt_repository: PromptRepository, configuration: TreeManagerConfiguration):
|
||||
self.db = db
|
||||
self.cfg = configuration
|
||||
@@ -242,6 +253,9 @@ class TreeManager:
|
||||
messages = self.pr.fetch_message_conversation(ranking_parent_id)
|
||||
conversation = prepare_conversation(messages)
|
||||
replies = self.pr.fetch_message_children(ranking_parent_id, reviewed=True, exclude_deleted=True)
|
||||
|
||||
assert len(replies) > 1
|
||||
random.shuffle(replies) # hand out replies in random order
|
||||
reply_messages = prepare_conversation_message_list(replies)
|
||||
replies = [p.text for p in replies]
|
||||
|
||||
@@ -263,26 +277,40 @@ class TreeManager:
|
||||
assert len(replies_need_review) > 0
|
||||
random_reply_message_id = random.choice(replies_need_review)
|
||||
messages = self.pr.fetch_message_conversation(random_reply_message_id)
|
||||
|
||||
conversation = prepare_conversation(messages[:-1])
|
||||
message = messages[-1]
|
||||
|
||||
self.cfg.p_full_labeling_review_reply_prompter: float = 0.1
|
||||
|
||||
label_mode = protocol_schema.LabelTaskMode.full
|
||||
valid_labels = self._all_text_labels
|
||||
|
||||
if message.role == "assistant":
|
||||
logger.info("Generating a LabelAssistantReplyTask.")
|
||||
if random.random() > self.cfg.p_full_labeling_review_reply_assistant:
|
||||
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply))
|
||||
label_mode = protocol_schema.LabelTaskMode.simple
|
||||
logger.info(f"Generating a LabelAssistantReplyTask. ({label_mode=:s})")
|
||||
task = protocol_schema.LabelAssistantReplyTask(
|
||||
message_id=message.id,
|
||||
conversation=conversation,
|
||||
reply=message.text,
|
||||
valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)),
|
||||
valid_labels=valid_labels,
|
||||
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)),
|
||||
mode=label_mode,
|
||||
)
|
||||
else:
|
||||
logger.info("Generating a LabelPrompterReplyTask.")
|
||||
if random.random() > self.cfg.p_full_labeling_review_reply_prompter:
|
||||
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply))
|
||||
label_mode = protocol_schema.LabelTaskMode.simple
|
||||
logger.info(f"Generating a LabelPrompterReplyTask. ({label_mode=:s})")
|
||||
task = protocol_schema.LabelPrompterReplyTask(
|
||||
message_id=message.id,
|
||||
conversation=conversation,
|
||||
reply=message.text,
|
||||
valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)),
|
||||
valid_labels=valid_labels,
|
||||
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)),
|
||||
mode=label_mode,
|
||||
)
|
||||
|
||||
parent_message_id = message.id
|
||||
@@ -314,13 +342,21 @@ class TreeManager:
|
||||
case TaskType.LABEL_PROMPT:
|
||||
assert len(prompts_need_review) > 0
|
||||
message = self.pr.fetch_message(random.choice(prompts_need_review))
|
||||
logger.info("Generating a LabelInitialPromptTask.")
|
||||
|
||||
label_mode = protocol_schema.LabelTaskMode.full
|
||||
valid_labels = self._all_text_labels
|
||||
|
||||
if random.random() > self.cfg.p_full_labeling_review_prompt:
|
||||
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt))
|
||||
label_mode = protocol_schema.LabelTaskMode.simple
|
||||
|
||||
logger.info(f"Generating a LabelInitialPromptTask ({label_mode=:s}).")
|
||||
task = protocol_schema.LabelInitialPromptTask(
|
||||
message_id=message.id,
|
||||
prompt=message.text,
|
||||
valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)),
|
||||
valid_labels=valid_labels,
|
||||
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt)),
|
||||
mode=label_mode,
|
||||
)
|
||||
|
||||
parent_message_id = message.id
|
||||
|
||||
@@ -179,6 +179,13 @@ class RankAssistantRepliesTask(RankConversationRepliesTask):
|
||||
type: Literal["rank_assistant_replies"] = "rank_assistant_replies"
|
||||
|
||||
|
||||
class LabelTaskMode(str, enum.Enum):
|
||||
"""Label task mode that allows frontends to select an appropriate UI."""
|
||||
|
||||
simple = "simple"
|
||||
full = "full"
|
||||
|
||||
|
||||
class LabelInitialPromptTask(Task):
|
||||
"""A task to label an initial prompt."""
|
||||
|
||||
@@ -187,6 +194,7 @@ class LabelInitialPromptTask(Task):
|
||||
prompt: str
|
||||
valid_labels: list[str]
|
||||
mandatory_labels: Optional[list[str]]
|
||||
mode: Optional[LabelTaskMode]
|
||||
|
||||
|
||||
class LabelConversationReplyTask(Task):
|
||||
@@ -198,6 +206,7 @@ class LabelConversationReplyTask(Task):
|
||||
reply: str
|
||||
valid_labels: list[str]
|
||||
mandatory_labels: Optional[list[str]]
|
||||
mode: Optional[LabelTaskMode]
|
||||
|
||||
|
||||
class LabelPrompterReplyTask(LabelConversationReplyTask):
|
||||
|
||||
+24
-16
@@ -220,15 +220,19 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
|
||||
valid_labels = task["valid_labels"]
|
||||
|
||||
labels_dict = None
|
||||
while labels_dict is None:
|
||||
labels_str: str = typer.prompt("Enter labels, separated by commas")
|
||||
labels = labels_str.lower().replace(" ", "").split(",")
|
||||
if task["mode"] == "simple" and len(valid_labels) == 1:
|
||||
answer: str = typer.confirm(f"{valid_labels[0]}?")
|
||||
labels_dict = {valid_labels[0]: 1 if answer else 0}
|
||||
else:
|
||||
while labels_dict is None:
|
||||
labels_str: str = typer.prompt("Enter labels, separated by commas")
|
||||
labels = labels_str.lower().replace(" ", "").split(",")
|
||||
|
||||
if all([label in valid_labels for label in labels]):
|
||||
labels_dict = {label: "1" if label in labels else "0" for label in valid_labels}
|
||||
else:
|
||||
invalid_labels = [label for label in labels if label not in valid_labels]
|
||||
typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}")
|
||||
if all([label in valid_labels for label in labels]):
|
||||
labels_dict = {label: "1" if label in labels else "0" for label in valid_labels}
|
||||
else:
|
||||
invalid_labels = [label for label in labels if label not in valid_labels]
|
||||
typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}")
|
||||
|
||||
# send labels
|
||||
new_task = _post(
|
||||
@@ -258,15 +262,19 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
|
||||
valid_labels = task["valid_labels"]
|
||||
|
||||
labels_dict = None
|
||||
while labels_dict is None:
|
||||
labels_str: str = typer.prompt("Enter labels, separated by commas")
|
||||
labels = labels_str.lower().replace(" ", "").split(",")
|
||||
if task["mode"] == "simple" and len(valid_labels) == 1:
|
||||
answer: str = typer.confirm(f"{valid_labels[0]}?")
|
||||
labels_dict = {valid_labels[0]: 1 if answer else 0}
|
||||
else:
|
||||
while labels_dict is None:
|
||||
labels_str: str = typer.prompt("Enter labels, separated by commas")
|
||||
labels = labels_str.lower().replace(" ", "").split(",")
|
||||
|
||||
if all([label in valid_labels for label in labels]):
|
||||
labels_dict = {label: "1" if label in labels else "0" for label in valid_labels}
|
||||
else:
|
||||
invalid_labels = [label for label in labels if label not in valid_labels]
|
||||
typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}")
|
||||
if all([label in valid_labels for label in labels]):
|
||||
labels_dict = {label: "1" if label in labels else "0" for label in valid_labels}
|
||||
else:
|
||||
invalid_labels = [label for label in labels if label not in valid_labels]
|
||||
typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}")
|
||||
|
||||
# send labels
|
||||
new_task = _post(
|
||||
|
||||
Reference in New Issue
Block a user