From af0711e505dc9ba640bfc7e8a646ec10d7ee1e7b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Sat, 4 Feb 2023 14:36:29 +0100 Subject: [PATCH] include initial prompt review in user stats --- .../oasst_backend/user_stats_repository.py | 30 +++++++++++++++---- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/backend/oasst_backend/user_stats_repository.py b/backend/oasst_backend/user_stats_repository.py index 4c28b293..cca0d6bf 100644 --- a/backend/oasst_backend/user_stats_repository.py +++ b/backend/oasst_backend/user_stats_repository.py @@ -18,11 +18,20 @@ from oasst_backend.models import ( ) from oasst_backend.models.db_payload import ( LabelAssistantReplyPayload, + LabelInitialPromptPayload, LabelPrompterReplyPayload, RankingReactionPayload, ) from oasst_backend.models.message_tree_state import State as TreeState -from oasst_shared.schemas.protocol import EmojiCode, LeaderboardStats, TextLabel, TrollboardStats, TrollScore, UserScore +from oasst_shared.schemas.protocol import ( + EmojiCode, + LabelTaskMode, + LeaderboardStats, + TextLabel, + TrollboardStats, + TrollScore, + UserScore, +) from oasst_shared.utils import log_timing, utcnow from sqlalchemy.dialects import postgresql from sqlalchemy.sql.functions import coalesce @@ -310,9 +319,9 @@ class UserStatsRepository: for r in qry: uid, mode, count = r s = get_stats(uid) - if mode == "simple": + if mode == LabelTaskMode.simple: s.labels_simple = count - elif mode == "full": + elif mode == LabelTaskMode.full: s.labels_full = count qry = self.query_labels_by_mode_per_user( @@ -321,9 +330,20 @@ class UserStatsRepository: for r in qry: uid, mode, count = r s = get_stats(uid) - if mode == "simple": + if mode == LabelTaskMode.simple: s.labels_simple += count - elif mode == "full": + elif mode == LabelTaskMode.full: + s.labels_full += count + + qry = self.query_labels_by_mode_per_user( + payload_type=LabelInitialPromptPayload.__name__, reference_time=base_date + ) + for r in qry: + uid, mode, count = r + s = get_stats(uid) + if mode == LabelTaskMode.simple: + s.labels_simple += count + elif mode == LabelTaskMode.full: s.labels_full += count qry = self.query_rankings_per_user(reference_time=base_date)