From c8d16285d016822b0b03f821b04e099f1b4c38c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Sat, 28 Jan 2023 15:05:46 +0100 Subject: [PATCH] Import message trees from jsonl file (#964) * add new backlog_ranking tree state * add first version of import script * allow activation of trees during import * add min_active_rankings_per_lang config param * add settings docstring --- ...add_origin_column_to_message_tree_state.py | 31 +++ backend/import.py | 187 ++++++++++++++++++ backend/main.py | 4 +- backend/oasst_backend/api/deps.py | 6 +- backend/oasst_backend/config.py | 9 + backend/oasst_backend/models/message.py | 5 + .../models/message_tree_state.py | 5 + backend/oasst_backend/prompt_repository.py | 6 +- backend/oasst_backend/tree_manager.py | 108 +++++++--- backend/oasst_backend/user_repository.py | 43 +++- .../oasst_backend/user_stats_repository.py | 4 + backend/oasst_backend/utils/tree_export.py | 18 +- 12 files changed, 377 insertions(+), 49 deletions(-) create mode 100644 backend/alembic/versions/2023_01_28_1157-49d8445b4c90_add_origin_column_to_message_tree_state.py create mode 100644 backend/import.py diff --git a/backend/alembic/versions/2023_01_28_1157-49d8445b4c90_add_origin_column_to_message_tree_state.py b/backend/alembic/versions/2023_01_28_1157-49d8445b4c90_add_origin_column_to_message_tree_state.py new file mode 100644 index 00000000..e11f8f9e --- /dev/null +++ b/backend/alembic/versions/2023_01_28_1157-49d8445b4c90_add_origin_column_to_message_tree_state.py @@ -0,0 +1,31 @@ +"""add origin column to message_tree_state + +Revision ID: 49d8445b4c90 +Revises: f856bf19d32b +Create Date: 2023-01-28 11:57:45.580027 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "49d8445b4c90" +down_revision = "f856bf19d32b" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("message", sa.Column("synthetic", sa.Boolean(), server_default=sa.text("false"), nullable=False)) + op.add_column("message", sa.Column("model_name", sa.String(length=1024), nullable=True)) + op.add_column("message_tree_state", sa.Column("origin", sa.String(length=1024), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("message_tree_state", "origin") + op.drop_column("message", "model_name") + op.drop_column("message", "synthetic") + # ### end Alembic commands ### diff --git a/backend/import.py b/backend/import.py new file mode 100644 index 00000000..46a9c2cb --- /dev/null +++ b/backend/import.py @@ -0,0 +1,187 @@ +import argparse +import json +from pathlib import Path +from typing import Optional +from uuid import UUID + +import oasst_backend.models.db_payload as db_payload +import oasst_backend.utils.database_utils as db_utils +import pydantic +from loguru import logger +from oasst_backend.api.deps import create_api_client +from oasst_backend.models import ApiClient, Message +from oasst_backend.models.message_tree_state import MessageTreeState +from oasst_backend.models.message_tree_state import State as TreeState +from oasst_backend.models.payload_column_type import PayloadContainer +from oasst_backend.prompt_repository import PromptRepository +from oasst_backend.user_repository import UserRepository +from oasst_backend.utils.tree_export import ExportMessageNode, ExportMessageTree +from sqlmodel import Session + +# well known id +IMPORT_API_CLIENT_ID = UUID("bd8fde8b-1d8e-4e9a-9966-e96d000f8363") + + +class Importer: + def __init__(self, db: Session, origin: str, model_name: Optional[str] = None): + self.db = db + self.origin = origin + self.model_name = model_name + + # get import api client + api_client = db.query(ApiClient).filter(ApiClient.id == IMPORT_API_CLIENT_ID).first() + if not api_client: + api_client = create_api_client( + session=db, + description="API client used for importing data", + frontend_type="import", + force_id=IMPORT_API_CLIENT_ID, + ) + + ur = UserRepository(db, api_client) + self.import_user = ur.lookup_system_user(username="import") + self.pr = PromptRepository(db=db, api_client=api_client, user_repository=ur) + self.api_client = api_client + + def fetch_message_tree_state(self, message_tree_id: UUID) -> MessageTreeState: + return self.db.query(MessageTreeState).filter(MessageTreeState.message_tree_id == message_tree_id).one_or_none() + + def import_message( + self, message: ExportMessageNode, message_tree_id: UUID, parent_id: Optional[UUID] = None + ) -> Message: + payload = db_payload.MessagePayload(text=message.text) + msg = Message( + id=message.message_id, + message_tree_id=message_tree_id, + frontend_message_id=message.message_id, + parent_id=parent_id, + review_count=message.review_count or 0, + lang=message.lang or "en", + review_result=True, + synthetic=message.synthetic if message.synthetic is not None else True, + model_name=message.model_name or self.model_name, + role=message.role, + api_client_id=self.api_client.id, + payload_type=type(payload).__name__, + payload=PayloadContainer(payload=payload), + user_id=self.import_user.id, + ) + self.db.add(msg) + if message.replies: + for r in message.replies: + self.import_message(r, message_tree_id=message_tree_id, parent_id=msg.id) + self.db.flush() + if parent_id is None: + self.pr.update_children_counts(msg.id) + self.db.refresh(msg) + return msg + + def import_tree( + self, tree: ExportMessageTree, state: TreeState = TreeState.BACKLOG_RANKING + ) -> tuple[MessageTreeState, Message]: + assert tree.message_tree_id is not None and tree.message_tree_id == tree.prompt.message_id + root_msg = self.import_message(tree.prompt, message_tree_id=tree.prompt.message_id) + assert state == TreeState.BACKLOG_RANKING or state == TreeState.RANKING, f"{state} not supported for import" + active = state == TreeState.RANKING + mts = MessageTreeState( + message_tree_id=root_msg.id, + goal_tree_size=0, + max_depth=0, + max_children_count=0, + state=state, + origin=self.origin, + active=active, + ) + self.db.add(mts) + return mts, root_msg + + +def import_file( + input_file_path: Path, + origin: str, + *, + model_name: Optional[str] = None, + num_activate: int = 0, + max_count: Optional[int] = None, + dry_run: bool = False, +) -> int: + @db_utils.managed_tx_function(auto_commit=db_utils.CommitMode.ROLLBACK if dry_run else db_utils.CommitMode.COMMIT) + def import_tx(db: Session) -> int: + importer = Importer(db, origin=origin, model_name=model_name) + i = 0 + with input_file_path.open() as file_in: + # read line tree object + for line in file_in: + dict_tree = json.loads(line) + + # validate data + tree: ExportMessageTree = pydantic.parse_obj_as(ExportMessageTree, dict_tree) + existing_mts = importer.fetch_message_tree_state(tree.message_tree_id) + if existing_mts: + logger.info(f"Skipping existing message tree: {tree.message_tree_id}") + else: + state = TreeState.BACKLOG_RANKING if i >= num_activate else TreeState.RANKING + mts, root_msg = importer.import_tree(tree, state=state) + i += 1 + logger.info( + f"imported tree: {mts.message_tree_id}, {mts.state=}, {mts.active=}, {root_msg.children_count=}" + ) + + if max_count and i >= max_count: + logger.info(f"Reached max count {max_count} of trees to import.") + break + return i + + if dry_run: + logger.info("DRY RUN with rollback") + return import_tx() + + +def parse_args(): + def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected.") + + parser = argparse.ArgumentParser() + parser.add_argument( + "input_file_path", + help="Input file path", + ) + parser.add_argument("--origin", type=str, default=None, help="Value for origin of message trees") + parser.add_argument("--model_name", type=str, default=None, help="Default name of model (if missing in messages)") + parser.add_argument("--num_activate", type=int, default=0, help="Number of trees to add in ranking state") + parser.add_argument("--max_count", type=int, default=None, help="Maximum number of message trees to import") + parser.add_argument("--dry_run", type=str2bool, default=False) + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + input_file_path = Path(args.input_file_path) + if not input_file_path.exists() or not input_file_path.is_file(): + print("Invalid input file:", args.input_file_path) + exit(1) + + dry_run = args.dry_run + num_imported = import_file( + input_file_path, + origin=args.origin or input_file_path.name, + model_name=args.model_name, + num_activate=args.num_activate, + max_count=args.max_count, + dry_run=dry_run, + ) + + logger.info(f"Done ({num_imported=}, {dry_run=})") + + +if __name__ == "__main__": + main() diff --git a/backend/main.py b/backend/main.py index ea4b25da..06c67a2e 100644 --- a/backend/main.py +++ b/backend/main.py @@ -191,6 +191,7 @@ if settings.DEBUG_USE_SEED_DATA: review_count=5, review_result=True, check_tree_state=False, + check_duplicate=False, ) if message.parent_id is None: tm._insert_default_state( @@ -215,7 +216,8 @@ def ensure_tree_states(): try: logger.info("Startup: TreeManager.ensure_tree_states()") with Session(engine) as db: - tm = TreeManager(db, None) + api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=db) + tm = TreeManager(db, PromptRepository(db, api_client=api_client)) tm.ensure_tree_states() except Exception: diff --git a/backend/oasst_backend/api/deps.py b/backend/oasst_backend/api/deps.py index dd10bd20..af29285a 100644 --- a/backend/oasst_backend/api/deps.py +++ b/backend/oasst_backend/api/deps.py @@ -1,6 +1,7 @@ from http import HTTPStatus from secrets import token_hex -from typing import Generator, NamedTuple +from typing import Generator, NamedTuple, Optional +from uuid import UUID from fastapi import Depends, Request, Response, Security from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer @@ -67,6 +68,7 @@ def create_api_client( trusted: bool | None = False, admin_email: str | None = None, api_key: str | None = None, + force_id: Optional[UUID] = None, ) -> ApiClient: if api_key is None: api_key = token_hex(32) @@ -79,6 +81,8 @@ def create_api_client( trusted=trusted, admin_email=admin_email, ) + if force_id: + api_client.id = force_id session.add(api_client) session.commit() session.refresh(api_client) diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index 92316037..5c566b1e 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -46,6 +46,15 @@ class TreeManagerConfiguration(BaseModel): num_required_rankings: int = 3 """Number of rankings in which the message participated.""" + p_activate_backlog_tree: float = 0.8 + """Probability to activate a message tree in BACKLOG_RANKING state when another tree enters + a terminal state. Use this settting to control ratio of initial prompts and backlog tree + activations.""" + + min_active_rankings_per_lang: int = 2 + """When the number of active ranking tasks is below this value when a tree enters a terminal + state an available trees in BACKLOG_RANKING will be actived (i.e. enters the RANKING state).""" + labels_initial_prompt: list[TextLabel] = [ TextLabel.spam, TextLabel.quality, diff --git a/backend/oasst_backend/models/message.py b/backend/oasst_backend/models/message.py index 24fafc01..2cf3b7bb 100644 --- a/backend/oasst_backend/models/message.py +++ b/backend/oasst_backend/models/message.py @@ -57,6 +57,11 @@ class Message(SQLModel, table=True): rank: Optional[int] = Field(nullable=True) + synthetic: Optional[bool] = Field( + sa_column=sa.Column(sa.Boolean, default=False, server_default=false(), nullable=False) + ) + model_name: Optional[str] = Field(sa_column=sa.Column(sa.String(1024), nullable=True)) + emojis: Optional[dict[str, int]] = Field(default=None, sa_column=sa.Column(pg.JSONB), nullable=False) _user_emojis: Optional[list[str]] = PrivateAttr(default=None) diff --git a/backend/oasst_backend/models/message_tree_state.py b/backend/oasst_backend/models/message_tree_state.py index 2f7ce363..00b94967 100644 --- a/backend/oasst_backend/models/message_tree_state.py +++ b/backend/oasst_backend/models/message_tree_state.py @@ -43,6 +43,9 @@ class State(str, Enum): HALTED_BY_MODERATOR = "halted_by_moderator" """A moderator decided to manually halt the message tree construction process.""" + BACKLOG_RANKING = "backlog_ranking" + """Imported tree ready to be activated and ranked by users (currently inactive).""" + VALID_STATES = ( State.INITIAL_PROMPT_REVIEW, @@ -51,6 +54,7 @@ VALID_STATES = ( State.READY_FOR_SCORING, State.READY_FOR_EXPORT, State.ABORTED_LOW_GRADE, + State.BACKLOG_RANKING, ) TERMINAL_STATES = (State.READY_FOR_EXPORT, State.ABORTED_LOW_GRADE, State.SCORING_FAILED, State.HALTED_BY_MODERATOR) @@ -67,3 +71,4 @@ class MessageTreeState(SQLModel, table=True): max_children_count: int = Field(nullable=False) state: str = Field(nullable=False, max_length=128, index=True) active: bool = Field(nullable=False, index=True) + origin: str = Field(sa_column=sa.Column(sa.String(1024), nullable=True)) diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 01f63d26..c69f7340 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -177,6 +177,7 @@ class PromptRepository: review_count: int = 0, review_result: bool = False, check_tree_state: bool = True, + check_duplicate: bool = True, ) -> Message: self.ensure_user_is_enabled() @@ -199,7 +200,7 @@ class PromptRepository: logger.error(f"Message size {len(text)=} exceeds size limit of {settings.MESSAGE_SIZE_LIMIT=}.") raise OasstError("Message size too long.", OasstErrorCode.TASK_MESSAGE_TOO_LONG) - if self.check_users_recent_replies_for_duplicates(text): + if check_duplicate and self.check_users_recent_replies_for_duplicates(text): raise OasstError("User recent messages have duplicates", OasstErrorCode.TASK_MESSAGE_DUPLICATED) if task.parent_message_id: @@ -909,8 +910,7 @@ FROM ( ) AS cc WHERE message.id = cc.id; """ - r = self.db.execute(text(sql_update_children_count), {"message_tree_id": message_tree_id}) - logger.debug(f"update_children_count({message_tree_id=}): {r.rowcount} rows.") + self.db.execute(text(sql_update_children_count), {"message_tree_id": message_tree_id}) @managed_tx_method(CommitMode.COMMIT) def mark_messages_deleted(self, messages: Message | UUID | list[Message | UUID], recursive: bool = True): diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index 782b9af8..bcab02c3 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -25,7 +25,7 @@ from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingM from oasst_backend.utils.ranking import ranked_pairs from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema -from sqlmodel import Session, func, not_, text, update +from sqlmodel import Session, func, not_, or_, text, update class TaskType(Enum): @@ -73,6 +73,7 @@ class IncompleteRankingsRow(pydantic.BaseModel): role: str children_count: int child_min_ranking_count: int + message_tree_id: UUID class Config: orm_mode = True @@ -625,19 +626,28 @@ class TreeManager: return protocol_schema.TaskDone() - @managed_tx_method(CommitMode.FLUSH) def _enter_state(self, mts: MessageTreeState, state: message_tree_state.State): - assert mts and mts.active + assert mts is_terminal = state in message_tree_state.TERMINAL_STATES - + was_active = mts.active if is_terminal: mts.active = False mts.state = state.value self.db.add(mts) + self.db.flush if is_terminal: logger.info(f"Tree entered terminal '{mts.state}' state ({mts.message_tree_id=})") + root_msg = self.pr.fetch_message(message_id=mts.message_tree_id, fail_if_missing=False) + if root_msg and was_active: + if random.random() < self.cfg.p_activate_backlog_tree: + self.activate_backlog_tree(lang=root_msg.lang) + + if self.cfg.min_active_rankings_per_lang > 0: + incomplete_rankings = self.query_incomplete_rankings(lang=root_msg.lang) + if len(incomplete_rankings) < self.cfg.min_active_rankings_per_lang: + self.activate_backlog_tree(lang=root_msg.lang) else: logger.info(f"Tree entered '{mts.state}' state ({mts.message_tree_id=})") @@ -680,24 +690,30 @@ class TreeManager: self._enter_state(mts, message_tree_state.State.RANKING) return True - def check_condition_for_scoring_state( - self, message_tree_id: UUID - ) -> Tuple[bool, dict[UUID, list[MessageReaction]]]: + def check_condition_for_scoring_state(self, message_tree_id: UUID) -> bool: logger.debug(f"check_condition_for_scoring_state({message_tree_id=})") mts = self.pr.fetch_tree_state(message_tree_id) - if not mts.active or mts.state != message_tree_state.State.RANKING: - logger.debug(f"False {mts.active=}, {mts.state=}") - return False, None + if mts.state != message_tree_state.State.SCORING_FAILED: + if not mts.active or mts.state not in ( + message_tree_state.State.RANKING, + message_tree_state.State.READY_FOR_SCORING, + ): + logger.debug(f"False {mts.active=}, {mts.state=}") + return False ranking_role_filter = None if self.cfg.rank_prompter_replies else "assistant" rankings_by_message = self.query_tree_ranking_results(message_tree_id, role_filter=ranking_role_filter) for parent_msg_id, ranking in rankings_by_message.items(): if len(ranking) < self.cfg.num_required_rankings: logger.debug(f"False {parent_msg_id=} {len(ranking)=}") - return False, None + return False - self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING) + if ( + mts.state != message_tree_state.State.SCORING_FAILED + and mts.state != message_tree_state.State.READY_FOR_SCORING + ): + self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING) self.update_message_ranks(message_tree_id, rankings_by_message) return True @@ -759,8 +775,35 @@ class TreeManager: return False self._enter_state(mts, message_tree_state.State.READY_FOR_EXPORT) + return True + def activate_backlog_tree(self, lang: str) -> MessageTreeState: + while True: + # find tree in backlog state + backlog_tree: MessageTreeState = ( + self.db.query(MessageTreeState) + .join(Message, MessageTreeState.message_tree_id == Message.id) # root msg + .filter(MessageTreeState.state == message_tree_state.State.BACKLOG_RANKING) + .filter(Message.lang == lang) + .limit(1) + .one_or_none() + ) + + if not backlog_tree: + return None + + if len(self.query_tree_ranking_results(message_tree_id=backlog_tree.message_tree_id)) == 0: + logger.info( + f"Backlog tree {backlog_tree.message_tree_id} has no children to rank, aborting with 'aborted_low_grade' state." + ) + self._enter_state(backlog_tree, message_tree_state.State.ABORTED_LOW_GRADE) + else: + logger.info(f"Activating backlog tree {backlog_tree.message_tree_id}") + backlog_tree.active = True + self._enter_state(backlog_tree, message_tree_state.State.RANKING) + return backlog_tree + def _calculate_acceptance(self, labels: list[TextLabels]): # calculate acceptance based on spam label return np.mean([1 - l.labels[protocol_schema.TextLabel.spam] for l in labels]) @@ -828,7 +871,8 @@ class TreeManager: _sql_find_incomplete_rankings = """ -- find incomplete rankings SELECT m.parent_id, m.role, COUNT(m.id) children_count, MIN(m.ranking_count) child_min_ranking_count, - COUNT(m.id) FILTER (WHERE m.ranking_count >= :num_required_rankings) as completed_rankings + COUNT(m.id) FILTER (WHERE m.ranking_count >= :num_required_rankings) as completed_rankings, + mts.message_tree_id FROM message_tree_state mts INNER JOIN message m ON mts.message_tree_id = m.message_tree_id WHERE mts.active -- only consider active trees @@ -837,7 +881,7 @@ WHERE mts.active -- only consider active trees AND m.lang = :lang -- matches lang AND NOT m.deleted -- not deleted AND m.parent_id IS NOT NULL -- ignore initial prompts -GROUP BY m.parent_id, m.role +GROUP BY m.parent_id, m.role, mts.message_tree_id HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings """ @@ -846,7 +890,8 @@ HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings WITH incomplete_rankings AS ({_sql_find_incomplete_rankings}) SELECT ir.* FROM incomplete_rankings ir LEFT JOIN message_reaction mr ON ir.parent_id = mr.message_id AND mr.payload_type = 'RankingReactionPayload' -GROUP BY ir.parent_id, ir.role, ir.children_count, ir.child_min_ranking_count, ir.completed_rankings +GROUP BY ir.parent_id, ir.role, ir.children_count, ir.child_min_ranking_count, ir.completed_rankings, + ir.message_tree_id HAVING(COUNT(mr.message_id) FILTER (WHERE mr.user_id = :user_id) = 0) """ @@ -985,8 +1030,8 @@ SELECT p.parent_id, mr.* FROM GROUP BY m.parent_id, m.message_tree_id HAVING COUNT(m.id) > 1 ) as p -INNER JOIN task t ON p.parent_id = t.parent_message_id AND t.done AND (t.payload_type = 'RankPrompterRepliesPayload' OR t.payload_type = 'RankAssistantRepliesPayload') -INNER JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'RankingReactionPayload' +LEFT JOIN task t ON p.parent_id = t.parent_message_id AND t.done AND (t.payload_type = 'RankPrompterRepliesPayload' OR t.payload_type = 'RankAssistantRepliesPayload') +LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'RankingReactionPayload' """ def query_tree_ranking_results( @@ -1029,7 +1074,14 @@ INNER JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Ranki self._insert_default_state(id, state=state) rankings = ( - self.db.query(MessageTreeState).filter(MessageTreeState.state == message_tree_state.State.RANKING).all() + self.db.query(MessageTreeState) + .filter( + or_( + MessageTreeState.state == message_tree_state.State.RANKING, + MessageTreeState.state == message_tree_state.State.READY_FOR_SCORING, + ) + ) + .all() ) if len(rankings) > 0: logger.info(f"Checking state of {len(rankings)} message trees in ranking state.") @@ -1322,17 +1374,17 @@ DELETE FROM user_stats WHERE user_id = :user_id; @managed_tx_method(CommitMode.COMMIT) def retry_scoring_failed_message_trees(self): - query = self.db.query(MessageTreeState.message_tree_id).filter( + query = self.db.query(MessageTreeState).filter( MessageTreeState.state == message_tree_state.State.SCORING_FAILED ) - ranking_role_filter = None if self.cfg.rank_prompter_replies else "assistant" - for row in query.all(): + for mts in query.all(): + mts: MessageTreeState try: - message_tree_id = row["message_tree_id"] - rankings_by_message = self.query_tree_ranking_results(message_tree_id, role_filter=ranking_role_filter) - self.update_message_ranks(message_tree_id=message_tree_id, rankings_by_message=rankings_by_message) + if not self.check_condition_for_scoring_state(mts.message_tree_id): + mts.active = True + self._enter_state(message_tree_state.State.RANKING) except Exception: - logger.exception(f"retry_scoring_failed_message_trees failed for ({message_tree_id=})") + logger.exception(f"retry_scoring_failed_message_trees failed for ({mts.message_tree_id=})") if __name__ == "__main__": @@ -1366,8 +1418,8 @@ if __name__ == "__main__": # print("next_task:", tm.next_task()) - # print( - # ".query_tree_ranking_results", tm.query_tree_ranking_results(UUID("2ac20d38-6650-43aa-8bb3-f61080c0d921")) - # ) + print( + ".query_tree_ranking_results", tm.query_tree_ranking_results(UUID("21f9d585-d22c-44ab-a696-baa3d83b5f1b")) + ) # print(tm.export_trees_to_file(message_tree_ids=["7e75fb38-e664-4e2b-817c-b9a0b01b0074"], file="lol.jsonl")) diff --git a/backend/oasst_backend/user_repository.py b/backend/oasst_backend/user_repository.py index 6f96748f..8d6187fb 100644 --- a/backend/oasst_backend/user_repository.py +++ b/backend/oasst_backend/user_repository.py @@ -117,13 +117,20 @@ class UserRepository: self.db.add(user) @managed_tx_method(CommitMode.COMMIT) - def _lookup_client_user_tx(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]: + def _lookup_user_tx( + self, + *, + username: str, + auth_method: str, + display_name: Optional[str] = None, + create_missing: bool = True, + ) -> User | None: user: User = ( self.db.query(User) .filter( User.api_client_id == self.api_client.id, - User.username == client_user.id, - User.auth_method == client_user.auth_method, + User.username == username, + User.auth_method == auth_method, ) .first() ) @@ -131,30 +138,46 @@ class UserRepository: if create_missing: # user is unknown, create new record user = User( - username=client_user.id, - display_name=client_user.display_name, + username=username, + display_name=display_name, api_client_id=self.api_client.id, - auth_method=client_user.auth_method, + auth_method=auth_method, + show_on_leaderboard=(auth_method != "system"), # don't show system users, e.g. import user ) self.db.add(user) - elif client_user.display_name and client_user.display_name != user.display_name: + elif display_name and display_name != user.display_name: # we found the user but the display name changed - user.display_name = client_user.display_name + user.display_name = display_name self.db.add(user) + return user - def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]: + def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> User | None: if not client_user: return None num_retries = settings.DATABASE_MAX_TX_RETRY_COUNT for i in range(num_retries): try: - return self._lookup_client_user_tx(client_user, create_missing) + return self._lookup_user_tx( + username=client_user.id, + auth_method=client_user.auth_method, + display_name=client_user.display_name, + create_missing=create_missing, + ) except IntegrityError: # catch UniqueViolation exception, for concurrent requests due to conflicts in ix_user_username if i + 1 == num_retries: raise + @managed_tx_method(CommitMode.COMMIT) + def lookup_system_user(self, username: str, create_missing: bool = True) -> User | None: + return self._lookup_user_tx( + username=username, + auth_method="system", + display_name=f"__system__/{username}", + create_missing=create_missing, + ) + def query_users_ordered_by_username( self, api_client_id: Optional[UUID] = None, diff --git a/backend/oasst_backend/user_stats_repository.py b/backend/oasst_backend/user_stats_repository.py index 81e5cb92..4cc0fc84 100644 --- a/backend/oasst_backend/user_stats_repository.py +++ b/backend/oasst_backend/user_stats_repository.py @@ -214,6 +214,10 @@ class UserStatsRepository: d = delete(UserStats).where(UserStats.time_frame == time_frame_key) self.session.execute(d) + if None in stats_by_user: + logger.warning("Some messages in DB have NULL values in user_id column.") + del stats_by_user[None] + # compute magic leader score for v in stats_by_user.values(): v.leader_score = v.compute_leader_score() diff --git a/backend/oasst_backend/utils/tree_export.py b/backend/oasst_backend/utils/tree_export.py index ee3de9d7..5cd69abe 100644 --- a/backend/oasst_backend/utils/tree_export.py +++ b/backend/oasst_backend/utils/tree_export.py @@ -12,12 +12,15 @@ from pydantic import BaseModel class ExportMessageNode(BaseModel): message_id: str - parent_id: Optional[str] - text: Optional[str] + parent_id: str | None + text: str role: str - review_count: Optional[int] - rank: Optional[int] - replies: Optional[list[ExportMessageNode]] + lang: str | None + review_count: int | None + rank: int | None + synthetic: bool | None + model_name: str | None + replies: list[ExportMessageNode] | None @classmethod def prep_message_export(cls, message: Message) -> ExportMessageNode: @@ -26,14 +29,17 @@ class ExportMessageNode(BaseModel): parent_id=str(message.parent_id) if message.parent_id else None, text=str(message.payload.payload.text), role=message.role, + lang=message.lang, review_count=message.review_count, + synthetic=message.synthetic, + model_name=message.model_name, rank=message.rank, ) class ExportMessageTree(BaseModel): message_tree_id: str - replies: Optional[ExportMessageNode] + prompt: Optional[ExportMessageNode] def build_export_tree(message_tree_id: str, messages: list[Message]) -> ExportMessageTree: