From 35e0c32a08a86235ce0cf410cf18787bbbba0b4b Mon Sep 17 00:00:00 2001 From: alexandrelefourner Date: Fri, 30 Dec 2022 18:05:23 +0100 Subject: [PATCH] Updating work_package to task --- backend/main.py | 22 +-- backend/oasst_backend/api/v1/tasks.py | 12 +- backend/oasst_backend/exceptions.py | 14 +- backend/oasst_backend/journal_writer.py | 28 +-- backend/oasst_backend/models/__init__.py | 4 +- backend/oasst_backend/models/db_payload.py | 2 +- backend/oasst_backend/models/message.py | 2 +- .../oasst_backend/models/message_reaction.py | 4 +- backend/oasst_backend/models/user_stats.py | 2 +- backend/oasst_backend/models/work_package.py | 4 +- backend/oasst_backend/prompt_repository.py | 168 +++++++++--------- 11 files changed, 131 insertions(+), 131 deletions(-) diff --git a/backend/main.py b/backend/main.py index fb3d14b9..51f95241 100644 --- a/backend/main.py +++ b/backend/main.py @@ -141,36 +141,36 @@ if settings.DEBUG_USE_SEED_DATA: ] for p in dummy_messages: - wp = pr.fetch_workpackage_by_message_id(p.task_message_id) - if wp and not wp.ack: + task = pr.fetch_task_by_message_id(p.task_message_id) + if task and not task.ack: logger.warning("Deleting unacknowledged seed data work package") - db.delete(wp) - wp = None - if not wp: + db.delete(task) + task = None + if not task: if p.parent_message_id is None: - wp = pr.store_task( - protocol_schema.InitialPromptTask(hint=""), thread_id=None, parent_message_id=None + task = pr.store_task( + protocol_schema.InitialPromptTask(hint=""), message_tree_id=None, parent_message_id=None ) else: print("p.parent_message_id", p.parent_message_id) parent_message = pr.fetch_message_by_frontend_message_id(p.parent_message_id, fail_if_missing=True) - wp = pr.store_task( + task = pr.store_task( protocol_schema.AssistantReplyTask( conversation=protocol_schema.Conversation( messages=[protocol_schema.ConversationMessage(text="dummy", is_assistant=False)] ) ), - thread_id=parent_message.thread_id, + message_tree_id=parent_message.message_tree_id, parent_message_id=parent_message.id, ) - pr.bind_frontend_message_id(wp.id, p.task_message_id) + pr.bind_frontend_message_id(task.id, p.task_message_id) message = pr.store_text_reply(p.text, p.task_message_id, p.user_message_id) logger.info( f"Inserted: message_id: {message.id}, payload: {message.payload.payload}, parent_message_id: {message.parent_id}" ) else: - logger.debug(f"seed data work_package found: {wp.id}") + logger.debug(f"seed data task found: {task.id}") logger.info("Seed data check completed") except Exception: diff --git a/backend/oasst_backend/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py index c8975ac5..220d5949 100644 --- a/backend/oasst_backend/api/v1/tasks.py +++ b/backend/oasst_backend/api/v1/tasks.py @@ -18,7 +18,7 @@ router = APIRouter() def generate_task( request: protocol_schema.TaskRequest, pr: PromptRepository ) -> Tuple[protocol_schema.Task, Optional[UUID], Optional[UUID]]: - thread_id = None + message_tree_id = None parent_message_id = None match request.type: @@ -63,7 +63,7 @@ def generate_task( ] task = protocol_schema.UserReplyTask(conversation=protocol_schema.Conversation(messages=messages)) - thread_id = messages[-1].thread_id + message_tree_id = messages[-1].message_tree_id parent_message_id = messages[-1].id case protocol_schema.TaskRequestType.assistant_reply: logger.info("Generating a AssistantReplyTask.") @@ -74,7 +74,7 @@ def generate_task( ] task = protocol_schema.AssistantReplyTask(conversation=protocol_schema.Conversation(messages=messages)) - thread_id = messages[-1].thread_id + message_tree_id = messages[-1].message_tree_id parent_message_id = messages[-1].id case protocol_schema.TaskRequestType.rank_initial_prompts: logger.info("Generating a RankInitialPromptsTask.") @@ -121,7 +121,7 @@ def generate_task( logger.info(f"Generated {task=}.") - return task, thread_id, parent_message_id + return task, message_tree_id, parent_message_id @router.post("/", response_model=protocol_schema.AnyTask) # work with Union once more types are added @@ -138,8 +138,8 @@ def request_task( try: pr = PromptRepository(db, api_client, request.user) - task, thread_id, parent_message_id = generate_task(request, pr) - pr.store_task(task, thread_id, parent_message_id, request.collective) + task, message_tree_id, parent_message_id = generate_task(request, pr) + pr.store_task(task, message_tree_id, parent_message_id, request.collective) except OasstError: raise diff --git a/backend/oasst_backend/exceptions.py b/backend/oasst_backend/exceptions.py index 1c30e453..237d8284 100644 --- a/backend/oasst_backend/exceptions.py +++ b/backend/oasst_backend/exceptions.py @@ -35,13 +35,13 @@ class OasstErrorCode(IntEnum): USER_NOT_SPECIFIED = 2005 NO_THREADS_FOUND = 2006 NO_REPLIES_FOUND = 2007 - WORK_PACKAGE_NOT_FOUND = 2100 - WORK_PACKAGE_EXPIRED = 2101 - WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH = 2102 - WORK_PACKAGE_ALREADY_UPDATED = 2103 - WORK_PACKAGE_NOT_ACK = 2104 - WORK_PACKAGE_ALREADY_DONE = 2105 - WORK_PACKAGE_NOT_COLLECTIVE = 2106 + TASK_NOT_FOUND = 2100 + TASK_EXPIRED = 2101 + TASK_PAYLOAD_TYPE_MISMATCH = 2102 + TASK_ALREADY_UPDATED = 2103 + TASK_NOT_ACK = 2104 + TASK_ALREADY_DONE = 2105 + TASK_NOT_COLLECTIVE = 2106 class OasstError(Exception): diff --git a/backend/oasst_backend/journal_writer.py b/backend/oasst_backend/journal_writer.py index d9cf5e6e..60508433 100644 --- a/backend/oasst_backend/journal_writer.py +++ b/backend/oasst_backend/journal_writer.py @@ -3,7 +3,7 @@ import enum from typing import Literal, Optional from uuid import UUID -from oasst_backend.models import ApiClient, Journal, Person, WorkPackage +from oasst_backend.models import ApiClient, Journal, Person, Task from oasst_backend.models.payload_column_type import PayloadContainer, payload_type from oasst_shared.utils import utcnow from pydantic import BaseModel @@ -24,7 +24,7 @@ class JournalEvent(BaseModel): type: str user_id: Optional[UUID] message_id: Optional[UUID] - workpackage_id: Optional[UUID] + task_id: Optional[UUID] task_type: Optional[str] @@ -54,30 +54,30 @@ class JournalWriter: self.user = user self.user_id = self.user.id if self.user else None - def log_text_reply(self, work_package: WorkPackage, message_id: UUID, role: str, length: int) -> Journal: + def log_text_reply(self, task: Task, message_id: UUID, role: str, length: int) -> Journal: return self.log( - task_type=work_package.payload_type, + task_type=task.payload_type, event_type=JournalEventType.text_reply_to_message, payload=TextReplyEvent(role=role, length=length), - workpackage_id=work_package.id, + task_id=task.id, message_id=message_id, ) - def log_rating(self, work_package: WorkPackage, message_id: UUID, rating: int) -> Journal: + def log_rating(self, task: Task, message_id: UUID, rating: int) -> Journal: return self.log( - task_type=work_package.payload_type, + task_type=task.payload_type, event_type=JournalEventType.message_rating, payload=RatingEvent(rating=rating), - workpackage_id=work_package.id, + task_id=task.id, message_id=message_id, ) - def log_ranking(self, work_package: WorkPackage, message_id: UUID, ranking: list[int]) -> Journal: + def log_ranking(self, task: Task, message_id: UUID, ranking: list[int]) -> Journal: return self.log( - task_type=work_package.payload_type, + task_type=task.payload_type, event_type=JournalEventType.message_ranking, payload=RankingEvent(ranking=ranking), - workpackage_id=work_package.id, + task_id=task.id, message_id=message_id, ) @@ -87,7 +87,7 @@ class JournalWriter: payload: JournalEvent, task_type: str, event_type: str = None, - workpackage_id: Optional[UUID] = None, + task_id: Optional[UUID] = None, message_id: Optional[UUID] = None, commit: bool = True, ) -> Journal: @@ -101,8 +101,8 @@ class JournalWriter: payload.user_id = self.user_id if payload.message_id is None: payload.message_id = message_id - if payload.workpackage_id is None: - payload.workpackage_id = workpackage_id + if payload.task_id is None: + payload.task_id = task_id if payload.task_type is None: payload.task_type = task_type diff --git a/backend/oasst_backend/models/__init__.py b/backend/oasst_backend/models/__init__.py index d85df2ba..99030517 100644 --- a/backend/oasst_backend/models/__init__.py +++ b/backend/oasst_backend/models/__init__.py @@ -6,7 +6,7 @@ from .user_stats import UserStats from .message import Message from .message_reaction import MessageReaction from .text_labels import TextLabels -from .work_package import WorkPackage +from .task import Task __all__ = [ "ApiClient", @@ -14,7 +14,7 @@ __all__ = [ "UserStats", "Message", "MessageReaction", - "WorkPackage", + "Task", "TextLabels", "Journal", "JournalIntegration", diff --git a/backend/oasst_backend/models/db_payload.py b/backend/oasst_backend/models/db_payload.py index 7c952284..b44228e0 100644 --- a/backend/oasst_backend/models/db_payload.py +++ b/backend/oasst_backend/models/db_payload.py @@ -45,7 +45,7 @@ class AssistantReplyPayload(TaskPayload): @payload_type -class PostPayload(BaseModel): +class MessagePayload(BaseModel): text: str diff --git a/backend/oasst_backend/models/message.py b/backend/oasst_backend/models/message.py index 02ffcf3a..37babdbb 100644 --- a/backend/oasst_backend/models/message.py +++ b/backend/oasst_backend/models/message.py @@ -21,7 +21,7 @@ class Message(SQLModel, table=True): ) parent_id: UUID = Field(nullable=True) message_tree_id: UUID = Field(nullable=False, index=True) - workpackage_id: UUID = Field(nullable=True, index=True) + task_id: UUID = Field(nullable=True, index=True) user_id: UUID = Field(nullable=True, foreign_key="user.id", index=True) role: str = Field(nullable=False, max_length=128) api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id") diff --git a/backend/oasst_backend/models/message_reaction.py b/backend/oasst_backend/models/message_reaction.py index 1761de89..9c93961f 100644 --- a/backend/oasst_backend/models/message_reaction.py +++ b/backend/oasst_backend/models/message_reaction.py @@ -13,8 +13,8 @@ from .payload_column_type import PayloadContainer, payload_column_type class MessageReaction(SQLModel, table=True): __tablename__ = "message_reaction" - work_package_id: Optional[UUID] = Field( - sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("work_package.id"), nullable=False, primary_key=True) + task_id: Optional[UUID] = Field( + sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("task.id"), nullable=False, primary_key=True) ) user_id: UUID = Field( sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), nullable=False, primary_key=True) diff --git a/backend/oasst_backend/models/user_stats.py b/backend/oasst_backend/models/user_stats.py index 0825513c..a92775b9 100644 --- a/backend/oasst_backend/models/user_stats.py +++ b/backend/oasst_backend/models/user_stats.py @@ -23,6 +23,6 @@ class UserStats(SQLModel, table=True): messages: int = 0 # messages sent by user upvotes: int = 0 # received upvotes (form other users) downvotes: int = 0 # received downvotes (from other users) - work_reward: int = 0 # reward for workpackage completions + task_reward: int = 0 # reward for task completions compare_wins: int = 0 # num times user's message won compare tasks compare_losses: int = 0 # num times users's message lost compare tasks diff --git a/backend/oasst_backend/models/work_package.py b/backend/oasst_backend/models/work_package.py index 612cf243..e2a4358e 100644 --- a/backend/oasst_backend/models/work_package.py +++ b/backend/oasst_backend/models/work_package.py @@ -11,8 +11,8 @@ from sqlmodel import Field, SQLModel from .payload_column_type import PayloadContainer, payload_column_type -class WorkPackage(SQLModel, table=True): - __tablename__ = "work_package" +class Task(SQLModel, table=True): + __tablename__ = "task" id: Optional[UUID] = Field( sa_column=sa.Column( diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 6ebd58b5..1d30623b 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -7,7 +7,7 @@ import oasst_backend.models.db_payload as db_payload from loguru import logger from oasst_backend.exceptions import OasstError, OasstErrorCode from oasst_backend.journal_writer import JournalWriter -from oasst_backend.models import ApiClient, User, Message, MessageReaction, TextLabels, WorkPackage +from oasst_backend.models import ApiClient, User, Message, MessageReaction, TextLabels, Task from oasst_backend.models.payload_column_type import PayloadContainer from oasst_shared.schemas import protocol as protocol_schema from sqlmodel import Session, func @@ -61,17 +61,17 @@ class PromptRepository: self.validate_message_id(message_id) # find work package - work_pack: WorkPackage = ( - self.db.query(WorkPackage) - .filter(WorkPackage.id == task_id, WorkPackage.api_client_id == self.api_client.id) + work_pack: Task = ( + self.db.query(Task) + .filter(Task.id == task_id, Task.api_client_id == self.api_client.id) .first() ) if work_pack is None: - raise OasstError(f"WorkPackage for task {task_id} not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND) + raise OasstError(f"Task for task {task_id} not found", OasstErrorCode.TASK_NOT_FOUND) if work_pack.expired: - raise OasstError("WorkPackage already expired.", OasstErrorCode.WORK_PACKAGE_EXPIRED) + raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED) if work_pack.done or work_pack.ack is not None: - raise OasstError("WorkPackage already updated.", OasstErrorCode.WORK_PACKAGE_ALREADY_UPDATED) + raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED) work_pack.frontend_ref_message_id = message_id work_pack.ack = True @@ -81,17 +81,17 @@ class PromptRepository: def acknowledge_task_failure(self, task_id): # find work package - work_pack: WorkPackage = ( - self.db.query(WorkPackage) - .filter(WorkPackage.id == task_id, WorkPackage.api_client_id == self.api_client.id) + work_pack: Task = ( + self.db.query(Task) + .filter(Task.id == task_id, Task.api_client_id == self.api_client.id) .first() ) if work_pack is None: - raise OasstError(f"WorkPackage for task {task_id} not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND) + raise OasstError(f"Task for task {task_id} not found", OasstErrorCode.TASK_NOT_FOUND) if work_pack.expired: - raise OasstError("WorkPackage already expired.", OasstErrorCode.WORK_PACKAGE_EXPIRED) + raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED) if work_pack.done or work_pack.ack is not None: - raise OasstError("WorkPackage already updated.", OasstErrorCode.WORK_PACKAGE_ALREADY_UPDATED) + raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED) work_pack.ack = False # ToDo: check race-condition, transaction @@ -109,36 +109,36 @@ class PromptRepository: raise OasstError(f"Message with message_id {frontend_message_id} not found.", OasstErrorCode.POST_NOT_FOUND) return message - def fetch_workpackage_by_message_id(self, message_id: str) -> WorkPackage: + def fetch_task_by_message_id(self, message_id: str) -> Task: self.validate_message_id(message_id) - work_pack = ( - self.db.query(WorkPackage) - .filter(WorkPackage.api_client_id == self.api_client.id, WorkPackage.frontend_ref_message_id == message_id) + task = ( + self.db.query(Task) + .filter(Task.api_client_id == self.api_client.id, Task.frontend_ref_message_id == message_id) .one_or_none() ) - return work_pack + return task def store_text_reply(self, text: str, message_id: str, user_message_id: str, role: str = None) -> Message: self.validate_message_id(message_id) self.validate_message_id(user_message_id) - wp = self.fetch_workpackage_by_message_id(message_id) + task = self.fetch_task_by_message_id(message_id) - if wp is None: - raise OasstError(f"WorkPackage for {message_id=} not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND) - if wp.expired: - raise OasstError("WorkPackage already expired.", OasstErrorCode.WORK_PACKAGE_EXPIRED) - if not wp.ack: - raise OasstError("WorkPackage is not acknowledged.", OasstErrorCode.WORK_PACKAGE_NOT_ACK) - if wp.done: - raise OasstError("WorkPackage already done.", OasstErrorCode.WORK_PACKAGE_ALREADY_DONE) + if task is None: + raise OasstError(f"Task for {message_id=} not found", OasstErrorCode.TASK_NOT_FOUND) + if task.expired: + raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED) + if not task.ack: + raise OasstError("Task is not acknowledged.", OasstErrorCode.TASK_NOT_ACK) + if task.done: + raise OasstError("Task already done.", OasstErrorCode.TASK_ALREADY_DONE) # If there's no parent message assume user started new conversation role = "user" depth = 0 - if wp.parent_message_id: - parent_message = self.fetch_message(wp.parent_message_id) + if task.parent_message_id: + parent_message = self.fetch_message(task.parent_message_id) parent_message.children_count += 1 self.db.add(parent_message) @@ -153,29 +153,29 @@ class PromptRepository: user_message = self.insert_message( message_id=new_message_id, frontend_message_id=user_message_id, - parent_id=wp.parent_message_id, - message_tree_id=wp.message_tree_id or new_message_id, - workpackage_id=wp.id, + parent_id=task.parent_message_id, + message_tree_id=task.message_tree_id or new_message_id, + task_id=task.id, role=role, payload=db_payload.MessagePayload(text=text), depth=depth, ) - if not wp.collective: - wp.done = True - self.db.add(wp) + if not task.collective: + task.done = True + self.db.add(task) self.db.commit() - self.journal.log_text_reply(work_package=wp, message_id=new_message_id, role=role, length=len(text)) + self.journal.log_text_reply(task=task, message_id=new_message_id, role=role, length=len(text)) return user_message def store_rating(self, rating: protocol_schema.MessageRating) -> MessageReaction: message = self.fetch_message_by_frontend_message_id(rating.message_id, fail_if_missing=True) - work_package = self.fetch_workpackage_by_message_id(rating.message_id) - work_payload: db_payload.RateSummaryPayload = work_package.payload.payload + task = self.fetch_task_by_message_id(rating.message_id) + work_payload: db_payload.RateSummaryPayload = task.payload.payload if type(work_payload) != db_payload.RateSummaryPayload: raise OasstError( - f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RateSummaryPayload}", - OasstErrorCode.WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH, + f"task payload type mismatch: {type(work_payload)=} != {db_payload.RateSummaryPayload}", + OasstErrorCode.TASK_PAYLOAD_TYPE_MISMATCH, ) if rating.rating < work_payload.scale.min or rating.rating > work_payload.scale.max: @@ -187,23 +187,23 @@ class PromptRepository: # store reaction to message reaction_payload = db_payload.RatingReactionPayload(rating=rating.rating) reaction = self.insert_reaction(message.id, reaction_payload) - if not work_package.collective: - work_package.done = True - self.db.add(work_package) + if not task.collective: + task.done = True + self.db.add(task) - self.journal.log_rating(work_package, message_id=message.id, rating=rating.rating) - logger.info(f"Ranking {rating.rating} stored for work_package {work_package.id}.") + self.journal.log_rating(task, message_id=message.id, rating=rating.rating) + logger.info(f"Ranking {rating.rating} stored for task {task.id}.") return reaction def store_ranking(self, ranking: protocol_schema.MessageRanking) -> MessageReaction: - # fetch work_package - work_package = self.fetch_workpackage_by_message_id(ranking.message_id) - if not work_package.collective: - work_package.done = True - self.db.add(work_package) + # fetch task + task = self.fetch_task_by_message_id(ranking.message_id) + if not task.collective: + task.done = True + self.db.add(task) work_payload: db_payload.RankConversationRepliesPayload | db_payload.RankInitialPromptsPayload = ( - work_package.payload.payload + task.payload.payload ) match type(work_payload): @@ -219,11 +219,11 @@ class PromptRepository: # store reaction to message reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking) - reaction = self.insert_reaction(work_package.id, reaction_payload) + reaction = self.insert_reaction(task.id, reaction_payload) # TODO: resolve message_id - self.journal.log_ranking(work_package, message_id=None, ranking=ranking.ranking) + self.journal.log_ranking(task, message_id=None, ranking=ranking.ranking) - logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.") + logger.info(f"Ranking {ranking.ranking} stored for task {task.id}.") return reaction @@ -237,18 +237,18 @@ class PromptRepository: # store reaction to message reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking) - reaction = self.insert_reaction(work_package.id, reaction_payload) + reaction = self.insert_reaction(task.id, reaction_payload) # TODO: resolve message_id - self.journal.log_ranking(work_package, message_id=None, ranking=ranking.ranking) + self.journal.log_ranking(task, message_id=None, ranking=ranking.ranking) - logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.") + logger.info(f"Ranking {ranking.ranking} stored for task {task.id}.") return reaction case _: raise OasstError( - f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RankConversationRepliesPayload}", - OasstErrorCode.WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH, + f"task payload type mismatch: {type(work_payload)=} != {db_payload.RankConversationRepliesPayload}", + OasstErrorCode.TASK_PAYLOAD_TYPE_MISMATCH, ) def store_task( @@ -257,7 +257,7 @@ class PromptRepository: message_tree_id: UUID = None, parent_message_id: UUID = None, collective: bool = False, - ) -> WorkPackage: + ) -> Task: payload: db_payload.TaskPayload match type(task): case protocol_schema.SummarizeStoryTask: @@ -293,22 +293,22 @@ class PromptRepository: case _: raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE) - wp = self.insert_work_package( + task = self.insert_task( payload=payload, id=task.id, message_tree_id=message_tree_id, parent_message_id=parent_message_id, collective=collective ) - assert wp.id == task.id - return wp + assert task.id == task.id + return task - def insert_work_package( + def insert_task( self, payload: db_payload.TaskPayload, id: UUID = None, message_tree_id: UUID = None, parent_message_id: UUID = None, collective: bool = False, - ) -> WorkPackage: + ) -> Task: c = PayloadContainer(payload=payload) - wp = WorkPackage( + task = Task( id=id, user_id=self.user_id, payload_type=type(payload).__name__, @@ -318,10 +318,10 @@ class PromptRepository: parent_message_id=parent_message_id, collective=collective, ) - self.db.add(wp) + self.db.add(task) self.db.commit() - self.db.refresh(wp) - return wp + self.db.refresh(task) + return task def insert_message( self, @@ -330,7 +330,7 @@ class PromptRepository: frontend_message_id: str, parent_id: UUID, message_tree_id: UUID, - workpackage_id: UUID, + task_id: UUID, role: str, payload: db_payload.MessagePayload, payload_type: str = None, @@ -346,7 +346,7 @@ class PromptRepository: id=message_id, parent_id=parent_id, message_tree_id=message_tree_id, - workpackage_id=workpackage_id, + task_id=task_id, user_id=self.user_id, role=role, frontend_message_id=frontend_message_id, @@ -360,13 +360,13 @@ class PromptRepository: self.db.refresh(message) return message - def insert_reaction(self, work_package_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction: + def insert_reaction(self, task_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction: if self.user_id is None: raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED) container = PayloadContainer(payload=payload) reaction = MessageReaction( - work_package_id=work_package_id, + task_id=task_id, user_id=self.user_id, payload=container, api_client_id=self.api_client.id, @@ -474,17 +474,17 @@ class PromptRepository: def close_task(self, message_id: str, allow_personal_tasks: bool = False): self.validate_message_id(message_id) - wp = self.fetch_workpackage_by_message_id(message_id) + task = self.fetch_task_by_message_id(message_id) - if not wp: - raise OasstError("Work package not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND) - if wp.expired: - raise OasstError("Work package expired", OasstErrorCode.WORK_PACKAGE_EXPIRED) - if not allow_personal_tasks and not wp.collective: - raise OasstError("This is not a collective task", OasstErrorCode.WORK_PACKAGE_NOT_COLLECTIVE) - if wp.done: - raise OasstError("Allready closed", OasstErrorCode.WORK_PACKAGE_ALREADY_DONE) + if not task: + raise OasstError("Work package not found", OasstErrorCode.TASK_NOT_FOUND) + if task.expired: + raise OasstError("Work package expired", OasstErrorCode.TASK_EXPIRED) + if not allow_personal_tasks and not task.collective: + raise OasstError("This is not a collective task", OasstErrorCode.TASK_NOT_COLLECTIVE) + if task.done: + raise OasstError("Allready closed", OasstErrorCode.TASK_ALREADY_DONE) - wp.done = True - self.db.add(wp) + task.done = True + self.db.add(task) self.db.commit()