role='user' -> role='prompter'

This commit is contained in:
Andreas Köpf
2022-12-30 23:48:03 +01:00
committed by Andreas Köpf
parent f6ea90187c
commit ee14554e1b
12 changed files with 46 additions and 46 deletions
+4 -4
View File
@@ -87,7 +87,7 @@ if settings.DEBUG_USE_SEED_DATA:
user_message_id="6f1d0711",
parent_message_id=None,
text="Hi!",
role="user",
role="prompter",
),
DummyMessage(
task_message_id="74c381d4",
@@ -101,14 +101,14 @@ if settings.DEBUG_USE_SEED_DATA:
user_message_id="a8c01c04",
parent_message_id="4a24530b",
text="Do you have a recipe for potato soup?",
role="user",
role="prompter",
),
DummyMessage(
task_message_id="643716c1",
user_message_id="f43a93b7",
parent_message_id="4a24530b",
text="Who were the 8 presidents before George Washington?",
role="user",
role="prompter",
),
DummyMessage(
task_message_id="2e4e1e6",
@@ -122,7 +122,7 @@ if settings.DEBUG_USE_SEED_DATA:
user_message_id="cec432cf",
parent_message_id=None,
text="euirdteunvglfe23908230892309832098 AAAAAAAA",
role="user",
role="prompter",
),
DummyMessage(
task_message_id="6066118e",
+5 -5
View File
@@ -54,7 +54,7 @@ def generate_task(
task = protocol_schema.InitialPromptTask(
hint="Ask the assistant about a current event." # this is optional
)
case protocol_schema.TaskRequestType.user_reply:
case protocol_schema.TaskRequestType.prompter_reply:
logger.info("Generating a UserReplyTask.")
messages = pr.fetch_random_conversation("assistant")
task_messages = [
@@ -64,7 +64,7 @@ def generate_task(
for msg in messages
]
task = protocol_schema.UserReplyTask(conversation=protocol_schema.Conversation(messages=task_messages))
task = protocol_schema.PrompterReplyTask(conversation=protocol_schema.Conversation(messages=task_messages))
message_tree_id = messages[-1].message_tree_id
parent_message_id = messages[-1].id
case protocol_schema.TaskRequestType.assistant_reply:
@@ -85,7 +85,7 @@ def generate_task(
messages = pr.fetch_random_initial_prompts()
task = protocol_schema.RankInitialPromptsTask(prompts=[msg.payload.payload.text for msg in messages])
case protocol_schema.TaskRequestType.rank_user_replies:
case protocol_schema.TaskRequestType.rank_prompter_replies:
logger.info("Generating a RankUserRepliesTask.")
conversation, replies = pr.fetch_multiple_random_replies(message_role="assistant")
@@ -97,7 +97,7 @@ def generate_task(
for p in conversation
]
replies = [p.payload.payload.text for p in replies]
task = protocol_schema.RankUserRepliesTask(
task = protocol_schema.RankPrompterRepliesTask(
conversation=protocol_schema.Conversation(
messages=task_messages,
),
@@ -106,7 +106,7 @@ def generate_task(
case protocol_schema.TaskRequestType.rank_assistant_replies:
logger.info("Generating a RankAssistantRepliesTask.")
conversation, replies = pr.fetch_multiple_random_replies(message_role="user")
conversation, replies = pr.fetch_multiple_random_replies(message_role="prompter")
task_messages = [
protocol_schema.ConversationMessage(
+5 -5
View File
@@ -32,8 +32,8 @@ class InitialPromptPayload(TaskPayload):
@payload_type
class UserReplyPayload(TaskPayload):
type: Literal["user_reply"] = "user_reply"
class PrompterReplyPayload(TaskPayload):
type: Literal["prompter_reply"] = "prompter_reply"
conversation: protocol_schema.Conversation
hint: str | None
@@ -81,10 +81,10 @@ class RankInitialPromptsPayload(TaskPayload):
@payload_type
class RankUserRepliesPayload(RankConversationRepliesPayload):
"""A task to rank a set of user replies to a conversation."""
class RankPrompterRepliesPayload(RankConversationRepliesPayload):
"""A task to rank a set of prompter replies to a conversation."""
type: Literal["rank_user_replies"] = "rank_user_replies"
type: Literal["rank_prompter_replies"] = "rank_prompter_replies"
@payload_type
+1 -1
View File
@@ -23,7 +23,7 @@ class Message(SQLModel, table=True):
message_tree_id: UUID = Field(nullable=False, index=True)
task_id: UUID = Field(nullable=True, index=True)
user_id: UUID = Field(nullable=True, foreign_key="user.id", index=True)
role: str = Field(nullable=False, max_length=128)
role: str = Field(nullable=False, max_length=128) # valid: "prompter" | "assistant"
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
frontend_message_id: str = Field(max_length=200, nullable=False)
created_date: Optional[datetime] = Field(
+7 -7
View File
@@ -132,7 +132,7 @@ class PromptRepository:
raise OasstError("Task already done.", OasstErrorCode.TASK_ALREADY_DONE)
# If there's no parent message assume user started new conversation
role = "user"
role = "prompter"
depth = 0
if task.parent_message_id:
@@ -142,7 +142,7 @@ class PromptRepository:
depth = parent_message.depth + 1
if parent_message.role == "assistant":
role = "user"
role = "prompter"
else:
role = "assistant"
@@ -206,7 +206,7 @@ class PromptRepository:
match type(task_payload):
case db_payload.RankUserRepliesPayload | db_payload.RankAssistantRepliesPayload:
case db_payload.RankPrompterRepliesPayload | db_payload.RankAssistantRepliesPayload:
# validate ranking
num_replies = len(task_payload.replies)
if sorted(ranking.ranking) != list(range(num_replies)):
@@ -269,8 +269,8 @@ class PromptRepository:
case protocol_schema.InitialPromptTask:
payload = db_payload.InitialPromptPayload(hint=task.hint)
case protocol_schema.UserReplyTask:
payload = db_payload.UserReplyPayload(conversation=task.conversation, hint=task.hint)
case protocol_schema.PrompterReplyTask:
payload = db_payload.PrompterReplyPayload(conversation=task.conversation, hint=task.hint)
case protocol_schema.AssistantReplyTask:
payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation)
@@ -278,8 +278,8 @@ class PromptRepository:
case protocol_schema.RankInitialPromptsTask:
payload = db_payload.RankInitialPromptsPayload(tpye=task.type, prompts=task.prompts)
case protocol_schema.RankUserRepliesTask:
payload = db_payload.RankUserRepliesPayload(
case protocol_schema.RankPrompterRepliesTask:
payload = db_payload.RankPrompterRepliesPayload(
tpye=task.type, conversation=task.conversation, replies=task.replies
)
+4 -4
View File
@@ -10,10 +10,10 @@ class TaskType(str, enum.Enum):
summarize_story = "summarize_story"
rate_summary = "rate_summary"
initial_prompt = "initial_prompt"
user_reply = "user_reply"
prompter_reply = "prompter_reply"
assistant_reply = "assistant_reply"
rank_initial_prompts = "rank_initial_prompts"
rank_user_replies = "rank_user_replies"
rank_prompter_replies = "rank_prompter_replies"
rank_assistant_replies = "rank_assistant_replies"
done = "task_done"
@@ -27,10 +27,10 @@ class ApiClient:
TaskType.summarize_story: protocol_schema.SummarizeStoryTask,
TaskType.rate_summary: protocol_schema.RateSummaryTask,
TaskType.initial_prompt: protocol_schema.InitialPromptTask,
TaskType.user_reply: protocol_schema.UserReplyTask,
TaskType.prompter_reply: protocol_schema.PrompterReplyTask,
TaskType.assistant_reply: protocol_schema.AssistantReplyTask,
TaskType.rank_initial_prompts: protocol_schema.RankInitialPromptsTask,
TaskType.rank_user_replies: protocol_schema.RankUserRepliesTask,
TaskType.rank_prompter_replies: protocol_schema.RankPrompterRepliesTask,
TaskType.rank_assistant_replies: protocol_schema.RankAssistantRepliesTask,
TaskType.done: protocol_schema.TaskDone,
}
+3 -3
View File
@@ -137,13 +137,13 @@ class OpenAssistantBot(BotBase):
handler = task_handlers.RateSummaryHandler()
case TaskType.initial_prompt:
handler = task_handlers.InitialPromptHandler()
case TaskType.user_reply:
handler = task_handlers.UserReplyHandler()
case TaskType.prompter_reply:
handler = task_handlers.PrompterReplyHandler()
case TaskType.assistant_reply:
handler = task_handlers.AssistantReplyHandler()
case TaskType.rank_initial_prompts:
handler = task_handlers.RankInitialPromptsHandler()
case TaskType.rank_user_replies | TaskType.rank_assistant_replies:
case TaskType.rank_prompter_replies | TaskType.rank_assistant_replies:
handler = task_handlers.RankConversationsHandler()
case _:
logger.warning(f"Unsupported task type received: {task.type}")
+4 -4
View File
@@ -146,15 +146,15 @@ class InitialPromptHandler(ChannelTaskBase):
await self.handle_text_reply_to_post(msg)
class UserReplyHandler(ChannelTaskBase):
task: protocol_schema.UserReplyTask
class PrompterReplyHandler(ChannelTaskBase):
task: protocol_schema.PrompterReplyTask
thread_name: str = "User replies"
async def send_first_message(self) -> discord.message:
return await self.post_teaser_msg("teaser_user_reply.msg")
return await self.post_teaser_msg("teaser_prompter_reply.msg")
async def on_thread_created(self, thread: discord.Thread) -> None:
await self.bot.post_template("task_user_reply.msg", channel=thread, task=self.task)
await self.bot.post_template("task_prompter_reply.msg", channel=thread, task=self.task)
async def handler_loop(self):
while True:
+10 -10
View File
@@ -12,10 +12,10 @@ class TaskRequestType(str, enum.Enum):
summarize_story = "summarize_story"
rate_summary = "rate_summary"
initial_prompt = "initial_prompt"
user_reply = "user_reply"
prompter_reply = "prompter_reply"
assistant_reply = "assistant_reply"
rank_initial_prompts = "rank_initial_prompts"
rank_user_replies = "rank_user_replies"
rank_prompter_replies = "rank_prompter_replies"
rank_assistant_replies = "rank_assistant_replies"
@@ -33,7 +33,7 @@ class ConversationMessage(BaseModel):
class Conversation(BaseModel):
"""Represents a conversation between the user and the assistant."""
"""Represents a conversation between the prompter and the assistant."""
messages: list[ConversationMessage] = []
@@ -114,10 +114,10 @@ class ReplyToConversationTask(Task):
conversation: Conversation # the conversation so far
class UserReplyTask(ReplyToConversationTask, WithHintMixin):
class PrompterReplyTask(ReplyToConversationTask, WithHintMixin):
"""A task to prompt the user to submit a reply to the assistant."""
type: Literal["user_reply"] = "user_reply"
type: Literal["prompter_reply"] = "prompter_reply"
class AssistantReplyTask(ReplyToConversationTask):
@@ -141,10 +141,10 @@ class RankConversationRepliesTask(Task):
replies: list[str]
class RankUserRepliesTask(RankConversationRepliesTask):
"""A task to rank a set of user replies to a conversation."""
class RankPrompterRepliesTask(RankConversationRepliesTask):
"""A task to rank a set of prompter replies to a conversation."""
type: Literal["rank_user_replies"] = "rank_user_replies"
type: Literal["rank_prompter_replies"] = "rank_prompter_replies"
class RankAssistantRepliesTask(RankConversationRepliesTask):
@@ -165,11 +165,11 @@ AnyTask = Union[
RateSummaryTask,
InitialPromptTask,
ReplyToConversationTask,
UserReplyTask,
PrompterReplyTask,
AssistantReplyTask,
RankInitialPromptsTask,
RankConversationRepliesTask,
RankUserRepliesTask,
RankPrompterRepliesTask,
RankAssistantRepliesTask,
]
+3 -3
View File
@@ -21,7 +21,7 @@ def _render_message(message: dict) -> str:
"""Render a message to the user."""
if message["is_assistant"]:
return f"Assistant: {message['text']}"
return f"User: {message['text']}"
return f"Prompter: {message['text']}"
@app.command()
@@ -107,7 +107,7 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
)
tasks.append(new_task)
case "user_reply":
case "prompter_reply":
typer.echo("Please provide a reply to the assistant.")
typer.echo("Here is the conversation so far:")
for message in task["conversation"]["messages"]:
@@ -178,7 +178,7 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
)
tasks.append(new_task)
case "rank_user_replies" | "rank_assistant_replies":
case "rank_prompter_replies" | "rank_assistant_replies":
typer.echo("Here is the conversation so far:")
for message in task["conversation"]["messages"]:
typer.echo(_render_message(message))