use sqlalchemy for query_reviews_for_message() in TreeManager

This commit is contained in:
Andreas Köpf
2023-01-16 20:31:29 +01:00
parent 6ccbd38462
commit 2d4e39cf5d
+14 -11
View File
@@ -9,7 +9,7 @@ import pydantic
from loguru import logger
from oasst_backend.api.v1.utils import prepare_conversation, prepare_conversation_message_list
from oasst_backend.config import TreeManagerConfiguration, settings
from oasst_backend.models import Message, MessageReaction, MessageTreeState, TextLabels, message_tree_state
from oasst_backend.models import Message, MessageReaction, MessageTreeState, Task, TextLabels, message_tree_state
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.utils.database_utils import CommitMode, async_managed_tx_method, managed_tx_method
from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingModel, HfUrl, HuggingFaceAPI
@@ -840,15 +840,13 @@ LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Rankin
return query.scalar()
def query_reviews_for_message(self, message_id: UUID) -> list[TextLabels]:
sql_qry = """
SELECT tl.*
FROM task t
INNER JOIN text_labels tl ON tl.id = t.id
WHERE t.done = TRUE
AND tl.message_id = :message_id
"""
r = self.db.execute(text(sql_qry), {"message_id": message_id})
return [TextLabels.from_orm(x) for x in r.all()]
qry = (
self.db.query(TextLabels)
.select_from(Task)
.join(TextLabels, Task.id == TextLabels.id)
.filter(Task.done, TextLabels.message_id == message_id)
)
return qry.all()
@managed_tx_method(CommitMode.FLUSH)
def _insert_tree_state(
@@ -911,7 +909,12 @@ if __name__ == "__main__":
# print("query_extendible_parents", tm.query_extendible_parents())
# print("query_tree_size", tm.query_tree_size(message_tree_id=UUID("bdf434cf-4df5-4b74-949c-a5a157bc3292")))
print("next_task:", tm.next_task())
print(
"query_reviews_for_message",
tm.query_reviews_for_message(message_id=UUID("6a444493-0d48-4316-a9f1-7e263f5a2473")),
)
# print("next_task:", tm.next_task())
# print(
# "query_tree_ranking_results", tm.query_tree_ranking_results(UUID("6036f58f-41b5-48c4-bdd9-b16f34ab1312"))