mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-07-02 17:00:28 +08:00
add admin purge user function (#834)
* add admin purge user function
* improve comments
* minor naming changes
* ensuer user is enabled for tasks api requests
* add preview with stats to /admin/purge_user/{id} endpoint
* add update_children_counts()
This commit is contained in:
@@ -1,7 +1,15 @@
|
||||
from uuid import UUID
|
||||
|
||||
import pydantic
|
||||
from fastapi import APIRouter, Depends
|
||||
from loguru import logger
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.models.api_client import ApiClient
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.tree_manager import TreeManager
|
||||
from oasst_backend.utils.database_utils import CommitMode, managed_tx_function
|
||||
from oasst_shared.schemas.protocol import SystemStats
|
||||
from oasst_shared.utils import ScopeTimer
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -29,3 +37,47 @@ async def create_api_client(
|
||||
)
|
||||
logger.info(f"Created api_client with key {api_client.api_key}")
|
||||
return api_client.api_key
|
||||
|
||||
|
||||
class PurgeResultModel(pydantic.BaseModel):
|
||||
before: SystemStats
|
||||
after: SystemStats
|
||||
preview: bool
|
||||
duration: float
|
||||
|
||||
|
||||
@router.post("/purge_user/{user_id}", response_model=PurgeResultModel)
|
||||
async def purge_user(
|
||||
user_id: UUID,
|
||||
preview: bool = True,
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
) -> str:
|
||||
assert api_client.trusted
|
||||
|
||||
@managed_tx_function(CommitMode.NONE if preview else CommitMode.COMMIT)
|
||||
def purge_tx(session: deps.Session):
|
||||
pr = PromptRepository(session, api_client)
|
||||
|
||||
stats_before = pr.get_stats()
|
||||
|
||||
user = pr.user_repository.get_user(user_id)
|
||||
tm = TreeManager(session, pr)
|
||||
tm.purge_user(user_id)
|
||||
|
||||
return user, stats_before, pr.get_stats()
|
||||
|
||||
timer = ScopeTimer()
|
||||
user, before, after = purge_tx()
|
||||
timer.stop()
|
||||
|
||||
if preview:
|
||||
logger.info(
|
||||
f"PURGE USER PREVIEW: '{user.display_name}' (id: {str(user_id)}; username: '{user.username}'; auth-method: '{user.auth_method}')"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"PURGE USER: '{user.display_name}' (id: {str(user_id)}; username: '{user.username}'; auth-method: '{user.auth_method}')"
|
||||
)
|
||||
|
||||
logger.info(f"{before=}; {after=}")
|
||||
return PurgeResultModel(before=before, after=after, preview=preview, duration=timer.elapsed)
|
||||
|
||||
@@ -36,6 +36,8 @@ def request_task(
|
||||
|
||||
try:
|
||||
pr = PromptRepository(db, api_client, client_user=request.user)
|
||||
pr.ensure_user_is_enabled()
|
||||
|
||||
tm = TreeManager(db, pr)
|
||||
task, message_tree_id, parent_message_id = tm.next_task(request.type)
|
||||
pr.task_repository.store_task(task, message_tree_id, parent_message_id, request.collective)
|
||||
@@ -85,6 +87,7 @@ def tasks_acknowledge(
|
||||
|
||||
try:
|
||||
pr = PromptRepository(db, api_client)
|
||||
pr.ensure_user_is_enabled()
|
||||
|
||||
# here we store the message id in the database for the task
|
||||
logger.info(f"Frontend acknowledges task {task_id=}, {ack_request=}.")
|
||||
@@ -113,6 +116,7 @@ def tasks_acknowledge_failure(
|
||||
logger.info(f"Frontend reports failure to implement task {task_id=}, {nack_request=}.")
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
pr = PromptRepository(db, api_client)
|
||||
pr.ensure_user_is_enabled()
|
||||
pr.task_repository.acknowledge_task_failure(task_id)
|
||||
except (KeyError, RuntimeError):
|
||||
logger.exception("Failed to not acknowledge task.")
|
||||
|
||||
@@ -28,8 +28,7 @@ from oasst_backend.utils.database_utils import CommitMode, managed_tx_method
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from oasst_shared.schemas.protocol import SystemStats
|
||||
from sqlalchemy import update
|
||||
from sqlmodel import Session, func
|
||||
from sqlmodel import Session, func, text, update
|
||||
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
@@ -53,6 +52,13 @@ class PromptRepository:
|
||||
)
|
||||
self.journal = JournalWriter(db, api_client, self.user)
|
||||
|
||||
def ensure_user_is_enabled(self):
|
||||
if self.user is None or self.user_id is None:
|
||||
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
|
||||
|
||||
if self.user.deleted or not self.user.enabled:
|
||||
raise OasstError("User account disabled", OasstErrorCode.USER_DISABLED)
|
||||
|
||||
def fetch_message_by_frontend_message_id(self, frontend_message_id: str, fail_if_missing: bool = True) -> Message:
|
||||
validate_frontend_message_id(frontend_message_id)
|
||||
message: Message = (
|
||||
@@ -146,6 +152,8 @@ class PromptRepository:
|
||||
review_result: bool = False,
|
||||
check_tree_state: bool = True,
|
||||
) -> Message:
|
||||
self.ensure_user_is_enabled()
|
||||
|
||||
validate_frontend_message_id(frontend_message_id)
|
||||
validate_frontend_message_id(user_frontend_message_id)
|
||||
|
||||
@@ -354,8 +362,7 @@ class PromptRepository:
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def insert_reaction(self, task_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction:
|
||||
if self.user_id is None:
|
||||
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
|
||||
self.ensure_user_is_enabled()
|
||||
|
||||
container = PayloadContainer(payload=payload)
|
||||
reaction = MessageReaction(
|
||||
@@ -499,7 +506,7 @@ class PromptRepository:
|
||||
messages = self.db.query(Message).filter(Message.parent_id.is_(None)).order_by(func.random()).limit(size).all()
|
||||
return messages
|
||||
|
||||
def fetch_message_tree(self, message_tree_id: UUID, reviewed: bool = True):
|
||||
def fetch_message_tree(self, message_tree_id: UUID, reviewed: bool = True) -> list[Message]:
|
||||
qry = self.db.query(Message).filter(Message.message_tree_id == message_tree_id)
|
||||
if reviewed:
|
||||
qry = qry.filter(Message.review_result)
|
||||
@@ -702,6 +709,21 @@ class PromptRepository:
|
||||
|
||||
return messages.all()
|
||||
|
||||
def update_children_counts(self, message_tree_id: UUID):
|
||||
sql_update_children_count = """
|
||||
UPDATE message SET children_count = cc.children_count
|
||||
FROM (
|
||||
SELECT m.id, count(c.id) - COALESCE(SUM(c.deleted::int), 0) AS children_count
|
||||
FROM message m
|
||||
LEFT JOIN message c ON m.id = c.parent_id
|
||||
WHERE m.message_tree_id = :message_tree_id
|
||||
GROUP BY m.id
|
||||
) AS cc
|
||||
WHERE message.id = cc.id;
|
||||
"""
|
||||
r = self.db.execute(text(sql_update_children_count), {"message_tree_id": message_tree_id})
|
||||
logger.debug(f"update_children_count({message_tree_id=}): {r.rowcount} rows.")
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def mark_messages_deleted(self, messages: Message | UUID | list[Message | UUID], recursive: bool = True):
|
||||
"""
|
||||
|
||||
@@ -17,8 +17,7 @@ from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingM
|
||||
from oasst_backend.utils.ranking import ranked_pairs
|
||||
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlalchemy.sql import text
|
||||
from sqlmodel import Session, func, not_
|
||||
from sqlmodel import Session, func, not_, text
|
||||
|
||||
|
||||
class TaskType(Enum):
|
||||
@@ -192,6 +191,8 @@ class TreeManager:
|
||||
return task_count_by_type
|
||||
|
||||
def determine_task_availability(self) -> dict[protocol_schema.TaskRequestType, int]:
|
||||
self.pr.ensure_user_is_enabled()
|
||||
|
||||
num_active_trees = self.query_num_active_trees()
|
||||
extendible_parents = self.query_extendible_parents()
|
||||
prompts_need_review = self.query_prompts_need_review()
|
||||
@@ -212,6 +213,8 @@ class TreeManager:
|
||||
|
||||
logger.debug("TreeManager.next_task()")
|
||||
|
||||
self.pr.ensure_user_is_enabled()
|
||||
|
||||
num_active_trees = self.query_num_active_trees()
|
||||
prompts_need_review = self.query_prompts_need_review()
|
||||
replies_need_review = self.query_replies_need_review()
|
||||
@@ -445,6 +448,7 @@ class TreeManager:
|
||||
@async_managed_tx_method(CommitMode.COMMIT)
|
||||
async def handle_interaction(self, interaction: protocol_schema.AnyInteraction) -> protocol_schema.Task:
|
||||
pr = self.pr
|
||||
pr.ensure_user_is_enabled()
|
||||
match type(interaction):
|
||||
case protocol_schema.TextReplyToMessage:
|
||||
logger.info(
|
||||
@@ -978,6 +982,136 @@ INNER JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Ranki
|
||||
message_counts=self.tree_message_count_stats(only_active=True),
|
||||
)
|
||||
|
||||
def get_user_messages_by_tree(self, user_id: UUID) -> Tuple[dict[UUID, list[Message]], list[Message]]:
|
||||
"""Returns a dict with replies by tree (excluding initial prompts) and list of initial prompts
|
||||
associated with user_id."""
|
||||
|
||||
# query all messages of the user
|
||||
qry = self.db.query(Message).filter(Message.user_id == user_id)
|
||||
|
||||
prompts: list[Message] = []
|
||||
replies_by_tree: dict[UUID, list[Message]] = {}
|
||||
|
||||
# walk over result set and distinguish between initial prompts and replies
|
||||
for m in qry:
|
||||
m: Message
|
||||
|
||||
if m.message_tree_id == m.id:
|
||||
prompts.append(m)
|
||||
else:
|
||||
message_list = replies_by_tree.get(m.message_tree_id)
|
||||
if message_list is None:
|
||||
message_list = [m]
|
||||
replies_by_tree[m.message_tree_id] = message_list
|
||||
else:
|
||||
message_list.append(m)
|
||||
|
||||
return replies_by_tree, prompts
|
||||
|
||||
def _purge_message_internal(self, message_id: UUID) -> None:
|
||||
"""This internal function deletes a single message. It does not take care of
|
||||
descendants, children_count in parent etc."""
|
||||
|
||||
sql_purge_message = """
|
||||
DELETE FROM journal j USING message m WHERE j.message_id = :message_id;
|
||||
DELETE FROM message_embedding e WHERE e.message_id = :message_id;
|
||||
DELETE FROM message_toxicity t WHERE t.message_id = :message_id;
|
||||
DELETE FROM text_labels l WHERE l.message_id = :message_id;
|
||||
-- delete all ranking results that contain message
|
||||
DELETE FROM message_reaction r WHERE r.payload_type = 'RankingReactionPayload' AND r.task_id IN (
|
||||
SELECT t.id FROM message m
|
||||
JOIN task t ON m.parent_id = t.parent_message_id
|
||||
WHERE m.id = :message_id);
|
||||
-- delete task which inserted message
|
||||
DELETE FROM task t using message m WHERE t.id = m.task_id AND m.id = :message_id;
|
||||
DELETE FROM task t WHERE t.parent_message_id = :message_id;
|
||||
DELETE FROM message WHERE id = :message_id;
|
||||
"""
|
||||
r = self.db.execute(text(sql_purge_message), {"message_id": message_id})
|
||||
logger.debug(f"purge_message({message_id=}): {r.rowcount} rows.")
|
||||
|
||||
def purge_message_tree(self, message_tree_id: UUID) -> None:
|
||||
sql_purge_message_tree = """
|
||||
DELETE FROM journal j USING message m WHERE j.message_id = m.Id AND m.message_tree_id = :message_tree_id;
|
||||
DELETE FROM message_embedding e USING message m WHERE e.message_id = m.Id AND m.message_tree_id = :message_tree_id;
|
||||
DELETE FROM message_toxicity t USING message m WHERE t.message_id = m.Id AND m.message_tree_id = :message_tree_id;
|
||||
DELETE FROM text_labels l USING message m WHERE l.message_id = m.Id AND m.message_tree_id = :message_tree_id;
|
||||
DELETE FROM message_reaction r USING task t WHERE r.task_id = t.id AND t.message_tree_id = :message_tree_id;
|
||||
DELETE FROM task t WHERE t.message_tree_id = :message_tree_id;
|
||||
DELETE FROM message_tree_state WHERE message_tree_id = :message_tree_id;
|
||||
DELETE FROM message WHERE message_tree_id = :message_tree_id;
|
||||
"""
|
||||
r = self.db.execute(text(sql_purge_message_tree), {"message_tree_id": message_tree_id})
|
||||
logger.debug(f"purge_message_tree updated({message_tree_id=}) {r.rowcount} rows.")
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def purge_user_messages(self, user_id: UUID, purge_initial_prompts: bool = True):
|
||||
|
||||
# find all affected message trees
|
||||
replies_by_tree, prompts = self.get_user_messages_by_tree(user_id)
|
||||
|
||||
# remove all trees based on inital prompts of the user
|
||||
if purge_initial_prompts:
|
||||
for p in prompts:
|
||||
self.purge_message_tree(p.message_tree_id)
|
||||
if p.message_tree_id in replies_by_tree:
|
||||
del replies_by_tree[p.message_tree_id]
|
||||
|
||||
# patch all affected message trees
|
||||
for tree_id, replies in replies_by_tree.items():
|
||||
bad_parent_ids = set(m.id for m in replies)
|
||||
|
||||
tree_messages = self.pr.fetch_message_tree(tree_id)
|
||||
by_id = {m.id: m for m in tree_messages}
|
||||
|
||||
def ancestor_ids(msg: Message) -> list[UUID]:
|
||||
t = []
|
||||
while msg.parent_id is not None:
|
||||
msg = by_id[msg.parent_id]
|
||||
t.append(msg.id)
|
||||
return t
|
||||
|
||||
def is_descendant_of_deleted(m: Message) -> bool:
|
||||
if m.id in bad_parent_ids:
|
||||
return True
|
||||
ancestors = ancestor_ids(m)
|
||||
if any(a in bad_parent_ids for a in ancestors):
|
||||
return True
|
||||
return False
|
||||
|
||||
# start with deepest messages first
|
||||
tree_messages.sort(key=lambda x: x.depth, reverse=True)
|
||||
for m in tree_messages:
|
||||
if is_descendant_of_deleted(m):
|
||||
self._purge_message_internal(m.id)
|
||||
|
||||
# update childern counts
|
||||
self.pr.update_children_counts(m.message_tree_id)
|
||||
|
||||
# reactivate tree
|
||||
logger.info(f"reactivating message tree {tree_id}")
|
||||
mts = self.pr.fetch_tree_state(tree_id)
|
||||
mts.active = True
|
||||
self._enter_state(mts, message_tree_state.State.INITIAL_PROMPT_REVIEW)
|
||||
self.check_condition_for_growing_state(tree_id)
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def purge_user(self, user_id: UUID) -> None:
|
||||
self.purge_user_messages(user_id, purge_initial_prompts=True)
|
||||
|
||||
# delete all remaining rows and ban user
|
||||
sql_purge_user = """
|
||||
DELETE FROM journal WHERE user_id = :user_id;
|
||||
DELETE FROM message_reaction WHERE user_id = :user_id;
|
||||
DELETE FROM task WHERE user_id = :user_id;
|
||||
DELETE FROM message WHERE user_id = :user_id;
|
||||
DELETE FROM user_stats WHERE user_id = :user_id;
|
||||
UPDATE "user" SET deleted = TRUE, enabled = FALSE WHERE id = :user_id;
|
||||
"""
|
||||
|
||||
r = self.db.execute(text(sql_purge_user), {"user_id": user_id})
|
||||
logger.debug(f"purge_user({user_id=}): {r.rowcount} rows.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from oasst_backend.api.deps import api_auth
|
||||
@@ -994,6 +1128,10 @@ if __name__ == "__main__":
|
||||
tm = TreeManager(db, pr, cfg)
|
||||
tm.ensure_tree_states()
|
||||
|
||||
# tm.purge_user_messages(user_id=UUID("2ef9ad21-0dc5-442d-8750-6f7f1790723f"), purge_initial_prompts=False)
|
||||
# tm.purge_user(user_id=UUID("2ef9ad21-0dc5-442d-8750-6f7f1790723f"))
|
||||
# db.commit()
|
||||
|
||||
# print("query_num_active_trees", tm.query_num_active_trees())
|
||||
# print("query_incomplete_rankings", tm.query_incomplete_rankings())
|
||||
# print("query_replies_need_review", tm.query_replies_need_review())
|
||||
|
||||
@@ -40,7 +40,7 @@ class OasstErrorCode(IntEnum):
|
||||
RATING_OUT_OF_RANGE = 2002
|
||||
INVALID_RANKING_VALUE = 2003
|
||||
INVALID_TASK_TYPE = 2004
|
||||
USER_NOT_SPECIFIED = 2005
|
||||
|
||||
NO_MESSAGE_TREE_FOUND = 2006
|
||||
NO_REPLIES_FOUND = 2007
|
||||
INVALID_MESSAGE = 2008
|
||||
@@ -62,11 +62,15 @@ class OasstErrorCode(IntEnum):
|
||||
TASK_NOT_COLLECTIVE = 2106
|
||||
TASK_NOT_ASSIGNED_TO_USER = 2106
|
||||
TASK_UNEXPECTED_PAYLOAD_TYPE_ = 2107
|
||||
USER_NOT_FOUND = 2200
|
||||
|
||||
# 3000-4000: external resources
|
||||
HUGGINGFACE_API_ERROR = 3001
|
||||
|
||||
# 4000-5000: user
|
||||
USER_NOT_SPECIFIED = 4000
|
||||
USER_DISABLED = 4001
|
||||
USER_NOT_FOUND = 4002
|
||||
|
||||
|
||||
class OasstError(Exception):
|
||||
"""Base class for Open-Assistant exceptions."""
|
||||
|
||||
@@ -10,14 +10,40 @@ def utcnow() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
class TimerError(Exception):
|
||||
"""A custom exception used to report errors in use of Timer class"""
|
||||
|
||||
|
||||
class ScopeTimer:
|
||||
def __init__(self):
|
||||
self.start()
|
||||
|
||||
def start(self) -> None:
|
||||
"""Measure new start time"""
|
||||
self.start_time = time.perf_counter()
|
||||
|
||||
def stop(self) -> float:
|
||||
"""Store and return the elapsed time"""
|
||||
self.elapsed = time.perf_counter() - self.start_time
|
||||
return self.elapsed
|
||||
|
||||
def __enter__(self):
|
||||
"""Start a new timer as a context manager"""
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc_info):
|
||||
"""Stop the context manager timer"""
|
||||
self.stop()
|
||||
|
||||
|
||||
def log_timing(func=None, *, log_kwargs: bool = False, level: int | str = "DEBUG") -> None:
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapped(*args, **kwargs):
|
||||
start = time.time()
|
||||
timer = ScopeTimer()
|
||||
result = func(*args, **kwargs)
|
||||
end = time.time()
|
||||
elapsed = end - start
|
||||
elapsed = timer.stop()
|
||||
if log_kwargs:
|
||||
kwargs = ", ".join([f"{k}={v}" for k, v in kwargs.items()])
|
||||
logger.log(level, f"Function '{func.__name__}({kwargs})' executed in {elapsed:f} s")
|
||||
|
||||
Reference in New Issue
Block a user