diff --git a/backend/oasst_backend/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py index 3860bb07..3a85cd5b 100644 --- a/backend/oasst_backend/api/v1/tasks.py +++ b/backend/oasst_backend/api/v1/tasks.py @@ -302,7 +302,8 @@ def tasks_interaction( logger.info( f"Frontend reports labels of {interaction.message_id=} with {interaction.labels=} by {interaction.user=}." ) - # TODO: check if the labels are valid? + # Labels are implicitly validated when converting str -> TextLabel + # So no need for explicit validation here pr.store_text_labels(interaction) return protocol_schema.TaskDone() case _: diff --git a/text-frontend/__main__.py b/text-frontend/__main__.py index de65749a..2060498d 100644 --- a/text-frontend/__main__.py +++ b/text-frontend/__main__.py @@ -211,9 +211,17 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY") _post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id}) valid_labels = task["valid_labels"] - labels_str: str = typer.prompt("Enter labels, separated by commas") - labels = labels_str.lower().replace(" ", "").split(",") - labels_dict = {label: "1" if label in labels else "0" for label in valid_labels} + + labels_dict = None + while labels_dict is None: + labels_str: str = typer.prompt("Enter labels, separated by commas") + labels = labels_str.lower().replace(" ", "").split(",") + + if all([label in valid_labels for label in labels]): + labels_dict = {label: "1" if label in labels else "0" for label in valid_labels} + else: + invalid_labels = [label for label in labels if label not in valid_labels] + typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}") # send ranking new_task = _post( @@ -240,9 +248,17 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY") _post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id}) valid_labels = task["valid_labels"] - labels_str: str = typer.prompt("Enter labels, separated by commas") - labels = labels_str.lower().replace(" ", "").split(",") - labels_dict = {label: "1" if label in labels else "0" for label in valid_labels} + + labels_dict = None + while labels_dict is None: + labels_str: str = typer.prompt("Enter labels, separated by commas") + labels = labels_str.lower().replace(" ", "").split(",") + + if all([label in valid_labels for label in labels]): + labels_dict = {label: "1" if label in labels else "0" for label in valid_labels} + else: + invalid_labels = [label for label in labels if label not in valid_labels] + typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}") # send ranking new_task = _post(