diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index 4d6dea30..6ff727de 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -47,7 +47,7 @@ class TreeManagerConfiguration(BaseModel): """Automatically set tree state to `halted_by_moderator` when more than the specified number of users skip replying to a message. (auto moderation)""" - auto_mod_red_flags: int = 3 + auto_mod_red_flags: int = 4 """Delete messages that receive more than this number of red flags if it is a reply or set the tree to `aborted_low_grade` when a prompt is flagged. (auto moderation)""" diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index cdf61d3f..940e08ea 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -26,7 +26,12 @@ from oasst_backend.models import ( ) from oasst_backend.prompt_repository import PromptRepository from oasst_backend.utils import tree_export -from oasst_backend.utils.database_utils import CommitMode, async_managed_tx_method, managed_tx_method +from oasst_backend.utils.database_utils import ( + CommitMode, + async_managed_tx_method, + managed_tx_function, + managed_tx_method, +) from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingModel, HfUrl, HuggingFaceAPI from oasst_backend.utils.ranking import ranked_pairs from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode @@ -218,17 +223,12 @@ class TreeManager: return task_count_by_type - def _prompt_lottery(self, lang: str) -> int: - MAX_RETRIES = 5 - + def _prompt_lottery(self, lang: str, max_activate: int = 1) -> int: # Under high load the DB runs into deadlocks when many trees are released # simultaneously (happens whens the max_active_trees setting is increased). # To reduce the chance of write conflicts during updates of rows in the # message_tree_state table we limit the number of trees that are activated - # per _prompt_lottery() call to MAX_ACTIVATE. - MAX_ACTIVATE = 2 - - retry = 0 + # per _prompt_lottery() call to max_activate. activated = 0 while True: @@ -237,67 +237,76 @@ class TreeManager: remaining_prompt_review = max(0, self.cfg.max_initial_prompt_review - stats.initial_prompt_review) num_missing_growing = max(0, self.cfg.max_active_trees - stats.growing) - logger.debug(f"_prompt_lottery {remaining_prompt_review=}, {num_missing_growing=}") + logger.info(f"_prompt_lottery {remaining_prompt_review=}, {num_missing_growing=}") - if num_missing_growing == 0 or activated >= MAX_ACTIVATE: + if num_missing_growing == 0 or activated >= max_activate: return num_missing_growing + remaining_prompt_review - # select among distinct users - authors_qry = ( - self.db.query(Message.user_id) - .select_from(MessageTreeState) - .join(Message, MessageTreeState.message_tree_id == Message.id) - .filter( - MessageTreeState.state == message_tree_state.State.PROMPT_LOTTERY_WAITING, - Message.lang == lang, - not_(Message.deleted), - Message.review_result, + @managed_tx_function(CommitMode.COMMIT) + def activate_one(db: Session) -> int: + # select among distinct users + authors_qry = ( + db.query(Message.user_id) + .select_from(MessageTreeState) + .join(Message, MessageTreeState.message_tree_id == Message.id) + .filter( + MessageTreeState.state == message_tree_state.State.PROMPT_LOTTERY_WAITING, + Message.lang == lang, + not_(Message.deleted), + Message.review_result, + ) + .distinct(Message.user_id) ) - .distinct(Message.user_id) - ) - author_ids = authors_qry.all() - if len(author_ids) == 0: - logger.info( - f"No prompts for prompt lottery available ({num_missing_growing=}, trees missing for {lang=})." + author_ids = authors_qry.all() + if len(author_ids) == 0: + logger.info( + f"No prompts for prompt lottery available ({num_missing_growing=}, trees missing for {lang=})." + ) + return False + + # first select an authour + prompt_author_id: UUID = random.choice(author_ids)["user_id"] + logger.info(f"Selected random prompt author {prompt_author_id} among {len(author_ids)} candidates.") + + # select random prompt of author + qry = ( + db.query(MessageTreeState, Message) + .select_from(MessageTreeState) + .join(Message, MessageTreeState.message_tree_id == Message.id) + .filter( + MessageTreeState.state == message_tree_state.State.PROMPT_LOTTERY_WAITING, + Message.user_id == prompt_author_id, + Message.lang == lang, + not_(Message.deleted), + Message.review_result, + ) + .limit(100) ) + + prompt_candidates = qry.all() + if len(prompt_candidates) == 0: + logger.warning("No prompt candidates of selected author found.") + return False + + winner_prompt = random.choice(prompt_candidates) + message: Message = winner_prompt.Message + logger.info(f"Prompt lottery winner: {message.id=}") + + mts: MessageTreeState = winner_prompt.MessageTreeState + mts.state = message_tree_state.State.GROWING + mts.active = True + db.add(mts) + + if mts.won_prompt_lottery_date is None: + mts.won_prompt_lottery_date = utcnow() + logger.info(f"Tree entered '{mts.state}' state ({mts.message_tree_id=})") + + return True + + if not activate_one(): return num_missing_growing + remaining_prompt_review - # first select an authour - prompt_author_id: UUID = random.choice(author_ids)["user_id"] - logger.info(f"Selected random prompt author {prompt_author_id} among {len(author_ids)} candidates.") - - # select random prompt of author - qry = ( - self.db.query(MessageTreeState, Message) - .select_from(MessageTreeState) - .join(Message, MessageTreeState.message_tree_id == Message.id) - .filter( - MessageTreeState.state == message_tree_state.State.PROMPT_LOTTERY_WAITING, - Message.user_id == prompt_author_id, - Message.lang == lang, - not_(Message.deleted), - Message.review_result, - ) - .limit(100) - ) - - prompt_candidates = qry.all() - if len(prompt_candidates) == 0: - retry += 1 # not sure if this can happen with repeatable read isolation level, just in case we retry - if retry < MAX_RETRIES: - continue - else: - logger.warning("Max retries in prompt lottery reached.") - return num_missing_growing + remaining_prompt_review - - winner_prompt = random.choice(prompt_candidates) - message: Message = winner_prompt.Message - logger.info(f"Prompt lottery winner: {message.id=}") - - mts: MessageTreeState = winner_prompt.MessageTreeState - self._enter_state(mts, message_tree_state.State.GROWING) - self.db.flush() activated += 1 def _auto_moderation(self, lang: str) -> None: @@ -333,7 +342,7 @@ class TreeManager: logger.warning("Task availability request without lang tag received, assuming lang='en'.") self._auto_moderation(lang=lang) - num_missing_prompts = self._prompt_lottery(lang=lang) + num_missing_prompts = self._prompt_lottery(lang=lang, max_activate=1) extendible_parents, _ = self.query_extendible_parents(lang=lang) prompts_need_review = self.query_prompts_need_review(lang=lang) replies_need_review = self.query_replies_need_review(lang=lang) @@ -371,7 +380,7 @@ class TreeManager: logger.warning("Task request without lang tag received, assuming 'en'.") self._auto_moderation(lang=lang) - num_missing_prompts = self._prompt_lottery(lang=lang) + num_missing_prompts = self._prompt_lottery(lang=lang, max_activate=2) prompts_need_review = self.query_prompts_need_review(lang=lang) replies_need_review = self.query_replies_need_review(lang=lang)