diff --git a/backend/alembic/versions/2022_12_29_2103-464ec4667aae_add_collective_flag_to_task.py b/backend/alembic/versions/2022_12_29_2103-464ec4667aae_add_collective_flag_to_task.py new file mode 100644 index 00000000..cbed707c --- /dev/null +++ b/backend/alembic/versions/2022_12_29_2103-464ec4667aae_add_collective_flag_to_task.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +"""add collective flag to task + +Revision ID: 464ec4667aae +Revises: d24b37426857 +Create Date: 2022-12-29 21:03:06.841962 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "464ec4667aae" +down_revision = "d24b37426857" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "work_package", sa.Column("collective", sa.Boolean(), server_default=sa.text("false"), nullable=False) + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("work_package", "collective") + # ### end Alembic commands ### diff --git a/backend/oasst_backend/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py index 94a0be4a..1318ba41 100644 --- a/backend/oasst_backend/api/v1/tasks.py +++ b/backend/oasst_backend/api/v1/tasks.py @@ -139,7 +139,7 @@ def request_task( try: pr = PromptRepository(db, api_client, request.user) task, thread_id, parent_post_id = generate_task(request, pr) - pr.store_task(task, thread_id, parent_post_id) + pr.store_task(task, thread_id, parent_post_id, request.collective) except OasstError: raise @@ -252,3 +252,15 @@ def post_interaction( except Exception: logger.exception("Interaction request failed.") raise OasstError("Interaction request failed.", OasstErrorCode.TASK_INTERACTION_REQUEST_FAILED) + + +@router.post("/close") +def close_collective_task( + close_task_request: protocol_schema.TaskClose, + db: Session = Depends(deps.get_db), + api_key: APIKey = Depends(deps.get_api_key), +): + api_client = deps.api_auth(api_key, db) + pr = PromptRepository(db, api_client, user=None) + pr.close_task(close_task_request.post_id) + return protocol_schema.TaskDone() diff --git a/backend/oasst_backend/exceptions.py b/backend/oasst_backend/exceptions.py index 98224790..ba11e931 100644 --- a/backend/oasst_backend/exceptions.py +++ b/backend/oasst_backend/exceptions.py @@ -40,6 +40,7 @@ class OasstErrorCode(IntEnum): WORK_PACKAGE_ALREADY_UPDATED = 2103 WORK_PACKAGE_NOT_ACK = 2104 WORK_PACKAGE_ALREADY_DONE = 2105 + WORK_PACKAGE_NOT_COLLECTIVE = 2106 class OasstError(Exception): diff --git a/backend/oasst_backend/models/work_package.py b/backend/oasst_backend/models/work_package.py index a89ed646..7e568cf7 100644 --- a/backend/oasst_backend/models/work_package.py +++ b/backend/oasst_backend/models/work_package.py @@ -32,6 +32,7 @@ class WorkPackage(SQLModel, table=True): frontend_ref_post_id: Optional[str] = None thread_id: Optional[UUID] = None parent_post_id: Optional[UUID] = None + collective: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false())) @property def expired(self) -> bool: diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 6d44c443..f350e9fc 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -160,8 +160,9 @@ class PromptRepository: payload=db_payload.PostPayload(text=text), depth=depth, ) - wp.done = True - self.db.add(wp) + if not wp.collective: + wp.done = True + self.db.add(wp) self.db.commit() self.journal.log_text_reply(work_package=wp, post_id=new_post_id, role=role, length=len(text)) return user_post @@ -186,6 +187,10 @@ class PromptRepository: # store reaction to post reaction_payload = db_payload.RatingReactionPayload(rating=rating.rating) reaction = self.insert_reaction(post.id, reaction_payload) + if not work_package.collective: + work_package.done = True + self.db.add(work_package) + self.journal.log_rating(work_package, post_id=post.id, rating=rating.rating) logger.info(f"Ranking {rating.rating} stored for work_package {work_package.id}.") return reaction @@ -193,8 +198,9 @@ class PromptRepository: def store_ranking(self, ranking: protocol_schema.PostRanking) -> PostReaction: # fetch work_package work_package = self.fetch_workpackage_by_postid(ranking.post_id) - work_package.done = True - self.db.add(work_package) + if not work_package.collective: + work_package.done = True + self.db.add(work_package) work_payload: db_payload.RankConversationRepliesPayload | db_payload.RankInitialPromptsPayload = ( work_package.payload.payload @@ -250,6 +256,7 @@ class PromptRepository: task: protocol_schema.Task, thread_id: UUID = None, parent_post_id: UUID = None, + collective: bool = False, ) -> WorkPackage: payload: db_payload.TaskPayload match type(task): @@ -287,10 +294,7 @@ class PromptRepository: raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE) wp = self.insert_work_package( - payload=payload, - id=task.id, - thread_id=thread_id, - parent_post_id=parent_post_id, + payload=payload, id=task.id, thread_id=thread_id, parent_post_id=parent_post_id, collective=collective ) assert wp.id == task.id return wp @@ -301,6 +305,7 @@ class PromptRepository: id: UUID = None, thread_id: UUID = None, parent_post_id: UUID = None, + collective: bool = False, ) -> WorkPackage: c = PayloadContainer(payload=payload) wp = WorkPackage( @@ -311,6 +316,7 @@ class PromptRepository: api_client_id=self.api_client.id, thread_id=thread_id, parent_post_id=parent_post_id, + collective=collective, ) self.db.add(wp) self.db.commit() @@ -463,3 +469,20 @@ class PromptRepository: def fetch_post(self, post_id: UUID) -> Optional[Post]: return self.db.query(Post).filter(Post.id == post_id).one() + + def close_task(self, post_id: str, allow_personal_tasks: bool = False): + self.validate_post_id(post_id) + wp = self.fetch_workpackage_by_postid(post_id) + + if not wp: + raise OasstError("Work package not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND) + if wp.expired: + raise OasstError("Work package expired", OasstErrorCode.WORK_PACKAGE_EXPIRED) + if not allow_personal_tasks and not wp.collective: + raise OasstError("This is not a collective task", OasstErrorCode.WORK_PACKAGE_NOT_COLLECTIVE) + if wp.done: + raise OasstError("Allready closed", OasstErrorCode.WORK_PACKAGE_ALREADY_DONE) + + wp.done = True + self.db.add(wp) + self.db.commit() diff --git a/discord-bot/api_client.py b/discord-bot/api_client.py index 1de6bb17..0c88258e 100644 --- a/discord-bot/api_client.py +++ b/discord-bot/api_client.py @@ -52,14 +52,19 @@ class ApiClient: return self.task_models_map[task_type].parse_obj(data) def fetch_task( - self, task_type: protocol_schema.TaskRequestType, user: Optional[protocol_schema.User] = None + self, + task_type: protocol_schema.TaskRequestType, + user: Optional[protocol_schema.User] = None, + collective: bool = False, ) -> protocol_schema.Task: - req = protocol_schema.TaskRequest(type=task_type, user=user) + req = protocol_schema.TaskRequest(type=task_type, user=user, collective=collective) data = self.post("/api/v1/tasks/", req.dict()) return self._parse_task(data) - def fetch_random_task(self, user: Optional[protocol_schema.User] = None) -> protocol_schema.Task: - return self.fetch_task(protocol_schema.TaskRequestType.random, user) + def fetch_random_task( + self, user: Optional[protocol_schema.User] = None, collective: bool = False + ) -> protocol_schema.Task: + return self.fetch_task(protocol_schema.TaskRequestType.random, user, collective=collective) def ack_task(self, task_id: str, post_id: str) -> None: req = protocol_schema.TaskAck(post_id=post_id) diff --git a/oasst-shared/oasst_shared/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py index 17ee23f0..19b1921d 100644 --- a/oasst-shared/oasst_shared/schemas/protocol.py +++ b/oasst-shared/oasst_shared/schemas/protocol.py @@ -43,6 +43,7 @@ class TaskRequest(BaseModel): type: TaskRequestType = TaskRequestType.random user: Optional[User] = None + collective: bool = False class TaskAck(BaseModel): @@ -57,6 +58,12 @@ class TaskNAck(BaseModel): reason: str +class TaskClose(BaseModel): + """The frontend asks to mark task as done""" + + post_id: str + + class Task(BaseModel): """A task is a unit of work that the backend gives to the frontend."""