From f3ffde47ffc1cf2f49a8f835cf2c1a4a38ebcc9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Thu, 26 Jan 2023 23:00:54 +0100 Subject: [PATCH] add preferred lonely_children extension (#942) * add preferred lonely_children extension * simplify sampling process, lower the probability to 25% * exclude parents for replies that were recently used * lonely children := count > 0 * consider only tasks not done for parent exclusion * increase lonely child sampling probability --- ...84fcd6900dc_add_task_created_date_index.py | 26 ++++++++++++++ backend/oasst_backend/config.py | 9 +++++ backend/oasst_backend/models/task.py | 4 ++- backend/oasst_backend/task_repository.py | 16 ++++++++- backend/oasst_backend/tree_manager.py | 34 ++++++++++++++++--- 5 files changed, 83 insertions(+), 6 deletions(-) create mode 100644 backend/alembic/versions/2023_01_26_1835-c84fcd6900dc_add_task_created_date_index.py diff --git a/backend/alembic/versions/2023_01_26_1835-c84fcd6900dc_add_task_created_date_index.py b/backend/alembic/versions/2023_01_26_1835-c84fcd6900dc_add_task_created_date_index.py new file mode 100644 index 00000000..29fd1aec --- /dev/null +++ b/backend/alembic/versions/2023_01_26_1835-c84fcd6900dc_add_task_created_date_index.py @@ -0,0 +1,26 @@ +"""add task created date index + +Revision ID: c84fcd6900dc +Revises: 40ed93df0ed5 +Create Date: 2023-01-26 18:35:43.061589 + +""" +from alembic import op + +# revision identifiers, used by Alembic. +revision = "c84fcd6900dc" +down_revision = "40ed93df0ed5" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_index(op.f("ix_task_created_date"), "task", ["created_date"], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_task_created_date"), table_name="task") + # ### end Alembic commands ### diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index d831eca2..9952c654 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -57,6 +57,15 @@ class TreeManagerConfiguration(BaseModel): rank_prompter_replies: bool = False + lonely_children_count: int = 3 + """Number of children below which parents are preferred during sampling for reply tasks.""" + + p_lonely_child_extension: float = 0.8 + """Probability to select a parent with less than lonely_children_count children.""" + + recent_tasks_span_sec: int = 3 * 60 # 3 min + """Time in seconds of recent tasks to consider for exclusion during task selection.""" + class Settings(BaseSettings): PROJECT_NAME: str = "open-assistant backend" diff --git a/backend/oasst_backend/models/task.py b/backend/oasst_backend/models/task.py index a59f689e..7f91b157 100644 --- a/backend/oasst_backend/models/task.py +++ b/backend/oasst_backend/models/task.py @@ -20,7 +20,9 @@ class Task(SQLModel, table=True): ), ) created_date: Optional[datetime] = Field( - sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp()), + sa_column=sa.Column( + sa.DateTime(timezone=True), nullable=False, index=True, server_default=sa.func.current_timestamp() + ), ) expiry_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True)) user_id: Optional[UUID] = Field(nullable=True, foreign_key="user.id", index=True) diff --git a/backend/oasst_backend/task_repository.py b/backend/oasst_backend/task_repository.py index eb100fe3..5c5dea21 100644 --- a/backend/oasst_backend/task_repository.py +++ b/backend/oasst_backend/task_repository.py @@ -1,3 +1,4 @@ +from datetime import timedelta from typing import Optional from uuid import UUID @@ -9,7 +10,7 @@ from oasst_backend.user_repository import UserRepository from oasst_backend.utils.database_utils import CommitMode, managed_tx_method from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema -from sqlmodel import Session +from sqlmodel import Session, func, or_ from starlette.status import HTTP_404_NOT_FOUND @@ -219,3 +220,16 @@ class TaskRepository: def fetch_task_by_id(self, task_id: UUID) -> Task: task = self.db.query(Task).filter(Task.api_client_id == self.api_client.id, Task.id == task_id).one_or_none() return task + + def fetch_recent_reply_tasks( + self, max_age: timedelta = timedelta(minutes=5), done: bool = False, limit: int = 100 + ) -> list[Task]: + qry = self.db.query(Task).filter( + func.age(Task.created_date) < max_age, + or_(Task.payload_type == "AssistantReplyPayload", Task.payload_type == "PrompterReplyPayload"), + ) + if done is not None: + qry = qry.filter(Task.done == done) + if limit: + qry = qry.limit(limit) + return qry.all() diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index 992e75dd..89a51807 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -1,7 +1,7 @@ import json import random import sys -from datetime import datetime +from datetime import datetime, timedelta from enum import Enum from http import HTTPStatus from typing import Any, Dict, List, Optional, Tuple @@ -339,6 +339,7 @@ class TreeManager: message_tree_id = messages[-1].message_tree_id case TaskType.LABEL_REPLY: + if task_role == TaskRole.PROMPTER: replies_need_review = list(filter(lambda m: m.role == "prompter", replies_need_review)) elif task_role == TaskRole.ASSISTANT: @@ -398,19 +399,44 @@ class TreeManager: message_tree_id = message.message_tree_id case TaskType.REPLY: - # select a tree with missing replies + + recent_reply_tasks = self.pr.task_repository.fetch_recent_reply_tasks( + max_age=timedelta(seconds=self.cfg.recent_tasks_span_sec), done=False + ) + recent_reply_task_parents = {t.parent_message_id for t in recent_reply_tasks} + if task_role == TaskRole.PROMPTER: extendible_parents = list(filter(lambda x: x.parent_role == "assistant", extendible_parents)) elif task_role == TaskRole.ASSISTANT: extendible_parents = list(filter(lambda x: x.parent_role == "prompter", extendible_parents)) + # select a tree with missing replies if len(extendible_parents) > 0: - random_parent = random.choice(extendible_parents) + random_parent: ExtendibleParentRow = None + if self.cfg.p_lonely_child_extension > 0 and self.cfg.lonely_children_count > 1: + # check if we have extendible parents with a small number of replies + + lonely_children_parents = [ + p + for p in extendible_parents + if 0 < p.active_children_count < self.cfg.lonely_children_count + and p.parent_id not in recent_reply_task_parents + ] + if len(lonely_children_parents) > 0 and random.random() < self.cfg.p_lonely_child_extension: + random_parent = random.choice(lonely_children_parents) + + if random_parent is None: + # try to exclude parents for which tasks were recently handed out + fresh_parents = [p for p in extendible_parents if p.parent_id not in recent_reply_task_parents] + if len(fresh_parents) > 0: + random_parent = random.choice(fresh_parents) + else: + random_parent = random.choice(extendible_parents) # fetch random conversation to extend logger.debug(f"selected {random_parent=}") messages = self.pr.fetch_message_conversation(random_parent.parent_id) - assert all(m.review_result for m in messages) # ensure all messages have positive review + assert all(m.review_result for m in messages) # ensure all messages have positive reviews conversation = prepare_conversation(messages) # generate reply task depending on last message