mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-07-02 17:00:28 +08:00
added tasks to act as user or assistant
This commit is contained in:
@@ -16,9 +16,9 @@ router = APIRouter()
|
||||
|
||||
def generate_task(request: protocol_schema.TaskRequest) -> protocol_schema.Task:
|
||||
match (request.type):
|
||||
case protocol_schema.TaskRequestType.generic:
|
||||
logger.info("Frontend requested a generic task.")
|
||||
while request.type == protocol_schema.TaskRequestType.generic:
|
||||
case protocol_schema.TaskRequestType.random:
|
||||
logger.info("Frontend requested a random task.")
|
||||
while request.type == protocol_schema.TaskRequestType.random:
|
||||
request.type = random.choice(list(protocol_schema.TaskRequestType)).value
|
||||
return generate_task(request)
|
||||
case protocol_schema.TaskRequestType.summarize_story:
|
||||
@@ -38,6 +38,34 @@ def generate_task(request: protocol_schema.TaskRequest) -> protocol_schema.Task:
|
||||
task = protocol_schema.InitialPromptTask(
|
||||
hint="Ask the assistant about a current event." # this is optional
|
||||
)
|
||||
case protocol_schema.TaskRequestType.user_reply:
|
||||
logger.info("Generating a UserReplyTask.")
|
||||
task = protocol_schema.UserReplyTask(
|
||||
conversation=protocol_schema.Conversation(
|
||||
messages=[
|
||||
protocol_schema.ConversationMessage(
|
||||
text="Hey, assistant, what's going on in the world?",
|
||||
is_assistant=False,
|
||||
),
|
||||
protocol_schema.ConversationMessage(
|
||||
text="I'm not sure I understood correctly, could you rephrase that?",
|
||||
is_assistant=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
)
|
||||
case protocol_schema.TaskRequestType.assistant_reply:
|
||||
logger.info("Generating a AssistantReplyTask.")
|
||||
task = protocol_schema.AssistantReplyTask(
|
||||
conversation=protocol_schema.Conversation(
|
||||
messages=[
|
||||
protocol_schema.ConversationMessage(
|
||||
text="Hey, assistant, write me an English essay about water.",
|
||||
is_assistant=False,
|
||||
),
|
||||
],
|
||||
)
|
||||
)
|
||||
case _:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_400_BAD_REQUEST,
|
||||
@@ -45,7 +73,7 @@ def generate_task(request: protocol_schema.TaskRequest) -> protocol_schema.Task:
|
||||
)
|
||||
logger.info(f"Generated {task=}.")
|
||||
if request.user is not None:
|
||||
task.addressed_users = [request.user]
|
||||
task.addressed_user = request.user
|
||||
|
||||
return task
|
||||
|
||||
@@ -122,7 +150,7 @@ def post_interaction(
|
||||
# here we would store the text reply in the database
|
||||
return protocol_schema.TaskDone(
|
||||
reply_to_post_id=interaction.user_post_id,
|
||||
addressed_users=[interaction.user],
|
||||
addressed_user=interaction.user,
|
||||
)
|
||||
case protocol_schema.PostRating:
|
||||
logger.info(
|
||||
@@ -132,7 +160,7 @@ def post_interaction(
|
||||
# here we would store the rating in the database
|
||||
return protocol_schema.TaskDone(
|
||||
reply_to_post_id=interaction.post_id,
|
||||
addressed_users=[interaction.user],
|
||||
addressed_user=interaction.user,
|
||||
)
|
||||
case _:
|
||||
raise HTTPException(
|
||||
|
||||
@@ -8,21 +8,37 @@ from pydantic import BaseModel
|
||||
|
||||
|
||||
class TaskRequestType(str, enum.Enum):
|
||||
generic = "generic"
|
||||
random = "random"
|
||||
summarize_story = "summarize_story"
|
||||
rate_summary = "rate_summary"
|
||||
initial_prompt = "initial_prompt"
|
||||
user_reply = "user_reply"
|
||||
assistant_reply = "assistant_reply"
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
display_name: str
|
||||
auth_method: Literal["discord", "local"]
|
||||
|
||||
|
||||
class ConversationMessage(BaseModel):
|
||||
"""Represents a message in a conversation between the user and the assistant."""
|
||||
|
||||
text: str
|
||||
is_assistant: bool
|
||||
|
||||
|
||||
class Conversation(BaseModel):
|
||||
"""Represents a conversation between the user and the assistant."""
|
||||
|
||||
messages: list[ConversationMessage] = []
|
||||
|
||||
|
||||
class TaskRequest(BaseModel):
|
||||
"""The frontend asks the backend for a task."""
|
||||
|
||||
type: TaskRequestType = TaskRequestType.generic
|
||||
type: TaskRequestType = TaskRequestType.random
|
||||
user: Optional[User] = None
|
||||
|
||||
|
||||
@@ -31,7 +47,7 @@ class Task(BaseModel):
|
||||
|
||||
id: UUID = pydantic.Field(default_factory=uuid4)
|
||||
type: str
|
||||
addressed_users: Optional[list[User]] = None
|
||||
addressed_user: Optional[User] = None
|
||||
|
||||
|
||||
class TaskResponse(BaseModel):
|
||||
@@ -91,6 +107,21 @@ class InitialPromptTask(Task):
|
||||
)
|
||||
|
||||
|
||||
class UserReplyTask(Task):
|
||||
"""A task to prompt the user to submit a reply to the assistant."""
|
||||
|
||||
type: Literal["user_reply"] = "user_reply"
|
||||
conversation: Conversation # the conversation so far
|
||||
hint: str | None = None # e.g. "Try to ask for clarification."
|
||||
|
||||
|
||||
class AssistantReplyTask(Task):
|
||||
"""A task to prompt the user to act as the assistant."""
|
||||
|
||||
type: Literal["assistant_reply"] = "assistant_reply"
|
||||
conversation: Conversation # the conversation so far
|
||||
|
||||
|
||||
class TaskDone(Task):
|
||||
"""Signals to the frontend that the task is done."""
|
||||
|
||||
@@ -99,10 +130,12 @@ class TaskDone(Task):
|
||||
|
||||
|
||||
AnyTask = Union[
|
||||
TaskDone,
|
||||
SummarizeStoryTask,
|
||||
RateSummaryTask,
|
||||
InitialPromptTask,
|
||||
TaskDone,
|
||||
UserReplyTask,
|
||||
AssistantReplyTask,
|
||||
]
|
||||
|
||||
|
||||
|
||||
+73
-13
@@ -7,6 +7,19 @@ import typer
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
# debug constants
|
||||
POST_ID = "1234"
|
||||
USER_POST_ID = "5678"
|
||||
USER = {"id": "1234", "display_name": "John Doe", "auth_method": "local"}
|
||||
|
||||
|
||||
def _render_message(message: dict) -> str:
|
||||
"""Render a message to the user."""
|
||||
if message["is_assistant"]:
|
||||
return f"Assistant: {message['text']}"
|
||||
return f"User: {message['text']}"
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(backend_url: str, api_key: str):
|
||||
"""Simple REPL frontend."""
|
||||
@@ -17,7 +30,7 @@ def main(backend_url: str, api_key: str):
|
||||
return response.json()
|
||||
|
||||
typer.echo("Requesting work...")
|
||||
tasks = [_post("/api/v1/tasks/", {"type": "generic"})]
|
||||
tasks = [_post("/api/v1/tasks/", {"type": "random"})]
|
||||
while tasks:
|
||||
task = tasks.pop(0)
|
||||
match (task["type"]):
|
||||
@@ -26,7 +39,7 @@ def main(backend_url: str, api_key: str):
|
||||
typer.echo(task["story"])
|
||||
|
||||
# acknowledge task
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"type": "post_created", "post_id": "1234"})
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"type": "post_created", "post_id": POST_ID})
|
||||
|
||||
summary = typer.prompt("Enter your summary")
|
||||
|
||||
@@ -35,10 +48,10 @@ def main(backend_url: str, api_key: str):
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "text_reply_to_post",
|
||||
"post_id": "1234",
|
||||
"user_post_id": "5678",
|
||||
"post_id": POST_ID,
|
||||
"user_post_id": USER_POST_ID,
|
||||
"text": summary,
|
||||
"user": {"id": "1234", "name": "John Doe"},
|
||||
"user": USER,
|
||||
},
|
||||
)
|
||||
tasks.append(new_task)
|
||||
@@ -50,7 +63,7 @@ def main(backend_url: str, api_key: str):
|
||||
typer.echo(f"Rating scale: {task['scale']['min']} - {task['scale']['max']}")
|
||||
|
||||
# acknowledge task
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"type": "rating_created", "post_id": "1234"})
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"type": "rating_created", "post_id": POST_ID})
|
||||
|
||||
rating = typer.prompt("Enter your rating", type=int)
|
||||
# send interaction
|
||||
@@ -58,9 +71,9 @@ def main(backend_url: str, api_key: str):
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "post_rating",
|
||||
"post_id": "1234",
|
||||
"post_id": POST_ID,
|
||||
"rating": rating,
|
||||
"user": {"id": "1234", "name": "John Doe"},
|
||||
"user": USER,
|
||||
},
|
||||
)
|
||||
tasks.append(new_task)
|
||||
@@ -69,23 +82,70 @@ def main(backend_url: str, api_key: str):
|
||||
if task["hint"]:
|
||||
typer.echo(f"Hint: {task['hint']}")
|
||||
# acknowledge task
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"type": "post_created", "post_id": "1234"})
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"type": "post_created", "post_id": POST_ID})
|
||||
prompt = typer.prompt("Enter your prompt")
|
||||
# send interaction
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "text_reply_to_post",
|
||||
"post_id": "1234",
|
||||
"user_post_id": "5678",
|
||||
"post_id": POST_ID,
|
||||
"user_post_id": USER_POST_ID,
|
||||
"text": prompt,
|
||||
"user": {"id": "1234", "name": "John Doe"},
|
||||
"user": USER,
|
||||
},
|
||||
)
|
||||
tasks.append(new_task)
|
||||
|
||||
case "user_reply":
|
||||
typer.echo("Please provide a reply to the assistant.")
|
||||
typer.echo("Here is the conversation so far:")
|
||||
for message in task["conversation"]["messages"]:
|
||||
typer.echo(_render_message(message))
|
||||
if task["hint"]:
|
||||
typer.echo(f"Hint: {task['hint']}")
|
||||
# acknowledge task
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"type": "post_created", "post_id": POST_ID})
|
||||
reply = typer.prompt("Enter your reply")
|
||||
# send interaction
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "text_reply_to_post",
|
||||
"post_id": POST_ID,
|
||||
"user_post_id": USER_POST_ID,
|
||||
"text": reply,
|
||||
"user": USER,
|
||||
},
|
||||
)
|
||||
tasks.append(new_task)
|
||||
|
||||
case "assistant_reply":
|
||||
typer.echo("Act as the assistant and reply to the user.")
|
||||
typer.echo("Here is the conversation so far:")
|
||||
for message in task["conversation"]["messages"]:
|
||||
typer.echo(_render_message(message))
|
||||
# acknowledge task
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"type": "post_created", "post_id": POST_ID})
|
||||
reply = typer.prompt("Enter your reply")
|
||||
# send interaction
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "text_reply_to_post",
|
||||
"post_id": POST_ID,
|
||||
"user_post_id": USER_POST_ID,
|
||||
"text": reply,
|
||||
"user": USER,
|
||||
},
|
||||
)
|
||||
tasks.append(new_task)
|
||||
|
||||
case "task_done":
|
||||
typer.echo("Task done!")
|
||||
if addressed_user := task["addressed_user"]:
|
||||
typer.echo(f"Hey, {addressed_user['display_name']}! Thank you!")
|
||||
else:
|
||||
typer.echo("Task done!")
|
||||
case _:
|
||||
typer.echo(f"Unknown task type {task['type']}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user