shuffle ranking options, add simple mode for text-labels

This commit is contained in:
Andreas Köpf
2023-01-11 23:47:56 +01:00
parent d1e8df3982
commit 025d20e638
3 changed files with 78 additions and 25 deletions
+45 -9
View File
@@ -19,11 +19,11 @@ from sqlmodel import Session, func
class TreeManagerConfiguration(pydantic.BaseModel):
"""Configuration class for the TreeManager"""
"""TreeManager configuration settings"""
max_active_trees: int = 10
"""Maximum number of concurrently active trees in the database.
No new initial prompt tasks will be handed out to users if this
"""Maximum number of concurrently active message trees in the database.
No new initial prompt tasks are handed out to users if this
number is reached."""
max_tree_depth: int = 6
@@ -41,6 +41,15 @@ class TreeManagerConfiguration(pydantic.BaseModel):
num_reviews_reply: int = 3
"""Number of peer review checks to collect per reply (other than initial_prompt)"""
p_full_labeling_review_prompt: float = 0.1
"""Probability of full text-labeling (instead of mandatory only) for initial prompts"""
p_full_labeling_review_reply_assistant: float = 0.1
"""Probability of full text-labeling (instead of mandatory only) for assistant replies"""
p_full_labeling_review_reply_prompter: float = 0.1
"""Probability of full text-labeling (instead of mandatory only) for prompter replies"""
acceptance_threshold_initial_prompt: float = 0.6
"""Threshold for accepting an initial prompt"""
@@ -103,6 +112,8 @@ class IncompleteRankingsRow(pydantic.BaseModel):
class TreeManager:
_all_text_labels = list(map(lambda x: x.value, protocol_schema.TextLabel))
def __init__(self, db: Session, prompt_repository: PromptRepository, configuration: TreeManagerConfiguration):
self.db = db
self.cfg = configuration
@@ -242,6 +253,9 @@ class TreeManager:
messages = self.pr.fetch_message_conversation(ranking_parent_id)
conversation = prepare_conversation(messages)
replies = self.pr.fetch_message_children(ranking_parent_id, reviewed=True, exclude_deleted=True)
assert len(replies) > 1
random.shuffle(replies) # hand out replies in random order
reply_messages = prepare_conversation_message_list(replies)
replies = [p.text for p in replies]
@@ -263,26 +277,40 @@ class TreeManager:
assert len(replies_need_review) > 0
random_reply_message_id = random.choice(replies_need_review)
messages = self.pr.fetch_message_conversation(random_reply_message_id)
conversation = prepare_conversation(messages[:-1])
message = messages[-1]
self.cfg.p_full_labeling_review_reply_prompter: float = 0.1
label_mode = protocol_schema.LabelTaskMode.full
valid_labels = self._all_text_labels
if message.role == "assistant":
logger.info("Generating a LabelAssistantReplyTask.")
if random.random() > self.cfg.p_full_labeling_review_reply_assistant:
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply))
label_mode = protocol_schema.LabelTaskMode.simple
logger.info(f"Generating a LabelAssistantReplyTask. ({label_mode=:s})")
task = protocol_schema.LabelAssistantReplyTask(
message_id=message.id,
conversation=conversation,
reply=message.text,
valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)),
valid_labels=valid_labels,
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)),
mode=label_mode,
)
else:
logger.info("Generating a LabelPrompterReplyTask.")
if random.random() > self.cfg.p_full_labeling_review_reply_prompter:
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply))
label_mode = protocol_schema.LabelTaskMode.simple
logger.info(f"Generating a LabelPrompterReplyTask. ({label_mode=:s})")
task = protocol_schema.LabelPrompterReplyTask(
message_id=message.id,
conversation=conversation,
reply=message.text,
valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)),
valid_labels=valid_labels,
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)),
mode=label_mode,
)
parent_message_id = message.id
@@ -314,13 +342,21 @@ class TreeManager:
case TaskType.LABEL_PROMPT:
assert len(prompts_need_review) > 0
message = self.pr.fetch_message(random.choice(prompts_need_review))
logger.info("Generating a LabelInitialPromptTask.")
label_mode = protocol_schema.LabelTaskMode.full
valid_labels = self._all_text_labels
if random.random() > self.cfg.p_full_labeling_review_prompt:
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt))
label_mode = protocol_schema.LabelTaskMode.simple
logger.info(f"Generating a LabelInitialPromptTask ({label_mode=:s}).")
task = protocol_schema.LabelInitialPromptTask(
message_id=message.id,
prompt=message.text,
valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)),
valid_labels=valid_labels,
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt)),
mode=label_mode,
)
parent_message_id = message.id
@@ -179,6 +179,13 @@ class RankAssistantRepliesTask(RankConversationRepliesTask):
type: Literal["rank_assistant_replies"] = "rank_assistant_replies"
class LabelTaskMode(str, enum.Enum):
"""Label task mode that allows frontends to select an appropriate UI."""
simple = "simple"
full = "full"
class LabelInitialPromptTask(Task):
"""A task to label an initial prompt."""
@@ -187,6 +194,7 @@ class LabelInitialPromptTask(Task):
prompt: str
valid_labels: list[str]
mandatory_labels: Optional[list[str]]
mode: Optional[LabelTaskMode]
class LabelConversationReplyTask(Task):
@@ -198,6 +206,7 @@ class LabelConversationReplyTask(Task):
reply: str
valid_labels: list[str]
mandatory_labels: Optional[list[str]]
mode: Optional[LabelTaskMode]
class LabelPrompterReplyTask(LabelConversationReplyTask):
+24 -16
View File
@@ -220,15 +220,19 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
valid_labels = task["valid_labels"]
labels_dict = None
while labels_dict is None:
labels_str: str = typer.prompt("Enter labels, separated by commas")
labels = labels_str.lower().replace(" ", "").split(",")
if task["mode"] == "simple" and len(valid_labels) == 1:
answer: str = typer.confirm(f"{valid_labels[0]}?")
labels_dict = {valid_labels[0]: 1 if answer else 0}
else:
while labels_dict is None:
labels_str: str = typer.prompt("Enter labels, separated by commas")
labels = labels_str.lower().replace(" ", "").split(",")
if all([label in valid_labels for label in labels]):
labels_dict = {label: "1" if label in labels else "0" for label in valid_labels}
else:
invalid_labels = [label for label in labels if label not in valid_labels]
typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}")
if all([label in valid_labels for label in labels]):
labels_dict = {label: "1" if label in labels else "0" for label in valid_labels}
else:
invalid_labels = [label for label in labels if label not in valid_labels]
typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}")
# send labels
new_task = _post(
@@ -258,15 +262,19 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
valid_labels = task["valid_labels"]
labels_dict = None
while labels_dict is None:
labels_str: str = typer.prompt("Enter labels, separated by commas")
labels = labels_str.lower().replace(" ", "").split(",")
if task["mode"] == "simple" and len(valid_labels) == 1:
answer: str = typer.confirm(f"{valid_labels[0]}?")
labels_dict = {valid_labels[0]: 1 if answer else 0}
else:
while labels_dict is None:
labels_str: str = typer.prompt("Enter labels, separated by commas")
labels = labels_str.lower().replace(" ", "").split(",")
if all([label in valid_labels for label in labels]):
labels_dict = {label: "1" if label in labels else "0" for label in valid_labels}
else:
invalid_labels = [label for label in labels if label not in valid_labels]
typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}")
if all([label in valid_labels for label in labels]):
labels_dict = {label: "1" if label in labels else "0" for label in valid_labels}
else:
invalid_labels = [label for label in labels if label not in valid_labels]
typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}")
# send labels
new_task = _post(