diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index 0aeea390..c18bd4c2 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -47,8 +47,13 @@ class TreeManagerConfiguration(BaseModel): """Number of rankings in which the message participated.""" mandatory_labels_initial_prompt: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam] + """Mandatory labels in text-labeling tasks for initial prompts.""" + mandatory_labels_assistant_reply: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam] + """Mandatory labels in text-labeling tasks for assistant reylies.""" + mandatory_labels_prompter_reply: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam] + """Mandatory labels in text-labeling tasks for prompter replies.""" class Settings(BaseSettings): diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index 6181d62e..d7abd9f8 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -68,9 +68,11 @@ class IncompleteRankingsRow(pydantic.BaseModel): class TreeManager: _all_text_labels = list(map(lambda x: x.value, protocol_schema.TextLabel)) - def __init__(self, db: Session, prompt_repository: PromptRepository): + def __init__( + self, db: Session, prompt_repository: PromptRepository, cfg: Optional[TreeManagerConfiguration] = None + ): self.db = db - self.cfg = settings.tree_manager + self.cfg = cfg or settings.tree_manager self.pr = prompt_repository def _task_selection(