From bb3b0e739781c8d26fb6d2f1d474fda69f9d7260 Mon Sep 17 00:00:00 2001 From: Alex Ott <66271487+AlexanderHOtt@users.noreply.github.com> Date: Thu, 29 Dec 2022 15:00:20 -0800 Subject: [PATCH] update api client to upstream version --- discord-bot/api_client.py | 79 ----------------------------------- discord-bot/bot/api_client.py | 13 ++++-- 2 files changed, 9 insertions(+), 83 deletions(-) delete mode 100644 discord-bot/api_client.py diff --git a/discord-bot/api_client.py b/discord-bot/api_client.py deleted file mode 100644 index 0c88258e..00000000 --- a/discord-bot/api_client.py +++ /dev/null @@ -1,79 +0,0 @@ -# -*- coding: utf-8 -*- -import enum -from typing import Optional, Type - -import requests -from oasst_shared.schemas import protocol as protocol_schema - - -class TaskType(str, enum.Enum): - summarize_story = "summarize_story" - rate_summary = "rate_summary" - initial_prompt = "initial_prompt" - user_reply = "user_reply" - assistant_reply = "assistant_reply" - rank_initial_prompts = "rank_initial_prompts" - rank_user_replies = "rank_user_replies" - rank_assistant_replies = "rank_assistant_replies" - done = "task_done" - - -class ApiClient: - def __init__(self, backend_url: str, api_key: str): - self.backend_url = backend_url - self.api_key = api_key - - task_models_map: dict[str, Type[protocol_schema.Task]] = { - TaskType.summarize_story: protocol_schema.SummarizeStoryTask, - TaskType.rate_summary: protocol_schema.RateSummaryTask, - TaskType.initial_prompt: protocol_schema.InitialPromptTask, - TaskType.user_reply: protocol_schema.UserReplyTask, - TaskType.assistant_reply: protocol_schema.AssistantReplyTask, - TaskType.rank_initial_prompts: protocol_schema.RankInitialPromptsTask, - TaskType.rank_user_replies: protocol_schema.RankUserRepliesTask, - TaskType.rank_assistant_replies: protocol_schema.RankAssistantRepliesTask, - TaskType.done: protocol_schema.TaskDone, - } - self.task_models_map = task_models_map - - def post(self, path: str, json: dict) -> dict: - response = requests.post(f"{self.backend_url}{path}", json=json, headers={"X-API-Key": self.api_key}) - response.raise_for_status() - return response.json() - - def _parse_task(self, data: dict) -> protocol_schema.Task: - if not isinstance(data, dict): - raise ValueError("dict expected") - - task_type = data.get("type") - if task_type not in self.task_models_map: - raise RuntimeError(f"Unsupported task type: {task_type}") - - return self.task_models_map[task_type].parse_obj(data) - - def fetch_task( - self, - task_type: protocol_schema.TaskRequestType, - user: Optional[protocol_schema.User] = None, - collective: bool = False, - ) -> protocol_schema.Task: - req = protocol_schema.TaskRequest(type=task_type, user=user, collective=collective) - data = self.post("/api/v1/tasks/", req.dict()) - return self._parse_task(data) - - def fetch_random_task( - self, user: Optional[protocol_schema.User] = None, collective: bool = False - ) -> protocol_schema.Task: - return self.fetch_task(protocol_schema.TaskRequestType.random, user, collective=collective) - - def ack_task(self, task_id: str, post_id: str) -> None: - req = protocol_schema.TaskAck(post_id=post_id) - return self.post(f"/api/v1/tasks/{task_id}/ack", req.dict()) - - def nack_task(self, task_id: str, reason: str) -> None: - req = protocol_schema.TaskNAck(reason=reason) - return self.post(f"/api/v1/tasks/{task_id}/nack", req.dict()) - - def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task: - data = self.post("/api/v1/tasks/interaction", interaction.dict()) - return self._parse_task(data) diff --git a/discord-bot/bot/api_client.py b/discord-bot/bot/api_client.py index 9f319869..b5c96505 100644 --- a/discord-bot/bot/api_client.py +++ b/discord-bot/bot/api_client.py @@ -69,19 +69,24 @@ class OasstApiClient: return self.task_models_map[task_type].parse_obj(data) # type: ignore async def fetch_task( - self, task_type: protocol_schema.TaskRequestType, user: Optional[protocol_schema.User] = None + self, + task_type: protocol_schema.TaskRequestType, + user: Optional[protocol_schema.User] = None, + collective: bool = False, ) -> protocol_schema.Task: """Fetch a task from the backend.""" logger.debug(f"Fetching task {task_type} for user {user}") - req = protocol_schema.TaskRequest(type=task_type.value, user=user) + req = protocol_schema.TaskRequest(type=task_type.value, user=user, collective=collective) resp = await self.post("/api/v1/tasks/", data=req.dict()) logger.debug(f"Fetch task response: {resp}") return self._parse_task(resp) - async def fetch_random_task(self, user: Optional[protocol_schema.User] = None) -> protocol_schema.Task: + async def fetch_random_task( + self, user: Optional[protocol_schema.User] = None, collective: bool = False + ) -> protocol_schema.Task: """Fetch a random task from the backend.""" logger.debug(f"Fetching random for user {user}") - return await self.fetch_task(protocol_schema.TaskRequestType.random, user) + return await self.fetch_task(protocol_schema.TaskRequestType.random, user, collective) async def ack_task(self, task_id: str | UUID, post_id: str): """Send an ACK for a task to the backend."""