move prompt lottery tree activation into separate transaction scope

This commit is contained in:
Andreas Köpf
2023-02-05 01:16:34 +01:00
parent ae8b9ea09e
commit 263edbaefd
2 changed files with 73 additions and 64 deletions
+1 -1
View File
@@ -47,7 +47,7 @@ class TreeManagerConfiguration(BaseModel):
"""Automatically set tree state to `halted_by_moderator` when more than the specified number
of users skip replying to a message. (auto moderation)"""
auto_mod_red_flags: int = 3
auto_mod_red_flags: int = 4
"""Delete messages that receive more than this number of red flags if it is a reply or
set the tree to `aborted_low_grade` when a prompt is flagged. (auto moderation)"""
+72 -63
View File
@@ -26,7 +26,12 @@ from oasst_backend.models import (
)
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.utils import tree_export
from oasst_backend.utils.database_utils import CommitMode, async_managed_tx_method, managed_tx_method
from oasst_backend.utils.database_utils import (
CommitMode,
async_managed_tx_method,
managed_tx_function,
managed_tx_method,
)
from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingModel, HfUrl, HuggingFaceAPI
from oasst_backend.utils.ranking import ranked_pairs
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
@@ -218,17 +223,12 @@ class TreeManager:
return task_count_by_type
def _prompt_lottery(self, lang: str) -> int:
MAX_RETRIES = 5
def _prompt_lottery(self, lang: str, max_activate: int = 1) -> int:
# Under high load the DB runs into deadlocks when many trees are released
# simultaneously (happens whens the max_active_trees setting is increased).
# To reduce the chance of write conflicts during updates of rows in the
# message_tree_state table we limit the number of trees that are activated
# per _prompt_lottery() call to MAX_ACTIVATE.
MAX_ACTIVATE = 2
retry = 0
# per _prompt_lottery() call to max_activate.
activated = 0
while True:
@@ -237,67 +237,76 @@ class TreeManager:
remaining_prompt_review = max(0, self.cfg.max_initial_prompt_review - stats.initial_prompt_review)
num_missing_growing = max(0, self.cfg.max_active_trees - stats.growing)
logger.debug(f"_prompt_lottery {remaining_prompt_review=}, {num_missing_growing=}")
logger.info(f"_prompt_lottery {remaining_prompt_review=}, {num_missing_growing=}")
if num_missing_growing == 0 or activated >= MAX_ACTIVATE:
if num_missing_growing == 0 or activated >= max_activate:
return num_missing_growing + remaining_prompt_review
# select among distinct users
authors_qry = (
self.db.query(Message.user_id)
.select_from(MessageTreeState)
.join(Message, MessageTreeState.message_tree_id == Message.id)
.filter(
MessageTreeState.state == message_tree_state.State.PROMPT_LOTTERY_WAITING,
Message.lang == lang,
not_(Message.deleted),
Message.review_result,
@managed_tx_function(CommitMode.COMMIT)
def activate_one(db: Session) -> int:
# select among distinct users
authors_qry = (
db.query(Message.user_id)
.select_from(MessageTreeState)
.join(Message, MessageTreeState.message_tree_id == Message.id)
.filter(
MessageTreeState.state == message_tree_state.State.PROMPT_LOTTERY_WAITING,
Message.lang == lang,
not_(Message.deleted),
Message.review_result,
)
.distinct(Message.user_id)
)
.distinct(Message.user_id)
)
author_ids = authors_qry.all()
if len(author_ids) == 0:
logger.info(
f"No prompts for prompt lottery available ({num_missing_growing=}, trees missing for {lang=})."
author_ids = authors_qry.all()
if len(author_ids) == 0:
logger.info(
f"No prompts for prompt lottery available ({num_missing_growing=}, trees missing for {lang=})."
)
return False
# first select an authour
prompt_author_id: UUID = random.choice(author_ids)["user_id"]
logger.info(f"Selected random prompt author {prompt_author_id} among {len(author_ids)} candidates.")
# select random prompt of author
qry = (
db.query(MessageTreeState, Message)
.select_from(MessageTreeState)
.join(Message, MessageTreeState.message_tree_id == Message.id)
.filter(
MessageTreeState.state == message_tree_state.State.PROMPT_LOTTERY_WAITING,
Message.user_id == prompt_author_id,
Message.lang == lang,
not_(Message.deleted),
Message.review_result,
)
.limit(100)
)
prompt_candidates = qry.all()
if len(prompt_candidates) == 0:
logger.warning("No prompt candidates of selected author found.")
return False
winner_prompt = random.choice(prompt_candidates)
message: Message = winner_prompt.Message
logger.info(f"Prompt lottery winner: {message.id=}")
mts: MessageTreeState = winner_prompt.MessageTreeState
mts.state = message_tree_state.State.GROWING
mts.active = True
db.add(mts)
if mts.won_prompt_lottery_date is None:
mts.won_prompt_lottery_date = utcnow()
logger.info(f"Tree entered '{mts.state}' state ({mts.message_tree_id=})")
return True
if not activate_one():
return num_missing_growing + remaining_prompt_review
# first select an authour
prompt_author_id: UUID = random.choice(author_ids)["user_id"]
logger.info(f"Selected random prompt author {prompt_author_id} among {len(author_ids)} candidates.")
# select random prompt of author
qry = (
self.db.query(MessageTreeState, Message)
.select_from(MessageTreeState)
.join(Message, MessageTreeState.message_tree_id == Message.id)
.filter(
MessageTreeState.state == message_tree_state.State.PROMPT_LOTTERY_WAITING,
Message.user_id == prompt_author_id,
Message.lang == lang,
not_(Message.deleted),
Message.review_result,
)
.limit(100)
)
prompt_candidates = qry.all()
if len(prompt_candidates) == 0:
retry += 1 # not sure if this can happen with repeatable read isolation level, just in case we retry
if retry < MAX_RETRIES:
continue
else:
logger.warning("Max retries in prompt lottery reached.")
return num_missing_growing + remaining_prompt_review
winner_prompt = random.choice(prompt_candidates)
message: Message = winner_prompt.Message
logger.info(f"Prompt lottery winner: {message.id=}")
mts: MessageTreeState = winner_prompt.MessageTreeState
self._enter_state(mts, message_tree_state.State.GROWING)
self.db.flush()
activated += 1
def _auto_moderation(self, lang: str) -> None:
@@ -333,7 +342,7 @@ class TreeManager:
logger.warning("Task availability request without lang tag received, assuming lang='en'.")
self._auto_moderation(lang=lang)
num_missing_prompts = self._prompt_lottery(lang=lang)
num_missing_prompts = self._prompt_lottery(lang=lang, max_activate=1)
extendible_parents, _ = self.query_extendible_parents(lang=lang)
prompts_need_review = self.query_prompts_need_review(lang=lang)
replies_need_review = self.query_replies_need_review(lang=lang)
@@ -371,7 +380,7 @@ class TreeManager:
logger.warning("Task request without lang tag received, assuming 'en'.")
self._auto_moderation(lang=lang)
num_missing_prompts = self._prompt_lottery(lang=lang)
num_missing_prompts = self._prompt_lottery(lang=lang, max_activate=2)
prompts_need_review = self.query_prompts_need_review(lang=lang)
replies_need_review = self.query_replies_need_review(lang=lang)