diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index 265591c1..2f48bca0 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -489,7 +489,8 @@ class TreeManager: _, task = pr.store_ranking(interaction) - self.check_condition_for_scoring_state(task.message_tree_id) + ok, rankings_by_message = self.check_condition_for_scoring_state(task.message_tree_id) + self.update_message_ranks(task.message_tree_id, rankings_by_message) case protocol_schema.TextLabels: logger.info( @@ -589,39 +590,56 @@ class TreeManager: return True @managed_tx_method(CommitMode.COMMIT) - def check_condition_for_scoring_state(self, message_tree_id: UUID) -> bool: + def check_condition_for_scoring_state( + self, message_tree_id: UUID + ) -> Tuple[bool, dict[UUID, list[MessageReaction]]]: logger.debug(f"check_condition_for_scoring_state({message_tree_id=})") - mts: MessageTreeState - mts = self.db.query(MessageTreeState).filter(MessageTreeState.message_tree_id == message_tree_id).one() + + mts = self.pr.fetch_tree_state(message_tree_id) if not mts.active or mts.state != message_tree_state.State.RANKING: logger.debug(f"False {mts.active=}, {mts.state=}") - return False + return False, None ranking_role_filter = None if self.cfg.rank_prompter_replies else "assistant" rankings_by_message = self.query_tree_ranking_results(message_tree_id, role_filter=ranking_role_filter) for parent_msg_id, ranking in rankings_by_message.items(): if len(ranking) < self.cfg.num_required_rankings: logger.debug(f"False {parent_msg_id=} {len(ranking)=}") - return False + return False, None self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING) - self.update_message_ranks(rankings_by_message) - return True + return True, rankings_by_message @managed_tx_method(CommitMode.COMMIT) - def update_message_ranks(self, rankings_by_message: Dict[int, int]) -> None: - for parent_msg_id, ranking in rankings_by_message.items(): - sorted_messages = [] - for msg_reaction in ranking: - sorted_messages.append(msg_reaction.payload.payload.ranked_message_ids) - logger.debug(f"SORTED MESSAGE {sorted_messages}") - consensus = ranked_pairs(sorted_messages) - logger.debug(f"CONSENSUS: {consensus}\n\n") - for rank, message_id in enumerate(consensus): - # set rank for each message_id for Message rows - msg = self.pr.fetch_message(message_id=message_id, fail_if_missing=True) - msg.rank = rank - self.db.add(msg) + def update_message_ranks(self, message_tree_id: UUID, rankings_by_message: Dict[int, int]) -> bool: + + mts = self.pr.fetch_tree_state(message_tree_id) + # check state, allow retry if in SCORING_FAILED state + if mts.state not in (message_tree_state.State.READY_FOR_SCORING, message_tree_state.State.SCORING_FAILED): + logger.debug(f"False {mts.active=}, {mts.state=}") + return False + + try: + for rankings in rankings_by_message.values(): + sorted_messages = [] + for msg_reaction in rankings: + sorted_messages.append(msg_reaction.payload.payload.ranked_message_ids) + logger.debug(f"SORTED MESSAGE {sorted_messages}") + consensus = ranked_pairs(sorted_messages) + logger.debug(f"CONSENSUS: {consensus}\n\n") + for rank, message_id in enumerate(consensus): + # set rank for each message_id for Message rows + msg = self.pr.fetch_message(message_id=message_id, fail_if_missing=True) + msg.rank = rank + self.db.add(msg) + + except Exception: + logger.exception(f"update_message_ranks({message_tree_id=}) failed") + self._enter_state(mts, message_tree_state.State.SCORING_FAILED) + return False + + self._enter_state(mts, message_tree_state.State.READY_FOR_EXPORT) + return True def _calculate_acceptance(self, labels: list[TextLabels]): # calculate acceptance based on spam label diff --git a/backend/oasst_backend/user_stats_repository.py b/backend/oasst_backend/user_stats_repository.py index bdd0e2e9..fe466bfa 100644 --- a/backend/oasst_backend/user_stats_repository.py +++ b/backend/oasst_backend/user_stats_repository.py @@ -4,6 +4,7 @@ from uuid import UUID import sqlalchemy as sa from loguru import logger +from oasst_backend.config import settings from oasst_backend.models import Message, MessageReaction, Task, User, UserStats, UserStatsTimeFrame from oasst_backend.models.db_payload import ( LabelAssistantReplyPayload, @@ -291,13 +292,11 @@ WHERE if __name__ == "__main__": - from oasst_backend.api.deps import get_dummy_api_client + from oasst_backend.api.deps import api_auth from oasst_backend.database import engine - with Session(engine) as session: - api_client = get_dummy_api_client(session) - usr = UserStatsRepository(session) - # usr.update_all_time_frames() - # session.commit() - # usr.get_leader_board(UserStatsTimeFrame.total) - usr.get_user_stats_all_time_frames(UUID("0d6ff62a-0bea-4c56-ade8-b3e0520a10ce")) + with Session(engine) as db: + api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=db) + usr = UserStatsRepository(db) + usr.update_all_time_frames() + db.commit() diff --git a/website/package-lock.json b/website/package-lock.json index 1fa3d14d..68ef2ba8 100644 --- a/website/package-lock.json +++ b/website/package-lock.json @@ -45,6 +45,7 @@ "sharp": "^0.31.3", "swr": "^2.0.0", "tailwindcss": "^3.2.4", + "unique-username-generator": "^1.1.3", "use-debounce": "^9.0.2" }, "devDependencies": { @@ -35987,6 +35988,11 @@ "imurmurhash": "^0.1.4" } }, + "node_modules/unique-username-generator": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/unique-username-generator/-/unique-username-generator-1.1.3.tgz", + "integrity": "sha512-TB6YdqPMKMpTSgxAzjZkKWtmpZPHvARoWreCKBpc1UrLFz/0C6Q96/qdjpLr9OXPCHk16sD1LHjTr3JDj7q2JA==" + }, "node_modules/unist-builder": { "version": "2.0.3", "resolved": "https://registry.npmjs.org/unist-builder/-/unist-builder-2.0.3.tgz", @@ -64596,6 +64602,11 @@ "imurmurhash": "^0.1.4" } }, + "unique-username-generator": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/unique-username-generator/-/unique-username-generator-1.1.3.tgz", + "integrity": "sha512-TB6YdqPMKMpTSgxAzjZkKWtmpZPHvARoWreCKBpc1UrLFz/0C6Q96/qdjpLr9OXPCHk16sD1LHjTr3JDj7q2JA==" + }, "unist-builder": { "version": "2.0.3", "resolved": "https://registry.npmjs.org/unist-builder/-/unist-builder-2.0.3.tgz", diff --git a/website/package.json b/website/package.json index 580d0be3..2e46d187 100644 --- a/website/package.json +++ b/website/package.json @@ -62,6 +62,7 @@ "sharp": "^0.31.3", "swr": "^2.0.0", "tailwindcss": "^3.2.4", + "unique-username-generator": "^1.1.3", "use-debounce": "^9.0.2" }, "devDependencies": { diff --git a/website/src/components/Header/UserMenu.tsx b/website/src/components/Header/UserMenu.tsx index 7a470464..99ec01f1 100644 --- a/website/src/components/Header/UserMenu.tsx +++ b/website/src/components/Header/UserMenu.tsx @@ -74,7 +74,7 @@ export function UserMenu() { - {session.user.name || session.user.email} + {session.user.name || "New User"} diff --git a/website/src/lib/oasst_api_client.ts b/website/src/lib/oasst_api_client.ts index fb11adec..d48a987c 100644 --- a/website/src/lib/oasst_api_client.ts +++ b/website/src/lib/oasst_api_client.ts @@ -113,7 +113,7 @@ export class OasstApiClient { type: taskType, user: { id: userToken.sub, - display_name: userToken.name || userToken.email, + display_name: userToken.name, auth_method: "local", }, }); @@ -146,7 +146,7 @@ export class OasstApiClient { type: updateType, user: { id: userToken.sub, - display_name: userToken.name || userToken.email, + display_name: userToken.name, auth_method: "local", }, task_id: taskId, diff --git a/website/src/pages/api/auth/[...nextauth].ts b/website/src/pages/api/auth/[...nextauth].ts index 691cbcba..c718ddce 100644 --- a/website/src/pages/api/auth/[...nextauth].ts +++ b/website/src/pages/api/auth/[...nextauth].ts @@ -7,6 +7,7 @@ import CredentialsProvider from "next-auth/providers/credentials"; import DiscordProvider from "next-auth/providers/discord"; import EmailProvider from "next-auth/providers/email"; import prisma from "src/lib/prismadb"; +import { generateUsername } from "unique-username-generator"; const providers: Provider[] = []; @@ -97,10 +98,11 @@ export const authOptions: AuthOptions = { * This let's use forward the role to the session object. */ async jwt({ token }) { - const { isNew, role } = await prisma.user.findUnique({ + const { isNew, name, role } = await prisma.user.findUnique({ where: { id: token.sub }, - select: { role: true, isNew: true }, + select: { name: true, role: true, isNew: true }, }); + token.name = name; token.role = role; token.isNew = isNew; return token; @@ -110,7 +112,18 @@ export const authOptions: AuthOptions = { /** * Update the user's role after they have successfully signed in */ - async signIn({ user, account }) { + async signIn({ user, account, isNewUser }) { + if (isNewUser && account.provider === "email") { + await prisma.user.update({ + data: { + name: generateUsername(), + }, + where: { + id: user.id, + }, + }); + } + // Get the admin list for the user's auth type. const adminForAccountType = adminUserMap.get(account.provider);