diff --git a/backend/main.py b/backend/main.py index a9dd8f72..ea4b25da 100644 --- a/backend/main.py +++ b/backend/main.py @@ -291,6 +291,20 @@ def export_ready_trees(file: Optional[str] = None, use_compression: bool = False logger.exception("Error exporting trees.") +def retry_scoring_failed_message_trees(): + try: + logger.info("TreeManager.retry_scoring_failed_message_trees()") + with Session(engine) as db: + api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=db) + + pr = PromptRepository(db=db, api_client=api_client) + tm = TreeManager(db, pr) + tm.retry_scoring_failed_message_trees() + + except Exception: + logger.exception("TreeManager.retry_scoring_failed_message_trees() failed.") + + def main(): # Importing here so we don't import packages unnecessarily if we're # importing main as a module. @@ -314,6 +328,11 @@ def main(): "--export-file", help="Name of file to export trees to. If not provided when exporting, output will be send to STDOUT", ) + parser.add_argument( + "--retry-scoring", + help="Retry scoring failed message trees", + action=argparse.BooleanOptionalAction, + ) args = parser.parse_args() @@ -322,6 +341,8 @@ def main(): elif args.export: use_compression: bool = ".gz" in args.export_file export_ready_trees(file=args.export_file, use_compression=use_compression) + elif args.retry_scoring: + retry_scoring_failed_message_trees() else: uvicorn.run(app, host=args.host, port=args.port) diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index 48dea1c9..1828f8ab 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -652,7 +652,9 @@ class TreeManager: self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING) return True, rankings_by_message - def update_message_ranks(self, message_tree_id: UUID, rankings_by_message: Dict[int, int]) -> bool: + def update_message_ranks( + self, message_tree_id: UUID, rankings_by_message: dict[UUID, list[MessageReaction]] + ) -> bool: mts = self.pr.fetch_tree_state(message_tree_id) # check state, allow retry if in SCORING_FAILED state @@ -1226,6 +1228,20 @@ DELETE FROM user_stats WHERE user_id = :user_id; message_tree_ids = [ms.message_tree_id for ms in messages] self.export_trees_to_file(message_tree_ids, file, reviewed, include_deleted, use_compression) + @managed_tx_method(CommitMode.COMMIT) + def retry_scoring_failed_message_trees(self): + query = self.db.query(MessageTreeState.message_tree_id).filter( + MessageTreeState.state == message_tree_state.State.SCORING_FAILED + ) + ranking_role_filter = None if self.cfg.rank_prompter_replies else "assistant" + for row in query.all(): + try: + message_tree_id = row["message_tree_id"] + rankings_by_message = self.query_tree_ranking_results(message_tree_id, role_filter=ranking_role_filter) + self.update_message_ranks(message_tree_id=message_tree_id, rankings_by_message=rankings_by_message) + except Exception: + logger.exception(f"retry_scoring_failed_message_trees failed for ({message_tree_id=})") + if __name__ == "__main__": from oasst_backend.api.deps import api_auth