mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-07-04 17:20:19 +08:00
Merge branch 'main' into production
This commit is contained in:
@@ -489,7 +489,8 @@ class TreeManager:
|
||||
|
||||
_, task = pr.store_ranking(interaction)
|
||||
|
||||
self.check_condition_for_scoring_state(task.message_tree_id)
|
||||
ok, rankings_by_message = self.check_condition_for_scoring_state(task.message_tree_id)
|
||||
self.update_message_ranks(task.message_tree_id, rankings_by_message)
|
||||
|
||||
case protocol_schema.TextLabels:
|
||||
logger.info(
|
||||
@@ -589,39 +590,56 @@ class TreeManager:
|
||||
return True
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def check_condition_for_scoring_state(self, message_tree_id: UUID) -> bool:
|
||||
def check_condition_for_scoring_state(
|
||||
self, message_tree_id: UUID
|
||||
) -> Tuple[bool, dict[UUID, list[MessageReaction]]]:
|
||||
logger.debug(f"check_condition_for_scoring_state({message_tree_id=})")
|
||||
mts: MessageTreeState
|
||||
mts = self.db.query(MessageTreeState).filter(MessageTreeState.message_tree_id == message_tree_id).one()
|
||||
|
||||
mts = self.pr.fetch_tree_state(message_tree_id)
|
||||
if not mts.active or mts.state != message_tree_state.State.RANKING:
|
||||
logger.debug(f"False {mts.active=}, {mts.state=}")
|
||||
return False
|
||||
return False, None
|
||||
|
||||
ranking_role_filter = None if self.cfg.rank_prompter_replies else "assistant"
|
||||
rankings_by_message = self.query_tree_ranking_results(message_tree_id, role_filter=ranking_role_filter)
|
||||
for parent_msg_id, ranking in rankings_by_message.items():
|
||||
if len(ranking) < self.cfg.num_required_rankings:
|
||||
logger.debug(f"False {parent_msg_id=} {len(ranking)=}")
|
||||
return False
|
||||
return False, None
|
||||
|
||||
self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING)
|
||||
self.update_message_ranks(rankings_by_message)
|
||||
return True
|
||||
return True, rankings_by_message
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def update_message_ranks(self, rankings_by_message: Dict[int, int]) -> None:
|
||||
for parent_msg_id, ranking in rankings_by_message.items():
|
||||
sorted_messages = []
|
||||
for msg_reaction in ranking:
|
||||
sorted_messages.append(msg_reaction.payload.payload.ranked_message_ids)
|
||||
logger.debug(f"SORTED MESSAGE {sorted_messages}")
|
||||
consensus = ranked_pairs(sorted_messages)
|
||||
logger.debug(f"CONSENSUS: {consensus}\n\n")
|
||||
for rank, message_id in enumerate(consensus):
|
||||
# set rank for each message_id for Message rows
|
||||
msg = self.pr.fetch_message(message_id=message_id, fail_if_missing=True)
|
||||
msg.rank = rank
|
||||
self.db.add(msg)
|
||||
def update_message_ranks(self, message_tree_id: UUID, rankings_by_message: Dict[int, int]) -> bool:
|
||||
|
||||
mts = self.pr.fetch_tree_state(message_tree_id)
|
||||
# check state, allow retry if in SCORING_FAILED state
|
||||
if mts.state not in (message_tree_state.State.READY_FOR_SCORING, message_tree_state.State.SCORING_FAILED):
|
||||
logger.debug(f"False {mts.active=}, {mts.state=}")
|
||||
return False
|
||||
|
||||
try:
|
||||
for rankings in rankings_by_message.values():
|
||||
sorted_messages = []
|
||||
for msg_reaction in rankings:
|
||||
sorted_messages.append(msg_reaction.payload.payload.ranked_message_ids)
|
||||
logger.debug(f"SORTED MESSAGE {sorted_messages}")
|
||||
consensus = ranked_pairs(sorted_messages)
|
||||
logger.debug(f"CONSENSUS: {consensus}\n\n")
|
||||
for rank, message_id in enumerate(consensus):
|
||||
# set rank for each message_id for Message rows
|
||||
msg = self.pr.fetch_message(message_id=message_id, fail_if_missing=True)
|
||||
msg.rank = rank
|
||||
self.db.add(msg)
|
||||
|
||||
except Exception:
|
||||
logger.exception(f"update_message_ranks({message_tree_id=}) failed")
|
||||
self._enter_state(mts, message_tree_state.State.SCORING_FAILED)
|
||||
return False
|
||||
|
||||
self._enter_state(mts, message_tree_state.State.READY_FOR_EXPORT)
|
||||
return True
|
||||
|
||||
def _calculate_acceptance(self, labels: list[TextLabels]):
|
||||
# calculate acceptance based on spam label
|
||||
|
||||
@@ -4,6 +4,7 @@ from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
from loguru import logger
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.models import Message, MessageReaction, Task, User, UserStats, UserStatsTimeFrame
|
||||
from oasst_backend.models.db_payload import (
|
||||
LabelAssistantReplyPayload,
|
||||
@@ -291,13 +292,11 @@ WHERE
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from oasst_backend.api.deps import get_dummy_api_client
|
||||
from oasst_backend.api.deps import api_auth
|
||||
from oasst_backend.database import engine
|
||||
|
||||
with Session(engine) as session:
|
||||
api_client = get_dummy_api_client(session)
|
||||
usr = UserStatsRepository(session)
|
||||
# usr.update_all_time_frames()
|
||||
# session.commit()
|
||||
# usr.get_leader_board(UserStatsTimeFrame.total)
|
||||
usr.get_user_stats_all_time_frames(UUID("0d6ff62a-0bea-4c56-ade8-b3e0520a10ce"))
|
||||
with Session(engine) as db:
|
||||
api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=db)
|
||||
usr = UserStatsRepository(db)
|
||||
usr.update_all_time_frames()
|
||||
db.commit()
|
||||
|
||||
Generated
+11
@@ -45,6 +45,7 @@
|
||||
"sharp": "^0.31.3",
|
||||
"swr": "^2.0.0",
|
||||
"tailwindcss": "^3.2.4",
|
||||
"unique-username-generator": "^1.1.3",
|
||||
"use-debounce": "^9.0.2"
|
||||
},
|
||||
"devDependencies": {
|
||||
@@ -35987,6 +35988,11 @@
|
||||
"imurmurhash": "^0.1.4"
|
||||
}
|
||||
},
|
||||
"node_modules/unique-username-generator": {
|
||||
"version": "1.1.3",
|
||||
"resolved": "https://registry.npmjs.org/unique-username-generator/-/unique-username-generator-1.1.3.tgz",
|
||||
"integrity": "sha512-TB6YdqPMKMpTSgxAzjZkKWtmpZPHvARoWreCKBpc1UrLFz/0C6Q96/qdjpLr9OXPCHk16sD1LHjTr3JDj7q2JA=="
|
||||
},
|
||||
"node_modules/unist-builder": {
|
||||
"version": "2.0.3",
|
||||
"resolved": "https://registry.npmjs.org/unist-builder/-/unist-builder-2.0.3.tgz",
|
||||
@@ -64596,6 +64602,11 @@
|
||||
"imurmurhash": "^0.1.4"
|
||||
}
|
||||
},
|
||||
"unique-username-generator": {
|
||||
"version": "1.1.3",
|
||||
"resolved": "https://registry.npmjs.org/unique-username-generator/-/unique-username-generator-1.1.3.tgz",
|
||||
"integrity": "sha512-TB6YdqPMKMpTSgxAzjZkKWtmpZPHvARoWreCKBpc1UrLFz/0C6Q96/qdjpLr9OXPCHk16sD1LHjTr3JDj7q2JA=="
|
||||
},
|
||||
"unist-builder": {
|
||||
"version": "2.0.3",
|
||||
"resolved": "https://registry.npmjs.org/unist-builder/-/unist-builder-2.0.3.tgz",
|
||||
|
||||
@@ -62,6 +62,7 @@
|
||||
"sharp": "^0.31.3",
|
||||
"swr": "^2.0.0",
|
||||
"tailwindcss": "^3.2.4",
|
||||
"unique-username-generator": "^1.1.3",
|
||||
"use-debounce": "^9.0.2"
|
||||
},
|
||||
"devDependencies": {
|
||||
|
||||
@@ -74,7 +74,7 @@ export function UserMenu() {
|
||||
<Box display="flex" alignItems="center" gap="3" p="1" paddingRight={[1, 1, 1, 6, 6]}>
|
||||
<Avatar size="sm" bgImage={session.user.image}></Avatar>
|
||||
<Text data-cy="username" className="hidden lg:flex">
|
||||
{session.user.name || session.user.email}
|
||||
{session.user.name || "New User"}
|
||||
</Text>
|
||||
</Box>
|
||||
</MenuButton>
|
||||
|
||||
@@ -113,7 +113,7 @@ export class OasstApiClient {
|
||||
type: taskType,
|
||||
user: {
|
||||
id: userToken.sub,
|
||||
display_name: userToken.name || userToken.email,
|
||||
display_name: userToken.name,
|
||||
auth_method: "local",
|
||||
},
|
||||
});
|
||||
@@ -146,7 +146,7 @@ export class OasstApiClient {
|
||||
type: updateType,
|
||||
user: {
|
||||
id: userToken.sub,
|
||||
display_name: userToken.name || userToken.email,
|
||||
display_name: userToken.name,
|
||||
auth_method: "local",
|
||||
},
|
||||
task_id: taskId,
|
||||
|
||||
@@ -7,6 +7,7 @@ import CredentialsProvider from "next-auth/providers/credentials";
|
||||
import DiscordProvider from "next-auth/providers/discord";
|
||||
import EmailProvider from "next-auth/providers/email";
|
||||
import prisma from "src/lib/prismadb";
|
||||
import { generateUsername } from "unique-username-generator";
|
||||
|
||||
const providers: Provider[] = [];
|
||||
|
||||
@@ -97,10 +98,11 @@ export const authOptions: AuthOptions = {
|
||||
* This let's use forward the role to the session object.
|
||||
*/
|
||||
async jwt({ token }) {
|
||||
const { isNew, role } = await prisma.user.findUnique({
|
||||
const { isNew, name, role } = await prisma.user.findUnique({
|
||||
where: { id: token.sub },
|
||||
select: { role: true, isNew: true },
|
||||
select: { name: true, role: true, isNew: true },
|
||||
});
|
||||
token.name = name;
|
||||
token.role = role;
|
||||
token.isNew = isNew;
|
||||
return token;
|
||||
@@ -110,7 +112,18 @@ export const authOptions: AuthOptions = {
|
||||
/**
|
||||
* Update the user's role after they have successfully signed in
|
||||
*/
|
||||
async signIn({ user, account }) {
|
||||
async signIn({ user, account, isNewUser }) {
|
||||
if (isNewUser && account.provider === "email") {
|
||||
await prisma.user.update({
|
||||
data: {
|
||||
name: generateUsername(),
|
||||
},
|
||||
where: {
|
||||
id: user.id,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// Get the admin list for the user's auth type.
|
||||
const adminForAccountType = adminUserMap.get(account.provider);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user