mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-28 16:20:34 +08:00
add endpoint to put message tree into 'halted_by_moderator' state (#1025)
* add PUT endpoint to put message tree into halted_by_moderator state * add private _reactivate_tree() method
This commit is contained in:
@@ -5,8 +5,10 @@ from uuid import UUID
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.api.v1 import utils
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.models import ApiClient, MessageTreeState
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.schemas.message_tree import MessageTreeStateResponse
|
||||
from oasst_backend.tree_manager import TreeManager
|
||||
from oasst_backend.utils.database_utils import CommitMode, managed_tx_function
|
||||
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol
|
||||
@@ -189,6 +191,55 @@ def get_tree(
|
||||
return utils.prepare_tree(tree, message.message_tree_id)
|
||||
|
||||
|
||||
@router.get("/{message_id}/tree/state", response_model=MessageTreeStateResponse)
|
||||
def get_message_tree_state(
|
||||
*,
|
||||
message_id: UUID,
|
||||
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
) -> MessageTreeStateResponse:
|
||||
|
||||
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
|
||||
message = pr.fetch_message(message_id=message_id, fail_if_missing=True)
|
||||
mts = pr.fetch_tree_state(message.message_tree_id)
|
||||
return MessageTreeStateResponse(
|
||||
message_tree_id=mts.message_tree_id,
|
||||
state=mts.state,
|
||||
active=mts.active,
|
||||
goal_tree_size=mts.goal_tree_size,
|
||||
max_children_count=mts.max_children_count,
|
||||
max_depth=mts.max_depth,
|
||||
origin=mts.origin,
|
||||
)
|
||||
|
||||
|
||||
@router.put("/{message_id}/tree/state", response_model=MessageTreeStateResponse)
|
||||
def put_message_tree_state(
|
||||
*,
|
||||
message_id: UUID,
|
||||
halt: bool,
|
||||
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
) -> MessageTreeStateResponse:
|
||||
@managed_tx_function(CommitMode.COMMIT)
|
||||
def halt_tree_tx(session: deps.Session) -> MessageTreeState:
|
||||
pr = PromptRepository(session, api_client, frontend_user=frontend_user)
|
||||
tm = TreeManager(session, pr)
|
||||
return tm.halt_tree(message_id, halt=halt)
|
||||
|
||||
mts = halt_tree_tx()
|
||||
return MessageTreeStateResponse(
|
||||
message_tree_id=mts.message_tree_id,
|
||||
state=mts.state,
|
||||
active=mts.active,
|
||||
goal_tree_size=mts.goal_tree_size,
|
||||
max_children_count=mts.max_children_count,
|
||||
max_depth=mts.max_depth,
|
||||
origin=mts.origin,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{message_id}/children", response_model=list[protocol.Message])
|
||||
def get_children(
|
||||
*,
|
||||
|
||||
@@ -46,12 +46,11 @@ class TreeManagerConfiguration(BaseModel):
|
||||
num_required_rankings: int = 3
|
||||
"""Number of rankings in which the message participated."""
|
||||
|
||||
p_activate_backlog_tree: float = 0.8
|
||||
p_activate_backlog_tree: float = 0.1
|
||||
"""Probability to activate a message tree in BACKLOG_RANKING state when another tree enters
|
||||
a terminal state. Use this settting to control ratio of initial prompts and backlog tree
|
||||
activations."""
|
||||
a terminal state."""
|
||||
|
||||
min_active_rankings_per_lang: int = 2
|
||||
min_active_rankings_per_lang: int = 0
|
||||
"""When the number of active ranking tasks is below this value when a tree enters a terminal
|
||||
state an available trees in BACKLOG_RANKING will be actived (i.e. enters the RANKING state)."""
|
||||
|
||||
|
||||
@@ -57,7 +57,13 @@ VALID_STATES = (
|
||||
State.BACKLOG_RANKING,
|
||||
)
|
||||
|
||||
TERMINAL_STATES = (State.READY_FOR_EXPORT, State.ABORTED_LOW_GRADE, State.SCORING_FAILED, State.HALTED_BY_MODERATOR)
|
||||
TERMINAL_STATES = (
|
||||
State.READY_FOR_EXPORT,
|
||||
State.ABORTED_LOW_GRADE,
|
||||
State.SCORING_FAILED,
|
||||
State.HALTED_BY_MODERATOR,
|
||||
State.BACKLOG_RANKING,
|
||||
)
|
||||
|
||||
|
||||
class MessageTreeState(SQLModel, table=True):
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
from uuid import UUID
|
||||
|
||||
from oasst_backend.models.message_tree_state import State as TreeState
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class MessageTreeStateResponse(BaseModel):
|
||||
message_tree_id: UUID
|
||||
state: TreeState
|
||||
goal_tree_size: int
|
||||
max_depth: int
|
||||
max_children_count: int
|
||||
active: bool
|
||||
origin: str | None
|
||||
@@ -635,8 +635,7 @@ class TreeManager:
|
||||
|
||||
is_terminal = state in message_tree_state.TERMINAL_STATES
|
||||
was_active = mts.active
|
||||
if is_terminal:
|
||||
mts.active = False
|
||||
mts.active = not is_terminal
|
||||
mts.state = state.value
|
||||
self.db.add(mts)
|
||||
self.db.flush
|
||||
@@ -1288,6 +1287,13 @@ DELETE FROM message WHERE message_tree_id = :message_tree_id;
|
||||
r = self.db.execute(text(sql_purge_message_tree), {"message_tree_id": message_tree_id})
|
||||
logger.debug(f"purge_message_tree({message_tree_id=}) {r.rowcount} rows.")
|
||||
|
||||
def _reactivate_tree(self, mts: MessageTreeState):
|
||||
self._enter_state(mts, message_tree_state.State.INITIAL_PROMPT_REVIEW)
|
||||
tree_id = mts.message_tree_id
|
||||
if self.check_condition_for_growing_state(tree_id):
|
||||
if self.check_condition_for_ranking_state(tree_id):
|
||||
self.check_condition_for_scoring_state(tree_id)
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def purge_user_messages(
|
||||
self,
|
||||
@@ -1347,10 +1353,7 @@ DELETE FROM message WHERE message_tree_id = :message_tree_id;
|
||||
logger.info(f"reactivating message tree {tree_id}")
|
||||
mts = self.pr.fetch_tree_state(tree_id)
|
||||
mts.active = True
|
||||
self._enter_state(mts, message_tree_state.State.INITIAL_PROMPT_REVIEW)
|
||||
self.check_condition_for_growing_state(tree_id)
|
||||
self.check_condition_for_ranking_state(tree_id)
|
||||
self.check_condition_for_scoring_state(tree_id)
|
||||
self._reactivate_tree(mts)
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def purge_user(self, user_id: UUID, ban: bool = True) -> None:
|
||||
@@ -1424,6 +1427,18 @@ DELETE FROM user_stats WHERE user_id = :user_id;
|
||||
except Exception:
|
||||
logger.exception(f"retry_scoring_failed_message_trees failed for ({mts.message_tree_id=})")
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def halt_tree(self, message_id: UUID, halt: bool = True) -> MessageTreeState:
|
||||
message = self.pr.fetch_message(message_id, fail_if_missing=True)
|
||||
mts = self.pr.fetch_tree_state(message.message_tree_id)
|
||||
|
||||
if halt:
|
||||
self._enter_state(mts, message_tree_state.State.HALTED_BY_MODERATOR)
|
||||
else:
|
||||
self._reactivate_tree(mts)
|
||||
|
||||
return mts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from oasst_backend.api.deps import api_auth
|
||||
|
||||
Reference in New Issue
Block a user