Merge branch 'main' into 371_set_labels

This commit is contained in:
James Melvin
2023-01-08 09:28:38 +05:30
91 changed files with 2103 additions and 2051 deletions
+1 -4
View File
@@ -4,7 +4,6 @@ on:
push:
branches:
- main
- docs-site-poc
paths:
- ".github/workflows/deploy-docs-site.yaml"
- "docs/**"
@@ -45,9 +44,7 @@ jobs:
- name: Deploy
uses: peaceiris/actions-gh-pages@v3
if:
${{ github.ref == 'refs/heads/main' || github.ref ==
'refs/heads/docs-site-poc' }}
if: ${{ github.ref == 'refs/heads/main' }}
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
publish_dir: ./docs/build
+2
View File
@@ -2,3 +2,5 @@
/website/ @fozziethebeat @k-nearest-neighbor @AbdBarho
/model/ @theblackcat102 @sanagno
/copilot/ @fozziethebeat @andreaskoepf @yk
/docs/ @andrewm4894 @andreaskoepf @yk
/.devcontainer/ @andrewm4894 @andreaskoepf @yk
+1
View File
@@ -6,6 +6,7 @@
<div align="center">
<a href="https://github.com/LAION-AI/Open-Assistant/stargazers">![GitHub Repo stars](https://img.shields.io/github/stars/LAION-AI/Open-Assistant?style=social)</a>
<a href="https://laion-ai.github.io/Open-Assistant/">![Docs](https://img.shields.io/badge/docs-laion--ai.github.io%2FOpen--Assistant%2F-green)</a>
<a href="https://github.com/LAION-AI/Open-Assistant/actions/workflows/build-frontend.yaml">![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/LAION-AI/Open-Assistant/build-frontend.yaml?label=frontend)</a>
<a href="https://github.com/LAION-AI/Open-Assistant/actions/workflows/pre-commit.yaml">![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/LAION-AI/Open-Assistant/pre-commit.yaml?label=pre-commit)</a>
<a href="https://github.com/LAION-AI/Open-Assistant/actions/workflows/test-api-contract.yaml">![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/LAION-AI/Open-Assistant/test-api-contract.yaml?label=api)</a>
+21
View File
@@ -1,5 +1,18 @@
# Open-Assistant REST Backend
## Backend Development Setup
In root directory, run
`docker compose up backend-dev --build --attach-dependencies` to start a
database. The default settings are already configured to connect to the database
at `localhost:5432`.
Make sure you have all requirements installed. You can do this by running
`pip install -r requirements.txt` inside the `backend` folder and
`pip install -e .` inside the `oasst-shared` folder. Then, run the backend using
the `run-local.sh` script inside the `scripts` folder. This will start the
backend server at `http://localhost:8080`.
## REST Server Configuration
Please either use environment variables or create a `.env` file in the backend
@@ -20,3 +33,11 @@ REDIS_PORT=6379
Have a look into the main `README.md` file for more information on how to set up
the backend for development. Use the scripts within the
scripts/backend-development folder to run the BE API locally.
## Alembic
To create an Alembic database migration script after sql-models were modified
run `alembic revision --autogenerate -m "..."` ("..." is what you did) in the
`/backend` directory. Then edit the newly created file. See
[here](https://alembic.sqlalchemy.org/en/latest/tutorial.html) for more
information.
@@ -0,0 +1,23 @@
"""added frontend_type to api_client
Revision ID: ba61fe17fb6e
Revises: 20cd871f4ec7
Create Date: 2023-01-07 12:50:32.195930
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "ba61fe17fb6e"
down_revision = "20cd871f4ec7"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column("api_client", sa.Column("frontend_type", sa.String(256), nullable=True))
def downgrade() -> None:
op.drop_column("api_client", "frontend_id")
+20 -66
View File
@@ -1,3 +1,4 @@
import json
from http import HTTPStatus
from math import ceil
from pathlib import Path
@@ -6,7 +7,6 @@ from typing import Optional
import alembic.command
import alembic.config
import fastapi
import pydantic
import redis.asyncio as redis
from fastapi_limiter import FastAPILimiter
from loguru import logger
@@ -17,6 +17,7 @@ from oasst_backend.database import engine
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from pydantic import BaseModel
from sqlmodel import Session
from starlette.middleware.cors import CORSMiddleware
@@ -97,7 +98,7 @@ if settings.DEBUG_USE_SEED_DATA:
@app.on_event("startup")
def seed_data():
class DummyMessage(pydantic.BaseModel):
class DummyMessage(BaseModel):
task_message_id: str
user_message_id: str
parent_message_id: Optional[str]
@@ -111,64 +112,10 @@ if settings.DEBUG_USE_SEED_DATA:
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
pr = PromptRepository(db=db, api_client=api_client, user=dummy_user)
dummy_messages = [
DummyMessage(
task_message_id="de111fa8",
user_message_id="6f1d0711",
parent_message_id=None,
text="Hi!",
role="prompter",
),
DummyMessage(
task_message_id="74c381d4",
user_message_id="4a24530b",
parent_message_id="6f1d0711",
text="Hello! How can I help you?",
role="assistant",
),
DummyMessage(
task_message_id="3d5dc440",
user_message_id="a8c01c04",
parent_message_id="4a24530b",
text="Do you have a recipe for potato soup?",
role="prompter",
),
DummyMessage(
task_message_id="643716c1",
user_message_id="f43a93b7",
parent_message_id="4a24530b",
text="Who were the 8 presidents before George Washington?",
role="prompter",
),
DummyMessage(
task_message_id="2e4e1e6",
user_message_id="c886920",
parent_message_id="6f1d0711",
text="Hey buddy! How can I serve you?",
role="assistant",
),
DummyMessage(
task_message_id="970c437d",
user_message_id="cec432cf",
parent_message_id=None,
text="euirdteunvglfe23908230892309832098 AAAAAAAA",
role="prompter",
),
DummyMessage(
task_message_id="6066118e",
user_message_id="4f85f637",
parent_message_id="cec432cf",
text="Sorry, I did not understand your request and it is unclear to me what you want me to do. Could you describe it in a different way?",
role="assistant",
),
DummyMessage(
task_message_id="ba87780d",
user_message_id="0e276b98",
parent_message_id="cec432cf",
text="I'm unsure how to interpret this. Is it a riddle?",
role="assistant",
),
]
with open(settings.DEBUG_USE_SEED_DATA_PATH) as f:
dummy_messages_raw = json.load(f)
dummy_messages = [DummyMessage(**dm) for dm in dummy_messages_raw]
for msg in dummy_messages:
task = pr.fetch_task_by_frontend_message_id(msg.task_message_id)
@@ -185,12 +132,20 @@ if settings.DEBUG_USE_SEED_DATA:
parent_message = pr.fetch_message_by_frontend_message_id(
msg.parent_message_id, fail_if_missing=True
)
task = pr.store_task(
protocol_schema.AssistantReplyTask(
conversation=protocol_schema.Conversation(
messages=[protocol_schema.ConversationMessage(text="dummy", is_assistant=False)]
conversation_messages = pr.fetch_message_conversation(parent_message)
conversation = protocol_schema.Conversation(
messages=[
protocol_schema.ConversationMessage(
text=cmsg.text,
is_assistant=cmsg.role == "assistant",
message_id=cmsg.id,
fronend_message_id=cmsg.frontend_message_id,
)
),
for cmsg in conversation_messages
]
)
task = pr.store_task(
protocol_schema.AssistantReplyTask(conversation=conversation),
message_tree_id=parent_message.message_tree_id,
parent_message_id=parent_message.id,
)
@@ -219,7 +174,6 @@ if __name__ == "__main__":
# Importing here so we don't import packages unnecessarily if we're
# importing main as a module.
import argparse
import json
import uvicorn
+7 -1
View File
@@ -40,7 +40,13 @@ def get_dummy_api_client(db: Session) -> ApiClient:
if api_client is None:
token = token_hex(32)
logger.info(f"ANY_API_KEY missing, inserting api_key: {token}")
api_client = ApiClient(id=ANY_API_KEY_ID, api_key=token, description="ANY_API_KEY, random token", trusted=True)
api_client = ApiClient(
id=ANY_API_KEY_ID,
api_key=token,
description="ANY_API_KEY, random token",
trusted=True,
frontend_type="Test frontend",
)
db.add(api_client)
db.commit()
return api_client
@@ -2,9 +2,7 @@ from fastapi import APIRouter, Depends
from oasst_backend.api import deps
from oasst_backend.api.v1 import utils
from oasst_backend.models import ApiClient
from oasst_backend.models.db_payload import MessagePayload
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol
from sqlmodel import Session
@@ -20,11 +18,6 @@ def get_message_by_frontend_id(
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_message_by_frontend_message_id(message_id)
if not isinstance(message.payload.payload, MessagePayload):
# Unexpected message payload
raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE)
return utils.prepare_message(message)
-6
View File
@@ -5,9 +5,7 @@ from fastapi import APIRouter, Depends, Query
from oasst_backend.api import deps
from oasst_backend.api.v1 import utils
from oasst_backend.models import ApiClient
from oasst_backend.models.db_payload import MessagePayload
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol
from sqlmodel import Session
from starlette.status import HTTP_204_NO_CONTENT
@@ -55,10 +53,6 @@ def get_message(
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_message(message_id)
if not isinstance(message.payload.payload, MessagePayload):
# Unexptcted message payload
raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE)
return utils.prepare_message(message)
+28 -16
View File
@@ -6,6 +6,7 @@ from fastapi import APIRouter, Depends
from fastapi.security.api_key import APIKey
from loguru import logger
from oasst_backend.api import deps
from oasst_backend.api.v1.utils import prepare_conversation
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
@@ -58,7 +59,10 @@ def generate_task(
messages = pr.fetch_random_conversation("assistant")
task_messages = [
protocol_schema.ConversationMessage(
text=msg.payload.payload.text, is_assistant=(msg.role == "assistant")
text=msg.text,
is_assistant=(msg.role == "assistant"),
message_id=msg.id,
front_end_id=msg.frontend_message_id,
)
for msg in messages
]
@@ -71,7 +75,10 @@ def generate_task(
messages = pr.fetch_random_conversation("prompter")
task_messages = [
protocol_schema.ConversationMessage(
text=msg.payload.payload.text, is_assistant=(msg.role == "assistant")
text=msg.text,
is_assistant=(msg.role == "assistant"),
message_id=msg.id,
front_end_id=msg.frontend_message_id,
)
for msg in messages
]
@@ -83,19 +90,21 @@ def generate_task(
logger.info("Generating a RankInitialPromptsTask.")
messages = pr.fetch_random_initial_prompts()
task = protocol_schema.RankInitialPromptsTask(prompts=[msg.payload.payload.text for msg in messages])
task = protocol_schema.RankInitialPromptsTask(prompts=[msg.text for msg in messages])
case protocol_schema.TaskRequestType.rank_prompter_replies:
logger.info("Generating a RankPrompterRepliesTask.")
conversation, replies = pr.fetch_multiple_random_replies(message_role="assistant")
task_messages = [
protocol_schema.ConversationMessage(
text=p.payload.payload.text,
text=p.text,
is_assistant=(p.role == "assistant"),
message_id=p.id,
front_end_id=p.frontend_message_id,
)
for p in conversation
]
replies = [p.payload.payload.text for p in replies]
replies = [p.text for p in replies]
task = protocol_schema.RankPrompterRepliesTask(
conversation=protocol_schema.Conversation(
messages=task_messages,
@@ -109,14 +118,16 @@ def generate_task(
task_messages = [
protocol_schema.ConversationMessage(
text=p.payload.payload.text,
text=p.text,
is_assistant=(p.role == "assistant"),
message_id=p.id,
front_end_id=p.frontend_message_id,
)
for p in conversation
]
replies = [p.payload.payload.text for p in replies]
replies = [p.text for p in replies]
task = protocol_schema.RankAssistantRepliesTask(
conversation=protocol_schema.Conversation(messages=task_messages),
conversation=prepare_conversation(conversation),
replies=replies,
)
@@ -125,29 +136,29 @@ def generate_task(
message = pr.fetch_random_initial_prompts(1)[0]
task = protocol_schema.LabelInitialPromptTask(
message_id=message.id,
prompt=message.payload.payload.text,
prompt=message.text,
valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)),
)
case protocol_schema.TaskRequestType.label_prompter_reply:
logger.info("Generating a LabelPrompterReplyTask.")
conversation, messages = pr.fetch_multiple_random_replies(max_size=1, message_role="assistant")
message = messages[0].payload.payload.text
message = messages[0]
task = protocol_schema.LabelPrompterReplyTask(
message_id=message.id,
conversation=conversation,
reply=message,
conversation=prepare_conversation(conversation),
reply=message.text,
valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)),
)
case protocol_schema.TaskRequestType.label_assistant_reply:
logger.info("Generating a LabelAssistantReplyTask.")
conversation, messages = pr.fetch_multiple_random_replies(max_size=1, message_role="prompter")
message = messages[0].payload.payload.text
message = messages[0]
task = protocol_schema.LabelAssistantReplyTask(
message_id=message.id,
conversation=conversation,
reply=message,
conversation=prepare_conversation(conversation),
reply=message.text,
valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)),
)
@@ -292,7 +303,8 @@ def tasks_interaction(
logger.info(
f"Frontend reports labels of {interaction.message_id=} with {interaction.labels=} by {interaction.user=}."
)
# TODO: check if the labels are valid?
# Labels are implicitly validated when converting str -> TextLabel
# So no need for explicit validation here
pr.store_text_labels(interaction)
return protocol_schema.TaskDone()
case _:
@@ -3,6 +3,7 @@ from fastapi.security.api_key import APIKey
from loguru import logger
from oasst_backend.api import deps
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.schemas.text_labels import LabelOption, ValidLabelsResponse
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
from starlette.status import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST
@@ -32,3 +33,13 @@ def label_text(
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
)
@router.get("/valid_labels")
def get_valid_lables() -> ValidLabelsResponse:
return ValidLabelsResponse(
valid_labels=[
LabelOption(name=l.value, display_text=l.display_text, help_text=l.help_text)
for l in protocol_schema.TextLabel
]
)
+7 -11
View File
@@ -1,19 +1,14 @@
from http import HTTPStatus
from uuid import UUID
from oasst_backend.models import Message
from oasst_backend.models.db_payload import MessagePayload
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol
def prepare_message(m: Message) -> protocol.Message:
if not isinstance(m.payload.payload, MessagePayload):
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
return protocol.Message(
id=m.id,
parent_id=m.parent_id,
text=m.payload.payload.text,
text=m.text,
is_assistant=(m.role == "assistant"),
created_date=m.created_date,
)
@@ -26,10 +21,13 @@ def prepare_message_list(messages: list[Message]) -> list[protocol.Message]:
def prepare_conversation(messages: list[Message]) -> protocol.Conversation:
conv_messages = []
for message in messages:
if not isinstance(message.payload.payload, MessagePayload):
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
conv_messages.append(
protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant"))
protocol.ConversationMessage(
text=message.text,
is_assistant=(message.role == "assistant"),
message_id=message.id,
frontend_message_id=message.frontend_message_id,
)
)
return protocol.Conversation(messages=conv_messages)
@@ -38,8 +36,6 @@ def prepare_conversation(messages: list[Message]) -> protocol.Conversation:
def prepare_tree(tree: list[Message], tree_id: UUID) -> protocol.MessageTree:
tree_messages = []
for message in tree:
if not isinstance(message.payload.payload, MessagePayload):
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
tree_messages.append(prepare_message(message))
return protocol.MessageTree(id=tree_id, messages=tree_messages)
+5 -1
View File
@@ -1,6 +1,7 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from pydantic import AnyHttpUrl, BaseSettings, PostgresDsn, validator
from pydantic import AnyHttpUrl, BaseSettings, FilePath, PostgresDsn, validator
class Settings(BaseSettings):
@@ -21,6 +22,9 @@ class Settings(BaseSettings):
DEBUG_ALLOW_ANY_API_KEY: bool = False
DEBUG_SKIP_API_KEY_CHECK: bool = False
DEBUG_USE_SEED_DATA: bool = False
DEBUG_USE_SEED_DATA_PATH: Optional[FilePath] = (
Path(__file__).parent.parent / "test_data/generic/test_generic_data.json"
)
HUGGING_FACE_API_KEY: str = ""
@@ -20,3 +20,4 @@ class ApiClient(SQLModel, table=True):
admin_email: Optional[str] = Field(max_length=256, nullable=True)
enabled: bool = Field(default=True)
trusted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
frontend_type: str = Field(max_length=256, nullable=True)
+5 -5
View File
@@ -32,7 +32,7 @@ class Journal(SQLModel, table=True):
created_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp())
)
user_id: UUID = Field(nullable=True, foreign_key="user.id", index=True)
user_id: Optional[UUID] = Field(nullable=True, foreign_key="user.id", index=True)
message_id: Optional[UUID] = Field(foreign_key="message.id", nullable=True)
api_client_id: UUID = Field(foreign_key="api_client.id")
@@ -49,7 +49,7 @@ class JournalIntegration(SQLModel, table=True):
),
)
description: str = Field(max_length=512, primary_key=True)
last_journal_id: UUID = Field(foreign_key="journal.id", nullable=True)
last_run: datetime = Field(sa_column=sa.Column(sa.DateTime(), nullable=True))
last_error: str = Field(nullable=True)
next_run: datetime = Field(nullable=True)
last_journal_id: Optional[UUID] = Field(foreign_key="journal.id", nullable=True)
last_run: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(), nullable=True))
last_error: Optional[str] = Field(nullable=True)
next_run: Optional[datetime] = Field(nullable=True)
+19 -5
View File
@@ -1,9 +1,12 @@
from datetime import datetime
from http import HTTPStatus
from typing import Optional
from uuid import UUID, uuid4
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from oasst_backend.models.db_payload import MessagePayload
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from sqlalchemy import false
from sqlmodel import Field, Index, SQLModel
@@ -19,19 +22,30 @@ class Message(SQLModel, table=True):
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
),
)
parent_id: UUID = Field(nullable=True)
parent_id: Optional[UUID] = Field(nullable=True)
message_tree_id: UUID = Field(nullable=False, index=True)
task_id: UUID = Field(nullable=True, index=True)
user_id: UUID = Field(nullable=True, foreign_key="user.id", index=True)
role: str = Field(nullable=False, max_length=128) # valid: "prompter" | "assistant"
task_id: Optional[UUID] = Field(nullable=True, index=True)
user_id: Optional[UUID] = Field(nullable=True, foreign_key="user.id", index=True)
role: str = Field(nullable=False, max_length=128, regex="^prompter|assistant$")
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
frontend_message_id: str = Field(max_length=200, nullable=False)
created_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
)
payload_type: str = Field(nullable=False, max_length=200)
payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=True))
payload: Optional[PayloadContainer] = Field(
sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=True)
)
lang: str = Field(nullable=False, max_length=200, default="en-US")
depth: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
children_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
deleted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
def ensure_is_message(self) -> None:
if not self.payload or not isinstance(self.payload.payload, MessagePayload):
raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE, HTTPStatus.INTERNAL_SERVER_ERROR)
@property
def text(self) -> str:
self.ensure_is_message()
return self.payload.payload.text
+1 -1
View File
@@ -22,7 +22,7 @@ class Task(SQLModel, table=True):
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
)
expiry_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(), nullable=True))
user_id: UUID = Field(nullable=True, foreign_key="user.id", index=True)
user_id: Optional[UUID] = Field(nullable=True, foreign_key="user.id", index=True)
payload_type: str = Field(nullable=False, max_length=200)
payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False))
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
@@ -0,0 +1,13 @@
from typing import Optional
from pydantic import BaseModel
class LabelOption(BaseModel):
name: str
display_text: str
help_text: Optional[str]
class ValidLabelsResponse(BaseModel):
valid_labels: list[LabelOption]
@@ -0,0 +1,72 @@
[
{
"task_message_id": "de111fa8",
"user_message_id": "6f1d0711",
"parent_message_id": null,
"text": "Hi!",
"role": "prompter"
},
{
"task_message_id": "74c381d4",
"user_message_id": "4a24530b",
"parent_message_id": "6f1d0711",
"text": "Hello! How can I help you?",
"role": "assistant"
},
{
"task_message_id": "3d5dc440",
"user_message_id": "a8c01c04",
"parent_message_id": "4a24530b",
"text": "Do you have a recipe for potato soup?",
"role": "prompter"
},
{
"task_message_id": "643716c1",
"user_message_id": "f43a93b7",
"parent_message_id": "4a24530b",
"text": "Who were the 8 presidents before George Washington?",
"role": "prompter"
},
{
"task_message_id": "2e4e1e6",
"user_message_id": "c886920",
"parent_message_id": "6f1d0711",
"text": "Hey buddy! How can I serve you?",
"role": "assistant"
},
{
"task_message_id": "970c437d",
"user_message_id": "cec432cf",
"parent_message_id": null,
"text": "euirdteunvglfe23908230892309832098 AAAAAAAA",
"role": "prompter"
},
{
"task_message_id": "6066118e",
"user_message_id": "4f85f637",
"parent_message_id": "cec432cf",
"text": "Sorry, I did not understand your request and it is unclear to me what you want me to do. Could you describe it in a different way?",
"role": "assistant"
},
{
"task_message_id": "ba87780d",
"user_message_id": "0e276b98",
"parent_message_id": "cec432cf",
"text": "I'm unsure how to interpret this. Is it a riddle?",
"role": "assistant"
},
{
"task_message_id": "b8e98ed6",
"user_message_id": "89384709",
"parent_message_id": "0e276b98",
"text": "No, I just wanted to see how you reply when I type random characters. Can you tell me who invented Wikipedia?",
"role": "prompter"
},
{
"task_message_id": "9a0e7683",
"user_message_id": "6d452c57",
"parent_message_id": "0e276b98",
"text": "Sorry, my cat sat on my keyboard. Can you print a cat in ASCII art?",
"role": "prompter"
}
]
+1
View File
@@ -14,3 +14,4 @@ COPY ./backend/alembic /app/alembic
COPY ./backend/alembic.ini /app/alembic.ini
COPY ./backend/main.py /app/main.py
COPY ./backend/oasst_backend /app/oasst_backend
COPY ./backend/test_data /app/test_data
+5 -3
View File
@@ -1,7 +1,9 @@
# Website
# Docs Site
This website is built using [Docusaurus 2](https://docusaurus.io/), a modern
static website generator.
https://laion-ai.github.io/Open-Assistant/
This [site](https://laion-ai.github.io/Open-Assistant/) is built using
[Docusaurus 2](https://docusaurus.io/), a modern static website generator.
### Contributing
+1 -1
View File
@@ -20,7 +20,7 @@ const config = {
// If you aren't using GitHub pages, you don't need these.
organizationName: "LAION-AI", // Usually your GitHub org/user name.
projectName: "Open-Assistant", // Usually your repo name.
deploymentBranch: "docs-site-poc",
deploymentBranch: "main",
// Even if you don't use internalization, you can use this field to set useful
// metadata like html lang. For example, if your site is Chinese, you may want
@@ -0,0 +1,15 @@
model_name: microsoft/deberta-v2-xlarge
learning_rate: 1e-5
freeze_layer: 15
scheduler: cosine
gradient_checkpointing: false
gradient_accumulation_steps: 16
per_device_train_batch_size: 1
warmup_steps: 600
eval_steps: 200
save_steps: 500
max_length: 512
num_train_epochs: 2
datasets:
- webgpt
- hfsummary
@@ -0,0 +1,14 @@
model_name: microsoft/deberta-v3-base
learning_rate: 1e-5
scheduler: cosine
gradient_checkpointing: false
gradient_accumulation_steps: 32
per_device_train_batch_size: 2
warmup_steps: 600
eval_steps: 200
save_steps: 500
max_length: 512
num_train_epochs: 2
datasets:
- webgpt
- hfsummary
@@ -0,0 +1,13 @@
model_name: deepset/deberta-v3-large-squad2
learning_rate: 1e-5
gradient_checkpointing: false
gradient_accumulation_steps: 32
per_device_train_batch_size: 1
warmup_steps: 600
eval_steps: 200
save_steps: 500
max_length: 512
num_train_epochs: 2
datasets:
- webgpt
- hfsummary
@@ -0,0 +1,14 @@
model_name: microsoft/deberta-v3-large
learning_rate: 1e-5
scheduler: cosine
gradient_checkpointing: false
gradient_accumulation_steps: 32
per_device_train_batch_size: 1
warmup_steps: 600
eval_steps: 200
save_steps: 500
max_length: 512
num_train_epochs: 2
datasets:
- webgpt
- hfsummary
+20 -1
View File
@@ -11,6 +11,7 @@ from rank_datasets import DataCollatorForPairRank, HFSummary, RankGenCollator, W
from torch import nn
from torch.utils.data import ConcatDataset, Dataset
from transformers import (
AdamW,
AutoModelForSequenceClassification,
DataCollator,
EvalPrediction,
@@ -19,6 +20,8 @@ from transformers import (
Trainer,
TrainerCallback,
TrainingArguments,
get_cosine_schedule_with_warmup,
get_linear_schedule_with_warmup,
)
from utils import argument_parsing, freeze_top_n_layers, get_tokenizer, train_val_dataset
@@ -179,7 +182,7 @@ if __name__ == "__main__":
evaluation_strategy="steps",
eval_steps=training_conf["eval_steps"],
save_steps=1000,
report_to="local",
report_to="wandb",
)
train_datasets, evals = [], {}
if "webgpt" in training_conf["datasets"]:
@@ -202,6 +205,21 @@ if __name__ == "__main__":
else:
collate_fn = DataCollatorForPairRank(tokenizer, max_length=training_conf["max_length"])
assert len(evals) > 0
optimizer = AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
scheduler = None
if "scheduler" in training_conf:
if training_conf["scheduler"] == "linear":
scheduler = get_linear_schedule_with_warmup()
elif training_conf["scheduler"] == "cosine":
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=args.warmup_steps,
num_training_steps=len(train)
* args.num_train_epochs
/ (args.per_device_train_batch_size * args.gradient_accumulation_steps),
)
trainer = RankTrainer(
model=model,
model_name=model_name,
@@ -211,6 +229,7 @@ if __name__ == "__main__":
data_collator=collate_fn,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
optimizers=(optimizer, scheduler),
)
# trainer.evaluate()
trainer.train()
@@ -1,5 +1,13 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/LAION-AI/Open-Assistant/blob/main/notebooks/code-bugger/openbugger_example.ipynb)"
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -1,5 +1,13 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/LAION-AI/Open-Assistant/blob/main/notebooks/data-argumentation/EssayInstructions.ipynb)"
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -1,5 +1,13 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/LAION-AI/Open-Assistant/blob/main/notebooks/data-argumentation/EssayRevision.ipynb)"
]
},
{
"cell_type": "markdown",
"metadata": {
File diff suppressed because one or more lines are too long
@@ -1,5 +1,23 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/LAION-AI/Open-Assistant/blob/main/notebooks/detoxify-evaluation/DetoxityEvaluation.ipynb)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# uncomment below to install required python packages\n",
"#!pip install detoxify"
]
},
{
"attachments": {},
"cell_type": "markdown",
+19 -2
View File
@@ -34,6 +34,8 @@ class ConversationMessage(BaseModel):
text: str
is_assistant: bool
message_id: Optional[UUID] = None
frontend_message_id: Optional[str] = None
class Conversation(BaseModel):
@@ -263,10 +265,25 @@ class MessageRanking(Interaction):
class TextLabel(str, enum.Enum):
"""A label for a piece of text."""
def __new__(cls, label: str, display_text: str = "", help_text: str = None):
obj = str.__new__(cls, label)
obj._value_ = label
obj.display_text = display_text
obj.help_text = help_text
return obj
spam = "spam"
violence = "violence"
sexual_content = "sexual_content"
fails_task = "fails_task", "Fails to follow the correct instruction / task"
not_appropriate = "not_appropriate", "Inappropriate for customer assistant"
violence = "violence", "Encourages or fails to discourage violence/abuse/terrorism/self-harm"
harmful = (
"harmful",
"Harmful content",
"The advice given in the output is harmful or counter-productive. This may be in addition to, but is distinct from the question about encouraging violence/abuse/terrorism/self-harm.",
)
sexual_content = "sexual_content", "Contains sexual content"
toxicity = "toxicity"
moral_judgement = "moral_judgement", "Expresses moral judgement"
political_content = "political_content"
humor = "humor"
sarcasm = "sarcasm"
+29 -10
View File
@@ -1,5 +1,6 @@
"""Simple REPL frontend."""
import http
import random
import requests
@@ -30,6 +31,8 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
def _post(path: str, json: dict) -> dict:
response = requests.post(f"{backend_url}{path}", json=json, headers={"X-API-Key": api_key})
response.raise_for_status()
if response.status_code == http.HTTPStatus.NO_CONTENT:
return None
return response.json()
typer.echo("Requesting work...")
@@ -191,7 +194,7 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
ranking_str = typer.prompt("Enter the reply numbers in order of preference, separated by commas")
ranking = [int(x) - 1 for x in ranking_str.split(",")]
# send ranking
# send labels
new_task = _post(
"/api/v1/tasks/interaction",
{
@@ -211,11 +214,19 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
valid_labels = task["valid_labels"]
labels_str: str = typer.prompt("Enter labels, separated by commas")
labels = labels_str.lower().replace(" ", "").split(",")
labels_dict = {label: "1" if label in labels else "0" for label in valid_labels}
# send ranking
labels_dict = None
while labels_dict is None:
labels_str: str = typer.prompt("Enter labels, separated by commas")
labels = labels_str.lower().replace(" ", "").split(",")
if all([label in valid_labels for label in labels]):
labels_dict = {label: "1" if label in labels else "0" for label in valid_labels}
else:
invalid_labels = [label for label in labels if label not in valid_labels]
typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}")
# send labels
new_task = _post(
"/api/v1/tasks/interaction",
{
@@ -240,17 +251,25 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
valid_labels = task["valid_labels"]
labels_str: str = typer.prompt("Enter labels, separated by commas")
labels = labels_str.lower().replace(" ", "").split(",")
labels_dict = {label: "1" if label in labels else "0" for label in valid_labels}
# send ranking
labels_dict = None
while labels_dict is None:
labels_str: str = typer.prompt("Enter labels, separated by commas")
labels = labels_str.lower().replace(" ", "").split(",")
if all([label in valid_labels for label in labels]):
labels_dict = {label: "1" if label in labels else "0" for label in valid_labels}
else:
invalid_labels = [label for label in labels if label not in valid_labels]
typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}")
# send labels
new_task = _post(
"/api/v1/tasks/interaction",
{
"type": "text_labels",
"message_id": task["message_id"],
"text": task["prompt"],
"text": task["reply"],
"labels": labels_dict,
"user": USER,
},
+2 -1
View File
@@ -8,7 +8,8 @@
"rules": {
"unused-imports/no-unused-imports": "warn",
"simple-import-sort/imports": "warn",
"simple-import-sort/exports": "warn"
"simple-import-sort/exports": "warn",
"eqeqeq": "warn"
},
"plugins": ["simple-import-sort", "unused-imports"]
}
+29 -616
View File
File diff suppressed because it is too large Load Diff
-1
View File
@@ -50,7 +50,6 @@
"react": "18.2.0",
"react-dom": "18.2.0",
"react-icons": "^4.7.1",
"sharp": "0.31.2",
"swr": "^2.0.0",
"tailwindcss": "^3.2.4",
"use-debounce": "^9.0.2"
+1 -1
View File
@@ -48,7 +48,7 @@ export function CallToAction() {
here:
</p>
<div className="mt-8 flex justify-center">
<a href="https://discord.gg/pXtnYk9c" rel="noreferrer" target="_blank">
<a href="https://ykilcher.com/open-assistant-discord" rel="noreferrer" target="_blank">
<button
type="button"
className="mb-2 ml-6 flex items-center rounded-md border border-transparent bg-blue-600 px-6 py-3 text-base font-medium text-white shadow-sm hover:bg-blue-700 focus:outline-none focus:ring-2 focus:ring-blue-500 focus:ring-offset-2"
+10 -2
View File
@@ -1,5 +1,13 @@
import { Button, useDisclosure } from "@chakra-ui/react";
import { Modal, ModalOverlay, ModalContent, ModalHeader, ModalBody, ModalCloseButton } from "@chakra-ui/react";
import {
Button,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalHeader,
ModalOverlay,
useDisclosure,
} from "@chakra-ui/react";
import React from "react";
export const CollapsableText = ({ text, maxLength = 220 }) => {
@@ -1,17 +0,0 @@
import { Box } from "@chakra-ui/react";
import { Message } from "./Messages";
export const ContextMessages = ({ messages }: { messages: Message[] }) => {
return (
<Box className="flex flex-col gap-1">
{messages.map((message, i) => {
return (
<Box key={i}>
<span>{message.is_assistant ? "Assistant: " : "User: "}</span>
<span>{message.text}</span>
</Box>
);
})}
</Box>
);
};
+38 -111
View File
@@ -1,126 +1,53 @@
import { Box, Flex, GridItem, Heading, SimpleGrid, Text, useColorModeValue } from "@chakra-ui/react";
import Link from "next/link";
const crTasks = [
{
label: "Create Initial Prompts",
desc: "Write initial prompts to help Open Assistant to try replying to diverse messages.",
type: "create",
pathname: "/create/initial_prompt",
},
{
label: "Reply as User",
desc: "Chat with Open Assistant and help improve its responses as you interact with it.",
type: "create",
pathname: "/create/user_reply",
},
{
label: "Reply as Assistant",
desc: "Help Open Assistant improve its responses to conversations with other users.",
type: "create",
pathname: "/create/assistant_reply",
},
];
import { TaskCategory, TaskTypes } from "../Tasks/TaskTypes";
const evTasks = [
{
label: "Rank User Replies",
type: "eval",
desc: "Help Open Assistant improve its responses to conversations with other users.",
pathname: "/evaluate/rank_user_replies",
},
{
label: "Rank Assistant Replies",
desc: "Score prompts given by Open Assistant based on their accuracy and readability.",
type: "eval",
pathname: "/evaluate/rank_assistant_replies",
},
{
label: "Rank Initial Prompts",
desc: "Score prompts given by Open Assistant based on their accuracy and readability.",
type: "eval;",
pathname: "/evaluate/rank_initial_prompts",
},
];
const displayTaskCategories = [TaskCategory.Create, TaskCategory.Evaluate, TaskCategory.Label];
export const TaskOption = () => {
const backgroundColor = useColorModeValue("white", "gray.700");
return (
<Box className="flex flex-col gap-14" fontFamily="inter">
<div>
<Text className="text-2xl font-bold pb-4">Create</Text>
<SimpleGrid columns={[1, 2, 2, 3, 4]} gap={4}>
{crTasks.map((item, itemIndex) => (
<Link key={itemIndex} href={item.pathname}>
<GridItem
bg={backgroundColor}
borderRadius="xl"
boxShadow="base"
className="flex flex-col justify-between h-full"
>
<Box className="p-6 pb-10">
<Flex flexDir="column" gap="3">
<Heading size="md" fontFamily="inter">
{item.label}
</Heading>
<Text size="sm" opacity="80%">
{item.desc}
</Text>
</Flex>
</Box>
<Box
bg="blue.500"
borderBottomRadius="xl"
className="px-6 py-2 transition-colors duration-300"
_hover={{ backgroundColor: "blue.600" }}
{displayTaskCategories.map((category, categoryIndex) => (
<div key={categoryIndex}>
<Text className="text-2xl font-bold pb-4">{category}</Text>
<SimpleGrid columns={[1, 2, 2, 3, 4]} gap={4}>
{TaskTypes.filter((task) => task.category === category).map((item, itemIndex) => (
<Link key={itemIndex} href={item.pathname}>
<GridItem
bg={backgroundColor}
borderRadius="xl"
boxShadow="base"
className="flex flex-col justify-between h-full"
>
<Text fontWeight="bold" color="white">
Go
</Text>
</Box>
</GridItem>
</Link>
))}
</SimpleGrid>
</div>
<div>
<Text className="text-2xl font-bold pb-4">Evaluate</Text>
<SimpleGrid columns={[1, 2, 2, 3, 4]} gap={4}>
{evTasks.map((item, itemIndex) => (
<Link key={itemIndex} href={item.pathname}>
<GridItem
bg={backgroundColor}
borderRadius="xl"
boxShadow="base"
className="flex flex-col justify-between h-full"
>
<Box className="p-6 pb-10">
<Flex flexDir="column" gap="3">
<Heading size="md" fontFamily="inter">
{item.label}
</Heading>
<Text size="sm" opacity="80%">
{item.desc}
<Box className="p-6 pb-10">
<Flex flexDir="column" gap="3">
<Heading size="md" fontFamily="inter">
{item.label}
</Heading>
<Text size="sm" opacity="80%">
{item.desc}
</Text>
</Flex>
</Box>
<Box
bg="blue.500"
borderBottomRadius="xl"
className="px-6 py-2 transition-colors duration-300"
_hover={{ backgroundColor: "blue.600" }}
>
<Text fontWeight="bold" color="white">
Go
</Text>
</Flex>
</Box>
<Box
bg="blue.500"
borderBottomRadius="xl"
className="px-6 py-2 transition-colors duration-300"
_hover={{ backgroundColor: "blue.600" }}
>
<Text fontWeight="bold" color="white">
Go
</Text>
</Box>
</GridItem>
</Link>
))}
</SimpleGrid>
</div>
</Box>
</GridItem>
</Link>
))}
</SimpleGrid>
</div>
))}
</Box>
);
};
@@ -1,3 +1,2 @@
export { LeaderboardTable } from "./LeaderboardTable";
export { SideMenu } from "./SideMenu";
export { TaskOption } from "./TaskOption";
+3 -2
View File
@@ -24,8 +24,8 @@ import {
import { FlagIcon, QuestionMarkCircleIcon } from "@heroicons/react/20/solid";
import { useState } from "react";
import poster from "src/lib/poster";
import useSWRMutation from "swr/mutation";
import { colors } from "styles/Theme/colors";
import useSWRMutation from "swr/mutation";
export const FlaggableElement = (props) => {
const [isEditing, setIsEditing] = useBoolean();
@@ -118,7 +118,8 @@ export const FlaggableElement = (props) => {
</Popover>
);
};
function FlagCheckbox(props: {
export function FlagCheckbox(props: {
option: textFlagLabels;
idx: number;
checkboxValues: boolean[];
+29 -43
View File
@@ -1,6 +1,7 @@
import { useColorMode } from "@chakra-ui/react";
import Image from "next/image";
import Link from "next/link";
import { useMemo } from "react";
export function Footer() {
const { colorMode } = useColorMode();
@@ -9,7 +10,7 @@ export function Footer() {
return (
<footer className={bgColorClass}>
<div className={`flex mx-auto max-w-7xl justify-between py-10 px-10 border-t ${borderClass}`}>
<div className={`flex mx-auto max-w-7xl justify-between border-t p-10 ${borderClass}`}>
<div className="flex items-center pr-8">
<Link href="/" aria-label="Home" className="flex items-center">
<Image src="/images/logos/logo.svg" className="mx-auto object-fill" width="52" height="52" alt="logo" />
@@ -21,50 +22,35 @@ export function Footer() {
</div>
</div>
<nav className="flex justify-center gap-20">
<nav className="flex justify-center gap-20">
<div className="flex flex-col text-sm leading-7">
<b>Legal</b>
<div className="flex flex-col leading-5">
<Link href="/privacy-policy" aria-label="Privacy Policy" className="hover:underline underline-offset-2">
Privacy Policy
</Link>
<Link
href="/terms-of-service"
aria-label="Terms of Service"
className="hover:underline underline-offset-2"
>
Terms of Service
</Link>
</div>
</div>
<div className="flex flex-col text-sm leading-7">
<b>Connect</b>
<div className="flex flex-col leading-5">
<Link
href="https://github.com/LAION-AI/Open-Assistant"
rel="noopener noreferrer nofollow"
target="_blank"
aria-label="Privacy Policy"
className="hover:underline underline-offset-2"
>
Github
</Link>
<Link
href="https://discord.gg/pXtnYk9c"
rel="noopener noreferrer nofollow"
target="_blank"
aria-label="Terms of Service"
className="hover:underline underline-offset-2"
>
Discord
</Link>
</div>
</div>
</nav>
{/* </div> */}
<nav className="grid grid-cols-2 gap-20 leading-5 text-sm">
<div className="flex flex-col">
<b className="pb-1">Legal</b>
<FooterLink href="/privacy-policy" label="Privacy Policy" />
<FooterLink href="/terms-of-service" label="Terms of Service" />
</div>
<div className="flex flex-col">
<b className="pb-1">Connect</b>
<FooterLink href="https://github.com/LAION-AI/Open-Assistant" label="Github" />
<FooterLink href="https://ykilcher.com/open-assistant-discord" label="Discord" />
</div>
</nav>
</div>
</footer>
);
}
const FooterLink = ({ href, label }: { href: string; label: string }) =>
useMemo(
() => (
<Link
href={href}
rel="noopener noreferrer nofollow"
target="_blank"
aria-label={label}
className="hover:underline underline-offset-2"
>
{label}
</Link>
),
[href, label]
);
@@ -2,6 +2,7 @@ import { Box, Link, Text, useColorModeValue } from "@chakra-ui/react";
import { Popover } from "@headlessui/react";
import { AnimatePresence, motion } from "framer-motion";
import Image from "next/image";
import NextLink from "next/link";
import { signOut, useSession } from "next-auth/react";
import React from "react";
import { FiLayout, FiLogOut, FiSettings } from "react-icons/fi";
@@ -77,6 +78,7 @@ export function UserMenu() {
<Box className="flex flex-col gap-1">
{accountOptions.map((item) => (
<Link
as={NextLink}
key={item.name}
href={item.href}
aria-label={item.desc}
+38 -1
View File
@@ -1,9 +1,11 @@
// https://nextjs.org/docs/basic-features/layouts
import type { NextPage } from "next";
import { FiLayout, FiMessageSquare, FiUsers } from "react-icons/fi";
import { Header } from "src/components/Header";
import { Footer } from "./Footer";
import { SideMenuLayout } from "./SideMenuLayout";
export type NextPageWithLayout<P = unknown, IP = P> = NextPage<P, IP> & {
getLayout?: (page: React.ReactElement) => React.ReactNode;
@@ -28,7 +30,42 @@ export const getTransparentHeaderLayout = (page: React.ReactElement) => (
export const getDashboardLayout = (page: React.ReactElement) => (
<div className="grid grid-rows-[min-content_1fr_min-content] h-full justify-items-stretch">
<Header transparent={true} />
{page}
<SideMenuLayout
menuButtonOptions={[
{
label: "Dashboard",
pathname: "/dashboard",
desc: "Dashboard Home",
icon: FiLayout,
},
{
label: "Messages",
pathname: "/messages",
desc: "Messages Dashboard",
icon: FiMessageSquare,
},
]}
>
{page}
</SideMenuLayout>
</div>
);
export const getAdminLayout = (page: React.ReactElement) => (
<div className="grid grid-rows-[min-content_1fr_min-content] h-full justify-items-stretch">
<Header transparent={true} />
<SideMenuLayout
menuButtonOptions={[
{
label: "Users",
pathname: "/admin",
desc: "Users Dashboard",
icon: FiUsers,
},
]}
>
{page}
</SideMenuLayout>
</div>
);
+18 -16
View File
@@ -1,5 +1,6 @@
import { Grid } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import { useMemo } from "react";
import { FlaggableElement } from "./FlaggableElement";
@@ -8,29 +9,30 @@ export interface Message {
is_assistant: boolean;
}
const getBgColor = (isAssistant: boolean, colorMode: "light" | "dark") => {
if (colorMode === "light") {
return isAssistant ? "bg-slate-800" : "bg-sky-900";
} else {
return isAssistant ? "bg-black" : "bg-sky-900";
}
};
export const Messages = ({ messages, post_id }: { messages: Message[]; post_id: string }) => {
const { colorMode } = useColorMode();
const items = messages.map((messageProps: Message, i: number) => {
const { text } = messageProps;
const items = messages.map(({ text, is_assistant }: Message, i: number) => {
return (
<FlaggableElement text={text} post_id={post_id} key={i + text}>
<div
key={i + text}
className={`${getBgColor(is_assistant, colorMode)} p-4 rounded-md text-white whitespace-pre-wrap`}
>
{text}
</div>
<MessageView {...messageProps} />
</FlaggableElement>
);
});
// Maybe also show a legend of the colors?
return <Grid gap={2}>{items}</Grid>;
};
export const MessageView = ({ is_assistant, text }: Message) => {
const { colorMode } = useColorMode();
const bgColor = useMemo(() => {
if (colorMode === "light") {
return is_assistant ? "bg-slate-800" : "bg-sky-900";
} else {
return is_assistant ? "bg-black" : "bg-sky-900";
}
}, [colorMode, is_assistant]);
return <div className={`${bgColor} p-4 rounded-md text-white whitespace-pre-wrap`}>{text}</div>;
};
@@ -1,5 +1,5 @@
import { Box, CircularProgress, Stack, StackDivider, useColorModeValue } from "@chakra-ui/react";
import { MessageTableEntry } from "./MessageTableEntry";
import { Stack, StackDivider } from "@chakra-ui/react";
import { MessageTableEntry } from "src/components/Messages/MessageTableEntry";
export function MessageTable({ messages }) {
return (
@@ -1,9 +1,19 @@
import { Avatar, Box, HStack, LinkBox, useColorModeValue } from "@chakra-ui/react";
import { Avatar, HStack, LinkBox, useColorModeValue } from "@chakra-ui/react";
import { boolean } from "boolean";
import NextLink from "next/link";
import { FlaggableElement } from "../FlaggableElement";
import { FlaggableElement } from "src/components/FlaggableElement";
export function MessageTableEntry({ item, idx }) {
interface Message {
text: string;
id: string;
is_assistant: boolean;
}
interface MessageTableEntryProps {
item: Message;
idx: number;
}
export function MessageTableEntry(props: MessageTableEntryProps) {
const { item, idx } = props;
const bgColor = useColorModeValue(idx % 2 === 0 ? "bg-slate-800" : "bg-black", "bg-sky-900");
return (
@@ -0,0 +1,104 @@
import { Box, CircularProgress, Flex, HStack, StackDivider, StackProps, Text, TextProps } from "@chakra-ui/react";
import { boolean } from "boolean";
import { useState } from "react";
import { MessageTableEntry } from "src/components/Messages/MessageTableEntry";
import fetcher from "src/lib/fetcher";
import useSWR from "swr";
const MessageHeaderProps: TextProps = {
align: "center",
fontSize: "xl",
py: "2",
};
const MessageStackProps: StackProps = {
spacing: "2",
alignItems: "start",
justifyContent: "center",
divider: <StackDivider />,
};
interface MessageWithChildrenProps {
id: string;
depth?: number;
maxDepth?: number;
isOnlyChild?: boolean;
}
export function MessageWithChildren(props: MessageWithChildrenProps) {
const { id, depth, maxDepth, isOnlyChild = true } = props;
const [message, setMessage] = useState(null);
const [children, setChildren] = useState(null);
const { isLoading } = useSWR(id ? `/api/messages/${id}` : null, fetcher, {
onSuccess: (data) => {
setMessage(data);
},
onError: () => {
setMessage(null);
},
});
const { isLoading: isLoadingChildren } = useSWR(id ? `/api/messages/${id}/children` : null, fetcher, {
onSuccess: (data) => {
setChildren(data);
},
onError: () => {
setChildren(null);
},
});
const renderRecursive = maxDepth && ((depth && depth < maxDepth) || !depth);
const isFirst = depth === 0 || !depth;
const isFirstOrOnly = isFirst || boolean(isOnlyChild);
if (isLoading || isLoadingChildren) {
return <CircularProgress isIndeterminate />;
}
return (
<>
{message && (
<>
<Text {...MessageHeaderProps}>{isFirst ? "Message" : depth === 1 ? "Children" : "Ancestor"}</Text>
<Flex justifyContent="center" pb="2">
<Box maxWidth="container.sm" flex="1" px={isFirstOrOnly ? [4, 6, 8, 9] : "0"}>
<Box px={isFirstOrOnly ? "2" : "0"}>
<MessageTableEntry item={message} idx={1} />
</Box>
</Box>
</Flex>
</>
)}
{children && Array.isArray(children) && children.length > 0 ? (
renderRecursive ? (
<HStack {...MessageStackProps}>
{children.map((item, idx) => (
<Box flex="1" key={`recursiveMessageWChildren_${idx}`}>
<MessageWithChildren
id={item.id}
depth={depth ? depth + 1 : 1}
maxDepth={maxDepth}
isOnlyChild={children.length === 1 && isOnlyChild}
/>
</Box>
))}
</HStack>
) : (
<>
<Text {...MessageHeaderProps}>{isFirstOrOnly ? "Children" : "Ancestor"}</Text>
<HStack {...MessageStackProps}>
{children.map((item, idx) => (
<Box maxWidth="container.sm" flex="1" key={`recursiveMessageWChildren_${idx}`}>
<MessageTableEntry item={item} idx={idx * 2} />
</Box>
))}
</HStack>
</>
)
) : (
<></>
)}
</>
);
}
@@ -1,37 +1,24 @@
import { Box, Button, Link, Text, Tooltip, useColorMode } from "@chakra-ui/react";
import { Box, Button, Text, Tooltip, useColorMode } from "@chakra-ui/react";
import Link from "next/link";
import { useRouter } from "next/router";
import { FiLayout, FiSun, FiMessageSquare } from "react-icons/fi";
import { FiSun } from "react-icons/fi";
import { IconType } from "react-icons/lib";
import { colors } from "styles/Theme/colors";
export function SideMenu() {
export interface MenuButtonOption {
label: string;
pathname: string;
desc: string;
icon: IconType;
}
export interface SideMenuProps {
buttonOptions: MenuButtonOption[];
}
export function SideMenu(props: SideMenuProps) {
const router = useRouter();
const { colorMode, toggleColorMode } = useColorMode();
const buttonOptions = [
{
label: "Dashboard",
pathname: "/dashboard",
desc: "Dashboard Home",
icon: FiLayout,
},
{
label: "Messages",
pathname: "/messages",
desc: "Messages Dashboard",
icon: FiMessageSquare,
},
// {
// label: "Leaderboard",
// pathname: "#",
// desc: "Public Leaderboard",
// icon: FiAward,
// },
// {
// label: "Stats",
// pathname: "#",
// desc: "User Statistics",
// icon: FiBarChart2,
// },
];
return (
<main className="sticky top-0 sm:h-full">
@@ -43,7 +30,7 @@ export function SideMenu() {
className="grid grid-cols-4 gap-2 sm:flex sm:flex-col sm:justify-between p-4 h-full"
>
<nav className="grid grid-cols-3 col-span-3 sm:flex sm:flex-col gap-2">
{buttonOptions.map((item, itemIndex) => (
{props.buttonOptions.map((item, itemIndex) => (
<Tooltip
key={itemIndex}
fontFamily="inter"
+23
View File
@@ -0,0 +1,23 @@
import { Box, useColorMode } from "@chakra-ui/react";
import { MenuButtonOption, SideMenu } from "src/components/SideMenu";
import { colors } from "styles/Theme/colors";
interface SideMenuLayoutProps {
menuButtonOptions: MenuButtonOption[];
children: React.ReactNode;
}
export const SideMenuLayout = (props: SideMenuLayoutProps) => {
const { colorMode } = useColorMode();
return (
<Box backgroundColor={colorMode === "light" ? colors.light.bg : colors.dark.bg} className="sm:overflow-hidden">
<Box className="sm:flex h-full gap-6">
<Box className="p-6 sm:pr-0">
<SideMenu buttonOptions={props.menuButtonOptions} />
</Box>
<Box className="flex flex-col overflow-auto p-6 sm:pl-0 gap-14">{props.children}</Box>
</Box>
</Box>
);
};
+12 -9
View File
@@ -1,5 +1,6 @@
import { useColorMode } from "@chakra-ui/react";
import { Flex } from "@chakra-ui/react";
import clsx from "clsx";
import { SkipButton } from "src/components/Buttons/Skip";
import { SubmitButton } from "src/components/Buttons/Submit";
import { TaskInfo } from "src/components/TaskInfo/TaskInfo";
@@ -14,18 +15,20 @@ export interface TaskControlsProps {
}
export const TaskControls = (props: TaskControlsProps) => {
const extraClases = props.className || "";
const { colorMode } = useColorMode();
const baseClasses = "flex flex-row justify-items-stretch mb-8 p-4 rounded-lg max-w-7xl mx-auto";
const taskControlClases =
colorMode === "light"
? `${baseClasses} bg-white text-gray-800 shadow-lg ${extraClases}`
: `${baseClasses} bg-slate-800 text-slate-400 shadow-xl ring-1 ring-white/10 ring-inset ${extraClases}`;
const isLightMode = colorMode === "light";
const endTask = props.tasks[props.tasks.length - 1];
return (
<section className={taskControlClases}>
<section
className={clsx(
"flex-row justify-items-stretch mb-8 p-4 rounded-lg max-w-7xl mx-auto space-y-4 sm:space-y-0 sm:flex",
props.className,
{
"bg-white text-gray-800 shadow-lg": isLightMode,
"bg-slate-800 text-slate-400 shadow-xl ring-1 ring-white/10 ring-inset": !isLightMode,
}
)}
>
<TaskInfo id={props.tasks[0].id} output="Submit your answer" />
<Flex justify="center" ml="auto" gap={2}>
<SkipButton>Skip</SkipButton>
@@ -10,7 +10,7 @@ import {
ModalOverlay,
useDisclosure,
} from "@chakra-ui/react";
import { TaskControls, TaskControlsProps } from "./TaskControls";
import { TaskControls, TaskControlsProps } from "src/components/Survey/TaskControls";
interface TaskControlsOverridableProps extends TaskControlsProps {
isValid: boolean;
@@ -12,7 +12,7 @@ interface TrackedTextboxProps {
}
export const TrackedTextarea = (props: TrackedTextboxProps) => {
const wordCount = props.text.split(" ").length - 1;
const wordCount = (props.text.match(/\w+/g) || []).length;
let progressColor: string;
switch (true) {
@@ -28,7 +28,7 @@ export const TrackedTextarea = (props: TrackedTextboxProps) => {
return (
<Stack direction={"column"}>
<Textarea data-cy="reply" value={props.text} onChange={props.onTextChange} {...props.textareaProps} onCapture />
<Textarea data-cy="reply" value={props.text} onChange={props.onTextChange} {...props.textareaProps} />
<Progress size={"md"} rounded={"md"} value={wordCount} colorScheme={progressColor} max={props.thresholds.goal} />
</Stack>
);
+1 -1
View File
@@ -1,6 +1,6 @@
export const TaskInfo = ({ id, output }: { id: string; output: string }) => {
return (
<div className="grid grid-cols-[min-content_auto] gap-x-2 ">
<div className="grid grid-cols-[min-content_auto] gap-x-2">
<b>Prompt</b>
<span data-cy="task-id">{id}</span>
<b>Output</b>
@@ -1,39 +0,0 @@
import { Card, CardBody, Flex, Heading } from "@chakra-ui/react";
import Image from "next/image";
import Link from "next/link";
export type OptionProps = {
img: string;
alt: string;
title: string;
link: string;
};
export const TaskOption = (props: OptionProps) => {
const { alt, img, title, link } = props;
return (
<Link href={link}>
<Card
maxW="300"
minW="300"
minH="300"
maxH="300"
className="transition ease-in-out duration-500 sm:grayscale hover:grayscale-0"
>
<CardBody width="full" height="full">
<Flex direction="column" alignItems="center" justifyContent="center">
<Image src={img} alt={alt} width={200} height={200} />
<Heading
mt={-10}
className="bg-gradient-to-r from-indigo-600 via-sky-400 to-indigo-700 bg-clip-text tracking-tight text-transparent"
textAlign="center"
fontSize="3xl"
>
{title}
</Heading>
</Flex>
</CardBody>
</Card>
</Link>
);
};
@@ -1,23 +0,0 @@
import { Divider, Flex, Heading } from "@chakra-ui/react";
import React from "react";
export type TaskOptionsProps = {
title: string;
children: JSX.Element | JSX.Element[];
};
export const TaskOptions = (props: TaskOptionsProps) => {
const { title, children } = props;
return (
<Flex gap={10} wrap="wrap" justifyContent="center">
<Heading
className="bg-gradient-to-r from-indigo-600 via-sky-400 to-indigo-700 bg-clip-text tracking-tight text-transparent"
fontSize="5xl"
>
{title}
</Heading>
<Divider mt={-8} />
{children}
</Flex>
);
};
@@ -1,73 +0,0 @@
import { Flex } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import React from "react";
import { TaskOption } from "./TaskOption";
import { TaskOptions } from "./TaskOptions";
export const TaskSelection = () => {
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
return (
<Flex
gap={10}
wrap="wrap"
justifyContent="space-evenly"
width="full"
height="full"
alignItems={"center"}
className={mainBgClasses}
>
<TaskOptions key="create" title="Create">
{/* <TaskOption
alt="Summarize Stories"
img="/images/logos/logo.svg"
title="Summarize stories"
link="/create/summarize_story"
/> */}
<TaskOption
alt="Create Initial Prompt"
img="/images/logos/logo.svg"
title="Create Initial Prompt"
link="/create/initial_prompt"
/>
<TaskOption alt="Reply as User" img="/images/logos/logo.svg" title="Reply as User" link="/create/user_reply" />
<TaskOption
alt="Reply as Assistant"
img="/images/logos/logo.svg"
title="Reply as Assistant"
link="/create/assistant_reply"
/>
</TaskOptions>
<TaskOptions key="evaluate" title="Evaluate">
{/*
Commented out while the backend does not support them.
<TaskOption
alt="Rate Prompts"
img="/images/logos/logo.svg"
title="Rate Prompts"
link="/evaluate/rate_summary"
/> */}
<TaskOption
alt="Rank Initial Prompts"
img="/images/logos/logo.svg"
title="Rank Initial Prompts"
link="/evaluate/rank_initial_prompts"
/>
<TaskOption
alt="Rank User Replies"
img="/images/logos/logo.svg"
title="Rank User Replies"
link="/evaluate/rank_user_replies"
/>
<TaskOption
alt="Rank Assistant Replies"
img="/images/logos/logo.svg"
title="Rank Assistant Replies"
link="/evaluate/rank_assistant_replies"
/>
</TaskOptions>
</Flex>
);
};
@@ -1,3 +0,0 @@
export { TaskOption } from "./TaskOption";
export { TaskOptions } from "./TaskOptions";
export { TaskSelection } from "./TaskSelection";
@@ -0,0 +1,54 @@
import { useState } from "react";
import { Messages } from "src/components/Messages";
import { TaskControls } from "src/components/Survey/TaskControls";
import { TrackedTextarea } from "src/components/Survey/TrackedTextarea";
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
export const CreateTask = ({ tasks, taskType, trigger, mutate, mainBgClasses }) => {
const task = tasks[0].task;
const [inputText, setInputText] = useState("");
const submitResponse = (task: { id: string }) => {
const text = inputText.trim();
trigger({
id: task.id,
update_type: "text_reply_to_message",
content: {
text,
},
});
};
const fetchNextTask = () => {
setInputText("");
mutate();
};
const textChangeHandler = (event: React.ChangeEvent<HTMLTextAreaElement>) => {
setInputText(event.target.value);
};
return (
<div className={`p-12 ${mainBgClasses}`}>
<TwoColumnsWithCards>
<>
<h5 className="text-lg font-semibold">{taskType.label}</h5>
<p className="text-lg py-1">{taskType.overview}</p>
{task.conversation ? <Messages messages={task.conversation.messages} post_id={task.id} /> : null}
</>
<>
<h5 className="text-lg font-semibold">{taskType.instruction}</h5>
<TrackedTextarea
text={inputText}
onTextChange={textChangeHandler}
thresholds={{ low: 20, medium: 40, goal: 50 }}
textareaProps={{ placeholder: "Reply..." }}
/>
</>
</TwoColumnsWithCards>
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={fetchNextTask} />
</div>
);
};
@@ -0,0 +1,52 @@
import { useState } from "react";
import { Sortable } from "src/components/Sortable/Sortable";
import { SurveyCard } from "src/components/Survey/SurveyCard";
import { TaskControlsOverridable } from "src/components/Survey/TaskControlsOverridable";
import { MessageTable } from "../Messages/MessageTable";
export const EvaluateTask = ({ tasks, trigger, mutate, mainBgClasses }) => {
const [ranking, setRanking] = useState<number[]>([]);
const submitResponse = (task) => {
trigger({
id: task.id,
update_type: "message_ranking",
content: {
ranking,
},
});
};
const fetchNextTask = () => {
setRanking([]);
mutate();
};
let messages = null;
if (tasks[0].task.conversation) {
messages = tasks[0].task.conversation.messages;
messages = messages.map((message, index) => ({ ...message, id: index }));
}
const sortables = tasks[0].task.replies ? "replies" : "prompts";
return (
<div className={`p-12 ${mainBgClasses}`}>
<SurveyCard className="max-w-7xl mx-auto h-fit mb-24">
<h5 className="text-lg font-semibold mb-4">Instructions</h5>
<p className="text-lg py-1">
Given the following {sortables}, sort them from best to worst, best being first, worst being last.
</p>
{messages ? <MessageTable messages={messages} /> : null}
<Sortable items={tasks[0].task[sortables]} onChange={setRanking} className="my-8" />
</SurveyCard>
<TaskControlsOverridable
tasks={tasks}
isValid={ranking.length == tasks[0].task[sortables].length}
prepareForSubmit={() => setRanking(tasks[0].task[sortables].map((_, idx) => idx))}
onSubmitResponse={submitResponse}
onSkip={fetchNextTask}
/>
</div>
);
};
+28
View File
@@ -0,0 +1,28 @@
import { CreateTask } from "./CreateTask";
import { EvaluateTask } from "./EvaluateTask";
import { TaskCategory, TaskTypes } from "./TaskTypes";
export const Task = ({ tasks, trigger, mutate, mainBgClasses }) => {
const task = tasks[0].task;
function taskTypeComponent(type) {
const taskType = TaskTypes.find((taskType) => taskType.type === type);
const category = taskType.category;
switch (category) {
case TaskCategory.Create:
return (
<CreateTask
tasks={tasks}
trigger={trigger}
mutate={mutate}
taskType={taskType}
mainBgClasses={mainBgClasses}
/>
);
case TaskCategory.Evaluate:
return <EvaluateTask tasks={tasks} trigger={trigger} mutate={mutate} mainBgClasses={mainBgClasses} />;
}
}
return taskTypeComponent(task.type);
};
@@ -0,0 +1,66 @@
export enum TaskCategory {
Create = "Create",
Evaluate = "Evaluate",
Label = "Label",
}
export const TaskTypes = [
// create
{
label: "Create Initial Prompts",
desc: "Write initial prompts to help Open Assistant to try replying to diverse messages.",
category: TaskCategory.Create,
pathname: "/create/initial_prompt",
type: "initial_prompt",
overview: "Create an initial message to send to the assistant",
instruction: "Provide the initial prompt",
},
{
label: "Reply as User",
desc: "Chat with Open Assistant and help improve its responses as you interact with it.",
category: TaskCategory.Create,
pathname: "/create/user_reply",
type: "prompter_reply",
overview: "Given the following conversation, provide an adequate reply",
instruction: "Provide the user`s reply",
},
{
label: "Reply as Assistant",
desc: "Help Open Assistant improve its responses to conversations with other users.",
category: TaskCategory.Create,
pathname: "/create/assistant_reply",
type: "assistant_reply",
overview: "Given the following conversation, provide an adequate reply",
instruction: "Provide the assistant`s reply",
},
// evaluate
{
label: "Rank User Replies",
category: TaskCategory.Evaluate,
desc: "Help Open Assistant improve its responses to conversations with other users.",
pathname: "/evaluate/rank_user_replies",
type: "rank_prompter_replies",
},
{
label: "Rank Assistant Replies",
desc: "Score prompts given by Open Assistant based on their accuracy and readability.",
category: TaskCategory.Evaluate,
pathname: "/evaluate/rank_assistant_replies",
type: "rank_assistant_replies",
},
{
label: "Rank Initial Prompts",
desc: "Score prompts given by Open Assistant based on their accuracy and readability.",
category: TaskCategory.Evaluate,
pathname: "/evaluate/rank_initial_prompts",
type: "rank_initial_prompts",
},
// label
{
label: "Label Initial Prompt",
desc: "Provide labels for a prompt.",
category: TaskCategory.Label,
pathname: "/label/label_initial_prompt",
type: "label_initial_prompt",
},
];
+44
View File
@@ -0,0 +1,44 @@
import { Table, TableCaption, TableContainer, Tbody, Td, Th, Thead, Tr } from "@chakra-ui/react";
import { useState } from "react";
import fetcher from "src/lib/fetcher";
import useSWR from "swr";
/**
* Fetches users from the users api route and then presents them in a simple Chakra table.
*/
const UsersCell = () => {
// Fetch and save the users.
const [users, setUsers] = useState([]);
const { isLoading } = useSWR("/api/admin/users", fetcher, {
onSuccess: setUsers,
});
// Present users in a naive table.
return (
<TableContainer>
<Table variant="simple">
<TableCaption>Users</TableCaption>
<Thead>
<Tr>
<Th>Id</Th>
<Th>Email</Th>
<Th>Name</Th>
<Th>Role</Th>
</Tr>
</Thead>
<Tbody>
{users.map((user, index) => (
<Tr key={index}>
<Td>{user.id}</Td>
<Td>{user.email}</Td>
<Td>{user.name}</Td>
<Td>{user.role}</Td>
</Tr>
))}
</Tbody>
</Table>
</TableContainer>
);
};
export default UsersCell;
+1 -1
View File
@@ -1,5 +1,5 @@
import { Container } from "./Container";
import Image from "next/image";
import { Container } from "src/components/Container";
const Vision = () => {
return (
+52
View File
@@ -0,0 +1,52 @@
import { useEffect, useState } from "react";
import fetcher from "src/lib/fetcher";
import poster from "src/lib/poster";
import useSWRImmutable from "swr/immutable";
import useSWRMutation from "swr/mutation";
// TODO: type & centralize types for all tasks
interface TaskResponse<TaskType> {
id: string;
userId: string;
task: TaskType;
}
export interface LabelInitialPromptTask {
id: string;
message_id: string;
prompt: string;
type: string;
valid_labels: string[];
}
export type LabelInitialPromptTaskResponse = TaskResponse<LabelInitialPromptTask>;
export const useLabelingTask = <LabelingTaskType>({ taskApiEndpoint }: { taskApiEndpoint: "label_initial_prompt" }) => {
type ConcreteTaskResponse = TaskResponse<LabelingTaskType>;
const [tasks, setTasks] = useState<Array<ConcreteTaskResponse>>([]);
const { isLoading, mutate, error } = useSWRImmutable("/api/new_task/" + taskApiEndpoint, fetcher, {
onSuccess: (data: ConcreteTaskResponse) => {
setTasks([data]);
},
});
useEffect(() => {
if (tasks.length === 0 && !isLoading && !error) {
mutate();
}
}, [tasks, isLoading, mutate, error]);
const { trigger } = useSWRMutation("/api/update_task", poster, {
onSuccess: async (reply) => {
const newTask: ConcreteTaskResponse = await reply.json();
setTasks((oldTasks) => [...oldTasks, newTask]);
},
});
const submit = (id: string, message_id: string, text: string, labels: Record<string, string>) =>
trigger({ id, update_type: "text_labels", content: { labels, text, message_id } });
return { tasks, isLoading, submit, error, reset: mutate };
};
+1 -1
View File
@@ -42,7 +42,7 @@ export class OasstApiClient {
} catch (e) {
throw new OasstError(errorText, 0, resp.status);
}
throw new OasstError(error.message, error.error_code, resp.status);
throw new OasstError(error.message ?? error, error.error_code, resp.status);
}
return await resp.json();
+3 -3
View File
@@ -1,9 +1,9 @@
import Image from "next/image";
import { CallToAction } from "src/components/CallToAction";
import { Container } from "src/components/Container";
import Roadmap from "src/components/Roadmap";
import Services from "src/components/Services";
import Vision from "src/components/Vision";
import Roadmap from "src/components/Roadmap";
import { CallToAction } from "src/components/CallToAction";
import Image from "next/image";
const AboutPage = () => {
return (
+49
View File
@@ -0,0 +1,49 @@
import Head from "next/head";
import { useRouter } from "next/router";
import { useSession } from "next-auth/react";
import { useEffect } from "react";
import { getAdminLayout } from "src/components/Layout";
import UsersCell from "src/components/UsersCell";
/**
* Provides the admin index page that will display a list of users and give
* admins the ability to manage their access rights.
*/
const AdminIndex = () => {
const router = useRouter();
const { data: session, status } = useSession();
// Check when the user session is loaded and re-route if the user is not an
// admin. This follows the suggestion by NextJS for handling private pages:
// https://nextjs.org/docs/api-reference/next/router#usage
//
// All admin pages should use the same check and routing steps.
useEffect(() => {
if (status === "loading") {
return;
}
if (session?.user?.role === "admin") {
return;
}
router.push("/");
}, [session, status]);
// Show the final page.
// TODO(#237): Display a component that fetches actual user data.
return (
<>
<Head>
<title>Open Assistant</title>
<meta
name="description"
content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world."
/>
</Head>
<main className="oa-basic-theme">{status === "loading" ? "loading..." : <UsersCell />}</main>
</>
);
};
AdminIndex.getLayout = getAdminLayout;
export default AdminIndex;
+31
View File
@@ -0,0 +1,31 @@
import { getToken } from "next-auth/jwt";
import client from "src/lib/prismadb";
/**
* Returns a list of user results from the database when the requesting user is
* a logged in admin.
*/
const handler = async (req, res) => {
const token = await getToken({ req });
// Return nothing if the user isn't registered or if the user isn't an admin.
if (!token || token.role !== "admin") {
res.status(403).end();
return;
}
// Fetch 20 users.
const users = await client.user.findMany({
select: {
id: true,
role: true,
name: true,
email: true,
},
take: 20,
});
res.status(200).json(users);
};
export default handler;
@@ -0,0 +1,27 @@
import { getToken } from "next-auth/jwt";
const handler = async (req, res) => {
const token = await getToken({ req });
// Return nothing if the user isn't registered.
if (!token) {
res.status(401).end();
return;
}
const { id } = req.query;
const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages/${id}/children`, {
method: "GET",
headers: {
"X-API-Key": process.env.FASTAPI_KEY,
"Content-Type": "application/json",
},
});
const messages = await messagesRes.json();
// Send recieved messages to the client.
res.status(200).json(messages);
};
export default handler;
@@ -0,0 +1,27 @@
import { getToken } from "next-auth/jwt";
const handler = async (req, res) => {
const token = await getToken({ req });
// Return nothing if the user isn't registered.
if (!token) {
res.status(401).end();
return;
}
const { id } = req.query;
const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages/${id}/conversation`, {
method: "GET",
headers: {
"X-API-Key": process.env.FASTAPI_KEY,
"Content-Type": "application/json",
},
});
const messages = await messagesRes.json();
// Send recieved messages to the client.
res.status(200).json(messages);
};
export default handler;
@@ -0,0 +1,27 @@
import { getToken } from "next-auth/jwt";
const handler = async (req, res) => {
const token = await getToken({ req });
// Return nothing if the user isn't registered.
if (!token) {
res.status(401).end();
return;
}
const { id } = req.query;
const messageRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages/${id}`, {
method: "GET",
headers: {
"X-API-Key": process.env.FASTAPI_KEY,
"Content-Type": "application/json",
},
});
const message = await messageRes.json();
// Send recieved messages to the client.
res.status(200).json(message);
};
export default handler;
@@ -0,0 +1,48 @@
import { getToken } from "next-auth/jwt";
const handler = async (req, res) => {
const token = await getToken({ req });
// Return nothing if the user isn't registered.
if (!token) {
res.status(401).end();
return;
}
const { id } = req.query;
if (!id) {
res.status(400).end();
return;
}
const messageRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages/${id}`, {
method: "GET",
headers: {
"X-API-Key": process.env.FASTAPI_KEY,
"Content-Type": "application/json",
},
});
const message = await messageRes.json();
if (!message.parent_id) {
res.status(404).end();
return;
}
const parentRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages/${message.parent_id}`, {
method: "GET",
headers: {
"X-API-Key": process.env.FASTAPI_KEY,
"Content-Type": "application/json",
},
});
const parent = await parentRes.json();
// Send recieved messages to the client.
res.status(200).json(parent);
};
export default handler;
+6 -1
View File
@@ -35,7 +35,12 @@ const handler = async (req, res) => {
},
});
const newTask = await oasstApiClient.interactTask(update_type, id, interaction.id, content, token);
let newTask;
try {
newTask = await oasstApiClient.interactTask(update_type, id, interaction.id, content, token);
} catch (err) {
return res.status(500).json(err);
}
// Stores the new task with our database.
const newRegisteredTask = await prisma.registeredTask.create({
+1 -1
View File
@@ -1,5 +1,5 @@
import { getSession } from "next-auth/react";
import prisma from "../../lib/prismadb";
import prisma from "src/lib/prismadb";
// POST /api/post
// Required fields in body: title
+9 -47
View File
@@ -1,11 +1,9 @@
import { Container } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
import { useEffect, useState } from "react";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Messages } from "src/components/Messages";
import { TaskControls } from "src/components/Survey/TaskControls";
import { TrackedTextarea } from "src/components/Survey/TrackedTextarea";
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
import { Task } from "src/components/Tasks/Task";
import fetcher from "src/lib/fetcher";
import poster from "src/lib/poster";
import useSWRImmutable from "swr/immutable";
@@ -13,7 +11,6 @@ import useSWRMutation from "swr/mutation";
const AssistantReply = () => {
const [tasks, setTasks] = useState([]);
const [inputText, setInputText] = useState("");
const { isLoading, mutate } = useSWRImmutable("/api/new_task/assistant_reply ", fetcher, {
onSuccess: (data) => {
@@ -34,26 +31,6 @@ const AssistantReply = () => {
},
});
const submitResponse = (task: { id: string }) => {
const text = inputText.trim();
trigger({
id: task.id,
update_type: "text_reply_to_message",
content: {
text,
},
});
};
const fetchNextTask = () => {
setInputText("");
mutate();
};
const textChangeHandler = (event: React.ChangeEvent<HTMLTextAreaElement>) => {
setInputText(event.target.value);
};
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
@@ -65,29 +42,14 @@ const AssistantReply = () => {
return <Container className="p-6 text-center text-gray-800">No tasks found...</Container>;
}
const task = tasks[0].task;
return (
<div className={`p-12 ${mainBgClasses}`}>
<TwoColumnsWithCards>
<>
<h5 className="text-lg font-semibold">Reply as the assistant</h5>
<p className="text-lg py-1">Given the following conversation, provide an adequate reply</p>
<Messages messages={task.conversation.messages} post_id={task.id} />
</>
<>
<h5 className="text-lg font-semibold">Provide the assistant`s reply</h5>
<TrackedTextarea
text={inputText}
onTextChange={textChangeHandler}
thresholds={{ low: 20, medium: 40, goal: 50 }}
textareaProps={{ placeholder: "Reply..." }}
/>
</>
</TwoColumnsWithCards>
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={fetchNextTask} />
</div>
<>
<Head>
<title>Reply as Assistant</title>
<meta name="description" content="Reply as Assistant." />
</Head>
<Task tasks={tasks} trigger={trigger} mutate={mutate} mainBgClasses={mainBgClasses} />
</>
);
};
+9 -43
View File
@@ -1,10 +1,9 @@
import { Container } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
import { useEffect, useState } from "react";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { TaskControls } from "src/components/Survey/TaskControls";
import { TrackedTextarea } from "src/components/Survey/TrackedTextarea";
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
import { Task } from "src/components/Tasks/Task";
import fetcher from "src/lib/fetcher";
import poster from "src/lib/poster";
import useSWRImmutable from "swr/immutable";
@@ -12,7 +11,6 @@ import useSWRMutation from "swr/mutation";
const InitialPrompt = () => {
const [tasks, setTasks] = useState([]);
const [inputText, setInputText] = useState("");
const { isLoading, mutate } = useSWRImmutable("/api/new_task/initial_prompt ", fetcher, {
onSuccess: (data) => {
@@ -33,26 +31,6 @@ const InitialPrompt = () => {
}
}, [tasks]);
const submitResponse = (task: { id: string }) => {
const text = inputText.trim();
trigger({
id: task.id,
update_type: "text_reply_to_message",
content: {
text,
},
});
};
const fetchNextTask = () => {
setInputText("");
mutate();
};
const textChangeHandler = (event: React.ChangeEvent<HTMLTextAreaElement>) => {
setInputText(event.target.value);
};
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
@@ -65,25 +43,13 @@ const InitialPrompt = () => {
}
return (
<div className={`p-12 ${mainBgClasses}`}>
<TwoColumnsWithCards>
<>
<h5 className="text-lg font-semibold">Start a conversation</h5>
<p className="text-lg py-1">Create an initial message to send to the assistant</p>
</>
<>
<h5 className="text-lg font-semibold">Provide the initial prompt</h5>
<TrackedTextarea
text={inputText}
onTextChange={textChangeHandler}
thresholds={{ low: 20, medium: 40, goal: 50 }}
textareaProps={{ placeholder: "Question, task, greeting or similar..." }}
/>
</>
</TwoColumnsWithCards>
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={fetchNextTask} />
</div>
<>
<Head>
<title>Reply as Assistant</title>
<meta name="description" content="Reply as Assistant." />
</Head>
<Task tasks={tasks} trigger={trigger} mutate={mutate} mainBgClasses={mainBgClasses} />
</>
);
};
+9 -48
View File
@@ -1,10 +1,8 @@
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
import { useEffect, useState } from "react";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Messages } from "src/components/Messages";
import { TaskControls } from "src/components/Survey/TaskControls";
import { TrackedTextarea } from "src/components/Survey/TrackedTextarea";
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
import { Task } from "src/components/Tasks/Task";
import fetcher from "src/lib/fetcher";
import poster from "src/lib/poster";
import useSWRImmutable from "swr/immutable";
@@ -12,7 +10,6 @@ import useSWRMutation from "swr/mutation";
const UserReply = () => {
const [tasks, setTasks] = useState([]);
const [inputText, setInputText] = useState("");
const { isLoading, mutate } = useSWRImmutable("/api/new_task/prompter_reply", fetcher, {
onSuccess: (data) => {
@@ -33,26 +30,6 @@ const UserReply = () => {
},
});
const submitResponse = (task: { id: string }) => {
const text = inputText.trim();
trigger({
id: task.id,
update_type: "text_reply_to_message",
content: {
text,
},
});
};
const fetchNextTask = () => {
setInputText("");
mutate();
};
const textChangeHandler = (event: React.ChangeEvent<HTMLTextAreaElement>) => {
setInputText(event.target.value);
};
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
@@ -70,30 +47,14 @@ const UserReply = () => {
);
}
const task = tasks[0].task;
return (
<div className={`p-12 ${mainBgClasses}`}>
<TwoColumnsWithCards>
<>
<h5 className="text-lg font-semibold">Reply as a user</h5>
<p className="text-lg py-1">Given the following conversation, provide an adequate reply</p>
<Messages messages={task.conversation.messages} post_id={task.id} />
{task.hint && <p className="text-lg py-1">Hint: {task.hint}</p>}
</>
<>
<h5 className="text-lg font-semibold">Provide the user`s reply</h5>
<TrackedTextarea
text={inputText}
onTextChange={textChangeHandler}
thresholds={{ low: 20, medium: 40, goal: 50 }}
textareaProps={{ placeholder: "Reply..." }}
/>
</>
</TwoColumnsWithCards>
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={fetchNextTask} />
</div>
<>
<Head>
<title>Reply as Assistant</title>
<meta name="description" content="Reply as Assistant." />
</Head>
<Task tasks={tasks} trigger={trigger} mutate={mutate} mainBgClasses={mainBgClasses} />
</>
);
};
+3 -16
View File
@@ -1,29 +1,16 @@
import { Box, useColorMode } from "@chakra-ui/react";
import Head from "next/head";
import { LeaderboardTable, TaskOption } from "src/components/Dashboard";
import { getDashboardLayout } from "src/components/Layout";
import { LeaderboardTable, SideMenu, TaskOption } from "src/components/Dashboard";
import { colors } from "styles/Theme/colors";
const Dashboard = () => {
const { colorMode } = useColorMode();
return (
<>
<Head>
<title>Dashboard - Open Assistant</title>
<meta name="description" content="Chat with Open Assistant and provide feedback." />
</Head>
<Box backgroundColor={colorMode === "light" ? colors.light.bg : colors.dark.bg} className="sm:overflow-hidden">
<Box className="sm:flex h-full gap-6">
<Box className="p-6 sm:pr-0">
<SideMenu />
</Box>
<Box className="flex flex-col overflow-auto p-6 sm:pl-0 gap-14">
<TaskOption />
<LeaderboardTable />
</Box>
</Box>
</Box>
<TaskOption />
<LeaderboardTable />
</>
);
};
@@ -1,12 +1,8 @@
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
import { useEffect, useState } from "react";
import { ContextMessages } from "src/components/ContextMessages";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Message } from "src/components/Messages";
import { Sortable } from "src/components/Sortable/Sortable";
import { SurveyCard } from "src/components/Survey/SurveyCard";
import { TaskControlsOverridable } from "src/components/Survey/TaskControlsOverridable";
import { Task } from "src/components/Tasks/Task";
import fetcher from "src/lib/fetcher";
import poster from "src/lib/poster";
import useSWRImmutable from "swr/immutable";
@@ -14,11 +10,6 @@ import useSWRMutation from "swr/mutation";
const RankAssistantReplies = () => {
const [tasks, setTasks] = useState([]);
/**
* This array will contain the ranked indices of the replies
* The best reply will have index 0, and the worst is the last.
*/
const [ranking, setRanking] = useState<number[]>([]);
const { isLoading, mutate } = useSWRImmutable("/api/new_task/rank_assistant_replies", fetcher, {
onSuccess: (data) => {
@@ -39,21 +30,6 @@ const RankAssistantReplies = () => {
},
});
const submitResponse = (task) => {
trigger({
id: task.id,
update_type: "message_ranking",
content: {
ranking,
},
});
};
const fetchNextTask = () => {
setRanking([]);
mutate();
};
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
@@ -71,33 +47,13 @@ const RankAssistantReplies = () => {
);
}
const replies = tasks[0].task.replies as string[];
const messages = tasks[0].task.conversation.messages as Message[];
return (
<>
<Head>
<title>Rank Assistant Replies</title>
<meta name="description" content="Rank Assistant Replies." />
</Head>
<div className={`p-12 ${mainBgClasses}`}>
<SurveyCard className="max-w-7xl mx-auto h-fit mb-24">
<h5 className="text-lg font-semibold mb-4">Instructions</h5>
<p className="text-lg py-1">
Given the following replies, sort them from best to worst, best being first, worst being last.
</p>
<ContextMessages messages={messages} />
<Sortable items={replies} onChange={setRanking} className="my-8" />
</SurveyCard>
<TaskControlsOverridable
tasks={tasks}
isValid={ranking.length == tasks[0].task.replies.length}
prepareForSubmit={() => setRanking(tasks[0].task.replies.map((_, idx) => idx))}
onSubmitResponse={submitResponse}
onSkip={fetchNextTask}
/>
</div>
<Task tasks={tasks} trigger={trigger} mutate={mutate} mainBgClasses={mainBgClasses} />
</>
);
};
@@ -2,9 +2,7 @@ import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
import { useEffect, useState } from "react";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Sortable } from "src/components/Sortable/Sortable";
import { SurveyCard } from "src/components/Survey/SurveyCard";
import { TaskControlsOverridable } from "src/components/Survey/TaskControlsOverridable";
import { Task } from "src/components/Tasks/Task";
import fetcher from "src/lib/fetcher";
import poster from "src/lib/poster";
import useSWRImmutable from "swr/immutable";
@@ -12,12 +10,6 @@ import useSWRMutation from "swr/mutation";
const RankInitialPrompts = () => {
const [tasks, setTasks] = useState([]);
/**
* This array will contain the ranked indices of the prompts
* The best prompt will have index 0, and the worst is the last.
*/
const [ranking, setRanking] = useState<number[]>([]);
// const bg = useColorModeValue("gray.100", "gray.800");
const { isLoading, mutate } = useSWRImmutable("/api/new_task/rank_initial_prompts", fetcher, {
onSuccess: (data) => {
@@ -38,21 +30,6 @@ const RankInitialPrompts = () => {
}
}, [tasks]);
const submitResponse = (task) => {
trigger({
id: task.id,
update_type: "message_ranking",
content: {
ranking,
},
});
};
const fetchNextTask = () => {
setRanking([]);
mutate();
};
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
@@ -76,23 +53,7 @@ const RankInitialPrompts = () => {
<title>Rank Initial Prompts</title>
<meta name="description" content="Rank initial prompts." />
</Head>
<div className={`p-12 ${mainBgClasses}`}>
<SurveyCard className="max-w-7xl mx-auto h-fit mb-24">
<h5 className="text-lg font-semibold mb-4">Instructions</h5>
<p className="text-lg py-1">
Given the following prompts, sort them from best to worst, best being first, worst being last.
</p>
<Sortable items={tasks[0].task.prompts} onChange={setRanking} className="my-8" />
</SurveyCard>
<TaskControlsOverridable
tasks={tasks}
isValid={ranking.length == tasks[0].task.prompts.length}
prepareForSubmit={() => setRanking(tasks[0].task.prompts.map((_, idx) => idx))}
onSubmitResponse={submitResponse}
onSkip={fetchNextTask}
/>
</div>
<Task tasks={tasks} trigger={trigger} mutate={mutate} mainBgClasses={mainBgClasses} />
</>
);
};
@@ -1,12 +1,8 @@
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
import { useEffect, useState } from "react";
import { ContextMessages } from "src/components/ContextMessages";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Message } from "src/components/Messages";
import { Sortable } from "src/components/Sortable/Sortable";
import { SurveyCard } from "src/components/Survey/SurveyCard";
import { TaskControlsOverridable } from "src/components/Survey/TaskControlsOverridable";
import { Task } from "src/components/Tasks/Task";
import fetcher from "src/lib/fetcher";
import poster from "src/lib/poster";
import useSWRImmutable from "swr/immutable";
@@ -14,11 +10,6 @@ import useSWRMutation from "swr/mutation";
const RankUserReplies = () => {
const [tasks, setTasks] = useState([]);
/**
* This array will contain the ranked indices of the replies
* The best reply will have index 0, and the worst is the last.
*/
const [ranking, setRanking] = useState<number[]>([]);
const { isLoading, mutate } = useSWRImmutable("/api/new_task/rank_prompter_replies", fetcher, {
onSuccess: (data) => {
@@ -39,21 +30,6 @@ const RankUserReplies = () => {
},
});
const submitResponse = (task) => {
trigger({
id: task.id,
update_type: "message_ranking",
content: {
ranking,
},
});
};
const fetchNextTask = () => {
setRanking([]);
mutate();
};
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
@@ -70,8 +46,6 @@ const RankUserReplies = () => {
</div>
);
}
const replies = tasks[0].task.replies as string[];
const messages = tasks[0].task.conversation.messages as Message[];
return (
<>
@@ -79,24 +53,7 @@ const RankUserReplies = () => {
<title>Rank User Replies</title>
<meta name="description" content="Rank User Replies." />
</Head>
<div className={`p-12 ${mainBgClasses}`}>
<SurveyCard className="max-w-7xl mx-auto h-fit mb-24">
<h5 className="text-lg font-semibold mb-4">Instructions</h5>
<p className="text-lg py-1">
Given the following replies, sort them from best to worst, best being first, worst being last.
</p>
<ContextMessages messages={messages} />
<Sortable items={replies} onChange={setRanking} className="my-8" />
</SurveyCard>
<TaskControlsOverridable
tasks={tasks}
isValid={ranking.length == tasks[0].task.replies.length}
prepareForSubmit={() => setRanking(tasks[0].task.replies.map((_, idx) => idx))}
onSubmitResponse={submitResponse}
onSkip={fetchNextTask}
/>
</div>
<Task tasks={tasks} trigger={trigger} mutate={mutate} mainBgClasses={mainBgClasses} />
</>
);
};
@@ -0,0 +1,113 @@
import { Container, Grid, Slider, SliderFilledTrack, SliderThumb, SliderTrack } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import { useEffect, useId, useState } from "react";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { MessageView } from "src/components/Messages";
import { TaskControls } from "src/components/Survey/TaskControls";
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
import { LabelInitialPromptTask, LabelInitialPromptTaskResponse, useLabelingTask } from "src/hooks/useLabelingTask";
import { colors } from "styles/Theme/colors";
const LabelInitialPrompt = () => {
const [sliderValues, setSliderValues] = useState<number[]>([]);
const { tasks, isLoading, submit, reset } = useLabelingTask<LabelInitialPromptTask>({
taskApiEndpoint: "label_initial_prompt",
});
const submitResponse = ({ id, task }: LabelInitialPromptTaskResponse) => {
const labels = task.valid_labels.reduce((obj, label, i) => {
obj[label] = sliderValues[i].toString();
return obj;
}, {} as Record<string, string>);
submit(id, task.message_id, task.prompt, labels);
};
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
if (isLoading) {
return <LoadingScreen text="Loading..." />;
}
if (tasks.length === 0) {
return <Container className="p-6 text-center text-gray-800">No tasks found...</Container>;
}
const task = tasks[0].task;
return (
<div className={`p-12 ${mainBgClasses}`}>
<TwoColumnsWithCards>
<>
<h5 className="text-lg font-semibold">Label Initial Prompt</h5>
<p className="text-lg py-1">Provide labels for the following prompt</p>
<MessageView text={task.prompt} is_assistant />
</>
<CheckboxSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />
</TwoColumnsWithCards>
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={reset} />
</div>
);
};
export default LabelInitialPrompt;
// TODO: consolidate with FlaggableElement
interface CheckboxSliderGroupProps {
labelIDs: Array<string>;
onChange: (sliderValues: number[]) => unknown;
}
const CheckboxSliderGroup = ({ labelIDs, onChange }: CheckboxSliderGroupProps) => {
const [sliderValues, setSliderValues] = useState<number[]>(Array.from({ length: labelIDs.length }).map(() => 0));
useEffect(() => {
onChange(sliderValues);
}, [sliderValues, onChange]);
return (
<Grid templateColumns="auto 1fr" rowGap={1} columnGap={3}>
{labelIDs.map((labelId, idx) => (
<CheckboxSliderItem
key={idx}
labelId={labelId}
sliderValue={sliderValues[idx]}
sliderHandler={(sliderValue) => {
const newState = sliderValues.slice();
newState[idx] = sliderValue;
setSliderValues(newState);
}}
/>
))}
</Grid>
);
};
function CheckboxSliderItem(props: {
labelId: string;
sliderValue: number;
sliderHandler: (newVal: number) => unknown;
}) {
const id = useId();
const { colorMode } = useColorMode();
const labelTextClass = colorMode === "light" ? `text-${colors.light.text}` : `text-${colors.dark.text}`;
return (
<>
<label className="text-sm" htmlFor={id}>
{/* TODO: display real text instead of just the id */}
<span className={labelTextClass}>{props.labelId}</span>
</label>
<Slider defaultValue={0} onChangeEnd={(val) => props.sliderHandler(val / 100)}>
<SliderTrack>
<SliderFilledTrack />
<SliderThumb />
</SliderTrack>
</Slider>
</>
);
}
+62
View File
@@ -0,0 +1,62 @@
import { Box, Container, Text, useColorModeValue } from "@chakra-ui/react";
import Head from "next/head";
import { useState } from "react";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { MessageTableEntry } from "src/components/Messages/MessageTableEntry";
import { MessageWithChildren } from "src/components/Messages/MessageWithChildren";
import fetcher from "src/lib/fetcher";
import useSWR from "swr";
const MessageDetail = ({ id }) => {
const mainBg = useColorModeValue("bg-slate-300", "bg-slate-900");
const [parent, setParent] = useState(null);
const { isLoading: isLoadingParent } = useSWR(id ? `/api/messages/${id}/parent` : null, fetcher, {
onSuccess: (data) => {
setParent(data);
},
onError: () => {
setParent(null);
},
});
if (isLoadingParent) {
return <LoadingScreen text="Loading..." />;
}
return (
<>
<Head>
<title>Open Assistant</title>
<meta
name="description"
content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world."
/>
</Head>
<main className={`${mainBg}`}>
<Container w="100%" pt={[2, 2, 4, 4]}>
{parent && (
<>
<Text align="center" fontSize="xl">
Parent
</Text>
<Box rounded="lg" p="2">
<MessageTableEntry item={parent} idx={1} />
</Box>
</>
)}
</Container>
<Box pb="4" maxW="full" px="2">
<MessageWithChildren id={id} maxDepth={2} />
</Box>
</main>
</>
);
};
MessageDetail.getInitialProps = async ({ query }) => {
const { id } = query;
return { id };
};
export default MessageDetail;
+43 -47
View File
@@ -1,79 +1,75 @@
import { Box, CircularProgress, SimpleGrid, Text, useColorModeValue } from "@chakra-ui/react";
import Head from "next/head";
import { useState } from "react";
import { useEffect, useState } from "react";
import { getDashboardLayout } from "src/components/Layout";
import { MessageTable } from "src/components/Messages/MessageTable";
import fetcher from "src/lib/fetcher";
import useSWRImmutable from "swr/immutable";
import fetcher from "src/lib/fetcher";
import { SideMenu } from "src/components/Dashboard";
import { MessageTable } from "src/components/Messages/MessageTable";
import { getDashboardLayout } from "src/components/Layout";
import { colors } from "styles/Theme/colors";
const MessagesDashboard = () => {
const bgColor = useColorModeValue(colors.light.bg, colors.dark.bg);
const boxBgColor = useColorModeValue("white", "gray.700");
const boxAccentColor = useColorModeValue("gray.200", "gray.900");
const [messages, setMessages] = useState([]);
const [userMessages, setUserMessages] = useState([]);
const { isLoading: isLoadingAll } = useSWRImmutable("/api/messages", fetcher, {
const { isLoading: isLoadingAll, mutate: mutateAll } = useSWRImmutable("/api/messages", fetcher, {
onSuccess: (data) => {
setMessages(data);
},
});
const { isLoading: isLoadingUser } = useSWRImmutable(`/api/messages/user`, fetcher, {
const { isLoading: isLoadingUser, mutate: mutateUser } = useSWRImmutable(`/api/messages/user`, fetcher, {
onSuccess: (data) => {
setUserMessages(data);
},
});
useEffect(() => {
if (messages.length == 0) {
mutateAll();
}
if (userMessages.length == 0) {
mutateUser();
}
}, [messages, userMessages]);
return (
<>
<Head>
<title>Messages - Open Assistant</title>
<meta name="description" content="Chat with Open Assistant and provide feedback." />
</Head>
<Box backgroundColor={bgColor} className="sm:overflow-hidden">
<Box className="sm:flex h-full gap-6">
<Box className="p-6 sm:pr-0">
<SideMenu />
</Box>
<Box className="flex flex-col overflow-auto p-6 sm:pl-0 gap-14">
<SimpleGrid columns={[1, 1, 1, 2]} gap={4}>
<Box>
<Text className="text-2xl font-bold" pb="4">
Most recent messages
</Text>
<Box
backgroundColor={boxBgColor}
boxShadow="base"
dropShadow={boxAccentColor}
borderRadius="xl"
className="p-6 shadow-sm"
>
{isLoadingAll ? <CircularProgress isIndeterminate /> : <MessageTable messages={messages} />}
</Box>
</Box>
<Box>
<Text className="text-2xl font-bold" pb="4">
Your most recent messages
</Text>
<Box
backgroundColor={boxBgColor}
boxShadow="base"
dropShadow={boxAccentColor}
borderRadius="xl"
className="p-6 shadow-sm"
>
{isLoadingUser ? <CircularProgress isIndeterminate /> : <MessageTable messages={userMessages} />}
</Box>
</Box>
</SimpleGrid>
<SimpleGrid columns={[1, 1, 1, 2]} gap={4}>
<Box>
<Text className="text-2xl font-bold" pb="4">
Most recent messages
</Text>
<Box
backgroundColor={boxBgColor}
boxShadow="base"
dropShadow={boxAccentColor}
borderRadius="xl"
className="p-6 shadow-sm"
>
{isLoadingAll ? <CircularProgress isIndeterminate /> : <MessageTable messages={messages} />}
</Box>
</Box>
</Box>
<Box>
<Text className="text-2xl font-bold" pb="4">
Your most recent messages
</Text>
<Box
backgroundColor={boxBgColor}
boxShadow="base"
dropShadow={boxAccentColor}
borderRadius="xl"
className="p-6 shadow-sm"
>
{isLoadingUser ? <CircularProgress isIndeterminate /> : <MessageTable messages={userMessages} />}
</Box>
</Box>
</SimpleGrid>
</>
);
};
-2
View File
@@ -1,7 +1,5 @@
import { Container, Heading } from "@chakra-ui/react";
import Head from "next/head";
import { Footer } from "src/components/Footer";
import { Header } from "src/components/Header";
import { getTransparentHeaderLayout } from "src/components/Layout";
const PrivacyPolicy = () => {
-2
View File
@@ -1,7 +1,5 @@
import { Container, Heading } from "@chakra-ui/react";
import Head from "next/head";
import { Footer } from "src/components/Footer";
import { Header } from "src/components/Header";
import { getTransparentHeaderLayout } from "src/components/Layout";
const TermsOfService = () => {