mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
add update_leaderboard..(), admin/purge/{user-id}/messages, ban param for purge_user()
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
import pydantic
|
||||
@@ -5,12 +6,12 @@ from fastapi import APIRouter, Depends
|
||||
from loguru import logger
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.config import Settings, settings
|
||||
from oasst_backend.models.api_client import ApiClient
|
||||
from oasst_backend.models import ApiClient, User
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.tree_manager import TreeManager
|
||||
from oasst_backend.utils.database_utils import CommitMode, managed_tx_function
|
||||
from oasst_shared.schemas.protocol import SystemStats
|
||||
from oasst_shared.utils import ScopeTimer
|
||||
from oasst_shared.utils import ScopeTimer, unaware_to_utc
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -76,24 +77,26 @@ class PurgeResultModel(pydantic.BaseModel):
|
||||
duration: float
|
||||
|
||||
|
||||
@router.post("/purge_user/{user_id}", response_model=PurgeResultModel)
|
||||
@router.post("/purge/{user_id}", response_model=PurgeResultModel)
|
||||
async def purge_user(
|
||||
user_id: UUID,
|
||||
preview: bool = True,
|
||||
ban: bool = True,
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
) -> str:
|
||||
assert api_client.trusted
|
||||
|
||||
@managed_tx_function(CommitMode.NONE if preview else CommitMode.COMMIT)
|
||||
def purge_tx(session: deps.Session):
|
||||
@managed_tx_function(CommitMode.ROLLBACK if preview else CommitMode.COMMIT)
|
||||
def purge_tx(session: deps.Session) -> tuple[User, SystemStats, SystemStats]:
|
||||
pr = PromptRepository(session, api_client)
|
||||
|
||||
stats_before = pr.get_stats()
|
||||
|
||||
user = pr.user_repository.get_user(user_id)
|
||||
tm = TreeManager(session, pr)
|
||||
tm.purge_user(user_id)
|
||||
tm.purge_user(user_id=user_id, ban=ban)
|
||||
|
||||
session.expunge(user)
|
||||
return user, stats_before, pr.get_stats()
|
||||
|
||||
timer = ScopeTimer()
|
||||
@@ -111,3 +114,50 @@ async def purge_user(
|
||||
|
||||
logger.info(f"{before=}; {after=}")
|
||||
return PurgeResultModel(before=before, after=after, preview=preview, duration=timer.elapsed)
|
||||
|
||||
|
||||
@router.post("/purge/{user_id}/messages", response_model=PurgeResultModel)
|
||||
async def purge_user_messages(
|
||||
user_id: UUID,
|
||||
purge_initial_prompts: bool = False,
|
||||
min_date: datetime = None,
|
||||
max_date: datetime = None,
|
||||
preview: bool = True,
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
) -> str:
|
||||
assert api_client.trusted
|
||||
|
||||
min_date = unaware_to_utc(min_date)
|
||||
max_date = unaware_to_utc(max_date)
|
||||
|
||||
@managed_tx_function(CommitMode.ROLLBACK if preview else CommitMode.COMMIT)
|
||||
def purge_user_messages_tx(session: deps.Session):
|
||||
pr = PromptRepository(session, api_client)
|
||||
|
||||
stats_before = pr.get_stats()
|
||||
|
||||
user = pr.user_repository.get_user(user_id)
|
||||
|
||||
tm = TreeManager(session, pr)
|
||||
tm.purge_user_messages(
|
||||
user_id, purge_initial_prompts=purge_initial_prompts, min_date=min_date, max_date=max_date
|
||||
)
|
||||
|
||||
session.expunge(user)
|
||||
return user, stats_before, pr.get_stats()
|
||||
|
||||
timer = ScopeTimer()
|
||||
user, before, after = purge_user_messages_tx()
|
||||
timer.stop()
|
||||
|
||||
if preview:
|
||||
logger.info(
|
||||
f"PURGE USER MESSAGES PREVIEW: '{user.display_name}' (id: {str(user_id)}; username: '{user.username}'; auth-method: '{user.auth_method}')"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"PURGE USER MESSAGES: '{user.display_name}' (id: {str(user_id)}; username: '{user.username}'; auth-method: '{user.auth_method}')"
|
||||
)
|
||||
|
||||
logger.info(f"{before=}; {after=}")
|
||||
return PurgeResultModel(before=before, after=after, preview=preview, duration=timer.elapsed)
|
||||
|
||||
@@ -6,6 +6,7 @@ from oasst_backend.models import ApiClient
|
||||
from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame
|
||||
from oasst_shared.schemas.protocol import LeaderboardStats
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_204_NO_CONTENT
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -19,3 +20,22 @@ def get_leaderboard(
|
||||
) -> LeaderboardStats:
|
||||
usr = UserStatsRepository(db)
|
||||
return usr.get_leaderboard(time_frame, limit=max_count)
|
||||
|
||||
|
||||
@router.post("/update/{time_frame}", response_model=None, status_code=HTTP_204_NO_CONTENT)
|
||||
def update_leaderboard_time_frame(
|
||||
time_frame: UserStatsTimeFrame,
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
) -> LeaderboardStats:
|
||||
usr = UserStatsRepository(db)
|
||||
return usr.update_stats(time_frame=time_frame)
|
||||
|
||||
|
||||
@router.post("/update", response_model=None, status_code=HTTP_204_NO_CONTENT)
|
||||
def update_leaderboards_all(
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
) -> LeaderboardStats:
|
||||
usr = UserStatsRepository(db)
|
||||
return usr.update_all_time_frames()
|
||||
|
||||
@@ -7,6 +7,7 @@ from oasst_backend.api.v1 import utils
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.schemas import protocol
|
||||
from oasst_shared.utils import unaware_to_utc
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_204_NO_CONTENT
|
||||
|
||||
@@ -29,6 +30,9 @@ def query_messages(
|
||||
"""
|
||||
Query messages.
|
||||
"""
|
||||
start_date = unaware_to_utc(start_date)
|
||||
end_date = unaware_to_utc(end_date)
|
||||
|
||||
pr = PromptRepository(db, api_client)
|
||||
messages = pr.query_messages(
|
||||
username=username,
|
||||
|
||||
@@ -10,14 +10,14 @@ import pydantic
|
||||
from loguru import logger
|
||||
from oasst_backend.api.v1.utils import prepare_conversation, prepare_conversation_message_list
|
||||
from oasst_backend.config import TreeManagerConfiguration, settings
|
||||
from oasst_backend.models import Message, MessageReaction, MessageTreeState, Task, TextLabels, message_tree_state
|
||||
from oasst_backend.models import Message, MessageReaction, MessageTreeState, Task, TextLabels, User, message_tree_state
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.utils.database_utils import CommitMode, async_managed_tx_method, managed_tx_method
|
||||
from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingModel, HfUrl, HuggingFaceAPI
|
||||
from oasst_backend.utils.ranking import ranked_pairs
|
||||
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session, func, not_, text
|
||||
from sqlmodel import Session, func, not_, text, update
|
||||
|
||||
|
||||
class TaskType(Enum):
|
||||
@@ -577,7 +577,6 @@ class TreeManager:
|
||||
mts = self.pr.fetch_tree_state(message_tree_id)
|
||||
self._enter_state(mts, message_tree_state.State.ABORTED_LOW_GRADE)
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def check_condition_for_growing_state(self, message_tree_id: UUID) -> bool:
|
||||
logger.debug(f"check_condition_for_growing_state({message_tree_id=})")
|
||||
|
||||
@@ -595,7 +594,6 @@ class TreeManager:
|
||||
self._enter_state(mts, message_tree_state.State.GROWING)
|
||||
return True
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def check_condition_for_ranking_state(self, message_tree_id: UUID) -> bool:
|
||||
logger.debug(f"check_condition_for_ranking_state({message_tree_id=})")
|
||||
|
||||
@@ -613,7 +611,6 @@ class TreeManager:
|
||||
self._enter_state(mts, message_tree_state.State.RANKING)
|
||||
return True
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def check_condition_for_scoring_state(
|
||||
self, message_tree_id: UUID
|
||||
) -> Tuple[bool, dict[UUID, list[MessageReaction]]]:
|
||||
@@ -634,7 +631,6 @@ class TreeManager:
|
||||
self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING)
|
||||
return True, rankings_by_message
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def update_message_ranks(self, message_tree_id: UUID, rankings_by_message: Dict[int, int]) -> bool:
|
||||
|
||||
mts = self.pr.fetch_tree_state(message_tree_id)
|
||||
@@ -982,12 +978,21 @@ INNER JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Ranki
|
||||
message_counts=self.tree_message_count_stats(only_active=True),
|
||||
)
|
||||
|
||||
def get_user_messages_by_tree(self, user_id: UUID) -> Tuple[dict[UUID, list[Message]], list[Message]]:
|
||||
def get_user_messages_by_tree(
|
||||
self,
|
||||
user_id: UUID,
|
||||
min_date: datetime = None,
|
||||
max_date: datetime = None,
|
||||
) -> Tuple[dict[UUID, list[Message]], list[Message]]:
|
||||
"""Returns a dict with replies by tree (excluding initial prompts) and list of initial prompts
|
||||
associated with user_id."""
|
||||
|
||||
# query all messages of the user
|
||||
qry = self.db.query(Message).filter(Message.user_id == user_id)
|
||||
if min_date:
|
||||
qry = qry.filter(Message.created_date >= min_date)
|
||||
if max_date:
|
||||
qry = qry.filter(Message.created_date <= max_date)
|
||||
|
||||
prompts: list[Message] = []
|
||||
replies_by_tree: dict[UUID, list[Message]] = {}
|
||||
@@ -1045,10 +1050,16 @@ DELETE FROM message WHERE message_tree_id = :message_tree_id;
|
||||
logger.debug(f"purge_message_tree updated({message_tree_id=}) {r.rowcount} rows.")
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def purge_user_messages(self, user_id: UUID, purge_initial_prompts: bool = True):
|
||||
def purge_user_messages(
|
||||
self,
|
||||
user_id: UUID,
|
||||
purge_initial_prompts: bool = True,
|
||||
min_date: datetime = None,
|
||||
max_date: datetime = None,
|
||||
):
|
||||
|
||||
# find all affected message trees
|
||||
replies_by_tree, prompts = self.get_user_messages_by_tree(user_id)
|
||||
replies_by_tree, prompts = self.get_user_messages_by_tree(user_id, min_date, max_date)
|
||||
|
||||
# remove all trees based on inital prompts of the user
|
||||
if purge_initial_prompts:
|
||||
@@ -1094,9 +1105,11 @@ DELETE FROM message WHERE message_tree_id = :message_tree_id;
|
||||
mts.active = True
|
||||
self._enter_state(mts, message_tree_state.State.INITIAL_PROMPT_REVIEW)
|
||||
self.check_condition_for_growing_state(tree_id)
|
||||
self.check_condition_for_ranking_state(tree_id)
|
||||
self.check_condition_for_scoring_state(tree_id)
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def purge_user(self, user_id: UUID) -> None:
|
||||
def purge_user(self, user_id: UUID, ban: bool = True) -> None:
|
||||
self.purge_user_messages(user_id, purge_initial_prompts=True)
|
||||
|
||||
# delete all remaining rows and ban user
|
||||
@@ -1106,12 +1119,14 @@ DELETE FROM message_reaction WHERE user_id = :user_id;
|
||||
DELETE FROM task WHERE user_id = :user_id;
|
||||
DELETE FROM message WHERE user_id = :user_id;
|
||||
DELETE FROM user_stats WHERE user_id = :user_id;
|
||||
UPDATE "user" SET deleted = TRUE, enabled = FALSE WHERE id = :user_id;
|
||||
"""
|
||||
|
||||
r = self.db.execute(text(sql_purge_user), {"user_id": user_id})
|
||||
logger.debug(f"purge_user({user_id=}): {r.rowcount} rows.")
|
||||
|
||||
if ban:
|
||||
self.db.execute(update(User).filter(User.id == user_id).values(deleted=True, enabled=False))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from oasst_backend.api.deps import api_auth
|
||||
@@ -1128,7 +1143,7 @@ if __name__ == "__main__":
|
||||
tm = TreeManager(db, pr, cfg)
|
||||
tm.ensure_tree_states()
|
||||
|
||||
# tm.purge_user_messages(user_id=UUID("2ef9ad21-0dc5-442d-8750-6f7f1790723f"), purge_initial_prompts=False)
|
||||
tm.purge_user_messages(user_id=UUID("2ef9ad21-0dc5-442d-8750-6f7f1790723f"), purge_initial_prompts=False)
|
||||
# tm.purge_user(user_id=UUID("2ef9ad21-0dc5-442d-8750-6f7f1790723f"))
|
||||
# db.commit()
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ class CommitMode(IntEnum):
|
||||
NONE = 0
|
||||
FLUSH = 1
|
||||
COMMIT = 2
|
||||
ROLLBACK = 3
|
||||
|
||||
|
||||
"""
|
||||
@@ -41,6 +42,8 @@ def managed_tx_method(auto_commit: CommitMode = CommitMode.COMMIT, num_retries=s
|
||||
self.db.commit()
|
||||
elif auto_commit == CommitMode.FLUSH:
|
||||
self.db.flush()
|
||||
elif auto_commit == CommitMode.ROLLBACK:
|
||||
self.db.rollback()
|
||||
if isinstance(result, SQLModel):
|
||||
self.db.refresh(result)
|
||||
return result
|
||||
@@ -75,6 +78,8 @@ def async_managed_tx_method(
|
||||
self.db.commit()
|
||||
elif auto_commit == CommitMode.FLUSH:
|
||||
self.db.flush()
|
||||
elif auto_commit == CommitMode.ROLLBACK:
|
||||
self.db.rollback()
|
||||
if isinstance(result, SQLModel):
|
||||
self.db.refresh(result)
|
||||
return result
|
||||
@@ -118,8 +123,8 @@ def managed_tx_function(
|
||||
session.commit()
|
||||
elif auto_commit == CommitMode.FLUSH:
|
||||
session.flush()
|
||||
if isinstance(result, SQLModel):
|
||||
session.refresh(result)
|
||||
elif auto_commit == CommitMode.ROLLBACK:
|
||||
session.rollback()
|
||||
return result
|
||||
except OperationalError:
|
||||
logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.")
|
||||
|
||||
@@ -10,6 +10,13 @@ def utcnow() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def unaware_to_utc(d: datetime | None) -> datetime:
|
||||
"""Set timezeno to UTC if datetime is unaware (tzinfo == None)."""
|
||||
if d and d.tzinfo is None:
|
||||
return d.replace(tzinfo=timezone.utc)
|
||||
return d
|
||||
|
||||
|
||||
class TimerError(Exception):
|
||||
"""A custom exception used to report errors in use of Timer class"""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user