Files
Open-Assistant/backend/oasst_backend/task_repository.py
T
Andreas Köpf 14fa08e2e7 Message tree state machine (#555)
* add query_incomplete_rankings()

* Add SQL queries for TreeManager task selection

* first working version of TreeManager.next_task()

* remove old generate_task(), add mandatory_labels to text_labels task

* Add ConversationMessage list to Ranking tasks

* add more sophisticated sql queries to find extendible trees

* add TreeManager.query_extendible_parents()

* fix task validation, seed data insertion (reviewed)

* provide user for task selection in text-frontend

* enter 'growing' state

* enter 'aborted_low_grade' state

* enter 'ranking' state

* check tree 'growing' state upon relpy insertion

* exclude user from labeling their own messages (added DEBUG_ALLOW_SELF_LABELING setting)

* add DEBUG_ALLOW_SELF_LABELING to docker-compose.yaml

* fix ranking submission

* add query_tree_ranking_results()

* add ranked_message_ids to RankingReactionPayload

* fix reply_messages instead of prompt_messages

* incorment 'ranking_count' of ranked replies

* added logic to check_condition_for_scoring_state

* changes to msg_tree_state_machine

* pre-commit changes

* enter 'ready_for_scoring' state

* re-add HF embedding call (lost during merge)

* use prepare_conversation() helper for seed-data creation

* Partially add user specified task selection

Co-authored-by: Daniel Hug <danielpatrickhug@gmail.com>
2023-01-11 10:54:03 +01:00

208 lines
8.3 KiB
Python

from typing import Optional
from uuid import UUID
import oasst_backend.models.db_payload as db_payload
from loguru import logger
from oasst_backend.models import ApiClient, Task
from oasst_backend.models.payload_column_type import PayloadContainer
from oasst_backend.user_repository import UserRepository
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
from starlette.status import HTTP_404_NOT_FOUND
def validate_frontend_message_id(message_id: str) -> None:
# TODO: Should it be replaced with fastapi/pydantic validation?
if not isinstance(message_id, str):
raise OasstError(
f"message_id must be string, not {type(message_id)}", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID
)
if not message_id:
raise OasstError("message_id must not be empty", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID)
class TaskRepository:
def __init__(
self,
db: Session,
api_client: ApiClient,
client_user: Optional[protocol_schema.User],
user_repository: UserRepository,
):
self.db = db
self.api_client = api_client
self.user_repository = user_repository
self.user = self.user_repository.lookup_client_user(client_user, create_missing=True)
self.user_id = self.user.id if self.user else None
def store_task(
self,
task: protocol_schema.Task,
message_tree_id: UUID = None,
parent_message_id: UUID = None,
collective: bool = False,
) -> Task:
payload: db_payload.TaskPayload
match type(task):
case protocol_schema.SummarizeStoryTask:
payload = db_payload.SummarizationStoryPayload(story=task.story)
case protocol_schema.RateSummaryTask:
payload = db_payload.RateSummaryPayload(
full_text=task.full_text, summary=task.summary, scale=task.scale
)
case protocol_schema.InitialPromptTask:
payload = db_payload.InitialPromptPayload(hint=task.hint)
case protocol_schema.PrompterReplyTask:
payload = db_payload.PrompterReplyPayload(conversation=task.conversation, hint=task.hint)
case protocol_schema.AssistantReplyTask:
payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation)
case protocol_schema.RankInitialPromptsTask:
payload = db_payload.RankInitialPromptsPayload(type=task.type, prompt_messages=task.prompt_messages)
case protocol_schema.RankPrompterRepliesTask:
payload = db_payload.RankPrompterRepliesPayload(
type=task.type, conversation=task.conversation, reply_messages=task.reply_messages
)
case protocol_schema.RankAssistantRepliesTask:
payload = db_payload.RankAssistantRepliesPayload(
type=task.type, conversation=task.conversation, reply_messages=task.reply_messages
)
case protocol_schema.LabelInitialPromptTask:
payload = db_payload.LabelInitialPromptPayload(
type=task.type, message_id=task.message_id, prompt=task.prompt, valid_labels=task.valid_labels
)
case protocol_schema.LabelPrompterReplyTask:
payload = db_payload.LabelPrompterReplyPayload(
type=task.type,
message_id=task.message_id,
conversation=task.conversation,
reply=task.reply,
valid_labels=task.valid_labels,
mandatory_labels=task.mandatory_labels,
)
case protocol_schema.LabelAssistantReplyTask:
payload = db_payload.LabelAssistantReplyPayload(
type=task.type,
message_id=task.message_id,
conversation=task.conversation,
reply=task.reply,
valid_labels=task.valid_labels,
mandatory_labels=task.mandatory_labels,
)
case _:
raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE)
task_model = self.insert_task(
payload=payload,
id=task.id,
message_tree_id=message_tree_id,
parent_message_id=parent_message_id,
collective=collective,
)
assert task_model.id == task.id
return task_model
def bind_frontend_message_id(self, task_id: UUID, frontend_message_id: str):
validate_frontend_message_id(frontend_message_id)
# find task
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
if task is None:
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
if task.expired:
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
if task.done or task.ack is not None:
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
task.frontend_message_id = frontend_message_id
task.ack = True
# ToDo: check race-condition, transaction
self.db.add(task)
self.db.commit()
def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False):
"""
Mark task as done. No further messages will be accepted for this task.
"""
validate_frontend_message_id(frontend_message_id)
task = self.task_repository.fetch_task_by_frontend_message_id(frontend_message_id)
if not task:
raise OasstError(
f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND
)
if task.expired:
raise OasstError("Task already expired", OasstErrorCode.TASK_EXPIRED)
if not allow_personal_tasks and not task.collective:
raise OasstError("This is not a collective task", OasstErrorCode.TASK_NOT_COLLECTIVE)
if task.done:
raise OasstError("Allready closed", OasstErrorCode.TASK_ALREADY_DONE)
task.done = True
self.db.add(task)
self.db.commit()
def acknowledge_task_failure(self, task_id):
# find task
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
if task is None:
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
if task.expired:
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
if task.done or task.ack is not None:
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
task.ack = False
# ToDo: check race-condition, transaction
self.db.add(task)
self.db.commit()
def insert_task(
self,
payload: db_payload.TaskPayload,
id: UUID = None,
message_tree_id: UUID = None,
parent_message_id: UUID = None,
collective: bool = False,
) -> Task:
c = PayloadContainer(payload=payload)
task = Task(
id=id,
user_id=self.user_id,
payload_type=type(payload).__name__,
payload=c,
api_client_id=self.api_client.id,
message_tree_id=message_tree_id,
parent_message_id=parent_message_id,
collective=collective,
)
logger.debug(f"inserting {task=}")
self.db.add(task)
self.db.commit()
self.db.refresh(task)
return task
def fetch_task_by_frontend_message_id(self, message_id: str) -> Task:
validate_frontend_message_id(message_id)
task = (
self.db.query(Task)
.filter(Task.api_client_id == self.api_client.id, Task.frontend_message_id == message_id)
.one_or_none()
)
return task
def fetch_task_by_id(self, task_id: UUID) -> Task:
task = self.db.query(Task).filter(Task.api_client_id == self.api_client.id, Task.id == task_id).one_or_none()
return task