fix export & add emojis (#1004)

This commit is contained in:
Andreas Köpf
2023-01-29 23:28:47 +01:00
committed by GitHub
parent 29b540a1d8
commit eda275bef1
2 changed files with 31 additions and 23 deletions
+15 -6
View File
@@ -318,34 +318,43 @@ def main():
parser.add_argument(
"--print-openapi-schema",
default=False,
help="Dumps the openapi schema to stdout",
action=argparse.BooleanOptionalAction,
action="store_true",
)
parser.add_argument("--host", help="The host to run the server", default="0.0.0.0")
parser.add_argument("--port", help="The port to run the server", default=8080)
parser.add_argument(
"--export", help="Export all trees which are ready for exporting.", action=argparse.BooleanOptionalAction
"--export",
default=False,
help="Export all trees which are ready for exporting.",
action="store_true",
)
parser.add_argument(
"--export-file",
type=str,
help="Name of file to export trees to. If not provided when exporting, output will be send to STDOUT",
)
parser.add_argument(
"--retry-scoring",
default=False,
help="Retry scoring failed message trees",
action=argparse.BooleanOptionalAction,
action="store_true",
)
args = parser.parse_args()
if args.print_openapi_schema:
print(get_openapi_schema())
elif args.export:
if args.export:
use_compression: bool = ".gz" in args.export_file
export_ready_trees(file=args.export_file, use_compression=use_compression)
elif args.retry_scoring:
if args.retry_scoring:
retry_scoring_failed_message_trees()
else:
if not (args.export or args.print_openapi_schema or args.retry_scoring):
uvicorn.run(app, host=args.host, port=args.port)
+16 -17
View File
@@ -20,11 +20,12 @@ class ExportMessageNode(BaseModel):
rank: int | None
synthetic: bool | None
model_name: str | None
emojis: dict[str, int] | None
replies: list[ExportMessageNode] | None
@classmethod
def prep_message_export(cls, message: Message) -> ExportMessageNode:
return cls(
@staticmethod
def prep_message_export(message: Message) -> ExportMessageNode:
return ExportMessageNode(
message_id=str(message.id),
parent_id=str(message.parent_id) if message.parent_id else None,
text=str(message.payload.payload.text),
@@ -33,6 +34,7 @@ class ExportMessageNode(BaseModel):
review_count=message.review_count,
synthetic=message.synthetic,
model_name=message.model_name,
emojis=message.emojis,
rank=message.rank,
)
@@ -43,23 +45,20 @@ class ExportMessageTree(BaseModel):
def build_export_tree(message_tree_id: str, messages: list[Message]) -> ExportMessageTree:
export_tree = ExportMessageTree(message_tree_id=str(message_tree_id))
export_tree_data = [ExportMessageNode.prep_message_export(m) for m in messages]
export_messages = [ExportMessageNode.prep_message_export(m) for m in messages]
message_parents = defaultdict(list)
for message in export_tree_data:
message_parents[message.parent_id].append(message)
messages_by_parent = defaultdict(list)
for message in export_messages:
messages_by_parent[message.parent_id].append(message)
def build_tree(tree: dict, parent: Optional[str], messages: list[Message]):
children = message_parents[parent]
tree.replies = children
def assign_replies(node: ExportMessageNode) -> ExportMessageNode:
node.replies = messages_by_parent[node.message_id]
for child in node.replies:
assign_replies(child)
return node
for idx, child in enumerate(tree.replies):
build_tree(tree.replies[idx], child.message_id, messages)
build_tree(export_tree, None, export_tree_data)
return export_tree
prompt = assign_replies(messages_by_parent[None][0])
return ExportMessageTree(message_tree_id=str(message_tree_id), prompt=prompt)
def write_trees_to_file(file, trees: list[ExportMessageTree], use_compression: bool = True) -> None: