From 2b561a0dde1d3f564d7d40d33d839f85cd03f43e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Tue, 31 Jan 2023 09:28:33 +0100 Subject: [PATCH] add endpoint to put message tree into 'halted_by_moderator' state (#1025) * add PUT endpoint to put message tree into halted_by_moderator state * add private _reactivate_tree() method --- backend/oasst_backend/api/v1/messages.py | 53 ++++++++++++++++++- backend/oasst_backend/config.py | 7 ++- .../models/message_tree_state.py | 8 ++- backend/oasst_backend/schemas/message_tree.py | 14 +++++ backend/oasst_backend/tree_manager.py | 27 +++++++--- 5 files changed, 97 insertions(+), 12 deletions(-) create mode 100644 backend/oasst_backend/schemas/message_tree.py diff --git a/backend/oasst_backend/api/v1/messages.py b/backend/oasst_backend/api/v1/messages.py index 29468bf1..87b84157 100644 --- a/backend/oasst_backend/api/v1/messages.py +++ b/backend/oasst_backend/api/v1/messages.py @@ -5,8 +5,10 @@ from uuid import UUID from fastapi import APIRouter, Depends, Query from oasst_backend.api import deps from oasst_backend.api.v1 import utils -from oasst_backend.models import ApiClient +from oasst_backend.models import ApiClient, MessageTreeState from oasst_backend.prompt_repository import PromptRepository +from oasst_backend.schemas.message_tree import MessageTreeStateResponse +from oasst_backend.tree_manager import TreeManager from oasst_backend.utils.database_utils import CommitMode, managed_tx_function from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode from oasst_shared.schemas import protocol @@ -189,6 +191,55 @@ def get_tree( return utils.prepare_tree(tree, message.message_tree_id) +@router.get("/{message_id}/tree/state", response_model=MessageTreeStateResponse) +def get_message_tree_state( + *, + message_id: UUID, + frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id), + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), +) -> MessageTreeStateResponse: + + pr = PromptRepository(db, api_client, frontend_user=frontend_user) + message = pr.fetch_message(message_id=message_id, fail_if_missing=True) + mts = pr.fetch_tree_state(message.message_tree_id) + return MessageTreeStateResponse( + message_tree_id=mts.message_tree_id, + state=mts.state, + active=mts.active, + goal_tree_size=mts.goal_tree_size, + max_children_count=mts.max_children_count, + max_depth=mts.max_depth, + origin=mts.origin, + ) + + +@router.put("/{message_id}/tree/state", response_model=MessageTreeStateResponse) +def put_message_tree_state( + *, + message_id: UUID, + halt: bool, + frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id), + api_client: ApiClient = Depends(deps.get_trusted_api_client), +) -> MessageTreeStateResponse: + @managed_tx_function(CommitMode.COMMIT) + def halt_tree_tx(session: deps.Session) -> MessageTreeState: + pr = PromptRepository(session, api_client, frontend_user=frontend_user) + tm = TreeManager(session, pr) + return tm.halt_tree(message_id, halt=halt) + + mts = halt_tree_tx() + return MessageTreeStateResponse( + message_tree_id=mts.message_tree_id, + state=mts.state, + active=mts.active, + goal_tree_size=mts.goal_tree_size, + max_children_count=mts.max_children_count, + max_depth=mts.max_depth, + origin=mts.origin, + ) + + @router.get("/{message_id}/children", response_model=list[protocol.Message]) def get_children( *, diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index da89d4d4..e195d6e8 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -46,12 +46,11 @@ class TreeManagerConfiguration(BaseModel): num_required_rankings: int = 3 """Number of rankings in which the message participated.""" - p_activate_backlog_tree: float = 0.8 + p_activate_backlog_tree: float = 0.1 """Probability to activate a message tree in BACKLOG_RANKING state when another tree enters - a terminal state. Use this settting to control ratio of initial prompts and backlog tree - activations.""" + a terminal state.""" - min_active_rankings_per_lang: int = 2 + min_active_rankings_per_lang: int = 0 """When the number of active ranking tasks is below this value when a tree enters a terminal state an available trees in BACKLOG_RANKING will be actived (i.e. enters the RANKING state).""" diff --git a/backend/oasst_backend/models/message_tree_state.py b/backend/oasst_backend/models/message_tree_state.py index 00b94967..a286d483 100644 --- a/backend/oasst_backend/models/message_tree_state.py +++ b/backend/oasst_backend/models/message_tree_state.py @@ -57,7 +57,13 @@ VALID_STATES = ( State.BACKLOG_RANKING, ) -TERMINAL_STATES = (State.READY_FOR_EXPORT, State.ABORTED_LOW_GRADE, State.SCORING_FAILED, State.HALTED_BY_MODERATOR) +TERMINAL_STATES = ( + State.READY_FOR_EXPORT, + State.ABORTED_LOW_GRADE, + State.SCORING_FAILED, + State.HALTED_BY_MODERATOR, + State.BACKLOG_RANKING, +) class MessageTreeState(SQLModel, table=True): diff --git a/backend/oasst_backend/schemas/message_tree.py b/backend/oasst_backend/schemas/message_tree.py new file mode 100644 index 00000000..0eb63e35 --- /dev/null +++ b/backend/oasst_backend/schemas/message_tree.py @@ -0,0 +1,14 @@ +from uuid import UUID + +from oasst_backend.models.message_tree_state import State as TreeState +from pydantic import BaseModel + + +class MessageTreeStateResponse(BaseModel): + message_tree_id: UUID + state: TreeState + goal_tree_size: int + max_depth: int + max_children_count: int + active: bool + origin: str | None diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index 4c7a7df6..3f390920 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -635,8 +635,7 @@ class TreeManager: is_terminal = state in message_tree_state.TERMINAL_STATES was_active = mts.active - if is_terminal: - mts.active = False + mts.active = not is_terminal mts.state = state.value self.db.add(mts) self.db.flush @@ -1288,6 +1287,13 @@ DELETE FROM message WHERE message_tree_id = :message_tree_id; r = self.db.execute(text(sql_purge_message_tree), {"message_tree_id": message_tree_id}) logger.debug(f"purge_message_tree({message_tree_id=}) {r.rowcount} rows.") + def _reactivate_tree(self, mts: MessageTreeState): + self._enter_state(mts, message_tree_state.State.INITIAL_PROMPT_REVIEW) + tree_id = mts.message_tree_id + if self.check_condition_for_growing_state(tree_id): + if self.check_condition_for_ranking_state(tree_id): + self.check_condition_for_scoring_state(tree_id) + @managed_tx_method(CommitMode.FLUSH) def purge_user_messages( self, @@ -1347,10 +1353,7 @@ DELETE FROM message WHERE message_tree_id = :message_tree_id; logger.info(f"reactivating message tree {tree_id}") mts = self.pr.fetch_tree_state(tree_id) mts.active = True - self._enter_state(mts, message_tree_state.State.INITIAL_PROMPT_REVIEW) - self.check_condition_for_growing_state(tree_id) - self.check_condition_for_ranking_state(tree_id) - self.check_condition_for_scoring_state(tree_id) + self._reactivate_tree(mts) @managed_tx_method(CommitMode.FLUSH) def purge_user(self, user_id: UUID, ban: bool = True) -> None: @@ -1424,6 +1427,18 @@ DELETE FROM user_stats WHERE user_id = :user_id; except Exception: logger.exception(f"retry_scoring_failed_message_trees failed for ({mts.message_tree_id=})") + @managed_tx_method(CommitMode.FLUSH) + def halt_tree(self, message_id: UUID, halt: bool = True) -> MessageTreeState: + message = self.pr.fetch_message(message_id, fail_if_missing=True) + mts = self.pr.fetch_tree_state(message.message_tree_id) + + if halt: + self._enter_state(mts, message_tree_state.State.HALTED_BY_MODERATOR) + else: + self._reactivate_tree(mts) + + return mts + if __name__ == "__main__": from oasst_backend.api.deps import api_auth