mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
added support for collective tasks
This commit is contained in:
@@ -0,0 +1,30 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""add collective flag to task
|
||||
|
||||
Revision ID: 464ec4667aae
|
||||
Revises: d24b37426857
|
||||
Create Date: 2022-12-29 21:03:06.841962
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "464ec4667aae"
|
||||
down_revision = "d24b37426857"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"work_package", sa.Column("collective", sa.Boolean(), server_default=sa.text("false"), nullable=False)
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("work_package", "collective")
|
||||
# ### end Alembic commands ###
|
||||
@@ -139,7 +139,7 @@ def request_task(
|
||||
try:
|
||||
pr = PromptRepository(db, api_client, request.user)
|
||||
task, thread_id, parent_post_id = generate_task(request, pr)
|
||||
pr.store_task(task, thread_id, parent_post_id)
|
||||
pr.store_task(task, thread_id, parent_post_id, request.collective)
|
||||
|
||||
except OasstError:
|
||||
raise
|
||||
@@ -252,3 +252,15 @@ def post_interaction(
|
||||
except Exception:
|
||||
logger.exception("Interaction request failed.")
|
||||
raise OasstError("Interaction request failed.", OasstErrorCode.TASK_INTERACTION_REQUEST_FAILED)
|
||||
|
||||
|
||||
@router.post("/close")
|
||||
def close_collective_task(
|
||||
close_task_request: protocol_schema.TaskClose,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
):
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr.close_task(close_task_request.post_id)
|
||||
return protocol_schema.TaskDone()
|
||||
|
||||
@@ -40,6 +40,7 @@ class OasstErrorCode(IntEnum):
|
||||
WORK_PACKAGE_ALREADY_UPDATED = 2103
|
||||
WORK_PACKAGE_NOT_ACK = 2104
|
||||
WORK_PACKAGE_ALREADY_DONE = 2105
|
||||
WORK_PACKAGE_NOT_COLLECTIVE = 2106
|
||||
|
||||
|
||||
class OasstError(Exception):
|
||||
|
||||
@@ -32,6 +32,7 @@ class WorkPackage(SQLModel, table=True):
|
||||
frontend_ref_post_id: Optional[str] = None
|
||||
thread_id: Optional[UUID] = None
|
||||
parent_post_id: Optional[UUID] = None
|
||||
collective: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
|
||||
|
||||
@property
|
||||
def expired(self) -> bool:
|
||||
|
||||
@@ -160,8 +160,9 @@ class PromptRepository:
|
||||
payload=db_payload.PostPayload(text=text),
|
||||
depth=depth,
|
||||
)
|
||||
wp.done = True
|
||||
self.db.add(wp)
|
||||
if not wp.collective:
|
||||
wp.done = True
|
||||
self.db.add(wp)
|
||||
self.db.commit()
|
||||
self.journal.log_text_reply(work_package=wp, post_id=new_post_id, role=role, length=len(text))
|
||||
return user_post
|
||||
@@ -186,6 +187,10 @@ class PromptRepository:
|
||||
# store reaction to post
|
||||
reaction_payload = db_payload.RatingReactionPayload(rating=rating.rating)
|
||||
reaction = self.insert_reaction(post.id, reaction_payload)
|
||||
if not work_package.collective:
|
||||
work_package.done = True
|
||||
self.db.add(work_package)
|
||||
|
||||
self.journal.log_rating(work_package, post_id=post.id, rating=rating.rating)
|
||||
logger.info(f"Ranking {rating.rating} stored for work_package {work_package.id}.")
|
||||
return reaction
|
||||
@@ -193,8 +198,9 @@ class PromptRepository:
|
||||
def store_ranking(self, ranking: protocol_schema.PostRanking) -> PostReaction:
|
||||
# fetch work_package
|
||||
work_package = self.fetch_workpackage_by_postid(ranking.post_id)
|
||||
work_package.done = True
|
||||
self.db.add(work_package)
|
||||
if not work_package.collective:
|
||||
work_package.done = True
|
||||
self.db.add(work_package)
|
||||
|
||||
work_payload: db_payload.RankConversationRepliesPayload | db_payload.RankInitialPromptsPayload = (
|
||||
work_package.payload.payload
|
||||
@@ -250,6 +256,7 @@ class PromptRepository:
|
||||
task: protocol_schema.Task,
|
||||
thread_id: UUID = None,
|
||||
parent_post_id: UUID = None,
|
||||
collective: bool = False,
|
||||
) -> WorkPackage:
|
||||
payload: db_payload.TaskPayload
|
||||
match type(task):
|
||||
@@ -287,10 +294,7 @@ class PromptRepository:
|
||||
raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE)
|
||||
|
||||
wp = self.insert_work_package(
|
||||
payload=payload,
|
||||
id=task.id,
|
||||
thread_id=thread_id,
|
||||
parent_post_id=parent_post_id,
|
||||
payload=payload, id=task.id, thread_id=thread_id, parent_post_id=parent_post_id, collective=collective
|
||||
)
|
||||
assert wp.id == task.id
|
||||
return wp
|
||||
@@ -301,6 +305,7 @@ class PromptRepository:
|
||||
id: UUID = None,
|
||||
thread_id: UUID = None,
|
||||
parent_post_id: UUID = None,
|
||||
collective: bool = False,
|
||||
) -> WorkPackage:
|
||||
c = PayloadContainer(payload=payload)
|
||||
wp = WorkPackage(
|
||||
@@ -311,6 +316,7 @@ class PromptRepository:
|
||||
api_client_id=self.api_client.id,
|
||||
thread_id=thread_id,
|
||||
parent_post_id=parent_post_id,
|
||||
collective=collective,
|
||||
)
|
||||
self.db.add(wp)
|
||||
self.db.commit()
|
||||
@@ -463,3 +469,20 @@ class PromptRepository:
|
||||
|
||||
def fetch_post(self, post_id: UUID) -> Optional[Post]:
|
||||
return self.db.query(Post).filter(Post.id == post_id).one()
|
||||
|
||||
def close_task(self, post_id: str, allow_personal_tasks: bool = False):
|
||||
self.validate_post_id(post_id)
|
||||
wp = self.fetch_workpackage_by_postid(post_id)
|
||||
|
||||
if not wp:
|
||||
raise OasstError("Work package not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND)
|
||||
if wp.expired:
|
||||
raise OasstError("Work package expired", OasstErrorCode.WORK_PACKAGE_EXPIRED)
|
||||
if not allow_personal_tasks and not wp.collective:
|
||||
raise OasstError("This is not a collective task", OasstErrorCode.WORK_PACKAGE_NOT_COLLECTIVE)
|
||||
if wp.done:
|
||||
raise OasstError("Allready closed", OasstErrorCode.WORK_PACKAGE_ALREADY_DONE)
|
||||
|
||||
wp.done = True
|
||||
self.db.add(wp)
|
||||
self.db.commit()
|
||||
|
||||
@@ -52,14 +52,19 @@ class ApiClient:
|
||||
return self.task_models_map[task_type].parse_obj(data)
|
||||
|
||||
def fetch_task(
|
||||
self, task_type: protocol_schema.TaskRequestType, user: Optional[protocol_schema.User] = None
|
||||
self,
|
||||
task_type: protocol_schema.TaskRequestType,
|
||||
user: Optional[protocol_schema.User] = None,
|
||||
collective: bool = False,
|
||||
) -> protocol_schema.Task:
|
||||
req = protocol_schema.TaskRequest(type=task_type, user=user)
|
||||
req = protocol_schema.TaskRequest(type=task_type, user=user, collective=collective)
|
||||
data = self.post("/api/v1/tasks/", req.dict())
|
||||
return self._parse_task(data)
|
||||
|
||||
def fetch_random_task(self, user: Optional[protocol_schema.User] = None) -> protocol_schema.Task:
|
||||
return self.fetch_task(protocol_schema.TaskRequestType.random, user)
|
||||
def fetch_random_task(
|
||||
self, user: Optional[protocol_schema.User] = None, collective: bool = False
|
||||
) -> protocol_schema.Task:
|
||||
return self.fetch_task(protocol_schema.TaskRequestType.random, user, collective=collective)
|
||||
|
||||
def ack_task(self, task_id: str, post_id: str) -> None:
|
||||
req = protocol_schema.TaskAck(post_id=post_id)
|
||||
|
||||
@@ -43,6 +43,7 @@ class TaskRequest(BaseModel):
|
||||
|
||||
type: TaskRequestType = TaskRequestType.random
|
||||
user: Optional[User] = None
|
||||
collective: bool = False
|
||||
|
||||
|
||||
class TaskAck(BaseModel):
|
||||
@@ -57,6 +58,12 @@ class TaskNAck(BaseModel):
|
||||
reason: str
|
||||
|
||||
|
||||
class TaskClose(BaseModel):
|
||||
"""The frontend asks to mark task as done"""
|
||||
|
||||
post_id: str
|
||||
|
||||
|
||||
class Task(BaseModel):
|
||||
"""A task is a unit of work that the backend gives to the frontend."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user