mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
role='user' -> role='prompter'
This commit is contained in:
committed by
Andreas Köpf
parent
f6ea90187c
commit
ee14554e1b
+4
-4
@@ -87,7 +87,7 @@ if settings.DEBUG_USE_SEED_DATA:
|
||||
user_message_id="6f1d0711",
|
||||
parent_message_id=None,
|
||||
text="Hi!",
|
||||
role="user",
|
||||
role="prompter",
|
||||
),
|
||||
DummyMessage(
|
||||
task_message_id="74c381d4",
|
||||
@@ -101,14 +101,14 @@ if settings.DEBUG_USE_SEED_DATA:
|
||||
user_message_id="a8c01c04",
|
||||
parent_message_id="4a24530b",
|
||||
text="Do you have a recipe for potato soup?",
|
||||
role="user",
|
||||
role="prompter",
|
||||
),
|
||||
DummyMessage(
|
||||
task_message_id="643716c1",
|
||||
user_message_id="f43a93b7",
|
||||
parent_message_id="4a24530b",
|
||||
text="Who were the 8 presidents before George Washington?",
|
||||
role="user",
|
||||
role="prompter",
|
||||
),
|
||||
DummyMessage(
|
||||
task_message_id="2e4e1e6",
|
||||
@@ -122,7 +122,7 @@ if settings.DEBUG_USE_SEED_DATA:
|
||||
user_message_id="cec432cf",
|
||||
parent_message_id=None,
|
||||
text="euirdteunvglfe23908230892309832098 AAAAAAAA",
|
||||
role="user",
|
||||
role="prompter",
|
||||
),
|
||||
DummyMessage(
|
||||
task_message_id="6066118e",
|
||||
|
||||
@@ -54,7 +54,7 @@ def generate_task(
|
||||
task = protocol_schema.InitialPromptTask(
|
||||
hint="Ask the assistant about a current event." # this is optional
|
||||
)
|
||||
case protocol_schema.TaskRequestType.user_reply:
|
||||
case protocol_schema.TaskRequestType.prompter_reply:
|
||||
logger.info("Generating a UserReplyTask.")
|
||||
messages = pr.fetch_random_conversation("assistant")
|
||||
task_messages = [
|
||||
@@ -64,7 +64,7 @@ def generate_task(
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
task = protocol_schema.UserReplyTask(conversation=protocol_schema.Conversation(messages=task_messages))
|
||||
task = protocol_schema.PrompterReplyTask(conversation=protocol_schema.Conversation(messages=task_messages))
|
||||
message_tree_id = messages[-1].message_tree_id
|
||||
parent_message_id = messages[-1].id
|
||||
case protocol_schema.TaskRequestType.assistant_reply:
|
||||
@@ -85,7 +85,7 @@ def generate_task(
|
||||
|
||||
messages = pr.fetch_random_initial_prompts()
|
||||
task = protocol_schema.RankInitialPromptsTask(prompts=[msg.payload.payload.text for msg in messages])
|
||||
case protocol_schema.TaskRequestType.rank_user_replies:
|
||||
case protocol_schema.TaskRequestType.rank_prompter_replies:
|
||||
logger.info("Generating a RankUserRepliesTask.")
|
||||
conversation, replies = pr.fetch_multiple_random_replies(message_role="assistant")
|
||||
|
||||
@@ -97,7 +97,7 @@ def generate_task(
|
||||
for p in conversation
|
||||
]
|
||||
replies = [p.payload.payload.text for p in replies]
|
||||
task = protocol_schema.RankUserRepliesTask(
|
||||
task = protocol_schema.RankPrompterRepliesTask(
|
||||
conversation=protocol_schema.Conversation(
|
||||
messages=task_messages,
|
||||
),
|
||||
@@ -106,7 +106,7 @@ def generate_task(
|
||||
|
||||
case protocol_schema.TaskRequestType.rank_assistant_replies:
|
||||
logger.info("Generating a RankAssistantRepliesTask.")
|
||||
conversation, replies = pr.fetch_multiple_random_replies(message_role="user")
|
||||
conversation, replies = pr.fetch_multiple_random_replies(message_role="prompter")
|
||||
|
||||
task_messages = [
|
||||
protocol_schema.ConversationMessage(
|
||||
|
||||
@@ -32,8 +32,8 @@ class InitialPromptPayload(TaskPayload):
|
||||
|
||||
|
||||
@payload_type
|
||||
class UserReplyPayload(TaskPayload):
|
||||
type: Literal["user_reply"] = "user_reply"
|
||||
class PrompterReplyPayload(TaskPayload):
|
||||
type: Literal["prompter_reply"] = "prompter_reply"
|
||||
conversation: protocol_schema.Conversation
|
||||
hint: str | None
|
||||
|
||||
@@ -81,10 +81,10 @@ class RankInitialPromptsPayload(TaskPayload):
|
||||
|
||||
|
||||
@payload_type
|
||||
class RankUserRepliesPayload(RankConversationRepliesPayload):
|
||||
"""A task to rank a set of user replies to a conversation."""
|
||||
class RankPrompterRepliesPayload(RankConversationRepliesPayload):
|
||||
"""A task to rank a set of prompter replies to a conversation."""
|
||||
|
||||
type: Literal["rank_user_replies"] = "rank_user_replies"
|
||||
type: Literal["rank_prompter_replies"] = "rank_prompter_replies"
|
||||
|
||||
|
||||
@payload_type
|
||||
|
||||
@@ -23,7 +23,7 @@ class Message(SQLModel, table=True):
|
||||
message_tree_id: UUID = Field(nullable=False, index=True)
|
||||
task_id: UUID = Field(nullable=True, index=True)
|
||||
user_id: UUID = Field(nullable=True, foreign_key="user.id", index=True)
|
||||
role: str = Field(nullable=False, max_length=128)
|
||||
role: str = Field(nullable=False, max_length=128) # valid: "prompter" | "assistant"
|
||||
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
|
||||
frontend_message_id: str = Field(max_length=200, nullable=False)
|
||||
created_date: Optional[datetime] = Field(
|
||||
|
||||
@@ -132,7 +132,7 @@ class PromptRepository:
|
||||
raise OasstError("Task already done.", OasstErrorCode.TASK_ALREADY_DONE)
|
||||
|
||||
# If there's no parent message assume user started new conversation
|
||||
role = "user"
|
||||
role = "prompter"
|
||||
depth = 0
|
||||
|
||||
if task.parent_message_id:
|
||||
@@ -142,7 +142,7 @@ class PromptRepository:
|
||||
|
||||
depth = parent_message.depth + 1
|
||||
if parent_message.role == "assistant":
|
||||
role = "user"
|
||||
role = "prompter"
|
||||
else:
|
||||
role = "assistant"
|
||||
|
||||
@@ -206,7 +206,7 @@ class PromptRepository:
|
||||
|
||||
match type(task_payload):
|
||||
|
||||
case db_payload.RankUserRepliesPayload | db_payload.RankAssistantRepliesPayload:
|
||||
case db_payload.RankPrompterRepliesPayload | db_payload.RankAssistantRepliesPayload:
|
||||
# validate ranking
|
||||
num_replies = len(task_payload.replies)
|
||||
if sorted(ranking.ranking) != list(range(num_replies)):
|
||||
@@ -269,8 +269,8 @@ class PromptRepository:
|
||||
case protocol_schema.InitialPromptTask:
|
||||
payload = db_payload.InitialPromptPayload(hint=task.hint)
|
||||
|
||||
case protocol_schema.UserReplyTask:
|
||||
payload = db_payload.UserReplyPayload(conversation=task.conversation, hint=task.hint)
|
||||
case protocol_schema.PrompterReplyTask:
|
||||
payload = db_payload.PrompterReplyPayload(conversation=task.conversation, hint=task.hint)
|
||||
|
||||
case protocol_schema.AssistantReplyTask:
|
||||
payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation)
|
||||
@@ -278,8 +278,8 @@ class PromptRepository:
|
||||
case protocol_schema.RankInitialPromptsTask:
|
||||
payload = db_payload.RankInitialPromptsPayload(tpye=task.type, prompts=task.prompts)
|
||||
|
||||
case protocol_schema.RankUserRepliesTask:
|
||||
payload = db_payload.RankUserRepliesPayload(
|
||||
case protocol_schema.RankPrompterRepliesTask:
|
||||
payload = db_payload.RankPrompterRepliesPayload(
|
||||
tpye=task.type, conversation=task.conversation, replies=task.replies
|
||||
)
|
||||
|
||||
|
||||
@@ -10,10 +10,10 @@ class TaskType(str, enum.Enum):
|
||||
summarize_story = "summarize_story"
|
||||
rate_summary = "rate_summary"
|
||||
initial_prompt = "initial_prompt"
|
||||
user_reply = "user_reply"
|
||||
prompter_reply = "prompter_reply"
|
||||
assistant_reply = "assistant_reply"
|
||||
rank_initial_prompts = "rank_initial_prompts"
|
||||
rank_user_replies = "rank_user_replies"
|
||||
rank_prompter_replies = "rank_prompter_replies"
|
||||
rank_assistant_replies = "rank_assistant_replies"
|
||||
done = "task_done"
|
||||
|
||||
@@ -27,10 +27,10 @@ class ApiClient:
|
||||
TaskType.summarize_story: protocol_schema.SummarizeStoryTask,
|
||||
TaskType.rate_summary: protocol_schema.RateSummaryTask,
|
||||
TaskType.initial_prompt: protocol_schema.InitialPromptTask,
|
||||
TaskType.user_reply: protocol_schema.UserReplyTask,
|
||||
TaskType.prompter_reply: protocol_schema.PrompterReplyTask,
|
||||
TaskType.assistant_reply: protocol_schema.AssistantReplyTask,
|
||||
TaskType.rank_initial_prompts: protocol_schema.RankInitialPromptsTask,
|
||||
TaskType.rank_user_replies: protocol_schema.RankUserRepliesTask,
|
||||
TaskType.rank_prompter_replies: protocol_schema.RankPrompterRepliesTask,
|
||||
TaskType.rank_assistant_replies: protocol_schema.RankAssistantRepliesTask,
|
||||
TaskType.done: protocol_schema.TaskDone,
|
||||
}
|
||||
|
||||
+3
-3
@@ -137,13 +137,13 @@ class OpenAssistantBot(BotBase):
|
||||
handler = task_handlers.RateSummaryHandler()
|
||||
case TaskType.initial_prompt:
|
||||
handler = task_handlers.InitialPromptHandler()
|
||||
case TaskType.user_reply:
|
||||
handler = task_handlers.UserReplyHandler()
|
||||
case TaskType.prompter_reply:
|
||||
handler = task_handlers.PrompterReplyHandler()
|
||||
case TaskType.assistant_reply:
|
||||
handler = task_handlers.AssistantReplyHandler()
|
||||
case TaskType.rank_initial_prompts:
|
||||
handler = task_handlers.RankInitialPromptsHandler()
|
||||
case TaskType.rank_user_replies | TaskType.rank_assistant_replies:
|
||||
case TaskType.rank_prompter_replies | TaskType.rank_assistant_replies:
|
||||
handler = task_handlers.RankConversationsHandler()
|
||||
case _:
|
||||
logger.warning(f"Unsupported task type received: {task.type}")
|
||||
|
||||
@@ -146,15 +146,15 @@ class InitialPromptHandler(ChannelTaskBase):
|
||||
await self.handle_text_reply_to_post(msg)
|
||||
|
||||
|
||||
class UserReplyHandler(ChannelTaskBase):
|
||||
task: protocol_schema.UserReplyTask
|
||||
class PrompterReplyHandler(ChannelTaskBase):
|
||||
task: protocol_schema.PrompterReplyTask
|
||||
thread_name: str = "User replies"
|
||||
|
||||
async def send_first_message(self) -> discord.message:
|
||||
return await self.post_teaser_msg("teaser_user_reply.msg")
|
||||
return await self.post_teaser_msg("teaser_prompter_reply.msg")
|
||||
|
||||
async def on_thread_created(self, thread: discord.Thread) -> None:
|
||||
await self.bot.post_template("task_user_reply.msg", channel=thread, task=self.task)
|
||||
await self.bot.post_template("task_prompter_reply.msg", channel=thread, task=self.task)
|
||||
|
||||
async def handler_loop(self):
|
||||
while True:
|
||||
|
||||
@@ -12,10 +12,10 @@ class TaskRequestType(str, enum.Enum):
|
||||
summarize_story = "summarize_story"
|
||||
rate_summary = "rate_summary"
|
||||
initial_prompt = "initial_prompt"
|
||||
user_reply = "user_reply"
|
||||
prompter_reply = "prompter_reply"
|
||||
assistant_reply = "assistant_reply"
|
||||
rank_initial_prompts = "rank_initial_prompts"
|
||||
rank_user_replies = "rank_user_replies"
|
||||
rank_prompter_replies = "rank_prompter_replies"
|
||||
rank_assistant_replies = "rank_assistant_replies"
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ class ConversationMessage(BaseModel):
|
||||
|
||||
|
||||
class Conversation(BaseModel):
|
||||
"""Represents a conversation between the user and the assistant."""
|
||||
"""Represents a conversation between the prompter and the assistant."""
|
||||
|
||||
messages: list[ConversationMessage] = []
|
||||
|
||||
@@ -114,10 +114,10 @@ class ReplyToConversationTask(Task):
|
||||
conversation: Conversation # the conversation so far
|
||||
|
||||
|
||||
class UserReplyTask(ReplyToConversationTask, WithHintMixin):
|
||||
class PrompterReplyTask(ReplyToConversationTask, WithHintMixin):
|
||||
"""A task to prompt the user to submit a reply to the assistant."""
|
||||
|
||||
type: Literal["user_reply"] = "user_reply"
|
||||
type: Literal["prompter_reply"] = "prompter_reply"
|
||||
|
||||
|
||||
class AssistantReplyTask(ReplyToConversationTask):
|
||||
@@ -141,10 +141,10 @@ class RankConversationRepliesTask(Task):
|
||||
replies: list[str]
|
||||
|
||||
|
||||
class RankUserRepliesTask(RankConversationRepliesTask):
|
||||
"""A task to rank a set of user replies to a conversation."""
|
||||
class RankPrompterRepliesTask(RankConversationRepliesTask):
|
||||
"""A task to rank a set of prompter replies to a conversation."""
|
||||
|
||||
type: Literal["rank_user_replies"] = "rank_user_replies"
|
||||
type: Literal["rank_prompter_replies"] = "rank_prompter_replies"
|
||||
|
||||
|
||||
class RankAssistantRepliesTask(RankConversationRepliesTask):
|
||||
@@ -165,11 +165,11 @@ AnyTask = Union[
|
||||
RateSummaryTask,
|
||||
InitialPromptTask,
|
||||
ReplyToConversationTask,
|
||||
UserReplyTask,
|
||||
PrompterReplyTask,
|
||||
AssistantReplyTask,
|
||||
RankInitialPromptsTask,
|
||||
RankConversationRepliesTask,
|
||||
RankUserRepliesTask,
|
||||
RankPrompterRepliesTask,
|
||||
RankAssistantRepliesTask,
|
||||
]
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ def _render_message(message: dict) -> str:
|
||||
"""Render a message to the user."""
|
||||
if message["is_assistant"]:
|
||||
return f"Assistant: {message['text']}"
|
||||
return f"User: {message['text']}"
|
||||
return f"Prompter: {message['text']}"
|
||||
|
||||
|
||||
@app.command()
|
||||
@@ -107,7 +107,7 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
|
||||
)
|
||||
tasks.append(new_task)
|
||||
|
||||
case "user_reply":
|
||||
case "prompter_reply":
|
||||
typer.echo("Please provide a reply to the assistant.")
|
||||
typer.echo("Here is the conversation so far:")
|
||||
for message in task["conversation"]["messages"]:
|
||||
@@ -178,7 +178,7 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
|
||||
)
|
||||
tasks.append(new_task)
|
||||
|
||||
case "rank_user_replies" | "rank_assistant_replies":
|
||||
case "rank_prompter_replies" | "rank_assistant_replies":
|
||||
typer.echo("Here is the conversation so far:")
|
||||
for message in task["conversation"]["messages"]:
|
||||
typer.echo(_render_message(message))
|
||||
|
||||
Reference in New Issue
Block a user