From 547e355e2767a294a8f1a94bcf514070a4b7e47d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Thu, 19 Jan 2023 15:10:22 +0100 Subject: [PATCH] add update_leaderboard..(), admin/purge/{user-id}/messages, ban param for purge_user() --- backend/oasst_backend/api/v1/admin.py | 62 +++++++++++++++++-- backend/oasst_backend/api/v1/leaderboards.py | 20 ++++++ backend/oasst_backend/api/v1/messages.py | 4 ++ backend/oasst_backend/tree_manager.py | 39 ++++++++---- backend/oasst_backend/utils/database_utils.py | 9 ++- oasst-shared/oasst_shared/utils.py | 7 +++ 6 files changed, 121 insertions(+), 20 deletions(-) diff --git a/backend/oasst_backend/api/v1/admin.py b/backend/oasst_backend/api/v1/admin.py index e171e6a1..7f6be8e7 100644 --- a/backend/oasst_backend/api/v1/admin.py +++ b/backend/oasst_backend/api/v1/admin.py @@ -1,3 +1,4 @@ +from datetime import datetime from uuid import UUID import pydantic @@ -5,12 +6,12 @@ from fastapi import APIRouter, Depends from loguru import logger from oasst_backend.api import deps from oasst_backend.config import Settings, settings -from oasst_backend.models.api_client import ApiClient +from oasst_backend.models import ApiClient, User from oasst_backend.prompt_repository import PromptRepository from oasst_backend.tree_manager import TreeManager from oasst_backend.utils.database_utils import CommitMode, managed_tx_function from oasst_shared.schemas.protocol import SystemStats -from oasst_shared.utils import ScopeTimer +from oasst_shared.utils import ScopeTimer, unaware_to_utc router = APIRouter() @@ -76,24 +77,26 @@ class PurgeResultModel(pydantic.BaseModel): duration: float -@router.post("/purge_user/{user_id}", response_model=PurgeResultModel) +@router.post("/purge/{user_id}", response_model=PurgeResultModel) async def purge_user( user_id: UUID, preview: bool = True, + ban: bool = True, api_client: ApiClient = Depends(deps.get_trusted_api_client), ) -> str: assert api_client.trusted - @managed_tx_function(CommitMode.NONE if preview else CommitMode.COMMIT) - def purge_tx(session: deps.Session): + @managed_tx_function(CommitMode.ROLLBACK if preview else CommitMode.COMMIT) + def purge_tx(session: deps.Session) -> tuple[User, SystemStats, SystemStats]: pr = PromptRepository(session, api_client) stats_before = pr.get_stats() user = pr.user_repository.get_user(user_id) tm = TreeManager(session, pr) - tm.purge_user(user_id) + tm.purge_user(user_id=user_id, ban=ban) + session.expunge(user) return user, stats_before, pr.get_stats() timer = ScopeTimer() @@ -111,3 +114,50 @@ async def purge_user( logger.info(f"{before=}; {after=}") return PurgeResultModel(before=before, after=after, preview=preview, duration=timer.elapsed) + + +@router.post("/purge/{user_id}/messages", response_model=PurgeResultModel) +async def purge_user_messages( + user_id: UUID, + purge_initial_prompts: bool = False, + min_date: datetime = None, + max_date: datetime = None, + preview: bool = True, + api_client: ApiClient = Depends(deps.get_trusted_api_client), +) -> str: + assert api_client.trusted + + min_date = unaware_to_utc(min_date) + max_date = unaware_to_utc(max_date) + + @managed_tx_function(CommitMode.ROLLBACK if preview else CommitMode.COMMIT) + def purge_user_messages_tx(session: deps.Session): + pr = PromptRepository(session, api_client) + + stats_before = pr.get_stats() + + user = pr.user_repository.get_user(user_id) + + tm = TreeManager(session, pr) + tm.purge_user_messages( + user_id, purge_initial_prompts=purge_initial_prompts, min_date=min_date, max_date=max_date + ) + + session.expunge(user) + return user, stats_before, pr.get_stats() + + timer = ScopeTimer() + user, before, after = purge_user_messages_tx() + timer.stop() + + if preview: + logger.info( + f"PURGE USER MESSAGES PREVIEW: '{user.display_name}' (id: {str(user_id)}; username: '{user.username}'; auth-method: '{user.auth_method}')" + ) + else: + logger.warning( + f"PURGE USER MESSAGES: '{user.display_name}' (id: {str(user_id)}; username: '{user.username}'; auth-method: '{user.auth_method}')" + ) + + logger.info(f"{before=}; {after=}") + return PurgeResultModel(before=before, after=after, preview=preview, duration=timer.elapsed) diff --git a/backend/oasst_backend/api/v1/leaderboards.py b/backend/oasst_backend/api/v1/leaderboards.py index 213855a1..27366475 100644 --- a/backend/oasst_backend/api/v1/leaderboards.py +++ b/backend/oasst_backend/api/v1/leaderboards.py @@ -6,6 +6,7 @@ from oasst_backend.models import ApiClient from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame from oasst_shared.schemas.protocol import LeaderboardStats from sqlmodel import Session +from starlette.status import HTTP_204_NO_CONTENT router = APIRouter() @@ -19,3 +20,22 @@ def get_leaderboard( ) -> LeaderboardStats: usr = UserStatsRepository(db) return usr.get_leaderboard(time_frame, limit=max_count) + + +@router.post("/update/{time_frame}", response_model=None, status_code=HTTP_204_NO_CONTENT) +def update_leaderboard_time_frame( + time_frame: UserStatsTimeFrame, + api_client: ApiClient = Depends(deps.get_trusted_api_client), + db: Session = Depends(deps.get_db), +) -> LeaderboardStats: + usr = UserStatsRepository(db) + return usr.update_stats(time_frame=time_frame) + + +@router.post("/update", response_model=None, status_code=HTTP_204_NO_CONTENT) +def update_leaderboards_all( + api_client: ApiClient = Depends(deps.get_trusted_api_client), + db: Session = Depends(deps.get_db), +) -> LeaderboardStats: + usr = UserStatsRepository(db) + return usr.update_all_time_frames() diff --git a/backend/oasst_backend/api/v1/messages.py b/backend/oasst_backend/api/v1/messages.py index 6229e20c..fcba59df 100644 --- a/backend/oasst_backend/api/v1/messages.py +++ b/backend/oasst_backend/api/v1/messages.py @@ -7,6 +7,7 @@ from oasst_backend.api.v1 import utils from oasst_backend.models import ApiClient from oasst_backend.prompt_repository import PromptRepository from oasst_shared.schemas import protocol +from oasst_shared.utils import unaware_to_utc from sqlmodel import Session from starlette.status import HTTP_204_NO_CONTENT @@ -29,6 +30,9 @@ def query_messages( """ Query messages. """ + start_date = unaware_to_utc(start_date) + end_date = unaware_to_utc(end_date) + pr = PromptRepository(db, api_client) messages = pr.query_messages( username=username, diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index da8cd9a3..731f39a8 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -10,14 +10,14 @@ import pydantic from loguru import logger from oasst_backend.api.v1.utils import prepare_conversation, prepare_conversation_message_list from oasst_backend.config import TreeManagerConfiguration, settings -from oasst_backend.models import Message, MessageReaction, MessageTreeState, Task, TextLabels, message_tree_state +from oasst_backend.models import Message, MessageReaction, MessageTreeState, Task, TextLabels, User, message_tree_state from oasst_backend.prompt_repository import PromptRepository from oasst_backend.utils.database_utils import CommitMode, async_managed_tx_method, managed_tx_method from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingModel, HfUrl, HuggingFaceAPI from oasst_backend.utils.ranking import ranked_pairs from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema -from sqlmodel import Session, func, not_, text +from sqlmodel import Session, func, not_, text, update class TaskType(Enum): @@ -577,7 +577,6 @@ class TreeManager: mts = self.pr.fetch_tree_state(message_tree_id) self._enter_state(mts, message_tree_state.State.ABORTED_LOW_GRADE) - @managed_tx_method(CommitMode.COMMIT) def check_condition_for_growing_state(self, message_tree_id: UUID) -> bool: logger.debug(f"check_condition_for_growing_state({message_tree_id=})") @@ -595,7 +594,6 @@ class TreeManager: self._enter_state(mts, message_tree_state.State.GROWING) return True - @managed_tx_method(CommitMode.COMMIT) def check_condition_for_ranking_state(self, message_tree_id: UUID) -> bool: logger.debug(f"check_condition_for_ranking_state({message_tree_id=})") @@ -613,7 +611,6 @@ class TreeManager: self._enter_state(mts, message_tree_state.State.RANKING) return True - @managed_tx_method(CommitMode.COMMIT) def check_condition_for_scoring_state( self, message_tree_id: UUID ) -> Tuple[bool, dict[UUID, list[MessageReaction]]]: @@ -634,7 +631,6 @@ class TreeManager: self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING) return True, rankings_by_message - @managed_tx_method(CommitMode.COMMIT) def update_message_ranks(self, message_tree_id: UUID, rankings_by_message: Dict[int, int]) -> bool: mts = self.pr.fetch_tree_state(message_tree_id) @@ -982,12 +978,21 @@ INNER JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Ranki message_counts=self.tree_message_count_stats(only_active=True), ) - def get_user_messages_by_tree(self, user_id: UUID) -> Tuple[dict[UUID, list[Message]], list[Message]]: + def get_user_messages_by_tree( + self, + user_id: UUID, + min_date: datetime = None, + max_date: datetime = None, + ) -> Tuple[dict[UUID, list[Message]], list[Message]]: """Returns a dict with replies by tree (excluding initial prompts) and list of initial prompts associated with user_id.""" # query all messages of the user qry = self.db.query(Message).filter(Message.user_id == user_id) + if min_date: + qry = qry.filter(Message.created_date >= min_date) + if max_date: + qry = qry.filter(Message.created_date <= max_date) prompts: list[Message] = [] replies_by_tree: dict[UUID, list[Message]] = {} @@ -1045,10 +1050,16 @@ DELETE FROM message WHERE message_tree_id = :message_tree_id; logger.debug(f"purge_message_tree updated({message_tree_id=}) {r.rowcount} rows.") @managed_tx_method(CommitMode.FLUSH) - def purge_user_messages(self, user_id: UUID, purge_initial_prompts: bool = True): + def purge_user_messages( + self, + user_id: UUID, + purge_initial_prompts: bool = True, + min_date: datetime = None, + max_date: datetime = None, + ): # find all affected message trees - replies_by_tree, prompts = self.get_user_messages_by_tree(user_id) + replies_by_tree, prompts = self.get_user_messages_by_tree(user_id, min_date, max_date) # remove all trees based on inital prompts of the user if purge_initial_prompts: @@ -1094,9 +1105,11 @@ DELETE FROM message WHERE message_tree_id = :message_tree_id; mts.active = True self._enter_state(mts, message_tree_state.State.INITIAL_PROMPT_REVIEW) self.check_condition_for_growing_state(tree_id) + self.check_condition_for_ranking_state(tree_id) + self.check_condition_for_scoring_state(tree_id) @managed_tx_method(CommitMode.FLUSH) - def purge_user(self, user_id: UUID) -> None: + def purge_user(self, user_id: UUID, ban: bool = True) -> None: self.purge_user_messages(user_id, purge_initial_prompts=True) # delete all remaining rows and ban user @@ -1106,12 +1119,14 @@ DELETE FROM message_reaction WHERE user_id = :user_id; DELETE FROM task WHERE user_id = :user_id; DELETE FROM message WHERE user_id = :user_id; DELETE FROM user_stats WHERE user_id = :user_id; -UPDATE "user" SET deleted = TRUE, enabled = FALSE WHERE id = :user_id; """ r = self.db.execute(text(sql_purge_user), {"user_id": user_id}) logger.debug(f"purge_user({user_id=}): {r.rowcount} rows.") + if ban: + self.db.execute(update(User).filter(User.id == user_id).values(deleted=True, enabled=False)) + if __name__ == "__main__": from oasst_backend.api.deps import api_auth @@ -1128,7 +1143,7 @@ if __name__ == "__main__": tm = TreeManager(db, pr, cfg) tm.ensure_tree_states() - # tm.purge_user_messages(user_id=UUID("2ef9ad21-0dc5-442d-8750-6f7f1790723f"), purge_initial_prompts=False) + tm.purge_user_messages(user_id=UUID("2ef9ad21-0dc5-442d-8750-6f7f1790723f"), purge_initial_prompts=False) # tm.purge_user(user_id=UUID("2ef9ad21-0dc5-442d-8750-6f7f1790723f")) # db.commit() diff --git a/backend/oasst_backend/utils/database_utils.py b/backend/oasst_backend/utils/database_utils.py index d378c6a6..803b81e8 100644 --- a/backend/oasst_backend/utils/database_utils.py +++ b/backend/oasst_backend/utils/database_utils.py @@ -19,6 +19,7 @@ class CommitMode(IntEnum): NONE = 0 FLUSH = 1 COMMIT = 2 + ROLLBACK = 3 """ @@ -41,6 +42,8 @@ def managed_tx_method(auto_commit: CommitMode = CommitMode.COMMIT, num_retries=s self.db.commit() elif auto_commit == CommitMode.FLUSH: self.db.flush() + elif auto_commit == CommitMode.ROLLBACK: + self.db.rollback() if isinstance(result, SQLModel): self.db.refresh(result) return result @@ -75,6 +78,8 @@ def async_managed_tx_method( self.db.commit() elif auto_commit == CommitMode.FLUSH: self.db.flush() + elif auto_commit == CommitMode.ROLLBACK: + self.db.rollback() if isinstance(result, SQLModel): self.db.refresh(result) return result @@ -118,8 +123,8 @@ def managed_tx_function( session.commit() elif auto_commit == CommitMode.FLUSH: session.flush() - if isinstance(result, SQLModel): - session.refresh(result) + elif auto_commit == CommitMode.ROLLBACK: + session.rollback() return result except OperationalError: logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.") diff --git a/oasst-shared/oasst_shared/utils.py b/oasst-shared/oasst_shared/utils.py index 1e9f2ef1..57cb66cc 100644 --- a/oasst-shared/oasst_shared/utils.py +++ b/oasst-shared/oasst_shared/utils.py @@ -10,6 +10,13 @@ def utcnow() -> datetime: return datetime.now(timezone.utc) +def unaware_to_utc(d: datetime | None) -> datetime: + """Set timezeno to UTC if datetime is unaware (tzinfo == None).""" + if d and d.tzinfo is None: + return d.replace(tzinfo=timezone.utc) + return d + + class TimerError(Exception): """A custom exception used to report errors in use of Timer class"""