From 2d4e39cf5dd68c3582f1c8b8b2d08d32825c676e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Mon, 16 Jan 2023 20:31:29 +0100 Subject: [PATCH] use sqlalchemy for query_reviews_for_message() in TreeManager --- backend/oasst_backend/tree_manager.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index 34672e2e..225b0146 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -9,7 +9,7 @@ import pydantic from loguru import logger from oasst_backend.api.v1.utils import prepare_conversation, prepare_conversation_message_list from oasst_backend.config import TreeManagerConfiguration, settings -from oasst_backend.models import Message, MessageReaction, MessageTreeState, TextLabels, message_tree_state +from oasst_backend.models import Message, MessageReaction, MessageTreeState, Task, TextLabels, message_tree_state from oasst_backend.prompt_repository import PromptRepository from oasst_backend.utils.database_utils import CommitMode, async_managed_tx_method, managed_tx_method from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingModel, HfUrl, HuggingFaceAPI @@ -840,15 +840,13 @@ LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Rankin return query.scalar() def query_reviews_for_message(self, message_id: UUID) -> list[TextLabels]: - sql_qry = """ -SELECT tl.* -FROM task t - INNER JOIN text_labels tl ON tl.id = t.id -WHERE t.done = TRUE - AND tl.message_id = :message_id -""" - r = self.db.execute(text(sql_qry), {"message_id": message_id}) - return [TextLabels.from_orm(x) for x in r.all()] + qry = ( + self.db.query(TextLabels) + .select_from(Task) + .join(TextLabels, Task.id == TextLabels.id) + .filter(Task.done, TextLabels.message_id == message_id) + ) + return qry.all() @managed_tx_method(CommitMode.FLUSH) def _insert_tree_state( @@ -911,7 +909,12 @@ if __name__ == "__main__": # print("query_extendible_parents", tm.query_extendible_parents()) # print("query_tree_size", tm.query_tree_size(message_tree_id=UUID("bdf434cf-4df5-4b74-949c-a5a157bc3292"))) - print("next_task:", tm.next_task()) + print( + "query_reviews_for_message", + tm.query_reviews_for_message(message_id=UUID("6a444493-0d48-4316-a9f1-7e263f5a2473")), + ) + + # print("next_task:", tm.next_task()) # print( # "query_tree_ranking_results", tm.query_tree_ranking_results(UUID("6036f58f-41b5-48c4-bdd9-b16f34ab1312"))