From eda275bef1773b4e8c623fdbad1a9da41b478f21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Sun, 29 Jan 2023 23:28:47 +0100 Subject: [PATCH] fix export & add emojis (#1004) --- backend/main.py | 21 ++++++++++---- backend/oasst_backend/utils/tree_export.py | 33 +++++++++++----------- 2 files changed, 31 insertions(+), 23 deletions(-) diff --git a/backend/main.py b/backend/main.py index 06c67a2e..5454d880 100644 --- a/backend/main.py +++ b/backend/main.py @@ -318,34 +318,43 @@ def main(): parser.add_argument( "--print-openapi-schema", + default=False, help="Dumps the openapi schema to stdout", - action=argparse.BooleanOptionalAction, + action="store_true", ) parser.add_argument("--host", help="The host to run the server", default="0.0.0.0") parser.add_argument("--port", help="The port to run the server", default=8080) parser.add_argument( - "--export", help="Export all trees which are ready for exporting.", action=argparse.BooleanOptionalAction + "--export", + default=False, + help="Export all trees which are ready for exporting.", + action="store_true", ) parser.add_argument( "--export-file", + type=str, help="Name of file to export trees to. If not provided when exporting, output will be send to STDOUT", ) parser.add_argument( "--retry-scoring", + default=False, help="Retry scoring failed message trees", - action=argparse.BooleanOptionalAction, + action="store_true", ) args = parser.parse_args() if args.print_openapi_schema: print(get_openapi_schema()) - elif args.export: + + if args.export: use_compression: bool = ".gz" in args.export_file export_ready_trees(file=args.export_file, use_compression=use_compression) - elif args.retry_scoring: + + if args.retry_scoring: retry_scoring_failed_message_trees() - else: + + if not (args.export or args.print_openapi_schema or args.retry_scoring): uvicorn.run(app, host=args.host, port=args.port) diff --git a/backend/oasst_backend/utils/tree_export.py b/backend/oasst_backend/utils/tree_export.py index 5cd69abe..ce69edad 100644 --- a/backend/oasst_backend/utils/tree_export.py +++ b/backend/oasst_backend/utils/tree_export.py @@ -20,11 +20,12 @@ class ExportMessageNode(BaseModel): rank: int | None synthetic: bool | None model_name: str | None + emojis: dict[str, int] | None replies: list[ExportMessageNode] | None - @classmethod - def prep_message_export(cls, message: Message) -> ExportMessageNode: - return cls( + @staticmethod + def prep_message_export(message: Message) -> ExportMessageNode: + return ExportMessageNode( message_id=str(message.id), parent_id=str(message.parent_id) if message.parent_id else None, text=str(message.payload.payload.text), @@ -33,6 +34,7 @@ class ExportMessageNode(BaseModel): review_count=message.review_count, synthetic=message.synthetic, model_name=message.model_name, + emojis=message.emojis, rank=message.rank, ) @@ -43,23 +45,20 @@ class ExportMessageTree(BaseModel): def build_export_tree(message_tree_id: str, messages: list[Message]) -> ExportMessageTree: - export_tree = ExportMessageTree(message_tree_id=str(message_tree_id)) - export_tree_data = [ExportMessageNode.prep_message_export(m) for m in messages] + export_messages = [ExportMessageNode.prep_message_export(m) for m in messages] - message_parents = defaultdict(list) - for message in export_tree_data: - message_parents[message.parent_id].append(message) + messages_by_parent = defaultdict(list) + for message in export_messages: + messages_by_parent[message.parent_id].append(message) - def build_tree(tree: dict, parent: Optional[str], messages: list[Message]): - children = message_parents[parent] - tree.replies = children + def assign_replies(node: ExportMessageNode) -> ExportMessageNode: + node.replies = messages_by_parent[node.message_id] + for child in node.replies: + assign_replies(child) + return node - for idx, child in enumerate(tree.replies): - build_tree(tree.replies[idx], child.message_id, messages) - - build_tree(export_tree, None, export_tree_data) - - return export_tree + prompt = assign_replies(messages_by_parent[None][0]) + return ExportMessageTree(message_tree_id=str(message_tree_id), prompt=prompt) def write_trees_to_file(file, trees: list[ExportMessageTree], use_compression: bool = True) -> None: