mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-28 16:20:34 +08:00
fix export & add emojis (#1004)
This commit is contained in:
+15
-6
@@ -318,34 +318,43 @@ def main():
|
||||
|
||||
parser.add_argument(
|
||||
"--print-openapi-schema",
|
||||
default=False,
|
||||
help="Dumps the openapi schema to stdout",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument("--host", help="The host to run the server", default="0.0.0.0")
|
||||
parser.add_argument("--port", help="The port to run the server", default=8080)
|
||||
parser.add_argument(
|
||||
"--export", help="Export all trees which are ready for exporting.", action=argparse.BooleanOptionalAction
|
||||
"--export",
|
||||
default=False,
|
||||
help="Export all trees which are ready for exporting.",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--export-file",
|
||||
type=str,
|
||||
help="Name of file to export trees to. If not provided when exporting, output will be send to STDOUT",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--retry-scoring",
|
||||
default=False,
|
||||
help="Retry scoring failed message trees",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
action="store_true",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.print_openapi_schema:
|
||||
print(get_openapi_schema())
|
||||
elif args.export:
|
||||
|
||||
if args.export:
|
||||
use_compression: bool = ".gz" in args.export_file
|
||||
export_ready_trees(file=args.export_file, use_compression=use_compression)
|
||||
elif args.retry_scoring:
|
||||
|
||||
if args.retry_scoring:
|
||||
retry_scoring_failed_message_trees()
|
||||
else:
|
||||
|
||||
if not (args.export or args.print_openapi_schema or args.retry_scoring):
|
||||
uvicorn.run(app, host=args.host, port=args.port)
|
||||
|
||||
|
||||
|
||||
@@ -20,11 +20,12 @@ class ExportMessageNode(BaseModel):
|
||||
rank: int | None
|
||||
synthetic: bool | None
|
||||
model_name: str | None
|
||||
emojis: dict[str, int] | None
|
||||
replies: list[ExportMessageNode] | None
|
||||
|
||||
@classmethod
|
||||
def prep_message_export(cls, message: Message) -> ExportMessageNode:
|
||||
return cls(
|
||||
@staticmethod
|
||||
def prep_message_export(message: Message) -> ExportMessageNode:
|
||||
return ExportMessageNode(
|
||||
message_id=str(message.id),
|
||||
parent_id=str(message.parent_id) if message.parent_id else None,
|
||||
text=str(message.payload.payload.text),
|
||||
@@ -33,6 +34,7 @@ class ExportMessageNode(BaseModel):
|
||||
review_count=message.review_count,
|
||||
synthetic=message.synthetic,
|
||||
model_name=message.model_name,
|
||||
emojis=message.emojis,
|
||||
rank=message.rank,
|
||||
)
|
||||
|
||||
@@ -43,23 +45,20 @@ class ExportMessageTree(BaseModel):
|
||||
|
||||
|
||||
def build_export_tree(message_tree_id: str, messages: list[Message]) -> ExportMessageTree:
|
||||
export_tree = ExportMessageTree(message_tree_id=str(message_tree_id))
|
||||
export_tree_data = [ExportMessageNode.prep_message_export(m) for m in messages]
|
||||
export_messages = [ExportMessageNode.prep_message_export(m) for m in messages]
|
||||
|
||||
message_parents = defaultdict(list)
|
||||
for message in export_tree_data:
|
||||
message_parents[message.parent_id].append(message)
|
||||
messages_by_parent = defaultdict(list)
|
||||
for message in export_messages:
|
||||
messages_by_parent[message.parent_id].append(message)
|
||||
|
||||
def build_tree(tree: dict, parent: Optional[str], messages: list[Message]):
|
||||
children = message_parents[parent]
|
||||
tree.replies = children
|
||||
def assign_replies(node: ExportMessageNode) -> ExportMessageNode:
|
||||
node.replies = messages_by_parent[node.message_id]
|
||||
for child in node.replies:
|
||||
assign_replies(child)
|
||||
return node
|
||||
|
||||
for idx, child in enumerate(tree.replies):
|
||||
build_tree(tree.replies[idx], child.message_id, messages)
|
||||
|
||||
build_tree(export_tree, None, export_tree_data)
|
||||
|
||||
return export_tree
|
||||
prompt = assign_replies(messages_by_parent[None][0])
|
||||
return ExportMessageTree(message_tree_id=str(message_tree_id), prompt=prompt)
|
||||
|
||||
|
||||
def write_trees_to_file(file, trees: list[ExportMessageTree], use_compression: bool = True) -> None:
|
||||
|
||||
Reference in New Issue
Block a user