add retry_scoring_failed_message_trees cli command (#931)

This commit is contained in:
Vechtomov
2023-01-25 15:40:36 +03:00
committed by GitHub
parent 50e60862b4
commit 1020dcb024
2 changed files with 38 additions and 1 deletions
+21
View File
@@ -291,6 +291,20 @@ def export_ready_trees(file: Optional[str] = None, use_compression: bool = False
logger.exception("Error exporting trees.")
def retry_scoring_failed_message_trees():
try:
logger.info("TreeManager.retry_scoring_failed_message_trees()")
with Session(engine) as db:
api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=db)
pr = PromptRepository(db=db, api_client=api_client)
tm = TreeManager(db, pr)
tm.retry_scoring_failed_message_trees()
except Exception:
logger.exception("TreeManager.retry_scoring_failed_message_trees() failed.")
def main():
# Importing here so we don't import packages unnecessarily if we're
# importing main as a module.
@@ -314,6 +328,11 @@ def main():
"--export-file",
help="Name of file to export trees to. If not provided when exporting, output will be send to STDOUT",
)
parser.add_argument(
"--retry-scoring",
help="Retry scoring failed message trees",
action=argparse.BooleanOptionalAction,
)
args = parser.parse_args()
@@ -322,6 +341,8 @@ def main():
elif args.export:
use_compression: bool = ".gz" in args.export_file
export_ready_trees(file=args.export_file, use_compression=use_compression)
elif args.retry_scoring:
retry_scoring_failed_message_trees()
else:
uvicorn.run(app, host=args.host, port=args.port)
+17 -1
View File
@@ -652,7 +652,9 @@ class TreeManager:
self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING)
return True, rankings_by_message
def update_message_ranks(self, message_tree_id: UUID, rankings_by_message: Dict[int, int]) -> bool:
def update_message_ranks(
self, message_tree_id: UUID, rankings_by_message: dict[UUID, list[MessageReaction]]
) -> bool:
mts = self.pr.fetch_tree_state(message_tree_id)
# check state, allow retry if in SCORING_FAILED state
@@ -1226,6 +1228,20 @@ DELETE FROM user_stats WHERE user_id = :user_id;
message_tree_ids = [ms.message_tree_id for ms in messages]
self.export_trees_to_file(message_tree_ids, file, reviewed, include_deleted, use_compression)
@managed_tx_method(CommitMode.COMMIT)
def retry_scoring_failed_message_trees(self):
query = self.db.query(MessageTreeState.message_tree_id).filter(
MessageTreeState.state == message_tree_state.State.SCORING_FAILED
)
ranking_role_filter = None if self.cfg.rank_prompter_replies else "assistant"
for row in query.all():
try:
message_tree_id = row["message_tree_id"]
rankings_by_message = self.query_tree_ranking_results(message_tree_id, role_filter=ranking_role_filter)
self.update_message_ranks(message_tree_id=message_tree_id, rankings_by_message=rankings_by_message)
except Exception:
logger.exception(f"retry_scoring_failed_message_trees failed for ({message_tree_id=})")
if __name__ == "__main__":
from oasst_backend.api.deps import api_auth