mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Merge branch 'main' into 371_set_labels
This commit is contained in:
@@ -4,7 +4,6 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- docs-site-poc
|
||||
paths:
|
||||
- ".github/workflows/deploy-docs-site.yaml"
|
||||
- "docs/**"
|
||||
@@ -45,9 +44,7 @@ jobs:
|
||||
|
||||
- name: Deploy
|
||||
uses: peaceiris/actions-gh-pages@v3
|
||||
if:
|
||||
${{ github.ref == 'refs/heads/main' || github.ref ==
|
||||
'refs/heads/docs-site-poc' }}
|
||||
if: ${{ github.ref == 'refs/heads/main' }}
|
||||
with:
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
publish_dir: ./docs/build
|
||||
|
||||
@@ -2,3 +2,5 @@
|
||||
/website/ @fozziethebeat @k-nearest-neighbor @AbdBarho
|
||||
/model/ @theblackcat102 @sanagno
|
||||
/copilot/ @fozziethebeat @andreaskoepf @yk
|
||||
/docs/ @andrewm4894 @andreaskoepf @yk
|
||||
/.devcontainer/ @andrewm4894 @andreaskoepf @yk
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
<div align="center">
|
||||
|
||||
<a href="https://github.com/LAION-AI/Open-Assistant/stargazers"></a>
|
||||
<a href="https://laion-ai.github.io/Open-Assistant/"></a>
|
||||
<a href="https://github.com/LAION-AI/Open-Assistant/actions/workflows/build-frontend.yaml"></a>
|
||||
<a href="https://github.com/LAION-AI/Open-Assistant/actions/workflows/pre-commit.yaml"></a>
|
||||
<a href="https://github.com/LAION-AI/Open-Assistant/actions/workflows/test-api-contract.yaml"></a>
|
||||
|
||||
@@ -1,5 +1,18 @@
|
||||
# Open-Assistant REST Backend
|
||||
|
||||
## Backend Development Setup
|
||||
|
||||
In root directory, run
|
||||
`docker compose up backend-dev --build --attach-dependencies` to start a
|
||||
database. The default settings are already configured to connect to the database
|
||||
at `localhost:5432`.
|
||||
|
||||
Make sure you have all requirements installed. You can do this by running
|
||||
`pip install -r requirements.txt` inside the `backend` folder and
|
||||
`pip install -e .` inside the `oasst-shared` folder. Then, run the backend using
|
||||
the `run-local.sh` script inside the `scripts` folder. This will start the
|
||||
backend server at `http://localhost:8080`.
|
||||
|
||||
## REST Server Configuration
|
||||
|
||||
Please either use environment variables or create a `.env` file in the backend
|
||||
@@ -20,3 +33,11 @@ REDIS_PORT=6379
|
||||
Have a look into the main `README.md` file for more information on how to set up
|
||||
the backend for development. Use the scripts within the
|
||||
scripts/backend-development folder to run the BE API locally.
|
||||
|
||||
## Alembic
|
||||
|
||||
To create an Alembic database migration script after sql-models were modified
|
||||
run `alembic revision --autogenerate -m "..."` ("..." is what you did) in the
|
||||
`/backend` directory. Then edit the newly created file. See
|
||||
[here](https://alembic.sqlalchemy.org/en/latest/tutorial.html) for more
|
||||
information.
|
||||
|
||||
+23
@@ -0,0 +1,23 @@
|
||||
"""added frontend_type to api_client
|
||||
|
||||
Revision ID: ba61fe17fb6e
|
||||
Revises: 20cd871f4ec7
|
||||
Create Date: 2023-01-07 12:50:32.195930
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ba61fe17fb6e"
|
||||
down_revision = "20cd871f4ec7"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("api_client", sa.Column("frontend_type", sa.String(256), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("api_client", "frontend_id")
|
||||
+20
-66
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
from http import HTTPStatus
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
@@ -6,7 +7,6 @@ from typing import Optional
|
||||
import alembic.command
|
||||
import alembic.config
|
||||
import fastapi
|
||||
import pydantic
|
||||
import redis.asyncio as redis
|
||||
from fastapi_limiter import FastAPILimiter
|
||||
from loguru import logger
|
||||
@@ -17,6 +17,7 @@ from oasst_backend.database import engine
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import Session
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
@@ -97,7 +98,7 @@ if settings.DEBUG_USE_SEED_DATA:
|
||||
|
||||
@app.on_event("startup")
|
||||
def seed_data():
|
||||
class DummyMessage(pydantic.BaseModel):
|
||||
class DummyMessage(BaseModel):
|
||||
task_message_id: str
|
||||
user_message_id: str
|
||||
parent_message_id: Optional[str]
|
||||
@@ -111,64 +112,10 @@ if settings.DEBUG_USE_SEED_DATA:
|
||||
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
|
||||
pr = PromptRepository(db=db, api_client=api_client, user=dummy_user)
|
||||
|
||||
dummy_messages = [
|
||||
DummyMessage(
|
||||
task_message_id="de111fa8",
|
||||
user_message_id="6f1d0711",
|
||||
parent_message_id=None,
|
||||
text="Hi!",
|
||||
role="prompter",
|
||||
),
|
||||
DummyMessage(
|
||||
task_message_id="74c381d4",
|
||||
user_message_id="4a24530b",
|
||||
parent_message_id="6f1d0711",
|
||||
text="Hello! How can I help you?",
|
||||
role="assistant",
|
||||
),
|
||||
DummyMessage(
|
||||
task_message_id="3d5dc440",
|
||||
user_message_id="a8c01c04",
|
||||
parent_message_id="4a24530b",
|
||||
text="Do you have a recipe for potato soup?",
|
||||
role="prompter",
|
||||
),
|
||||
DummyMessage(
|
||||
task_message_id="643716c1",
|
||||
user_message_id="f43a93b7",
|
||||
parent_message_id="4a24530b",
|
||||
text="Who were the 8 presidents before George Washington?",
|
||||
role="prompter",
|
||||
),
|
||||
DummyMessage(
|
||||
task_message_id="2e4e1e6",
|
||||
user_message_id="c886920",
|
||||
parent_message_id="6f1d0711",
|
||||
text="Hey buddy! How can I serve you?",
|
||||
role="assistant",
|
||||
),
|
||||
DummyMessage(
|
||||
task_message_id="970c437d",
|
||||
user_message_id="cec432cf",
|
||||
parent_message_id=None,
|
||||
text="euirdteunvglfe23908230892309832098 AAAAAAAA",
|
||||
role="prompter",
|
||||
),
|
||||
DummyMessage(
|
||||
task_message_id="6066118e",
|
||||
user_message_id="4f85f637",
|
||||
parent_message_id="cec432cf",
|
||||
text="Sorry, I did not understand your request and it is unclear to me what you want me to do. Could you describe it in a different way?",
|
||||
role="assistant",
|
||||
),
|
||||
DummyMessage(
|
||||
task_message_id="ba87780d",
|
||||
user_message_id="0e276b98",
|
||||
parent_message_id="cec432cf",
|
||||
text="I'm unsure how to interpret this. Is it a riddle?",
|
||||
role="assistant",
|
||||
),
|
||||
]
|
||||
with open(settings.DEBUG_USE_SEED_DATA_PATH) as f:
|
||||
dummy_messages_raw = json.load(f)
|
||||
|
||||
dummy_messages = [DummyMessage(**dm) for dm in dummy_messages_raw]
|
||||
|
||||
for msg in dummy_messages:
|
||||
task = pr.fetch_task_by_frontend_message_id(msg.task_message_id)
|
||||
@@ -185,12 +132,20 @@ if settings.DEBUG_USE_SEED_DATA:
|
||||
parent_message = pr.fetch_message_by_frontend_message_id(
|
||||
msg.parent_message_id, fail_if_missing=True
|
||||
)
|
||||
task = pr.store_task(
|
||||
protocol_schema.AssistantReplyTask(
|
||||
conversation=protocol_schema.Conversation(
|
||||
messages=[protocol_schema.ConversationMessage(text="dummy", is_assistant=False)]
|
||||
conversation_messages = pr.fetch_message_conversation(parent_message)
|
||||
conversation = protocol_schema.Conversation(
|
||||
messages=[
|
||||
protocol_schema.ConversationMessage(
|
||||
text=cmsg.text,
|
||||
is_assistant=cmsg.role == "assistant",
|
||||
message_id=cmsg.id,
|
||||
fronend_message_id=cmsg.frontend_message_id,
|
||||
)
|
||||
),
|
||||
for cmsg in conversation_messages
|
||||
]
|
||||
)
|
||||
task = pr.store_task(
|
||||
protocol_schema.AssistantReplyTask(conversation=conversation),
|
||||
message_tree_id=parent_message.message_tree_id,
|
||||
parent_message_id=parent_message.id,
|
||||
)
|
||||
@@ -219,7 +174,6 @@ if __name__ == "__main__":
|
||||
# Importing here so we don't import packages unnecessarily if we're
|
||||
# importing main as a module.
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import uvicorn
|
||||
|
||||
|
||||
@@ -40,7 +40,13 @@ def get_dummy_api_client(db: Session) -> ApiClient:
|
||||
if api_client is None:
|
||||
token = token_hex(32)
|
||||
logger.info(f"ANY_API_KEY missing, inserting api_key: {token}")
|
||||
api_client = ApiClient(id=ANY_API_KEY_ID, api_key=token, description="ANY_API_KEY, random token", trusted=True)
|
||||
api_client = ApiClient(
|
||||
id=ANY_API_KEY_ID,
|
||||
api_key=token,
|
||||
description="ANY_API_KEY, random token",
|
||||
trusted=True,
|
||||
frontend_type="Test frontend",
|
||||
)
|
||||
db.add(api_client)
|
||||
db.commit()
|
||||
return api_client
|
||||
|
||||
@@ -2,9 +2,7 @@ from fastapi import APIRouter, Depends
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.api.v1 import utils
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.models.db_payload import MessagePayload
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol
|
||||
from sqlmodel import Session
|
||||
|
||||
@@ -20,11 +18,6 @@ def get_message_by_frontend_id(
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
|
||||
if not isinstance(message.payload.payload, MessagePayload):
|
||||
# Unexpected message payload
|
||||
raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE)
|
||||
|
||||
return utils.prepare_message(message)
|
||||
|
||||
|
||||
|
||||
@@ -5,9 +5,7 @@ from fastapi import APIRouter, Depends, Query
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.api.v1 import utils
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.models.db_payload import MessagePayload
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_204_NO_CONTENT
|
||||
@@ -55,10 +53,6 @@ def get_message(
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_message(message_id)
|
||||
if not isinstance(message.payload.payload, MessagePayload):
|
||||
# Unexptcted message payload
|
||||
raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE)
|
||||
|
||||
return utils.prepare_message(message)
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from fastapi import APIRouter, Depends
|
||||
from fastapi.security.api_key import APIKey
|
||||
from loguru import logger
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.api.v1.utils import prepare_conversation
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
@@ -58,7 +59,10 @@ def generate_task(
|
||||
messages = pr.fetch_random_conversation("assistant")
|
||||
task_messages = [
|
||||
protocol_schema.ConversationMessage(
|
||||
text=msg.payload.payload.text, is_assistant=(msg.role == "assistant")
|
||||
text=msg.text,
|
||||
is_assistant=(msg.role == "assistant"),
|
||||
message_id=msg.id,
|
||||
front_end_id=msg.frontend_message_id,
|
||||
)
|
||||
for msg in messages
|
||||
]
|
||||
@@ -71,7 +75,10 @@ def generate_task(
|
||||
messages = pr.fetch_random_conversation("prompter")
|
||||
task_messages = [
|
||||
protocol_schema.ConversationMessage(
|
||||
text=msg.payload.payload.text, is_assistant=(msg.role == "assistant")
|
||||
text=msg.text,
|
||||
is_assistant=(msg.role == "assistant"),
|
||||
message_id=msg.id,
|
||||
front_end_id=msg.frontend_message_id,
|
||||
)
|
||||
for msg in messages
|
||||
]
|
||||
@@ -83,19 +90,21 @@ def generate_task(
|
||||
logger.info("Generating a RankInitialPromptsTask.")
|
||||
|
||||
messages = pr.fetch_random_initial_prompts()
|
||||
task = protocol_schema.RankInitialPromptsTask(prompts=[msg.payload.payload.text for msg in messages])
|
||||
task = protocol_schema.RankInitialPromptsTask(prompts=[msg.text for msg in messages])
|
||||
case protocol_schema.TaskRequestType.rank_prompter_replies:
|
||||
logger.info("Generating a RankPrompterRepliesTask.")
|
||||
conversation, replies = pr.fetch_multiple_random_replies(message_role="assistant")
|
||||
|
||||
task_messages = [
|
||||
protocol_schema.ConversationMessage(
|
||||
text=p.payload.payload.text,
|
||||
text=p.text,
|
||||
is_assistant=(p.role == "assistant"),
|
||||
message_id=p.id,
|
||||
front_end_id=p.frontend_message_id,
|
||||
)
|
||||
for p in conversation
|
||||
]
|
||||
replies = [p.payload.payload.text for p in replies]
|
||||
replies = [p.text for p in replies]
|
||||
task = protocol_schema.RankPrompterRepliesTask(
|
||||
conversation=protocol_schema.Conversation(
|
||||
messages=task_messages,
|
||||
@@ -109,14 +118,16 @@ def generate_task(
|
||||
|
||||
task_messages = [
|
||||
protocol_schema.ConversationMessage(
|
||||
text=p.payload.payload.text,
|
||||
text=p.text,
|
||||
is_assistant=(p.role == "assistant"),
|
||||
message_id=p.id,
|
||||
front_end_id=p.frontend_message_id,
|
||||
)
|
||||
for p in conversation
|
||||
]
|
||||
replies = [p.payload.payload.text for p in replies]
|
||||
replies = [p.text for p in replies]
|
||||
task = protocol_schema.RankAssistantRepliesTask(
|
||||
conversation=protocol_schema.Conversation(messages=task_messages),
|
||||
conversation=prepare_conversation(conversation),
|
||||
replies=replies,
|
||||
)
|
||||
|
||||
@@ -125,29 +136,29 @@ def generate_task(
|
||||
message = pr.fetch_random_initial_prompts(1)[0]
|
||||
task = protocol_schema.LabelInitialPromptTask(
|
||||
message_id=message.id,
|
||||
prompt=message.payload.payload.text,
|
||||
prompt=message.text,
|
||||
valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)),
|
||||
)
|
||||
|
||||
case protocol_schema.TaskRequestType.label_prompter_reply:
|
||||
logger.info("Generating a LabelPrompterReplyTask.")
|
||||
conversation, messages = pr.fetch_multiple_random_replies(max_size=1, message_role="assistant")
|
||||
message = messages[0].payload.payload.text
|
||||
message = messages[0]
|
||||
task = protocol_schema.LabelPrompterReplyTask(
|
||||
message_id=message.id,
|
||||
conversation=conversation,
|
||||
reply=message,
|
||||
conversation=prepare_conversation(conversation),
|
||||
reply=message.text,
|
||||
valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)),
|
||||
)
|
||||
|
||||
case protocol_schema.TaskRequestType.label_assistant_reply:
|
||||
logger.info("Generating a LabelAssistantReplyTask.")
|
||||
conversation, messages = pr.fetch_multiple_random_replies(max_size=1, message_role="prompter")
|
||||
message = messages[0].payload.payload.text
|
||||
message = messages[0]
|
||||
task = protocol_schema.LabelAssistantReplyTask(
|
||||
message_id=message.id,
|
||||
conversation=conversation,
|
||||
reply=message,
|
||||
conversation=prepare_conversation(conversation),
|
||||
reply=message.text,
|
||||
valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)),
|
||||
)
|
||||
|
||||
@@ -292,7 +303,8 @@ def tasks_interaction(
|
||||
logger.info(
|
||||
f"Frontend reports labels of {interaction.message_id=} with {interaction.labels=} by {interaction.user=}."
|
||||
)
|
||||
# TODO: check if the labels are valid?
|
||||
# Labels are implicitly validated when converting str -> TextLabel
|
||||
# So no need for explicit validation here
|
||||
pr.store_text_labels(interaction)
|
||||
return protocol_schema.TaskDone()
|
||||
case _:
|
||||
|
||||
@@ -3,6 +3,7 @@ from fastapi.security.api_key import APIKey
|
||||
from loguru import logger
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.schemas.text_labels import LabelOption, ValidLabelsResponse
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST
|
||||
@@ -32,3 +33,13 @@ def label_text(
|
||||
raise HTTPException(
|
||||
status_code=HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/valid_labels")
|
||||
def get_valid_lables() -> ValidLabelsResponse:
|
||||
return ValidLabelsResponse(
|
||||
valid_labels=[
|
||||
LabelOption(name=l.value, display_text=l.display_text, help_text=l.help_text)
|
||||
for l in protocol_schema.TextLabel
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1,19 +1,14 @@
|
||||
from http import HTTPStatus
|
||||
from uuid import UUID
|
||||
|
||||
from oasst_backend.models import Message
|
||||
from oasst_backend.models.db_payload import MessagePayload
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol
|
||||
|
||||
|
||||
def prepare_message(m: Message) -> protocol.Message:
|
||||
if not isinstance(m.payload.payload, MessagePayload):
|
||||
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
return protocol.Message(
|
||||
id=m.id,
|
||||
parent_id=m.parent_id,
|
||||
text=m.payload.payload.text,
|
||||
text=m.text,
|
||||
is_assistant=(m.role == "assistant"),
|
||||
created_date=m.created_date,
|
||||
)
|
||||
@@ -26,10 +21,13 @@ def prepare_message_list(messages: list[Message]) -> list[protocol.Message]:
|
||||
def prepare_conversation(messages: list[Message]) -> protocol.Conversation:
|
||||
conv_messages = []
|
||||
for message in messages:
|
||||
if not isinstance(message.payload.payload, MessagePayload):
|
||||
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
conv_messages.append(
|
||||
protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant"))
|
||||
protocol.ConversationMessage(
|
||||
text=message.text,
|
||||
is_assistant=(message.role == "assistant"),
|
||||
message_id=message.id,
|
||||
frontend_message_id=message.frontend_message_id,
|
||||
)
|
||||
)
|
||||
|
||||
return protocol.Conversation(messages=conv_messages)
|
||||
@@ -38,8 +36,6 @@ def prepare_conversation(messages: list[Message]) -> protocol.Conversation:
|
||||
def prepare_tree(tree: list[Message], tree_id: UUID) -> protocol.MessageTree:
|
||||
tree_messages = []
|
||||
for message in tree:
|
||||
if not isinstance(message.payload.payload, MessagePayload):
|
||||
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
tree_messages.append(prepare_message(message))
|
||||
|
||||
return protocol.MessageTree(id=tree_id, messages=tree_messages)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import AnyHttpUrl, BaseSettings, PostgresDsn, validator
|
||||
from pydantic import AnyHttpUrl, BaseSettings, FilePath, PostgresDsn, validator
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
@@ -21,6 +22,9 @@ class Settings(BaseSettings):
|
||||
DEBUG_ALLOW_ANY_API_KEY: bool = False
|
||||
DEBUG_SKIP_API_KEY_CHECK: bool = False
|
||||
DEBUG_USE_SEED_DATA: bool = False
|
||||
DEBUG_USE_SEED_DATA_PATH: Optional[FilePath] = (
|
||||
Path(__file__).parent.parent / "test_data/generic/test_generic_data.json"
|
||||
)
|
||||
|
||||
HUGGING_FACE_API_KEY: str = ""
|
||||
|
||||
|
||||
@@ -20,3 +20,4 @@ class ApiClient(SQLModel, table=True):
|
||||
admin_email: Optional[str] = Field(max_length=256, nullable=True)
|
||||
enabled: bool = Field(default=True)
|
||||
trusted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
|
||||
frontend_type: str = Field(max_length=256, nullable=True)
|
||||
|
||||
@@ -32,7 +32,7 @@ class Journal(SQLModel, table=True):
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp())
|
||||
)
|
||||
user_id: UUID = Field(nullable=True, foreign_key="user.id", index=True)
|
||||
user_id: Optional[UUID] = Field(nullable=True, foreign_key="user.id", index=True)
|
||||
message_id: Optional[UUID] = Field(foreign_key="message.id", nullable=True)
|
||||
api_client_id: UUID = Field(foreign_key="api_client.id")
|
||||
|
||||
@@ -49,7 +49,7 @@ class JournalIntegration(SQLModel, table=True):
|
||||
),
|
||||
)
|
||||
description: str = Field(max_length=512, primary_key=True)
|
||||
last_journal_id: UUID = Field(foreign_key="journal.id", nullable=True)
|
||||
last_run: datetime = Field(sa_column=sa.Column(sa.DateTime(), nullable=True))
|
||||
last_error: str = Field(nullable=True)
|
||||
next_run: datetime = Field(nullable=True)
|
||||
last_journal_id: Optional[UUID] = Field(foreign_key="journal.id", nullable=True)
|
||||
last_run: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(), nullable=True))
|
||||
last_error: Optional[str] = Field(nullable=True)
|
||||
next_run: Optional[datetime] = Field(nullable=True)
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
from datetime import datetime
|
||||
from http import HTTPStatus
|
||||
from typing import Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from oasst_backend.models.db_payload import MessagePayload
|
||||
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
||||
from sqlalchemy import false
|
||||
from sqlmodel import Field, Index, SQLModel
|
||||
|
||||
@@ -19,19 +22,30 @@ class Message(SQLModel, table=True):
|
||||
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
|
||||
),
|
||||
)
|
||||
parent_id: UUID = Field(nullable=True)
|
||||
parent_id: Optional[UUID] = Field(nullable=True)
|
||||
message_tree_id: UUID = Field(nullable=False, index=True)
|
||||
task_id: UUID = Field(nullable=True, index=True)
|
||||
user_id: UUID = Field(nullable=True, foreign_key="user.id", index=True)
|
||||
role: str = Field(nullable=False, max_length=128) # valid: "prompter" | "assistant"
|
||||
task_id: Optional[UUID] = Field(nullable=True, index=True)
|
||||
user_id: Optional[UUID] = Field(nullable=True, foreign_key="user.id", index=True)
|
||||
role: str = Field(nullable=False, max_length=128, regex="^prompter|assistant$")
|
||||
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
|
||||
frontend_message_id: str = Field(max_length=200, nullable=False)
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
|
||||
)
|
||||
payload_type: str = Field(nullable=False, max_length=200)
|
||||
payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=True))
|
||||
payload: Optional[PayloadContainer] = Field(
|
||||
sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=True)
|
||||
)
|
||||
lang: str = Field(nullable=False, max_length=200, default="en-US")
|
||||
depth: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
|
||||
children_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
|
||||
deleted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
|
||||
|
||||
def ensure_is_message(self) -> None:
|
||||
if not self.payload or not isinstance(self.payload.payload, MessagePayload):
|
||||
raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
self.ensure_is_message()
|
||||
return self.payload.payload.text
|
||||
|
||||
@@ -22,7 +22,7 @@ class Task(SQLModel, table=True):
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
|
||||
)
|
||||
expiry_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(), nullable=True))
|
||||
user_id: UUID = Field(nullable=True, foreign_key="user.id", index=True)
|
||||
user_id: Optional[UUID] = Field(nullable=True, foreign_key="user.id", index=True)
|
||||
payload_type: str = Field(nullable=False, max_length=200)
|
||||
payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False))
|
||||
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LabelOption(BaseModel):
|
||||
name: str
|
||||
display_text: str
|
||||
help_text: Optional[str]
|
||||
|
||||
|
||||
class ValidLabelsResponse(BaseModel):
|
||||
valid_labels: list[LabelOption]
|
||||
@@ -0,0 +1,72 @@
|
||||
[
|
||||
{
|
||||
"task_message_id": "de111fa8",
|
||||
"user_message_id": "6f1d0711",
|
||||
"parent_message_id": null,
|
||||
"text": "Hi!",
|
||||
"role": "prompter"
|
||||
},
|
||||
{
|
||||
"task_message_id": "74c381d4",
|
||||
"user_message_id": "4a24530b",
|
||||
"parent_message_id": "6f1d0711",
|
||||
"text": "Hello! How can I help you?",
|
||||
"role": "assistant"
|
||||
},
|
||||
{
|
||||
"task_message_id": "3d5dc440",
|
||||
"user_message_id": "a8c01c04",
|
||||
"parent_message_id": "4a24530b",
|
||||
"text": "Do you have a recipe for potato soup?",
|
||||
"role": "prompter"
|
||||
},
|
||||
{
|
||||
"task_message_id": "643716c1",
|
||||
"user_message_id": "f43a93b7",
|
||||
"parent_message_id": "4a24530b",
|
||||
"text": "Who were the 8 presidents before George Washington?",
|
||||
"role": "prompter"
|
||||
},
|
||||
{
|
||||
"task_message_id": "2e4e1e6",
|
||||
"user_message_id": "c886920",
|
||||
"parent_message_id": "6f1d0711",
|
||||
"text": "Hey buddy! How can I serve you?",
|
||||
"role": "assistant"
|
||||
},
|
||||
{
|
||||
"task_message_id": "970c437d",
|
||||
"user_message_id": "cec432cf",
|
||||
"parent_message_id": null,
|
||||
"text": "euirdteunvglfe23908230892309832098 AAAAAAAA",
|
||||
"role": "prompter"
|
||||
},
|
||||
{
|
||||
"task_message_id": "6066118e",
|
||||
"user_message_id": "4f85f637",
|
||||
"parent_message_id": "cec432cf",
|
||||
"text": "Sorry, I did not understand your request and it is unclear to me what you want me to do. Could you describe it in a different way?",
|
||||
"role": "assistant"
|
||||
},
|
||||
{
|
||||
"task_message_id": "ba87780d",
|
||||
"user_message_id": "0e276b98",
|
||||
"parent_message_id": "cec432cf",
|
||||
"text": "I'm unsure how to interpret this. Is it a riddle?",
|
||||
"role": "assistant"
|
||||
},
|
||||
{
|
||||
"task_message_id": "b8e98ed6",
|
||||
"user_message_id": "89384709",
|
||||
"parent_message_id": "0e276b98",
|
||||
"text": "No, I just wanted to see how you reply when I type random characters. Can you tell me who invented Wikipedia?",
|
||||
"role": "prompter"
|
||||
},
|
||||
{
|
||||
"task_message_id": "9a0e7683",
|
||||
"user_message_id": "6d452c57",
|
||||
"parent_message_id": "0e276b98",
|
||||
"text": "Sorry, my cat sat on my keyboard. Can you print a cat in ASCII art?",
|
||||
"role": "prompter"
|
||||
}
|
||||
]
|
||||
@@ -14,3 +14,4 @@ COPY ./backend/alembic /app/alembic
|
||||
COPY ./backend/alembic.ini /app/alembic.ini
|
||||
COPY ./backend/main.py /app/main.py
|
||||
COPY ./backend/oasst_backend /app/oasst_backend
|
||||
COPY ./backend/test_data /app/test_data
|
||||
|
||||
+5
-3
@@ -1,7 +1,9 @@
|
||||
# Website
|
||||
# Docs Site
|
||||
|
||||
This website is built using [Docusaurus 2](https://docusaurus.io/), a modern
|
||||
static website generator.
|
||||
https://laion-ai.github.io/Open-Assistant/
|
||||
|
||||
This [site](https://laion-ai.github.io/Open-Assistant/) is built using
|
||||
[Docusaurus 2](https://docusaurus.io/), a modern static website generator.
|
||||
|
||||
### Contributing
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ const config = {
|
||||
// If you aren't using GitHub pages, you don't need these.
|
||||
organizationName: "LAION-AI", // Usually your GitHub org/user name.
|
||||
projectName: "Open-Assistant", // Usually your repo name.
|
||||
deploymentBranch: "docs-site-poc",
|
||||
deploymentBranch: "main",
|
||||
|
||||
// Even if you don't use internalization, you can use this field to set useful
|
||||
// metadata like html lang. For example, if your site is Chinese, you may want
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
model_name: microsoft/deberta-v2-xlarge
|
||||
learning_rate: 1e-5
|
||||
freeze_layer: 15
|
||||
scheduler: cosine
|
||||
gradient_checkpointing: false
|
||||
gradient_accumulation_steps: 16
|
||||
per_device_train_batch_size: 1
|
||||
warmup_steps: 600
|
||||
eval_steps: 200
|
||||
save_steps: 500
|
||||
max_length: 512
|
||||
num_train_epochs: 2
|
||||
datasets:
|
||||
- webgpt
|
||||
- hfsummary
|
||||
@@ -0,0 +1,14 @@
|
||||
model_name: microsoft/deberta-v3-base
|
||||
learning_rate: 1e-5
|
||||
scheduler: cosine
|
||||
gradient_checkpointing: false
|
||||
gradient_accumulation_steps: 32
|
||||
per_device_train_batch_size: 2
|
||||
warmup_steps: 600
|
||||
eval_steps: 200
|
||||
save_steps: 500
|
||||
max_length: 512
|
||||
num_train_epochs: 2
|
||||
datasets:
|
||||
- webgpt
|
||||
- hfsummary
|
||||
@@ -0,0 +1,13 @@
|
||||
model_name: deepset/deberta-v3-large-squad2
|
||||
learning_rate: 1e-5
|
||||
gradient_checkpointing: false
|
||||
gradient_accumulation_steps: 32
|
||||
per_device_train_batch_size: 1
|
||||
warmup_steps: 600
|
||||
eval_steps: 200
|
||||
save_steps: 500
|
||||
max_length: 512
|
||||
num_train_epochs: 2
|
||||
datasets:
|
||||
- webgpt
|
||||
- hfsummary
|
||||
@@ -0,0 +1,14 @@
|
||||
model_name: microsoft/deberta-v3-large
|
||||
learning_rate: 1e-5
|
||||
scheduler: cosine
|
||||
gradient_checkpointing: false
|
||||
gradient_accumulation_steps: 32
|
||||
per_device_train_batch_size: 1
|
||||
warmup_steps: 600
|
||||
eval_steps: 200
|
||||
save_steps: 500
|
||||
max_length: 512
|
||||
num_train_epochs: 2
|
||||
datasets:
|
||||
- webgpt
|
||||
- hfsummary
|
||||
@@ -11,6 +11,7 @@ from rank_datasets import DataCollatorForPairRank, HFSummary, RankGenCollator, W
|
||||
from torch import nn
|
||||
from torch.utils.data import ConcatDataset, Dataset
|
||||
from transformers import (
|
||||
AdamW,
|
||||
AutoModelForSequenceClassification,
|
||||
DataCollator,
|
||||
EvalPrediction,
|
||||
@@ -19,6 +20,8 @@ from transformers import (
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
get_cosine_schedule_with_warmup,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
from utils import argument_parsing, freeze_top_n_layers, get_tokenizer, train_val_dataset
|
||||
|
||||
@@ -179,7 +182,7 @@ if __name__ == "__main__":
|
||||
evaluation_strategy="steps",
|
||||
eval_steps=training_conf["eval_steps"],
|
||||
save_steps=1000,
|
||||
report_to="local",
|
||||
report_to="wandb",
|
||||
)
|
||||
train_datasets, evals = [], {}
|
||||
if "webgpt" in training_conf["datasets"]:
|
||||
@@ -202,6 +205,21 @@ if __name__ == "__main__":
|
||||
else:
|
||||
collate_fn = DataCollatorForPairRank(tokenizer, max_length=training_conf["max_length"])
|
||||
assert len(evals) > 0
|
||||
|
||||
optimizer = AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
|
||||
scheduler = None
|
||||
if "scheduler" in training_conf:
|
||||
if training_conf["scheduler"] == "linear":
|
||||
scheduler = get_linear_schedule_with_warmup()
|
||||
elif training_conf["scheduler"] == "cosine":
|
||||
scheduler = get_cosine_schedule_with_warmup(
|
||||
optimizer,
|
||||
num_warmup_steps=args.warmup_steps,
|
||||
num_training_steps=len(train)
|
||||
* args.num_train_epochs
|
||||
/ (args.per_device_train_batch_size * args.gradient_accumulation_steps),
|
||||
)
|
||||
|
||||
trainer = RankTrainer(
|
||||
model=model,
|
||||
model_name=model_name,
|
||||
@@ -211,6 +229,7 @@ if __name__ == "__main__":
|
||||
data_collator=collate_fn,
|
||||
tokenizer=tokenizer,
|
||||
compute_metrics=compute_metrics,
|
||||
optimizers=(optimizer, scheduler),
|
||||
)
|
||||
# trainer.evaluate()
|
||||
trainer.train()
|
||||
|
||||
@@ -1,5 +1,13 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"[](https://colab.research.google.com/github/LAION-AI/Open-Assistant/blob/main/notebooks/code-bugger/openbugger_example.ipynb)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
|
||||
@@ -1,5 +1,13 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"[](https://colab.research.google.com/github/LAION-AI/Open-Assistant/blob/main/notebooks/data-argumentation/EssayInstructions.ipynb)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
|
||||
@@ -1,5 +1,13 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"[](https://colab.research.google.com/github/LAION-AI/Open-Assistant/blob/main/notebooks/data-argumentation/EssayRevision.ipynb)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -1,5 +1,23 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"[](https://colab.research.google.com/github/LAION-AI/Open-Assistant/blob/main/notebooks/detoxify-evaluation/DetoxityEvaluation.ipynb)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# uncomment below to install required python packages\n",
|
||||
"#!pip install detoxify"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
|
||||
@@ -34,6 +34,8 @@ class ConversationMessage(BaseModel):
|
||||
|
||||
text: str
|
||||
is_assistant: bool
|
||||
message_id: Optional[UUID] = None
|
||||
frontend_message_id: Optional[str] = None
|
||||
|
||||
|
||||
class Conversation(BaseModel):
|
||||
@@ -263,10 +265,25 @@ class MessageRanking(Interaction):
|
||||
class TextLabel(str, enum.Enum):
|
||||
"""A label for a piece of text."""
|
||||
|
||||
def __new__(cls, label: str, display_text: str = "", help_text: str = None):
|
||||
obj = str.__new__(cls, label)
|
||||
obj._value_ = label
|
||||
obj.display_text = display_text
|
||||
obj.help_text = help_text
|
||||
return obj
|
||||
|
||||
spam = "spam"
|
||||
violence = "violence"
|
||||
sexual_content = "sexual_content"
|
||||
fails_task = "fails_task", "Fails to follow the correct instruction / task"
|
||||
not_appropriate = "not_appropriate", "Inappropriate for customer assistant"
|
||||
violence = "violence", "Encourages or fails to discourage violence/abuse/terrorism/self-harm"
|
||||
harmful = (
|
||||
"harmful",
|
||||
"Harmful content",
|
||||
"The advice given in the output is harmful or counter-productive. This may be in addition to, but is distinct from the question about encouraging violence/abuse/terrorism/self-harm.",
|
||||
)
|
||||
sexual_content = "sexual_content", "Contains sexual content"
|
||||
toxicity = "toxicity"
|
||||
moral_judgement = "moral_judgement", "Expresses moral judgement"
|
||||
political_content = "political_content"
|
||||
humor = "humor"
|
||||
sarcasm = "sarcasm"
|
||||
|
||||
+29
-10
@@ -1,5 +1,6 @@
|
||||
"""Simple REPL frontend."""
|
||||
|
||||
import http
|
||||
import random
|
||||
|
||||
import requests
|
||||
@@ -30,6 +31,8 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
|
||||
def _post(path: str, json: dict) -> dict:
|
||||
response = requests.post(f"{backend_url}{path}", json=json, headers={"X-API-Key": api_key})
|
||||
response.raise_for_status()
|
||||
if response.status_code == http.HTTPStatus.NO_CONTENT:
|
||||
return None
|
||||
return response.json()
|
||||
|
||||
typer.echo("Requesting work...")
|
||||
@@ -191,7 +194,7 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
|
||||
ranking_str = typer.prompt("Enter the reply numbers in order of preference, separated by commas")
|
||||
ranking = [int(x) - 1 for x in ranking_str.split(",")]
|
||||
|
||||
# send ranking
|
||||
# send labels
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
@@ -211,11 +214,19 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
|
||||
|
||||
valid_labels = task["valid_labels"]
|
||||
labels_str: str = typer.prompt("Enter labels, separated by commas")
|
||||
labels = labels_str.lower().replace(" ", "").split(",")
|
||||
labels_dict = {label: "1" if label in labels else "0" for label in valid_labels}
|
||||
|
||||
# send ranking
|
||||
labels_dict = None
|
||||
while labels_dict is None:
|
||||
labels_str: str = typer.prompt("Enter labels, separated by commas")
|
||||
labels = labels_str.lower().replace(" ", "").split(",")
|
||||
|
||||
if all([label in valid_labels for label in labels]):
|
||||
labels_dict = {label: "1" if label in labels else "0" for label in valid_labels}
|
||||
else:
|
||||
invalid_labels = [label for label in labels if label not in valid_labels]
|
||||
typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}")
|
||||
|
||||
# send labels
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
@@ -240,17 +251,25 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
|
||||
|
||||
valid_labels = task["valid_labels"]
|
||||
labels_str: str = typer.prompt("Enter labels, separated by commas")
|
||||
labels = labels_str.lower().replace(" ", "").split(",")
|
||||
labels_dict = {label: "1" if label in labels else "0" for label in valid_labels}
|
||||
|
||||
# send ranking
|
||||
labels_dict = None
|
||||
while labels_dict is None:
|
||||
labels_str: str = typer.prompt("Enter labels, separated by commas")
|
||||
labels = labels_str.lower().replace(" ", "").split(",")
|
||||
|
||||
if all([label in valid_labels for label in labels]):
|
||||
labels_dict = {label: "1" if label in labels else "0" for label in valid_labels}
|
||||
else:
|
||||
invalid_labels = [label for label in labels if label not in valid_labels]
|
||||
typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}")
|
||||
|
||||
# send labels
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "text_labels",
|
||||
"message_id": task["message_id"],
|
||||
"text": task["prompt"],
|
||||
"text": task["reply"],
|
||||
"labels": labels_dict,
|
||||
"user": USER,
|
||||
},
|
||||
|
||||
@@ -8,7 +8,8 @@
|
||||
"rules": {
|
||||
"unused-imports/no-unused-imports": "warn",
|
||||
"simple-import-sort/imports": "warn",
|
||||
"simple-import-sort/exports": "warn"
|
||||
"simple-import-sort/exports": "warn",
|
||||
"eqeqeq": "warn"
|
||||
},
|
||||
"plugins": ["simple-import-sort", "unused-imports"]
|
||||
}
|
||||
|
||||
Generated
+29
-616
File diff suppressed because it is too large
Load Diff
@@ -50,7 +50,6 @@
|
||||
"react": "18.2.0",
|
||||
"react-dom": "18.2.0",
|
||||
"react-icons": "^4.7.1",
|
||||
"sharp": "0.31.2",
|
||||
"swr": "^2.0.0",
|
||||
"tailwindcss": "^3.2.4",
|
||||
"use-debounce": "^9.0.2"
|
||||
|
||||
@@ -48,7 +48,7 @@ export function CallToAction() {
|
||||
here:
|
||||
</p>
|
||||
<div className="mt-8 flex justify-center">
|
||||
<a href="https://discord.gg/pXtnYk9c" rel="noreferrer" target="_blank">
|
||||
<a href="https://ykilcher.com/open-assistant-discord" rel="noreferrer" target="_blank">
|
||||
<button
|
||||
type="button"
|
||||
className="mb-2 ml-6 flex items-center rounded-md border border-transparent bg-blue-600 px-6 py-3 text-base font-medium text-white shadow-sm hover:bg-blue-700 focus:outline-none focus:ring-2 focus:ring-blue-500 focus:ring-offset-2"
|
||||
|
||||
@@ -1,5 +1,13 @@
|
||||
import { Button, useDisclosure } from "@chakra-ui/react";
|
||||
import { Modal, ModalOverlay, ModalContent, ModalHeader, ModalBody, ModalCloseButton } from "@chakra-ui/react";
|
||||
import {
|
||||
Button,
|
||||
Modal,
|
||||
ModalBody,
|
||||
ModalCloseButton,
|
||||
ModalContent,
|
||||
ModalHeader,
|
||||
ModalOverlay,
|
||||
useDisclosure,
|
||||
} from "@chakra-ui/react";
|
||||
import React from "react";
|
||||
|
||||
export const CollapsableText = ({ text, maxLength = 220 }) => {
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
import { Box } from "@chakra-ui/react";
|
||||
import { Message } from "./Messages";
|
||||
|
||||
export const ContextMessages = ({ messages }: { messages: Message[] }) => {
|
||||
return (
|
||||
<Box className="flex flex-col gap-1">
|
||||
{messages.map((message, i) => {
|
||||
return (
|
||||
<Box key={i}>
|
||||
<span>{message.is_assistant ? "Assistant: " : "User: "}</span>
|
||||
<span>{message.text}</span>
|
||||
</Box>
|
||||
);
|
||||
})}
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
@@ -1,126 +1,53 @@
|
||||
import { Box, Flex, GridItem, Heading, SimpleGrid, Text, useColorModeValue } from "@chakra-ui/react";
|
||||
import Link from "next/link";
|
||||
|
||||
const crTasks = [
|
||||
{
|
||||
label: "Create Initial Prompts",
|
||||
desc: "Write initial prompts to help Open Assistant to try replying to diverse messages.",
|
||||
type: "create",
|
||||
pathname: "/create/initial_prompt",
|
||||
},
|
||||
{
|
||||
label: "Reply as User",
|
||||
desc: "Chat with Open Assistant and help improve it’s responses as you interact with it.",
|
||||
type: "create",
|
||||
pathname: "/create/user_reply",
|
||||
},
|
||||
{
|
||||
label: "Reply as Assistant",
|
||||
desc: "Help Open Assistant improve its responses to conversations with other users.",
|
||||
type: "create",
|
||||
pathname: "/create/assistant_reply",
|
||||
},
|
||||
];
|
||||
import { TaskCategory, TaskTypes } from "../Tasks/TaskTypes";
|
||||
|
||||
const evTasks = [
|
||||
{
|
||||
label: "Rank User Replies",
|
||||
type: "eval",
|
||||
desc: "Help Open Assistant improve its responses to conversations with other users.",
|
||||
pathname: "/evaluate/rank_user_replies",
|
||||
},
|
||||
|
||||
{
|
||||
label: "Rank Assistant Replies",
|
||||
desc: "Score prompts given by Open Assistant based on their accuracy and readability.",
|
||||
type: "eval",
|
||||
pathname: "/evaluate/rank_assistant_replies",
|
||||
},
|
||||
{
|
||||
label: "Rank Initial Prompts",
|
||||
desc: "Score prompts given by Open Assistant based on their accuracy and readability.",
|
||||
type: "eval;",
|
||||
pathname: "/evaluate/rank_initial_prompts",
|
||||
},
|
||||
];
|
||||
const displayTaskCategories = [TaskCategory.Create, TaskCategory.Evaluate, TaskCategory.Label];
|
||||
|
||||
export const TaskOption = () => {
|
||||
const backgroundColor = useColorModeValue("white", "gray.700");
|
||||
|
||||
return (
|
||||
<Box className="flex flex-col gap-14" fontFamily="inter">
|
||||
<div>
|
||||
<Text className="text-2xl font-bold pb-4">Create</Text>
|
||||
<SimpleGrid columns={[1, 2, 2, 3, 4]} gap={4}>
|
||||
{crTasks.map((item, itemIndex) => (
|
||||
<Link key={itemIndex} href={item.pathname}>
|
||||
<GridItem
|
||||
bg={backgroundColor}
|
||||
borderRadius="xl"
|
||||
boxShadow="base"
|
||||
className="flex flex-col justify-between h-full"
|
||||
>
|
||||
<Box className="p-6 pb-10">
|
||||
<Flex flexDir="column" gap="3">
|
||||
<Heading size="md" fontFamily="inter">
|
||||
{item.label}
|
||||
</Heading>
|
||||
<Text size="sm" opacity="80%">
|
||||
{item.desc}
|
||||
</Text>
|
||||
</Flex>
|
||||
</Box>
|
||||
<Box
|
||||
bg="blue.500"
|
||||
borderBottomRadius="xl"
|
||||
className="px-6 py-2 transition-colors duration-300"
|
||||
_hover={{ backgroundColor: "blue.600" }}
|
||||
{displayTaskCategories.map((category, categoryIndex) => (
|
||||
<div key={categoryIndex}>
|
||||
<Text className="text-2xl font-bold pb-4">{category}</Text>
|
||||
<SimpleGrid columns={[1, 2, 2, 3, 4]} gap={4}>
|
||||
{TaskTypes.filter((task) => task.category === category).map((item, itemIndex) => (
|
||||
<Link key={itemIndex} href={item.pathname}>
|
||||
<GridItem
|
||||
bg={backgroundColor}
|
||||
borderRadius="xl"
|
||||
boxShadow="base"
|
||||
className="flex flex-col justify-between h-full"
|
||||
>
|
||||
<Text fontWeight="bold" color="white">
|
||||
Go
|
||||
</Text>
|
||||
</Box>
|
||||
</GridItem>
|
||||
</Link>
|
||||
))}
|
||||
</SimpleGrid>
|
||||
</div>
|
||||
<div>
|
||||
<Text className="text-2xl font-bold pb-4">Evaluate</Text>
|
||||
<SimpleGrid columns={[1, 2, 2, 3, 4]} gap={4}>
|
||||
{evTasks.map((item, itemIndex) => (
|
||||
<Link key={itemIndex} href={item.pathname}>
|
||||
<GridItem
|
||||
bg={backgroundColor}
|
||||
borderRadius="xl"
|
||||
boxShadow="base"
|
||||
className="flex flex-col justify-between h-full"
|
||||
>
|
||||
<Box className="p-6 pb-10">
|
||||
<Flex flexDir="column" gap="3">
|
||||
<Heading size="md" fontFamily="inter">
|
||||
{item.label}
|
||||
</Heading>
|
||||
<Text size="sm" opacity="80%">
|
||||
{item.desc}
|
||||
<Box className="p-6 pb-10">
|
||||
<Flex flexDir="column" gap="3">
|
||||
<Heading size="md" fontFamily="inter">
|
||||
{item.label}
|
||||
</Heading>
|
||||
<Text size="sm" opacity="80%">
|
||||
{item.desc}
|
||||
</Text>
|
||||
</Flex>
|
||||
</Box>
|
||||
<Box
|
||||
bg="blue.500"
|
||||
borderBottomRadius="xl"
|
||||
className="px-6 py-2 transition-colors duration-300"
|
||||
_hover={{ backgroundColor: "blue.600" }}
|
||||
>
|
||||
<Text fontWeight="bold" color="white">
|
||||
Go
|
||||
</Text>
|
||||
</Flex>
|
||||
</Box>
|
||||
<Box
|
||||
bg="blue.500"
|
||||
borderBottomRadius="xl"
|
||||
className="px-6 py-2 transition-colors duration-300"
|
||||
_hover={{ backgroundColor: "blue.600" }}
|
||||
>
|
||||
<Text fontWeight="bold" color="white">
|
||||
Go
|
||||
</Text>
|
||||
</Box>
|
||||
</GridItem>
|
||||
</Link>
|
||||
))}
|
||||
</SimpleGrid>
|
||||
</div>
|
||||
</Box>
|
||||
</GridItem>
|
||||
</Link>
|
||||
))}
|
||||
</SimpleGrid>
|
||||
</div>
|
||||
))}
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,3 +1,2 @@
|
||||
export { LeaderboardTable } from "./LeaderboardTable";
|
||||
export { SideMenu } from "./SideMenu";
|
||||
export { TaskOption } from "./TaskOption";
|
||||
|
||||
@@ -24,8 +24,8 @@ import {
|
||||
import { FlagIcon, QuestionMarkCircleIcon } from "@heroicons/react/20/solid";
|
||||
import { useState } from "react";
|
||||
import poster from "src/lib/poster";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
import { colors } from "styles/Theme/colors";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
|
||||
export const FlaggableElement = (props) => {
|
||||
const [isEditing, setIsEditing] = useBoolean();
|
||||
@@ -118,7 +118,8 @@ export const FlaggableElement = (props) => {
|
||||
</Popover>
|
||||
);
|
||||
};
|
||||
function FlagCheckbox(props: {
|
||||
|
||||
export function FlagCheckbox(props: {
|
||||
option: textFlagLabels;
|
||||
idx: number;
|
||||
checkboxValues: boolean[];
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import Image from "next/image";
|
||||
import Link from "next/link";
|
||||
import { useMemo } from "react";
|
||||
|
||||
export function Footer() {
|
||||
const { colorMode } = useColorMode();
|
||||
@@ -9,7 +10,7 @@ export function Footer() {
|
||||
|
||||
return (
|
||||
<footer className={bgColorClass}>
|
||||
<div className={`flex mx-auto max-w-7xl justify-between py-10 px-10 border-t ${borderClass}`}>
|
||||
<div className={`flex mx-auto max-w-7xl justify-between border-t p-10 ${borderClass}`}>
|
||||
<div className="flex items-center pr-8">
|
||||
<Link href="/" aria-label="Home" className="flex items-center">
|
||||
<Image src="/images/logos/logo.svg" className="mx-auto object-fill" width="52" height="52" alt="logo" />
|
||||
@@ -21,50 +22,35 @@ export function Footer() {
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<nav className="flex justify-center gap-20">
|
||||
<nav className="flex justify-center gap-20">
|
||||
<div className="flex flex-col text-sm leading-7">
|
||||
<b>Legal</b>
|
||||
<div className="flex flex-col leading-5">
|
||||
<Link href="/privacy-policy" aria-label="Privacy Policy" className="hover:underline underline-offset-2">
|
||||
Privacy Policy
|
||||
</Link>
|
||||
<Link
|
||||
href="/terms-of-service"
|
||||
aria-label="Terms of Service"
|
||||
className="hover:underline underline-offset-2"
|
||||
>
|
||||
Terms of Service
|
||||
</Link>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex flex-col text-sm leading-7">
|
||||
<b>Connect</b>
|
||||
<div className="flex flex-col leading-5">
|
||||
<Link
|
||||
href="https://github.com/LAION-AI/Open-Assistant"
|
||||
rel="noopener noreferrer nofollow"
|
||||
target="_blank"
|
||||
aria-label="Privacy Policy"
|
||||
className="hover:underline underline-offset-2"
|
||||
>
|
||||
Github
|
||||
</Link>
|
||||
<Link
|
||||
href="https://discord.gg/pXtnYk9c"
|
||||
rel="noopener noreferrer nofollow"
|
||||
target="_blank"
|
||||
aria-label="Terms of Service"
|
||||
className="hover:underline underline-offset-2"
|
||||
>
|
||||
Discord
|
||||
</Link>
|
||||
</div>
|
||||
</div>
|
||||
</nav>
|
||||
{/* </div> */}
|
||||
<nav className="grid grid-cols-2 gap-20 leading-5 text-sm">
|
||||
<div className="flex flex-col">
|
||||
<b className="pb-1">Legal</b>
|
||||
<FooterLink href="/privacy-policy" label="Privacy Policy" />
|
||||
<FooterLink href="/terms-of-service" label="Terms of Service" />
|
||||
</div>
|
||||
<div className="flex flex-col">
|
||||
<b className="pb-1">Connect</b>
|
||||
<FooterLink href="https://github.com/LAION-AI/Open-Assistant" label="Github" />
|
||||
<FooterLink href="https://ykilcher.com/open-assistant-discord" label="Discord" />
|
||||
</div>
|
||||
</nav>
|
||||
</div>
|
||||
</footer>
|
||||
);
|
||||
}
|
||||
|
||||
const FooterLink = ({ href, label }: { href: string; label: string }) =>
|
||||
useMemo(
|
||||
() => (
|
||||
<Link
|
||||
href={href}
|
||||
rel="noopener noreferrer nofollow"
|
||||
target="_blank"
|
||||
aria-label={label}
|
||||
className="hover:underline underline-offset-2"
|
||||
>
|
||||
{label}
|
||||
</Link>
|
||||
),
|
||||
[href, label]
|
||||
);
|
||||
|
||||
@@ -2,6 +2,7 @@ import { Box, Link, Text, useColorModeValue } from "@chakra-ui/react";
|
||||
import { Popover } from "@headlessui/react";
|
||||
import { AnimatePresence, motion } from "framer-motion";
|
||||
import Image from "next/image";
|
||||
import NextLink from "next/link";
|
||||
import { signOut, useSession } from "next-auth/react";
|
||||
import React from "react";
|
||||
import { FiLayout, FiLogOut, FiSettings } from "react-icons/fi";
|
||||
@@ -77,6 +78,7 @@ export function UserMenu() {
|
||||
<Box className="flex flex-col gap-1">
|
||||
{accountOptions.map((item) => (
|
||||
<Link
|
||||
as={NextLink}
|
||||
key={item.name}
|
||||
href={item.href}
|
||||
aria-label={item.desc}
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
// https://nextjs.org/docs/basic-features/layouts
|
||||
|
||||
import type { NextPage } from "next";
|
||||
import { FiLayout, FiMessageSquare, FiUsers } from "react-icons/fi";
|
||||
import { Header } from "src/components/Header";
|
||||
|
||||
import { Footer } from "./Footer";
|
||||
import { SideMenuLayout } from "./SideMenuLayout";
|
||||
|
||||
export type NextPageWithLayout<P = unknown, IP = P> = NextPage<P, IP> & {
|
||||
getLayout?: (page: React.ReactElement) => React.ReactNode;
|
||||
@@ -28,7 +30,42 @@ export const getTransparentHeaderLayout = (page: React.ReactElement) => (
|
||||
export const getDashboardLayout = (page: React.ReactElement) => (
|
||||
<div className="grid grid-rows-[min-content_1fr_min-content] h-full justify-items-stretch">
|
||||
<Header transparent={true} />
|
||||
{page}
|
||||
<SideMenuLayout
|
||||
menuButtonOptions={[
|
||||
{
|
||||
label: "Dashboard",
|
||||
pathname: "/dashboard",
|
||||
desc: "Dashboard Home",
|
||||
icon: FiLayout,
|
||||
},
|
||||
{
|
||||
label: "Messages",
|
||||
pathname: "/messages",
|
||||
desc: "Messages Dashboard",
|
||||
icon: FiMessageSquare,
|
||||
},
|
||||
]}
|
||||
>
|
||||
{page}
|
||||
</SideMenuLayout>
|
||||
</div>
|
||||
);
|
||||
|
||||
export const getAdminLayout = (page: React.ReactElement) => (
|
||||
<div className="grid grid-rows-[min-content_1fr_min-content] h-full justify-items-stretch">
|
||||
<Header transparent={true} />
|
||||
<SideMenuLayout
|
||||
menuButtonOptions={[
|
||||
{
|
||||
label: "Users",
|
||||
pathname: "/admin",
|
||||
desc: "Users Dashboard",
|
||||
icon: FiUsers,
|
||||
},
|
||||
]}
|
||||
>
|
||||
{page}
|
||||
</SideMenuLayout>
|
||||
</div>
|
||||
);
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { Grid } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import { useMemo } from "react";
|
||||
|
||||
import { FlaggableElement } from "./FlaggableElement";
|
||||
|
||||
@@ -8,29 +9,30 @@ export interface Message {
|
||||
is_assistant: boolean;
|
||||
}
|
||||
|
||||
const getBgColor = (isAssistant: boolean, colorMode: "light" | "dark") => {
|
||||
if (colorMode === "light") {
|
||||
return isAssistant ? "bg-slate-800" : "bg-sky-900";
|
||||
} else {
|
||||
return isAssistant ? "bg-black" : "bg-sky-900";
|
||||
}
|
||||
};
|
||||
|
||||
export const Messages = ({ messages, post_id }: { messages: Message[]; post_id: string }) => {
|
||||
const { colorMode } = useColorMode();
|
||||
const items = messages.map((messageProps: Message, i: number) => {
|
||||
const { text } = messageProps;
|
||||
|
||||
const items = messages.map(({ text, is_assistant }: Message, i: number) => {
|
||||
return (
|
||||
<FlaggableElement text={text} post_id={post_id} key={i + text}>
|
||||
<div
|
||||
key={i + text}
|
||||
className={`${getBgColor(is_assistant, colorMode)} p-4 rounded-md text-white whitespace-pre-wrap`}
|
||||
>
|
||||
{text}
|
||||
</div>
|
||||
<MessageView {...messageProps} />
|
||||
</FlaggableElement>
|
||||
);
|
||||
});
|
||||
// Maybe also show a legend of the colors?
|
||||
return <Grid gap={2}>{items}</Grid>;
|
||||
};
|
||||
|
||||
export const MessageView = ({ is_assistant, text }: Message) => {
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
const bgColor = useMemo(() => {
|
||||
if (colorMode === "light") {
|
||||
return is_assistant ? "bg-slate-800" : "bg-sky-900";
|
||||
} else {
|
||||
return is_assistant ? "bg-black" : "bg-sky-900";
|
||||
}
|
||||
}, [colorMode, is_assistant]);
|
||||
|
||||
return <div className={`${bgColor} p-4 rounded-md text-white whitespace-pre-wrap`}>{text}</div>;
|
||||
};
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { Box, CircularProgress, Stack, StackDivider, useColorModeValue } from "@chakra-ui/react";
|
||||
import { MessageTableEntry } from "./MessageTableEntry";
|
||||
import { Stack, StackDivider } from "@chakra-ui/react";
|
||||
import { MessageTableEntry } from "src/components/Messages/MessageTableEntry";
|
||||
|
||||
export function MessageTable({ messages }) {
|
||||
return (
|
||||
|
||||
@@ -1,9 +1,19 @@
|
||||
import { Avatar, Box, HStack, LinkBox, useColorModeValue } from "@chakra-ui/react";
|
||||
import { Avatar, HStack, LinkBox, useColorModeValue } from "@chakra-ui/react";
|
||||
import { boolean } from "boolean";
|
||||
import NextLink from "next/link";
|
||||
import { FlaggableElement } from "../FlaggableElement";
|
||||
import { FlaggableElement } from "src/components/FlaggableElement";
|
||||
|
||||
export function MessageTableEntry({ item, idx }) {
|
||||
interface Message {
|
||||
text: string;
|
||||
id: string;
|
||||
is_assistant: boolean;
|
||||
}
|
||||
interface MessageTableEntryProps {
|
||||
item: Message;
|
||||
idx: number;
|
||||
}
|
||||
export function MessageTableEntry(props: MessageTableEntryProps) {
|
||||
const { item, idx } = props;
|
||||
const bgColor = useColorModeValue(idx % 2 === 0 ? "bg-slate-800" : "bg-black", "bg-sky-900");
|
||||
|
||||
return (
|
||||
|
||||
@@ -0,0 +1,104 @@
|
||||
import { Box, CircularProgress, Flex, HStack, StackDivider, StackProps, Text, TextProps } from "@chakra-ui/react";
|
||||
import { boolean } from "boolean";
|
||||
import { useState } from "react";
|
||||
import { MessageTableEntry } from "src/components/Messages/MessageTableEntry";
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import useSWR from "swr";
|
||||
|
||||
const MessageHeaderProps: TextProps = {
|
||||
align: "center",
|
||||
fontSize: "xl",
|
||||
py: "2",
|
||||
};
|
||||
|
||||
const MessageStackProps: StackProps = {
|
||||
spacing: "2",
|
||||
alignItems: "start",
|
||||
justifyContent: "center",
|
||||
divider: <StackDivider />,
|
||||
};
|
||||
|
||||
interface MessageWithChildrenProps {
|
||||
id: string;
|
||||
depth?: number;
|
||||
maxDepth?: number;
|
||||
isOnlyChild?: boolean;
|
||||
}
|
||||
|
||||
export function MessageWithChildren(props: MessageWithChildrenProps) {
|
||||
const { id, depth, maxDepth, isOnlyChild = true } = props;
|
||||
|
||||
const [message, setMessage] = useState(null);
|
||||
const [children, setChildren] = useState(null);
|
||||
|
||||
const { isLoading } = useSWR(id ? `/api/messages/${id}` : null, fetcher, {
|
||||
onSuccess: (data) => {
|
||||
setMessage(data);
|
||||
},
|
||||
onError: () => {
|
||||
setMessage(null);
|
||||
},
|
||||
});
|
||||
const { isLoading: isLoadingChildren } = useSWR(id ? `/api/messages/${id}/children` : null, fetcher, {
|
||||
onSuccess: (data) => {
|
||||
setChildren(data);
|
||||
},
|
||||
onError: () => {
|
||||
setChildren(null);
|
||||
},
|
||||
});
|
||||
|
||||
const renderRecursive = maxDepth && ((depth && depth < maxDepth) || !depth);
|
||||
const isFirst = depth === 0 || !depth;
|
||||
const isFirstOrOnly = isFirst || boolean(isOnlyChild);
|
||||
|
||||
if (isLoading || isLoadingChildren) {
|
||||
return <CircularProgress isIndeterminate />;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
{message && (
|
||||
<>
|
||||
<Text {...MessageHeaderProps}>{isFirst ? "Message" : depth === 1 ? "Children" : "Ancestor"}</Text>
|
||||
<Flex justifyContent="center" pb="2">
|
||||
<Box maxWidth="container.sm" flex="1" px={isFirstOrOnly ? [4, 6, 8, 9] : "0"}>
|
||||
<Box px={isFirstOrOnly ? "2" : "0"}>
|
||||
<MessageTableEntry item={message} idx={1} />
|
||||
</Box>
|
||||
</Box>
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
{children && Array.isArray(children) && children.length > 0 ? (
|
||||
renderRecursive ? (
|
||||
<HStack {...MessageStackProps}>
|
||||
{children.map((item, idx) => (
|
||||
<Box flex="1" key={`recursiveMessageWChildren_${idx}`}>
|
||||
<MessageWithChildren
|
||||
id={item.id}
|
||||
depth={depth ? depth + 1 : 1}
|
||||
maxDepth={maxDepth}
|
||||
isOnlyChild={children.length === 1 && isOnlyChild}
|
||||
/>
|
||||
</Box>
|
||||
))}
|
||||
</HStack>
|
||||
) : (
|
||||
<>
|
||||
<Text {...MessageHeaderProps}>{isFirstOrOnly ? "Children" : "Ancestor"}</Text>
|
||||
<HStack {...MessageStackProps}>
|
||||
{children.map((item, idx) => (
|
||||
<Box maxWidth="container.sm" flex="1" key={`recursiveMessageWChildren_${idx}`}>
|
||||
<MessageTableEntry item={item} idx={idx * 2} />
|
||||
</Box>
|
||||
))}
|
||||
</HStack>
|
||||
</>
|
||||
)
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
+17
-30
@@ -1,37 +1,24 @@
|
||||
import { Box, Button, Link, Text, Tooltip, useColorMode } from "@chakra-ui/react";
|
||||
import { Box, Button, Text, Tooltip, useColorMode } from "@chakra-ui/react";
|
||||
import Link from "next/link";
|
||||
import { useRouter } from "next/router";
|
||||
import { FiLayout, FiSun, FiMessageSquare } from "react-icons/fi";
|
||||
import { FiSun } from "react-icons/fi";
|
||||
import { IconType } from "react-icons/lib";
|
||||
import { colors } from "styles/Theme/colors";
|
||||
|
||||
export function SideMenu() {
|
||||
export interface MenuButtonOption {
|
||||
label: string;
|
||||
pathname: string;
|
||||
desc: string;
|
||||
icon: IconType;
|
||||
}
|
||||
|
||||
export interface SideMenuProps {
|
||||
buttonOptions: MenuButtonOption[];
|
||||
}
|
||||
|
||||
export function SideMenu(props: SideMenuProps) {
|
||||
const router = useRouter();
|
||||
const { colorMode, toggleColorMode } = useColorMode();
|
||||
const buttonOptions = [
|
||||
{
|
||||
label: "Dashboard",
|
||||
pathname: "/dashboard",
|
||||
desc: "Dashboard Home",
|
||||
icon: FiLayout,
|
||||
},
|
||||
{
|
||||
label: "Messages",
|
||||
pathname: "/messages",
|
||||
desc: "Messages Dashboard",
|
||||
icon: FiMessageSquare,
|
||||
},
|
||||
// {
|
||||
// label: "Leaderboard",
|
||||
// pathname: "#",
|
||||
// desc: "Public Leaderboard",
|
||||
// icon: FiAward,
|
||||
// },
|
||||
// {
|
||||
// label: "Stats",
|
||||
// pathname: "#",
|
||||
// desc: "User Statistics",
|
||||
// icon: FiBarChart2,
|
||||
// },
|
||||
];
|
||||
|
||||
return (
|
||||
<main className="sticky top-0 sm:h-full">
|
||||
@@ -43,7 +30,7 @@ export function SideMenu() {
|
||||
className="grid grid-cols-4 gap-2 sm:flex sm:flex-col sm:justify-between p-4 h-full"
|
||||
>
|
||||
<nav className="grid grid-cols-3 col-span-3 sm:flex sm:flex-col gap-2">
|
||||
{buttonOptions.map((item, itemIndex) => (
|
||||
{props.buttonOptions.map((item, itemIndex) => (
|
||||
<Tooltip
|
||||
key={itemIndex}
|
||||
fontFamily="inter"
|
||||
@@ -0,0 +1,23 @@
|
||||
import { Box, useColorMode } from "@chakra-ui/react";
|
||||
import { MenuButtonOption, SideMenu } from "src/components/SideMenu";
|
||||
import { colors } from "styles/Theme/colors";
|
||||
|
||||
interface SideMenuLayoutProps {
|
||||
menuButtonOptions: MenuButtonOption[];
|
||||
children: React.ReactNode;
|
||||
}
|
||||
|
||||
export const SideMenuLayout = (props: SideMenuLayoutProps) => {
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
return (
|
||||
<Box backgroundColor={colorMode === "light" ? colors.light.bg : colors.dark.bg} className="sm:overflow-hidden">
|
||||
<Box className="sm:flex h-full gap-6">
|
||||
<Box className="p-6 sm:pr-0">
|
||||
<SideMenu buttonOptions={props.menuButtonOptions} />
|
||||
</Box>
|
||||
<Box className="flex flex-col overflow-auto p-6 sm:pl-0 gap-14">{props.children}</Box>
|
||||
</Box>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
@@ -1,5 +1,6 @@
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import { Flex } from "@chakra-ui/react";
|
||||
import clsx from "clsx";
|
||||
import { SkipButton } from "src/components/Buttons/Skip";
|
||||
import { SubmitButton } from "src/components/Buttons/Submit";
|
||||
import { TaskInfo } from "src/components/TaskInfo/TaskInfo";
|
||||
@@ -14,18 +15,20 @@ export interface TaskControlsProps {
|
||||
}
|
||||
|
||||
export const TaskControls = (props: TaskControlsProps) => {
|
||||
const extraClases = props.className || "";
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
const baseClasses = "flex flex-row justify-items-stretch mb-8 p-4 rounded-lg max-w-7xl mx-auto";
|
||||
const taskControlClases =
|
||||
colorMode === "light"
|
||||
? `${baseClasses} bg-white text-gray-800 shadow-lg ${extraClases}`
|
||||
: `${baseClasses} bg-slate-800 text-slate-400 shadow-xl ring-1 ring-white/10 ring-inset ${extraClases}`;
|
||||
|
||||
const isLightMode = colorMode === "light";
|
||||
const endTask = props.tasks[props.tasks.length - 1];
|
||||
return (
|
||||
<section className={taskControlClases}>
|
||||
<section
|
||||
className={clsx(
|
||||
"flex-row justify-items-stretch mb-8 p-4 rounded-lg max-w-7xl mx-auto space-y-4 sm:space-y-0 sm:flex",
|
||||
props.className,
|
||||
{
|
||||
"bg-white text-gray-800 shadow-lg": isLightMode,
|
||||
"bg-slate-800 text-slate-400 shadow-xl ring-1 ring-white/10 ring-inset": !isLightMode,
|
||||
}
|
||||
)}
|
||||
>
|
||||
<TaskInfo id={props.tasks[0].id} output="Submit your answer" />
|
||||
<Flex justify="center" ml="auto" gap={2}>
|
||||
<SkipButton>Skip</SkipButton>
|
||||
|
||||
@@ -10,7 +10,7 @@ import {
|
||||
ModalOverlay,
|
||||
useDisclosure,
|
||||
} from "@chakra-ui/react";
|
||||
import { TaskControls, TaskControlsProps } from "./TaskControls";
|
||||
import { TaskControls, TaskControlsProps } from "src/components/Survey/TaskControls";
|
||||
|
||||
interface TaskControlsOverridableProps extends TaskControlsProps {
|
||||
isValid: boolean;
|
||||
|
||||
@@ -12,7 +12,7 @@ interface TrackedTextboxProps {
|
||||
}
|
||||
|
||||
export const TrackedTextarea = (props: TrackedTextboxProps) => {
|
||||
const wordCount = props.text.split(" ").length - 1;
|
||||
const wordCount = (props.text.match(/\w+/g) || []).length;
|
||||
|
||||
let progressColor: string;
|
||||
switch (true) {
|
||||
@@ -28,7 +28,7 @@ export const TrackedTextarea = (props: TrackedTextboxProps) => {
|
||||
|
||||
return (
|
||||
<Stack direction={"column"}>
|
||||
<Textarea data-cy="reply" value={props.text} onChange={props.onTextChange} {...props.textareaProps} onCapture />
|
||||
<Textarea data-cy="reply" value={props.text} onChange={props.onTextChange} {...props.textareaProps} />
|
||||
<Progress size={"md"} rounded={"md"} value={wordCount} colorScheme={progressColor} max={props.thresholds.goal} />
|
||||
</Stack>
|
||||
);
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
export const TaskInfo = ({ id, output }: { id: string; output: string }) => {
|
||||
return (
|
||||
<div className="grid grid-cols-[min-content_auto] gap-x-2 ">
|
||||
<div className="grid grid-cols-[min-content_auto] gap-x-2">
|
||||
<b>Prompt</b>
|
||||
<span data-cy="task-id">{id}</span>
|
||||
<b>Output</b>
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
import { Card, CardBody, Flex, Heading } from "@chakra-ui/react";
|
||||
import Image from "next/image";
|
||||
import Link from "next/link";
|
||||
|
||||
export type OptionProps = {
|
||||
img: string;
|
||||
alt: string;
|
||||
title: string;
|
||||
link: string;
|
||||
};
|
||||
|
||||
export const TaskOption = (props: OptionProps) => {
|
||||
const { alt, img, title, link } = props;
|
||||
return (
|
||||
<Link href={link}>
|
||||
<Card
|
||||
maxW="300"
|
||||
minW="300"
|
||||
minH="300"
|
||||
maxH="300"
|
||||
className="transition ease-in-out duration-500 sm:grayscale hover:grayscale-0"
|
||||
>
|
||||
<CardBody width="full" height="full">
|
||||
<Flex direction="column" alignItems="center" justifyContent="center">
|
||||
<Image src={img} alt={alt} width={200} height={200} />
|
||||
<Heading
|
||||
mt={-10}
|
||||
className="bg-gradient-to-r from-indigo-600 via-sky-400 to-indigo-700 bg-clip-text tracking-tight text-transparent"
|
||||
textAlign="center"
|
||||
fontSize="3xl"
|
||||
>
|
||||
{title}
|
||||
</Heading>
|
||||
</Flex>
|
||||
</CardBody>
|
||||
</Card>
|
||||
</Link>
|
||||
);
|
||||
};
|
||||
@@ -1,23 +0,0 @@
|
||||
import { Divider, Flex, Heading } from "@chakra-ui/react";
|
||||
import React from "react";
|
||||
|
||||
export type TaskOptionsProps = {
|
||||
title: string;
|
||||
children: JSX.Element | JSX.Element[];
|
||||
};
|
||||
|
||||
export const TaskOptions = (props: TaskOptionsProps) => {
|
||||
const { title, children } = props;
|
||||
return (
|
||||
<Flex gap={10} wrap="wrap" justifyContent="center">
|
||||
<Heading
|
||||
className="bg-gradient-to-r from-indigo-600 via-sky-400 to-indigo-700 bg-clip-text tracking-tight text-transparent"
|
||||
fontSize="5xl"
|
||||
>
|
||||
{title}
|
||||
</Heading>
|
||||
<Divider mt={-8} />
|
||||
{children}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
@@ -1,73 +0,0 @@
|
||||
import { Flex } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import React from "react";
|
||||
|
||||
import { TaskOption } from "./TaskOption";
|
||||
import { TaskOptions } from "./TaskOptions";
|
||||
|
||||
export const TaskSelection = () => {
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
return (
|
||||
<Flex
|
||||
gap={10}
|
||||
wrap="wrap"
|
||||
justifyContent="space-evenly"
|
||||
width="full"
|
||||
height="full"
|
||||
alignItems={"center"}
|
||||
className={mainBgClasses}
|
||||
>
|
||||
<TaskOptions key="create" title="Create">
|
||||
{/* <TaskOption
|
||||
alt="Summarize Stories"
|
||||
img="/images/logos/logo.svg"
|
||||
title="Summarize stories"
|
||||
link="/create/summarize_story"
|
||||
/> */}
|
||||
<TaskOption
|
||||
alt="Create Initial Prompt"
|
||||
img="/images/logos/logo.svg"
|
||||
title="Create Initial Prompt"
|
||||
link="/create/initial_prompt"
|
||||
/>
|
||||
<TaskOption alt="Reply as User" img="/images/logos/logo.svg" title="Reply as User" link="/create/user_reply" />
|
||||
<TaskOption
|
||||
alt="Reply as Assistant"
|
||||
img="/images/logos/logo.svg"
|
||||
title="Reply as Assistant"
|
||||
link="/create/assistant_reply"
|
||||
/>
|
||||
</TaskOptions>
|
||||
<TaskOptions key="evaluate" title="Evaluate">
|
||||
{/*
|
||||
Commented out while the backend does not support them.
|
||||
<TaskOption
|
||||
alt="Rate Prompts"
|
||||
img="/images/logos/logo.svg"
|
||||
title="Rate Prompts"
|
||||
link="/evaluate/rate_summary"
|
||||
/> */}
|
||||
<TaskOption
|
||||
alt="Rank Initial Prompts"
|
||||
img="/images/logos/logo.svg"
|
||||
title="Rank Initial Prompts"
|
||||
link="/evaluate/rank_initial_prompts"
|
||||
/>
|
||||
<TaskOption
|
||||
alt="Rank User Replies"
|
||||
img="/images/logos/logo.svg"
|
||||
title="Rank User Replies"
|
||||
link="/evaluate/rank_user_replies"
|
||||
/>
|
||||
<TaskOption
|
||||
alt="Rank Assistant Replies"
|
||||
img="/images/logos/logo.svg"
|
||||
title="Rank Assistant Replies"
|
||||
link="/evaluate/rank_assistant_replies"
|
||||
/>
|
||||
</TaskOptions>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
@@ -1,3 +0,0 @@
|
||||
export { TaskOption } from "./TaskOption";
|
||||
export { TaskOptions } from "./TaskOptions";
|
||||
export { TaskSelection } from "./TaskSelection";
|
||||
@@ -0,0 +1,54 @@
|
||||
import { useState } from "react";
|
||||
import { Messages } from "src/components/Messages";
|
||||
import { TaskControls } from "src/components/Survey/TaskControls";
|
||||
import { TrackedTextarea } from "src/components/Survey/TrackedTextarea";
|
||||
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
|
||||
|
||||
export const CreateTask = ({ tasks, taskType, trigger, mutate, mainBgClasses }) => {
|
||||
const task = tasks[0].task;
|
||||
|
||||
const [inputText, setInputText] = useState("");
|
||||
|
||||
const submitResponse = (task: { id: string }) => {
|
||||
const text = inputText.trim();
|
||||
trigger({
|
||||
id: task.id,
|
||||
update_type: "text_reply_to_message",
|
||||
content: {
|
||||
text,
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
const fetchNextTask = () => {
|
||||
setInputText("");
|
||||
mutate();
|
||||
};
|
||||
|
||||
const textChangeHandler = (event: React.ChangeEvent<HTMLTextAreaElement>) => {
|
||||
setInputText(event.target.value);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className={`p-12 ${mainBgClasses}`}>
|
||||
<TwoColumnsWithCards>
|
||||
<>
|
||||
<h5 className="text-lg font-semibold">{taskType.label}</h5>
|
||||
<p className="text-lg py-1">{taskType.overview}</p>
|
||||
{task.conversation ? <Messages messages={task.conversation.messages} post_id={task.id} /> : null}
|
||||
</>
|
||||
<>
|
||||
<h5 className="text-lg font-semibold">{taskType.instruction}</h5>
|
||||
<TrackedTextarea
|
||||
text={inputText}
|
||||
onTextChange={textChangeHandler}
|
||||
thresholds={{ low: 20, medium: 40, goal: 50 }}
|
||||
textareaProps={{ placeholder: "Reply..." }}
|
||||
/>
|
||||
</>
|
||||
</TwoColumnsWithCards>
|
||||
|
||||
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={fetchNextTask} />
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,52 @@
|
||||
import { useState } from "react";
|
||||
import { Sortable } from "src/components/Sortable/Sortable";
|
||||
import { SurveyCard } from "src/components/Survey/SurveyCard";
|
||||
import { TaskControlsOverridable } from "src/components/Survey/TaskControlsOverridable";
|
||||
|
||||
import { MessageTable } from "../Messages/MessageTable";
|
||||
|
||||
export const EvaluateTask = ({ tasks, trigger, mutate, mainBgClasses }) => {
|
||||
const [ranking, setRanking] = useState<number[]>([]);
|
||||
const submitResponse = (task) => {
|
||||
trigger({
|
||||
id: task.id,
|
||||
update_type: "message_ranking",
|
||||
content: {
|
||||
ranking,
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
const fetchNextTask = () => {
|
||||
setRanking([]);
|
||||
mutate();
|
||||
};
|
||||
let messages = null;
|
||||
if (tasks[0].task.conversation) {
|
||||
messages = tasks[0].task.conversation.messages;
|
||||
messages = messages.map((message, index) => ({ ...message, id: index }));
|
||||
}
|
||||
|
||||
const sortables = tasks[0].task.replies ? "replies" : "prompts";
|
||||
|
||||
return (
|
||||
<div className={`p-12 ${mainBgClasses}`}>
|
||||
<SurveyCard className="max-w-7xl mx-auto h-fit mb-24">
|
||||
<h5 className="text-lg font-semibold mb-4">Instructions</h5>
|
||||
<p className="text-lg py-1">
|
||||
Given the following {sortables}, sort them from best to worst, best being first, worst being last.
|
||||
</p>
|
||||
{messages ? <MessageTable messages={messages} /> : null}
|
||||
<Sortable items={tasks[0].task[sortables]} onChange={setRanking} className="my-8" />
|
||||
</SurveyCard>
|
||||
|
||||
<TaskControlsOverridable
|
||||
tasks={tasks}
|
||||
isValid={ranking.length == tasks[0].task[sortables].length}
|
||||
prepareForSubmit={() => setRanking(tasks[0].task[sortables].map((_, idx) => idx))}
|
||||
onSubmitResponse={submitResponse}
|
||||
onSkip={fetchNextTask}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,28 @@
|
||||
import { CreateTask } from "./CreateTask";
|
||||
import { EvaluateTask } from "./EvaluateTask";
|
||||
import { TaskCategory, TaskTypes } from "./TaskTypes";
|
||||
|
||||
export const Task = ({ tasks, trigger, mutate, mainBgClasses }) => {
|
||||
const task = tasks[0].task;
|
||||
|
||||
function taskTypeComponent(type) {
|
||||
const taskType = TaskTypes.find((taskType) => taskType.type === type);
|
||||
const category = taskType.category;
|
||||
switch (category) {
|
||||
case TaskCategory.Create:
|
||||
return (
|
||||
<CreateTask
|
||||
tasks={tasks}
|
||||
trigger={trigger}
|
||||
mutate={mutate}
|
||||
taskType={taskType}
|
||||
mainBgClasses={mainBgClasses}
|
||||
/>
|
||||
);
|
||||
case TaskCategory.Evaluate:
|
||||
return <EvaluateTask tasks={tasks} trigger={trigger} mutate={mutate} mainBgClasses={mainBgClasses} />;
|
||||
}
|
||||
}
|
||||
|
||||
return taskTypeComponent(task.type);
|
||||
};
|
||||
@@ -0,0 +1,66 @@
|
||||
export enum TaskCategory {
|
||||
Create = "Create",
|
||||
Evaluate = "Evaluate",
|
||||
Label = "Label",
|
||||
}
|
||||
|
||||
export const TaskTypes = [
|
||||
// create
|
||||
{
|
||||
label: "Create Initial Prompts",
|
||||
desc: "Write initial prompts to help Open Assistant to try replying to diverse messages.",
|
||||
category: TaskCategory.Create,
|
||||
pathname: "/create/initial_prompt",
|
||||
type: "initial_prompt",
|
||||
overview: "Create an initial message to send to the assistant",
|
||||
instruction: "Provide the initial prompt",
|
||||
},
|
||||
{
|
||||
label: "Reply as User",
|
||||
desc: "Chat with Open Assistant and help improve it’s responses as you interact with it.",
|
||||
category: TaskCategory.Create,
|
||||
pathname: "/create/user_reply",
|
||||
type: "prompter_reply",
|
||||
overview: "Given the following conversation, provide an adequate reply",
|
||||
instruction: "Provide the user`s reply",
|
||||
},
|
||||
{
|
||||
label: "Reply as Assistant",
|
||||
desc: "Help Open Assistant improve its responses to conversations with other users.",
|
||||
category: TaskCategory.Create,
|
||||
pathname: "/create/assistant_reply",
|
||||
type: "assistant_reply",
|
||||
overview: "Given the following conversation, provide an adequate reply",
|
||||
instruction: "Provide the assistant`s reply",
|
||||
},
|
||||
// evaluate
|
||||
{
|
||||
label: "Rank User Replies",
|
||||
category: TaskCategory.Evaluate,
|
||||
desc: "Help Open Assistant improve its responses to conversations with other users.",
|
||||
pathname: "/evaluate/rank_user_replies",
|
||||
type: "rank_prompter_replies",
|
||||
},
|
||||
{
|
||||
label: "Rank Assistant Replies",
|
||||
desc: "Score prompts given by Open Assistant based on their accuracy and readability.",
|
||||
category: TaskCategory.Evaluate,
|
||||
pathname: "/evaluate/rank_assistant_replies",
|
||||
type: "rank_assistant_replies",
|
||||
},
|
||||
{
|
||||
label: "Rank Initial Prompts",
|
||||
desc: "Score prompts given by Open Assistant based on their accuracy and readability.",
|
||||
category: TaskCategory.Evaluate,
|
||||
pathname: "/evaluate/rank_initial_prompts",
|
||||
type: "rank_initial_prompts",
|
||||
},
|
||||
// label
|
||||
{
|
||||
label: "Label Initial Prompt",
|
||||
desc: "Provide labels for a prompt.",
|
||||
category: TaskCategory.Label,
|
||||
pathname: "/label/label_initial_prompt",
|
||||
type: "label_initial_prompt",
|
||||
},
|
||||
];
|
||||
@@ -0,0 +1,44 @@
|
||||
import { Table, TableCaption, TableContainer, Tbody, Td, Th, Thead, Tr } from "@chakra-ui/react";
|
||||
import { useState } from "react";
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import useSWR from "swr";
|
||||
|
||||
/**
|
||||
* Fetches users from the users api route and then presents them in a simple Chakra table.
|
||||
*/
|
||||
const UsersCell = () => {
|
||||
// Fetch and save the users.
|
||||
const [users, setUsers] = useState([]);
|
||||
const { isLoading } = useSWR("/api/admin/users", fetcher, {
|
||||
onSuccess: setUsers,
|
||||
});
|
||||
|
||||
// Present users in a naive table.
|
||||
return (
|
||||
<TableContainer>
|
||||
<Table variant="simple">
|
||||
<TableCaption>Users</TableCaption>
|
||||
<Thead>
|
||||
<Tr>
|
||||
<Th>Id</Th>
|
||||
<Th>Email</Th>
|
||||
<Th>Name</Th>
|
||||
<Th>Role</Th>
|
||||
</Tr>
|
||||
</Thead>
|
||||
<Tbody>
|
||||
{users.map((user, index) => (
|
||||
<Tr key={index}>
|
||||
<Td>{user.id}</Td>
|
||||
<Td>{user.email}</Td>
|
||||
<Td>{user.name}</Td>
|
||||
<Td>{user.role}</Td>
|
||||
</Tr>
|
||||
))}
|
||||
</Tbody>
|
||||
</Table>
|
||||
</TableContainer>
|
||||
);
|
||||
};
|
||||
|
||||
export default UsersCell;
|
||||
@@ -1,5 +1,5 @@
|
||||
import { Container } from "./Container";
|
||||
import Image from "next/image";
|
||||
import { Container } from "src/components/Container";
|
||||
|
||||
const Vision = () => {
|
||||
return (
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
import { useEffect, useState } from "react";
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import poster from "src/lib/poster";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
|
||||
// TODO: type & centralize types for all tasks
|
||||
interface TaskResponse<TaskType> {
|
||||
id: string;
|
||||
userId: string;
|
||||
task: TaskType;
|
||||
}
|
||||
|
||||
export interface LabelInitialPromptTask {
|
||||
id: string;
|
||||
message_id: string;
|
||||
prompt: string;
|
||||
type: string;
|
||||
valid_labels: string[];
|
||||
}
|
||||
|
||||
export type LabelInitialPromptTaskResponse = TaskResponse<LabelInitialPromptTask>;
|
||||
|
||||
export const useLabelingTask = <LabelingTaskType>({ taskApiEndpoint }: { taskApiEndpoint: "label_initial_prompt" }) => {
|
||||
type ConcreteTaskResponse = TaskResponse<LabelingTaskType>;
|
||||
|
||||
const [tasks, setTasks] = useState<Array<ConcreteTaskResponse>>([]);
|
||||
|
||||
const { isLoading, mutate, error } = useSWRImmutable("/api/new_task/" + taskApiEndpoint, fetcher, {
|
||||
onSuccess: (data: ConcreteTaskResponse) => {
|
||||
setTasks([data]);
|
||||
},
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
if (tasks.length === 0 && !isLoading && !error) {
|
||||
mutate();
|
||||
}
|
||||
}, [tasks, isLoading, mutate, error]);
|
||||
|
||||
const { trigger } = useSWRMutation("/api/update_task", poster, {
|
||||
onSuccess: async (reply) => {
|
||||
const newTask: ConcreteTaskResponse = await reply.json();
|
||||
setTasks((oldTasks) => [...oldTasks, newTask]);
|
||||
},
|
||||
});
|
||||
|
||||
const submit = (id: string, message_id: string, text: string, labels: Record<string, string>) =>
|
||||
trigger({ id, update_type: "text_labels", content: { labels, text, message_id } });
|
||||
|
||||
return { tasks, isLoading, submit, error, reset: mutate };
|
||||
};
|
||||
@@ -42,7 +42,7 @@ export class OasstApiClient {
|
||||
} catch (e) {
|
||||
throw new OasstError(errorText, 0, resp.status);
|
||||
}
|
||||
throw new OasstError(error.message, error.error_code, resp.status);
|
||||
throw new OasstError(error.message ?? error, error.error_code, resp.status);
|
||||
}
|
||||
|
||||
return await resp.json();
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import Image from "next/image";
|
||||
import { CallToAction } from "src/components/CallToAction";
|
||||
import { Container } from "src/components/Container";
|
||||
import Roadmap from "src/components/Roadmap";
|
||||
import Services from "src/components/Services";
|
||||
import Vision from "src/components/Vision";
|
||||
import Roadmap from "src/components/Roadmap";
|
||||
import { CallToAction } from "src/components/CallToAction";
|
||||
import Image from "next/image";
|
||||
|
||||
const AboutPage = () => {
|
||||
return (
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
import Head from "next/head";
|
||||
import { useRouter } from "next/router";
|
||||
import { useSession } from "next-auth/react";
|
||||
import { useEffect } from "react";
|
||||
import { getAdminLayout } from "src/components/Layout";
|
||||
import UsersCell from "src/components/UsersCell";
|
||||
|
||||
/**
|
||||
* Provides the admin index page that will display a list of users and give
|
||||
* admins the ability to manage their access rights.
|
||||
*/
|
||||
const AdminIndex = () => {
|
||||
const router = useRouter();
|
||||
const { data: session, status } = useSession();
|
||||
|
||||
// Check when the user session is loaded and re-route if the user is not an
|
||||
// admin. This follows the suggestion by NextJS for handling private pages:
|
||||
// https://nextjs.org/docs/api-reference/next/router#usage
|
||||
//
|
||||
// All admin pages should use the same check and routing steps.
|
||||
useEffect(() => {
|
||||
if (status === "loading") {
|
||||
return;
|
||||
}
|
||||
if (session?.user?.role === "admin") {
|
||||
return;
|
||||
}
|
||||
router.push("/");
|
||||
}, [session, status]);
|
||||
|
||||
// Show the final page.
|
||||
// TODO(#237): Display a component that fetches actual user data.
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<title>Open Assistant</title>
|
||||
<meta
|
||||
name="description"
|
||||
content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world."
|
||||
/>
|
||||
</Head>
|
||||
<main className="oa-basic-theme">{status === "loading" ? "loading..." : <UsersCell />}</main>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
AdminIndex.getLayout = getAdminLayout;
|
||||
|
||||
export default AdminIndex;
|
||||
@@ -0,0 +1,31 @@
|
||||
import { getToken } from "next-auth/jwt";
|
||||
import client from "src/lib/prismadb";
|
||||
|
||||
/**
|
||||
* Returns a list of user results from the database when the requesting user is
|
||||
* a logged in admin.
|
||||
*/
|
||||
const handler = async (req, res) => {
|
||||
const token = await getToken({ req });
|
||||
|
||||
// Return nothing if the user isn't registered or if the user isn't an admin.
|
||||
if (!token || token.role !== "admin") {
|
||||
res.status(403).end();
|
||||
return;
|
||||
}
|
||||
|
||||
// Fetch 20 users.
|
||||
const users = await client.user.findMany({
|
||||
select: {
|
||||
id: true,
|
||||
role: true,
|
||||
name: true,
|
||||
email: true,
|
||||
},
|
||||
take: 20,
|
||||
});
|
||||
|
||||
res.status(200).json(users);
|
||||
};
|
||||
|
||||
export default handler;
|
||||
@@ -0,0 +1,27 @@
|
||||
import { getToken } from "next-auth/jwt";
|
||||
|
||||
const handler = async (req, res) => {
|
||||
const token = await getToken({ req });
|
||||
|
||||
// Return nothing if the user isn't registered.
|
||||
if (!token) {
|
||||
res.status(401).end();
|
||||
return;
|
||||
}
|
||||
|
||||
const { id } = req.query;
|
||||
|
||||
const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages/${id}/children`, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
"X-API-Key": process.env.FASTAPI_KEY,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
const messages = await messagesRes.json();
|
||||
|
||||
// Send recieved messages to the client.
|
||||
res.status(200).json(messages);
|
||||
};
|
||||
|
||||
export default handler;
|
||||
@@ -0,0 +1,27 @@
|
||||
import { getToken } from "next-auth/jwt";
|
||||
|
||||
const handler = async (req, res) => {
|
||||
const token = await getToken({ req });
|
||||
|
||||
// Return nothing if the user isn't registered.
|
||||
if (!token) {
|
||||
res.status(401).end();
|
||||
return;
|
||||
}
|
||||
|
||||
const { id } = req.query;
|
||||
|
||||
const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages/${id}/conversation`, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
"X-API-Key": process.env.FASTAPI_KEY,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
const messages = await messagesRes.json();
|
||||
|
||||
// Send recieved messages to the client.
|
||||
res.status(200).json(messages);
|
||||
};
|
||||
|
||||
export default handler;
|
||||
@@ -0,0 +1,27 @@
|
||||
import { getToken } from "next-auth/jwt";
|
||||
|
||||
const handler = async (req, res) => {
|
||||
const token = await getToken({ req });
|
||||
|
||||
// Return nothing if the user isn't registered.
|
||||
if (!token) {
|
||||
res.status(401).end();
|
||||
return;
|
||||
}
|
||||
|
||||
const { id } = req.query;
|
||||
|
||||
const messageRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages/${id}`, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
"X-API-Key": process.env.FASTAPI_KEY,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
const message = await messageRes.json();
|
||||
|
||||
// Send recieved messages to the client.
|
||||
res.status(200).json(message);
|
||||
};
|
||||
|
||||
export default handler;
|
||||
@@ -0,0 +1,48 @@
|
||||
import { getToken } from "next-auth/jwt";
|
||||
|
||||
const handler = async (req, res) => {
|
||||
const token = await getToken({ req });
|
||||
|
||||
// Return nothing if the user isn't registered.
|
||||
if (!token) {
|
||||
res.status(401).end();
|
||||
return;
|
||||
}
|
||||
|
||||
const { id } = req.query;
|
||||
|
||||
if (!id) {
|
||||
res.status(400).end();
|
||||
return;
|
||||
}
|
||||
|
||||
const messageRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages/${id}`, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
"X-API-Key": process.env.FASTAPI_KEY,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
|
||||
const message = await messageRes.json();
|
||||
|
||||
if (!message.parent_id) {
|
||||
res.status(404).end();
|
||||
return;
|
||||
}
|
||||
|
||||
const parentRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages/${message.parent_id}`, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
"X-API-Key": process.env.FASTAPI_KEY,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
|
||||
const parent = await parentRes.json();
|
||||
|
||||
// Send recieved messages to the client.
|
||||
res.status(200).json(parent);
|
||||
};
|
||||
|
||||
export default handler;
|
||||
@@ -35,7 +35,12 @@ const handler = async (req, res) => {
|
||||
},
|
||||
});
|
||||
|
||||
const newTask = await oasstApiClient.interactTask(update_type, id, interaction.id, content, token);
|
||||
let newTask;
|
||||
try {
|
||||
newTask = await oasstApiClient.interactTask(update_type, id, interaction.id, content, token);
|
||||
} catch (err) {
|
||||
return res.status(500).json(err);
|
||||
}
|
||||
|
||||
// Stores the new task with our database.
|
||||
const newRegisteredTask = await prisma.registeredTask.create({
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { getSession } from "next-auth/react";
|
||||
import prisma from "../../lib/prismadb";
|
||||
import prisma from "src/lib/prismadb";
|
||||
|
||||
// POST /api/post
|
||||
// Required fields in body: title
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
import { Container } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { useEffect, useState } from "react";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Messages } from "src/components/Messages";
|
||||
import { TaskControls } from "src/components/Survey/TaskControls";
|
||||
import { TrackedTextarea } from "src/components/Survey/TrackedTextarea";
|
||||
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import poster from "src/lib/poster";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
@@ -13,7 +11,6 @@ import useSWRMutation from "swr/mutation";
|
||||
|
||||
const AssistantReply = () => {
|
||||
const [tasks, setTasks] = useState([]);
|
||||
const [inputText, setInputText] = useState("");
|
||||
|
||||
const { isLoading, mutate } = useSWRImmutable("/api/new_task/assistant_reply ", fetcher, {
|
||||
onSuccess: (data) => {
|
||||
@@ -34,26 +31,6 @@ const AssistantReply = () => {
|
||||
},
|
||||
});
|
||||
|
||||
const submitResponse = (task: { id: string }) => {
|
||||
const text = inputText.trim();
|
||||
trigger({
|
||||
id: task.id,
|
||||
update_type: "text_reply_to_message",
|
||||
content: {
|
||||
text,
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
const fetchNextTask = () => {
|
||||
setInputText("");
|
||||
mutate();
|
||||
};
|
||||
|
||||
const textChangeHandler = (event: React.ChangeEvent<HTMLTextAreaElement>) => {
|
||||
setInputText(event.target.value);
|
||||
};
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
@@ -65,29 +42,14 @@ const AssistantReply = () => {
|
||||
return <Container className="p-6 text-center text-gray-800">No tasks found...</Container>;
|
||||
}
|
||||
|
||||
const task = tasks[0].task;
|
||||
|
||||
return (
|
||||
<div className={`p-12 ${mainBgClasses}`}>
|
||||
<TwoColumnsWithCards>
|
||||
<>
|
||||
<h5 className="text-lg font-semibold">Reply as the assistant</h5>
|
||||
<p className="text-lg py-1">Given the following conversation, provide an adequate reply</p>
|
||||
<Messages messages={task.conversation.messages} post_id={task.id} />
|
||||
</>
|
||||
<>
|
||||
<h5 className="text-lg font-semibold">Provide the assistant`s reply</h5>
|
||||
<TrackedTextarea
|
||||
text={inputText}
|
||||
onTextChange={textChangeHandler}
|
||||
thresholds={{ low: 20, medium: 40, goal: 50 }}
|
||||
textareaProps={{ placeholder: "Reply..." }}
|
||||
/>
|
||||
</>
|
||||
</TwoColumnsWithCards>
|
||||
|
||||
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={fetchNextTask} />
|
||||
</div>
|
||||
<>
|
||||
<Head>
|
||||
<title>Reply as Assistant</title>
|
||||
<meta name="description" content="Reply as Assistant." />
|
||||
</Head>
|
||||
<Task tasks={tasks} trigger={trigger} mutate={mutate} mainBgClasses={mainBgClasses} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import { Container } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { useEffect, useState } from "react";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { TaskControls } from "src/components/Survey/TaskControls";
|
||||
import { TrackedTextarea } from "src/components/Survey/TrackedTextarea";
|
||||
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import poster from "src/lib/poster";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
@@ -12,7 +11,6 @@ import useSWRMutation from "swr/mutation";
|
||||
|
||||
const InitialPrompt = () => {
|
||||
const [tasks, setTasks] = useState([]);
|
||||
const [inputText, setInputText] = useState("");
|
||||
|
||||
const { isLoading, mutate } = useSWRImmutable("/api/new_task/initial_prompt ", fetcher, {
|
||||
onSuccess: (data) => {
|
||||
@@ -33,26 +31,6 @@ const InitialPrompt = () => {
|
||||
}
|
||||
}, [tasks]);
|
||||
|
||||
const submitResponse = (task: { id: string }) => {
|
||||
const text = inputText.trim();
|
||||
trigger({
|
||||
id: task.id,
|
||||
update_type: "text_reply_to_message",
|
||||
content: {
|
||||
text,
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
const fetchNextTask = () => {
|
||||
setInputText("");
|
||||
mutate();
|
||||
};
|
||||
|
||||
const textChangeHandler = (event: React.ChangeEvent<HTMLTextAreaElement>) => {
|
||||
setInputText(event.target.value);
|
||||
};
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
@@ -65,25 +43,13 @@ const InitialPrompt = () => {
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={`p-12 ${mainBgClasses}`}>
|
||||
<TwoColumnsWithCards>
|
||||
<>
|
||||
<h5 className="text-lg font-semibold">Start a conversation</h5>
|
||||
<p className="text-lg py-1">Create an initial message to send to the assistant</p>
|
||||
</>
|
||||
<>
|
||||
<h5 className="text-lg font-semibold">Provide the initial prompt</h5>
|
||||
<TrackedTextarea
|
||||
text={inputText}
|
||||
onTextChange={textChangeHandler}
|
||||
thresholds={{ low: 20, medium: 40, goal: 50 }}
|
||||
textareaProps={{ placeholder: "Question, task, greeting or similar..." }}
|
||||
/>
|
||||
</>
|
||||
</TwoColumnsWithCards>
|
||||
|
||||
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={fetchNextTask} />
|
||||
</div>
|
||||
<>
|
||||
<Head>
|
||||
<title>Reply as Assistant</title>
|
||||
<meta name="description" content="Reply as Assistant." />
|
||||
</Head>
|
||||
<Task tasks={tasks} trigger={trigger} mutate={mutate} mainBgClasses={mainBgClasses} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { useEffect, useState } from "react";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Messages } from "src/components/Messages";
|
||||
import { TaskControls } from "src/components/Survey/TaskControls";
|
||||
import { TrackedTextarea } from "src/components/Survey/TrackedTextarea";
|
||||
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import poster from "src/lib/poster";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
@@ -12,7 +10,6 @@ import useSWRMutation from "swr/mutation";
|
||||
|
||||
const UserReply = () => {
|
||||
const [tasks, setTasks] = useState([]);
|
||||
const [inputText, setInputText] = useState("");
|
||||
|
||||
const { isLoading, mutate } = useSWRImmutable("/api/new_task/prompter_reply", fetcher, {
|
||||
onSuccess: (data) => {
|
||||
@@ -33,26 +30,6 @@ const UserReply = () => {
|
||||
},
|
||||
});
|
||||
|
||||
const submitResponse = (task: { id: string }) => {
|
||||
const text = inputText.trim();
|
||||
trigger({
|
||||
id: task.id,
|
||||
update_type: "text_reply_to_message",
|
||||
content: {
|
||||
text,
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
const fetchNextTask = () => {
|
||||
setInputText("");
|
||||
mutate();
|
||||
};
|
||||
|
||||
const textChangeHandler = (event: React.ChangeEvent<HTMLTextAreaElement>) => {
|
||||
setInputText(event.target.value);
|
||||
};
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
@@ -70,30 +47,14 @@ const UserReply = () => {
|
||||
);
|
||||
}
|
||||
|
||||
const task = tasks[0].task;
|
||||
|
||||
return (
|
||||
<div className={`p-12 ${mainBgClasses}`}>
|
||||
<TwoColumnsWithCards>
|
||||
<>
|
||||
<h5 className="text-lg font-semibold">Reply as a user</h5>
|
||||
<p className="text-lg py-1">Given the following conversation, provide an adequate reply</p>
|
||||
<Messages messages={task.conversation.messages} post_id={task.id} />
|
||||
{task.hint && <p className="text-lg py-1">Hint: {task.hint}</p>}
|
||||
</>
|
||||
<>
|
||||
<h5 className="text-lg font-semibold">Provide the user`s reply</h5>
|
||||
<TrackedTextarea
|
||||
text={inputText}
|
||||
onTextChange={textChangeHandler}
|
||||
thresholds={{ low: 20, medium: 40, goal: 50 }}
|
||||
textareaProps={{ placeholder: "Reply..." }}
|
||||
/>
|
||||
</>
|
||||
</TwoColumnsWithCards>
|
||||
|
||||
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={fetchNextTask} />
|
||||
</div>
|
||||
<>
|
||||
<Head>
|
||||
<title>Reply as Assistant</title>
|
||||
<meta name="description" content="Reply as Assistant." />
|
||||
</Head>
|
||||
<Task tasks={tasks} trigger={trigger} mutate={mutate} mainBgClasses={mainBgClasses} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -1,29 +1,16 @@
|
||||
import { Box, useColorMode } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
|
||||
import { LeaderboardTable, TaskOption } from "src/components/Dashboard";
|
||||
import { getDashboardLayout } from "src/components/Layout";
|
||||
import { LeaderboardTable, SideMenu, TaskOption } from "src/components/Dashboard";
|
||||
import { colors } from "styles/Theme/colors";
|
||||
|
||||
const Dashboard = () => {
|
||||
const { colorMode } = useColorMode();
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<title>Dashboard - Open Assistant</title>
|
||||
<meta name="description" content="Chat with Open Assistant and provide feedback." />
|
||||
</Head>
|
||||
<Box backgroundColor={colorMode === "light" ? colors.light.bg : colors.dark.bg} className="sm:overflow-hidden">
|
||||
<Box className="sm:flex h-full gap-6">
|
||||
<Box className="p-6 sm:pr-0">
|
||||
<SideMenu />
|
||||
</Box>
|
||||
<Box className="flex flex-col overflow-auto p-6 sm:pl-0 gap-14">
|
||||
<TaskOption />
|
||||
<LeaderboardTable />
|
||||
</Box>
|
||||
</Box>
|
||||
</Box>
|
||||
<TaskOption />
|
||||
<LeaderboardTable />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { useEffect, useState } from "react";
|
||||
import { ContextMessages } from "src/components/ContextMessages";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Message } from "src/components/Messages";
|
||||
import { Sortable } from "src/components/Sortable/Sortable";
|
||||
import { SurveyCard } from "src/components/Survey/SurveyCard";
|
||||
import { TaskControlsOverridable } from "src/components/Survey/TaskControlsOverridable";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import poster from "src/lib/poster";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
@@ -14,11 +10,6 @@ import useSWRMutation from "swr/mutation";
|
||||
|
||||
const RankAssistantReplies = () => {
|
||||
const [tasks, setTasks] = useState([]);
|
||||
/**
|
||||
* This array will contain the ranked indices of the replies
|
||||
* The best reply will have index 0, and the worst is the last.
|
||||
*/
|
||||
const [ranking, setRanking] = useState<number[]>([]);
|
||||
|
||||
const { isLoading, mutate } = useSWRImmutable("/api/new_task/rank_assistant_replies", fetcher, {
|
||||
onSuccess: (data) => {
|
||||
@@ -39,21 +30,6 @@ const RankAssistantReplies = () => {
|
||||
},
|
||||
});
|
||||
|
||||
const submitResponse = (task) => {
|
||||
trigger({
|
||||
id: task.id,
|
||||
update_type: "message_ranking",
|
||||
content: {
|
||||
ranking,
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
const fetchNextTask = () => {
|
||||
setRanking([]);
|
||||
mutate();
|
||||
};
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
@@ -71,33 +47,13 @@ const RankAssistantReplies = () => {
|
||||
);
|
||||
}
|
||||
|
||||
const replies = tasks[0].task.replies as string[];
|
||||
const messages = tasks[0].task.conversation.messages as Message[];
|
||||
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<title>Rank Assistant Replies</title>
|
||||
<meta name="description" content="Rank Assistant Replies." />
|
||||
</Head>
|
||||
<div className={`p-12 ${mainBgClasses}`}>
|
||||
<SurveyCard className="max-w-7xl mx-auto h-fit mb-24">
|
||||
<h5 className="text-lg font-semibold mb-4">Instructions</h5>
|
||||
<p className="text-lg py-1">
|
||||
Given the following replies, sort them from best to worst, best being first, worst being last.
|
||||
</p>
|
||||
<ContextMessages messages={messages} />
|
||||
<Sortable items={replies} onChange={setRanking} className="my-8" />
|
||||
</SurveyCard>
|
||||
|
||||
<TaskControlsOverridable
|
||||
tasks={tasks}
|
||||
isValid={ranking.length == tasks[0].task.replies.length}
|
||||
prepareForSubmit={() => setRanking(tasks[0].task.replies.map((_, idx) => idx))}
|
||||
onSubmitResponse={submitResponse}
|
||||
onSkip={fetchNextTask}
|
||||
/>
|
||||
</div>
|
||||
<Task tasks={tasks} trigger={trigger} mutate={mutate} mainBgClasses={mainBgClasses} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -2,9 +2,7 @@ import { useColorMode } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { useEffect, useState } from "react";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Sortable } from "src/components/Sortable/Sortable";
|
||||
import { SurveyCard } from "src/components/Survey/SurveyCard";
|
||||
import { TaskControlsOverridable } from "src/components/Survey/TaskControlsOverridable";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import poster from "src/lib/poster";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
@@ -12,12 +10,6 @@ import useSWRMutation from "swr/mutation";
|
||||
|
||||
const RankInitialPrompts = () => {
|
||||
const [tasks, setTasks] = useState([]);
|
||||
/**
|
||||
* This array will contain the ranked indices of the prompts
|
||||
* The best prompt will have index 0, and the worst is the last.
|
||||
*/
|
||||
const [ranking, setRanking] = useState<number[]>([]);
|
||||
// const bg = useColorModeValue("gray.100", "gray.800");
|
||||
|
||||
const { isLoading, mutate } = useSWRImmutable("/api/new_task/rank_initial_prompts", fetcher, {
|
||||
onSuccess: (data) => {
|
||||
@@ -38,21 +30,6 @@ const RankInitialPrompts = () => {
|
||||
}
|
||||
}, [tasks]);
|
||||
|
||||
const submitResponse = (task) => {
|
||||
trigger({
|
||||
id: task.id,
|
||||
update_type: "message_ranking",
|
||||
content: {
|
||||
ranking,
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
const fetchNextTask = () => {
|
||||
setRanking([]);
|
||||
mutate();
|
||||
};
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
@@ -76,23 +53,7 @@ const RankInitialPrompts = () => {
|
||||
<title>Rank Initial Prompts</title>
|
||||
<meta name="description" content="Rank initial prompts." />
|
||||
</Head>
|
||||
<div className={`p-12 ${mainBgClasses}`}>
|
||||
<SurveyCard className="max-w-7xl mx-auto h-fit mb-24">
|
||||
<h5 className="text-lg font-semibold mb-4">Instructions</h5>
|
||||
<p className="text-lg py-1">
|
||||
Given the following prompts, sort them from best to worst, best being first, worst being last.
|
||||
</p>
|
||||
<Sortable items={tasks[0].task.prompts} onChange={setRanking} className="my-8" />
|
||||
</SurveyCard>
|
||||
|
||||
<TaskControlsOverridable
|
||||
tasks={tasks}
|
||||
isValid={ranking.length == tasks[0].task.prompts.length}
|
||||
prepareForSubmit={() => setRanking(tasks[0].task.prompts.map((_, idx) => idx))}
|
||||
onSubmitResponse={submitResponse}
|
||||
onSkip={fetchNextTask}
|
||||
/>
|
||||
</div>
|
||||
<Task tasks={tasks} trigger={trigger} mutate={mutate} mainBgClasses={mainBgClasses} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { useEffect, useState } from "react";
|
||||
import { ContextMessages } from "src/components/ContextMessages";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Message } from "src/components/Messages";
|
||||
import { Sortable } from "src/components/Sortable/Sortable";
|
||||
import { SurveyCard } from "src/components/Survey/SurveyCard";
|
||||
import { TaskControlsOverridable } from "src/components/Survey/TaskControlsOverridable";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import poster from "src/lib/poster";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
@@ -14,11 +10,6 @@ import useSWRMutation from "swr/mutation";
|
||||
|
||||
const RankUserReplies = () => {
|
||||
const [tasks, setTasks] = useState([]);
|
||||
/**
|
||||
* This array will contain the ranked indices of the replies
|
||||
* The best reply will have index 0, and the worst is the last.
|
||||
*/
|
||||
const [ranking, setRanking] = useState<number[]>([]);
|
||||
|
||||
const { isLoading, mutate } = useSWRImmutable("/api/new_task/rank_prompter_replies", fetcher, {
|
||||
onSuccess: (data) => {
|
||||
@@ -39,21 +30,6 @@ const RankUserReplies = () => {
|
||||
},
|
||||
});
|
||||
|
||||
const submitResponse = (task) => {
|
||||
trigger({
|
||||
id: task.id,
|
||||
update_type: "message_ranking",
|
||||
content: {
|
||||
ranking,
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
const fetchNextTask = () => {
|
||||
setRanking([]);
|
||||
mutate();
|
||||
};
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
@@ -70,8 +46,6 @@ const RankUserReplies = () => {
|
||||
</div>
|
||||
);
|
||||
}
|
||||
const replies = tasks[0].task.replies as string[];
|
||||
const messages = tasks[0].task.conversation.messages as Message[];
|
||||
|
||||
return (
|
||||
<>
|
||||
@@ -79,24 +53,7 @@ const RankUserReplies = () => {
|
||||
<title>Rank User Replies</title>
|
||||
<meta name="description" content="Rank User Replies." />
|
||||
</Head>
|
||||
<div className={`p-12 ${mainBgClasses}`}>
|
||||
<SurveyCard className="max-w-7xl mx-auto h-fit mb-24">
|
||||
<h5 className="text-lg font-semibold mb-4">Instructions</h5>
|
||||
<p className="text-lg py-1">
|
||||
Given the following replies, sort them from best to worst, best being first, worst being last.
|
||||
</p>
|
||||
<ContextMessages messages={messages} />
|
||||
<Sortable items={replies} onChange={setRanking} className="my-8" />
|
||||
</SurveyCard>
|
||||
|
||||
<TaskControlsOverridable
|
||||
tasks={tasks}
|
||||
isValid={ranking.length == tasks[0].task.replies.length}
|
||||
prepareForSubmit={() => setRanking(tasks[0].task.replies.map((_, idx) => idx))}
|
||||
onSubmitResponse={submitResponse}
|
||||
onSkip={fetchNextTask}
|
||||
/>
|
||||
</div>
|
||||
<Task tasks={tasks} trigger={trigger} mutate={mutate} mainBgClasses={mainBgClasses} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -0,0 +1,113 @@
|
||||
import { Container, Grid, Slider, SliderFilledTrack, SliderThumb, SliderTrack } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import { useEffect, useId, useState } from "react";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { MessageView } from "src/components/Messages";
|
||||
import { TaskControls } from "src/components/Survey/TaskControls";
|
||||
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
|
||||
import { LabelInitialPromptTask, LabelInitialPromptTaskResponse, useLabelingTask } from "src/hooks/useLabelingTask";
|
||||
import { colors } from "styles/Theme/colors";
|
||||
|
||||
const LabelInitialPrompt = () => {
|
||||
const [sliderValues, setSliderValues] = useState<number[]>([]);
|
||||
|
||||
const { tasks, isLoading, submit, reset } = useLabelingTask<LabelInitialPromptTask>({
|
||||
taskApiEndpoint: "label_initial_prompt",
|
||||
});
|
||||
|
||||
const submitResponse = ({ id, task }: LabelInitialPromptTaskResponse) => {
|
||||
const labels = task.valid_labels.reduce((obj, label, i) => {
|
||||
obj[label] = sliderValues[i].toString();
|
||||
return obj;
|
||||
}, {} as Record<string, string>);
|
||||
|
||||
submit(id, task.message_id, task.prompt, labels);
|
||||
};
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
if (isLoading) {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
|
||||
if (tasks.length === 0) {
|
||||
return <Container className="p-6 text-center text-gray-800">No tasks found...</Container>;
|
||||
}
|
||||
|
||||
const task = tasks[0].task;
|
||||
|
||||
return (
|
||||
<div className={`p-12 ${mainBgClasses}`}>
|
||||
<TwoColumnsWithCards>
|
||||
<>
|
||||
<h5 className="text-lg font-semibold">Label Initial Prompt</h5>
|
||||
<p className="text-lg py-1">Provide labels for the following prompt</p>
|
||||
<MessageView text={task.prompt} is_assistant />
|
||||
</>
|
||||
<CheckboxSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />
|
||||
</TwoColumnsWithCards>
|
||||
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={reset} />
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default LabelInitialPrompt;
|
||||
|
||||
// TODO: consolidate with FlaggableElement
|
||||
|
||||
interface CheckboxSliderGroupProps {
|
||||
labelIDs: Array<string>;
|
||||
onChange: (sliderValues: number[]) => unknown;
|
||||
}
|
||||
|
||||
const CheckboxSliderGroup = ({ labelIDs, onChange }: CheckboxSliderGroupProps) => {
|
||||
const [sliderValues, setSliderValues] = useState<number[]>(Array.from({ length: labelIDs.length }).map(() => 0));
|
||||
|
||||
useEffect(() => {
|
||||
onChange(sliderValues);
|
||||
}, [sliderValues, onChange]);
|
||||
|
||||
return (
|
||||
<Grid templateColumns="auto 1fr" rowGap={1} columnGap={3}>
|
||||
{labelIDs.map((labelId, idx) => (
|
||||
<CheckboxSliderItem
|
||||
key={idx}
|
||||
labelId={labelId}
|
||||
sliderValue={sliderValues[idx]}
|
||||
sliderHandler={(sliderValue) => {
|
||||
const newState = sliderValues.slice();
|
||||
newState[idx] = sliderValue;
|
||||
setSliderValues(newState);
|
||||
}}
|
||||
/>
|
||||
))}
|
||||
</Grid>
|
||||
);
|
||||
};
|
||||
|
||||
function CheckboxSliderItem(props: {
|
||||
labelId: string;
|
||||
sliderValue: number;
|
||||
sliderHandler: (newVal: number) => unknown;
|
||||
}) {
|
||||
const id = useId();
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
const labelTextClass = colorMode === "light" ? `text-${colors.light.text}` : `text-${colors.dark.text}`;
|
||||
|
||||
return (
|
||||
<>
|
||||
<label className="text-sm" htmlFor={id}>
|
||||
{/* TODO: display real text instead of just the id */}
|
||||
<span className={labelTextClass}>{props.labelId}</span>
|
||||
</label>
|
||||
<Slider defaultValue={0} onChangeEnd={(val) => props.sliderHandler(val / 100)}>
|
||||
<SliderTrack>
|
||||
<SliderFilledTrack />
|
||||
<SliderThumb />
|
||||
</SliderTrack>
|
||||
</Slider>
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
import { Box, Container, Text, useColorModeValue } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { useState } from "react";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { MessageTableEntry } from "src/components/Messages/MessageTableEntry";
|
||||
import { MessageWithChildren } from "src/components/Messages/MessageWithChildren";
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import useSWR from "swr";
|
||||
|
||||
const MessageDetail = ({ id }) => {
|
||||
const mainBg = useColorModeValue("bg-slate-300", "bg-slate-900");
|
||||
|
||||
const [parent, setParent] = useState(null);
|
||||
|
||||
const { isLoading: isLoadingParent } = useSWR(id ? `/api/messages/${id}/parent` : null, fetcher, {
|
||||
onSuccess: (data) => {
|
||||
setParent(data);
|
||||
},
|
||||
onError: () => {
|
||||
setParent(null);
|
||||
},
|
||||
});
|
||||
|
||||
if (isLoadingParent) {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<title>Open Assistant</title>
|
||||
<meta
|
||||
name="description"
|
||||
content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world."
|
||||
/>
|
||||
</Head>
|
||||
<main className={`${mainBg}`}>
|
||||
<Container w="100%" pt={[2, 2, 4, 4]}>
|
||||
{parent && (
|
||||
<>
|
||||
<Text align="center" fontSize="xl">
|
||||
Parent
|
||||
</Text>
|
||||
<Box rounded="lg" p="2">
|
||||
<MessageTableEntry item={parent} idx={1} />
|
||||
</Box>
|
||||
</>
|
||||
)}
|
||||
</Container>
|
||||
<Box pb="4" maxW="full" px="2">
|
||||
<MessageWithChildren id={id} maxDepth={2} />
|
||||
</Box>
|
||||
</main>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
MessageDetail.getInitialProps = async ({ query }) => {
|
||||
const { id } = query;
|
||||
return { id };
|
||||
};
|
||||
|
||||
export default MessageDetail;
|
||||
@@ -1,79 +1,75 @@
|
||||
import { Box, CircularProgress, SimpleGrid, Text, useColorModeValue } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { useState } from "react";
|
||||
import { useEffect, useState } from "react";
|
||||
import { getDashboardLayout } from "src/components/Layout";
|
||||
import { MessageTable } from "src/components/Messages/MessageTable";
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import { SideMenu } from "src/components/Dashboard";
|
||||
import { MessageTable } from "src/components/Messages/MessageTable";
|
||||
import { getDashboardLayout } from "src/components/Layout";
|
||||
import { colors } from "styles/Theme/colors";
|
||||
|
||||
const MessagesDashboard = () => {
|
||||
const bgColor = useColorModeValue(colors.light.bg, colors.dark.bg);
|
||||
const boxBgColor = useColorModeValue("white", "gray.700");
|
||||
const boxAccentColor = useColorModeValue("gray.200", "gray.900");
|
||||
|
||||
const [messages, setMessages] = useState([]);
|
||||
const [userMessages, setUserMessages] = useState([]);
|
||||
|
||||
const { isLoading: isLoadingAll } = useSWRImmutable("/api/messages", fetcher, {
|
||||
const { isLoading: isLoadingAll, mutate: mutateAll } = useSWRImmutable("/api/messages", fetcher, {
|
||||
onSuccess: (data) => {
|
||||
setMessages(data);
|
||||
},
|
||||
});
|
||||
|
||||
const { isLoading: isLoadingUser } = useSWRImmutable(`/api/messages/user`, fetcher, {
|
||||
const { isLoading: isLoadingUser, mutate: mutateUser } = useSWRImmutable(`/api/messages/user`, fetcher, {
|
||||
onSuccess: (data) => {
|
||||
setUserMessages(data);
|
||||
},
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
if (messages.length == 0) {
|
||||
mutateAll();
|
||||
}
|
||||
if (userMessages.length == 0) {
|
||||
mutateUser();
|
||||
}
|
||||
}, [messages, userMessages]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<title>Messages - Open Assistant</title>
|
||||
<meta name="description" content="Chat with Open Assistant and provide feedback." />
|
||||
</Head>
|
||||
<Box backgroundColor={bgColor} className="sm:overflow-hidden">
|
||||
<Box className="sm:flex h-full gap-6">
|
||||
<Box className="p-6 sm:pr-0">
|
||||
<SideMenu />
|
||||
</Box>
|
||||
<Box className="flex flex-col overflow-auto p-6 sm:pl-0 gap-14">
|
||||
<SimpleGrid columns={[1, 1, 1, 2]} gap={4}>
|
||||
<Box>
|
||||
<Text className="text-2xl font-bold" pb="4">
|
||||
Most recent messages
|
||||
</Text>
|
||||
<Box
|
||||
backgroundColor={boxBgColor}
|
||||
boxShadow="base"
|
||||
dropShadow={boxAccentColor}
|
||||
borderRadius="xl"
|
||||
className="p-6 shadow-sm"
|
||||
>
|
||||
{isLoadingAll ? <CircularProgress isIndeterminate /> : <MessageTable messages={messages} />}
|
||||
</Box>
|
||||
</Box>
|
||||
<Box>
|
||||
<Text className="text-2xl font-bold" pb="4">
|
||||
Your most recent messages
|
||||
</Text>
|
||||
<Box
|
||||
backgroundColor={boxBgColor}
|
||||
boxShadow="base"
|
||||
dropShadow={boxAccentColor}
|
||||
borderRadius="xl"
|
||||
className="p-6 shadow-sm"
|
||||
>
|
||||
{isLoadingUser ? <CircularProgress isIndeterminate /> : <MessageTable messages={userMessages} />}
|
||||
</Box>
|
||||
</Box>
|
||||
</SimpleGrid>
|
||||
<SimpleGrid columns={[1, 1, 1, 2]} gap={4}>
|
||||
<Box>
|
||||
<Text className="text-2xl font-bold" pb="4">
|
||||
Most recent messages
|
||||
</Text>
|
||||
<Box
|
||||
backgroundColor={boxBgColor}
|
||||
boxShadow="base"
|
||||
dropShadow={boxAccentColor}
|
||||
borderRadius="xl"
|
||||
className="p-6 shadow-sm"
|
||||
>
|
||||
{isLoadingAll ? <CircularProgress isIndeterminate /> : <MessageTable messages={messages} />}
|
||||
</Box>
|
||||
</Box>
|
||||
</Box>
|
||||
<Box>
|
||||
<Text className="text-2xl font-bold" pb="4">
|
||||
Your most recent messages
|
||||
</Text>
|
||||
<Box
|
||||
backgroundColor={boxBgColor}
|
||||
boxShadow="base"
|
||||
dropShadow={boxAccentColor}
|
||||
borderRadius="xl"
|
||||
className="p-6 shadow-sm"
|
||||
>
|
||||
{isLoadingUser ? <CircularProgress isIndeterminate /> : <MessageTable messages={userMessages} />}
|
||||
</Box>
|
||||
</Box>
|
||||
</SimpleGrid>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import { Container, Heading } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { Footer } from "src/components/Footer";
|
||||
import { Header } from "src/components/Header";
|
||||
import { getTransparentHeaderLayout } from "src/components/Layout";
|
||||
|
||||
const PrivacyPolicy = () => {
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import { Container, Heading } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { Footer } from "src/components/Footer";
|
||||
import { Header } from "src/components/Header";
|
||||
import { getTransparentHeaderLayout } from "src/components/Layout";
|
||||
|
||||
const TermsOfService = () => {
|
||||
|
||||
Reference in New Issue
Block a user