mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-07-02 17:00:28 +08:00
move prompt lottery tree activation into separate transaction scope
This commit is contained in:
@@ -47,7 +47,7 @@ class TreeManagerConfiguration(BaseModel):
|
||||
"""Automatically set tree state to `halted_by_moderator` when more than the specified number
|
||||
of users skip replying to a message. (auto moderation)"""
|
||||
|
||||
auto_mod_red_flags: int = 3
|
||||
auto_mod_red_flags: int = 4
|
||||
"""Delete messages that receive more than this number of red flags if it is a reply or
|
||||
set the tree to `aborted_low_grade` when a prompt is flagged. (auto moderation)"""
|
||||
|
||||
|
||||
@@ -26,7 +26,12 @@ from oasst_backend.models import (
|
||||
)
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.utils import tree_export
|
||||
from oasst_backend.utils.database_utils import CommitMode, async_managed_tx_method, managed_tx_method
|
||||
from oasst_backend.utils.database_utils import (
|
||||
CommitMode,
|
||||
async_managed_tx_method,
|
||||
managed_tx_function,
|
||||
managed_tx_method,
|
||||
)
|
||||
from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingModel, HfUrl, HuggingFaceAPI
|
||||
from oasst_backend.utils.ranking import ranked_pairs
|
||||
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
||||
@@ -218,17 +223,12 @@ class TreeManager:
|
||||
|
||||
return task_count_by_type
|
||||
|
||||
def _prompt_lottery(self, lang: str) -> int:
|
||||
MAX_RETRIES = 5
|
||||
|
||||
def _prompt_lottery(self, lang: str, max_activate: int = 1) -> int:
|
||||
# Under high load the DB runs into deadlocks when many trees are released
|
||||
# simultaneously (happens whens the max_active_trees setting is increased).
|
||||
# To reduce the chance of write conflicts during updates of rows in the
|
||||
# message_tree_state table we limit the number of trees that are activated
|
||||
# per _prompt_lottery() call to MAX_ACTIVATE.
|
||||
MAX_ACTIVATE = 2
|
||||
|
||||
retry = 0
|
||||
# per _prompt_lottery() call to max_activate.
|
||||
activated = 0
|
||||
|
||||
while True:
|
||||
@@ -237,67 +237,76 @@ class TreeManager:
|
||||
|
||||
remaining_prompt_review = max(0, self.cfg.max_initial_prompt_review - stats.initial_prompt_review)
|
||||
num_missing_growing = max(0, self.cfg.max_active_trees - stats.growing)
|
||||
logger.debug(f"_prompt_lottery {remaining_prompt_review=}, {num_missing_growing=}")
|
||||
logger.info(f"_prompt_lottery {remaining_prompt_review=}, {num_missing_growing=}")
|
||||
|
||||
if num_missing_growing == 0 or activated >= MAX_ACTIVATE:
|
||||
if num_missing_growing == 0 or activated >= max_activate:
|
||||
return num_missing_growing + remaining_prompt_review
|
||||
|
||||
# select among distinct users
|
||||
authors_qry = (
|
||||
self.db.query(Message.user_id)
|
||||
.select_from(MessageTreeState)
|
||||
.join(Message, MessageTreeState.message_tree_id == Message.id)
|
||||
.filter(
|
||||
MessageTreeState.state == message_tree_state.State.PROMPT_LOTTERY_WAITING,
|
||||
Message.lang == lang,
|
||||
not_(Message.deleted),
|
||||
Message.review_result,
|
||||
@managed_tx_function(CommitMode.COMMIT)
|
||||
def activate_one(db: Session) -> int:
|
||||
# select among distinct users
|
||||
authors_qry = (
|
||||
db.query(Message.user_id)
|
||||
.select_from(MessageTreeState)
|
||||
.join(Message, MessageTreeState.message_tree_id == Message.id)
|
||||
.filter(
|
||||
MessageTreeState.state == message_tree_state.State.PROMPT_LOTTERY_WAITING,
|
||||
Message.lang == lang,
|
||||
not_(Message.deleted),
|
||||
Message.review_result,
|
||||
)
|
||||
.distinct(Message.user_id)
|
||||
)
|
||||
.distinct(Message.user_id)
|
||||
)
|
||||
|
||||
author_ids = authors_qry.all()
|
||||
if len(author_ids) == 0:
|
||||
logger.info(
|
||||
f"No prompts for prompt lottery available ({num_missing_growing=}, trees missing for {lang=})."
|
||||
author_ids = authors_qry.all()
|
||||
if len(author_ids) == 0:
|
||||
logger.info(
|
||||
f"No prompts for prompt lottery available ({num_missing_growing=}, trees missing for {lang=})."
|
||||
)
|
||||
return False
|
||||
|
||||
# first select an authour
|
||||
prompt_author_id: UUID = random.choice(author_ids)["user_id"]
|
||||
logger.info(f"Selected random prompt author {prompt_author_id} among {len(author_ids)} candidates.")
|
||||
|
||||
# select random prompt of author
|
||||
qry = (
|
||||
db.query(MessageTreeState, Message)
|
||||
.select_from(MessageTreeState)
|
||||
.join(Message, MessageTreeState.message_tree_id == Message.id)
|
||||
.filter(
|
||||
MessageTreeState.state == message_tree_state.State.PROMPT_LOTTERY_WAITING,
|
||||
Message.user_id == prompt_author_id,
|
||||
Message.lang == lang,
|
||||
not_(Message.deleted),
|
||||
Message.review_result,
|
||||
)
|
||||
.limit(100)
|
||||
)
|
||||
|
||||
prompt_candidates = qry.all()
|
||||
if len(prompt_candidates) == 0:
|
||||
logger.warning("No prompt candidates of selected author found.")
|
||||
return False
|
||||
|
||||
winner_prompt = random.choice(prompt_candidates)
|
||||
message: Message = winner_prompt.Message
|
||||
logger.info(f"Prompt lottery winner: {message.id=}")
|
||||
|
||||
mts: MessageTreeState = winner_prompt.MessageTreeState
|
||||
mts.state = message_tree_state.State.GROWING
|
||||
mts.active = True
|
||||
db.add(mts)
|
||||
|
||||
if mts.won_prompt_lottery_date is None:
|
||||
mts.won_prompt_lottery_date = utcnow()
|
||||
logger.info(f"Tree entered '{mts.state}' state ({mts.message_tree_id=})")
|
||||
|
||||
return True
|
||||
|
||||
if not activate_one():
|
||||
return num_missing_growing + remaining_prompt_review
|
||||
|
||||
# first select an authour
|
||||
prompt_author_id: UUID = random.choice(author_ids)["user_id"]
|
||||
logger.info(f"Selected random prompt author {prompt_author_id} among {len(author_ids)} candidates.")
|
||||
|
||||
# select random prompt of author
|
||||
qry = (
|
||||
self.db.query(MessageTreeState, Message)
|
||||
.select_from(MessageTreeState)
|
||||
.join(Message, MessageTreeState.message_tree_id == Message.id)
|
||||
.filter(
|
||||
MessageTreeState.state == message_tree_state.State.PROMPT_LOTTERY_WAITING,
|
||||
Message.user_id == prompt_author_id,
|
||||
Message.lang == lang,
|
||||
not_(Message.deleted),
|
||||
Message.review_result,
|
||||
)
|
||||
.limit(100)
|
||||
)
|
||||
|
||||
prompt_candidates = qry.all()
|
||||
if len(prompt_candidates) == 0:
|
||||
retry += 1 # not sure if this can happen with repeatable read isolation level, just in case we retry
|
||||
if retry < MAX_RETRIES:
|
||||
continue
|
||||
else:
|
||||
logger.warning("Max retries in prompt lottery reached.")
|
||||
return num_missing_growing + remaining_prompt_review
|
||||
|
||||
winner_prompt = random.choice(prompt_candidates)
|
||||
message: Message = winner_prompt.Message
|
||||
logger.info(f"Prompt lottery winner: {message.id=}")
|
||||
|
||||
mts: MessageTreeState = winner_prompt.MessageTreeState
|
||||
self._enter_state(mts, message_tree_state.State.GROWING)
|
||||
self.db.flush()
|
||||
activated += 1
|
||||
|
||||
def _auto_moderation(self, lang: str) -> None:
|
||||
@@ -333,7 +342,7 @@ class TreeManager:
|
||||
logger.warning("Task availability request without lang tag received, assuming lang='en'.")
|
||||
|
||||
self._auto_moderation(lang=lang)
|
||||
num_missing_prompts = self._prompt_lottery(lang=lang)
|
||||
num_missing_prompts = self._prompt_lottery(lang=lang, max_activate=1)
|
||||
extendible_parents, _ = self.query_extendible_parents(lang=lang)
|
||||
prompts_need_review = self.query_prompts_need_review(lang=lang)
|
||||
replies_need_review = self.query_replies_need_review(lang=lang)
|
||||
@@ -371,7 +380,7 @@ class TreeManager:
|
||||
logger.warning("Task request without lang tag received, assuming 'en'.")
|
||||
|
||||
self._auto_moderation(lang=lang)
|
||||
num_missing_prompts = self._prompt_lottery(lang=lang)
|
||||
num_missing_prompts = self._prompt_lottery(lang=lang, max_activate=2)
|
||||
|
||||
prompts_need_review = self.query_prompts_need_review(lang=lang)
|
||||
replies_need_review = self.query_replies_need_review(lang=lang)
|
||||
|
||||
Reference in New Issue
Block a user