endpoint to list chats

This commit is contained in:
Yannic Kilcher
2023-02-09 08:49:01 +01:00
parent 27671e3220
commit a85cc0a47d
+25 -10
View File
@@ -39,14 +39,6 @@ redisClient = redis.Redis(
)
class CreateChatRequest(pydantic.BaseModel):
pass
class CreateChatResponse(pydantic.BaseModel):
id: str
class MessageRequest(pydantic.BaseModel):
message: str = pydantic.Field(..., repr=False)
model_name: str = "distilgpt2"
@@ -67,24 +59,47 @@ class MessageRequestState(str, enum.Enum):
aborted_by_worker = "aborted_by_worker"
class CreateChatRequest(pydantic.BaseModel):
pass
class ChatListEntry(pydantic.BaseModel):
id: str
class ListChatsResponse(pydantic.BaseModel):
chats: list[ChatListEntry]
class DbChatEntry(pydantic.BaseModel):
id: str = pydantic.Field(default_factory=lambda: str(uuid.uuid4()))
conversation: protocol.Conversation = pydantic.Field(default_factory=protocol.Conversation)
pending_message_request: MessageRequest | None = None
message_request_state: MessageRequestState | None = None
def to_list_entry(self) -> ChatListEntry:
return ChatListEntry(id=self.id)
# TODO: make real database
CHATS: dict[str, DbChatEntry] = {}
@app.get("/chat")
async def list_chats() -> ListChatsResponse:
"""Lists all chats."""
logger.info("Listing all chats.")
chats = [chat.to_list_entry() for chat in CHATS.values()]
return ListChatsResponse(chats=chats)
@app.post("/chat")
async def create_chat(request: CreateChatRequest) -> CreateChatResponse:
async def create_chat(request: CreateChatRequest) -> ChatListEntry:
"""Allows a client to create a new chat."""
logger.info(f"Received {request}")
chat = DbChatEntry()
CHATS[chat.id] = chat
return CreateChatResponse(id=chat.id)
return ChatListEntry(id=chat.id)
@app.get("/chat/{id}")