mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Import message trees from jsonl file (#964)
* add new backlog_ranking tree state * add first version of import script * allow activation of trees during import * add min_active_rankings_per_lang config param * add settings docstring
This commit is contained in:
+31
@@ -0,0 +1,31 @@
|
||||
"""add origin column to message_tree_state
|
||||
|
||||
Revision ID: 49d8445b4c90
|
||||
Revises: f856bf19d32b
|
||||
Create Date: 2023-01-28 11:57:45.580027
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "49d8445b4c90"
|
||||
down_revision = "f856bf19d32b"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("message", sa.Column("synthetic", sa.Boolean(), server_default=sa.text("false"), nullable=False))
|
||||
op.add_column("message", sa.Column("model_name", sa.String(length=1024), nullable=True))
|
||||
op.add_column("message_tree_state", sa.Column("origin", sa.String(length=1024), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("message_tree_state", "origin")
|
||||
op.drop_column("message", "model_name")
|
||||
op.drop_column("message", "synthetic")
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,187 @@
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
import oasst_backend.models.db_payload as db_payload
|
||||
import oasst_backend.utils.database_utils as db_utils
|
||||
import pydantic
|
||||
from loguru import logger
|
||||
from oasst_backend.api.deps import create_api_client
|
||||
from oasst_backend.models import ApiClient, Message
|
||||
from oasst_backend.models.message_tree_state import MessageTreeState
|
||||
from oasst_backend.models.message_tree_state import State as TreeState
|
||||
from oasst_backend.models.payload_column_type import PayloadContainer
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.user_repository import UserRepository
|
||||
from oasst_backend.utils.tree_export import ExportMessageNode, ExportMessageTree
|
||||
from sqlmodel import Session
|
||||
|
||||
# well known id
|
||||
IMPORT_API_CLIENT_ID = UUID("bd8fde8b-1d8e-4e9a-9966-e96d000f8363")
|
||||
|
||||
|
||||
class Importer:
|
||||
def __init__(self, db: Session, origin: str, model_name: Optional[str] = None):
|
||||
self.db = db
|
||||
self.origin = origin
|
||||
self.model_name = model_name
|
||||
|
||||
# get import api client
|
||||
api_client = db.query(ApiClient).filter(ApiClient.id == IMPORT_API_CLIENT_ID).first()
|
||||
if not api_client:
|
||||
api_client = create_api_client(
|
||||
session=db,
|
||||
description="API client used for importing data",
|
||||
frontend_type="import",
|
||||
force_id=IMPORT_API_CLIENT_ID,
|
||||
)
|
||||
|
||||
ur = UserRepository(db, api_client)
|
||||
self.import_user = ur.lookup_system_user(username="import")
|
||||
self.pr = PromptRepository(db=db, api_client=api_client, user_repository=ur)
|
||||
self.api_client = api_client
|
||||
|
||||
def fetch_message_tree_state(self, message_tree_id: UUID) -> MessageTreeState:
|
||||
return self.db.query(MessageTreeState).filter(MessageTreeState.message_tree_id == message_tree_id).one_or_none()
|
||||
|
||||
def import_message(
|
||||
self, message: ExportMessageNode, message_tree_id: UUID, parent_id: Optional[UUID] = None
|
||||
) -> Message:
|
||||
payload = db_payload.MessagePayload(text=message.text)
|
||||
msg = Message(
|
||||
id=message.message_id,
|
||||
message_tree_id=message_tree_id,
|
||||
frontend_message_id=message.message_id,
|
||||
parent_id=parent_id,
|
||||
review_count=message.review_count or 0,
|
||||
lang=message.lang or "en",
|
||||
review_result=True,
|
||||
synthetic=message.synthetic if message.synthetic is not None else True,
|
||||
model_name=message.model_name or self.model_name,
|
||||
role=message.role,
|
||||
api_client_id=self.api_client.id,
|
||||
payload_type=type(payload).__name__,
|
||||
payload=PayloadContainer(payload=payload),
|
||||
user_id=self.import_user.id,
|
||||
)
|
||||
self.db.add(msg)
|
||||
if message.replies:
|
||||
for r in message.replies:
|
||||
self.import_message(r, message_tree_id=message_tree_id, parent_id=msg.id)
|
||||
self.db.flush()
|
||||
if parent_id is None:
|
||||
self.pr.update_children_counts(msg.id)
|
||||
self.db.refresh(msg)
|
||||
return msg
|
||||
|
||||
def import_tree(
|
||||
self, tree: ExportMessageTree, state: TreeState = TreeState.BACKLOG_RANKING
|
||||
) -> tuple[MessageTreeState, Message]:
|
||||
assert tree.message_tree_id is not None and tree.message_tree_id == tree.prompt.message_id
|
||||
root_msg = self.import_message(tree.prompt, message_tree_id=tree.prompt.message_id)
|
||||
assert state == TreeState.BACKLOG_RANKING or state == TreeState.RANKING, f"{state} not supported for import"
|
||||
active = state == TreeState.RANKING
|
||||
mts = MessageTreeState(
|
||||
message_tree_id=root_msg.id,
|
||||
goal_tree_size=0,
|
||||
max_depth=0,
|
||||
max_children_count=0,
|
||||
state=state,
|
||||
origin=self.origin,
|
||||
active=active,
|
||||
)
|
||||
self.db.add(mts)
|
||||
return mts, root_msg
|
||||
|
||||
|
||||
def import_file(
|
||||
input_file_path: Path,
|
||||
origin: str,
|
||||
*,
|
||||
model_name: Optional[str] = None,
|
||||
num_activate: int = 0,
|
||||
max_count: Optional[int] = None,
|
||||
dry_run: bool = False,
|
||||
) -> int:
|
||||
@db_utils.managed_tx_function(auto_commit=db_utils.CommitMode.ROLLBACK if dry_run else db_utils.CommitMode.COMMIT)
|
||||
def import_tx(db: Session) -> int:
|
||||
importer = Importer(db, origin=origin, model_name=model_name)
|
||||
i = 0
|
||||
with input_file_path.open() as file_in:
|
||||
# read line tree object
|
||||
for line in file_in:
|
||||
dict_tree = json.loads(line)
|
||||
|
||||
# validate data
|
||||
tree: ExportMessageTree = pydantic.parse_obj_as(ExportMessageTree, dict_tree)
|
||||
existing_mts = importer.fetch_message_tree_state(tree.message_tree_id)
|
||||
if existing_mts:
|
||||
logger.info(f"Skipping existing message tree: {tree.message_tree_id}")
|
||||
else:
|
||||
state = TreeState.BACKLOG_RANKING if i >= num_activate else TreeState.RANKING
|
||||
mts, root_msg = importer.import_tree(tree, state=state)
|
||||
i += 1
|
||||
logger.info(
|
||||
f"imported tree: {mts.message_tree_id}, {mts.state=}, {mts.active=}, {root_msg.children_count=}"
|
||||
)
|
||||
|
||||
if max_count and i >= max_count:
|
||||
logger.info(f"Reached max count {max_count} of trees to import.")
|
||||
break
|
||||
return i
|
||||
|
||||
if dry_run:
|
||||
logger.info("DRY RUN with rollback")
|
||||
return import_tx()
|
||||
|
||||
|
||||
def parse_args():
|
||||
def str2bool(v):
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
if v.lower() in ("yes", "true", "t", "y", "1"):
|
||||
return True
|
||||
elif v.lower() in ("no", "false", "f", "n", "0"):
|
||||
return False
|
||||
else:
|
||||
raise argparse.ArgumentTypeError("Boolean value expected.")
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"input_file_path",
|
||||
help="Input file path",
|
||||
)
|
||||
parser.add_argument("--origin", type=str, default=None, help="Value for origin of message trees")
|
||||
parser.add_argument("--model_name", type=str, default=None, help="Default name of model (if missing in messages)")
|
||||
parser.add_argument("--num_activate", type=int, default=0, help="Number of trees to add in ranking state")
|
||||
parser.add_argument("--max_count", type=int, default=None, help="Maximum number of message trees to import")
|
||||
parser.add_argument("--dry_run", type=str2bool, default=False)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
input_file_path = Path(args.input_file_path)
|
||||
if not input_file_path.exists() or not input_file_path.is_file():
|
||||
print("Invalid input file:", args.input_file_path)
|
||||
exit(1)
|
||||
|
||||
dry_run = args.dry_run
|
||||
num_imported = import_file(
|
||||
input_file_path,
|
||||
origin=args.origin or input_file_path.name,
|
||||
model_name=args.model_name,
|
||||
num_activate=args.num_activate,
|
||||
max_count=args.max_count,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
logger.info(f"Done ({num_imported=}, {dry_run=})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+3
-1
@@ -191,6 +191,7 @@ if settings.DEBUG_USE_SEED_DATA:
|
||||
review_count=5,
|
||||
review_result=True,
|
||||
check_tree_state=False,
|
||||
check_duplicate=False,
|
||||
)
|
||||
if message.parent_id is None:
|
||||
tm._insert_default_state(
|
||||
@@ -215,7 +216,8 @@ def ensure_tree_states():
|
||||
try:
|
||||
logger.info("Startup: TreeManager.ensure_tree_states()")
|
||||
with Session(engine) as db:
|
||||
tm = TreeManager(db, None)
|
||||
api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=db)
|
||||
tm = TreeManager(db, PromptRepository(db, api_client=api_client))
|
||||
tm.ensure_tree_states()
|
||||
|
||||
except Exception:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from http import HTTPStatus
|
||||
from secrets import token_hex
|
||||
from typing import Generator, NamedTuple
|
||||
from typing import Generator, NamedTuple, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends, Request, Response, Security
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
@@ -67,6 +68,7 @@ def create_api_client(
|
||||
trusted: bool | None = False,
|
||||
admin_email: str | None = None,
|
||||
api_key: str | None = None,
|
||||
force_id: Optional[UUID] = None,
|
||||
) -> ApiClient:
|
||||
if api_key is None:
|
||||
api_key = token_hex(32)
|
||||
@@ -79,6 +81,8 @@ def create_api_client(
|
||||
trusted=trusted,
|
||||
admin_email=admin_email,
|
||||
)
|
||||
if force_id:
|
||||
api_client.id = force_id
|
||||
session.add(api_client)
|
||||
session.commit()
|
||||
session.refresh(api_client)
|
||||
|
||||
@@ -46,6 +46,15 @@ class TreeManagerConfiguration(BaseModel):
|
||||
num_required_rankings: int = 3
|
||||
"""Number of rankings in which the message participated."""
|
||||
|
||||
p_activate_backlog_tree: float = 0.8
|
||||
"""Probability to activate a message tree in BACKLOG_RANKING state when another tree enters
|
||||
a terminal state. Use this settting to control ratio of initial prompts and backlog tree
|
||||
activations."""
|
||||
|
||||
min_active_rankings_per_lang: int = 2
|
||||
"""When the number of active ranking tasks is below this value when a tree enters a terminal
|
||||
state an available trees in BACKLOG_RANKING will be actived (i.e. enters the RANKING state)."""
|
||||
|
||||
labels_initial_prompt: list[TextLabel] = [
|
||||
TextLabel.spam,
|
||||
TextLabel.quality,
|
||||
|
||||
@@ -57,6 +57,11 @@ class Message(SQLModel, table=True):
|
||||
|
||||
rank: Optional[int] = Field(nullable=True)
|
||||
|
||||
synthetic: Optional[bool] = Field(
|
||||
sa_column=sa.Column(sa.Boolean, default=False, server_default=false(), nullable=False)
|
||||
)
|
||||
model_name: Optional[str] = Field(sa_column=sa.Column(sa.String(1024), nullable=True))
|
||||
|
||||
emojis: Optional[dict[str, int]] = Field(default=None, sa_column=sa.Column(pg.JSONB), nullable=False)
|
||||
_user_emojis: Optional[list[str]] = PrivateAttr(default=None)
|
||||
|
||||
|
||||
@@ -43,6 +43,9 @@ class State(str, Enum):
|
||||
HALTED_BY_MODERATOR = "halted_by_moderator"
|
||||
"""A moderator decided to manually halt the message tree construction process."""
|
||||
|
||||
BACKLOG_RANKING = "backlog_ranking"
|
||||
"""Imported tree ready to be activated and ranked by users (currently inactive)."""
|
||||
|
||||
|
||||
VALID_STATES = (
|
||||
State.INITIAL_PROMPT_REVIEW,
|
||||
@@ -51,6 +54,7 @@ VALID_STATES = (
|
||||
State.READY_FOR_SCORING,
|
||||
State.READY_FOR_EXPORT,
|
||||
State.ABORTED_LOW_GRADE,
|
||||
State.BACKLOG_RANKING,
|
||||
)
|
||||
|
||||
TERMINAL_STATES = (State.READY_FOR_EXPORT, State.ABORTED_LOW_GRADE, State.SCORING_FAILED, State.HALTED_BY_MODERATOR)
|
||||
@@ -67,3 +71,4 @@ class MessageTreeState(SQLModel, table=True):
|
||||
max_children_count: int = Field(nullable=False)
|
||||
state: str = Field(nullable=False, max_length=128, index=True)
|
||||
active: bool = Field(nullable=False, index=True)
|
||||
origin: str = Field(sa_column=sa.Column(sa.String(1024), nullable=True))
|
||||
|
||||
@@ -177,6 +177,7 @@ class PromptRepository:
|
||||
review_count: int = 0,
|
||||
review_result: bool = False,
|
||||
check_tree_state: bool = True,
|
||||
check_duplicate: bool = True,
|
||||
) -> Message:
|
||||
self.ensure_user_is_enabled()
|
||||
|
||||
@@ -199,7 +200,7 @@ class PromptRepository:
|
||||
logger.error(f"Message size {len(text)=} exceeds size limit of {settings.MESSAGE_SIZE_LIMIT=}.")
|
||||
raise OasstError("Message size too long.", OasstErrorCode.TASK_MESSAGE_TOO_LONG)
|
||||
|
||||
if self.check_users_recent_replies_for_duplicates(text):
|
||||
if check_duplicate and self.check_users_recent_replies_for_duplicates(text):
|
||||
raise OasstError("User recent messages have duplicates", OasstErrorCode.TASK_MESSAGE_DUPLICATED)
|
||||
|
||||
if task.parent_message_id:
|
||||
@@ -909,8 +910,7 @@ FROM (
|
||||
) AS cc
|
||||
WHERE message.id = cc.id;
|
||||
"""
|
||||
r = self.db.execute(text(sql_update_children_count), {"message_tree_id": message_tree_id})
|
||||
logger.debug(f"update_children_count({message_tree_id=}): {r.rowcount} rows.")
|
||||
self.db.execute(text(sql_update_children_count), {"message_tree_id": message_tree_id})
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def mark_messages_deleted(self, messages: Message | UUID | list[Message | UUID], recursive: bool = True):
|
||||
|
||||
@@ -25,7 +25,7 @@ from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingM
|
||||
from oasst_backend.utils.ranking import ranked_pairs
|
||||
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session, func, not_, text, update
|
||||
from sqlmodel import Session, func, not_, or_, text, update
|
||||
|
||||
|
||||
class TaskType(Enum):
|
||||
@@ -73,6 +73,7 @@ class IncompleteRankingsRow(pydantic.BaseModel):
|
||||
role: str
|
||||
children_count: int
|
||||
child_min_ranking_count: int
|
||||
message_tree_id: UUID
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
@@ -625,19 +626,28 @@ class TreeManager:
|
||||
|
||||
return protocol_schema.TaskDone()
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def _enter_state(self, mts: MessageTreeState, state: message_tree_state.State):
|
||||
assert mts and mts.active
|
||||
assert mts
|
||||
|
||||
is_terminal = state in message_tree_state.TERMINAL_STATES
|
||||
|
||||
was_active = mts.active
|
||||
if is_terminal:
|
||||
mts.active = False
|
||||
mts.state = state.value
|
||||
self.db.add(mts)
|
||||
self.db.flush
|
||||
|
||||
if is_terminal:
|
||||
logger.info(f"Tree entered terminal '{mts.state}' state ({mts.message_tree_id=})")
|
||||
root_msg = self.pr.fetch_message(message_id=mts.message_tree_id, fail_if_missing=False)
|
||||
if root_msg and was_active:
|
||||
if random.random() < self.cfg.p_activate_backlog_tree:
|
||||
self.activate_backlog_tree(lang=root_msg.lang)
|
||||
|
||||
if self.cfg.min_active_rankings_per_lang > 0:
|
||||
incomplete_rankings = self.query_incomplete_rankings(lang=root_msg.lang)
|
||||
if len(incomplete_rankings) < self.cfg.min_active_rankings_per_lang:
|
||||
self.activate_backlog_tree(lang=root_msg.lang)
|
||||
else:
|
||||
logger.info(f"Tree entered '{mts.state}' state ({mts.message_tree_id=})")
|
||||
|
||||
@@ -680,24 +690,30 @@ class TreeManager:
|
||||
self._enter_state(mts, message_tree_state.State.RANKING)
|
||||
return True
|
||||
|
||||
def check_condition_for_scoring_state(
|
||||
self, message_tree_id: UUID
|
||||
) -> Tuple[bool, dict[UUID, list[MessageReaction]]]:
|
||||
def check_condition_for_scoring_state(self, message_tree_id: UUID) -> bool:
|
||||
logger.debug(f"check_condition_for_scoring_state({message_tree_id=})")
|
||||
|
||||
mts = self.pr.fetch_tree_state(message_tree_id)
|
||||
if not mts.active or mts.state != message_tree_state.State.RANKING:
|
||||
logger.debug(f"False {mts.active=}, {mts.state=}")
|
||||
return False, None
|
||||
if mts.state != message_tree_state.State.SCORING_FAILED:
|
||||
if not mts.active or mts.state not in (
|
||||
message_tree_state.State.RANKING,
|
||||
message_tree_state.State.READY_FOR_SCORING,
|
||||
):
|
||||
logger.debug(f"False {mts.active=}, {mts.state=}")
|
||||
return False
|
||||
|
||||
ranking_role_filter = None if self.cfg.rank_prompter_replies else "assistant"
|
||||
rankings_by_message = self.query_tree_ranking_results(message_tree_id, role_filter=ranking_role_filter)
|
||||
for parent_msg_id, ranking in rankings_by_message.items():
|
||||
if len(ranking) < self.cfg.num_required_rankings:
|
||||
logger.debug(f"False {parent_msg_id=} {len(ranking)=}")
|
||||
return False, None
|
||||
return False
|
||||
|
||||
self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING)
|
||||
if (
|
||||
mts.state != message_tree_state.State.SCORING_FAILED
|
||||
and mts.state != message_tree_state.State.READY_FOR_SCORING
|
||||
):
|
||||
self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING)
|
||||
self.update_message_ranks(message_tree_id, rankings_by_message)
|
||||
return True
|
||||
|
||||
@@ -759,8 +775,35 @@ class TreeManager:
|
||||
return False
|
||||
|
||||
self._enter_state(mts, message_tree_state.State.READY_FOR_EXPORT)
|
||||
|
||||
return True
|
||||
|
||||
def activate_backlog_tree(self, lang: str) -> MessageTreeState:
|
||||
while True:
|
||||
# find tree in backlog state
|
||||
backlog_tree: MessageTreeState = (
|
||||
self.db.query(MessageTreeState)
|
||||
.join(Message, MessageTreeState.message_tree_id == Message.id) # root msg
|
||||
.filter(MessageTreeState.state == message_tree_state.State.BACKLOG_RANKING)
|
||||
.filter(Message.lang == lang)
|
||||
.limit(1)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
if not backlog_tree:
|
||||
return None
|
||||
|
||||
if len(self.query_tree_ranking_results(message_tree_id=backlog_tree.message_tree_id)) == 0:
|
||||
logger.info(
|
||||
f"Backlog tree {backlog_tree.message_tree_id} has no children to rank, aborting with 'aborted_low_grade' state."
|
||||
)
|
||||
self._enter_state(backlog_tree, message_tree_state.State.ABORTED_LOW_GRADE)
|
||||
else:
|
||||
logger.info(f"Activating backlog tree {backlog_tree.message_tree_id}")
|
||||
backlog_tree.active = True
|
||||
self._enter_state(backlog_tree, message_tree_state.State.RANKING)
|
||||
return backlog_tree
|
||||
|
||||
def _calculate_acceptance(self, labels: list[TextLabels]):
|
||||
# calculate acceptance based on spam label
|
||||
return np.mean([1 - l.labels[protocol_schema.TextLabel.spam] for l in labels])
|
||||
@@ -828,7 +871,8 @@ class TreeManager:
|
||||
_sql_find_incomplete_rankings = """
|
||||
-- find incomplete rankings
|
||||
SELECT m.parent_id, m.role, COUNT(m.id) children_count, MIN(m.ranking_count) child_min_ranking_count,
|
||||
COUNT(m.id) FILTER (WHERE m.ranking_count >= :num_required_rankings) as completed_rankings
|
||||
COUNT(m.id) FILTER (WHERE m.ranking_count >= :num_required_rankings) as completed_rankings,
|
||||
mts.message_tree_id
|
||||
FROM message_tree_state mts
|
||||
INNER JOIN message m ON mts.message_tree_id = m.message_tree_id
|
||||
WHERE mts.active -- only consider active trees
|
||||
@@ -837,7 +881,7 @@ WHERE mts.active -- only consider active trees
|
||||
AND m.lang = :lang -- matches lang
|
||||
AND NOT m.deleted -- not deleted
|
||||
AND m.parent_id IS NOT NULL -- ignore initial prompts
|
||||
GROUP BY m.parent_id, m.role
|
||||
GROUP BY m.parent_id, m.role, mts.message_tree_id
|
||||
HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings
|
||||
"""
|
||||
|
||||
@@ -846,7 +890,8 @@ HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings
|
||||
WITH incomplete_rankings AS ({_sql_find_incomplete_rankings})
|
||||
SELECT ir.* FROM incomplete_rankings ir
|
||||
LEFT JOIN message_reaction mr ON ir.parent_id = mr.message_id AND mr.payload_type = 'RankingReactionPayload'
|
||||
GROUP BY ir.parent_id, ir.role, ir.children_count, ir.child_min_ranking_count, ir.completed_rankings
|
||||
GROUP BY ir.parent_id, ir.role, ir.children_count, ir.child_min_ranking_count, ir.completed_rankings,
|
||||
ir.message_tree_id
|
||||
HAVING(COUNT(mr.message_id) FILTER (WHERE mr.user_id = :user_id) = 0)
|
||||
"""
|
||||
|
||||
@@ -985,8 +1030,8 @@ SELECT p.parent_id, mr.* FROM
|
||||
GROUP BY m.parent_id, m.message_tree_id
|
||||
HAVING COUNT(m.id) > 1
|
||||
) as p
|
||||
INNER JOIN task t ON p.parent_id = t.parent_message_id AND t.done AND (t.payload_type = 'RankPrompterRepliesPayload' OR t.payload_type = 'RankAssistantRepliesPayload')
|
||||
INNER JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'RankingReactionPayload'
|
||||
LEFT JOIN task t ON p.parent_id = t.parent_message_id AND t.done AND (t.payload_type = 'RankPrompterRepliesPayload' OR t.payload_type = 'RankAssistantRepliesPayload')
|
||||
LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'RankingReactionPayload'
|
||||
"""
|
||||
|
||||
def query_tree_ranking_results(
|
||||
@@ -1029,7 +1074,14 @@ INNER JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Ranki
|
||||
self._insert_default_state(id, state=state)
|
||||
|
||||
rankings = (
|
||||
self.db.query(MessageTreeState).filter(MessageTreeState.state == message_tree_state.State.RANKING).all()
|
||||
self.db.query(MessageTreeState)
|
||||
.filter(
|
||||
or_(
|
||||
MessageTreeState.state == message_tree_state.State.RANKING,
|
||||
MessageTreeState.state == message_tree_state.State.READY_FOR_SCORING,
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
if len(rankings) > 0:
|
||||
logger.info(f"Checking state of {len(rankings)} message trees in ranking state.")
|
||||
@@ -1322,17 +1374,17 @@ DELETE FROM user_stats WHERE user_id = :user_id;
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def retry_scoring_failed_message_trees(self):
|
||||
query = self.db.query(MessageTreeState.message_tree_id).filter(
|
||||
query = self.db.query(MessageTreeState).filter(
|
||||
MessageTreeState.state == message_tree_state.State.SCORING_FAILED
|
||||
)
|
||||
ranking_role_filter = None if self.cfg.rank_prompter_replies else "assistant"
|
||||
for row in query.all():
|
||||
for mts in query.all():
|
||||
mts: MessageTreeState
|
||||
try:
|
||||
message_tree_id = row["message_tree_id"]
|
||||
rankings_by_message = self.query_tree_ranking_results(message_tree_id, role_filter=ranking_role_filter)
|
||||
self.update_message_ranks(message_tree_id=message_tree_id, rankings_by_message=rankings_by_message)
|
||||
if not self.check_condition_for_scoring_state(mts.message_tree_id):
|
||||
mts.active = True
|
||||
self._enter_state(message_tree_state.State.RANKING)
|
||||
except Exception:
|
||||
logger.exception(f"retry_scoring_failed_message_trees failed for ({message_tree_id=})")
|
||||
logger.exception(f"retry_scoring_failed_message_trees failed for ({mts.message_tree_id=})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -1366,8 +1418,8 @@ if __name__ == "__main__":
|
||||
|
||||
# print("next_task:", tm.next_task())
|
||||
|
||||
# print(
|
||||
# ".query_tree_ranking_results", tm.query_tree_ranking_results(UUID("2ac20d38-6650-43aa-8bb3-f61080c0d921"))
|
||||
# )
|
||||
print(
|
||||
".query_tree_ranking_results", tm.query_tree_ranking_results(UUID("21f9d585-d22c-44ab-a696-baa3d83b5f1b"))
|
||||
)
|
||||
|
||||
# print(tm.export_trees_to_file(message_tree_ids=["7e75fb38-e664-4e2b-817c-b9a0b01b0074"], file="lol.jsonl"))
|
||||
|
||||
@@ -117,13 +117,20 @@ class UserRepository:
|
||||
self.db.add(user)
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def _lookup_client_user_tx(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]:
|
||||
def _lookup_user_tx(
|
||||
self,
|
||||
*,
|
||||
username: str,
|
||||
auth_method: str,
|
||||
display_name: Optional[str] = None,
|
||||
create_missing: bool = True,
|
||||
) -> User | None:
|
||||
user: User = (
|
||||
self.db.query(User)
|
||||
.filter(
|
||||
User.api_client_id == self.api_client.id,
|
||||
User.username == client_user.id,
|
||||
User.auth_method == client_user.auth_method,
|
||||
User.username == username,
|
||||
User.auth_method == auth_method,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
@@ -131,30 +138,46 @@ class UserRepository:
|
||||
if create_missing:
|
||||
# user is unknown, create new record
|
||||
user = User(
|
||||
username=client_user.id,
|
||||
display_name=client_user.display_name,
|
||||
username=username,
|
||||
display_name=display_name,
|
||||
api_client_id=self.api_client.id,
|
||||
auth_method=client_user.auth_method,
|
||||
auth_method=auth_method,
|
||||
show_on_leaderboard=(auth_method != "system"), # don't show system users, e.g. import user
|
||||
)
|
||||
self.db.add(user)
|
||||
elif client_user.display_name and client_user.display_name != user.display_name:
|
||||
elif display_name and display_name != user.display_name:
|
||||
# we found the user but the display name changed
|
||||
user.display_name = client_user.display_name
|
||||
user.display_name = display_name
|
||||
self.db.add(user)
|
||||
|
||||
return user
|
||||
|
||||
def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]:
|
||||
def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> User | None:
|
||||
if not client_user:
|
||||
return None
|
||||
num_retries = settings.DATABASE_MAX_TX_RETRY_COUNT
|
||||
for i in range(num_retries):
|
||||
try:
|
||||
return self._lookup_client_user_tx(client_user, create_missing)
|
||||
return self._lookup_user_tx(
|
||||
username=client_user.id,
|
||||
auth_method=client_user.auth_method,
|
||||
display_name=client_user.display_name,
|
||||
create_missing=create_missing,
|
||||
)
|
||||
except IntegrityError:
|
||||
# catch UniqueViolation exception, for concurrent requests due to conflicts in ix_user_username
|
||||
if i + 1 == num_retries:
|
||||
raise
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def lookup_system_user(self, username: str, create_missing: bool = True) -> User | None:
|
||||
return self._lookup_user_tx(
|
||||
username=username,
|
||||
auth_method="system",
|
||||
display_name=f"__system__/{username}",
|
||||
create_missing=create_missing,
|
||||
)
|
||||
|
||||
def query_users_ordered_by_username(
|
||||
self,
|
||||
api_client_id: Optional[UUID] = None,
|
||||
|
||||
@@ -214,6 +214,10 @@ class UserStatsRepository:
|
||||
d = delete(UserStats).where(UserStats.time_frame == time_frame_key)
|
||||
self.session.execute(d)
|
||||
|
||||
if None in stats_by_user:
|
||||
logger.warning("Some messages in DB have NULL values in user_id column.")
|
||||
del stats_by_user[None]
|
||||
|
||||
# compute magic leader score
|
||||
for v in stats_by_user.values():
|
||||
v.leader_score = v.compute_leader_score()
|
||||
|
||||
@@ -12,12 +12,15 @@ from pydantic import BaseModel
|
||||
|
||||
class ExportMessageNode(BaseModel):
|
||||
message_id: str
|
||||
parent_id: Optional[str]
|
||||
text: Optional[str]
|
||||
parent_id: str | None
|
||||
text: str
|
||||
role: str
|
||||
review_count: Optional[int]
|
||||
rank: Optional[int]
|
||||
replies: Optional[list[ExportMessageNode]]
|
||||
lang: str | None
|
||||
review_count: int | None
|
||||
rank: int | None
|
||||
synthetic: bool | None
|
||||
model_name: str | None
|
||||
replies: list[ExportMessageNode] | None
|
||||
|
||||
@classmethod
|
||||
def prep_message_export(cls, message: Message) -> ExportMessageNode:
|
||||
@@ -26,14 +29,17 @@ class ExportMessageNode(BaseModel):
|
||||
parent_id=str(message.parent_id) if message.parent_id else None,
|
||||
text=str(message.payload.payload.text),
|
||||
role=message.role,
|
||||
lang=message.lang,
|
||||
review_count=message.review_count,
|
||||
synthetic=message.synthetic,
|
||||
model_name=message.model_name,
|
||||
rank=message.rank,
|
||||
)
|
||||
|
||||
|
||||
class ExportMessageTree(BaseModel):
|
||||
message_tree_id: str
|
||||
replies: Optional[ExportMessageNode]
|
||||
prompt: Optional[ExportMessageNode]
|
||||
|
||||
|
||||
def build_export_tree(message_tree_id: str, messages: list[Message]) -> ExportMessageTree:
|
||||
|
||||
Reference in New Issue
Block a user