mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-07-04 17:20:19 +08:00
add retry_scoring_failed_message_trees cli command (#931)
This commit is contained in:
@@ -291,6 +291,20 @@ def export_ready_trees(file: Optional[str] = None, use_compression: bool = False
|
||||
logger.exception("Error exporting trees.")
|
||||
|
||||
|
||||
def retry_scoring_failed_message_trees():
|
||||
try:
|
||||
logger.info("TreeManager.retry_scoring_failed_message_trees()")
|
||||
with Session(engine) as db:
|
||||
api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=db)
|
||||
|
||||
pr = PromptRepository(db=db, api_client=api_client)
|
||||
tm = TreeManager(db, pr)
|
||||
tm.retry_scoring_failed_message_trees()
|
||||
|
||||
except Exception:
|
||||
logger.exception("TreeManager.retry_scoring_failed_message_trees() failed.")
|
||||
|
||||
|
||||
def main():
|
||||
# Importing here so we don't import packages unnecessarily if we're
|
||||
# importing main as a module.
|
||||
@@ -314,6 +328,11 @@ def main():
|
||||
"--export-file",
|
||||
help="Name of file to export trees to. If not provided when exporting, output will be send to STDOUT",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--retry-scoring",
|
||||
help="Retry scoring failed message trees",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -322,6 +341,8 @@ def main():
|
||||
elif args.export:
|
||||
use_compression: bool = ".gz" in args.export_file
|
||||
export_ready_trees(file=args.export_file, use_compression=use_compression)
|
||||
elif args.retry_scoring:
|
||||
retry_scoring_failed_message_trees()
|
||||
else:
|
||||
uvicorn.run(app, host=args.host, port=args.port)
|
||||
|
||||
|
||||
@@ -652,7 +652,9 @@ class TreeManager:
|
||||
self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING)
|
||||
return True, rankings_by_message
|
||||
|
||||
def update_message_ranks(self, message_tree_id: UUID, rankings_by_message: Dict[int, int]) -> bool:
|
||||
def update_message_ranks(
|
||||
self, message_tree_id: UUID, rankings_by_message: dict[UUID, list[MessageReaction]]
|
||||
) -> bool:
|
||||
|
||||
mts = self.pr.fetch_tree_state(message_tree_id)
|
||||
# check state, allow retry if in SCORING_FAILED state
|
||||
@@ -1226,6 +1228,20 @@ DELETE FROM user_stats WHERE user_id = :user_id;
|
||||
message_tree_ids = [ms.message_tree_id for ms in messages]
|
||||
self.export_trees_to_file(message_tree_ids, file, reviewed, include_deleted, use_compression)
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def retry_scoring_failed_message_trees(self):
|
||||
query = self.db.query(MessageTreeState.message_tree_id).filter(
|
||||
MessageTreeState.state == message_tree_state.State.SCORING_FAILED
|
||||
)
|
||||
ranking_role_filter = None if self.cfg.rank_prompter_replies else "assistant"
|
||||
for row in query.all():
|
||||
try:
|
||||
message_tree_id = row["message_tree_id"]
|
||||
rankings_by_message = self.query_tree_ranking_results(message_tree_id, role_filter=ranking_role_filter)
|
||||
self.update_message_ranks(message_tree_id=message_tree_id, rankings_by_message=rankings_by_message)
|
||||
except Exception:
|
||||
logger.exception(f"retry_scoring_failed_message_trees failed for ({message_tree_id=})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from oasst_backend.api.deps import api_auth
|
||||
|
||||
Reference in New Issue
Block a user