Updating work_package to task

This commit is contained in:
alexandrelefourner
2022-12-30 18:05:23 +01:00
committed by Andreas Köpf
parent d118f4e332
commit 35e0c32a08
11 changed files with 131 additions and 131 deletions
+11 -11
View File
@@ -141,36 +141,36 @@ if settings.DEBUG_USE_SEED_DATA:
]
for p in dummy_messages:
wp = pr.fetch_workpackage_by_message_id(p.task_message_id)
if wp and not wp.ack:
task = pr.fetch_task_by_message_id(p.task_message_id)
if task and not task.ack:
logger.warning("Deleting unacknowledged seed data work package")
db.delete(wp)
wp = None
if not wp:
db.delete(task)
task = None
if not task:
if p.parent_message_id is None:
wp = pr.store_task(
protocol_schema.InitialPromptTask(hint=""), thread_id=None, parent_message_id=None
task = pr.store_task(
protocol_schema.InitialPromptTask(hint=""), message_tree_id=None, parent_message_id=None
)
else:
print("p.parent_message_id", p.parent_message_id)
parent_message = pr.fetch_message_by_frontend_message_id(p.parent_message_id, fail_if_missing=True)
wp = pr.store_task(
task = pr.store_task(
protocol_schema.AssistantReplyTask(
conversation=protocol_schema.Conversation(
messages=[protocol_schema.ConversationMessage(text="dummy", is_assistant=False)]
)
),
thread_id=parent_message.thread_id,
message_tree_id=parent_message.message_tree_id,
parent_message_id=parent_message.id,
)
pr.bind_frontend_message_id(wp.id, p.task_message_id)
pr.bind_frontend_message_id(task.id, p.task_message_id)
message = pr.store_text_reply(p.text, p.task_message_id, p.user_message_id)
logger.info(
f"Inserted: message_id: {message.id}, payload: {message.payload.payload}, parent_message_id: {message.parent_id}"
)
else:
logger.debug(f"seed data work_package found: {wp.id}")
logger.debug(f"seed data task found: {task.id}")
logger.info("Seed data check completed")
except Exception:
+6 -6
View File
@@ -18,7 +18,7 @@ router = APIRouter()
def generate_task(
request: protocol_schema.TaskRequest, pr: PromptRepository
) -> Tuple[protocol_schema.Task, Optional[UUID], Optional[UUID]]:
thread_id = None
message_tree_id = None
parent_message_id = None
match request.type:
@@ -63,7 +63,7 @@ def generate_task(
]
task = protocol_schema.UserReplyTask(conversation=protocol_schema.Conversation(messages=messages))
thread_id = messages[-1].thread_id
message_tree_id = messages[-1].message_tree_id
parent_message_id = messages[-1].id
case protocol_schema.TaskRequestType.assistant_reply:
logger.info("Generating a AssistantReplyTask.")
@@ -74,7 +74,7 @@ def generate_task(
]
task = protocol_schema.AssistantReplyTask(conversation=protocol_schema.Conversation(messages=messages))
thread_id = messages[-1].thread_id
message_tree_id = messages[-1].message_tree_id
parent_message_id = messages[-1].id
case protocol_schema.TaskRequestType.rank_initial_prompts:
logger.info("Generating a RankInitialPromptsTask.")
@@ -121,7 +121,7 @@ def generate_task(
logger.info(f"Generated {task=}.")
return task, thread_id, parent_message_id
return task, message_tree_id, parent_message_id
@router.post("/", response_model=protocol_schema.AnyTask) # work with Union once more types are added
@@ -138,8 +138,8 @@ def request_task(
try:
pr = PromptRepository(db, api_client, request.user)
task, thread_id, parent_message_id = generate_task(request, pr)
pr.store_task(task, thread_id, parent_message_id, request.collective)
task, message_tree_id, parent_message_id = generate_task(request, pr)
pr.store_task(task, message_tree_id, parent_message_id, request.collective)
except OasstError:
raise
+7 -7
View File
@@ -35,13 +35,13 @@ class OasstErrorCode(IntEnum):
USER_NOT_SPECIFIED = 2005
NO_THREADS_FOUND = 2006
NO_REPLIES_FOUND = 2007
WORK_PACKAGE_NOT_FOUND = 2100
WORK_PACKAGE_EXPIRED = 2101
WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH = 2102
WORK_PACKAGE_ALREADY_UPDATED = 2103
WORK_PACKAGE_NOT_ACK = 2104
WORK_PACKAGE_ALREADY_DONE = 2105
WORK_PACKAGE_NOT_COLLECTIVE = 2106
TASK_NOT_FOUND = 2100
TASK_EXPIRED = 2101
TASK_PAYLOAD_TYPE_MISMATCH = 2102
TASK_ALREADY_UPDATED = 2103
TASK_NOT_ACK = 2104
TASK_ALREADY_DONE = 2105
TASK_NOT_COLLECTIVE = 2106
class OasstError(Exception):
+14 -14
View File
@@ -3,7 +3,7 @@ import enum
from typing import Literal, Optional
from uuid import UUID
from oasst_backend.models import ApiClient, Journal, Person, WorkPackage
from oasst_backend.models import ApiClient, Journal, Person, Task
from oasst_backend.models.payload_column_type import PayloadContainer, payload_type
from oasst_shared.utils import utcnow
from pydantic import BaseModel
@@ -24,7 +24,7 @@ class JournalEvent(BaseModel):
type: str
user_id: Optional[UUID]
message_id: Optional[UUID]
workpackage_id: Optional[UUID]
task_id: Optional[UUID]
task_type: Optional[str]
@@ -54,30 +54,30 @@ class JournalWriter:
self.user = user
self.user_id = self.user.id if self.user else None
def log_text_reply(self, work_package: WorkPackage, message_id: UUID, role: str, length: int) -> Journal:
def log_text_reply(self, task: Task, message_id: UUID, role: str, length: int) -> Journal:
return self.log(
task_type=work_package.payload_type,
task_type=task.payload_type,
event_type=JournalEventType.text_reply_to_message,
payload=TextReplyEvent(role=role, length=length),
workpackage_id=work_package.id,
task_id=task.id,
message_id=message_id,
)
def log_rating(self, work_package: WorkPackage, message_id: UUID, rating: int) -> Journal:
def log_rating(self, task: Task, message_id: UUID, rating: int) -> Journal:
return self.log(
task_type=work_package.payload_type,
task_type=task.payload_type,
event_type=JournalEventType.message_rating,
payload=RatingEvent(rating=rating),
workpackage_id=work_package.id,
task_id=task.id,
message_id=message_id,
)
def log_ranking(self, work_package: WorkPackage, message_id: UUID, ranking: list[int]) -> Journal:
def log_ranking(self, task: Task, message_id: UUID, ranking: list[int]) -> Journal:
return self.log(
task_type=work_package.payload_type,
task_type=task.payload_type,
event_type=JournalEventType.message_ranking,
payload=RankingEvent(ranking=ranking),
workpackage_id=work_package.id,
task_id=task.id,
message_id=message_id,
)
@@ -87,7 +87,7 @@ class JournalWriter:
payload: JournalEvent,
task_type: str,
event_type: str = None,
workpackage_id: Optional[UUID] = None,
task_id: Optional[UUID] = None,
message_id: Optional[UUID] = None,
commit: bool = True,
) -> Journal:
@@ -101,8 +101,8 @@ class JournalWriter:
payload.user_id = self.user_id
if payload.message_id is None:
payload.message_id = message_id
if payload.workpackage_id is None:
payload.workpackage_id = workpackage_id
if payload.task_id is None:
payload.task_id = task_id
if payload.task_type is None:
payload.task_type = task_type
+2 -2
View File
@@ -6,7 +6,7 @@ from .user_stats import UserStats
from .message import Message
from .message_reaction import MessageReaction
from .text_labels import TextLabels
from .work_package import WorkPackage
from .task import Task
__all__ = [
"ApiClient",
@@ -14,7 +14,7 @@ __all__ = [
"UserStats",
"Message",
"MessageReaction",
"WorkPackage",
"Task",
"TextLabels",
"Journal",
"JournalIntegration",
+1 -1
View File
@@ -45,7 +45,7 @@ class AssistantReplyPayload(TaskPayload):
@payload_type
class PostPayload(BaseModel):
class MessagePayload(BaseModel):
text: str
+1 -1
View File
@@ -21,7 +21,7 @@ class Message(SQLModel, table=True):
)
parent_id: UUID = Field(nullable=True)
message_tree_id: UUID = Field(nullable=False, index=True)
workpackage_id: UUID = Field(nullable=True, index=True)
task_id: UUID = Field(nullable=True, index=True)
user_id: UUID = Field(nullable=True, foreign_key="user.id", index=True)
role: str = Field(nullable=False, max_length=128)
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
@@ -13,8 +13,8 @@ from .payload_column_type import PayloadContainer, payload_column_type
class MessageReaction(SQLModel, table=True):
__tablename__ = "message_reaction"
work_package_id: Optional[UUID] = Field(
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("work_package.id"), nullable=False, primary_key=True)
task_id: Optional[UUID] = Field(
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("task.id"), nullable=False, primary_key=True)
)
user_id: UUID = Field(
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), nullable=False, primary_key=True)
+1 -1
View File
@@ -23,6 +23,6 @@ class UserStats(SQLModel, table=True):
messages: int = 0 # messages sent by user
upvotes: int = 0 # received upvotes (form other users)
downvotes: int = 0 # received downvotes (from other users)
work_reward: int = 0 # reward for workpackage completions
task_reward: int = 0 # reward for task completions
compare_wins: int = 0 # num times user's message won compare tasks
compare_losses: int = 0 # num times users's message lost compare tasks
+2 -2
View File
@@ -11,8 +11,8 @@ from sqlmodel import Field, SQLModel
from .payload_column_type import PayloadContainer, payload_column_type
class WorkPackage(SQLModel, table=True):
__tablename__ = "work_package"
class Task(SQLModel, table=True):
__tablename__ = "task"
id: Optional[UUID] = Field(
sa_column=sa.Column(
+84 -84
View File
@@ -7,7 +7,7 @@ import oasst_backend.models.db_payload as db_payload
from loguru import logger
from oasst_backend.exceptions import OasstError, OasstErrorCode
from oasst_backend.journal_writer import JournalWriter
from oasst_backend.models import ApiClient, User, Message, MessageReaction, TextLabels, WorkPackage
from oasst_backend.models import ApiClient, User, Message, MessageReaction, TextLabels, Task
from oasst_backend.models.payload_column_type import PayloadContainer
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session, func
@@ -61,17 +61,17 @@ class PromptRepository:
self.validate_message_id(message_id)
# find work package
work_pack: WorkPackage = (
self.db.query(WorkPackage)
.filter(WorkPackage.id == task_id, WorkPackage.api_client_id == self.api_client.id)
work_pack: Task = (
self.db.query(Task)
.filter(Task.id == task_id, Task.api_client_id == self.api_client.id)
.first()
)
if work_pack is None:
raise OasstError(f"WorkPackage for task {task_id} not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND)
raise OasstError(f"Task for task {task_id} not found", OasstErrorCode.TASK_NOT_FOUND)
if work_pack.expired:
raise OasstError("WorkPackage already expired.", OasstErrorCode.WORK_PACKAGE_EXPIRED)
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
if work_pack.done or work_pack.ack is not None:
raise OasstError("WorkPackage already updated.", OasstErrorCode.WORK_PACKAGE_ALREADY_UPDATED)
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
work_pack.frontend_ref_message_id = message_id
work_pack.ack = True
@@ -81,17 +81,17 @@ class PromptRepository:
def acknowledge_task_failure(self, task_id):
# find work package
work_pack: WorkPackage = (
self.db.query(WorkPackage)
.filter(WorkPackage.id == task_id, WorkPackage.api_client_id == self.api_client.id)
work_pack: Task = (
self.db.query(Task)
.filter(Task.id == task_id, Task.api_client_id == self.api_client.id)
.first()
)
if work_pack is None:
raise OasstError(f"WorkPackage for task {task_id} not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND)
raise OasstError(f"Task for task {task_id} not found", OasstErrorCode.TASK_NOT_FOUND)
if work_pack.expired:
raise OasstError("WorkPackage already expired.", OasstErrorCode.WORK_PACKAGE_EXPIRED)
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
if work_pack.done or work_pack.ack is not None:
raise OasstError("WorkPackage already updated.", OasstErrorCode.WORK_PACKAGE_ALREADY_UPDATED)
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
work_pack.ack = False
# ToDo: check race-condition, transaction
@@ -109,36 +109,36 @@ class PromptRepository:
raise OasstError(f"Message with message_id {frontend_message_id} not found.", OasstErrorCode.POST_NOT_FOUND)
return message
def fetch_workpackage_by_message_id(self, message_id: str) -> WorkPackage:
def fetch_task_by_message_id(self, message_id: str) -> Task:
self.validate_message_id(message_id)
work_pack = (
self.db.query(WorkPackage)
.filter(WorkPackage.api_client_id == self.api_client.id, WorkPackage.frontend_ref_message_id == message_id)
task = (
self.db.query(Task)
.filter(Task.api_client_id == self.api_client.id, Task.frontend_ref_message_id == message_id)
.one_or_none()
)
return work_pack
return task
def store_text_reply(self, text: str, message_id: str, user_message_id: str, role: str = None) -> Message:
self.validate_message_id(message_id)
self.validate_message_id(user_message_id)
wp = self.fetch_workpackage_by_message_id(message_id)
task = self.fetch_task_by_message_id(message_id)
if wp is None:
raise OasstError(f"WorkPackage for {message_id=} not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND)
if wp.expired:
raise OasstError("WorkPackage already expired.", OasstErrorCode.WORK_PACKAGE_EXPIRED)
if not wp.ack:
raise OasstError("WorkPackage is not acknowledged.", OasstErrorCode.WORK_PACKAGE_NOT_ACK)
if wp.done:
raise OasstError("WorkPackage already done.", OasstErrorCode.WORK_PACKAGE_ALREADY_DONE)
if task is None:
raise OasstError(f"Task for {message_id=} not found", OasstErrorCode.TASK_NOT_FOUND)
if task.expired:
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
if not task.ack:
raise OasstError("Task is not acknowledged.", OasstErrorCode.TASK_NOT_ACK)
if task.done:
raise OasstError("Task already done.", OasstErrorCode.TASK_ALREADY_DONE)
# If there's no parent message assume user started new conversation
role = "user"
depth = 0
if wp.parent_message_id:
parent_message = self.fetch_message(wp.parent_message_id)
if task.parent_message_id:
parent_message = self.fetch_message(task.parent_message_id)
parent_message.children_count += 1
self.db.add(parent_message)
@@ -153,29 +153,29 @@ class PromptRepository:
user_message = self.insert_message(
message_id=new_message_id,
frontend_message_id=user_message_id,
parent_id=wp.parent_message_id,
message_tree_id=wp.message_tree_id or new_message_id,
workpackage_id=wp.id,
parent_id=task.parent_message_id,
message_tree_id=task.message_tree_id or new_message_id,
task_id=task.id,
role=role,
payload=db_payload.MessagePayload(text=text),
depth=depth,
)
if not wp.collective:
wp.done = True
self.db.add(wp)
if not task.collective:
task.done = True
self.db.add(task)
self.db.commit()
self.journal.log_text_reply(work_package=wp, message_id=new_message_id, role=role, length=len(text))
self.journal.log_text_reply(task=task, message_id=new_message_id, role=role, length=len(text))
return user_message
def store_rating(self, rating: protocol_schema.MessageRating) -> MessageReaction:
message = self.fetch_message_by_frontend_message_id(rating.message_id, fail_if_missing=True)
work_package = self.fetch_workpackage_by_message_id(rating.message_id)
work_payload: db_payload.RateSummaryPayload = work_package.payload.payload
task = self.fetch_task_by_message_id(rating.message_id)
work_payload: db_payload.RateSummaryPayload = task.payload.payload
if type(work_payload) != db_payload.RateSummaryPayload:
raise OasstError(
f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RateSummaryPayload}",
OasstErrorCode.WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH,
f"task payload type mismatch: {type(work_payload)=} != {db_payload.RateSummaryPayload}",
OasstErrorCode.TASK_PAYLOAD_TYPE_MISMATCH,
)
if rating.rating < work_payload.scale.min or rating.rating > work_payload.scale.max:
@@ -187,23 +187,23 @@ class PromptRepository:
# store reaction to message
reaction_payload = db_payload.RatingReactionPayload(rating=rating.rating)
reaction = self.insert_reaction(message.id, reaction_payload)
if not work_package.collective:
work_package.done = True
self.db.add(work_package)
if not task.collective:
task.done = True
self.db.add(task)
self.journal.log_rating(work_package, message_id=message.id, rating=rating.rating)
logger.info(f"Ranking {rating.rating} stored for work_package {work_package.id}.")
self.journal.log_rating(task, message_id=message.id, rating=rating.rating)
logger.info(f"Ranking {rating.rating} stored for task {task.id}.")
return reaction
def store_ranking(self, ranking: protocol_schema.MessageRanking) -> MessageReaction:
# fetch work_package
work_package = self.fetch_workpackage_by_message_id(ranking.message_id)
if not work_package.collective:
work_package.done = True
self.db.add(work_package)
# fetch task
task = self.fetch_task_by_message_id(ranking.message_id)
if not task.collective:
task.done = True
self.db.add(task)
work_payload: db_payload.RankConversationRepliesPayload | db_payload.RankInitialPromptsPayload = (
work_package.payload.payload
task.payload.payload
)
match type(work_payload):
@@ -219,11 +219,11 @@ class PromptRepository:
# store reaction to message
reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking)
reaction = self.insert_reaction(work_package.id, reaction_payload)
reaction = self.insert_reaction(task.id, reaction_payload)
# TODO: resolve message_id
self.journal.log_ranking(work_package, message_id=None, ranking=ranking.ranking)
self.journal.log_ranking(task, message_id=None, ranking=ranking.ranking)
logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.")
logger.info(f"Ranking {ranking.ranking} stored for task {task.id}.")
return reaction
@@ -237,18 +237,18 @@ class PromptRepository:
# store reaction to message
reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking)
reaction = self.insert_reaction(work_package.id, reaction_payload)
reaction = self.insert_reaction(task.id, reaction_payload)
# TODO: resolve message_id
self.journal.log_ranking(work_package, message_id=None, ranking=ranking.ranking)
self.journal.log_ranking(task, message_id=None, ranking=ranking.ranking)
logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.")
logger.info(f"Ranking {ranking.ranking} stored for task {task.id}.")
return reaction
case _:
raise OasstError(
f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RankConversationRepliesPayload}",
OasstErrorCode.WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH,
f"task payload type mismatch: {type(work_payload)=} != {db_payload.RankConversationRepliesPayload}",
OasstErrorCode.TASK_PAYLOAD_TYPE_MISMATCH,
)
def store_task(
@@ -257,7 +257,7 @@ class PromptRepository:
message_tree_id: UUID = None,
parent_message_id: UUID = None,
collective: bool = False,
) -> WorkPackage:
) -> Task:
payload: db_payload.TaskPayload
match type(task):
case protocol_schema.SummarizeStoryTask:
@@ -293,22 +293,22 @@ class PromptRepository:
case _:
raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE)
wp = self.insert_work_package(
task = self.insert_task(
payload=payload, id=task.id, message_tree_id=message_tree_id, parent_message_id=parent_message_id, collective=collective
)
assert wp.id == task.id
return wp
assert task.id == task.id
return task
def insert_work_package(
def insert_task(
self,
payload: db_payload.TaskPayload,
id: UUID = None,
message_tree_id: UUID = None,
parent_message_id: UUID = None,
collective: bool = False,
) -> WorkPackage:
) -> Task:
c = PayloadContainer(payload=payload)
wp = WorkPackage(
task = Task(
id=id,
user_id=self.user_id,
payload_type=type(payload).__name__,
@@ -318,10 +318,10 @@ class PromptRepository:
parent_message_id=parent_message_id,
collective=collective,
)
self.db.add(wp)
self.db.add(task)
self.db.commit()
self.db.refresh(wp)
return wp
self.db.refresh(task)
return task
def insert_message(
self,
@@ -330,7 +330,7 @@ class PromptRepository:
frontend_message_id: str,
parent_id: UUID,
message_tree_id: UUID,
workpackage_id: UUID,
task_id: UUID,
role: str,
payload: db_payload.MessagePayload,
payload_type: str = None,
@@ -346,7 +346,7 @@ class PromptRepository:
id=message_id,
parent_id=parent_id,
message_tree_id=message_tree_id,
workpackage_id=workpackage_id,
task_id=task_id,
user_id=self.user_id,
role=role,
frontend_message_id=frontend_message_id,
@@ -360,13 +360,13 @@ class PromptRepository:
self.db.refresh(message)
return message
def insert_reaction(self, work_package_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction:
def insert_reaction(self, task_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction:
if self.user_id is None:
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
container = PayloadContainer(payload=payload)
reaction = MessageReaction(
work_package_id=work_package_id,
task_id=task_id,
user_id=self.user_id,
payload=container,
api_client_id=self.api_client.id,
@@ -474,17 +474,17 @@ class PromptRepository:
def close_task(self, message_id: str, allow_personal_tasks: bool = False):
self.validate_message_id(message_id)
wp = self.fetch_workpackage_by_message_id(message_id)
task = self.fetch_task_by_message_id(message_id)
if not wp:
raise OasstError("Work package not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND)
if wp.expired:
raise OasstError("Work package expired", OasstErrorCode.WORK_PACKAGE_EXPIRED)
if not allow_personal_tasks and not wp.collective:
raise OasstError("This is not a collective task", OasstErrorCode.WORK_PACKAGE_NOT_COLLECTIVE)
if wp.done:
raise OasstError("Allready closed", OasstErrorCode.WORK_PACKAGE_ALREADY_DONE)
if not task:
raise OasstError("Work package not found", OasstErrorCode.TASK_NOT_FOUND)
if task.expired:
raise OasstError("Work package expired", OasstErrorCode.TASK_EXPIRED)
if not allow_personal_tasks and not task.collective:
raise OasstError("This is not a collective task", OasstErrorCode.TASK_NOT_COLLECTIVE)
if task.done:
raise OasstError("Allready closed", OasstErrorCode.TASK_ALREADY_DONE)
wp.done = True
self.db.add(wp)
task.done = True
self.db.add(task)
self.db.commit()