mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-07-04 17:20:19 +08:00
Updating work_package to task
This commit is contained in:
committed by
Andreas Köpf
parent
d118f4e332
commit
35e0c32a08
+11
-11
@@ -141,36 +141,36 @@ if settings.DEBUG_USE_SEED_DATA:
|
||||
]
|
||||
|
||||
for p in dummy_messages:
|
||||
wp = pr.fetch_workpackage_by_message_id(p.task_message_id)
|
||||
if wp and not wp.ack:
|
||||
task = pr.fetch_task_by_message_id(p.task_message_id)
|
||||
if task and not task.ack:
|
||||
logger.warning("Deleting unacknowledged seed data work package")
|
||||
db.delete(wp)
|
||||
wp = None
|
||||
if not wp:
|
||||
db.delete(task)
|
||||
task = None
|
||||
if not task:
|
||||
if p.parent_message_id is None:
|
||||
wp = pr.store_task(
|
||||
protocol_schema.InitialPromptTask(hint=""), thread_id=None, parent_message_id=None
|
||||
task = pr.store_task(
|
||||
protocol_schema.InitialPromptTask(hint=""), message_tree_id=None, parent_message_id=None
|
||||
)
|
||||
else:
|
||||
print("p.parent_message_id", p.parent_message_id)
|
||||
parent_message = pr.fetch_message_by_frontend_message_id(p.parent_message_id, fail_if_missing=True)
|
||||
wp = pr.store_task(
|
||||
task = pr.store_task(
|
||||
protocol_schema.AssistantReplyTask(
|
||||
conversation=protocol_schema.Conversation(
|
||||
messages=[protocol_schema.ConversationMessage(text="dummy", is_assistant=False)]
|
||||
)
|
||||
),
|
||||
thread_id=parent_message.thread_id,
|
||||
message_tree_id=parent_message.message_tree_id,
|
||||
parent_message_id=parent_message.id,
|
||||
)
|
||||
pr.bind_frontend_message_id(wp.id, p.task_message_id)
|
||||
pr.bind_frontend_message_id(task.id, p.task_message_id)
|
||||
message = pr.store_text_reply(p.text, p.task_message_id, p.user_message_id)
|
||||
|
||||
logger.info(
|
||||
f"Inserted: message_id: {message.id}, payload: {message.payload.payload}, parent_message_id: {message.parent_id}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"seed data work_package found: {wp.id}")
|
||||
logger.debug(f"seed data task found: {task.id}")
|
||||
logger.info("Seed data check completed")
|
||||
|
||||
except Exception:
|
||||
|
||||
@@ -18,7 +18,7 @@ router = APIRouter()
|
||||
def generate_task(
|
||||
request: protocol_schema.TaskRequest, pr: PromptRepository
|
||||
) -> Tuple[protocol_schema.Task, Optional[UUID], Optional[UUID]]:
|
||||
thread_id = None
|
||||
message_tree_id = None
|
||||
parent_message_id = None
|
||||
|
||||
match request.type:
|
||||
@@ -63,7 +63,7 @@ def generate_task(
|
||||
]
|
||||
|
||||
task = protocol_schema.UserReplyTask(conversation=protocol_schema.Conversation(messages=messages))
|
||||
thread_id = messages[-1].thread_id
|
||||
message_tree_id = messages[-1].message_tree_id
|
||||
parent_message_id = messages[-1].id
|
||||
case protocol_schema.TaskRequestType.assistant_reply:
|
||||
logger.info("Generating a AssistantReplyTask.")
|
||||
@@ -74,7 +74,7 @@ def generate_task(
|
||||
]
|
||||
|
||||
task = protocol_schema.AssistantReplyTask(conversation=protocol_schema.Conversation(messages=messages))
|
||||
thread_id = messages[-1].thread_id
|
||||
message_tree_id = messages[-1].message_tree_id
|
||||
parent_message_id = messages[-1].id
|
||||
case protocol_schema.TaskRequestType.rank_initial_prompts:
|
||||
logger.info("Generating a RankInitialPromptsTask.")
|
||||
@@ -121,7 +121,7 @@ def generate_task(
|
||||
|
||||
logger.info(f"Generated {task=}.")
|
||||
|
||||
return task, thread_id, parent_message_id
|
||||
return task, message_tree_id, parent_message_id
|
||||
|
||||
|
||||
@router.post("/", response_model=protocol_schema.AnyTask) # work with Union once more types are added
|
||||
@@ -138,8 +138,8 @@ def request_task(
|
||||
|
||||
try:
|
||||
pr = PromptRepository(db, api_client, request.user)
|
||||
task, thread_id, parent_message_id = generate_task(request, pr)
|
||||
pr.store_task(task, thread_id, parent_message_id, request.collective)
|
||||
task, message_tree_id, parent_message_id = generate_task(request, pr)
|
||||
pr.store_task(task, message_tree_id, parent_message_id, request.collective)
|
||||
|
||||
except OasstError:
|
||||
raise
|
||||
|
||||
@@ -35,13 +35,13 @@ class OasstErrorCode(IntEnum):
|
||||
USER_NOT_SPECIFIED = 2005
|
||||
NO_THREADS_FOUND = 2006
|
||||
NO_REPLIES_FOUND = 2007
|
||||
WORK_PACKAGE_NOT_FOUND = 2100
|
||||
WORK_PACKAGE_EXPIRED = 2101
|
||||
WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH = 2102
|
||||
WORK_PACKAGE_ALREADY_UPDATED = 2103
|
||||
WORK_PACKAGE_NOT_ACK = 2104
|
||||
WORK_PACKAGE_ALREADY_DONE = 2105
|
||||
WORK_PACKAGE_NOT_COLLECTIVE = 2106
|
||||
TASK_NOT_FOUND = 2100
|
||||
TASK_EXPIRED = 2101
|
||||
TASK_PAYLOAD_TYPE_MISMATCH = 2102
|
||||
TASK_ALREADY_UPDATED = 2103
|
||||
TASK_NOT_ACK = 2104
|
||||
TASK_ALREADY_DONE = 2105
|
||||
TASK_NOT_COLLECTIVE = 2106
|
||||
|
||||
|
||||
class OasstError(Exception):
|
||||
|
||||
@@ -3,7 +3,7 @@ import enum
|
||||
from typing import Literal, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from oasst_backend.models import ApiClient, Journal, Person, WorkPackage
|
||||
from oasst_backend.models import ApiClient, Journal, Person, Task
|
||||
from oasst_backend.models.payload_column_type import PayloadContainer, payload_type
|
||||
from oasst_shared.utils import utcnow
|
||||
from pydantic import BaseModel
|
||||
@@ -24,7 +24,7 @@ class JournalEvent(BaseModel):
|
||||
type: str
|
||||
user_id: Optional[UUID]
|
||||
message_id: Optional[UUID]
|
||||
workpackage_id: Optional[UUID]
|
||||
task_id: Optional[UUID]
|
||||
task_type: Optional[str]
|
||||
|
||||
|
||||
@@ -54,30 +54,30 @@ class JournalWriter:
|
||||
self.user = user
|
||||
self.user_id = self.user.id if self.user else None
|
||||
|
||||
def log_text_reply(self, work_package: WorkPackage, message_id: UUID, role: str, length: int) -> Journal:
|
||||
def log_text_reply(self, task: Task, message_id: UUID, role: str, length: int) -> Journal:
|
||||
return self.log(
|
||||
task_type=work_package.payload_type,
|
||||
task_type=task.payload_type,
|
||||
event_type=JournalEventType.text_reply_to_message,
|
||||
payload=TextReplyEvent(role=role, length=length),
|
||||
workpackage_id=work_package.id,
|
||||
task_id=task.id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
def log_rating(self, work_package: WorkPackage, message_id: UUID, rating: int) -> Journal:
|
||||
def log_rating(self, task: Task, message_id: UUID, rating: int) -> Journal:
|
||||
return self.log(
|
||||
task_type=work_package.payload_type,
|
||||
task_type=task.payload_type,
|
||||
event_type=JournalEventType.message_rating,
|
||||
payload=RatingEvent(rating=rating),
|
||||
workpackage_id=work_package.id,
|
||||
task_id=task.id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
def log_ranking(self, work_package: WorkPackage, message_id: UUID, ranking: list[int]) -> Journal:
|
||||
def log_ranking(self, task: Task, message_id: UUID, ranking: list[int]) -> Journal:
|
||||
return self.log(
|
||||
task_type=work_package.payload_type,
|
||||
task_type=task.payload_type,
|
||||
event_type=JournalEventType.message_ranking,
|
||||
payload=RankingEvent(ranking=ranking),
|
||||
workpackage_id=work_package.id,
|
||||
task_id=task.id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
@@ -87,7 +87,7 @@ class JournalWriter:
|
||||
payload: JournalEvent,
|
||||
task_type: str,
|
||||
event_type: str = None,
|
||||
workpackage_id: Optional[UUID] = None,
|
||||
task_id: Optional[UUID] = None,
|
||||
message_id: Optional[UUID] = None,
|
||||
commit: bool = True,
|
||||
) -> Journal:
|
||||
@@ -101,8 +101,8 @@ class JournalWriter:
|
||||
payload.user_id = self.user_id
|
||||
if payload.message_id is None:
|
||||
payload.message_id = message_id
|
||||
if payload.workpackage_id is None:
|
||||
payload.workpackage_id = workpackage_id
|
||||
if payload.task_id is None:
|
||||
payload.task_id = task_id
|
||||
if payload.task_type is None:
|
||||
payload.task_type = task_type
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from .user_stats import UserStats
|
||||
from .message import Message
|
||||
from .message_reaction import MessageReaction
|
||||
from .text_labels import TextLabels
|
||||
from .work_package import WorkPackage
|
||||
from .task import Task
|
||||
|
||||
__all__ = [
|
||||
"ApiClient",
|
||||
@@ -14,7 +14,7 @@ __all__ = [
|
||||
"UserStats",
|
||||
"Message",
|
||||
"MessageReaction",
|
||||
"WorkPackage",
|
||||
"Task",
|
||||
"TextLabels",
|
||||
"Journal",
|
||||
"JournalIntegration",
|
||||
|
||||
@@ -45,7 +45,7 @@ class AssistantReplyPayload(TaskPayload):
|
||||
|
||||
|
||||
@payload_type
|
||||
class PostPayload(BaseModel):
|
||||
class MessagePayload(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ class Message(SQLModel, table=True):
|
||||
)
|
||||
parent_id: UUID = Field(nullable=True)
|
||||
message_tree_id: UUID = Field(nullable=False, index=True)
|
||||
workpackage_id: UUID = Field(nullable=True, index=True)
|
||||
task_id: UUID = Field(nullable=True, index=True)
|
||||
user_id: UUID = Field(nullable=True, foreign_key="user.id", index=True)
|
||||
role: str = Field(nullable=False, max_length=128)
|
||||
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
|
||||
|
||||
@@ -13,8 +13,8 @@ from .payload_column_type import PayloadContainer, payload_column_type
|
||||
class MessageReaction(SQLModel, table=True):
|
||||
__tablename__ = "message_reaction"
|
||||
|
||||
work_package_id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("work_package.id"), nullable=False, primary_key=True)
|
||||
task_id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("task.id"), nullable=False, primary_key=True)
|
||||
)
|
||||
user_id: UUID = Field(
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), nullable=False, primary_key=True)
|
||||
|
||||
@@ -23,6 +23,6 @@ class UserStats(SQLModel, table=True):
|
||||
messages: int = 0 # messages sent by user
|
||||
upvotes: int = 0 # received upvotes (form other users)
|
||||
downvotes: int = 0 # received downvotes (from other users)
|
||||
work_reward: int = 0 # reward for workpackage completions
|
||||
task_reward: int = 0 # reward for task completions
|
||||
compare_wins: int = 0 # num times user's message won compare tasks
|
||||
compare_losses: int = 0 # num times users's message lost compare tasks
|
||||
|
||||
@@ -11,8 +11,8 @@ from sqlmodel import Field, SQLModel
|
||||
from .payload_column_type import PayloadContainer, payload_column_type
|
||||
|
||||
|
||||
class WorkPackage(SQLModel, table=True):
|
||||
__tablename__ = "work_package"
|
||||
class Task(SQLModel, table=True):
|
||||
__tablename__ = "task"
|
||||
|
||||
id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(
|
||||
|
||||
@@ -7,7 +7,7 @@ import oasst_backend.models.db_payload as db_payload
|
||||
from loguru import logger
|
||||
from oasst_backend.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_backend.journal_writer import JournalWriter
|
||||
from oasst_backend.models import ApiClient, User, Message, MessageReaction, TextLabels, WorkPackage
|
||||
from oasst_backend.models import ApiClient, User, Message, MessageReaction, TextLabels, Task
|
||||
from oasst_backend.models.payload_column_type import PayloadContainer
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session, func
|
||||
@@ -61,17 +61,17 @@ class PromptRepository:
|
||||
self.validate_message_id(message_id)
|
||||
|
||||
# find work package
|
||||
work_pack: WorkPackage = (
|
||||
self.db.query(WorkPackage)
|
||||
.filter(WorkPackage.id == task_id, WorkPackage.api_client_id == self.api_client.id)
|
||||
work_pack: Task = (
|
||||
self.db.query(Task)
|
||||
.filter(Task.id == task_id, Task.api_client_id == self.api_client.id)
|
||||
.first()
|
||||
)
|
||||
if work_pack is None:
|
||||
raise OasstError(f"WorkPackage for task {task_id} not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND)
|
||||
raise OasstError(f"Task for task {task_id} not found", OasstErrorCode.TASK_NOT_FOUND)
|
||||
if work_pack.expired:
|
||||
raise OasstError("WorkPackage already expired.", OasstErrorCode.WORK_PACKAGE_EXPIRED)
|
||||
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
|
||||
if work_pack.done or work_pack.ack is not None:
|
||||
raise OasstError("WorkPackage already updated.", OasstErrorCode.WORK_PACKAGE_ALREADY_UPDATED)
|
||||
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
|
||||
|
||||
work_pack.frontend_ref_message_id = message_id
|
||||
work_pack.ack = True
|
||||
@@ -81,17 +81,17 @@ class PromptRepository:
|
||||
|
||||
def acknowledge_task_failure(self, task_id):
|
||||
# find work package
|
||||
work_pack: WorkPackage = (
|
||||
self.db.query(WorkPackage)
|
||||
.filter(WorkPackage.id == task_id, WorkPackage.api_client_id == self.api_client.id)
|
||||
work_pack: Task = (
|
||||
self.db.query(Task)
|
||||
.filter(Task.id == task_id, Task.api_client_id == self.api_client.id)
|
||||
.first()
|
||||
)
|
||||
if work_pack is None:
|
||||
raise OasstError(f"WorkPackage for task {task_id} not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND)
|
||||
raise OasstError(f"Task for task {task_id} not found", OasstErrorCode.TASK_NOT_FOUND)
|
||||
if work_pack.expired:
|
||||
raise OasstError("WorkPackage already expired.", OasstErrorCode.WORK_PACKAGE_EXPIRED)
|
||||
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
|
||||
if work_pack.done or work_pack.ack is not None:
|
||||
raise OasstError("WorkPackage already updated.", OasstErrorCode.WORK_PACKAGE_ALREADY_UPDATED)
|
||||
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
|
||||
|
||||
work_pack.ack = False
|
||||
# ToDo: check race-condition, transaction
|
||||
@@ -109,36 +109,36 @@ class PromptRepository:
|
||||
raise OasstError(f"Message with message_id {frontend_message_id} not found.", OasstErrorCode.POST_NOT_FOUND)
|
||||
return message
|
||||
|
||||
def fetch_workpackage_by_message_id(self, message_id: str) -> WorkPackage:
|
||||
def fetch_task_by_message_id(self, message_id: str) -> Task:
|
||||
self.validate_message_id(message_id)
|
||||
work_pack = (
|
||||
self.db.query(WorkPackage)
|
||||
.filter(WorkPackage.api_client_id == self.api_client.id, WorkPackage.frontend_ref_message_id == message_id)
|
||||
task = (
|
||||
self.db.query(Task)
|
||||
.filter(Task.api_client_id == self.api_client.id, Task.frontend_ref_message_id == message_id)
|
||||
.one_or_none()
|
||||
)
|
||||
return work_pack
|
||||
return task
|
||||
|
||||
def store_text_reply(self, text: str, message_id: str, user_message_id: str, role: str = None) -> Message:
|
||||
self.validate_message_id(message_id)
|
||||
self.validate_message_id(user_message_id)
|
||||
|
||||
wp = self.fetch_workpackage_by_message_id(message_id)
|
||||
task = self.fetch_task_by_message_id(message_id)
|
||||
|
||||
if wp is None:
|
||||
raise OasstError(f"WorkPackage for {message_id=} not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND)
|
||||
if wp.expired:
|
||||
raise OasstError("WorkPackage already expired.", OasstErrorCode.WORK_PACKAGE_EXPIRED)
|
||||
if not wp.ack:
|
||||
raise OasstError("WorkPackage is not acknowledged.", OasstErrorCode.WORK_PACKAGE_NOT_ACK)
|
||||
if wp.done:
|
||||
raise OasstError("WorkPackage already done.", OasstErrorCode.WORK_PACKAGE_ALREADY_DONE)
|
||||
if task is None:
|
||||
raise OasstError(f"Task for {message_id=} not found", OasstErrorCode.TASK_NOT_FOUND)
|
||||
if task.expired:
|
||||
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
|
||||
if not task.ack:
|
||||
raise OasstError("Task is not acknowledged.", OasstErrorCode.TASK_NOT_ACK)
|
||||
if task.done:
|
||||
raise OasstError("Task already done.", OasstErrorCode.TASK_ALREADY_DONE)
|
||||
|
||||
# If there's no parent message assume user started new conversation
|
||||
role = "user"
|
||||
depth = 0
|
||||
|
||||
if wp.parent_message_id:
|
||||
parent_message = self.fetch_message(wp.parent_message_id)
|
||||
if task.parent_message_id:
|
||||
parent_message = self.fetch_message(task.parent_message_id)
|
||||
parent_message.children_count += 1
|
||||
self.db.add(parent_message)
|
||||
|
||||
@@ -153,29 +153,29 @@ class PromptRepository:
|
||||
user_message = self.insert_message(
|
||||
message_id=new_message_id,
|
||||
frontend_message_id=user_message_id,
|
||||
parent_id=wp.parent_message_id,
|
||||
message_tree_id=wp.message_tree_id or new_message_id,
|
||||
workpackage_id=wp.id,
|
||||
parent_id=task.parent_message_id,
|
||||
message_tree_id=task.message_tree_id or new_message_id,
|
||||
task_id=task.id,
|
||||
role=role,
|
||||
payload=db_payload.MessagePayload(text=text),
|
||||
depth=depth,
|
||||
)
|
||||
if not wp.collective:
|
||||
wp.done = True
|
||||
self.db.add(wp)
|
||||
if not task.collective:
|
||||
task.done = True
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
self.journal.log_text_reply(work_package=wp, message_id=new_message_id, role=role, length=len(text))
|
||||
self.journal.log_text_reply(task=task, message_id=new_message_id, role=role, length=len(text))
|
||||
return user_message
|
||||
|
||||
def store_rating(self, rating: protocol_schema.MessageRating) -> MessageReaction:
|
||||
message = self.fetch_message_by_frontend_message_id(rating.message_id, fail_if_missing=True)
|
||||
|
||||
work_package = self.fetch_workpackage_by_message_id(rating.message_id)
|
||||
work_payload: db_payload.RateSummaryPayload = work_package.payload.payload
|
||||
task = self.fetch_task_by_message_id(rating.message_id)
|
||||
work_payload: db_payload.RateSummaryPayload = task.payload.payload
|
||||
if type(work_payload) != db_payload.RateSummaryPayload:
|
||||
raise OasstError(
|
||||
f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RateSummaryPayload}",
|
||||
OasstErrorCode.WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH,
|
||||
f"task payload type mismatch: {type(work_payload)=} != {db_payload.RateSummaryPayload}",
|
||||
OasstErrorCode.TASK_PAYLOAD_TYPE_MISMATCH,
|
||||
)
|
||||
|
||||
if rating.rating < work_payload.scale.min or rating.rating > work_payload.scale.max:
|
||||
@@ -187,23 +187,23 @@ class PromptRepository:
|
||||
# store reaction to message
|
||||
reaction_payload = db_payload.RatingReactionPayload(rating=rating.rating)
|
||||
reaction = self.insert_reaction(message.id, reaction_payload)
|
||||
if not work_package.collective:
|
||||
work_package.done = True
|
||||
self.db.add(work_package)
|
||||
if not task.collective:
|
||||
task.done = True
|
||||
self.db.add(task)
|
||||
|
||||
self.journal.log_rating(work_package, message_id=message.id, rating=rating.rating)
|
||||
logger.info(f"Ranking {rating.rating} stored for work_package {work_package.id}.")
|
||||
self.journal.log_rating(task, message_id=message.id, rating=rating.rating)
|
||||
logger.info(f"Ranking {rating.rating} stored for task {task.id}.")
|
||||
return reaction
|
||||
|
||||
def store_ranking(self, ranking: protocol_schema.MessageRanking) -> MessageReaction:
|
||||
# fetch work_package
|
||||
work_package = self.fetch_workpackage_by_message_id(ranking.message_id)
|
||||
if not work_package.collective:
|
||||
work_package.done = True
|
||||
self.db.add(work_package)
|
||||
# fetch task
|
||||
task = self.fetch_task_by_message_id(ranking.message_id)
|
||||
if not task.collective:
|
||||
task.done = True
|
||||
self.db.add(task)
|
||||
|
||||
work_payload: db_payload.RankConversationRepliesPayload | db_payload.RankInitialPromptsPayload = (
|
||||
work_package.payload.payload
|
||||
task.payload.payload
|
||||
)
|
||||
|
||||
match type(work_payload):
|
||||
@@ -219,11 +219,11 @@ class PromptRepository:
|
||||
|
||||
# store reaction to message
|
||||
reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking)
|
||||
reaction = self.insert_reaction(work_package.id, reaction_payload)
|
||||
reaction = self.insert_reaction(task.id, reaction_payload)
|
||||
# TODO: resolve message_id
|
||||
self.journal.log_ranking(work_package, message_id=None, ranking=ranking.ranking)
|
||||
self.journal.log_ranking(task, message_id=None, ranking=ranking.ranking)
|
||||
|
||||
logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.")
|
||||
logger.info(f"Ranking {ranking.ranking} stored for task {task.id}.")
|
||||
|
||||
return reaction
|
||||
|
||||
@@ -237,18 +237,18 @@ class PromptRepository:
|
||||
|
||||
# store reaction to message
|
||||
reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking)
|
||||
reaction = self.insert_reaction(work_package.id, reaction_payload)
|
||||
reaction = self.insert_reaction(task.id, reaction_payload)
|
||||
# TODO: resolve message_id
|
||||
self.journal.log_ranking(work_package, message_id=None, ranking=ranking.ranking)
|
||||
self.journal.log_ranking(task, message_id=None, ranking=ranking.ranking)
|
||||
|
||||
logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.")
|
||||
logger.info(f"Ranking {ranking.ranking} stored for task {task.id}.")
|
||||
|
||||
return reaction
|
||||
|
||||
case _:
|
||||
raise OasstError(
|
||||
f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RankConversationRepliesPayload}",
|
||||
OasstErrorCode.WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH,
|
||||
f"task payload type mismatch: {type(work_payload)=} != {db_payload.RankConversationRepliesPayload}",
|
||||
OasstErrorCode.TASK_PAYLOAD_TYPE_MISMATCH,
|
||||
)
|
||||
|
||||
def store_task(
|
||||
@@ -257,7 +257,7 @@ class PromptRepository:
|
||||
message_tree_id: UUID = None,
|
||||
parent_message_id: UUID = None,
|
||||
collective: bool = False,
|
||||
) -> WorkPackage:
|
||||
) -> Task:
|
||||
payload: db_payload.TaskPayload
|
||||
match type(task):
|
||||
case protocol_schema.SummarizeStoryTask:
|
||||
@@ -293,22 +293,22 @@ class PromptRepository:
|
||||
case _:
|
||||
raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE)
|
||||
|
||||
wp = self.insert_work_package(
|
||||
task = self.insert_task(
|
||||
payload=payload, id=task.id, message_tree_id=message_tree_id, parent_message_id=parent_message_id, collective=collective
|
||||
)
|
||||
assert wp.id == task.id
|
||||
return wp
|
||||
assert task.id == task.id
|
||||
return task
|
||||
|
||||
def insert_work_package(
|
||||
def insert_task(
|
||||
self,
|
||||
payload: db_payload.TaskPayload,
|
||||
id: UUID = None,
|
||||
message_tree_id: UUID = None,
|
||||
parent_message_id: UUID = None,
|
||||
collective: bool = False,
|
||||
) -> WorkPackage:
|
||||
) -> Task:
|
||||
c = PayloadContainer(payload=payload)
|
||||
wp = WorkPackage(
|
||||
task = Task(
|
||||
id=id,
|
||||
user_id=self.user_id,
|
||||
payload_type=type(payload).__name__,
|
||||
@@ -318,10 +318,10 @@ class PromptRepository:
|
||||
parent_message_id=parent_message_id,
|
||||
collective=collective,
|
||||
)
|
||||
self.db.add(wp)
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
self.db.refresh(wp)
|
||||
return wp
|
||||
self.db.refresh(task)
|
||||
return task
|
||||
|
||||
def insert_message(
|
||||
self,
|
||||
@@ -330,7 +330,7 @@ class PromptRepository:
|
||||
frontend_message_id: str,
|
||||
parent_id: UUID,
|
||||
message_tree_id: UUID,
|
||||
workpackage_id: UUID,
|
||||
task_id: UUID,
|
||||
role: str,
|
||||
payload: db_payload.MessagePayload,
|
||||
payload_type: str = None,
|
||||
@@ -346,7 +346,7 @@ class PromptRepository:
|
||||
id=message_id,
|
||||
parent_id=parent_id,
|
||||
message_tree_id=message_tree_id,
|
||||
workpackage_id=workpackage_id,
|
||||
task_id=task_id,
|
||||
user_id=self.user_id,
|
||||
role=role,
|
||||
frontend_message_id=frontend_message_id,
|
||||
@@ -360,13 +360,13 @@ class PromptRepository:
|
||||
self.db.refresh(message)
|
||||
return message
|
||||
|
||||
def insert_reaction(self, work_package_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction:
|
||||
def insert_reaction(self, task_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction:
|
||||
if self.user_id is None:
|
||||
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
|
||||
|
||||
container = PayloadContainer(payload=payload)
|
||||
reaction = MessageReaction(
|
||||
work_package_id=work_package_id,
|
||||
task_id=task_id,
|
||||
user_id=self.user_id,
|
||||
payload=container,
|
||||
api_client_id=self.api_client.id,
|
||||
@@ -474,17 +474,17 @@ class PromptRepository:
|
||||
|
||||
def close_task(self, message_id: str, allow_personal_tasks: bool = False):
|
||||
self.validate_message_id(message_id)
|
||||
wp = self.fetch_workpackage_by_message_id(message_id)
|
||||
task = self.fetch_task_by_message_id(message_id)
|
||||
|
||||
if not wp:
|
||||
raise OasstError("Work package not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND)
|
||||
if wp.expired:
|
||||
raise OasstError("Work package expired", OasstErrorCode.WORK_PACKAGE_EXPIRED)
|
||||
if not allow_personal_tasks and not wp.collective:
|
||||
raise OasstError("This is not a collective task", OasstErrorCode.WORK_PACKAGE_NOT_COLLECTIVE)
|
||||
if wp.done:
|
||||
raise OasstError("Allready closed", OasstErrorCode.WORK_PACKAGE_ALREADY_DONE)
|
||||
if not task:
|
||||
raise OasstError("Work package not found", OasstErrorCode.TASK_NOT_FOUND)
|
||||
if task.expired:
|
||||
raise OasstError("Work package expired", OasstErrorCode.TASK_EXPIRED)
|
||||
if not allow_personal_tasks and not task.collective:
|
||||
raise OasstError("This is not a collective task", OasstErrorCode.TASK_NOT_COLLECTIVE)
|
||||
if task.done:
|
||||
raise OasstError("Allready closed", OasstErrorCode.TASK_ALREADY_DONE)
|
||||
|
||||
wp.done = True
|
||||
self.db.add(wp)
|
||||
task.done = True
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
|
||||
Reference in New Issue
Block a user