mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
add lang_mismatch label to prompt/reply acceptance check (#984)
* add lang_mismatch label to acceptance check * make acceptance formula clearer * make lang_mismatch flag, remove helpful from prompter labels
This commit is contained in:
@@ -57,8 +57,8 @@ class TreeManagerConfiguration(BaseModel):
|
||||
|
||||
labels_initial_prompt: list[TextLabel] = [
|
||||
TextLabel.spam,
|
||||
TextLabel.lang_mismatch,
|
||||
TextLabel.quality,
|
||||
TextLabel.helpfulness,
|
||||
TextLabel.creativity,
|
||||
TextLabel.humor,
|
||||
TextLabel.toxicity,
|
||||
@@ -71,6 +71,7 @@ class TreeManagerConfiguration(BaseModel):
|
||||
|
||||
labels_assistant_reply: list[TextLabel] = [
|
||||
TextLabel.spam,
|
||||
TextLabel.lang_mismatch,
|
||||
TextLabel.fails_task,
|
||||
TextLabel.quality,
|
||||
TextLabel.helpfulness,
|
||||
@@ -86,8 +87,8 @@ class TreeManagerConfiguration(BaseModel):
|
||||
|
||||
labels_prompter_reply: list[TextLabel] = [
|
||||
TextLabel.spam,
|
||||
TextLabel.lang_mismatch,
|
||||
TextLabel.quality,
|
||||
TextLabel.helpfulness,
|
||||
TextLabel.humor,
|
||||
TextLabel.creativity,
|
||||
TextLabel.toxicity,
|
||||
|
||||
@@ -376,6 +376,8 @@ class TreeManager:
|
||||
label_mode = protocol_schema.LabelTaskMode.simple
|
||||
label_disposition = protocol_schema.LabelTaskDisposition.spam
|
||||
valid_labels = list(self.cfg.mandatory_labels_assistant_reply)
|
||||
if protocol_schema.TextLabel.lang_mismatch not in valid_labels:
|
||||
valid_labels.append(protocol_schema.TextLabel.lang_mismatch)
|
||||
if protocol_schema.TextLabel.quality not in valid_labels:
|
||||
valid_labels.append(protocol_schema.TextLabel.quality)
|
||||
|
||||
@@ -400,6 +402,8 @@ class TreeManager:
|
||||
label_mode = protocol_schema.LabelTaskMode.simple
|
||||
label_disposition = protocol_schema.LabelTaskDisposition.spam
|
||||
valid_labels = list(self.cfg.mandatory_labels_prompter_reply)
|
||||
if protocol_schema.TextLabel.lang_mismatch not in valid_labels:
|
||||
valid_labels.append(protocol_schema.TextLabel.lang_mismatch)
|
||||
if protocol_schema.TextLabel.quality not in valid_labels:
|
||||
valid_labels.append(protocol_schema.TextLabel.quality)
|
||||
|
||||
@@ -483,6 +487,8 @@ class TreeManager:
|
||||
valid_labels = self.cfg.mandatory_labels_initial_prompt
|
||||
label_mode = protocol_schema.LabelTaskMode.simple
|
||||
label_disposition = protocol_schema.LabelTaskDisposition.spam
|
||||
if protocol_schema.TextLabel.lang_mismatch not in valid_labels:
|
||||
valid_labels.append(protocol_schema.TextLabel.lang_mismatch)
|
||||
|
||||
logger.info(f"Generating a LabelInitialPromptTask ({label_mode=:s}).")
|
||||
task = protocol_schema.LabelInitialPromptTask(
|
||||
@@ -803,8 +809,12 @@ class TreeManager:
|
||||
return backlog_tree
|
||||
|
||||
def _calculate_acceptance(self, labels: list[TextLabels]):
|
||||
# calculate acceptance based on spam label
|
||||
return np.mean([1 - l.labels[protocol_schema.TextLabel.spam] for l in labels])
|
||||
# calculate acceptance based on lang_mismatch & spam label
|
||||
lang_mismatch = np.mean([(l.labels.get(protocol_schema.TextLabel.lang_mismatch) or 0) for l in labels])
|
||||
spam = np.mean([l.labels[protocol_schema.TextLabel.spam] for l in labels])
|
||||
acceptance_score = 1 - (spam + lang_mismatch)
|
||||
logger.debug(f"{acceptance_score=} ({spam=}, {lang_mismatch=})")
|
||||
return acceptance_score
|
||||
|
||||
def _query_need_review(
|
||||
self, state: message_tree_state.State, required_reviews: int, root: bool, lang: str
|
||||
|
||||
@@ -355,6 +355,12 @@ class TextLabel(str, enum.Enum):
|
||||
fails_task = "fails_task", LabelWidget.yes_no, "Fails to follow the correct instruction / task"
|
||||
|
||||
# flags
|
||||
lang_mismatch = (
|
||||
"lang_mismatch",
|
||||
LabelWidget.flag,
|
||||
"Language mismatch",
|
||||
"The message is written in language that differs from the currently selected language.",
|
||||
)
|
||||
pii = "pii", LabelWidget.flag, "Contains personal identifiable information (PII)"
|
||||
not_appropriate = "not_appropriate", LabelWidget.flag, "Inappropriate"
|
||||
hate_speech = (
|
||||
|
||||
@@ -96,9 +96,11 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "1234"):
|
||||
else:
|
||||
while labels_dict is None:
|
||||
labels = random.sample(valid_labels, random.randint(1, len(valid_labels)))
|
||||
|
||||
if all([label in valid_labels for label in labels]):
|
||||
labels_dict = {label: "1" if label in labels else "0" for label in valid_labels}
|
||||
labels_dict = {
|
||||
label: 1 if label != "lang_mismatch" and label in labels else 0
|
||||
for label in valid_labels
|
||||
}
|
||||
else:
|
||||
invalid_labels = [label for label in labels if label not in valid_labels]
|
||||
typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}")
|
||||
@@ -212,9 +214,11 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "1234"):
|
||||
else:
|
||||
while labels_dict is None:
|
||||
labels = random.sample(valid_labels, random.randint(1, len(valid_labels)))
|
||||
|
||||
if all([label in valid_labels for label in labels]):
|
||||
labels_dict = {label: "1" if label in labels else "0" for label in valid_labels}
|
||||
labels_dict = {
|
||||
label: 1 if label != "lang_mismatch" and label in labels else 0
|
||||
for label in valid_labels
|
||||
}
|
||||
else:
|
||||
invalid_labels = [label for label in labels if label not in valid_labels]
|
||||
typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}")
|
||||
|
||||
Reference in New Issue
Block a user