mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
6452bb860d
added playbook to deploy dev machine added playbook to deploy dev machine added next.js font module, updated fonts, updated login page replaced logos, changed logo on login and header added 404 and email confirmation pages, changed discord and github buttons to show conditionally also removed unused imports, fixed one spelling error, and minor styling changes Quick format to the authenticated user page, updated header with user profile, styling updates fixed html encoding added checkout for release re-vamped release config and ports
75 lines
3.1 KiB
Python
75 lines
3.1 KiB
Python
# -*- coding: utf-8 -*-
|
|
import enum
|
|
from typing import Optional, Type
|
|
|
|
import requests
|
|
from oasst_shared.schemas import protocol as protocol_schema
|
|
|
|
|
|
class TaskType(str, enum.Enum):
|
|
summarize_story = "summarize_story"
|
|
rate_summary = "rate_summary"
|
|
initial_prompt = "initial_prompt"
|
|
user_reply = "user_reply"
|
|
assistant_reply = "assistant_reply"
|
|
rank_initial_prompts = "rank_initial_prompts"
|
|
rank_user_replies = "rank_user_replies"
|
|
rank_assistant_replies = "rank_assistant_replies"
|
|
done = "task_done"
|
|
|
|
|
|
class ApiClient:
|
|
def __init__(self, backend_url: str, api_key: str):
|
|
self.backend_url = backend_url
|
|
self.api_key = api_key
|
|
|
|
task_models_map: dict[str, Type[protocol_schema.Task]] = {
|
|
TaskType.summarize_story: protocol_schema.SummarizeStoryTask,
|
|
TaskType.rate_summary: protocol_schema.RateSummaryTask,
|
|
TaskType.initial_prompt: protocol_schema.InitialPromptTask,
|
|
TaskType.user_reply: protocol_schema.UserReplyTask,
|
|
TaskType.assistant_reply: protocol_schema.AssistantReplyTask,
|
|
TaskType.rank_initial_prompts: protocol_schema.RankInitialPromptsTask,
|
|
TaskType.rank_user_replies: protocol_schema.RankUserRepliesTask,
|
|
TaskType.rank_assistant_replies: protocol_schema.RankAssistantRepliesTask,
|
|
TaskType.done: protocol_schema.TaskDone,
|
|
}
|
|
self.task_models_map = task_models_map
|
|
|
|
def post(self, path: str, json: dict) -> dict:
|
|
response = requests.post(f"{self.backend_url}{path}", json=json, headers={"X-API-Key": self.api_key})
|
|
response.raise_for_status()
|
|
return response.json()
|
|
|
|
def _parse_task(self, data: dict) -> protocol_schema.Task:
|
|
if not isinstance(data, dict):
|
|
raise ValueError("dict expected")
|
|
|
|
task_type = data.get("type")
|
|
if task_type not in self.task_models_map:
|
|
raise RuntimeError(f"Unsupported task type: {task_type}")
|
|
|
|
return self.task_models_map[task_type].parse_obj(data)
|
|
|
|
def fetch_task(
|
|
self, task_type: protocol_schema.TaskRequestType, user: Optional[protocol_schema.User] = None
|
|
) -> protocol_schema.Task:
|
|
req = protocol_schema.TaskRequest(type=task_type, user=user)
|
|
data = self.post("/api/v1/tasks/", req.dict())
|
|
return self._parse_task(data)
|
|
|
|
def fetch_random_task(self, user: Optional[protocol_schema.User] = None) -> protocol_schema.Task:
|
|
return self.fetch_task(protocol_schema.TaskRequestType.random, user)
|
|
|
|
def ack_task(self, task_id: str, post_id: str) -> None:
|
|
req = protocol_schema.TaskAck(post_id=post_id)
|
|
return self.post(f"/api/v1/tasks/{task_id}/ack", req.dict())
|
|
|
|
def nack_task(self, task_id: str, reason: str) -> None:
|
|
req = protocol_schema.TaskNAck(reason=reason)
|
|
return self.post(f"/api/v1/tasks/{task_id}/nack", req.dict())
|
|
|
|
def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task:
|
|
data = self.post("/api/v1/tasks/interaction", interaction.dict())
|
|
return self._parse_task(data)
|