added support for collective tasks

This commit is contained in:
Igor Miagkov
2022-12-29 21:32:17 +04:00
parent c5053ed6c9
commit efafc0173a
7 changed files with 92 additions and 13 deletions
@@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
"""add collective flag to task
Revision ID: 464ec4667aae
Revises: d24b37426857
Create Date: 2022-12-29 21:03:06.841962
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "464ec4667aae"
down_revision = "d24b37426857"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"work_package", sa.Column("collective", sa.Boolean(), server_default=sa.text("false"), nullable=False)
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("work_package", "collective")
# ### end Alembic commands ###
+13 -1
View File
@@ -139,7 +139,7 @@ def request_task(
try:
pr = PromptRepository(db, api_client, request.user)
task, thread_id, parent_post_id = generate_task(request, pr)
pr.store_task(task, thread_id, parent_post_id)
pr.store_task(task, thread_id, parent_post_id, request.collective)
except OasstError:
raise
@@ -252,3 +252,15 @@ def post_interaction(
except Exception:
logger.exception("Interaction request failed.")
raise OasstError("Interaction request failed.", OasstErrorCode.TASK_INTERACTION_REQUEST_FAILED)
@router.post("/close")
def close_collective_task(
close_task_request: protocol_schema.TaskClose,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
):
api_client = deps.api_auth(api_key, db)
pr = PromptRepository(db, api_client, user=None)
pr.close_task(close_task_request.post_id)
return protocol_schema.TaskDone()
+1
View File
@@ -40,6 +40,7 @@ class OasstErrorCode(IntEnum):
WORK_PACKAGE_ALREADY_UPDATED = 2103
WORK_PACKAGE_NOT_ACK = 2104
WORK_PACKAGE_ALREADY_DONE = 2105
WORK_PACKAGE_NOT_COLLECTIVE = 2106
class OasstError(Exception):
@@ -32,6 +32,7 @@ class WorkPackage(SQLModel, table=True):
frontend_ref_post_id: Optional[str] = None
thread_id: Optional[UUID] = None
parent_post_id: Optional[UUID] = None
collective: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
@property
def expired(self) -> bool:
+31 -8
View File
@@ -160,8 +160,9 @@ class PromptRepository:
payload=db_payload.PostPayload(text=text),
depth=depth,
)
wp.done = True
self.db.add(wp)
if not wp.collective:
wp.done = True
self.db.add(wp)
self.db.commit()
self.journal.log_text_reply(work_package=wp, post_id=new_post_id, role=role, length=len(text))
return user_post
@@ -186,6 +187,10 @@ class PromptRepository:
# store reaction to post
reaction_payload = db_payload.RatingReactionPayload(rating=rating.rating)
reaction = self.insert_reaction(post.id, reaction_payload)
if not work_package.collective:
work_package.done = True
self.db.add(work_package)
self.journal.log_rating(work_package, post_id=post.id, rating=rating.rating)
logger.info(f"Ranking {rating.rating} stored for work_package {work_package.id}.")
return reaction
@@ -193,8 +198,9 @@ class PromptRepository:
def store_ranking(self, ranking: protocol_schema.PostRanking) -> PostReaction:
# fetch work_package
work_package = self.fetch_workpackage_by_postid(ranking.post_id)
work_package.done = True
self.db.add(work_package)
if not work_package.collective:
work_package.done = True
self.db.add(work_package)
work_payload: db_payload.RankConversationRepliesPayload | db_payload.RankInitialPromptsPayload = (
work_package.payload.payload
@@ -250,6 +256,7 @@ class PromptRepository:
task: protocol_schema.Task,
thread_id: UUID = None,
parent_post_id: UUID = None,
collective: bool = False,
) -> WorkPackage:
payload: db_payload.TaskPayload
match type(task):
@@ -287,10 +294,7 @@ class PromptRepository:
raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE)
wp = self.insert_work_package(
payload=payload,
id=task.id,
thread_id=thread_id,
parent_post_id=parent_post_id,
payload=payload, id=task.id, thread_id=thread_id, parent_post_id=parent_post_id, collective=collective
)
assert wp.id == task.id
return wp
@@ -301,6 +305,7 @@ class PromptRepository:
id: UUID = None,
thread_id: UUID = None,
parent_post_id: UUID = None,
collective: bool = False,
) -> WorkPackage:
c = PayloadContainer(payload=payload)
wp = WorkPackage(
@@ -311,6 +316,7 @@ class PromptRepository:
api_client_id=self.api_client.id,
thread_id=thread_id,
parent_post_id=parent_post_id,
collective=collective,
)
self.db.add(wp)
self.db.commit()
@@ -463,3 +469,20 @@ class PromptRepository:
def fetch_post(self, post_id: UUID) -> Optional[Post]:
return self.db.query(Post).filter(Post.id == post_id).one()
def close_task(self, post_id: str, allow_personal_tasks: bool = False):
self.validate_post_id(post_id)
wp = self.fetch_workpackage_by_postid(post_id)
if not wp:
raise OasstError("Work package not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND)
if wp.expired:
raise OasstError("Work package expired", OasstErrorCode.WORK_PACKAGE_EXPIRED)
if not allow_personal_tasks and not wp.collective:
raise OasstError("This is not a collective task", OasstErrorCode.WORK_PACKAGE_NOT_COLLECTIVE)
if wp.done:
raise OasstError("Allready closed", OasstErrorCode.WORK_PACKAGE_ALREADY_DONE)
wp.done = True
self.db.add(wp)
self.db.commit()
+9 -4
View File
@@ -52,14 +52,19 @@ class ApiClient:
return self.task_models_map[task_type].parse_obj(data)
def fetch_task(
self, task_type: protocol_schema.TaskRequestType, user: Optional[protocol_schema.User] = None
self,
task_type: protocol_schema.TaskRequestType,
user: Optional[protocol_schema.User] = None,
collective: bool = False,
) -> protocol_schema.Task:
req = protocol_schema.TaskRequest(type=task_type, user=user)
req = protocol_schema.TaskRequest(type=task_type, user=user, collective=collective)
data = self.post("/api/v1/tasks/", req.dict())
return self._parse_task(data)
def fetch_random_task(self, user: Optional[protocol_schema.User] = None) -> protocol_schema.Task:
return self.fetch_task(protocol_schema.TaskRequestType.random, user)
def fetch_random_task(
self, user: Optional[protocol_schema.User] = None, collective: bool = False
) -> protocol_schema.Task:
return self.fetch_task(protocol_schema.TaskRequestType.random, user, collective=collective)
def ack_task(self, task_id: str, post_id: str) -> None:
req = protocol_schema.TaskAck(post_id=post_id)
@@ -43,6 +43,7 @@ class TaskRequest(BaseModel):
type: TaskRequestType = TaskRequestType.random
user: Optional[User] = None
collective: bool = False
class TaskAck(BaseModel):
@@ -57,6 +58,12 @@ class TaskNAck(BaseModel):
reason: str
class TaskClose(BaseModel):
"""The frontend asks to mark task as done"""
post_id: str
class Task(BaseModel):
"""A task is a unit of work that the backend gives to the frontend."""