diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index 52aff24e..d6e5bcfc 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -374,9 +374,11 @@ class TreeManager: desired_task_type == protocol_schema.TaskRequestType.random and random.random() > self.cfg.p_full_labeling_review_reply_assistant ): - valid_labels = self.cfg.mandatory_labels_assistant_reply label_mode = protocol_schema.LabelTaskMode.simple label_disposition = protocol_schema.LabelTaskDisposition.spam + valid_labels = list(self.cfg.mandatory_labels_assistant_reply) + if protocol_schema.LabelTaskDisposition.quality not in valid_labels: + valid_labels.append(protocol_schema.TextLabel.quality) logger.info(f"Generating a LabelAssistantReplyTask. ({label_mode=:s})") task = protocol_schema.LabelAssistantReplyTask( @@ -396,8 +398,12 @@ class TreeManager: desired_task_type == protocol_schema.TaskRequestType.random and random.random() > self.cfg.p_full_labeling_review_reply_prompter ): - valid_labels = self.cfg.mandatory_labels_prompter_reply label_mode = protocol_schema.LabelTaskMode.simple + label_disposition = protocol_schema.LabelTaskDisposition.spam + valid_labels = list(self.cfg.mandatory_labels_prompter_reply) + if protocol_schema.LabelTaskDisposition.quality not in valid_labels: + valid_labels.append(protocol_schema.TextLabel.quality) + logger.info(f"Generating a LabelPrompterReplyTask. ({label_mode=:s})") task = protocol_schema.LabelPrompterReplyTask( message_id=message.id,