diff --git a/backend/oasst_backend/api/v1/admin.py b/backend/oasst_backend/api/v1/admin.py index e8d3078e..fc04c272 100644 --- a/backend/oasst_backend/api/v1/admin.py +++ b/backend/oasst_backend/api/v1/admin.py @@ -1,7 +1,15 @@ +from uuid import UUID + import pydantic from fastapi import APIRouter, Depends from loguru import logger from oasst_backend.api import deps +from oasst_backend.models.api_client import ApiClient +from oasst_backend.prompt_repository import PromptRepository +from oasst_backend.tree_manager import TreeManager +from oasst_backend.utils.database_utils import CommitMode, managed_tx_function +from oasst_shared.schemas.protocol import SystemStats +from oasst_shared.utils import ScopeTimer router = APIRouter() @@ -29,3 +37,47 @@ async def create_api_client( ) logger.info(f"Created api_client with key {api_client.api_key}") return api_client.api_key + + +class PurgeResultModel(pydantic.BaseModel): + before: SystemStats + after: SystemStats + preview: bool + duration: float + + +@router.post("/purge_user/{user_id}", response_model=PurgeResultModel) +async def purge_user( + user_id: UUID, + preview: bool = True, + api_client: ApiClient = Depends(deps.get_trusted_api_client), +) -> str: + assert api_client.trusted + + @managed_tx_function(CommitMode.NONE if preview else CommitMode.COMMIT) + def purge_tx(session: deps.Session): + pr = PromptRepository(session, api_client) + + stats_before = pr.get_stats() + + user = pr.user_repository.get_user(user_id) + tm = TreeManager(session, pr) + tm.purge_user(user_id) + + return user, stats_before, pr.get_stats() + + timer = ScopeTimer() + user, before, after = purge_tx() + timer.stop() + + if preview: + logger.info( + f"PURGE USER PREVIEW: '{user.display_name}' (id: {str(user_id)}; username: '{user.username}'; auth-method: '{user.auth_method}')" + ) + else: + logger.warning( + f"PURGE USER: '{user.display_name}' (id: {str(user_id)}; username: '{user.username}'; auth-method: '{user.auth_method}')" + ) + + logger.info(f"{before=}; {after=}") + return PurgeResultModel(before=before, after=after, preview=preview, duration=timer.elapsed) diff --git a/backend/oasst_backend/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py index c65500fb..54b85821 100644 --- a/backend/oasst_backend/api/v1/tasks.py +++ b/backend/oasst_backend/api/v1/tasks.py @@ -36,6 +36,8 @@ def request_task( try: pr = PromptRepository(db, api_client, client_user=request.user) + pr.ensure_user_is_enabled() + tm = TreeManager(db, pr) task, message_tree_id, parent_message_id = tm.next_task(request.type) pr.task_repository.store_task(task, message_tree_id, parent_message_id, request.collective) @@ -85,6 +87,7 @@ def tasks_acknowledge( try: pr = PromptRepository(db, api_client) + pr.ensure_user_is_enabled() # here we store the message id in the database for the task logger.info(f"Frontend acknowledges task {task_id=}, {ack_request=}.") @@ -113,6 +116,7 @@ def tasks_acknowledge_failure( logger.info(f"Frontend reports failure to implement task {task_id=}, {nack_request=}.") api_client = deps.api_auth(api_key, db) pr = PromptRepository(db, api_client) + pr.ensure_user_is_enabled() pr.task_repository.acknowledge_task_failure(task_id) except (KeyError, RuntimeError): logger.exception("Failed to not acknowledge task.") diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 7f51ea19..2dd920e8 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -28,8 +28,7 @@ from oasst_backend.utils.database_utils import CommitMode, managed_tx_method from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema from oasst_shared.schemas.protocol import SystemStats -from sqlalchemy import update -from sqlmodel import Session, func +from sqlmodel import Session, func, text, update from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND @@ -53,6 +52,13 @@ class PromptRepository: ) self.journal = JournalWriter(db, api_client, self.user) + def ensure_user_is_enabled(self): + if self.user is None or self.user_id is None: + raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED) + + if self.user.deleted or not self.user.enabled: + raise OasstError("User account disabled", OasstErrorCode.USER_DISABLED) + def fetch_message_by_frontend_message_id(self, frontend_message_id: str, fail_if_missing: bool = True) -> Message: validate_frontend_message_id(frontend_message_id) message: Message = ( @@ -146,6 +152,8 @@ class PromptRepository: review_result: bool = False, check_tree_state: bool = True, ) -> Message: + self.ensure_user_is_enabled() + validate_frontend_message_id(frontend_message_id) validate_frontend_message_id(user_frontend_message_id) @@ -354,8 +362,7 @@ class PromptRepository: @managed_tx_method(CommitMode.FLUSH) def insert_reaction(self, task_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction: - if self.user_id is None: - raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED) + self.ensure_user_is_enabled() container = PayloadContainer(payload=payload) reaction = MessageReaction( @@ -499,7 +506,7 @@ class PromptRepository: messages = self.db.query(Message).filter(Message.parent_id.is_(None)).order_by(func.random()).limit(size).all() return messages - def fetch_message_tree(self, message_tree_id: UUID, reviewed: bool = True): + def fetch_message_tree(self, message_tree_id: UUID, reviewed: bool = True) -> list[Message]: qry = self.db.query(Message).filter(Message.message_tree_id == message_tree_id) if reviewed: qry = qry.filter(Message.review_result) @@ -702,6 +709,21 @@ class PromptRepository: return messages.all() + def update_children_counts(self, message_tree_id: UUID): + sql_update_children_count = """ +UPDATE message SET children_count = cc.children_count +FROM ( + SELECT m.id, count(c.id) - COALESCE(SUM(c.deleted::int), 0) AS children_count + FROM message m + LEFT JOIN message c ON m.id = c.parent_id + WHERE m.message_tree_id = :message_tree_id + GROUP BY m.id +) AS cc +WHERE message.id = cc.id; +""" + r = self.db.execute(text(sql_update_children_count), {"message_tree_id": message_tree_id}) + logger.debug(f"update_children_count({message_tree_id=}): {r.rowcount} rows.") + @managed_tx_method(CommitMode.COMMIT) def mark_messages_deleted(self, messages: Message | UUID | list[Message | UUID], recursive: bool = True): """ diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index 54f4d698..da8cd9a3 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -17,8 +17,7 @@ from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingM from oasst_backend.utils.ranking import ranked_pairs from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema -from sqlalchemy.sql import text -from sqlmodel import Session, func, not_ +from sqlmodel import Session, func, not_, text class TaskType(Enum): @@ -192,6 +191,8 @@ class TreeManager: return task_count_by_type def determine_task_availability(self) -> dict[protocol_schema.TaskRequestType, int]: + self.pr.ensure_user_is_enabled() + num_active_trees = self.query_num_active_trees() extendible_parents = self.query_extendible_parents() prompts_need_review = self.query_prompts_need_review() @@ -212,6 +213,8 @@ class TreeManager: logger.debug("TreeManager.next_task()") + self.pr.ensure_user_is_enabled() + num_active_trees = self.query_num_active_trees() prompts_need_review = self.query_prompts_need_review() replies_need_review = self.query_replies_need_review() @@ -445,6 +448,7 @@ class TreeManager: @async_managed_tx_method(CommitMode.COMMIT) async def handle_interaction(self, interaction: protocol_schema.AnyInteraction) -> protocol_schema.Task: pr = self.pr + pr.ensure_user_is_enabled() match type(interaction): case protocol_schema.TextReplyToMessage: logger.info( @@ -978,6 +982,136 @@ INNER JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Ranki message_counts=self.tree_message_count_stats(only_active=True), ) + def get_user_messages_by_tree(self, user_id: UUID) -> Tuple[dict[UUID, list[Message]], list[Message]]: + """Returns a dict with replies by tree (excluding initial prompts) and list of initial prompts + associated with user_id.""" + + # query all messages of the user + qry = self.db.query(Message).filter(Message.user_id == user_id) + + prompts: list[Message] = [] + replies_by_tree: dict[UUID, list[Message]] = {} + + # walk over result set and distinguish between initial prompts and replies + for m in qry: + m: Message + + if m.message_tree_id == m.id: + prompts.append(m) + else: + message_list = replies_by_tree.get(m.message_tree_id) + if message_list is None: + message_list = [m] + replies_by_tree[m.message_tree_id] = message_list + else: + message_list.append(m) + + return replies_by_tree, prompts + + def _purge_message_internal(self, message_id: UUID) -> None: + """This internal function deletes a single message. It does not take care of + descendants, children_count in parent etc.""" + + sql_purge_message = """ +DELETE FROM journal j USING message m WHERE j.message_id = :message_id; +DELETE FROM message_embedding e WHERE e.message_id = :message_id; +DELETE FROM message_toxicity t WHERE t.message_id = :message_id; +DELETE FROM text_labels l WHERE l.message_id = :message_id; +-- delete all ranking results that contain message +DELETE FROM message_reaction r WHERE r.payload_type = 'RankingReactionPayload' AND r.task_id IN ( + SELECT t.id FROM message m + JOIN task t ON m.parent_id = t.parent_message_id + WHERE m.id = :message_id); +-- delete task which inserted message +DELETE FROM task t using message m WHERE t.id = m.task_id AND m.id = :message_id; +DELETE FROM task t WHERE t.parent_message_id = :message_id; +DELETE FROM message WHERE id = :message_id; +""" + r = self.db.execute(text(sql_purge_message), {"message_id": message_id}) + logger.debug(f"purge_message({message_id=}): {r.rowcount} rows.") + + def purge_message_tree(self, message_tree_id: UUID) -> None: + sql_purge_message_tree = """ +DELETE FROM journal j USING message m WHERE j.message_id = m.Id AND m.message_tree_id = :message_tree_id; +DELETE FROM message_embedding e USING message m WHERE e.message_id = m.Id AND m.message_tree_id = :message_tree_id; +DELETE FROM message_toxicity t USING message m WHERE t.message_id = m.Id AND m.message_tree_id = :message_tree_id; +DELETE FROM text_labels l USING message m WHERE l.message_id = m.Id AND m.message_tree_id = :message_tree_id; +DELETE FROM message_reaction r USING task t WHERE r.task_id = t.id AND t.message_tree_id = :message_tree_id; +DELETE FROM task t WHERE t.message_tree_id = :message_tree_id; +DELETE FROM message_tree_state WHERE message_tree_id = :message_tree_id; +DELETE FROM message WHERE message_tree_id = :message_tree_id; +""" + r = self.db.execute(text(sql_purge_message_tree), {"message_tree_id": message_tree_id}) + logger.debug(f"purge_message_tree updated({message_tree_id=}) {r.rowcount} rows.") + + @managed_tx_method(CommitMode.FLUSH) + def purge_user_messages(self, user_id: UUID, purge_initial_prompts: bool = True): + + # find all affected message trees + replies_by_tree, prompts = self.get_user_messages_by_tree(user_id) + + # remove all trees based on inital prompts of the user + if purge_initial_prompts: + for p in prompts: + self.purge_message_tree(p.message_tree_id) + if p.message_tree_id in replies_by_tree: + del replies_by_tree[p.message_tree_id] + + # patch all affected message trees + for tree_id, replies in replies_by_tree.items(): + bad_parent_ids = set(m.id for m in replies) + + tree_messages = self.pr.fetch_message_tree(tree_id) + by_id = {m.id: m for m in tree_messages} + + def ancestor_ids(msg: Message) -> list[UUID]: + t = [] + while msg.parent_id is not None: + msg = by_id[msg.parent_id] + t.append(msg.id) + return t + + def is_descendant_of_deleted(m: Message) -> bool: + if m.id in bad_parent_ids: + return True + ancestors = ancestor_ids(m) + if any(a in bad_parent_ids for a in ancestors): + return True + return False + + # start with deepest messages first + tree_messages.sort(key=lambda x: x.depth, reverse=True) + for m in tree_messages: + if is_descendant_of_deleted(m): + self._purge_message_internal(m.id) + + # update childern counts + self.pr.update_children_counts(m.message_tree_id) + + # reactivate tree + logger.info(f"reactivating message tree {tree_id}") + mts = self.pr.fetch_tree_state(tree_id) + mts.active = True + self._enter_state(mts, message_tree_state.State.INITIAL_PROMPT_REVIEW) + self.check_condition_for_growing_state(tree_id) + + @managed_tx_method(CommitMode.FLUSH) + def purge_user(self, user_id: UUID) -> None: + self.purge_user_messages(user_id, purge_initial_prompts=True) + + # delete all remaining rows and ban user + sql_purge_user = """ +DELETE FROM journal WHERE user_id = :user_id; +DELETE FROM message_reaction WHERE user_id = :user_id; +DELETE FROM task WHERE user_id = :user_id; +DELETE FROM message WHERE user_id = :user_id; +DELETE FROM user_stats WHERE user_id = :user_id; +UPDATE "user" SET deleted = TRUE, enabled = FALSE WHERE id = :user_id; +""" + + r = self.db.execute(text(sql_purge_user), {"user_id": user_id}) + logger.debug(f"purge_user({user_id=}): {r.rowcount} rows.") + if __name__ == "__main__": from oasst_backend.api.deps import api_auth @@ -994,6 +1128,10 @@ if __name__ == "__main__": tm = TreeManager(db, pr, cfg) tm.ensure_tree_states() + # tm.purge_user_messages(user_id=UUID("2ef9ad21-0dc5-442d-8750-6f7f1790723f"), purge_initial_prompts=False) + # tm.purge_user(user_id=UUID("2ef9ad21-0dc5-442d-8750-6f7f1790723f")) + # db.commit() + # print("query_num_active_trees", tm.query_num_active_trees()) # print("query_incomplete_rankings", tm.query_incomplete_rankings()) # print("query_replies_need_review", tm.query_replies_need_review()) diff --git a/oasst-shared/oasst_shared/exceptions/oasst_api_error.py b/oasst-shared/oasst_shared/exceptions/oasst_api_error.py index e60ad746..e8cd2359 100644 --- a/oasst-shared/oasst_shared/exceptions/oasst_api_error.py +++ b/oasst-shared/oasst_shared/exceptions/oasst_api_error.py @@ -40,7 +40,7 @@ class OasstErrorCode(IntEnum): RATING_OUT_OF_RANGE = 2002 INVALID_RANKING_VALUE = 2003 INVALID_TASK_TYPE = 2004 - USER_NOT_SPECIFIED = 2005 + NO_MESSAGE_TREE_FOUND = 2006 NO_REPLIES_FOUND = 2007 INVALID_MESSAGE = 2008 @@ -62,11 +62,15 @@ class OasstErrorCode(IntEnum): TASK_NOT_COLLECTIVE = 2106 TASK_NOT_ASSIGNED_TO_USER = 2106 TASK_UNEXPECTED_PAYLOAD_TYPE_ = 2107 - USER_NOT_FOUND = 2200 # 3000-4000: external resources HUGGINGFACE_API_ERROR = 3001 + # 4000-5000: user + USER_NOT_SPECIFIED = 4000 + USER_DISABLED = 4001 + USER_NOT_FOUND = 4002 + class OasstError(Exception): """Base class for Open-Assistant exceptions.""" diff --git a/oasst-shared/oasst_shared/utils.py b/oasst-shared/oasst_shared/utils.py index 90ba8c8f..1e9f2ef1 100644 --- a/oasst-shared/oasst_shared/utils.py +++ b/oasst-shared/oasst_shared/utils.py @@ -10,14 +10,40 @@ def utcnow() -> datetime: return datetime.now(timezone.utc) +class TimerError(Exception): + """A custom exception used to report errors in use of Timer class""" + + +class ScopeTimer: + def __init__(self): + self.start() + + def start(self) -> None: + """Measure new start time""" + self.start_time = time.perf_counter() + + def stop(self) -> float: + """Store and return the elapsed time""" + self.elapsed = time.perf_counter() - self.start_time + return self.elapsed + + def __enter__(self): + """Start a new timer as a context manager""" + self.start() + return self + + def __exit__(self, *exc_info): + """Stop the context manager timer""" + self.stop() + + def log_timing(func=None, *, log_kwargs: bool = False, level: int | str = "DEBUG") -> None: def decorator(func): @wraps(func) def wrapped(*args, **kwargs): - start = time.time() + timer = ScopeTimer() result = func(*args, **kwargs) - end = time.time() - elapsed = end - start + elapsed = timer.stop() if log_kwargs: kwargs = ", ".join([f"{k}={v}" for k, v in kwargs.items()]) logger.log(level, f"Function '{func.__name__}({kwargs})' executed in {elapsed:f} s")