mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-28 16:20:34 +08:00
add preferred lonely_children extension (#942)
* add preferred lonely_children extension * simplify sampling process, lower the probability to 25% * exclude parents for replies that were recently used * lonely children := count > 0 * consider only tasks not done for parent exclusion * increase lonely child sampling probability
This commit is contained in:
@@ -0,0 +1,26 @@
|
||||
"""add task created date index
|
||||
|
||||
Revision ID: c84fcd6900dc
|
||||
Revises: 40ed93df0ed5
|
||||
Create Date: 2023-01-26 18:35:43.061589
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c84fcd6900dc"
|
||||
down_revision = "40ed93df0ed5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_index(op.f("ix_task_created_date"), "task", ["created_date"], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f("ix_task_created_date"), table_name="task")
|
||||
# ### end Alembic commands ###
|
||||
@@ -57,6 +57,15 @@ class TreeManagerConfiguration(BaseModel):
|
||||
|
||||
rank_prompter_replies: bool = False
|
||||
|
||||
lonely_children_count: int = 3
|
||||
"""Number of children below which parents are preferred during sampling for reply tasks."""
|
||||
|
||||
p_lonely_child_extension: float = 0.8
|
||||
"""Probability to select a parent with less than lonely_children_count children."""
|
||||
|
||||
recent_tasks_span_sec: int = 3 * 60 # 3 min
|
||||
"""Time in seconds of recent tasks to consider for exclusion during task selection."""
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
PROJECT_NAME: str = "open-assistant backend"
|
||||
|
||||
@@ -20,7 +20,9 @@ class Task(SQLModel, table=True):
|
||||
),
|
||||
)
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp()),
|
||||
sa_column=sa.Column(
|
||||
sa.DateTime(timezone=True), nullable=False, index=True, server_default=sa.func.current_timestamp()
|
||||
),
|
||||
)
|
||||
expiry_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True))
|
||||
user_id: Optional[UUID] = Field(nullable=True, foreign_key="user.id", index=True)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from datetime import timedelta
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
@@ -9,7 +10,7 @@ from oasst_backend.user_repository import UserRepository
|
||||
from oasst_backend.utils.database_utils import CommitMode, managed_tx_method
|
||||
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
from sqlmodel import Session, func, or_
|
||||
from starlette.status import HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
@@ -219,3 +220,16 @@ class TaskRepository:
|
||||
def fetch_task_by_id(self, task_id: UUID) -> Task:
|
||||
task = self.db.query(Task).filter(Task.api_client_id == self.api_client.id, Task.id == task_id).one_or_none()
|
||||
return task
|
||||
|
||||
def fetch_recent_reply_tasks(
|
||||
self, max_age: timedelta = timedelta(minutes=5), done: bool = False, limit: int = 100
|
||||
) -> list[Task]:
|
||||
qry = self.db.query(Task).filter(
|
||||
func.age(Task.created_date) < max_age,
|
||||
or_(Task.payload_type == "AssistantReplyPayload", Task.payload_type == "PrompterReplyPayload"),
|
||||
)
|
||||
if done is not None:
|
||||
qry = qry.filter(Task.done == done)
|
||||
if limit:
|
||||
qry = qry.limit(limit)
|
||||
return qry.all()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import random
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
@@ -339,6 +339,7 @@ class TreeManager:
|
||||
message_tree_id = messages[-1].message_tree_id
|
||||
|
||||
case TaskType.LABEL_REPLY:
|
||||
|
||||
if task_role == TaskRole.PROMPTER:
|
||||
replies_need_review = list(filter(lambda m: m.role == "prompter", replies_need_review))
|
||||
elif task_role == TaskRole.ASSISTANT:
|
||||
@@ -398,19 +399,44 @@ class TreeManager:
|
||||
message_tree_id = message.message_tree_id
|
||||
|
||||
case TaskType.REPLY:
|
||||
# select a tree with missing replies
|
||||
|
||||
recent_reply_tasks = self.pr.task_repository.fetch_recent_reply_tasks(
|
||||
max_age=timedelta(seconds=self.cfg.recent_tasks_span_sec), done=False
|
||||
)
|
||||
recent_reply_task_parents = {t.parent_message_id for t in recent_reply_tasks}
|
||||
|
||||
if task_role == TaskRole.PROMPTER:
|
||||
extendible_parents = list(filter(lambda x: x.parent_role == "assistant", extendible_parents))
|
||||
elif task_role == TaskRole.ASSISTANT:
|
||||
extendible_parents = list(filter(lambda x: x.parent_role == "prompter", extendible_parents))
|
||||
|
||||
# select a tree with missing replies
|
||||
if len(extendible_parents) > 0:
|
||||
random_parent = random.choice(extendible_parents)
|
||||
random_parent: ExtendibleParentRow = None
|
||||
if self.cfg.p_lonely_child_extension > 0 and self.cfg.lonely_children_count > 1:
|
||||
# check if we have extendible parents with a small number of replies
|
||||
|
||||
lonely_children_parents = [
|
||||
p
|
||||
for p in extendible_parents
|
||||
if 0 < p.active_children_count < self.cfg.lonely_children_count
|
||||
and p.parent_id not in recent_reply_task_parents
|
||||
]
|
||||
if len(lonely_children_parents) > 0 and random.random() < self.cfg.p_lonely_child_extension:
|
||||
random_parent = random.choice(lonely_children_parents)
|
||||
|
||||
if random_parent is None:
|
||||
# try to exclude parents for which tasks were recently handed out
|
||||
fresh_parents = [p for p in extendible_parents if p.parent_id not in recent_reply_task_parents]
|
||||
if len(fresh_parents) > 0:
|
||||
random_parent = random.choice(fresh_parents)
|
||||
else:
|
||||
random_parent = random.choice(extendible_parents)
|
||||
|
||||
# fetch random conversation to extend
|
||||
logger.debug(f"selected {random_parent=}")
|
||||
messages = self.pr.fetch_message_conversation(random_parent.parent_id)
|
||||
assert all(m.review_result for m in messages) # ensure all messages have positive review
|
||||
assert all(m.review_result for m in messages) # ensure all messages have positive reviews
|
||||
conversation = prepare_conversation(messages)
|
||||
|
||||
# generate reply task depending on last message
|
||||
|
||||
Reference in New Issue
Block a user