Merge branch 'main' into production

This commit is contained in:
Yannic Kilcher
2023-01-17 22:00:46 +01:00
7 changed files with 77 additions and 35 deletions
+39 -21
View File
@@ -489,7 +489,8 @@ class TreeManager:
_, task = pr.store_ranking(interaction)
self.check_condition_for_scoring_state(task.message_tree_id)
ok, rankings_by_message = self.check_condition_for_scoring_state(task.message_tree_id)
self.update_message_ranks(task.message_tree_id, rankings_by_message)
case protocol_schema.TextLabels:
logger.info(
@@ -589,39 +590,56 @@ class TreeManager:
return True
@managed_tx_method(CommitMode.COMMIT)
def check_condition_for_scoring_state(self, message_tree_id: UUID) -> bool:
def check_condition_for_scoring_state(
self, message_tree_id: UUID
) -> Tuple[bool, dict[UUID, list[MessageReaction]]]:
logger.debug(f"check_condition_for_scoring_state({message_tree_id=})")
mts: MessageTreeState
mts = self.db.query(MessageTreeState).filter(MessageTreeState.message_tree_id == message_tree_id).one()
mts = self.pr.fetch_tree_state(message_tree_id)
if not mts.active or mts.state != message_tree_state.State.RANKING:
logger.debug(f"False {mts.active=}, {mts.state=}")
return False
return False, None
ranking_role_filter = None if self.cfg.rank_prompter_replies else "assistant"
rankings_by_message = self.query_tree_ranking_results(message_tree_id, role_filter=ranking_role_filter)
for parent_msg_id, ranking in rankings_by_message.items():
if len(ranking) < self.cfg.num_required_rankings:
logger.debug(f"False {parent_msg_id=} {len(ranking)=}")
return False
return False, None
self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING)
self.update_message_ranks(rankings_by_message)
return True
return True, rankings_by_message
@managed_tx_method(CommitMode.COMMIT)
def update_message_ranks(self, rankings_by_message: Dict[int, int]) -> None:
for parent_msg_id, ranking in rankings_by_message.items():
sorted_messages = []
for msg_reaction in ranking:
sorted_messages.append(msg_reaction.payload.payload.ranked_message_ids)
logger.debug(f"SORTED MESSAGE {sorted_messages}")
consensus = ranked_pairs(sorted_messages)
logger.debug(f"CONSENSUS: {consensus}\n\n")
for rank, message_id in enumerate(consensus):
# set rank for each message_id for Message rows
msg = self.pr.fetch_message(message_id=message_id, fail_if_missing=True)
msg.rank = rank
self.db.add(msg)
def update_message_ranks(self, message_tree_id: UUID, rankings_by_message: Dict[int, int]) -> bool:
mts = self.pr.fetch_tree_state(message_tree_id)
# check state, allow retry if in SCORING_FAILED state
if mts.state not in (message_tree_state.State.READY_FOR_SCORING, message_tree_state.State.SCORING_FAILED):
logger.debug(f"False {mts.active=}, {mts.state=}")
return False
try:
for rankings in rankings_by_message.values():
sorted_messages = []
for msg_reaction in rankings:
sorted_messages.append(msg_reaction.payload.payload.ranked_message_ids)
logger.debug(f"SORTED MESSAGE {sorted_messages}")
consensus = ranked_pairs(sorted_messages)
logger.debug(f"CONSENSUS: {consensus}\n\n")
for rank, message_id in enumerate(consensus):
# set rank for each message_id for Message rows
msg = self.pr.fetch_message(message_id=message_id, fail_if_missing=True)
msg.rank = rank
self.db.add(msg)
except Exception:
logger.exception(f"update_message_ranks({message_tree_id=}) failed")
self._enter_state(mts, message_tree_state.State.SCORING_FAILED)
return False
self._enter_state(mts, message_tree_state.State.READY_FOR_EXPORT)
return True
def _calculate_acceptance(self, labels: list[TextLabels]):
# calculate acceptance based on spam label
@@ -4,6 +4,7 @@ from uuid import UUID
import sqlalchemy as sa
from loguru import logger
from oasst_backend.config import settings
from oasst_backend.models import Message, MessageReaction, Task, User, UserStats, UserStatsTimeFrame
from oasst_backend.models.db_payload import (
LabelAssistantReplyPayload,
@@ -291,13 +292,11 @@ WHERE
if __name__ == "__main__":
from oasst_backend.api.deps import get_dummy_api_client
from oasst_backend.api.deps import api_auth
from oasst_backend.database import engine
with Session(engine) as session:
api_client = get_dummy_api_client(session)
usr = UserStatsRepository(session)
# usr.update_all_time_frames()
# session.commit()
# usr.get_leader_board(UserStatsTimeFrame.total)
usr.get_user_stats_all_time_frames(UUID("0d6ff62a-0bea-4c56-ade8-b3e0520a10ce"))
with Session(engine) as db:
api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=db)
usr = UserStatsRepository(db)
usr.update_all_time_frames()
db.commit()
+11
View File
@@ -45,6 +45,7 @@
"sharp": "^0.31.3",
"swr": "^2.0.0",
"tailwindcss": "^3.2.4",
"unique-username-generator": "^1.1.3",
"use-debounce": "^9.0.2"
},
"devDependencies": {
@@ -35987,6 +35988,11 @@
"imurmurhash": "^0.1.4"
}
},
"node_modules/unique-username-generator": {
"version": "1.1.3",
"resolved": "https://registry.npmjs.org/unique-username-generator/-/unique-username-generator-1.1.3.tgz",
"integrity": "sha512-TB6YdqPMKMpTSgxAzjZkKWtmpZPHvARoWreCKBpc1UrLFz/0C6Q96/qdjpLr9OXPCHk16sD1LHjTr3JDj7q2JA=="
},
"node_modules/unist-builder": {
"version": "2.0.3",
"resolved": "https://registry.npmjs.org/unist-builder/-/unist-builder-2.0.3.tgz",
@@ -64596,6 +64602,11 @@
"imurmurhash": "^0.1.4"
}
},
"unique-username-generator": {
"version": "1.1.3",
"resolved": "https://registry.npmjs.org/unique-username-generator/-/unique-username-generator-1.1.3.tgz",
"integrity": "sha512-TB6YdqPMKMpTSgxAzjZkKWtmpZPHvARoWreCKBpc1UrLFz/0C6Q96/qdjpLr9OXPCHk16sD1LHjTr3JDj7q2JA=="
},
"unist-builder": {
"version": "2.0.3",
"resolved": "https://registry.npmjs.org/unist-builder/-/unist-builder-2.0.3.tgz",
+1
View File
@@ -62,6 +62,7 @@
"sharp": "^0.31.3",
"swr": "^2.0.0",
"tailwindcss": "^3.2.4",
"unique-username-generator": "^1.1.3",
"use-debounce": "^9.0.2"
},
"devDependencies": {
+1 -1
View File
@@ -74,7 +74,7 @@ export function UserMenu() {
<Box display="flex" alignItems="center" gap="3" p="1" paddingRight={[1, 1, 1, 6, 6]}>
<Avatar size="sm" bgImage={session.user.image}></Avatar>
<Text data-cy="username" className="hidden lg:flex">
{session.user.name || session.user.email}
{session.user.name || "New User"}
</Text>
</Box>
</MenuButton>
+2 -2
View File
@@ -113,7 +113,7 @@ export class OasstApiClient {
type: taskType,
user: {
id: userToken.sub,
display_name: userToken.name || userToken.email,
display_name: userToken.name,
auth_method: "local",
},
});
@@ -146,7 +146,7 @@ export class OasstApiClient {
type: updateType,
user: {
id: userToken.sub,
display_name: userToken.name || userToken.email,
display_name: userToken.name,
auth_method: "local",
},
task_id: taskId,
+16 -3
View File
@@ -7,6 +7,7 @@ import CredentialsProvider from "next-auth/providers/credentials";
import DiscordProvider from "next-auth/providers/discord";
import EmailProvider from "next-auth/providers/email";
import prisma from "src/lib/prismadb";
import { generateUsername } from "unique-username-generator";
const providers: Provider[] = [];
@@ -97,10 +98,11 @@ export const authOptions: AuthOptions = {
* This let's use forward the role to the session object.
*/
async jwt({ token }) {
const { isNew, role } = await prisma.user.findUnique({
const { isNew, name, role } = await prisma.user.findUnique({
where: { id: token.sub },
select: { role: true, isNew: true },
select: { name: true, role: true, isNew: true },
});
token.name = name;
token.role = role;
token.isNew = isNew;
return token;
@@ -110,7 +112,18 @@ export const authOptions: AuthOptions = {
/**
* Update the user's role after they have successfully signed in
*/
async signIn({ user, account }) {
async signIn({ user, account, isNewUser }) {
if (isNewUser && account.provider === "email") {
await prisma.user.update({
data: {
name: generateUsername(),
},
where: {
id: user.id,
},
});
}
// Get the admin list for the user's auth type.
const adminForAccountType = adminUserMap.get(account.provider);