mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-07-01 16:50:12 +08:00
endpoint to list chats
This commit is contained in:
+25
-10
@@ -39,14 +39,6 @@ redisClient = redis.Redis(
|
||||
)
|
||||
|
||||
|
||||
class CreateChatRequest(pydantic.BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class CreateChatResponse(pydantic.BaseModel):
|
||||
id: str
|
||||
|
||||
|
||||
class MessageRequest(pydantic.BaseModel):
|
||||
message: str = pydantic.Field(..., repr=False)
|
||||
model_name: str = "distilgpt2"
|
||||
@@ -67,24 +59,47 @@ class MessageRequestState(str, enum.Enum):
|
||||
aborted_by_worker = "aborted_by_worker"
|
||||
|
||||
|
||||
class CreateChatRequest(pydantic.BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class ChatListEntry(pydantic.BaseModel):
|
||||
id: str
|
||||
|
||||
|
||||
class ListChatsResponse(pydantic.BaseModel):
|
||||
chats: list[ChatListEntry]
|
||||
|
||||
|
||||
class DbChatEntry(pydantic.BaseModel):
|
||||
id: str = pydantic.Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
conversation: protocol.Conversation = pydantic.Field(default_factory=protocol.Conversation)
|
||||
pending_message_request: MessageRequest | None = None
|
||||
message_request_state: MessageRequestState | None = None
|
||||
|
||||
def to_list_entry(self) -> ChatListEntry:
|
||||
return ChatListEntry(id=self.id)
|
||||
|
||||
|
||||
# TODO: make real database
|
||||
CHATS: dict[str, DbChatEntry] = {}
|
||||
|
||||
|
||||
@app.get("/chat")
|
||||
async def list_chats() -> ListChatsResponse:
|
||||
"""Lists all chats."""
|
||||
logger.info("Listing all chats.")
|
||||
chats = [chat.to_list_entry() for chat in CHATS.values()]
|
||||
return ListChatsResponse(chats=chats)
|
||||
|
||||
|
||||
@app.post("/chat")
|
||||
async def create_chat(request: CreateChatRequest) -> CreateChatResponse:
|
||||
async def create_chat(request: CreateChatRequest) -> ChatListEntry:
|
||||
"""Allows a client to create a new chat."""
|
||||
logger.info(f"Received {request}")
|
||||
chat = DbChatEntry()
|
||||
CHATS[chat.id] = chat
|
||||
return CreateChatResponse(id=chat.id)
|
||||
return ChatListEntry(id=chat.id)
|
||||
|
||||
|
||||
@app.get("/chat/{id}")
|
||||
|
||||
Reference in New Issue
Block a user