add update_leaderboard..(), admin/purge/{user-id}/messages, ban param for purge_user()

This commit is contained in:
Andreas Köpf
2023-01-19 15:10:22 +01:00
parent ef8a00e682
commit 547e355e27
6 changed files with 121 additions and 20 deletions
+56 -6
View File
@@ -1,3 +1,4 @@
from datetime import datetime
from uuid import UUID
import pydantic
@@ -5,12 +6,12 @@ from fastapi import APIRouter, Depends
from loguru import logger
from oasst_backend.api import deps
from oasst_backend.config import Settings, settings
from oasst_backend.models.api_client import ApiClient
from oasst_backend.models import ApiClient, User
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.tree_manager import TreeManager
from oasst_backend.utils.database_utils import CommitMode, managed_tx_function
from oasst_shared.schemas.protocol import SystemStats
from oasst_shared.utils import ScopeTimer
from oasst_shared.utils import ScopeTimer, unaware_to_utc
router = APIRouter()
@@ -76,24 +77,26 @@ class PurgeResultModel(pydantic.BaseModel):
duration: float
@router.post("/purge_user/{user_id}", response_model=PurgeResultModel)
@router.post("/purge/{user_id}", response_model=PurgeResultModel)
async def purge_user(
user_id: UUID,
preview: bool = True,
ban: bool = True,
api_client: ApiClient = Depends(deps.get_trusted_api_client),
) -> str:
assert api_client.trusted
@managed_tx_function(CommitMode.NONE if preview else CommitMode.COMMIT)
def purge_tx(session: deps.Session):
@managed_tx_function(CommitMode.ROLLBACK if preview else CommitMode.COMMIT)
def purge_tx(session: deps.Session) -> tuple[User, SystemStats, SystemStats]:
pr = PromptRepository(session, api_client)
stats_before = pr.get_stats()
user = pr.user_repository.get_user(user_id)
tm = TreeManager(session, pr)
tm.purge_user(user_id)
tm.purge_user(user_id=user_id, ban=ban)
session.expunge(user)
return user, stats_before, pr.get_stats()
timer = ScopeTimer()
@@ -111,3 +114,50 @@ async def purge_user(
logger.info(f"{before=}; {after=}")
return PurgeResultModel(before=before, after=after, preview=preview, duration=timer.elapsed)
@router.post("/purge/{user_id}/messages", response_model=PurgeResultModel)
async def purge_user_messages(
user_id: UUID,
purge_initial_prompts: bool = False,
min_date: datetime = None,
max_date: datetime = None,
preview: bool = True,
api_client: ApiClient = Depends(deps.get_trusted_api_client),
) -> str:
assert api_client.trusted
min_date = unaware_to_utc(min_date)
max_date = unaware_to_utc(max_date)
@managed_tx_function(CommitMode.ROLLBACK if preview else CommitMode.COMMIT)
def purge_user_messages_tx(session: deps.Session):
pr = PromptRepository(session, api_client)
stats_before = pr.get_stats()
user = pr.user_repository.get_user(user_id)
tm = TreeManager(session, pr)
tm.purge_user_messages(
user_id, purge_initial_prompts=purge_initial_prompts, min_date=min_date, max_date=max_date
)
session.expunge(user)
return user, stats_before, pr.get_stats()
timer = ScopeTimer()
user, before, after = purge_user_messages_tx()
timer.stop()
if preview:
logger.info(
f"PURGE USER MESSAGES PREVIEW: '{user.display_name}' (id: {str(user_id)}; username: '{user.username}'; auth-method: '{user.auth_method}')"
)
else:
logger.warning(
f"PURGE USER MESSAGES: '{user.display_name}' (id: {str(user_id)}; username: '{user.username}'; auth-method: '{user.auth_method}')"
)
logger.info(f"{before=}; {after=}")
return PurgeResultModel(before=before, after=after, preview=preview, duration=timer.elapsed)
@@ -6,6 +6,7 @@ from oasst_backend.models import ApiClient
from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame
from oasst_shared.schemas.protocol import LeaderboardStats
from sqlmodel import Session
from starlette.status import HTTP_204_NO_CONTENT
router = APIRouter()
@@ -19,3 +20,22 @@ def get_leaderboard(
) -> LeaderboardStats:
usr = UserStatsRepository(db)
return usr.get_leaderboard(time_frame, limit=max_count)
@router.post("/update/{time_frame}", response_model=None, status_code=HTTP_204_NO_CONTENT)
def update_leaderboard_time_frame(
time_frame: UserStatsTimeFrame,
api_client: ApiClient = Depends(deps.get_trusted_api_client),
db: Session = Depends(deps.get_db),
) -> LeaderboardStats:
usr = UserStatsRepository(db)
return usr.update_stats(time_frame=time_frame)
@router.post("/update", response_model=None, status_code=HTTP_204_NO_CONTENT)
def update_leaderboards_all(
api_client: ApiClient = Depends(deps.get_trusted_api_client),
db: Session = Depends(deps.get_db),
) -> LeaderboardStats:
usr = UserStatsRepository(db)
return usr.update_all_time_frames()
+4
View File
@@ -7,6 +7,7 @@ from oasst_backend.api.v1 import utils
from oasst_backend.models import ApiClient
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.schemas import protocol
from oasst_shared.utils import unaware_to_utc
from sqlmodel import Session
from starlette.status import HTTP_204_NO_CONTENT
@@ -29,6 +30,9 @@ def query_messages(
"""
Query messages.
"""
start_date = unaware_to_utc(start_date)
end_date = unaware_to_utc(end_date)
pr = PromptRepository(db, api_client)
messages = pr.query_messages(
username=username,
+27 -12
View File
@@ -10,14 +10,14 @@ import pydantic
from loguru import logger
from oasst_backend.api.v1.utils import prepare_conversation, prepare_conversation_message_list
from oasst_backend.config import TreeManagerConfiguration, settings
from oasst_backend.models import Message, MessageReaction, MessageTreeState, Task, TextLabels, message_tree_state
from oasst_backend.models import Message, MessageReaction, MessageTreeState, Task, TextLabels, User, message_tree_state
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.utils.database_utils import CommitMode, async_managed_tx_method, managed_tx_method
from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingModel, HfUrl, HuggingFaceAPI
from oasst_backend.utils.ranking import ranked_pairs
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session, func, not_, text
from sqlmodel import Session, func, not_, text, update
class TaskType(Enum):
@@ -577,7 +577,6 @@ class TreeManager:
mts = self.pr.fetch_tree_state(message_tree_id)
self._enter_state(mts, message_tree_state.State.ABORTED_LOW_GRADE)
@managed_tx_method(CommitMode.COMMIT)
def check_condition_for_growing_state(self, message_tree_id: UUID) -> bool:
logger.debug(f"check_condition_for_growing_state({message_tree_id=})")
@@ -595,7 +594,6 @@ class TreeManager:
self._enter_state(mts, message_tree_state.State.GROWING)
return True
@managed_tx_method(CommitMode.COMMIT)
def check_condition_for_ranking_state(self, message_tree_id: UUID) -> bool:
logger.debug(f"check_condition_for_ranking_state({message_tree_id=})")
@@ -613,7 +611,6 @@ class TreeManager:
self._enter_state(mts, message_tree_state.State.RANKING)
return True
@managed_tx_method(CommitMode.COMMIT)
def check_condition_for_scoring_state(
self, message_tree_id: UUID
) -> Tuple[bool, dict[UUID, list[MessageReaction]]]:
@@ -634,7 +631,6 @@ class TreeManager:
self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING)
return True, rankings_by_message
@managed_tx_method(CommitMode.COMMIT)
def update_message_ranks(self, message_tree_id: UUID, rankings_by_message: Dict[int, int]) -> bool:
mts = self.pr.fetch_tree_state(message_tree_id)
@@ -982,12 +978,21 @@ INNER JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Ranki
message_counts=self.tree_message_count_stats(only_active=True),
)
def get_user_messages_by_tree(self, user_id: UUID) -> Tuple[dict[UUID, list[Message]], list[Message]]:
def get_user_messages_by_tree(
self,
user_id: UUID,
min_date: datetime = None,
max_date: datetime = None,
) -> Tuple[dict[UUID, list[Message]], list[Message]]:
"""Returns a dict with replies by tree (excluding initial prompts) and list of initial prompts
associated with user_id."""
# query all messages of the user
qry = self.db.query(Message).filter(Message.user_id == user_id)
if min_date:
qry = qry.filter(Message.created_date >= min_date)
if max_date:
qry = qry.filter(Message.created_date <= max_date)
prompts: list[Message] = []
replies_by_tree: dict[UUID, list[Message]] = {}
@@ -1045,10 +1050,16 @@ DELETE FROM message WHERE message_tree_id = :message_tree_id;
logger.debug(f"purge_message_tree updated({message_tree_id=}) {r.rowcount} rows.")
@managed_tx_method(CommitMode.FLUSH)
def purge_user_messages(self, user_id: UUID, purge_initial_prompts: bool = True):
def purge_user_messages(
self,
user_id: UUID,
purge_initial_prompts: bool = True,
min_date: datetime = None,
max_date: datetime = None,
):
# find all affected message trees
replies_by_tree, prompts = self.get_user_messages_by_tree(user_id)
replies_by_tree, prompts = self.get_user_messages_by_tree(user_id, min_date, max_date)
# remove all trees based on inital prompts of the user
if purge_initial_prompts:
@@ -1094,9 +1105,11 @@ DELETE FROM message WHERE message_tree_id = :message_tree_id;
mts.active = True
self._enter_state(mts, message_tree_state.State.INITIAL_PROMPT_REVIEW)
self.check_condition_for_growing_state(tree_id)
self.check_condition_for_ranking_state(tree_id)
self.check_condition_for_scoring_state(tree_id)
@managed_tx_method(CommitMode.FLUSH)
def purge_user(self, user_id: UUID) -> None:
def purge_user(self, user_id: UUID, ban: bool = True) -> None:
self.purge_user_messages(user_id, purge_initial_prompts=True)
# delete all remaining rows and ban user
@@ -1106,12 +1119,14 @@ DELETE FROM message_reaction WHERE user_id = :user_id;
DELETE FROM task WHERE user_id = :user_id;
DELETE FROM message WHERE user_id = :user_id;
DELETE FROM user_stats WHERE user_id = :user_id;
UPDATE "user" SET deleted = TRUE, enabled = FALSE WHERE id = :user_id;
"""
r = self.db.execute(text(sql_purge_user), {"user_id": user_id})
logger.debug(f"purge_user({user_id=}): {r.rowcount} rows.")
if ban:
self.db.execute(update(User).filter(User.id == user_id).values(deleted=True, enabled=False))
if __name__ == "__main__":
from oasst_backend.api.deps import api_auth
@@ -1128,7 +1143,7 @@ if __name__ == "__main__":
tm = TreeManager(db, pr, cfg)
tm.ensure_tree_states()
# tm.purge_user_messages(user_id=UUID("2ef9ad21-0dc5-442d-8750-6f7f1790723f"), purge_initial_prompts=False)
tm.purge_user_messages(user_id=UUID("2ef9ad21-0dc5-442d-8750-6f7f1790723f"), purge_initial_prompts=False)
# tm.purge_user(user_id=UUID("2ef9ad21-0dc5-442d-8750-6f7f1790723f"))
# db.commit()
@@ -19,6 +19,7 @@ class CommitMode(IntEnum):
NONE = 0
FLUSH = 1
COMMIT = 2
ROLLBACK = 3
"""
@@ -41,6 +42,8 @@ def managed_tx_method(auto_commit: CommitMode = CommitMode.COMMIT, num_retries=s
self.db.commit()
elif auto_commit == CommitMode.FLUSH:
self.db.flush()
elif auto_commit == CommitMode.ROLLBACK:
self.db.rollback()
if isinstance(result, SQLModel):
self.db.refresh(result)
return result
@@ -75,6 +78,8 @@ def async_managed_tx_method(
self.db.commit()
elif auto_commit == CommitMode.FLUSH:
self.db.flush()
elif auto_commit == CommitMode.ROLLBACK:
self.db.rollback()
if isinstance(result, SQLModel):
self.db.refresh(result)
return result
@@ -118,8 +123,8 @@ def managed_tx_function(
session.commit()
elif auto_commit == CommitMode.FLUSH:
session.flush()
if isinstance(result, SQLModel):
session.refresh(result)
elif auto_commit == CommitMode.ROLLBACK:
session.rollback()
return result
except OperationalError:
logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.")
+7
View File
@@ -10,6 +10,13 @@ def utcnow() -> datetime:
return datetime.now(timezone.utc)
def unaware_to_utc(d: datetime | None) -> datetime:
"""Set timezeno to UTC if datetime is unaware (tzinfo == None)."""
if d and d.tzinfo is None:
return d.replace(tzinfo=timezone.utc)
return d
class TimerError(Exception):
"""A custom exception used to report errors in use of Timer class"""