diff --git a/inference/server/main.py b/inference/server/main.py index 0c282394..1f9f16af 100644 --- a/inference/server/main.py +++ b/inference/server/main.py @@ -39,14 +39,6 @@ redisClient = redis.Redis( ) -class CreateChatRequest(pydantic.BaseModel): - pass - - -class CreateChatResponse(pydantic.BaseModel): - id: str - - class MessageRequest(pydantic.BaseModel): message: str = pydantic.Field(..., repr=False) model_name: str = "distilgpt2" @@ -67,24 +59,47 @@ class MessageRequestState(str, enum.Enum): aborted_by_worker = "aborted_by_worker" +class CreateChatRequest(pydantic.BaseModel): + pass + + +class ChatListEntry(pydantic.BaseModel): + id: str + + +class ListChatsResponse(pydantic.BaseModel): + chats: list[ChatListEntry] + + class DbChatEntry(pydantic.BaseModel): id: str = pydantic.Field(default_factory=lambda: str(uuid.uuid4())) conversation: protocol.Conversation = pydantic.Field(default_factory=protocol.Conversation) pending_message_request: MessageRequest | None = None message_request_state: MessageRequestState | None = None + def to_list_entry(self) -> ChatListEntry: + return ChatListEntry(id=self.id) + # TODO: make real database CHATS: dict[str, DbChatEntry] = {} +@app.get("/chat") +async def list_chats() -> ListChatsResponse: + """Lists all chats.""" + logger.info("Listing all chats.") + chats = [chat.to_list_entry() for chat in CHATS.values()] + return ListChatsResponse(chats=chats) + + @app.post("/chat") -async def create_chat(request: CreateChatRequest) -> CreateChatResponse: +async def create_chat(request: CreateChatRequest) -> ChatListEntry: """Allows a client to create a new chat.""" logger.info(f"Received {request}") chat = DbChatEntry() CHATS[chat.id] = chat - return CreateChatResponse(id=chat.id) + return ChatListEntry(id=chat.id) @app.get("/chat/{id}")