diff --git a/backend/oasst_backend/api/v1/admin.py b/backend/oasst_backend/api/v1/admin.py index 59c5efda..584c5fb7 100644 --- a/backend/oasst_backend/api/v1/admin.py +++ b/backend/oasst_backend/api/v1/admin.py @@ -54,6 +54,7 @@ class PublicSettings(pydantic.BaseModel): PROJECT_NAME: str API_V1_STR: str + MESSAGE_SIZE_LIMIT: int DEBUG_USE_SEED_DATA: bool DEBUG_ALLOW_SELF_LABELING: bool DEBUG_SKIP_EMBEDDING_COMPUTATION: bool diff --git a/backend/oasst_backend/api/v1/users.py b/backend/oasst_backend/api/v1/users.py index 04b3cc99..ed07eb21 100644 --- a/backend/oasst_backend/api/v1/users.py +++ b/backend/oasst_backend/api/v1/users.py @@ -111,7 +111,7 @@ def get_users_cursor( n = lt return p, n - def remove_extra_item(items: list[protocol.FrontEndUser], lt: str | None, gt: str): + def remove_extra_item(items: list[protocol.FrontEndUser], lt: str | None, gt: str | None): num_rows = len(items) if qry_max_count > max_count and num_rows == qry_max_count: assert not (lt and gt) diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index 99b10cb4..8ca2a413 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -72,6 +72,7 @@ class Settings(BaseSettings): DATABASE_MAX_TX_RETRY_COUNT: int = 3 RATE_LIMIT: bool = True + MESSAGE_SIZE_LIMIT: int = 2000 REDIS_HOST: str = "localhost" REDIS_PORT: str = "6379" diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index 026bb564..1224af4a 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -465,6 +465,13 @@ class TreeManager: f"Frontend reports text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}." ) + # ensure message size is below the predefined limit + if len(interaction.text) > settings.MESSAGE_SIZE_LIMIT: + logger.error( + f"Message size {len(interaction.text)=} exceeds size limit of {settings.MESSAGE_SIZE_LIMIT=}." + ) + raise OasstError("Message size too long.", OasstErrorCode.TASK_MESSAGE_TOO_LONG) + # here we store the text reply in the database message = pr.store_text_reply( text=interaction.text, diff --git a/backend/oasst_backend/user_repository.py b/backend/oasst_backend/user_repository.py index e960d944..670bbc3e 100644 --- a/backend/oasst_backend/user_repository.py +++ b/backend/oasst_backend/user_repository.py @@ -153,7 +153,7 @@ class UserRepository: if api_client_id != self.api_client.id: raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN) - qry = self.db.query(User).order_by(User.username, User.id) + qry = self.db.query(User) if gte_username is not None: if gt_id: @@ -184,8 +184,14 @@ class UserRepository: pattern = "%{}%".format(search_text.replace("\\", "\\\\").replace("_", "\\_").replace("%", "\\%")) qry = qry.filter(User.username.like(pattern)) - if limit is not None: - qry = qry.limit(limit) + if limit is not None and lte_username and not gte_username: + # select top rows but return results in ascernding order + sub_qry = qry.order_by(User.username.desc(), User.id.desc()).limit(limit).subquery("u") + qry = self.db.query(User).select_entity_from(sub_qry).order_by(User.username, User.id) + else: + qry = qry.order_by(User.username, User.id) + if limit is not None: + qry = qry.limit(limit) return qry.all() @@ -210,7 +216,7 @@ class UserRepository: # Unprivileged api client asks for foreign users raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN) - qry = self.db.query(User).order_by(User.display_name, User.id) + qry = self.db.query(User) if gte_display_name is not None: if gt_id: @@ -254,8 +260,14 @@ class UserRepository: if auth_method: qry = qry.filter(User.auth_method == auth_method) - if limit is not None: - qry = qry.limit(limit) + if limit is not None and lte_display_name and not gte_display_name: + # select top rows but return results in ascernding order + sub_qry = qry.order_by(User.display_name.desc(), User.id.desc()).limit(limit).subquery("u") + qry = self.db.query(User).select_entity_from(sub_qry).order_by(User.display_name, User.id) + else: + qry = qry.order_by(User.display_name, User.id) + if limit is not None: + qry = qry.limit(limit) users = qry.all() diff --git a/inference/README.md b/inference/README.md new file mode 100644 index 00000000..3dee94f9 --- /dev/null +++ b/inference/README.md @@ -0,0 +1,35 @@ +# OpenAssitant Inference + +Preliminary implementation of the inference engine for OpenAssistant. + +## Development (you'll need multiple terminals) + +Run a redis container (or use the one of the general docker compose file): + +```bash +docker run --rm -it -p 6379:6379 redis +``` + +Run the inference server: + +```bash +cd server +pip install -r requirements.txt +uvicorn main:app --reload +``` + +Run one (or more) workers: + +```bash +cd worker +pip install -r requirements.txt +python __main__.py +``` + +Run the client: + +```bash +cd text-client +pip install -r requirements.txt +python __main__.py +``` diff --git a/inference/server/README.md b/inference/server/README.md new file mode 100644 index 00000000..a235a7e6 --- /dev/null +++ b/inference/server/README.md @@ -0,0 +1,10 @@ +# OpenAssistant Inference Server + +Workers communicate with the `/work` endpoint via Websocket. They provide their +configuration and if a task is available, the server returns it. The worker then +performs the task and returns the result in a streaming fashion to the server, +also via websocket. + +Clients first call `/chat` to make a new chat, then add to that via +`/chat//message`. The response is a SSE event source, which will send tokens +as they are available. diff --git a/inference/server/main.py b/inference/server/main.py new file mode 100644 index 00000000..f3ec02b1 --- /dev/null +++ b/inference/server/main.py @@ -0,0 +1,193 @@ +import asyncio +import enum +import uuid + +import fastapi +import pydantic +import redis.asyncio as redis +from fastapi.middleware.cors import CORSMiddleware +from loguru import logger +from oasst_shared.schemas import inference, protocol +from sse_starlette.sse import EventSourceResponse + +app = fastapi.FastAPI() + +# Allow CORS +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +class Settings(pydantic.BaseSettings): + redis_host: str = "localhost" + redis_port: int = 6379 + redis_db: int = 0 + + sse_retry_timeout: int = 15000 + + +settings = Settings() + +# create async redis client +redisClient = redis.Redis( + host=settings.redis_host, port=settings.redis_port, db=settings.redis_db, decode_responses=True +) + + +class CreateChatRequest(pydantic.BaseModel): + pass + + +class CreateChatResponse(pydantic.BaseModel): + id: str + + +class MessageRequest(pydantic.BaseModel): + message: str = pydantic.Field(..., repr=False) + model_name: str = "distilgpt2" + max_new_tokens: int = 100 + + def compatible_with(self, worker_config: inference.WorkerConfig) -> bool: + return self.model_name == worker_config.model_name + + +class TokenResponseEvent(pydantic.BaseModel): + token: str + + +class MessageRequestState(str, enum.Enum): + pending = "pending" + in_progress = "in_progress" + complete = "complete" + + +class DbChatEntry(pydantic.BaseModel): + id: str = pydantic.Field(default_factory=lambda: str(uuid.uuid4())) + conversation: protocol.Conversation = pydantic.Field(default_factory=protocol.Conversation) + pending_message_request: MessageRequest | None = None + message_request_state: MessageRequestState | None = None + + +# TODO: make real database +CHATS: dict[str, DbChatEntry] = {} + + +@app.post("/chat") +async def create_chat(request: CreateChatRequest) -> CreateChatResponse: + """Allows a client to create a new chat.""" + logger.info(f"Received {request}") + chat = DbChatEntry() + CHATS[chat.id] = chat + return CreateChatResponse(id=chat.id) + + +@app.get("/chat/{id}") +async def get_chat(id: str) -> protocol.Conversation: + """Allows a client to get the current state of a chat.""" + return CHATS[id].conversation + + +@app.post("/chat/{id}/message") +async def create_message(id: str, message_request: MessageRequest, fastapi_request: fastapi.Request): + """Allows the client to stream the results of a request.""" + + chat = CHATS[id] + if not chat.conversation.is_prompter_turn: + raise fastapi.HTTPException(status_code=400, detail="Not your turn") + if chat.pending_message_request is not None: + raise fastapi.HTTPException(status_code=400, detail="Already pending") + + chat.conversation.messages.append( + protocol.ConversationMessage( + text=message_request.message, + is_assistant=False, + ) + ) + + chat.pending_message_request = message_request + chat.message_request_state = MessageRequestState.pending + + async def event_generator(): + result_data = [] + + try: + while True: + if await fastapi_request.is_disconnected(): + logger.warning("Client disconnected") + break + + item = await redisClient.blpop(chat.id, 1) + if item is None: + continue + + _, response_packet_str = item + response_packet = inference.WorkResponsePacket.parse_raw(response_packet_str) + result_data.append(response_packet) + + if response_packet.is_end: + break + + yield { + "retry": settings.sse_retry_timeout, + "data": TokenResponseEvent(token=response_packet.token).json(), + } + logger.info(f"Finished streaming {chat.id} {len(result_data)=}") + except Exception: + logger.exception(f"Error streaming {chat.id}") + + chat.conversation.messages.append( + protocol.ConversationMessage( + text="".join([d.token for d in result_data[:-1]]), + is_assistant=True, + ) + ) + chat.pending_message_request = None + + return EventSourceResponse(event_generator()) + + +@app.websocket("/work") +async def work(websocket: fastapi.WebSocket): + await websocket.accept() + worker_config = inference.WorkerConfig.parse_raw(await websocket.receive_text()) + while True: + # find a pending task that matches the worker's config + # could also be implemented using task queues + # but general compatibility matching is tricky + for chat in CHATS.values(): + if (request := chat.pending_message_request) is not None: + if chat.message_request_state == MessageRequestState.pending: + if request.compatible_with(worker_config): + break + else: + logger.debug("No pending tasks") + await asyncio.sleep(1) + continue + + chat.message_request_state = MessageRequestState.in_progress + + work_request = inference.WorkRequest( + conversation=chat.conversation, + model_name=request.model_name, + max_new_tokens=request.max_new_tokens, + ) + + logger.info(f"Created {work_request}") + try: + await websocket.send_text(work_request.json()) + while True: + # maybe unnecessary to parse and re-serialize + # could just pass the raw string and mark end via empty string + response_packet = inference.WorkResponsePacket.parse_raw(await websocket.receive_text()) + await redisClient.rpush(chat.id, response_packet.json()) + if response_packet.is_end: + break + except fastapi.WebSocketException: + # TODO: handle this better + logger.exception(f"Websocket closed during handling of {chat.id}") + + chat.message_request_state = MessageRequestState.complete diff --git a/inference/server/requirements.txt b/inference/server/requirements.txt new file mode 100644 index 00000000..e0a00339 --- /dev/null +++ b/inference/server/requirements.txt @@ -0,0 +1,6 @@ +fastapi[all] +loguru +pydantic +redis +sse-starlette +websockets diff --git a/inference/text-client/__main__.py b/inference/text-client/__main__.py new file mode 100644 index 00000000..bf1f8b02 --- /dev/null +++ b/inference/text-client/__main__.py @@ -0,0 +1,40 @@ +"""Simple REPL frontend.""" + +import json + +import requests +import sseclient +import typer + +app = typer.Typer() + + +@app.command() +def main(backend_url: str = "http://127.0.0.1:8000"): + """Simple REPL client.""" + chat_id = requests.post(f"{backend_url}/chat", json={}).json()["id"] + while True: + message = typer.prompt("User").strip() + + # wait for stream to be ready + # could implement a queue position indicator + # could be implemented with long polling + # but server load needs to be considered + response = requests.post( + f"{backend_url}/chat/{chat_id}/message", + json={"message": message}, + stream=True, + headers={"Accept": "text/event-stream"}, + ) + response.raise_for_status() + + client = sseclient.SSEClient(response) + print("Assistant: ", end="", flush=True) + for event in client.events(): + data = json.loads(event.data) + print(data["token"], end="", flush=True) + print() + + +if __name__ == "__main__": + app() diff --git a/inference/text-client/requirements.txt b/inference/text-client/requirements.txt new file mode 100644 index 00000000..8d7bff7d --- /dev/null +++ b/inference/text-client/requirements.txt @@ -0,0 +1,3 @@ +requests +sseclient-py +typer diff --git a/inference/worker/__main__.py b/inference/worker/__main__.py new file mode 100644 index 00000000..ad5e5cef --- /dev/null +++ b/inference/worker/__main__.py @@ -0,0 +1,79 @@ +import re +import time + +import rel +import torch +import typer +import websocket +from loguru import logger +from oasst_shared.schemas import inference, protocol +from transformers import pipeline + +app = typer.Typer() + + +@app.command() +def main( + backend_url: str = "ws://localhost:8000", + model_name: str = "distilgpt2", +): + pipe = pipeline("text-generation", model=model_name) + + def on_open(ws: websocket.WebSocket): + worker_config = inference.WorkerConfig(model_name=model_name) + ws.send(worker_config.json()) + + def on_message(ws: websocket.WebSocket, message: str): + # TODO: what if this comes in, but one is already in progress? + # also need to think of enabling batching + work_request = inference.WorkRequest.parse_raw(message) + + def _prepare_message(message: protocol.ConversationMessage) -> str: + prefix = "Assistant: " if message.is_assistant else "User: " + return prefix + message.text + + # construct prompt + messages = [_prepare_message(message) for message in work_request.conversation.messages] + + prompt = "\n".join(messages) + "\nAssistant:" + + # TODO: replace this with incremental generation + torch.manual_seed(work_request.seed) + model_output = pipe(prompt, max_new_tokens=work_request.max_new_tokens, do_sample=True, return_full_text=False)[ + 0 + ]["generated_text"] + model_output = model_output.strip() + + # fake streaming + split_idcs = [m.start() for m in re.finditer(r"([\w:]+)", model_output)] + pieces = [model_output[a:b] for a, b in zip([0] + split_idcs, split_idcs + [None])] + for piece in pieces: + if not piece: + continue + if piece.strip() in ("User:", "Assistant:"): + break + ws.send(inference.WorkResponsePacket(token=piece).json()) + time.sleep(0.1) + ws.send(inference.WorkResponsePacket(is_end=True).json()) + + def on_error(ws: websocket.WebSocket, error: Exception): + logger.error(f"Connection error: {error}") + + def on_close(ws: websocket.WebSocket, close_status_code: int, close_msg: str): + logger.warning(f"Connection closed: {close_status_code=} {close_msg=}") + + ws = websocket.WebSocketApp( + f"{backend_url}/work", + on_message=on_message, + on_error=on_error, + on_close=on_close, + on_open=on_open, + ) + + ws.run_forever(dispatcher=rel, reconnect=5) + rel.signal(2, rel.abort) + rel.dispatch() + + +if __name__ == "__main__": + app() diff --git a/inference/worker/requirements.txt b/inference/worker/requirements.txt new file mode 100644 index 00000000..c248c652 --- /dev/null +++ b/inference/worker/requirements.txt @@ -0,0 +1,6 @@ +loguru +rel +torch +transformers +typer +websocket-client diff --git a/oasst-shared/oasst_shared/exceptions/oasst_api_error.py b/oasst-shared/oasst_shared/exceptions/oasst_api_error.py index bed6f942..0a548ebb 100644 --- a/oasst-shared/oasst_shared/exceptions/oasst_api_error.py +++ b/oasst-shared/oasst_shared/exceptions/oasst_api_error.py @@ -37,6 +37,7 @@ class OasstErrorCode(IntEnum): TASK_GENERATION_FAILED = 1005 TASK_REQUESTED_TYPE_NOT_AVAILABLE = 1006 TASK_AVAILABILITY_QUERY_FAILED = 1007 + TASK_MESSAGE_TOO_LONG = 1008 # 2000-3000: prompt_repository INVALID_FRONTEND_MESSAGE_ID = 2000 diff --git a/oasst-shared/oasst_shared/schemas/inference.py b/oasst-shared/oasst_shared/schemas/inference.py new file mode 100644 index 00000000..0acb5014 --- /dev/null +++ b/oasst-shared/oasst_shared/schemas/inference.py @@ -0,0 +1,21 @@ +import random + +import pydantic + +from . import protocol + + +class WorkerConfig(pydantic.BaseModel): + model_name: str = "distilgpt2" + + +class WorkRequest(pydantic.BaseModel): + conversation: protocol.Conversation = pydantic.Field(..., repr=False) + model_name: str = "distilgpt2" + max_new_tokens: int = 100 + seed: int = pydantic.Field(default_factory=lambda: random.randint(0, 2**32 - 1)) + + +class WorkResponsePacket(pydantic.BaseModel): + token: str | None = None + is_end: bool = False diff --git a/oasst-shared/oasst_shared/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py index f2164e8f..20bbdf9b 100644 --- a/oasst-shared/oasst_shared/schemas/protocol.py +++ b/oasst-shared/oasst_shared/schemas/protocol.py @@ -64,6 +64,18 @@ class Conversation(BaseModel): messages: list[ConversationMessage] = [] + def __len__(self): + return len(self.messages) + + @property + def is_prompter_turn(self) -> bool: + if len(self) == 0: + return True + last_message = self.messages[-1] + if last_message.is_assistant: + return True + return False + class Message(ConversationMessage): parent_id: Optional[UUID] = None diff --git a/openassistant/datasets/__init__.py b/openassistant/datasets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/openassistant/datasets/soda_synthetic_dialogue/README.md b/openassistant/datasets/soda_synthetic_dialogue/README.md new file mode 100644 index 00000000..c4866e16 --- /dev/null +++ b/openassistant/datasets/soda_synthetic_dialogue/README.md @@ -0,0 +1,108 @@ +--- +annotations_creators: + - no-annotation +language: + - en +language_creators: + - machine-generated +license: + - mit +multilinguality: + - monolingual +pretty_name: "SODA Synthetic Dialogue" +size_categories: + - 1M 1 and sys.argv[1] == "--print" + + +def main(output_dir: str = "data"): + """Download and prepare the dataset for use.""" + + random.seed(42) + dataset = load_dataset("allenai/soda") + os.makedirs(output_dir, exist_ok=True) + + for split in ["train", "test", "validation"]: + with open(f"{output_dir}/{split}.jsonl", "w", encoding="utf8") as output: + + for i in tqdm(range(len(dataset[split])), desc=split): + dat = dataset["train"][i] + title = dat["literal"] + story = dat["narrative"] + + if dat["relation"] == "xWant": + theme = "wanting " + dat["tail"] + elif dat["relation"] == "xNeed": + theme = "needing " + dat["tail"] + elif not dat["tail"].startswith("to ") and not dat["tail"].startswith("and "): + theme = "being " + dat["tail"] + elif dat["tail"].startswith("and "): + theme = "people are " + dat["tail"].replace("and PersonY ", "") + else: + theme = dat["tail"] + theme = theme.replace("PersonY", "another person") + theme = theme.replace("being is", "being") + + dialogue = [s2 + ": " + s1 for s1, s2 in zip(dat["dialogue"], dat["speakers"])] + + if random.randint(0, 6) == 0: + # print("##") + # print(f"User: Can you give me a short story description for this dialog?") + # print(" " + "\n ".join(dialog)) + # print(f"Assistant: Sure, a short story description for this dialog could be: \n {story}") + # print("User: And a title?") + # print(f"Assistant: Sure, a title for this dialog could be: \n {title}") + # if theme: + # print("User: What would be one theme of this story?") + # print(f'Assistant: One theme of this story could be: "{theme}"') + conversation = SUMMARY_TEMPLATE.format(dialogue="\n ".join(dialogue), story=story, title=title) + if theme: + conversation = conversation + THEME_TEMPLATE.format(theme=theme) + elif random.randint(0, 6) == 0: + # print("##") + # print(f"User: Can you write a short dialog based on this story:\n {story}") + # print(f"Assistant: Sure, a dialog for this story could be:") + # print(" " + "\n ".join(dialog)) + # print("User: And a title?") + # print(f"Assistant: Sure, a title for this dialog could be: \n {title}") + # if theme: + # print("User: What would be one theme of this story?") + # print(f'Assistant: One theme of this story could be: "{theme}"') + conversation = NEW_DIALOGUE_TEMPLATE.format( + story=story, dialogue="\n ".join(dialogue), title=title + ) + if theme: + conversation = conversation + THEME_TEMPLATE.format(theme=theme) + elif random.randint(0, 3) == 0: + # print("##") + # print(f"User: Can you write the next few lines of dialog for this scene:") + # if random.randint(0, 1) == 0: + # print(" " + "\n ".join(dialog[:-5])) + # print(f"Assistant: Sure, the next dialog for this scene could be:") + # print(" " + "\n ".join(dialog[-5:])) + # elif random.randint(0, 1) == 0: + # print(" " + "\n ".join(dialog[:-3])) + # print(f"Assistant: Sure, the next dialog for this scene could be:") + # print(" " + "\n ".join(dialog[-3:])) + # else: + # print(" " + "\n ".join(dialog[:-4])) + # print(f"Assistant: Sure, the next dialog for this scene could be:") + # print(" " + "\n ".join(dialog[-4:])) + # print("User: And a title?") + # print(f"Assistant: Sure, a title for this dialog could be: \n {title}") + # print("User: How about a short description?") + # print(f"Assistant: Sure, a short description for this dialog could be: \n {story}") + # if theme: + # print("User: What would be one theme of this story?") + # print(f'Assistant: One theme of this story could be: "{theme}"') + if random.randint(0, 1) == 0: + depth = -5 + elif random.randint(0, 1) == 0: + depth = -3 + else: + depth = -4 + conversation = NEXT_LINES_TEMPLATE.format( + scene="\n ".join(dialogue[:depth]), + dialogue="\n ".join(dialogue[depth:]), + title=title, + story=story, + ) + if theme: + conversation = conversation + THEME_TEMPLATE.format(theme=theme) + elif random.randint(0, 3) == 0: + # print("##") + # title1 = title.split(".")[0] + # title2 = title.split(".")[1] + # print(f"User: Can you write short story and dialog about: {title1}") + # print(f'Assistant: Sure, a short story and dialog about: "{title1}" could be:') + # print(f" {story}") + # if random.randint(0, 1) == 0: + # print(" " + "\n ".join(dialog)) + # elif random.randint(0, 1) == 0 and len(dialog) > 5: + # print(" " + "\n ".join(dialog[:-5])) + # print(f'User: Can you provide more dialog assuming "{title2}"?') + # print(f"Assistant: Sure, the next dialog for this scene could be:") + # print(" " + "\n ".join(dialog[-5:])) + # elif random.randint(0, 1) == 0: + # print(" " + "\n ".join(dialog[:-3])) + # print("User: more please.") + # print(f"Assistant: Sure, the next dialog for this scene could be:") + # print(" " + "\n ".join(dialog[-3:])) + # else: + # print(" " + "\n ".join(dialog[:-4])) + # print(f'User: Can you provide more dialog assuming "{title2}"?') + # print(f"Assistant: Sure, the next dialog for this scene could be:") + # print(" " + "\n ".join(dialog[-4:])) + # if theme: + # print("User: What would be one theme of this story?") + # print(f'Assistant: One theme of this story could be: "{theme}"') + title1 = title.split(".")[0] + title2 = title.split(".")[1] + conversation = NEW_STORY_AND_DIALOGUE_TEMPLATE.format(title1=title1, story=story) + if random.randint(0, 1) == 0: + conversation = FULL_DIALOGUE_TEMPLATE.format( + conversation=conversation, dialogue="\n ".join(dialogue) + ) + elif random.randint(0, 1) == 0 and len(dialogue) > 5: + conversation = MORE_DIALOGUE_TEMPLATE.format( + conversation=conversation, + dialogue1="\n ".join(dialogue[:-5]), + title2=title2, + dialogue2="\n ".join(dialogue[-5:]), + ) + elif random.randint(0, 1) == 0: + conversation = NEXT_DIALOGUE_TEMPLATE.format( + conversation=conversation, + dialogue1="\n ".join(dialogue[:-3]), + dialogue2="\n ".join(dialogue[-3:]), + ) + else: + conversation = MORE_DIALOGUE_TEMPLATE.format( + conversation=conversation, + dialogue1="\n ".join(dialogue[:-4]), + title2=title2, + dialogue2="\n ".join(dialogue[-4:]), + ) + if theme: + conversation = conversation + THEME_TEMPLATE.format(theme=theme) + else: + # print("##") + # print(f"User: Can you write short story and dialog based on the theme:\n {theme}") + # print(f'Assistant: Sure, a short story and dialog based on the theme "{theme}" could be:') + # print(f" {story}") + # print(" " + "\n ".join(dialog)) + # print("User: And a title?") + # print(f"Assistant: Sure, a title for this dialog could be: \n {title}") + conversation = NEW_STORY_AND_DIALOGUE_FROM_THEME_TEMPLATE.format( + theme=theme, story=story, dialogue="\n ".join(dialogue), title=title + ) + if PRINT: + print("##") + print(conversation) + + output.write(f"{json.dumps({'conversation': conversation})}\n") + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/openassistant/datasets/soda_synthetic_dialogue/soda_synthetic_dialogue.py b/openassistant/datasets/soda_synthetic_dialogue/soda_synthetic_dialogue.py new file mode 100644 index 00000000..ddc7c883 --- /dev/null +++ b/openassistant/datasets/soda_synthetic_dialogue/soda_synthetic_dialogue.py @@ -0,0 +1,108 @@ +# Copyright 2023 The OpenAssistant Authors and the current dataset script contributor. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This dataset is a set of dialogues synthesized from the SODA dataset. +In each dialogue, User and Assistant have a conversation about a story. + +The original collab notebook for this dataset can be found at: +https://colab.research.google.com/drive/1Sw3px5dP8whdqT7QMNoqwmqIasZkMbJi?usp=sharing +""" + +import json +from typing import Dict, List, Tuple + +import datasets + +from .hub import OpenAssistantConfig, features + +_CITATION = """\ +@article{ontocord2023sodasynth, + author = {ontocord and Jeffrey Quesnelle}, + title = {SODA Synthetic Dialogue}, + year = {2023} +} +""" +_DATASETNAME = "soda_synthetic_dialogue" +_DISPLAYNAME = "🥤SODA Synthetic Dialogue" +_DESCRIPTION = "A set of dialogues synthesized from the SODA dataset." +_HOMEPAGE = "" +_LICENSE = "mit" +_URLS = { + _DATASETNAME: {"train": "./data/train.jsonl", "test": "./data/test.jsonl", "validation": "./data/validation.jsonl"} +} +_SUPPORTED_TASKS = ["dialogue-modeling"] +_VERSION = "1.0.0" + + +class SODASyntheticDialogueDataset(datasets.GeneratorBasedBuilder): + """A set of dialogues synthesized from the SODA dataset.""" + + VERSION = datasets.Version(_VERSION) + + BUILDER_CONFIGS = [ + OpenAssistantConfig( + name=f"{_DATASETNAME}_dialogue_modeling", + version=VERSION, + description=f"OpenAssistant dataset config for {_DATASETNAME}", + schema="dialogue_modeling", + subset_id=_DATASETNAME, + ) + ] + + DEFAULT_CONFIG_NAME = f"{_DATASETNAME}_dialogue_modeling" + + def _info(self) -> datasets.DatasetInfo: + + return datasets.DatasetInfo( + description=_DESCRIPTION, + features=features, + homepage=_HOMEPAGE, + license=_LICENSE, + citation=_CITATION, + ) + + def _split_generators(self, dl_manager) -> List[datasets.SplitGenerator]: + """Returns SplitGenerators.""" + + urls = _URLS[_DATASETNAME] + data_dir = dl_manager.download_and_extract(urls) + + return [ + datasets.SplitGenerator( + name=datasets.Split.TRAIN, + gen_kwargs={"filepath": data_dir, "split": "train"}, + ), + datasets.SplitGenerator( + name=datasets.Split.TEST, + gen_kwargs={"filepath": data_dir, "split": "test"}, + ), + datasets.SplitGenerator( + name=datasets.Split.VALIDATION, + gen_kwargs={"filepath": data_dir, "split": "validation"}, + ), + ] + + def _generate_examples(self, filepath, split: str) -> Tuple[int, Dict]: + """Yields examples as (key, example) tuples.""" + + if self.config.schema == "dialogue_modeling": + key = 0 + with open(filepath[split], "r", encoding="utf8") as data: + while True: + line = data.readline() + if not line: + return + yield key, json.loads(line) + key += 1 diff --git a/website/public/locales/en/common.json b/website/public/locales/en/common.json index 82f378c9..99edf6c3 100644 --- a/website/public/locales/en/common.json +++ b/website/public/locales/en/common.json @@ -13,6 +13,5 @@ "sign_in": "Sign In", "sign_out": "Sign Out", "terms_of_service": "Terms of Service", - "title": "Open Assistant", - "last_updated_at": "Last updated at: {{val, datetime}}" + "title": "Open Assistant" } diff --git a/website/public/locales/en/index.json b/website/public/locales/en/index.json index 3443e444..7d8c9152 100644 --- a/website/public/locales/en/index.json +++ b/website/public/locales/en/index.json @@ -1,16 +1,15 @@ { - "title": "Open Assistant", - "subtitle": "Conversational AI for everyone.", - "description": "Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world.", "blurb": "We believe we can create a revolution.", "blurb1": "In the same way that Stable Diffusion helped the world make art and images in new ways, we want to improve the world by providing amazing conversational AI.", - "join_us_title": "Join us", - "join_us_description": "All open source projects begin with people like you. Open source is the belief that if we collaborate we can together gift our knowledge and technology to the world for the benefit of humanity. Are you in? Find us here:", - "faq_title": "Frequently Asked Questions", + "description": "Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world.", "faq_items": { "q0": "How far along is this project?", "a0": "We are in the early stages of development, working from established research in applying RLHF to large language models.", "q1": "Who is behind Open Assistant?", "a1": "Open Assistant is a project organized by LAION and individuals around the world interested in bringing this technology to everyone." - } + }, + "faq_title": "Frequently Asked Questions", + "join_us_description": "All open source projects begin with people like you. Open source is the belief that if we collaborate we can together gift our knowledge and technology to the world for the benefit of humanity. Are you in? Find us here:", + "join_us_title": "Join us", + "subtitle": "Conversational AI for everyone." } diff --git a/website/public/locales/en/leaderboard.json b/website/public/locales/en/leaderboard.json new file mode 100644 index 00000000..c2dd0832 --- /dev/null +++ b/website/public/locales/en/leaderboard.json @@ -0,0 +1,11 @@ +{ + "daily": "Daily", + "last_updated_at": "Last updated at: {{val, datetime}}", + "leaderboard": "Leaderboard", + "monthly": "Monthly", + "overall": "Overall", + "rank": "Rank", + "score": "Score", + "user": "User", + "weekly": "Weekly" +} diff --git a/website/src/components/Hero.tsx b/website/src/components/Hero.tsx index d401e47e..d9cab0c2 100644 --- a/website/src/components/Hero.tsx +++ b/website/src/components/Hero.tsx @@ -6,7 +6,7 @@ import { AnimatedCircles } from "./AnimatedCircles"; import { Container } from "./Container"; export function Hero() { - const { t } = useTranslation("index"); + const { t } = useTranslation(["index", "common"]); const { colorMode } = useColorMode(); const pTextColor = colorMode === "light" ? "text-gray-600" : "text-white"; const fancyTextGradientClasses = @@ -17,7 +17,7 @@ export function Hero() { - {t("title")} + {t("common:title")} ( desc: "Users Dashboard", icon: FiUsers, }, + { + label: "Status", + pathname: "/admin/status", + desc: "Status Dashboard", + icon: FiActivity, + }, ]} > {page} diff --git a/website/src/components/LeaderboardGridCell/LeaderboardGridCell.tsx b/website/src/components/LeaderboardGridCell/LeaderboardGridCell.tsx index 9750a851..7886784a 100644 --- a/website/src/components/LeaderboardGridCell/LeaderboardGridCell.tsx +++ b/website/src/components/LeaderboardGridCell/LeaderboardGridCell.tsx @@ -6,19 +6,19 @@ import { get } from "src/lib/api"; import { LeaderboardReply, LeaderboardTimeFrame } from "src/types/Leaderboard"; import useSWRImmutable from "swr/immutable"; -const columns = [ +const getColumns = (t) => [ { - Header: "Rank", + Header: t("rank"), accessor: "rank", style: { width: "90px" }, }, { - Header: "Score", + Header: t("score"), accessor: "leader_score", style: { width: "90px" }, }, { - Header: "User", + Header: t("user"), accessor: "display_name", }, ]; @@ -27,11 +27,13 @@ const columns = [ * Presents a grid of leaderboard entries with more detailed information. */ const LeaderboardGridCell = ({ timeFrame }: { timeFrame: LeaderboardTimeFrame }) => { - const { t } = useTranslation(); + const { t } = useTranslation(["leaderboard", "common"]); const { data: reply } = useSWRImmutable(`/api/leaderboard?time_frame=${timeFrame}`, get, { revalidateOnMount: true, }); + const columns = useMemo(() => getColumns(t), [t]); + const { getTableProps, getTableBodyProps, headerGroups, rows, prepareRow } = useTable({ columns, data: reply?.leaderboard ?? [], diff --git a/website/src/lib/oasst_api_client.ts b/website/src/lib/oasst_api_client.ts index 0a5b3ba2..36aa8e66 100644 --- a/website/src/lib/oasst_api_client.ts +++ b/website/src/lib/oasst_api_client.ts @@ -164,6 +164,27 @@ export class OasstApiClient { }); } + /** + * Returns the tasks availability information for given `user`. + */ + async fetch_tasks_availability(user: object): Promise { + return this.post("/api/v1/tasks/availability", user); + } + + /** + * Returns the message stats from the backend. + */ + async fetch_stats(): Promise { + return this.get("/api/v1/stats/"); + } + + /** + * Returns the tree manager stats from the backend. + */ + async fetch_tree_manager(): Promise { + return this.get("/api/v1/stats/tree_manager"); + } + /** * Returns the `BackendUser` associated with `user_id` */ diff --git a/website/src/pages/admin/status/index.tsx b/website/src/pages/admin/status/index.tsx new file mode 100644 index 00000000..12eb0785 --- /dev/null +++ b/website/src/pages/admin/status/index.tsx @@ -0,0 +1,174 @@ +import { + Box, + Card, + CardBody, + CircularProgress, + SimpleGrid, + Text, + Table, + TableCaption, + TableContainer, + Tbody, + Td, + Th, + Thead, + Tr, + useColorMode, +} from "@chakra-ui/react"; +import Head from "next/head"; +import { useRouter } from "next/router"; +import { useSession } from "next-auth/react"; +import { useEffect } from "react"; +import useSWRImmutable from "swr/immutable"; +import { getAdminLayout } from "src/components/Layout"; +import { get } from "src/lib/api"; + +/** + * Provides the admin status page that shows result of calls to several backend API endpoints, + * namely /api/v1/tasks/availability, /api/v1/stats/, /api/v1/stats/tree_manager + */ + +const StatusIndex = () => { + const router = useRouter(); + const { data: session, status } = useSession(); + + const { colorMode } = useColorMode(); + const dataBackgroundColor = colorMode === "light" ? "gray.100" : "gray.800"; + // Check when the user session is loaded and re-route if the user is not an + // admin. This follows the suggestion by NextJS for handling private pages: + // https://nextjs.org/docs/api-reference/next/router#usage + // + // All admin pages should use the same check and routing steps. + useEffect(() => { + if (status === "loading") { + return; + } + if (session?.user?.role === "admin") { + return; + } + router.push("/"); + }, [router, session, status]); + + const { + data: dataStatus, + error: errorStatus, + isLoading: isLoadingStatus, + } = useSWRImmutable("/api/admin/status", get); + + const { tasksAvailability, stats, treeManager } = dataStatus || {}; + + return ( + <> + + Status - Open Assistant + + + + + + + + /api/v1/tasks/availability + + + {tasksAvailability?.status === "fulfilled" ? ( +
{JSON.stringify(tasksAvailability.value, null, 2)}
+ ) : tasksAvailability?.status === "rejected" ? ( +
{JSON.stringify(tasksAvailability.reason, null, 2)}
+ ) : errorStatus ? ( +
{JSON.stringify(errorStatus, null, 2)}
+ ) : ( + + )} +
+
+
+ + + + + /api/v1/stats/ + + + {stats?.status === "fulfilled" ? ( +
{JSON.stringify(stats.value, null, 2)}
+ ) : stats?.status === "rejected" ? ( +
{JSON.stringify(stats.reason, null, 2)}
+ ) : errorStatus ? ( +
{JSON.stringify(errorStatus, null, 2)}
+ ) : ( + + )} +
+
+
+
+
+ + + + /api/v1/stats/tree_manager + + {treeManager?.status === "fulfilled" ? ( + + + state_counts + + +
{JSON.stringify(treeManager.value.state_counts, null, 2)}
+
+ +
+ + message_counts + + + Tree Manager + + + + + + + + + + + + + {treeManager.value.message_counts.map( + ({ message_tree_id, state, depth, oldest, youngest, count, goal_tree_size }) => ( + + + + + + + + + + ) + )} + +
Message Tree IDStateDepthOldestYoungestCountGoal Tree Size
{message_tree_id}{state}{depth}{oldest}{youngest}{count}{goal_tree_size}
+
+
+ ) : treeManager?.status === "rejected" ? ( +
{JSON.stringify(treeManager.reason, null, 2)}
+ ) : errorStatus ? ( +
{JSON.stringify(errorStatus, null, 2)}
+ ) : ( + + )} +
+
+ + ); +}; + +StatusIndex.getLayout = getAdminLayout; + +export default StatusIndex; diff --git a/website/src/pages/api/admin/status.ts b/website/src/pages/api/admin/status.ts new file mode 100644 index 00000000..1da03da8 --- /dev/null +++ b/website/src/pages/api/admin/status.ts @@ -0,0 +1,30 @@ +import { getToken } from "next-auth/jwt"; +import { withRole } from "src/lib/auth"; +import { oasstApiClient } from "src/lib/oasst_api_client"; +import { getBackendUserCore } from "src/lib/users"; + +/** + * Returns tasks availability, stats, and tree manager stats. + */ +const handler = withRole("admin", async (req, res) => { + const dummyUser = { + id: "__dummy_user__", + display_name: "Dummy User", + auth_method: "local", + }; + const [tasksAvailabilityOutcome, statsOutcome, treeManagerOutcome] = await Promise.allSettled([ + oasstApiClient.fetch_tasks_availability(dummyUser), + oasstApiClient.fetch_stats(), + oasstApiClient.fetch_tree_manager(), + ]); + + const status = { + tasksAvailability: tasksAvailabilityOutcome, + stats: statsOutcome, + treeManager: treeManagerOutcome, + }; + + res.status(200).json(status); +}); + +export default handler; diff --git a/website/src/pages/index.tsx b/website/src/pages/index.tsx index 8fe5d852..888c0d61 100644 --- a/website/src/pages/index.tsx +++ b/website/src/pages/index.tsx @@ -3,17 +3,17 @@ import Head from "next/head"; import { useRouter } from "next/router"; import { useSession } from "next-auth/react"; import { useTranslation } from "next-i18next"; -import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { useEffect } from "react"; import { CallToAction } from "src/components/CallToAction"; import { Faq } from "src/components/Faq"; import { Hero } from "src/components/Hero"; import { getTransparentHeaderLayout } from "src/components/Layout"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; const Home = () => { const router = useRouter(); const { status } = useSession(); - const { t } = useTranslation("index"); + const { t } = useTranslation(); useEffect(() => { if (status === "authenticated") { router.push("/dashboard"); @@ -37,10 +37,4 @@ const Home = () => { Home.getLayout = getTransparentHeaderLayout; -export const getStaticProps = async ({ locale }) => ({ - props: { - ...(await serverSideTranslations(locale, ["index", "common"])), - }, -}); - export default Home; diff --git a/website/src/pages/leaderboard.tsx b/website/src/pages/leaderboard.tsx index f79dac52..e413366f 100644 --- a/website/src/pages/leaderboard.tsx +++ b/website/src/pages/leaderboard.tsx @@ -1,27 +1,29 @@ import { Box, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from "@chakra-ui/react"; import Head from "next/head"; +import { useTranslation } from "next-i18next"; import { getDashboardLayout } from "src/components/Layout"; import { LeaderboardGridCell } from "src/components/LeaderboardGridCell"; export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; import { LeaderboardTimeFrame } from "src/types/Leaderboard"; const Leaderboard = () => { + const { t } = useTranslation(["leaderboard", "common"]); return ( <> - Leaderboard - Open Assistant + {`${t("leaderboard")} - ${t("common:title")}`} - Leaderboard + {t("leaderboard")} - Daily - Weekly - Monthly - Overall + {t("daily")} + {t("weekly")} + {t("monthly")} + {t("overall")}