add preferred lonely_children extension (#942)

* add preferred lonely_children extension

* simplify sampling process, lower the probability to 25%

* exclude parents for replies that were recently used

* lonely children := count > 0

* consider only tasks not done for parent exclusion

* increase lonely child sampling probability
This commit is contained in:
Andreas Köpf
2023-01-26 23:00:54 +01:00
committed by GitHub
parent de00cc82d0
commit f3ffde47ff
5 changed files with 83 additions and 6 deletions
@@ -0,0 +1,26 @@
"""add task created date index
Revision ID: c84fcd6900dc
Revises: 40ed93df0ed5
Create Date: 2023-01-26 18:35:43.061589
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "c84fcd6900dc"
down_revision = "40ed93df0ed5"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_index(op.f("ix_task_created_date"), "task", ["created_date"], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_task_created_date"), table_name="task")
# ### end Alembic commands ###
+9
View File
@@ -57,6 +57,15 @@ class TreeManagerConfiguration(BaseModel):
rank_prompter_replies: bool = False
lonely_children_count: int = 3
"""Number of children below which parents are preferred during sampling for reply tasks."""
p_lonely_child_extension: float = 0.8
"""Probability to select a parent with less than lonely_children_count children."""
recent_tasks_span_sec: int = 3 * 60 # 3 min
"""Time in seconds of recent tasks to consider for exclusion during task selection."""
class Settings(BaseSettings):
PROJECT_NAME: str = "open-assistant backend"
+3 -1
View File
@@ -20,7 +20,9 @@ class Task(SQLModel, table=True):
),
)
created_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp()),
sa_column=sa.Column(
sa.DateTime(timezone=True), nullable=False, index=True, server_default=sa.func.current_timestamp()
),
)
expiry_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True))
user_id: Optional[UUID] = Field(nullable=True, foreign_key="user.id", index=True)
+15 -1
View File
@@ -1,3 +1,4 @@
from datetime import timedelta
from typing import Optional
from uuid import UUID
@@ -9,7 +10,7 @@ from oasst_backend.user_repository import UserRepository
from oasst_backend.utils.database_utils import CommitMode, managed_tx_method
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
from sqlmodel import Session, func, or_
from starlette.status import HTTP_404_NOT_FOUND
@@ -219,3 +220,16 @@ class TaskRepository:
def fetch_task_by_id(self, task_id: UUID) -> Task:
task = self.db.query(Task).filter(Task.api_client_id == self.api_client.id, Task.id == task_id).one_or_none()
return task
def fetch_recent_reply_tasks(
self, max_age: timedelta = timedelta(minutes=5), done: bool = False, limit: int = 100
) -> list[Task]:
qry = self.db.query(Task).filter(
func.age(Task.created_date) < max_age,
or_(Task.payload_type == "AssistantReplyPayload", Task.payload_type == "PrompterReplyPayload"),
)
if done is not None:
qry = qry.filter(Task.done == done)
if limit:
qry = qry.limit(limit)
return qry.all()
+30 -4
View File
@@ -1,7 +1,7 @@
import json
import random
import sys
from datetime import datetime
from datetime import datetime, timedelta
from enum import Enum
from http import HTTPStatus
from typing import Any, Dict, List, Optional, Tuple
@@ -339,6 +339,7 @@ class TreeManager:
message_tree_id = messages[-1].message_tree_id
case TaskType.LABEL_REPLY:
if task_role == TaskRole.PROMPTER:
replies_need_review = list(filter(lambda m: m.role == "prompter", replies_need_review))
elif task_role == TaskRole.ASSISTANT:
@@ -398,19 +399,44 @@ class TreeManager:
message_tree_id = message.message_tree_id
case TaskType.REPLY:
# select a tree with missing replies
recent_reply_tasks = self.pr.task_repository.fetch_recent_reply_tasks(
max_age=timedelta(seconds=self.cfg.recent_tasks_span_sec), done=False
)
recent_reply_task_parents = {t.parent_message_id for t in recent_reply_tasks}
if task_role == TaskRole.PROMPTER:
extendible_parents = list(filter(lambda x: x.parent_role == "assistant", extendible_parents))
elif task_role == TaskRole.ASSISTANT:
extendible_parents = list(filter(lambda x: x.parent_role == "prompter", extendible_parents))
# select a tree with missing replies
if len(extendible_parents) > 0:
random_parent = random.choice(extendible_parents)
random_parent: ExtendibleParentRow = None
if self.cfg.p_lonely_child_extension > 0 and self.cfg.lonely_children_count > 1:
# check if we have extendible parents with a small number of replies
lonely_children_parents = [
p
for p in extendible_parents
if 0 < p.active_children_count < self.cfg.lonely_children_count
and p.parent_id not in recent_reply_task_parents
]
if len(lonely_children_parents) > 0 and random.random() < self.cfg.p_lonely_child_extension:
random_parent = random.choice(lonely_children_parents)
if random_parent is None:
# try to exclude parents for which tasks were recently handed out
fresh_parents = [p for p in extendible_parents if p.parent_id not in recent_reply_task_parents]
if len(fresh_parents) > 0:
random_parent = random.choice(fresh_parents)
else:
random_parent = random.choice(extendible_parents)
# fetch random conversation to extend
logger.debug(f"selected {random_parent=}")
messages = self.pr.fetch_message_conversation(random_parent.parent_id)
assert all(m.review_result for m in messages) # ensure all messages have positive review
assert all(m.review_result for m in messages) # ensure all messages have positive reviews
conversation = prepare_conversation(messages)
# generate reply task depending on last message