mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Add user emoji augmentation for message queries (#937)
* add disposition to text labeling tasks
* add emoji stats to ConversationMessage
* add user emoji augmentation for message queries
* add auth_method,username to message queries (query emoji status)
* add auth_method+username for single message
* fix param name typo
* only join rows when message.emojis != JSON.NULL
* formatting
* make sure emojis and user_emojis default to {}, []
* remove init_user(), use fresh empty default collections
This commit is contained in:
@@ -77,7 +77,7 @@ def query_frontend_user_messages(
|
||||
"""
|
||||
Query frontend user messages.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client)
|
||||
pr = PromptRepository(db, api_client, auth_method=auth_method, username=username)
|
||||
messages = pr.query_messages_ordered_by_created_date(
|
||||
auth_method=auth_method,
|
||||
username=username,
|
||||
|
||||
@@ -34,7 +34,7 @@ def query_messages(
|
||||
"""
|
||||
Query messages.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client)
|
||||
pr = PromptRepository(db, api_client, auth_method=auth_method, username=username)
|
||||
messages = pr.query_messages_ordered_by_created_date(
|
||||
auth_method=auth_method,
|
||||
username=username,
|
||||
@@ -93,7 +93,7 @@ def get_messages_cursor(
|
||||
|
||||
qry_max_count = max_count + 1 if before is None or after is None else max_count
|
||||
|
||||
pr = PromptRepository(db, api_client)
|
||||
pr = PromptRepository(db, api_client, auth_method=auth_method, username=username, user_id=user_id)
|
||||
items = pr.query_messages_ordered_by_created_date(
|
||||
user_id=user_id,
|
||||
auth_method=auth_method,
|
||||
@@ -137,37 +137,49 @@ def get_messages_cursor(
|
||||
|
||||
@router.get("/{message_id}", response_model=protocol.Message)
|
||||
def get_message(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
message_id: UUID,
|
||||
auth_method: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
"""
|
||||
Get a message by its internal ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client)
|
||||
pr = PromptRepository(db, api_client, auth_method=auth_method, username=username)
|
||||
message = pr.fetch_message(message_id)
|
||||
return utils.prepare_message(message)
|
||||
|
||||
|
||||
@router.get("/{message_id}/conversation", response_model=protocol.Conversation)
|
||||
def get_conv(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
message_id: UUID,
|
||||
auth_method: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
"""
|
||||
Get a conversation from the tree root and up to the message with given internal ID.
|
||||
"""
|
||||
|
||||
pr = PromptRepository(db, api_client)
|
||||
pr = PromptRepository(db, api_client, auth_method=auth_method, username=username)
|
||||
messages = pr.fetch_message_conversation(message_id)
|
||||
return utils.prepare_conversation(messages)
|
||||
|
||||
|
||||
@router.get("/{message_id}/tree", response_model=protocol.MessageTree)
|
||||
def get_tree(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
message_id: UUID,
|
||||
auth_method: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
"""
|
||||
Get all messages belonging to the same message tree.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client)
|
||||
pr = PromptRepository(db, api_client, auth_method=auth_method, username=username)
|
||||
message = pr.fetch_message(message_id)
|
||||
tree = pr.fetch_message_tree(message.message_tree_id, reviewed=False)
|
||||
return utils.prepare_tree(tree, message.message_tree_id)
|
||||
@@ -175,24 +187,32 @@ def get_tree(
|
||||
|
||||
@router.get("/{message_id}/children", response_model=list[protocol.Message])
|
||||
def get_children(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
message_id: UUID,
|
||||
auth_method: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
"""
|
||||
Get all messages belonging to the same message tree.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client)
|
||||
pr = PromptRepository(db, api_client, auth_method=auth_method, username=username)
|
||||
messages = pr.fetch_message_children(message_id)
|
||||
return utils.prepare_message_list(messages)
|
||||
|
||||
|
||||
@router.get("/{message_id}/descendants", response_model=protocol.MessageTree)
|
||||
def get_descendants(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
message_id: UUID,
|
||||
auth_method: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
"""
|
||||
Get a subtree which starts with this message.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client)
|
||||
pr = PromptRepository(db, api_client, auth_method=auth_method, username=username)
|
||||
message = pr.fetch_message(message_id)
|
||||
descendants = pr.fetch_message_descendants(message)
|
||||
return utils.prepare_tree(descendants, message.id)
|
||||
@@ -200,12 +220,16 @@ def get_descendants(
|
||||
|
||||
@router.get("/{message_id}/longest_conversation_in_tree", response_model=protocol.Conversation)
|
||||
def get_longest_conv(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
message_id: UUID,
|
||||
auth_method: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
"""
|
||||
Get the longest conversation from the tree of the message.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client)
|
||||
pr = PromptRepository(db, api_client, auth_method=auth_method, username=username)
|
||||
message = pr.fetch_message(message_id)
|
||||
conv = pr.fetch_longest_conversation(message.message_tree_id)
|
||||
return utils.prepare_conversation(conv)
|
||||
@@ -213,12 +237,16 @@ def get_longest_conv(
|
||||
|
||||
@router.get("/{message_id}/max_children_in_tree", response_model=protocol.MessageTree)
|
||||
def get_max_children(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
message_id: UUID,
|
||||
auth_method: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
"""
|
||||
Get message with the most children from the tree of the provided message.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client)
|
||||
pr = PromptRepository(db, api_client, auth_method=auth_method, username=username)
|
||||
message = pr.fetch_message(message_id)
|
||||
message, children = pr.fetch_message_with_max_children(message.message_tree_id)
|
||||
return utils.prepare_tree([message, *children], message.id)
|
||||
|
||||
@@ -230,7 +230,7 @@ def query_user_messages(
|
||||
"""
|
||||
Query user messages.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client)
|
||||
pr = PromptRepository(db, api_client, user_id=user_id)
|
||||
messages = pr.query_messages_ordered_by_created_date(
|
||||
user_id=user_id,
|
||||
api_client_id=api_client_id,
|
||||
|
||||
@@ -14,7 +14,8 @@ def prepare_message(m: Message) -> protocol.Message:
|
||||
lang=m.lang,
|
||||
is_assistant=(m.role == "assistant"),
|
||||
created_date=m.created_date,
|
||||
emojis=m.emojis,
|
||||
emojis=m.emojis or {},
|
||||
user_emojis=m.user_emojis or [],
|
||||
)
|
||||
|
||||
|
||||
@@ -30,6 +31,8 @@ def prepare_conversation_message_list(messages: list[Message]) -> list[protocol.
|
||||
text=message.text,
|
||||
lang=message.lang,
|
||||
is_assistant=(message.role == "assistant"),
|
||||
emojis=message.emojis or {},
|
||||
user_emojis=message.user_emojis or [],
|
||||
)
|
||||
for message in messages
|
||||
]
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from datetime import datetime
|
||||
from http import HTTPStatus
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from oasst_backend.models.db_payload import MessagePayload
|
||||
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
||||
from pydantic import PrivateAttr
|
||||
from sqlalchemy import false
|
||||
from sqlmodel import Field, Index, SQLModel
|
||||
|
||||
@@ -17,6 +18,13 @@ class Message(SQLModel, table=True):
|
||||
__tablename__ = "message"
|
||||
__table_args__ = (Index("ix_message_frontend_message_id", "api_client_id", "frontend_message_id", unique=True),)
|
||||
|
||||
def __new__(cls, *args: Any, **kwargs: Any):
|
||||
new_object = super().__new__(cls, *args, **kwargs)
|
||||
# temporary fix until https://github.com/tiangolo/sqlmodel/issues/149 gets merged
|
||||
if not hasattr(new_object, "_user_emojis"):
|
||||
new_object._init_private_attributes()
|
||||
return new_object
|
||||
|
||||
id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(
|
||||
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
|
||||
@@ -49,7 +57,8 @@ class Message(SQLModel, table=True):
|
||||
|
||||
rank: Optional[int] = Field(nullable=True)
|
||||
|
||||
emojis: dict[str, int] = Field(default={}, sa_column=sa.Column(pg.JSONB), nullable=False)
|
||||
emojis: Optional[dict[str, int]] = Field(default=None, sa_column=sa.Column(pg.JSONB), nullable=False)
|
||||
_user_emojis: Optional[list[str]] = PrivateAttr(default=None)
|
||||
|
||||
def ensure_is_message(self) -> None:
|
||||
if not self.payload or not isinstance(self.payload.payload, MessagePayload):
|
||||
@@ -59,3 +68,7 @@ class Message(SQLModel, table=True):
|
||||
def text(self) -> str:
|
||||
self.ensure_is_message()
|
||||
return self.payload.payload.text
|
||||
|
||||
@property
|
||||
def user_emojis(self) -> str:
|
||||
return self._user_emojis
|
||||
|
||||
@@ -30,8 +30,9 @@ from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from oasst_shared.schemas.protocol import SystemStats
|
||||
from oasst_shared.utils import unaware_to_utc
|
||||
from sqlalchemy.orm import Query
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
from sqlmodel import Session, and_, func, not_, or_, text, update
|
||||
from sqlmodel import JSON, Session, and_, func, literal_column, not_, or_, text, update
|
||||
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
@@ -41,14 +42,25 @@ class PromptRepository:
|
||||
db: Session,
|
||||
api_client: ApiClient,
|
||||
client_user: Optional[protocol_schema.User] = None,
|
||||
*,
|
||||
user_repository: Optional[UserRepository] = None,
|
||||
task_repository: Optional[TaskRepository] = None,
|
||||
user_id: Optional[UUID] = None,
|
||||
auth_method: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
):
|
||||
self.db = db
|
||||
self.api_client = api_client
|
||||
self.user_repository = user_repository or UserRepository(db, api_client)
|
||||
self.user = self.user_repository.lookup_client_user(client_user, create_missing=True)
|
||||
self.user_id = self.user.id if self.user else None
|
||||
if user_id:
|
||||
self.user = self.user_repository.get_user(id=user_id)
|
||||
self.user_id = self.user.id
|
||||
elif auth_method and username:
|
||||
self.user = self.user_repository.query_frontend_user(auth_method=auth_method, username=username)
|
||||
self.user_id = self.user.id
|
||||
else:
|
||||
self.user = self.user_repository.lookup_client_user(client_user, create_missing=True)
|
||||
self.user_id = self.user.id if self.user else None
|
||||
logger.debug(f"PromptRepository(api_client_id={self.api_client.id}, {self.user_id=})")
|
||||
self.task_repository = task_repository or TaskRepository(
|
||||
db, api_client, client_user, user_repository=self.user_repository
|
||||
@@ -529,7 +541,7 @@ class PromptRepository:
|
||||
qry = qry.filter(Message.review_result)
|
||||
if not include_deleted:
|
||||
qry = qry.filter(not_(Message.deleted))
|
||||
return qry.all()
|
||||
return self._add_user_emojis_all(qry)
|
||||
|
||||
def fetch_user_message_trees(
|
||||
self, user_id: Message.user_id, reviewed: bool = True, include_deleted: bool = False
|
||||
@@ -539,7 +551,7 @@ class PromptRepository:
|
||||
qry = qry.filter(Message.review_result)
|
||||
if not include_deleted:
|
||||
qry = qry.filter(not_(Message.deleted))
|
||||
return qry.all()
|
||||
return self._add_user_emojis_all(qry)
|
||||
|
||||
def fetch_message_trees_ready_for_export(self) -> list[MessageTreeState]:
|
||||
qry = self.db.query(MessageTreeState).filter(
|
||||
@@ -582,6 +594,10 @@ class PromptRepository:
|
||||
return conversation, replies
|
||||
|
||||
def fetch_message(self, message_id: UUID, fail_if_missing: bool = True) -> Optional[Message]:
|
||||
qry = self.db.query(Message).filter(Message.id == message_id)
|
||||
messages = self._add_user_emojis_all(qry)
|
||||
message = messages[0] if messages else None
|
||||
|
||||
message = self.db.query(Message).filter(Message.id == message_id).one_or_none()
|
||||
if fail_if_missing and not message:
|
||||
raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTP_404_NOT_FOUND)
|
||||
@@ -656,7 +672,7 @@ class PromptRepository:
|
||||
qry = qry.filter(Message.review_result)
|
||||
if exclude_deleted:
|
||||
qry = qry.filter(Message.deleted == sa.false())
|
||||
children = qry.all()
|
||||
children = self._add_user_emojis_all(qry)
|
||||
return children
|
||||
|
||||
def fetch_message_siblings(
|
||||
@@ -674,7 +690,7 @@ class PromptRepository:
|
||||
qry = qry.filter(Message.review_result == reviewed)
|
||||
if deleted is not None:
|
||||
qry = qry.filter(Message.deleted == deleted)
|
||||
siblings = qry.all()
|
||||
siblings = self._add_user_emojis_all(qry)
|
||||
return siblings
|
||||
|
||||
@staticmethod
|
||||
@@ -705,7 +721,7 @@ class PromptRepository:
|
||||
if max_depth is not None:
|
||||
desc = desc.filter(Message.depth <= max_depth)
|
||||
|
||||
desc = desc.all()
|
||||
desc = self._add_user_emojis_all(desc)
|
||||
|
||||
return self.trace_descendants(message, desc)
|
||||
|
||||
@@ -719,6 +735,33 @@ class PromptRepository:
|
||||
max_message = max(tree, key=lambda m: m.children_count)
|
||||
return max_message, [m for m in tree if m.parent_id == max_message.id]
|
||||
|
||||
def _add_user_emojis_all(self, qry: Query) -> list[Message]:
|
||||
if self.user_id is None:
|
||||
return qry.all()
|
||||
|
||||
sq = qry.subquery("m")
|
||||
qry = (
|
||||
self.db.query(Message, func.string_agg(MessageEmoji.emoji, literal_column("','")).label("user_emojis"))
|
||||
.select_entity_from(sq)
|
||||
.outerjoin(
|
||||
MessageEmoji,
|
||||
and_(
|
||||
sq.c.id == MessageEmoji.message_id,
|
||||
MessageEmoji.user_id == self.user_id,
|
||||
sq.c.emojis != JSON.NULL,
|
||||
),
|
||||
)
|
||||
.group_by(sq)
|
||||
)
|
||||
messages: list[Message] = []
|
||||
for x in qry:
|
||||
m: Message = x.Message
|
||||
user_emojis = x["user_emojis"]
|
||||
if user_emojis:
|
||||
m._user_emojis = user_emojis.split(",")
|
||||
messages.append(m)
|
||||
return messages
|
||||
|
||||
def query_messages_ordered_by_created_date(
|
||||
self,
|
||||
user_id: Optional[UUID] = None,
|
||||
@@ -801,7 +844,7 @@ class PromptRepository:
|
||||
if lang is not None:
|
||||
qry = qry.filter(Message.lang == lang)
|
||||
|
||||
return qry.all()
|
||||
return self._add_user_emojis_all(qry)
|
||||
|
||||
def update_children_counts(self, message_tree_id: UUID):
|
||||
sql_update_children_count = """
|
||||
@@ -902,9 +945,15 @@ WHERE message.id = cc.id;
|
||||
else:
|
||||
count = emoji_counts.get(emoji.value) or 0
|
||||
emoji_counts[emoji.value] = count + 1
|
||||
if message._user_emojis is None:
|
||||
message._user_emojis = []
|
||||
if emoji.value not in message._user_emojis:
|
||||
message._user_emojis.append(emoji.value)
|
||||
elif op == protocol_schema.EmojiOp.remove:
|
||||
# remove emoji record and & decrement count
|
||||
message = self.fetch_message(message_id)
|
||||
if message._user_emojis and emoji.value in message._user_emojis:
|
||||
message._user_emojis.remove(emoji.value)
|
||||
self.db.delete(existing_emoji)
|
||||
emoji_counts = message.emojis
|
||||
count = emoji_counts.get(emoji.value)
|
||||
|
||||
@@ -354,6 +354,7 @@ class TreeManager:
|
||||
self.cfg.p_full_labeling_review_reply_prompter: float = 0.1
|
||||
|
||||
label_mode = protocol_schema.LabelTaskMode.full
|
||||
label_disposition = protocol_schema.LabelTaskDisposition.quality
|
||||
valid_labels = self._all_text_labels
|
||||
|
||||
if message.role == "assistant":
|
||||
@@ -363,6 +364,8 @@ class TreeManager:
|
||||
):
|
||||
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply))
|
||||
label_mode = protocol_schema.LabelTaskMode.simple
|
||||
label_disposition = protocol_schema.LabelTaskDisposition.spam
|
||||
|
||||
logger.info(f"Generating a LabelAssistantReplyTask. ({label_mode=:s})")
|
||||
task = protocol_schema.LabelAssistantReplyTask(
|
||||
message_id=message.id,
|
||||
@@ -371,6 +374,7 @@ class TreeManager:
|
||||
valid_labels=valid_labels,
|
||||
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)),
|
||||
mode=label_mode,
|
||||
disposition=label_disposition,
|
||||
)
|
||||
else:
|
||||
if (
|
||||
@@ -387,6 +391,7 @@ class TreeManager:
|
||||
valid_labels=valid_labels,
|
||||
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)),
|
||||
mode=label_mode,
|
||||
disposition=label_disposition,
|
||||
)
|
||||
|
||||
parent_message_id = message.id
|
||||
@@ -424,11 +429,13 @@ class TreeManager:
|
||||
message = random.choice(prompts_need_review)
|
||||
|
||||
label_mode = protocol_schema.LabelTaskMode.full
|
||||
label_disposition = protocol_schema.LabelTaskDisposition.quality
|
||||
valid_labels = self._all_text_labels
|
||||
|
||||
if random.random() > self.cfg.p_full_labeling_review_prompt:
|
||||
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt))
|
||||
label_mode = protocol_schema.LabelTaskMode.simple
|
||||
label_disposition = protocol_schema.LabelTaskDisposition.spam
|
||||
|
||||
logger.info(f"Generating a LabelInitialPromptTask ({label_mode=:s}).")
|
||||
task = protocol_schema.LabelInitialPromptTask(
|
||||
@@ -437,6 +444,7 @@ class TreeManager:
|
||||
valid_labels=valid_labels,
|
||||
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt)),
|
||||
mode=label_mode,
|
||||
disposition=label_disposition,
|
||||
)
|
||||
|
||||
parent_message_id = message.id
|
||||
|
||||
@@ -57,6 +57,8 @@ class ConversationMessage(BaseModel):
|
||||
text: str
|
||||
lang: Optional[str] # BCP 47
|
||||
is_assistant: bool
|
||||
emojis: Optional[dict[str, int]] = None
|
||||
user_emojis: Optional[list[str]] = None
|
||||
|
||||
|
||||
class Conversation(BaseModel):
|
||||
@@ -80,7 +82,6 @@ class Conversation(BaseModel):
|
||||
class Message(ConversationMessage):
|
||||
parent_id: Optional[UUID] = None
|
||||
created_date: Optional[datetime] = None
|
||||
emojis: Optional[dict] = None
|
||||
|
||||
|
||||
class MessagePage(PageResult):
|
||||
@@ -223,27 +224,34 @@ class LabelTaskMode(str, enum.Enum):
|
||||
full = "full"
|
||||
|
||||
|
||||
class LabelInitialPromptTask(Task):
|
||||
"""A task to label an initial prompt."""
|
||||
class LabelTaskDisposition(str, enum.Enum):
|
||||
"""Reason why the task was issued."""
|
||||
|
||||
type: Literal["label_initial_prompt"] = "label_initial_prompt"
|
||||
quality = "quality"
|
||||
spam = "spam"
|
||||
|
||||
|
||||
class AbstractLabelTask(Task):
|
||||
message_id: UUID
|
||||
prompt: str
|
||||
valid_labels: list[str]
|
||||
mandatory_labels: Optional[list[str]]
|
||||
mode: Optional[LabelTaskMode]
|
||||
disposition: Optional[LabelTaskDisposition]
|
||||
|
||||
|
||||
class LabelConversationReplyTask(Task):
|
||||
class LabelInitialPromptTask(AbstractLabelTask):
|
||||
"""A task to label an initial prompt."""
|
||||
|
||||
type: Literal["label_initial_prompt"] = "label_initial_prompt"
|
||||
prompt: str
|
||||
|
||||
|
||||
class LabelConversationReplyTask(AbstractLabelTask):
|
||||
"""A task to label a reply to a conversation."""
|
||||
|
||||
type: Literal["label_conversation_reply"] = "label_conversation_reply"
|
||||
conversation: Conversation # the conversation so far
|
||||
message_id: UUID
|
||||
reply: str
|
||||
valid_labels: list[str]
|
||||
mandatory_labels: Optional[list[str]]
|
||||
mode: Optional[LabelTaskMode]
|
||||
|
||||
|
||||
class LabelPrompterReplyTask(LabelConversationReplyTask):
|
||||
|
||||
Reference in New Issue
Block a user