add lang_mismatch label to prompt/reply acceptance check (#984)

* add lang_mismatch label to acceptance check

* make acceptance formula clearer

* make lang_mismatch flag, remove helpful from prompter labels
This commit is contained in:
Andreas Köpf
2023-01-29 17:34:37 +01:00
committed by GitHub
parent 61b2949122
commit 0601d874e8
4 changed files with 29 additions and 8 deletions
+3 -2
View File
@@ -57,8 +57,8 @@ class TreeManagerConfiguration(BaseModel):
labels_initial_prompt: list[TextLabel] = [
TextLabel.spam,
TextLabel.lang_mismatch,
TextLabel.quality,
TextLabel.helpfulness,
TextLabel.creativity,
TextLabel.humor,
TextLabel.toxicity,
@@ -71,6 +71,7 @@ class TreeManagerConfiguration(BaseModel):
labels_assistant_reply: list[TextLabel] = [
TextLabel.spam,
TextLabel.lang_mismatch,
TextLabel.fails_task,
TextLabel.quality,
TextLabel.helpfulness,
@@ -86,8 +87,8 @@ class TreeManagerConfiguration(BaseModel):
labels_prompter_reply: list[TextLabel] = [
TextLabel.spam,
TextLabel.lang_mismatch,
TextLabel.quality,
TextLabel.helpfulness,
TextLabel.humor,
TextLabel.creativity,
TextLabel.toxicity,
+12 -2
View File
@@ -376,6 +376,8 @@ class TreeManager:
label_mode = protocol_schema.LabelTaskMode.simple
label_disposition = protocol_schema.LabelTaskDisposition.spam
valid_labels = list(self.cfg.mandatory_labels_assistant_reply)
if protocol_schema.TextLabel.lang_mismatch not in valid_labels:
valid_labels.append(protocol_schema.TextLabel.lang_mismatch)
if protocol_schema.TextLabel.quality not in valid_labels:
valid_labels.append(protocol_schema.TextLabel.quality)
@@ -400,6 +402,8 @@ class TreeManager:
label_mode = protocol_schema.LabelTaskMode.simple
label_disposition = protocol_schema.LabelTaskDisposition.spam
valid_labels = list(self.cfg.mandatory_labels_prompter_reply)
if protocol_schema.TextLabel.lang_mismatch not in valid_labels:
valid_labels.append(protocol_schema.TextLabel.lang_mismatch)
if protocol_schema.TextLabel.quality not in valid_labels:
valid_labels.append(protocol_schema.TextLabel.quality)
@@ -483,6 +487,8 @@ class TreeManager:
valid_labels = self.cfg.mandatory_labels_initial_prompt
label_mode = protocol_schema.LabelTaskMode.simple
label_disposition = protocol_schema.LabelTaskDisposition.spam
if protocol_schema.TextLabel.lang_mismatch not in valid_labels:
valid_labels.append(protocol_schema.TextLabel.lang_mismatch)
logger.info(f"Generating a LabelInitialPromptTask ({label_mode=:s}).")
task = protocol_schema.LabelInitialPromptTask(
@@ -803,8 +809,12 @@ class TreeManager:
return backlog_tree
def _calculate_acceptance(self, labels: list[TextLabels]):
# calculate acceptance based on spam label
return np.mean([1 - l.labels[protocol_schema.TextLabel.spam] for l in labels])
# calculate acceptance based on lang_mismatch & spam label
lang_mismatch = np.mean([(l.labels.get(protocol_schema.TextLabel.lang_mismatch) or 0) for l in labels])
spam = np.mean([l.labels[protocol_schema.TextLabel.spam] for l in labels])
acceptance_score = 1 - (spam + lang_mismatch)
logger.debug(f"{acceptance_score=} ({spam=}, {lang_mismatch=})")
return acceptance_score
def _query_need_review(
self, state: message_tree_state.State, required_reviews: int, root: bool, lang: str
@@ -355,6 +355,12 @@ class TextLabel(str, enum.Enum):
fails_task = "fails_task", LabelWidget.yes_no, "Fails to follow the correct instruction / task"
# flags
lang_mismatch = (
"lang_mismatch",
LabelWidget.flag,
"Language mismatch",
"The message is written in language that differs from the currently selected language.",
)
pii = "pii", LabelWidget.flag, "Contains personal identifiable information (PII)"
not_appropriate = "not_appropriate", LabelWidget.flag, "Inappropriate"
hate_speech = (
+8 -4
View File
@@ -96,9 +96,11 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "1234"):
else:
while labels_dict is None:
labels = random.sample(valid_labels, random.randint(1, len(valid_labels)))
if all([label in valid_labels for label in labels]):
labels_dict = {label: "1" if label in labels else "0" for label in valid_labels}
labels_dict = {
label: 1 if label != "lang_mismatch" and label in labels else 0
for label in valid_labels
}
else:
invalid_labels = [label for label in labels if label not in valid_labels]
typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}")
@@ -212,9 +214,11 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "1234"):
else:
while labels_dict is None:
labels = random.sample(valid_labels, random.randint(1, len(valid_labels)))
if all([label in valid_labels for label in labels]):
labels_dict = {label: "1" if label in labels else "0" for label in valid_labels}
labels_dict = {
label: 1 if label != "lang_mismatch" and label in labels else 0
for label in valid_labels
}
else:
invalid_labels = [label for label in labels if label not in valid_labels]
typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}")