mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
use sqlalchemy for query_reviews_for_message() in TreeManager
This commit is contained in:
@@ -9,7 +9,7 @@ import pydantic
|
||||
from loguru import logger
|
||||
from oasst_backend.api.v1.utils import prepare_conversation, prepare_conversation_message_list
|
||||
from oasst_backend.config import TreeManagerConfiguration, settings
|
||||
from oasst_backend.models import Message, MessageReaction, MessageTreeState, TextLabels, message_tree_state
|
||||
from oasst_backend.models import Message, MessageReaction, MessageTreeState, Task, TextLabels, message_tree_state
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.utils.database_utils import CommitMode, async_managed_tx_method, managed_tx_method
|
||||
from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingModel, HfUrl, HuggingFaceAPI
|
||||
@@ -840,15 +840,13 @@ LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Rankin
|
||||
return query.scalar()
|
||||
|
||||
def query_reviews_for_message(self, message_id: UUID) -> list[TextLabels]:
|
||||
sql_qry = """
|
||||
SELECT tl.*
|
||||
FROM task t
|
||||
INNER JOIN text_labels tl ON tl.id = t.id
|
||||
WHERE t.done = TRUE
|
||||
AND tl.message_id = :message_id
|
||||
"""
|
||||
r = self.db.execute(text(sql_qry), {"message_id": message_id})
|
||||
return [TextLabels.from_orm(x) for x in r.all()]
|
||||
qry = (
|
||||
self.db.query(TextLabels)
|
||||
.select_from(Task)
|
||||
.join(TextLabels, Task.id == TextLabels.id)
|
||||
.filter(Task.done, TextLabels.message_id == message_id)
|
||||
)
|
||||
return qry.all()
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def _insert_tree_state(
|
||||
@@ -911,7 +909,12 @@ if __name__ == "__main__":
|
||||
# print("query_extendible_parents", tm.query_extendible_parents())
|
||||
# print("query_tree_size", tm.query_tree_size(message_tree_id=UUID("bdf434cf-4df5-4b74-949c-a5a157bc3292")))
|
||||
|
||||
print("next_task:", tm.next_task())
|
||||
print(
|
||||
"query_reviews_for_message",
|
||||
tm.query_reviews_for_message(message_id=UUID("6a444493-0d48-4316-a9f1-7e263f5a2473")),
|
||||
)
|
||||
|
||||
# print("next_task:", tm.next_task())
|
||||
|
||||
# print(
|
||||
# "query_tree_ranking_results", tm.query_tree_ranking_results(UUID("6036f58f-41b5-48c4-bdd9-b16f34ab1312"))
|
||||
|
||||
Reference in New Issue
Block a user