diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index 27b46bf5..da89d4d4 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -57,8 +57,8 @@ class TreeManagerConfiguration(BaseModel): labels_initial_prompt: list[TextLabel] = [ TextLabel.spam, + TextLabel.lang_mismatch, TextLabel.quality, - TextLabel.helpfulness, TextLabel.creativity, TextLabel.humor, TextLabel.toxicity, @@ -71,6 +71,7 @@ class TreeManagerConfiguration(BaseModel): labels_assistant_reply: list[TextLabel] = [ TextLabel.spam, + TextLabel.lang_mismatch, TextLabel.fails_task, TextLabel.quality, TextLabel.helpfulness, @@ -86,8 +87,8 @@ class TreeManagerConfiguration(BaseModel): labels_prompter_reply: list[TextLabel] = [ TextLabel.spam, + TextLabel.lang_mismatch, TextLabel.quality, - TextLabel.helpfulness, TextLabel.humor, TextLabel.creativity, TextLabel.toxicity, diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index 145ca185..b8466ff0 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -376,6 +376,8 @@ class TreeManager: label_mode = protocol_schema.LabelTaskMode.simple label_disposition = protocol_schema.LabelTaskDisposition.spam valid_labels = list(self.cfg.mandatory_labels_assistant_reply) + if protocol_schema.TextLabel.lang_mismatch not in valid_labels: + valid_labels.append(protocol_schema.TextLabel.lang_mismatch) if protocol_schema.TextLabel.quality not in valid_labels: valid_labels.append(protocol_schema.TextLabel.quality) @@ -400,6 +402,8 @@ class TreeManager: label_mode = protocol_schema.LabelTaskMode.simple label_disposition = protocol_schema.LabelTaskDisposition.spam valid_labels = list(self.cfg.mandatory_labels_prompter_reply) + if protocol_schema.TextLabel.lang_mismatch not in valid_labels: + valid_labels.append(protocol_schema.TextLabel.lang_mismatch) if protocol_schema.TextLabel.quality not in valid_labels: valid_labels.append(protocol_schema.TextLabel.quality) @@ -483,6 +487,8 @@ class TreeManager: valid_labels = self.cfg.mandatory_labels_initial_prompt label_mode = protocol_schema.LabelTaskMode.simple label_disposition = protocol_schema.LabelTaskDisposition.spam + if protocol_schema.TextLabel.lang_mismatch not in valid_labels: + valid_labels.append(protocol_schema.TextLabel.lang_mismatch) logger.info(f"Generating a LabelInitialPromptTask ({label_mode=:s}).") task = protocol_schema.LabelInitialPromptTask( @@ -803,8 +809,12 @@ class TreeManager: return backlog_tree def _calculate_acceptance(self, labels: list[TextLabels]): - # calculate acceptance based on spam label - return np.mean([1 - l.labels[protocol_schema.TextLabel.spam] for l in labels]) + # calculate acceptance based on lang_mismatch & spam label + lang_mismatch = np.mean([(l.labels.get(protocol_schema.TextLabel.lang_mismatch) or 0) for l in labels]) + spam = np.mean([l.labels[protocol_schema.TextLabel.spam] for l in labels]) + acceptance_score = 1 - (spam + lang_mismatch) + logger.debug(f"{acceptance_score=} ({spam=}, {lang_mismatch=})") + return acceptance_score def _query_need_review( self, state: message_tree_state.State, required_reviews: int, root: bool, lang: str diff --git a/oasst-shared/oasst_shared/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py index 4cdea856..f69d254a 100644 --- a/oasst-shared/oasst_shared/schemas/protocol.py +++ b/oasst-shared/oasst_shared/schemas/protocol.py @@ -355,6 +355,12 @@ class TextLabel(str, enum.Enum): fails_task = "fails_task", LabelWidget.yes_no, "Fails to follow the correct instruction / task" # flags + lang_mismatch = ( + "lang_mismatch", + LabelWidget.flag, + "Language mismatch", + "The message is written in language that differs from the currently selected language.", + ) pii = "pii", LabelWidget.flag, "Contains personal identifiable information (PII)" not_appropriate = "not_appropriate", LabelWidget.flag, "Inappropriate" hate_speech = ( diff --git a/text-frontend/auto_main.py b/text-frontend/auto_main.py index cea07c1e..c4ae19d3 100644 --- a/text-frontend/auto_main.py +++ b/text-frontend/auto_main.py @@ -96,9 +96,11 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "1234"): else: while labels_dict is None: labels = random.sample(valid_labels, random.randint(1, len(valid_labels))) - if all([label in valid_labels for label in labels]): - labels_dict = {label: "1" if label in labels else "0" for label in valid_labels} + labels_dict = { + label: 1 if label != "lang_mismatch" and label in labels else 0 + for label in valid_labels + } else: invalid_labels = [label for label in labels if label not in valid_labels] typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}") @@ -212,9 +214,11 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "1234"): else: while labels_dict is None: labels = random.sample(valid_labels, random.randint(1, len(valid_labels))) - if all([label in valid_labels for label in labels]): - labels_dict = {label: "1" if label in labels else "0" for label in valid_labels} + labels_dict = { + label: 1 if label != "lang_mismatch" and label in labels else 0 + for label in valid_labels + } else: invalid_labels = [label for label in labels if label not in valid_labels] typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}")