Import message trees from jsonl file (#964)

* add new backlog_ranking tree state

* add first version of import script

* allow activation of trees during import

* add min_active_rankings_per_lang config param

* add settings docstring
This commit is contained in:
Andreas Köpf
2023-01-28 15:05:46 +01:00
committed by GitHub
parent b2eb94962c
commit c8d16285d0
12 changed files with 377 additions and 49 deletions
@@ -0,0 +1,31 @@
"""add origin column to message_tree_state
Revision ID: 49d8445b4c90
Revises: f856bf19d32b
Create Date: 2023-01-28 11:57:45.580027
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "49d8445b4c90"
down_revision = "f856bf19d32b"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("message", sa.Column("synthetic", sa.Boolean(), server_default=sa.text("false"), nullable=False))
op.add_column("message", sa.Column("model_name", sa.String(length=1024), nullable=True))
op.add_column("message_tree_state", sa.Column("origin", sa.String(length=1024), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("message_tree_state", "origin")
op.drop_column("message", "model_name")
op.drop_column("message", "synthetic")
# ### end Alembic commands ###
+187
View File
@@ -0,0 +1,187 @@
import argparse
import json
from pathlib import Path
from typing import Optional
from uuid import UUID
import oasst_backend.models.db_payload as db_payload
import oasst_backend.utils.database_utils as db_utils
import pydantic
from loguru import logger
from oasst_backend.api.deps import create_api_client
from oasst_backend.models import ApiClient, Message
from oasst_backend.models.message_tree_state import MessageTreeState
from oasst_backend.models.message_tree_state import State as TreeState
from oasst_backend.models.payload_column_type import PayloadContainer
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.user_repository import UserRepository
from oasst_backend.utils.tree_export import ExportMessageNode, ExportMessageTree
from sqlmodel import Session
# well known id
IMPORT_API_CLIENT_ID = UUID("bd8fde8b-1d8e-4e9a-9966-e96d000f8363")
class Importer:
def __init__(self, db: Session, origin: str, model_name: Optional[str] = None):
self.db = db
self.origin = origin
self.model_name = model_name
# get import api client
api_client = db.query(ApiClient).filter(ApiClient.id == IMPORT_API_CLIENT_ID).first()
if not api_client:
api_client = create_api_client(
session=db,
description="API client used for importing data",
frontend_type="import",
force_id=IMPORT_API_CLIENT_ID,
)
ur = UserRepository(db, api_client)
self.import_user = ur.lookup_system_user(username="import")
self.pr = PromptRepository(db=db, api_client=api_client, user_repository=ur)
self.api_client = api_client
def fetch_message_tree_state(self, message_tree_id: UUID) -> MessageTreeState:
return self.db.query(MessageTreeState).filter(MessageTreeState.message_tree_id == message_tree_id).one_or_none()
def import_message(
self, message: ExportMessageNode, message_tree_id: UUID, parent_id: Optional[UUID] = None
) -> Message:
payload = db_payload.MessagePayload(text=message.text)
msg = Message(
id=message.message_id,
message_tree_id=message_tree_id,
frontend_message_id=message.message_id,
parent_id=parent_id,
review_count=message.review_count or 0,
lang=message.lang or "en",
review_result=True,
synthetic=message.synthetic if message.synthetic is not None else True,
model_name=message.model_name or self.model_name,
role=message.role,
api_client_id=self.api_client.id,
payload_type=type(payload).__name__,
payload=PayloadContainer(payload=payload),
user_id=self.import_user.id,
)
self.db.add(msg)
if message.replies:
for r in message.replies:
self.import_message(r, message_tree_id=message_tree_id, parent_id=msg.id)
self.db.flush()
if parent_id is None:
self.pr.update_children_counts(msg.id)
self.db.refresh(msg)
return msg
def import_tree(
self, tree: ExportMessageTree, state: TreeState = TreeState.BACKLOG_RANKING
) -> tuple[MessageTreeState, Message]:
assert tree.message_tree_id is not None and tree.message_tree_id == tree.prompt.message_id
root_msg = self.import_message(tree.prompt, message_tree_id=tree.prompt.message_id)
assert state == TreeState.BACKLOG_RANKING or state == TreeState.RANKING, f"{state} not supported for import"
active = state == TreeState.RANKING
mts = MessageTreeState(
message_tree_id=root_msg.id,
goal_tree_size=0,
max_depth=0,
max_children_count=0,
state=state,
origin=self.origin,
active=active,
)
self.db.add(mts)
return mts, root_msg
def import_file(
input_file_path: Path,
origin: str,
*,
model_name: Optional[str] = None,
num_activate: int = 0,
max_count: Optional[int] = None,
dry_run: bool = False,
) -> int:
@db_utils.managed_tx_function(auto_commit=db_utils.CommitMode.ROLLBACK if dry_run else db_utils.CommitMode.COMMIT)
def import_tx(db: Session) -> int:
importer = Importer(db, origin=origin, model_name=model_name)
i = 0
with input_file_path.open() as file_in:
# read line tree object
for line in file_in:
dict_tree = json.loads(line)
# validate data
tree: ExportMessageTree = pydantic.parse_obj_as(ExportMessageTree, dict_tree)
existing_mts = importer.fetch_message_tree_state(tree.message_tree_id)
if existing_mts:
logger.info(f"Skipping existing message tree: {tree.message_tree_id}")
else:
state = TreeState.BACKLOG_RANKING if i >= num_activate else TreeState.RANKING
mts, root_msg = importer.import_tree(tree, state=state)
i += 1
logger.info(
f"imported tree: {mts.message_tree_id}, {mts.state=}, {mts.active=}, {root_msg.children_count=}"
)
if max_count and i >= max_count:
logger.info(f"Reached max count {max_count} of trees to import.")
break
return i
if dry_run:
logger.info("DRY RUN with rollback")
return import_tx()
def parse_args():
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
parser = argparse.ArgumentParser()
parser.add_argument(
"input_file_path",
help="Input file path",
)
parser.add_argument("--origin", type=str, default=None, help="Value for origin of message trees")
parser.add_argument("--model_name", type=str, default=None, help="Default name of model (if missing in messages)")
parser.add_argument("--num_activate", type=int, default=0, help="Number of trees to add in ranking state")
parser.add_argument("--max_count", type=int, default=None, help="Maximum number of message trees to import")
parser.add_argument("--dry_run", type=str2bool, default=False)
args = parser.parse_args()
return args
def main():
args = parse_args()
input_file_path = Path(args.input_file_path)
if not input_file_path.exists() or not input_file_path.is_file():
print("Invalid input file:", args.input_file_path)
exit(1)
dry_run = args.dry_run
num_imported = import_file(
input_file_path,
origin=args.origin or input_file_path.name,
model_name=args.model_name,
num_activate=args.num_activate,
max_count=args.max_count,
dry_run=dry_run,
)
logger.info(f"Done ({num_imported=}, {dry_run=})")
if __name__ == "__main__":
main()
+3 -1
View File
@@ -191,6 +191,7 @@ if settings.DEBUG_USE_SEED_DATA:
review_count=5,
review_result=True,
check_tree_state=False,
check_duplicate=False,
)
if message.parent_id is None:
tm._insert_default_state(
@@ -215,7 +216,8 @@ def ensure_tree_states():
try:
logger.info("Startup: TreeManager.ensure_tree_states()")
with Session(engine) as db:
tm = TreeManager(db, None)
api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=db)
tm = TreeManager(db, PromptRepository(db, api_client=api_client))
tm.ensure_tree_states()
except Exception:
+5 -1
View File
@@ -1,6 +1,7 @@
from http import HTTPStatus
from secrets import token_hex
from typing import Generator, NamedTuple
from typing import Generator, NamedTuple, Optional
from uuid import UUID
from fastapi import Depends, Request, Response, Security
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
@@ -67,6 +68,7 @@ def create_api_client(
trusted: bool | None = False,
admin_email: str | None = None,
api_key: str | None = None,
force_id: Optional[UUID] = None,
) -> ApiClient:
if api_key is None:
api_key = token_hex(32)
@@ -79,6 +81,8 @@ def create_api_client(
trusted=trusted,
admin_email=admin_email,
)
if force_id:
api_client.id = force_id
session.add(api_client)
session.commit()
session.refresh(api_client)
+9
View File
@@ -46,6 +46,15 @@ class TreeManagerConfiguration(BaseModel):
num_required_rankings: int = 3
"""Number of rankings in which the message participated."""
p_activate_backlog_tree: float = 0.8
"""Probability to activate a message tree in BACKLOG_RANKING state when another tree enters
a terminal state. Use this settting to control ratio of initial prompts and backlog tree
activations."""
min_active_rankings_per_lang: int = 2
"""When the number of active ranking tasks is below this value when a tree enters a terminal
state an available trees in BACKLOG_RANKING will be actived (i.e. enters the RANKING state)."""
labels_initial_prompt: list[TextLabel] = [
TextLabel.spam,
TextLabel.quality,
+5
View File
@@ -57,6 +57,11 @@ class Message(SQLModel, table=True):
rank: Optional[int] = Field(nullable=True)
synthetic: Optional[bool] = Field(
sa_column=sa.Column(sa.Boolean, default=False, server_default=false(), nullable=False)
)
model_name: Optional[str] = Field(sa_column=sa.Column(sa.String(1024), nullable=True))
emojis: Optional[dict[str, int]] = Field(default=None, sa_column=sa.Column(pg.JSONB), nullable=False)
_user_emojis: Optional[list[str]] = PrivateAttr(default=None)
@@ -43,6 +43,9 @@ class State(str, Enum):
HALTED_BY_MODERATOR = "halted_by_moderator"
"""A moderator decided to manually halt the message tree construction process."""
BACKLOG_RANKING = "backlog_ranking"
"""Imported tree ready to be activated and ranked by users (currently inactive)."""
VALID_STATES = (
State.INITIAL_PROMPT_REVIEW,
@@ -51,6 +54,7 @@ VALID_STATES = (
State.READY_FOR_SCORING,
State.READY_FOR_EXPORT,
State.ABORTED_LOW_GRADE,
State.BACKLOG_RANKING,
)
TERMINAL_STATES = (State.READY_FOR_EXPORT, State.ABORTED_LOW_GRADE, State.SCORING_FAILED, State.HALTED_BY_MODERATOR)
@@ -67,3 +71,4 @@ class MessageTreeState(SQLModel, table=True):
max_children_count: int = Field(nullable=False)
state: str = Field(nullable=False, max_length=128, index=True)
active: bool = Field(nullable=False, index=True)
origin: str = Field(sa_column=sa.Column(sa.String(1024), nullable=True))
+3 -3
View File
@@ -177,6 +177,7 @@ class PromptRepository:
review_count: int = 0,
review_result: bool = False,
check_tree_state: bool = True,
check_duplicate: bool = True,
) -> Message:
self.ensure_user_is_enabled()
@@ -199,7 +200,7 @@ class PromptRepository:
logger.error(f"Message size {len(text)=} exceeds size limit of {settings.MESSAGE_SIZE_LIMIT=}.")
raise OasstError("Message size too long.", OasstErrorCode.TASK_MESSAGE_TOO_LONG)
if self.check_users_recent_replies_for_duplicates(text):
if check_duplicate and self.check_users_recent_replies_for_duplicates(text):
raise OasstError("User recent messages have duplicates", OasstErrorCode.TASK_MESSAGE_DUPLICATED)
if task.parent_message_id:
@@ -909,8 +910,7 @@ FROM (
) AS cc
WHERE message.id = cc.id;
"""
r = self.db.execute(text(sql_update_children_count), {"message_tree_id": message_tree_id})
logger.debug(f"update_children_count({message_tree_id=}): {r.rowcount} rows.")
self.db.execute(text(sql_update_children_count), {"message_tree_id": message_tree_id})
@managed_tx_method(CommitMode.COMMIT)
def mark_messages_deleted(self, messages: Message | UUID | list[Message | UUID], recursive: bool = True):
+80 -28
View File
@@ -25,7 +25,7 @@ from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingM
from oasst_backend.utils.ranking import ranked_pairs
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session, func, not_, text, update
from sqlmodel import Session, func, not_, or_, text, update
class TaskType(Enum):
@@ -73,6 +73,7 @@ class IncompleteRankingsRow(pydantic.BaseModel):
role: str
children_count: int
child_min_ranking_count: int
message_tree_id: UUID
class Config:
orm_mode = True
@@ -625,19 +626,28 @@ class TreeManager:
return protocol_schema.TaskDone()
@managed_tx_method(CommitMode.FLUSH)
def _enter_state(self, mts: MessageTreeState, state: message_tree_state.State):
assert mts and mts.active
assert mts
is_terminal = state in message_tree_state.TERMINAL_STATES
was_active = mts.active
if is_terminal:
mts.active = False
mts.state = state.value
self.db.add(mts)
self.db.flush
if is_terminal:
logger.info(f"Tree entered terminal '{mts.state}' state ({mts.message_tree_id=})")
root_msg = self.pr.fetch_message(message_id=mts.message_tree_id, fail_if_missing=False)
if root_msg and was_active:
if random.random() < self.cfg.p_activate_backlog_tree:
self.activate_backlog_tree(lang=root_msg.lang)
if self.cfg.min_active_rankings_per_lang > 0:
incomplete_rankings = self.query_incomplete_rankings(lang=root_msg.lang)
if len(incomplete_rankings) < self.cfg.min_active_rankings_per_lang:
self.activate_backlog_tree(lang=root_msg.lang)
else:
logger.info(f"Tree entered '{mts.state}' state ({mts.message_tree_id=})")
@@ -680,24 +690,30 @@ class TreeManager:
self._enter_state(mts, message_tree_state.State.RANKING)
return True
def check_condition_for_scoring_state(
self, message_tree_id: UUID
) -> Tuple[bool, dict[UUID, list[MessageReaction]]]:
def check_condition_for_scoring_state(self, message_tree_id: UUID) -> bool:
logger.debug(f"check_condition_for_scoring_state({message_tree_id=})")
mts = self.pr.fetch_tree_state(message_tree_id)
if not mts.active or mts.state != message_tree_state.State.RANKING:
logger.debug(f"False {mts.active=}, {mts.state=}")
return False, None
if mts.state != message_tree_state.State.SCORING_FAILED:
if not mts.active or mts.state not in (
message_tree_state.State.RANKING,
message_tree_state.State.READY_FOR_SCORING,
):
logger.debug(f"False {mts.active=}, {mts.state=}")
return False
ranking_role_filter = None if self.cfg.rank_prompter_replies else "assistant"
rankings_by_message = self.query_tree_ranking_results(message_tree_id, role_filter=ranking_role_filter)
for parent_msg_id, ranking in rankings_by_message.items():
if len(ranking) < self.cfg.num_required_rankings:
logger.debug(f"False {parent_msg_id=} {len(ranking)=}")
return False, None
return False
self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING)
if (
mts.state != message_tree_state.State.SCORING_FAILED
and mts.state != message_tree_state.State.READY_FOR_SCORING
):
self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING)
self.update_message_ranks(message_tree_id, rankings_by_message)
return True
@@ -759,8 +775,35 @@ class TreeManager:
return False
self._enter_state(mts, message_tree_state.State.READY_FOR_EXPORT)
return True
def activate_backlog_tree(self, lang: str) -> MessageTreeState:
while True:
# find tree in backlog state
backlog_tree: MessageTreeState = (
self.db.query(MessageTreeState)
.join(Message, MessageTreeState.message_tree_id == Message.id) # root msg
.filter(MessageTreeState.state == message_tree_state.State.BACKLOG_RANKING)
.filter(Message.lang == lang)
.limit(1)
.one_or_none()
)
if not backlog_tree:
return None
if len(self.query_tree_ranking_results(message_tree_id=backlog_tree.message_tree_id)) == 0:
logger.info(
f"Backlog tree {backlog_tree.message_tree_id} has no children to rank, aborting with 'aborted_low_grade' state."
)
self._enter_state(backlog_tree, message_tree_state.State.ABORTED_LOW_GRADE)
else:
logger.info(f"Activating backlog tree {backlog_tree.message_tree_id}")
backlog_tree.active = True
self._enter_state(backlog_tree, message_tree_state.State.RANKING)
return backlog_tree
def _calculate_acceptance(self, labels: list[TextLabels]):
# calculate acceptance based on spam label
return np.mean([1 - l.labels[protocol_schema.TextLabel.spam] for l in labels])
@@ -828,7 +871,8 @@ class TreeManager:
_sql_find_incomplete_rankings = """
-- find incomplete rankings
SELECT m.parent_id, m.role, COUNT(m.id) children_count, MIN(m.ranking_count) child_min_ranking_count,
COUNT(m.id) FILTER (WHERE m.ranking_count >= :num_required_rankings) as completed_rankings
COUNT(m.id) FILTER (WHERE m.ranking_count >= :num_required_rankings) as completed_rankings,
mts.message_tree_id
FROM message_tree_state mts
INNER JOIN message m ON mts.message_tree_id = m.message_tree_id
WHERE mts.active -- only consider active trees
@@ -837,7 +881,7 @@ WHERE mts.active -- only consider active trees
AND m.lang = :lang -- matches lang
AND NOT m.deleted -- not deleted
AND m.parent_id IS NOT NULL -- ignore initial prompts
GROUP BY m.parent_id, m.role
GROUP BY m.parent_id, m.role, mts.message_tree_id
HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings
"""
@@ -846,7 +890,8 @@ HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings
WITH incomplete_rankings AS ({_sql_find_incomplete_rankings})
SELECT ir.* FROM incomplete_rankings ir
LEFT JOIN message_reaction mr ON ir.parent_id = mr.message_id AND mr.payload_type = 'RankingReactionPayload'
GROUP BY ir.parent_id, ir.role, ir.children_count, ir.child_min_ranking_count, ir.completed_rankings
GROUP BY ir.parent_id, ir.role, ir.children_count, ir.child_min_ranking_count, ir.completed_rankings,
ir.message_tree_id
HAVING(COUNT(mr.message_id) FILTER (WHERE mr.user_id = :user_id) = 0)
"""
@@ -985,8 +1030,8 @@ SELECT p.parent_id, mr.* FROM
GROUP BY m.parent_id, m.message_tree_id
HAVING COUNT(m.id) > 1
) as p
INNER JOIN task t ON p.parent_id = t.parent_message_id AND t.done AND (t.payload_type = 'RankPrompterRepliesPayload' OR t.payload_type = 'RankAssistantRepliesPayload')
INNER JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'RankingReactionPayload'
LEFT JOIN task t ON p.parent_id = t.parent_message_id AND t.done AND (t.payload_type = 'RankPrompterRepliesPayload' OR t.payload_type = 'RankAssistantRepliesPayload')
LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'RankingReactionPayload'
"""
def query_tree_ranking_results(
@@ -1029,7 +1074,14 @@ INNER JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Ranki
self._insert_default_state(id, state=state)
rankings = (
self.db.query(MessageTreeState).filter(MessageTreeState.state == message_tree_state.State.RANKING).all()
self.db.query(MessageTreeState)
.filter(
or_(
MessageTreeState.state == message_tree_state.State.RANKING,
MessageTreeState.state == message_tree_state.State.READY_FOR_SCORING,
)
)
.all()
)
if len(rankings) > 0:
logger.info(f"Checking state of {len(rankings)} message trees in ranking state.")
@@ -1322,17 +1374,17 @@ DELETE FROM user_stats WHERE user_id = :user_id;
@managed_tx_method(CommitMode.COMMIT)
def retry_scoring_failed_message_trees(self):
query = self.db.query(MessageTreeState.message_tree_id).filter(
query = self.db.query(MessageTreeState).filter(
MessageTreeState.state == message_tree_state.State.SCORING_FAILED
)
ranking_role_filter = None if self.cfg.rank_prompter_replies else "assistant"
for row in query.all():
for mts in query.all():
mts: MessageTreeState
try:
message_tree_id = row["message_tree_id"]
rankings_by_message = self.query_tree_ranking_results(message_tree_id, role_filter=ranking_role_filter)
self.update_message_ranks(message_tree_id=message_tree_id, rankings_by_message=rankings_by_message)
if not self.check_condition_for_scoring_state(mts.message_tree_id):
mts.active = True
self._enter_state(message_tree_state.State.RANKING)
except Exception:
logger.exception(f"retry_scoring_failed_message_trees failed for ({message_tree_id=})")
logger.exception(f"retry_scoring_failed_message_trees failed for ({mts.message_tree_id=})")
if __name__ == "__main__":
@@ -1366,8 +1418,8 @@ if __name__ == "__main__":
# print("next_task:", tm.next_task())
# print(
# ".query_tree_ranking_results", tm.query_tree_ranking_results(UUID("2ac20d38-6650-43aa-8bb3-f61080c0d921"))
# )
print(
".query_tree_ranking_results", tm.query_tree_ranking_results(UUID("21f9d585-d22c-44ab-a696-baa3d83b5f1b"))
)
# print(tm.export_trees_to_file(message_tree_ids=["7e75fb38-e664-4e2b-817c-b9a0b01b0074"], file="lol.jsonl"))
+33 -10
View File
@@ -117,13 +117,20 @@ class UserRepository:
self.db.add(user)
@managed_tx_method(CommitMode.COMMIT)
def _lookup_client_user_tx(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]:
def _lookup_user_tx(
self,
*,
username: str,
auth_method: str,
display_name: Optional[str] = None,
create_missing: bool = True,
) -> User | None:
user: User = (
self.db.query(User)
.filter(
User.api_client_id == self.api_client.id,
User.username == client_user.id,
User.auth_method == client_user.auth_method,
User.username == username,
User.auth_method == auth_method,
)
.first()
)
@@ -131,30 +138,46 @@ class UserRepository:
if create_missing:
# user is unknown, create new record
user = User(
username=client_user.id,
display_name=client_user.display_name,
username=username,
display_name=display_name,
api_client_id=self.api_client.id,
auth_method=client_user.auth_method,
auth_method=auth_method,
show_on_leaderboard=(auth_method != "system"), # don't show system users, e.g. import user
)
self.db.add(user)
elif client_user.display_name and client_user.display_name != user.display_name:
elif display_name and display_name != user.display_name:
# we found the user but the display name changed
user.display_name = client_user.display_name
user.display_name = display_name
self.db.add(user)
return user
def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]:
def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> User | None:
if not client_user:
return None
num_retries = settings.DATABASE_MAX_TX_RETRY_COUNT
for i in range(num_retries):
try:
return self._lookup_client_user_tx(client_user, create_missing)
return self._lookup_user_tx(
username=client_user.id,
auth_method=client_user.auth_method,
display_name=client_user.display_name,
create_missing=create_missing,
)
except IntegrityError:
# catch UniqueViolation exception, for concurrent requests due to conflicts in ix_user_username
if i + 1 == num_retries:
raise
@managed_tx_method(CommitMode.COMMIT)
def lookup_system_user(self, username: str, create_missing: bool = True) -> User | None:
return self._lookup_user_tx(
username=username,
auth_method="system",
display_name=f"__system__/{username}",
create_missing=create_missing,
)
def query_users_ordered_by_username(
self,
api_client_id: Optional[UUID] = None,
@@ -214,6 +214,10 @@ class UserStatsRepository:
d = delete(UserStats).where(UserStats.time_frame == time_frame_key)
self.session.execute(d)
if None in stats_by_user:
logger.warning("Some messages in DB have NULL values in user_id column.")
del stats_by_user[None]
# compute magic leader score
for v in stats_by_user.values():
v.leader_score = v.compute_leader_score()
+12 -6
View File
@@ -12,12 +12,15 @@ from pydantic import BaseModel
class ExportMessageNode(BaseModel):
message_id: str
parent_id: Optional[str]
text: Optional[str]
parent_id: str | None
text: str
role: str
review_count: Optional[int]
rank: Optional[int]
replies: Optional[list[ExportMessageNode]]
lang: str | None
review_count: int | None
rank: int | None
synthetic: bool | None
model_name: str | None
replies: list[ExportMessageNode] | None
@classmethod
def prep_message_export(cls, message: Message) -> ExportMessageNode:
@@ -26,14 +29,17 @@ class ExportMessageNode(BaseModel):
parent_id=str(message.parent_id) if message.parent_id else None,
text=str(message.payload.payload.text),
role=message.role,
lang=message.lang,
review_count=message.review_count,
synthetic=message.synthetic,
model_name=message.model_name,
rank=message.rank,
)
class ExportMessageTree(BaseModel):
message_tree_id: str
replies: Optional[ExportMessageNode]
prompt: Optional[ExportMessageNode]
def build_export_tree(message_tree_id: str, messages: list[Message]) -> ExportMessageTree: