mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Merge branch 'LAION-AI:main' into main
This commit is contained in:
@@ -4,20 +4,28 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
workflow_call:
|
||||
pull_request_target:
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
# in case of PR, check out the PR's head branch
|
||||
- uses: actions/checkout@v3
|
||||
if: github.event_name == 'pull_request_target'
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha }}
|
||||
|
||||
# in case of push, check out the main branch
|
||||
- uses: actions/checkout@v3
|
||||
if: github.event_name == 'push'
|
||||
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- uses: pre-commit/action@v3.0.0
|
||||
- name: Post PR comment on failure
|
||||
if: failure() && github.event_name == 'pull_request'
|
||||
if: failure() && github.event_name == 'pull_request_target'
|
||||
uses: peter-evans/create-or-update-comment@v2
|
||||
with:
|
||||
issue-number: ${{ github.event.pull_request.number }}
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
name: E2E Tests (Website)
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- oasst-shared/**
|
||||
- backend/**
|
||||
- website/**
|
||||
pull_request:
|
||||
paths:
|
||||
- oasst-shared/**
|
||||
- backend/**
|
||||
- website/**
|
||||
|
||||
jobs:
|
||||
test-e2e:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
- name: Start website, backend, etc
|
||||
run: docker compose up ci --build -d
|
||||
- name: Run Cypress tests
|
||||
uses: cypress-io/github-action@v5.0.2
|
||||
with:
|
||||
browser: chrome
|
||||
working-directory: website
|
||||
- uses: actions/upload-artifact@v3
|
||||
if: failure() # NOTE: screenshots will be generated only if E2E test failed
|
||||
with:
|
||||
name: cypress-screenshots
|
||||
path: website/cypress/screenshots
|
||||
- uses: actions/upload-artifact@v3
|
||||
if: always()
|
||||
with:
|
||||
name: cypress-videos
|
||||
path: website/cypress/videos
|
||||
@@ -110,3 +110,8 @@ Upon making a release on GitHub, all docker images are automatically built and
|
||||
pushed to ghcr.io. The docker images are tagged with the release version, and
|
||||
the `latest` tag. Further, the ansible playbook in `ansible/dev.yaml` is run to
|
||||
automatically deploy the built release to the dev machine.
|
||||
|
||||
### Contribute a Dataset
|
||||
|
||||
See
|
||||
[here](https://github.com/LAION-AI/Open-Assistant/blob/main/docs/docs/data/datasets.md)
|
||||
|
||||
@@ -81,6 +81,7 @@
|
||||
DEBUG_USE_SEED_DATA: "true"
|
||||
MAX_WORKERS: "1"
|
||||
RATE_LIMIT: "false"
|
||||
DEBUG_SKIP_EMBEDDING_COMPUTATION: "true"
|
||||
ports:
|
||||
- 8080:8080
|
||||
|
||||
|
||||
+27
@@ -0,0 +1,27 @@
|
||||
"""added miniLM_embedding column to message
|
||||
|
||||
Revision ID: 023548d474f7
|
||||
Revises: ba61fe17fb6e
|
||||
Create Date: 2023-01-08 11:06:25.613290
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "023548d474f7"
|
||||
down_revision = "ba61fe17fb6e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("message", sa.Column("miniLM_embedding", sa.ARRAY(sa.Float()), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("message", "miniLM_embedding")
|
||||
# ### end Alembic commands ###
|
||||
+49
@@ -0,0 +1,49 @@
|
||||
"""embedding for message now in its own table
|
||||
|
||||
Revision ID: 35bdc1a08bb8
|
||||
Revises: 023548d474f7
|
||||
Create Date: 2023-01-08 16:03:48.454207
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "35bdc1a08bb8"
|
||||
down_revision = "023548d474f7"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"message_embedding",
|
||||
sa.Column("message_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("embedding", sa.ARRAY(sa.Float()), nullable=True),
|
||||
sa.Column("model", sqlmodel.sql.sqltypes.AutoString(length=256), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["message_id"],
|
||||
["message.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("message_id", "model"),
|
||||
)
|
||||
op.drop_column("message", "miniLM_embedding")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"message",
|
||||
sa.Column(
|
||||
"miniLM_embedding",
|
||||
postgresql.ARRAY(postgresql.DOUBLE_PRECISION(precision=53)),
|
||||
autoincrement=False,
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.drop_table("message_embedding")
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,30 @@
|
||||
"""Created date
|
||||
|
||||
Revision ID: aac6b2f66006
|
||||
Revises: 35bdc1a08bb8
|
||||
Create Date: 2023-01-08 21:28:27.342729
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "aac6b2f66006"
|
||||
down_revision = "35bdc1a08bb8"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"message_embedding",
|
||||
sa.Column("created_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("message_embedding", "created_date")
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,19 +1,14 @@
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.schemas.hugging_face import ToxicityClassification
|
||||
from oasst_backend.utils.hugging_face import HuggingFaceAPI
|
||||
from oasst_backend.utils.hugging_face import HfUrl, HuggingFaceAPI
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class HF_url(str, Enum):
|
||||
HUGGINGFACE_TOXIC_ROBERTA = "https://api-inference.huggingface.co/models/unitary/multilingual-toxic-xlm-roberta"
|
||||
|
||||
|
||||
@router.get("/text_toxicity")
|
||||
async def get_text_toxicity(
|
||||
msg: str,
|
||||
@@ -30,7 +25,7 @@ async def get_text_toxicity(
|
||||
ToxicityClassification: the score of toxicity of the message.
|
||||
"""
|
||||
|
||||
api_url: str = HF_url.HUGGINGFACE_TOXIC_ROBERTA.value
|
||||
api_url: str = HfUrl.HUGGINGFACE_TOXIC_ROBERTA.value
|
||||
hugging_face_api = HuggingFaceAPI(api_url)
|
||||
response = await hugging_face_api.post(msg)
|
||||
|
||||
|
||||
@@ -7,7 +7,9 @@ from fastapi.security.api_key import APIKey
|
||||
from loguru import logger
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.api.v1.utils import prepare_conversation
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.prompt_repository import PromptRepository, TaskRepository
|
||||
from oasst_backend.utils.hugging_face import HfEmbeddingModel, HfUrl, HuggingFaceAPI
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
@@ -253,7 +255,7 @@ def tasks_acknowledge_failure(
|
||||
|
||||
|
||||
@router.post("/interaction", response_model=protocol_schema.TaskDone)
|
||||
def tasks_interaction(
|
||||
async def tasks_interaction(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
@@ -274,12 +276,26 @@ def tasks_interaction(
|
||||
)
|
||||
|
||||
# here we store the text reply in the database
|
||||
pr.store_text_reply(
|
||||
newMessage = pr.store_text_reply(
|
||||
text=interaction.text,
|
||||
frontend_message_id=interaction.message_id,
|
||||
user_frontend_message_id=interaction.user_message_id,
|
||||
)
|
||||
|
||||
if not settings.DEBUG_SKIP_EMBEDDING_COMPUTATION:
|
||||
try:
|
||||
hugging_face_api = HuggingFaceAPI(
|
||||
f"{HfUrl.HUGGINGFACE_FEATURE_EXTRACTION.value}/{HfEmbeddingModel.MINILM.value}"
|
||||
)
|
||||
embedding = await hugging_face_api.post(interaction.text)
|
||||
pr.insert_message_embedding(
|
||||
message_id=newMessage.id, model=HfEmbeddingModel.MINILM.value, embedding=embedding
|
||||
)
|
||||
except OasstError:
|
||||
logger.error(
|
||||
f"Could not fetch embbeddings for text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
|
||||
)
|
||||
|
||||
return protocol_schema.TaskDone()
|
||||
case protocol_schema.MessageRating:
|
||||
logger.info(
|
||||
|
||||
@@ -25,6 +25,7 @@ class Settings(BaseSettings):
|
||||
DEBUG_USE_SEED_DATA_PATH: Optional[FilePath] = (
|
||||
Path(__file__).parent.parent / "test_data/generic/test_generic_data.json"
|
||||
)
|
||||
DEBUG_SKIP_EMBEDDING_COMPUTATION: bool = False
|
||||
|
||||
HUGGING_FACE_API_KEY: str = ""
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from .api_client import ApiClient
|
||||
from .journal import Journal, JournalIntegration
|
||||
from .message import Message
|
||||
from .message_embedding import MessageEmbedding
|
||||
from .message_reaction import MessageReaction
|
||||
from .message_tree_state import MessageTreeState
|
||||
from .task import Task
|
||||
@@ -13,6 +14,7 @@ __all__ = [
|
||||
"User",
|
||||
"UserStats",
|
||||
"Message",
|
||||
"MessageEmbedding",
|
||||
"MessageReaction",
|
||||
"MessageTreeState",
|
||||
"Task",
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from sqlmodel import ARRAY, Field, Float, SQLModel
|
||||
|
||||
|
||||
class MessageEmbedding(SQLModel, table=True):
|
||||
__tablename__ = "message_embedding"
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("message_id", "model"),)
|
||||
|
||||
message_id: UUID = Field(sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("message.id"), nullable=False))
|
||||
model: str = Field(max_length=256, nullable=False)
|
||||
embedding: List[float] = Field(sa_column=sa.Column(ARRAY(Float)), nullable=True)
|
||||
|
||||
# In the case that the Message Embedding is created afterwards
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
|
||||
)
|
||||
@@ -2,13 +2,13 @@ import datetime
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from http import HTTPStatus
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import oasst_backend.models.db_payload as db_payload
|
||||
from loguru import logger
|
||||
from oasst_backend.journal_writer import JournalWriter
|
||||
from oasst_backend.models import ApiClient, Message, MessageReaction, TextLabels, User
|
||||
from oasst_backend.models import ApiClient, Message, MessageEmbedding, MessageReaction, TextLabels, User
|
||||
from oasst_backend.models.payload_column_type import PayloadContainer
|
||||
from oasst_backend.task_repository import TaskRepository, validate_frontend_message_id
|
||||
from oasst_backend.user_repository import UserRepository
|
||||
@@ -91,7 +91,12 @@ class PromptRepository:
|
||||
self.db.refresh(message)
|
||||
return message
|
||||
|
||||
def store_text_reply(self, text: str, frontend_message_id: str, user_frontend_message_id: str) -> Message:
|
||||
def store_text_reply(
|
||||
self,
|
||||
text: str,
|
||||
frontend_message_id: str,
|
||||
user_frontend_message_id: str,
|
||||
) -> Message:
|
||||
validate_frontend_message_id(frontend_message_id)
|
||||
validate_frontend_message_id(user_frontend_message_id)
|
||||
|
||||
@@ -224,6 +229,30 @@ class PromptRepository:
|
||||
OasstErrorCode.TASK_PAYLOAD_TYPE_MISMATCH,
|
||||
)
|
||||
|
||||
def insert_message_embedding(self, message_id: UUID, model: str, embedding: List[float]) -> MessageEmbedding:
|
||||
"""Insert the embedding of a new message in the database.
|
||||
|
||||
Args:
|
||||
message_id (UUID): the identifier of the message we want to save its embedding
|
||||
model (str): the model used for creating the embedding
|
||||
embedding (List[float]): the values obtained from the message & model
|
||||
|
||||
Raises:
|
||||
OasstError: if misses some of the before params
|
||||
|
||||
Returns:
|
||||
MessageEmbedding: the instance in the database of the embedding saved for that message
|
||||
"""
|
||||
|
||||
if None in (message_id, model, embedding):
|
||||
raise OasstError("Paramters missing to add embedding", OasstErrorCode.GENERIC_ERROR)
|
||||
|
||||
message_embedding = MessageEmbedding(message_id=message_id, model=model, embedding=embedding)
|
||||
self.db.add(message_embedding)
|
||||
self.db.commit()
|
||||
self.db.refresh(message_embedding)
|
||||
return message_embedding
|
||||
|
||||
def insert_reaction(self, task_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction:
|
||||
if self.user_id is None:
|
||||
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
|
||||
|
||||
@@ -1,10 +1,21 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Dict
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
from oasst_backend.config import settings
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
|
||||
|
||||
class HfUrl(str, Enum):
|
||||
HUGGINGFACE_TOXIC_ROBERTA = ("https://api-inference.huggingface.co/models/unitary/multilingual-toxic-xlm-roberta",)
|
||||
HUGGINGFACE_FEATURE_EXTRACTION = "https://api-inference.huggingface.co/pipeline/feature-extraction"
|
||||
|
||||
|
||||
class HfEmbeddingModel(str, Enum):
|
||||
MINILM = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
||||
|
||||
|
||||
class HuggingFaceAPI:
|
||||
"""Class Object to make post calls to endpoints for inference in models hosted in HuggingFace"""
|
||||
|
||||
@@ -41,6 +52,9 @@ class HuggingFaceAPI:
|
||||
async with session.post(self.api_url, headers=self.headers, json=payload) as response:
|
||||
# If we get a bad response
|
||||
if response.status != 200:
|
||||
|
||||
logger.error(response)
|
||||
logger.info(self.headers)
|
||||
raise OasstError(
|
||||
"Response Error Detoxify HuggingFace", error_code=OasstErrorCode.HUGGINGFACE_API_ERROR
|
||||
)
|
||||
|
||||
@@ -11,6 +11,11 @@ services:
|
||||
image: sverrirab/sleep
|
||||
depends_on: [db, webdb, adminer, maildev, backend, redis]
|
||||
|
||||
# Used by CI automations.
|
||||
ci:
|
||||
image: sverrirab/sleep
|
||||
depends_on: [db, webdb, maildev, backend, redis, web]
|
||||
|
||||
# This DB is for the FastAPI Backend.
|
||||
db:
|
||||
image: postgres
|
||||
@@ -95,6 +100,7 @@ services:
|
||||
- DEBUG_SKIP_API_KEY_CHECK=True
|
||||
- DEBUG_USE_SEED_DATA=True
|
||||
- MAX_WORKERS=1
|
||||
- DEBUG_SKIP_EMBEDDING_COMPUTATION=True
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Example Notebook
|
||||
|
||||
[](https://colab.research.google.com/github/andrewm4894/Open-Assistant/blob/main/notebooks/example/example.ipynb)
|
||||
[](https://colab.research.google.com/github/LAION-AI/Open-Assistant/blob/main/notebooks/example/example.ipynb)
|
||||
|
||||
This folder contains an example reference notebook structure and approach for
|
||||
this project. Please try and follow this structure as closely as possible. While
|
||||
|
||||
@@ -9,10 +9,11 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"[](https://colab.research.google.com/github/andrewm4894/Open-Assistant/blob/example-notebook/notebooks/example/example.ipynb)"
|
||||
"[](https://colab.research.google.com/github/LAION-AI/Open-Assistant/blob/example-notebook/notebooks/example/example.ipynb)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -22,7 +23,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# uncomment and run below lines to set up if running in colab\n",
|
||||
"# !git clone https://github.com/andrewm4894/Open-Assistant.git\n",
|
||||
"# !git clone https://github.com/LAION-AI/Open-Assistant.git\n",
|
||||
"# %cd Open-Assistant/notebooks/example\n",
|
||||
"# !pip install -r requirements.txt"
|
||||
]
|
||||
@@ -146,12 +147,12 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.4"
|
||||
"version": "3.7.4 (tags/v3.7.4:e09359112e, Jul 8 2019, 20:34:20) [MSC v.1916 64 bit (AMD64)]"
|
||||
},
|
||||
"orig_nbformat": 4,
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "3ad933181bd8a04b432d3370b9dc3b0662ad032c4dfaa4e4f1596c548f763858"
|
||||
"hash": "25d5c2324055587ceaeef27650c79ce8358ea61d7689f2e0b8ada5d53f85bce4"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
@@ -6,6 +6,7 @@ pushd "$parent_path/../../backend"
|
||||
|
||||
export DEBUG_SKIP_API_KEY_CHECK=True
|
||||
export DEBUG_USE_SEED_DATA=True
|
||||
export DEBUG_SKIP_EMBEDDING_COMPUTATION=True
|
||||
|
||||
uvicorn main:app --reload --port 8080 --host 0.0.0.0
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
describe("ranking prompter replies", () => {
|
||||
it("completes the current task on submit and on request shows a new task", () => {
|
||||
cy.signInWithEmail("cypress@example.com");
|
||||
cy.visit("/evaluate/rank_user_replies");
|
||||
cy.visit("/evaluate/rank_assistant_replies");
|
||||
|
||||
cy.get('[data-cy="task-id"').then((taskIdElement) => {
|
||||
const taskId = taskIdElement.text();
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
describe("ranking assistant replies", () => {
|
||||
it("completes the current task on submit and on request shows a new task", () => {
|
||||
cy.signInWithEmail("cypress@example.com");
|
||||
cy.visit("/evaluate/rank_assistant_replies");
|
||||
cy.visit("/evaluate/rank_user_replies");
|
||||
|
||||
cy.get('[data-cy="task-id"').then((taskIdElement) => {
|
||||
const taskId = taskIdElement.text();
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
import {
|
||||
Button,
|
||||
ButtonProps,
|
||||
Menu,
|
||||
MenuButton,
|
||||
MenuItem,
|
||||
MenuList,
|
||||
Modal,
|
||||
ModalBody,
|
||||
ModalCloseButton,
|
||||
@@ -16,7 +12,6 @@ import {
|
||||
useDisclosure,
|
||||
} from "@chakra-ui/react";
|
||||
import { useState } from "react";
|
||||
import { FaChevronDown } from "react-icons/fa";
|
||||
|
||||
interface SkipButtonProps extends ButtonProps {
|
||||
onSkip: (reason: string) => void;
|
||||
|
||||
@@ -12,7 +12,7 @@ import React from "react";
|
||||
|
||||
export const CollapsableText = ({ text, maxLength = 220 }) => {
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
if (typeof text != "string" || text.length <= maxLength) {
|
||||
if (typeof text !== "string" || text.length <= maxLength) {
|
||||
return text;
|
||||
} else {
|
||||
return (
|
||||
|
||||
@@ -27,11 +27,26 @@ import poster from "src/lib/poster";
|
||||
import { colors } from "styles/Theme/colors";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
|
||||
interface textFlagLabels {
|
||||
attributeName: string;
|
||||
labelText: string;
|
||||
additionalExplanation?: string;
|
||||
}
|
||||
|
||||
export const FlaggableElement = (props) => {
|
||||
const [isEditing, setIsEditing] = useBoolean();
|
||||
const flaggable_labels = props.flaggable_labels;
|
||||
const TEXT_LABEL_FLAGS =
|
||||
flaggable_labels?.valid_labels?.map((valid_label) => {
|
||||
return {
|
||||
attributeName: valid_label.name,
|
||||
labelText: valid_label.display_text,
|
||||
additionalExplanation: valid_label.help_text,
|
||||
};
|
||||
}) || [];
|
||||
const { trigger } = useSWRMutation("/api/set_label", poster, {
|
||||
onSuccess: () => {
|
||||
setIsEditing.off;
|
||||
setIsEditing.off();
|
||||
},
|
||||
});
|
||||
|
||||
@@ -55,14 +70,14 @@ export const FlaggableElement = (props) => {
|
||||
const handleCheckboxState = (isChecked, idx) => {
|
||||
setCheckboxValues(
|
||||
checkboxValues.map((val, i) => {
|
||||
return i == idx ? isChecked : val;
|
||||
return i === idx ? isChecked : val;
|
||||
})
|
||||
);
|
||||
};
|
||||
const handleSliderState = (newVal, idx) => {
|
||||
setSliderValues(
|
||||
sliderValues.map((val, i) => {
|
||||
return i == idx ? newVal : val;
|
||||
return i === idx ? newVal : val;
|
||||
})
|
||||
);
|
||||
};
|
||||
@@ -181,48 +196,3 @@ export function FlagCheckbox(props: {
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
interface textFlagLabels {
|
||||
attributeName: string;
|
||||
labelText: string;
|
||||
additionalExplanation?: string;
|
||||
}
|
||||
const TEXT_LABEL_FLAGS: textFlagLabels[] = [
|
||||
// For the time being this list is configured on the FE.
|
||||
// In the future it may be provided by the API.
|
||||
// {
|
||||
// attributeName: "fails_task",
|
||||
// labelText: "Fails to follow the correct instruction / task",
|
||||
// additionalExplanation: "__TODO__",
|
||||
// },
|
||||
// {
|
||||
// attributeName: "not_customer_assistant_appropriate",
|
||||
// labelText: "Inappropriate for customer assistant",
|
||||
// additionalExplanation: "__TODO__",
|
||||
// },
|
||||
{
|
||||
attributeName: "sexual_content",
|
||||
labelText: "Contains sexual content",
|
||||
},
|
||||
{
|
||||
attributeName: "violence",
|
||||
labelText: "Contains violent content",
|
||||
},
|
||||
// {
|
||||
// attributeName: "encourages_violence",
|
||||
// labelText: "Encourages or fails to discourage violence/abuse/terrorism/self-harm",
|
||||
// },
|
||||
// {
|
||||
// attributeName: "denigrates_a_protected_class",
|
||||
// labelText: "Denigrates a protected class",
|
||||
// },
|
||||
// {
|
||||
// attributeName: "gives_harmful_advice",
|
||||
// labelText: "Fails to follow the correct instruction / task",
|
||||
// additionalExplanation:
|
||||
// "The advice given in the output is harmful or counter-productive. This may be in addition to, but is distinct from the question about encouraging violence/abuse/terrorism/self-harm.",
|
||||
// },
|
||||
// {
|
||||
// attributeName: "expresses_moral_judgement",
|
||||
// labelText: "Expresses moral judgement",
|
||||
// },
|
||||
];
|
||||
|
||||
@@ -1,20 +1,30 @@
|
||||
import { Grid } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import { forwardRef, useColorMode } from "@chakra-ui/react";
|
||||
import { useMemo } from "react";
|
||||
import { Message } from "src/types/Conversation";
|
||||
import { ValidLabel } from "src/types/Task";
|
||||
|
||||
import { FlaggableElement } from "./FlaggableElement";
|
||||
|
||||
export interface Message {
|
||||
text: string;
|
||||
is_assistant: boolean;
|
||||
message_id: string;
|
||||
}
|
||||
|
||||
export const Messages = ({ messages, post_id }: { messages: Message[]; post_id: string }) => {
|
||||
export const Messages = ({
|
||||
messages,
|
||||
post_id,
|
||||
valid_labels,
|
||||
}: {
|
||||
messages: Message[];
|
||||
post_id: string;
|
||||
valid_labels: ValidLabel[];
|
||||
}) => {
|
||||
const items = messages.map((messageProps: Message, i: number) => {
|
||||
const { message_id, text } = messageProps;
|
||||
return (
|
||||
<FlaggableElement text={text} post_id={post_id} message_id={message_id} key={i + text}>
|
||||
<FlaggableElement
|
||||
text={text}
|
||||
post_id={post_id}
|
||||
message_id={message_id}
|
||||
key={i + text}
|
||||
flaggable_labels={valid_labels}
|
||||
>
|
||||
<MessageView {...messageProps} />
|
||||
</FlaggableElement>
|
||||
);
|
||||
@@ -23,7 +33,7 @@ export const Messages = ({ messages, post_id }: { messages: Message[]; post_id:
|
||||
return <Grid gap={2}>{items}</Grid>;
|
||||
};
|
||||
|
||||
export const MessageView = ({ is_assistant, text, message_id }: Message) => {
|
||||
export const MessageView = forwardRef<Message, "div">(({ is_assistant, text }: Message, ref) => {
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
const bgColor = useMemo(() => {
|
||||
@@ -34,5 +44,11 @@ export const MessageView = ({ is_assistant, text, message_id }: Message) => {
|
||||
}
|
||||
}, [colorMode, is_assistant]);
|
||||
|
||||
return <div className={`${bgColor} p-4 rounded-md text-white whitespace-pre-wrap`}>{text}</div>;
|
||||
};
|
||||
return (
|
||||
<div ref={ref} className={`${bgColor} p-4 rounded-md text-white whitespace-pre-wrap`}>
|
||||
{text}
|
||||
</div>
|
||||
);
|
||||
});
|
||||
|
||||
MessageView.displayName = "MessageView";
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import { Stack, StackDivider } from "@chakra-ui/react";
|
||||
import { MessageTableEntry } from "src/components/Messages/MessageTableEntry";
|
||||
|
||||
export function MessageTable({ messages }) {
|
||||
export function MessageTable({ messages, valid_labels }) {
|
||||
return (
|
||||
<Stack divider={<StackDivider />} spacing="4">
|
||||
{messages.map((item, idx) => (
|
||||
<MessageTableEntry item={item} idx={idx} key={item.id} />
|
||||
<MessageTableEntry item={item} idx={idx} key={item.message_id || item.id} valid_labels={valid_labels} />
|
||||
))}
|
||||
</Stack>
|
||||
);
|
||||
|
||||
@@ -2,6 +2,7 @@ import { Avatar, HStack, LinkBox, useColorModeValue } from "@chakra-ui/react";
|
||||
import { boolean } from "boolean";
|
||||
import NextLink from "next/link";
|
||||
import { FlaggableElement } from "src/components/FlaggableElement";
|
||||
import type { ValidLabel } from "src/types/Task";
|
||||
|
||||
interface Message {
|
||||
text: string;
|
||||
@@ -11,13 +12,14 @@ interface Message {
|
||||
interface MessageTableEntryProps {
|
||||
item: Message;
|
||||
idx: number;
|
||||
valid_labels: ValidLabel[];
|
||||
}
|
||||
export function MessageTableEntry(props: MessageTableEntryProps) {
|
||||
const { item, idx } = props;
|
||||
const { item, idx, valid_labels } = props;
|
||||
const bgColor = useColorModeValue(idx % 2 === 0 ? "bg-slate-800" : "bg-black", "bg-sky-900");
|
||||
|
||||
return (
|
||||
<FlaggableElement text={item.text} post_id={item.id} key={`flag_${item.id}`}>
|
||||
<FlaggableElement text={item.text} post_id={item.id} key={`flag_${item.id}`} flaggable_labels={valid_labels}>
|
||||
<HStack>
|
||||
<Avatar
|
||||
name={`${boolean(item.is_assistant) ? "Assitant" : "User"}`}
|
||||
|
||||
@@ -64,7 +64,7 @@ export function MessageWithChildren(props: MessageWithChildrenProps) {
|
||||
<Flex justifyContent="center" pb="2">
|
||||
<Box maxWidth="container.sm" flex="1" px={isFirstOrOnly ? [4, 6, 8, 9] : "0"}>
|
||||
<Box px={isFirstOrOnly ? "2" : "0"}>
|
||||
<MessageTableEntry item={message} idx={1} />
|
||||
<MessageTableEntry item={message} idx={1} valid_labels={[]} />
|
||||
</Box>
|
||||
</Box>
|
||||
</Flex>
|
||||
@@ -90,7 +90,7 @@ export function MessageWithChildren(props: MessageWithChildrenProps) {
|
||||
<HStack {...MessageStackProps}>
|
||||
{children.map((item, idx) => (
|
||||
<Box maxWidth="container.sm" flex="1" key={`recursiveMessageWChildren_${idx}`}>
|
||||
<MessageTableEntry item={item} idx={idx * 2} />
|
||||
<MessageTableEntry item={item} idx={idx * 2} valid_labels={[]} />
|
||||
</Box>
|
||||
))}
|
||||
</HStack>
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
import { useState } from "react";
|
||||
|
||||
import { Messages } from "src/components/Messages";
|
||||
import { TaskControls } from "src/components/Survey/TaskControls";
|
||||
import { TrackedTextarea } from "src/components/Survey/TrackedTextarea";
|
||||
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
|
||||
import { TaskType } from "./TaskTypes";
|
||||
import { TaskInfo } from "src/components/Tasks/TaskTypes";
|
||||
|
||||
export interface CreateTaskProps {
|
||||
// we need a task type
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
tasks: any[];
|
||||
taskType: TaskType;
|
||||
taskType: TaskInfo;
|
||||
trigger: (update: { id: string; update_type: string; content: { text: string } }) => void;
|
||||
onSkipTask: (task: { id: string }, reason: string) => void;
|
||||
onNextTask: () => void;
|
||||
@@ -18,7 +17,7 @@ export interface CreateTaskProps {
|
||||
}
|
||||
export const CreateTask = ({ tasks, taskType, trigger, onSkipTask, onNextTask, mainBgClasses }: CreateTaskProps) => {
|
||||
const task = tasks[0].task;
|
||||
|
||||
const valid_labels = tasks[0].valid_labels;
|
||||
const [inputText, setInputText] = useState("");
|
||||
|
||||
const submitResponse = (task: { id: string }) => {
|
||||
@@ -42,7 +41,9 @@ export const CreateTask = ({ tasks, taskType, trigger, onSkipTask, onNextTask, m
|
||||
<>
|
||||
<h5 className="text-lg font-semibold">{taskType.label}</h5>
|
||||
<p className="text-lg py-1">{taskType.overview}</p>
|
||||
{task.conversation ? <Messages messages={task.conversation.messages} post_id={task.id} /> : null}
|
||||
{task.conversation ? (
|
||||
<Messages messages={task.conversation.messages} post_id={task.id} valid_labels={valid_labels} />
|
||||
) : null}
|
||||
</>
|
||||
<>
|
||||
<h5 className="text-lg font-semibold">{taskType.instruction}</h5>
|
||||
|
||||
@@ -33,6 +33,7 @@ export const EvaluateTask = ({ tasks, trigger, onSkipTask, onNextTask, mainBgCla
|
||||
messages = messages.map((message, index) => ({ ...message, id: index }));
|
||||
}
|
||||
|
||||
const valid_labels = tasks[0].valid_labels;
|
||||
const sortables = tasks[0].task.replies ? "replies" : "prompts";
|
||||
|
||||
return (
|
||||
@@ -42,13 +43,13 @@ export const EvaluateTask = ({ tasks, trigger, onSkipTask, onNextTask, mainBgCla
|
||||
<p className="text-lg py-1">
|
||||
Given the following {sortables}, sort them from best to worst, best being first, worst being last.
|
||||
</p>
|
||||
{messages ? <MessageTable messages={messages} /> : null}
|
||||
{messages ? <MessageTable messages={messages} valid_labels={valid_labels} /> : null}
|
||||
<Sortable items={tasks[0].task[sortables]} onChange={setRanking} className="my-8" />
|
||||
</SurveyCard>
|
||||
|
||||
<TaskControlsOverridable
|
||||
tasks={tasks}
|
||||
isValid={ranking.length == tasks[0].task[sortables].length}
|
||||
isValid={ranking.length === tasks[0].task[sortables].length}
|
||||
prepareForSubmit={() => setRanking(tasks[0].task[sortables].map((_, idx) => idx))}
|
||||
onSubmitResponse={submitResponse}
|
||||
onSkipTask={(task, reason) => {
|
||||
|
||||
@@ -1,43 +1,71 @@
|
||||
import { Grid, Slider, SliderFilledTrack, SliderThumb, SliderTrack } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import { ReactNode, useEffect, useId, useMemo, useState } from "react";
|
||||
import { useEffect, useId, useState } from "react";
|
||||
import { MessageView } from "src/components/Messages";
|
||||
import { MessageTable } from "src/components/Messages/MessageTable";
|
||||
import { TaskControls } from "src/components/Survey/TaskControls";
|
||||
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
|
||||
import { TaskInfo } from "src/components/Tasks/TaskTypes";
|
||||
import { TaskType } from "src/types/Task";
|
||||
import { colors } from "styles/Theme/colors";
|
||||
|
||||
export const LabelTask = ({
|
||||
title,
|
||||
desc,
|
||||
messages,
|
||||
inputs,
|
||||
controls,
|
||||
}: {
|
||||
title: string;
|
||||
desc: string;
|
||||
messages: ReactNode;
|
||||
inputs: ReactNode;
|
||||
controls: ReactNode;
|
||||
}) => {
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
export interface LabelTaskProps {
|
||||
// we need a task type
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
tasks: any[];
|
||||
taskType: TaskInfo;
|
||||
trigger: (update: {
|
||||
id: string;
|
||||
update_type: string;
|
||||
content: { text: string; labels: { [k: string]: number }; message_id: string };
|
||||
}) => void;
|
||||
onSkipTask: (task: { id: string }, reason: string) => void;
|
||||
onNextTask: () => void;
|
||||
mainBgClasses: string;
|
||||
}
|
||||
export const LabelTask = ({ tasks, taskType, trigger, onSkipTask, onNextTask, mainBgClasses }: LabelTaskProps) => {
|
||||
const task = tasks[0].task;
|
||||
const valid_labels = tasks[0].valid_labels;
|
||||
|
||||
const card = useMemo(
|
||||
() => (
|
||||
<>
|
||||
<h5 className="text-lg font-semibold">{title}</h5>
|
||||
<p className="text-lg py-1">{desc}</p>
|
||||
{messages}
|
||||
</>
|
||||
),
|
||||
[title, desc, messages]
|
||||
);
|
||||
const [sliderValues, setSliderValues] = useState<number[]>([]);
|
||||
|
||||
const submitResponse = (task: { id: string; reply: string; message_id: string }) => {
|
||||
console.assert(valid_labels.length === sliderValues.length);
|
||||
const labels = Object.fromEntries(valid_labels.valid_labels.map((label, i) => [label, sliderValues[i]]));
|
||||
trigger({
|
||||
id: task.id,
|
||||
update_type: "text_labels",
|
||||
content: { labels, text: task.reply, message_id: task.message_id },
|
||||
});
|
||||
};
|
||||
|
||||
return (
|
||||
<div className={`p-12 ${mainBgClasses}`}>
|
||||
<TwoColumnsWithCards>
|
||||
{card}
|
||||
{inputs}
|
||||
<>
|
||||
<h5 className="text-lg font-semibold">{taskType.label}</h5>
|
||||
<p className="text-lg py-1">{taskType.overview}</p>
|
||||
|
||||
{task.conversation ? (
|
||||
<MessageTable
|
||||
messages={[
|
||||
...(task.conversation ? task.conversation.messages : []),
|
||||
{
|
||||
text: task.reply,
|
||||
is_assistant: task.type === TaskType.label_assistant_reply,
|
||||
message_id: task.message_id,
|
||||
},
|
||||
]}
|
||||
valid_labels={valid_labels}
|
||||
/>
|
||||
) : (
|
||||
<MessageView text={task.prompt} is_assistant={false} message_id={task.message_id} />
|
||||
)}
|
||||
</>
|
||||
<LabelSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />
|
||||
</TwoColumnsWithCards>
|
||||
{controls}
|
||||
|
||||
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkipTask={onSkipTask} onNextTask={onNextTask} />
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
import { CreateTask } from "./CreateTask";
|
||||
import { EvaluateTask } from "./EvaluateTask";
|
||||
import { TaskCategory, TaskTypes } from "./TaskTypes";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import { CreateTask } from "src/components/Tasks/CreateTask";
|
||||
import { EvaluateTask } from "src/components/Tasks/EvaluateTask";
|
||||
import { LabelTask } from "src/components/Tasks/LabelTask";
|
||||
import { TaskCategory, TaskTypes } from "src/components/Tasks/TaskTypes";
|
||||
import poster from "src/lib/poster";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
|
||||
export const Task = ({ tasks, trigger, mutate, mainBgClasses }) => {
|
||||
export const Task = ({ tasks, trigger, mutate }) => {
|
||||
const task = tasks[0].task;
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
const { trigger: sendRejection } = useSWRMutation("/api/reject_task", poster, {
|
||||
onSuccess: async () => {
|
||||
mutate();
|
||||
@@ -45,6 +50,17 @@ export const Task = ({ tasks, trigger, mutate, mainBgClasses }) => {
|
||||
mainBgClasses={mainBgClasses}
|
||||
/>
|
||||
);
|
||||
case TaskCategory.Label:
|
||||
return (
|
||||
<LabelTask
|
||||
tasks={tasks}
|
||||
taskType={taskType}
|
||||
trigger={trigger}
|
||||
onSkipTask={rejectTask}
|
||||
onNextTask={mutate}
|
||||
mainBgClasses={mainBgClasses}
|
||||
/>
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ export enum TaskCategory {
|
||||
Label = "Label",
|
||||
}
|
||||
|
||||
export interface TaskType {
|
||||
export interface TaskInfo {
|
||||
label: string;
|
||||
desc: string;
|
||||
category: TaskCategory;
|
||||
@@ -14,7 +14,7 @@ export interface TaskType {
|
||||
instruction?: string;
|
||||
}
|
||||
|
||||
export const TaskTypes: TaskType[] = [
|
||||
export const TaskTypes: TaskInfo[] = [
|
||||
// create
|
||||
{
|
||||
label: "Create Initial Prompts",
|
||||
@@ -71,6 +71,7 @@ export const TaskTypes: TaskType[] = [
|
||||
desc: "Provide labels for a prompt.",
|
||||
category: TaskCategory.Label,
|
||||
pathname: "/label/label_initial_prompt",
|
||||
overview: "Provide labels for the following prompt",
|
||||
type: "label_initial_prompt",
|
||||
},
|
||||
{
|
||||
@@ -78,6 +79,7 @@ export const TaskTypes: TaskType[] = [
|
||||
desc: "Provide labels for a prompt.",
|
||||
category: TaskCategory.Label,
|
||||
pathname: "/label/label_prompter_reply",
|
||||
overview: "Given the following discussion, provide labels for the final promp",
|
||||
type: "label_prompter_reply",
|
||||
},
|
||||
{
|
||||
@@ -85,6 +87,7 @@ export const TaskTypes: TaskType[] = [
|
||||
desc: "Provide labels for a prompt.",
|
||||
category: TaskCategory.Label,
|
||||
pathname: "/label/label_assistant_reply",
|
||||
overview: "Given the following discussion, provide labels for the final prompt.",
|
||||
type: "label_assistant_reply",
|
||||
},
|
||||
];
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
import { useGenericTaskAPI } from "../useGenericTaskAPI";
|
||||
|
||||
interface CreateInitialPromptTask {
|
||||
id: string;
|
||||
type: "initial_prompt";
|
||||
hint: string;
|
||||
}
|
||||
|
||||
export const useCreateInitialPrompt = () => useGenericTaskAPI<CreateInitialPromptTask>("initial_prompt");
|
||||
@@ -1,24 +0,0 @@
|
||||
import { useGenericTaskAPI } from "../useGenericTaskAPI";
|
||||
|
||||
interface BaseCreateReplyTask {
|
||||
id: string;
|
||||
conversation: {
|
||||
messages: Array<{
|
||||
text: string;
|
||||
is_assistant: boolean;
|
||||
message_id: string;
|
||||
}>;
|
||||
};
|
||||
}
|
||||
|
||||
export interface CreateAssistantReplyTask extends BaseCreateReplyTask {
|
||||
type: "assistant_reply";
|
||||
}
|
||||
|
||||
export interface CreatePrompterReplyTask extends BaseCreateReplyTask {
|
||||
type: "prompter_reply";
|
||||
}
|
||||
|
||||
export const useCreateAssistantReply = () => useGenericTaskAPI<CreateAssistantReplyTask>("assistant_reply");
|
||||
|
||||
export const useCreatePrompterReply = () => useGenericTaskAPI<CreatePrompterReplyTask>("prompter_reply");
|
||||
@@ -1,9 +0,0 @@
|
||||
import { useGenericTaskAPI } from "../useGenericTaskAPI";
|
||||
|
||||
interface RankInitialPromptsTask {
|
||||
id: string;
|
||||
type: "rank_initial_prompts";
|
||||
prompts: string[];
|
||||
}
|
||||
|
||||
export const useRankInitialPromptsTask = () => useGenericTaskAPI<RankInitialPromptsTask>("rank_initial_prompts");
|
||||
@@ -1,25 +0,0 @@
|
||||
import { useGenericTaskAPI } from "../useGenericTaskAPI";
|
||||
|
||||
interface BaseRankRepliesTask {
|
||||
id: string;
|
||||
replies: string[];
|
||||
conversation: {
|
||||
messages: Array<{
|
||||
text: string;
|
||||
is_assistant: boolean;
|
||||
message_id: string;
|
||||
}>;
|
||||
};
|
||||
}
|
||||
|
||||
interface RankAssistantRepliesTask extends BaseRankRepliesTask {
|
||||
type: "rank_assistant_replies";
|
||||
}
|
||||
|
||||
interface RankPrompterRepliesTask extends BaseRankRepliesTask {
|
||||
type: "rank_prompter_replies";
|
||||
}
|
||||
|
||||
export const useRankAssistantRepliesTask = () => useGenericTaskAPI<RankAssistantRepliesTask>("rank_assistant_replies");
|
||||
|
||||
export const useRankPrompterRepliesTask = () => useGenericTaskAPI<RankPrompterRepliesTask>("rank_prompter_replies");
|
||||
@@ -1,22 +0,0 @@
|
||||
import { TaskResponse } from "../useGenericTaskAPI";
|
||||
import { LabelingTaskType, useLabelingTask } from "./useLabelingTask";
|
||||
|
||||
export interface LabelAssistantReplyTask {
|
||||
id: string;
|
||||
type: LabelingTaskType.label_assistant_reply;
|
||||
message_id: string;
|
||||
valid_labels: string[];
|
||||
reply: string;
|
||||
conversation: {
|
||||
messages: Array<{
|
||||
text: string;
|
||||
is_assistant: boolean;
|
||||
message_id: string;
|
||||
}>;
|
||||
};
|
||||
}
|
||||
|
||||
export type LabelAssistantReplyTaskResponse = TaskResponse<LabelAssistantReplyTask>;
|
||||
|
||||
export const useLabelAssistantReplyTask = () =>
|
||||
useLabelingTask<LabelAssistantReplyTask>(LabelingTaskType.label_assistant_reply);
|
||||
@@ -1,15 +0,0 @@
|
||||
import { TaskResponse } from "../useGenericTaskAPI";
|
||||
import { LabelingTaskType, useLabelingTask } from "./useLabelingTask";
|
||||
|
||||
export interface LabelInitialPromptTask {
|
||||
id: string;
|
||||
type: LabelingTaskType.label_initial_prompt;
|
||||
message_id: string;
|
||||
valid_labels: string[];
|
||||
prompt: string;
|
||||
}
|
||||
|
||||
export type LabelInitialPromptTaskResponse = TaskResponse<LabelInitialPromptTask>;
|
||||
|
||||
export const useLabelInitialPromptTask = () =>
|
||||
useLabelingTask<LabelInitialPromptTask>(LabelingTaskType.label_initial_prompt);
|
||||
@@ -1,22 +0,0 @@
|
||||
import { TaskResponse } from "../useGenericTaskAPI";
|
||||
import { LabelingTaskType, useLabelingTask } from "./useLabelingTask";
|
||||
|
||||
export interface LabelPrompterReplyTask {
|
||||
id: string;
|
||||
type: LabelingTaskType.label_prompter_reply;
|
||||
message_id: string;
|
||||
valid_labels: string[];
|
||||
reply: string;
|
||||
conversation: {
|
||||
messages: Array<{
|
||||
text: string;
|
||||
is_assistant: boolean;
|
||||
message_id: string;
|
||||
}>;
|
||||
};
|
||||
}
|
||||
|
||||
export type LabelPrompterReplyTaskResponse = TaskResponse<LabelPrompterReplyTask>;
|
||||
|
||||
export const useLabelPrompterReplyTask = () =>
|
||||
useLabelingTask<LabelPrompterReplyTask>(LabelingTaskType.label_prompter_reply);
|
||||
@@ -1,20 +0,0 @@
|
||||
import { useGenericTaskAPI } from "../useGenericTaskAPI";
|
||||
|
||||
export const enum LabelingTaskType {
|
||||
label_initial_prompt = "label_initial_prompt",
|
||||
label_prompter_reply = "label_prompter_reply",
|
||||
label_assistant_reply = "label_assistant_reply",
|
||||
}
|
||||
|
||||
export const useLabelingTask = <TaskType>(endpoint: LabelingTaskType) => {
|
||||
const { tasks, isLoading, trigger, reset, error } = useGenericTaskAPI<TaskType>(endpoint);
|
||||
|
||||
const submit = (id: string, message_id: string, text: string, validLabels: string[], labelWeights: number[]) => {
|
||||
console.assert(validLabels.length === labelWeights.length);
|
||||
const labels = Object.fromEntries(validLabels.map((label, i) => [label, labelWeights[i]]));
|
||||
|
||||
return trigger({ id, update_type: "text_labels", content: { labels, text, message_id } });
|
||||
};
|
||||
|
||||
return { tasks, isLoading, submit, reset, error };
|
||||
};
|
||||
@@ -0,0 +1,8 @@
|
||||
import { TaskType } from "src/types/Task";
|
||||
import { CreateAssistantReplyTask, CreateInitialPromptTask, CreatePrompterReplyTask } from "src/types/Tasks";
|
||||
|
||||
import { useGenericTaskAPI } from "./useGenericTaskAPI";
|
||||
|
||||
export const useCreateAssistantReply = () => useGenericTaskAPI<CreateAssistantReplyTask>(TaskType.assistant_reply);
|
||||
export const useCreatePrompterReply = () => useGenericTaskAPI<CreatePrompterReplyTask>(TaskType.prompter_reply);
|
||||
export const useCreateInitialPrompt = () => useGenericTaskAPI<CreateInitialPromptTask>(TaskType.initial_prompt);
|
||||
@@ -1,18 +1,11 @@
|
||||
import { useState } from "react";
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import poster from "src/lib/poster";
|
||||
import { BaseTask, TaskResponse } from "src/types/Task";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
|
||||
// TODO: type & centralize types for all tasks
|
||||
|
||||
export interface TaskResponse<TaskType> {
|
||||
id: string;
|
||||
userId: string;
|
||||
task: TaskType;
|
||||
}
|
||||
|
||||
export const useGenericTaskAPI = <TaskType,>(taskApiEndpoint: string) => {
|
||||
export const useGenericTaskAPI = <TaskType extends BaseTask>(taskApiEndpoint: string) => {
|
||||
type ConcreteTaskResponse = TaskResponse<TaskType>;
|
||||
|
||||
const [tasks, setTasks] = useState<ConcreteTaskResponse[]>([]);
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
import { TaskType } from "src/types/Task";
|
||||
import { LabelAssistantReplyTask, LabelInitialPromptTask, LabelPrompterReplyTask } from "src/types/Tasks";
|
||||
|
||||
import { useGenericTaskAPI } from "./useGenericTaskAPI";
|
||||
|
||||
export const useLabelAssistantReplyTask = () =>
|
||||
useGenericTaskAPI<LabelAssistantReplyTask>(TaskType.label_assistant_reply);
|
||||
export const useLabelInitialPromptTask = () => useGenericTaskAPI<LabelInitialPromptTask>(TaskType.label_initial_prompt);
|
||||
export const useLabelPrompterReplyTask = () => useGenericTaskAPI<LabelPrompterReplyTask>(TaskType.label_prompter_reply);
|
||||
@@ -0,0 +1,12 @@
|
||||
import { TaskType } from "src/types/Task";
|
||||
import { RankAssistantRepliesTask, RankInitialPromptsTask, RankPrompterRepliesTask } from "src/types/Tasks";
|
||||
|
||||
import { useGenericTaskAPI } from "./useGenericTaskAPI";
|
||||
|
||||
export const useRankAssistantRepliesTask = () =>
|
||||
useGenericTaskAPI<RankAssistantRepliesTask>(TaskType.rank_assistant_replies);
|
||||
|
||||
export const useRankPrompterRepliesTask = () =>
|
||||
useGenericTaskAPI<RankPrompterRepliesTask>(TaskType.rank_prompter_replies);
|
||||
|
||||
export const useRankInitialPromptsTask = () => useGenericTaskAPI<RankInitialPromptsTask>(TaskType.rank_initial_prompts);
|
||||
@@ -30,7 +30,34 @@ export class OasstApiClient {
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
|
||||
if (resp.status == 204) {
|
||||
if (resp.status === 204) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (resp.status >= 300) {
|
||||
const errorText = await resp.text();
|
||||
let error: any;
|
||||
try {
|
||||
error = JSON.parse(errorText);
|
||||
} catch (e) {
|
||||
throw new OasstError(errorText, 0, resp.status);
|
||||
}
|
||||
throw new OasstError(error.message ?? error, error.error_code, resp.status);
|
||||
}
|
||||
|
||||
return await resp.json();
|
||||
}
|
||||
|
||||
private async get(path: string): Promise<any> {
|
||||
const resp = await fetch(`${this.oasstApiUrl}${path}`, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
"X-API-Key": this.oasstApiKey,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
|
||||
if (resp.status === 204) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -96,6 +123,12 @@ export class OasstApiClient {
|
||||
...content,
|
||||
});
|
||||
}
|
||||
|
||||
//Fetch valid labels. This is called every task. though the call may be redundant
|
||||
//keeping this for future where the valid labels may change per task
|
||||
async fetch_valid_text(): Promise<void> {
|
||||
return this.get(`/api/v1/text_labels/valid_labels`);
|
||||
}
|
||||
}
|
||||
|
||||
export const oasstApiClient =
|
||||
|
||||
@@ -23,6 +23,7 @@ const handler = async (req, res) => {
|
||||
|
||||
// Fetch the new task.
|
||||
const task = await oasstApiClient.fetchTask(task_type, token);
|
||||
const valid_labels = await oasstApiClient.fetch_valid_text();
|
||||
|
||||
// Store the task and link it to the user..
|
||||
const registeredTask = await prisma.registeredTask.create({
|
||||
@@ -36,6 +37,9 @@ const handler = async (req, res) => {
|
||||
},
|
||||
});
|
||||
|
||||
// Add the valid labels that can be used to flag messages in this Task
|
||||
registeredTask["valid_labels"] = valid_labels;
|
||||
|
||||
// Send the results to the client.
|
||||
res.status(200).json(registeredTask);
|
||||
};
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { Prisma } from "@prisma/client";
|
||||
import { getToken } from "next-auth/jwt";
|
||||
import { oasstApiClient } from "src/lib/oasst_api_client";
|
||||
import prisma from "src/lib/prismadb";
|
||||
|
||||
const handler = async (req, res) => {
|
||||
const token = await getToken({ req });
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import { getToken } from "next-auth/jwt";
|
||||
import prisma from "src/lib/prismadb";
|
||||
|
||||
/**
|
||||
* Sets the Label in the Backend.
|
||||
@@ -15,7 +14,7 @@ const handler = async (req, res) => {
|
||||
}
|
||||
|
||||
// Parse out the local message_id, task ID and the interaction contents.
|
||||
const { message_id, post_id, label_map, text } = await JSON.parse(req.body);
|
||||
const { message_id, label_map, text } = await JSON.parse(req.body);
|
||||
|
||||
const interactionRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/text_labels`, {
|
||||
method: "POST",
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { getCsrfToken, getProviders } from "next-auth/react";
|
||||
import { AuthLayout } from "src/components/AuthLayout";
|
||||
|
||||
export default function Verify() {
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
@@ -1,16 +1,12 @@
|
||||
import { Container } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import { useCreateAssistantReply } from "src/hooks/tasks/create/useCreateReply";
|
||||
import { useCreateAssistantReply } from "src/hooks/tasks/useCreateReply";
|
||||
|
||||
const AssistantReply = () => {
|
||||
const { tasks, isLoading, reset, trigger } = useCreateAssistantReply();
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
if (isLoading) {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
@@ -25,7 +21,7 @@ const AssistantReply = () => {
|
||||
<title>Reply as Assistant</title>
|
||||
<meta name="description" content="Reply as Assistant." />
|
||||
</Head>
|
||||
<Task tasks={tasks} trigger={trigger} mutate={reset} mainBgClasses={mainBgClasses} />
|
||||
<Task tasks={tasks} trigger={trigger} mutate={reset} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,16 +1,12 @@
|
||||
import { Container } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import { useCreateInitialPrompt } from "src/hooks/tasks/create/useCreateInitialPrompt";
|
||||
import { useCreateInitialPrompt } from "src/hooks/tasks/useCreateReply";
|
||||
|
||||
const InitialPrompt = () => {
|
||||
const { tasks, isLoading, reset, trigger } = useCreateInitialPrompt();
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
if (isLoading) {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
@@ -25,7 +21,7 @@ const InitialPrompt = () => {
|
||||
<title>Reply as Assistant</title>
|
||||
<meta name="description" content="Reply as Assistant." />
|
||||
</Head>
|
||||
<Task tasks={tasks} trigger={trigger} mutate={reset} mainBgClasses={mainBgClasses} />
|
||||
<Task tasks={tasks} trigger={trigger} mutate={reset} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -63,7 +63,7 @@ const SummarizeStory = () => {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
|
||||
if (tasks.length == 0) {
|
||||
if (tasks.length === 0) {
|
||||
return <div className="p-6 bg-slate-100 text-gray-800">No tasks found...</div>;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,16 +1,12 @@
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { Container } from "src/components/Container";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import { useCreatePrompterReply } from "src/hooks/tasks/create/useCreateReply";
|
||||
import { useCreatePrompterReply } from "src/hooks/tasks/useCreateReply";
|
||||
|
||||
const UserReply = () => {
|
||||
const { tasks, isLoading, reset, trigger } = useCreatePrompterReply();
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
if (isLoading) {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
@@ -25,7 +21,7 @@ const UserReply = () => {
|
||||
<title>Reply as Assistant</title>
|
||||
<meta name="description" content="Reply as Assistant." />
|
||||
</Head>
|
||||
<Task tasks={tasks} trigger={trigger} mutate={reset} mainBgClasses={mainBgClasses} />
|
||||
<Task tasks={tasks} trigger={trigger} mutate={reset} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,16 +1,12 @@
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { Container } from "src/components/Container";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import { useRankAssistantRepliesTask } from "src/hooks/tasks/evaluate/useRankReplies";
|
||||
import { useRankAssistantRepliesTask } from "src/hooks/tasks/useRankReplies";
|
||||
|
||||
const RankAssistantReplies = () => {
|
||||
const { tasks, isLoading, reset, trigger } = useRankAssistantRepliesTask();
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
if (isLoading) {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
@@ -25,7 +21,7 @@ const RankAssistantReplies = () => {
|
||||
<title>Rank Assistant Replies</title>
|
||||
<meta name="description" content="Rank Assistant Replies." />
|
||||
</Head>
|
||||
<Task tasks={tasks} trigger={trigger} mutate={reset} mainBgClasses={mainBgClasses} />
|
||||
<Task tasks={tasks} trigger={trigger} mutate={reset} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,16 +1,12 @@
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { Container } from "src/components/Container";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import { useRankInitialPromptsTask } from "src/hooks/tasks/evaluate/useRankInitialPrompts";
|
||||
import { useRankInitialPromptsTask } from "src/hooks/tasks/useRankReplies";
|
||||
|
||||
const RankInitialPrompts = () => {
|
||||
const { tasks, isLoading, reset, trigger } = useRankInitialPromptsTask();
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
if (isLoading) {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
@@ -25,7 +21,7 @@ const RankInitialPrompts = () => {
|
||||
<title>Rank Initial Prompts</title>
|
||||
<meta name="description" content="Rank initial prompts." />
|
||||
</Head>
|
||||
<Task tasks={tasks} trigger={trigger} mutate={reset} mainBgClasses={mainBgClasses} />
|
||||
<Task tasks={tasks} trigger={trigger} mutate={reset} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,16 +1,12 @@
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { Container } from "src/components/Container";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import { useRankPrompterRepliesTask } from "src/hooks/tasks/evaluate/useRankReplies";
|
||||
import { useRankPrompterRepliesTask } from "src/hooks/tasks/useRankReplies";
|
||||
|
||||
const RankUserReplies = () => {
|
||||
const { tasks, isLoading, reset, trigger } = useRankPrompterRepliesTask();
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
if (isLoading) {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
@@ -25,7 +21,7 @@ const RankUserReplies = () => {
|
||||
<title>Rank User Replies</title>
|
||||
<meta name="description" content="Rank User Replies." />
|
||||
</Head>
|
||||
<Task tasks={tasks} trigger={trigger} mutate={reset} mainBgClasses={mainBgClasses} />
|
||||
<Task tasks={tasks} trigger={trigger} mutate={reset} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -60,7 +60,7 @@ const RateSummary = () => {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
|
||||
if (tasks.length == 0) {
|
||||
if (tasks.length === 0) {
|
||||
return (
|
||||
<div className={`p-12 ${mainBgClasses}`}>
|
||||
<div className="flex h-full">
|
||||
|
||||
@@ -1,10 +1,21 @@
|
||||
import Head from "next/head";
|
||||
import { useRouter } from "next/router";
|
||||
import { useSession } from "next-auth/react";
|
||||
import { useEffect } from "react";
|
||||
import { CallToAction } from "src/components/CallToAction";
|
||||
import { Faq } from "src/components/Faq";
|
||||
import { Hero } from "src/components/Hero";
|
||||
import { getTransparentHeaderLayout } from "src/components/Layout";
|
||||
|
||||
const Home = () => {
|
||||
const router = useRouter();
|
||||
const { status } = useSession();
|
||||
useEffect(() => {
|
||||
if (status === "authenticated") {
|
||||
router.push("/dashboard");
|
||||
}
|
||||
}, [router, status]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
|
||||
@@ -1,46 +1,28 @@
|
||||
import { useState } from "react";
|
||||
import { Container } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Message } from "src/components/Messages";
|
||||
import { MessageTable } from "src/components/Messages/MessageTable";
|
||||
import { TaskControls } from "src/components/Survey/TaskControls";
|
||||
import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask";
|
||||
import {
|
||||
LabelAssistantReplyTaskResponse,
|
||||
useLabelAssistantReplyTask,
|
||||
} from "src/hooks/tasks/labeling/useLabelAssistantReply";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import { useLabelAssistantReplyTask } from "src/hooks/tasks/useLabelingTask";
|
||||
|
||||
const LabelAssistantReply = () => {
|
||||
const [sliderValues, setSliderValues] = useState<number[]>([]);
|
||||
const { tasks, isLoading, trigger, reset } = useLabelAssistantReplyTask();
|
||||
|
||||
const { tasks, isLoading, submit, reset } = useLabelAssistantReplyTask();
|
||||
|
||||
if (isLoading || tasks.length === 0) {
|
||||
if (isLoading) {
|
||||
return <LoadingScreen />;
|
||||
}
|
||||
|
||||
const task = tasks[0].task;
|
||||
const messages: Message[] = [
|
||||
...task.conversation.messages,
|
||||
{ text: task.reply, is_assistant: true, message_id: task.message_id },
|
||||
];
|
||||
if (tasks.length === 0) {
|
||||
return <Container className="p-6 text-center text-gray-800">No tasks found...</Container>;
|
||||
}
|
||||
|
||||
return (
|
||||
<LabelTask
|
||||
title="Label Assistant Reply"
|
||||
desc="Given the following discussion, provide labels for the final prompt"
|
||||
messages={<MessageTable messages={messages} />}
|
||||
inputs={<LabelSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />}
|
||||
controls={
|
||||
<TaskControls
|
||||
tasks={tasks}
|
||||
onSkipTask={() => reset()}
|
||||
onNextTask={reset}
|
||||
onSubmitResponse={({ id, task }: LabelAssistantReplyTaskResponse) =>
|
||||
submit(id, task.message_id, task.reply, task.valid_labels, sliderValues)
|
||||
}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
<>
|
||||
<Head>
|
||||
<title>Label Assistant Reply</title>
|
||||
<meta name="description" content="Label Assistant Reply" />
|
||||
</Head>
|
||||
<Task tasks={tasks} trigger={trigger} mutate={reset} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -1,41 +1,28 @@
|
||||
import { useState } from "react";
|
||||
import { Container } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { MessageView } from "src/components/Messages";
|
||||
import { TaskControls } from "src/components/Survey/TaskControls";
|
||||
import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask";
|
||||
import {
|
||||
LabelInitialPromptTaskResponse,
|
||||
useLabelInitialPromptTask,
|
||||
} from "src/hooks/tasks/labeling/useLabelInitialPrompt";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import { useLabelInitialPromptTask } from "src/hooks/tasks/useLabelingTask";
|
||||
|
||||
const LabelInitialPrompt = () => {
|
||||
const [sliderValues, setSliderValues] = useState<number[]>([]);
|
||||
const { tasks, isLoading, trigger, reset } = useLabelInitialPromptTask();
|
||||
|
||||
const { tasks, isLoading, submit, reset } = useLabelInitialPromptTask();
|
||||
|
||||
if (isLoading || tasks.length === 0) {
|
||||
if (isLoading) {
|
||||
return <LoadingScreen />;
|
||||
}
|
||||
|
||||
const task = tasks[0].task;
|
||||
if (tasks.length === 0) {
|
||||
return <Container className="p-6 text-center text-gray-800">No tasks found...</Container>;
|
||||
}
|
||||
|
||||
return (
|
||||
<LabelTask
|
||||
title="Label Initial Prompt"
|
||||
desc="Provide labels for the following prompt"
|
||||
messages={<MessageView text={task.prompt} is_assistant message_id={task.message_id} />}
|
||||
inputs={<LabelSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />}
|
||||
controls={
|
||||
<TaskControls
|
||||
tasks={tasks}
|
||||
onSkipTask={() => reset()}
|
||||
onNextTask={reset}
|
||||
onSubmitResponse={({ id, task }: LabelInitialPromptTaskResponse) =>
|
||||
submit(id, task.message_id, task.prompt, task.valid_labels, sliderValues)
|
||||
}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
<>
|
||||
<Head>
|
||||
<title>Label Initial Prompt</title>
|
||||
<meta name="description" content="Label Initial Prompt" />
|
||||
</Head>
|
||||
<Task tasks={tasks} trigger={trigger} mutate={reset} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -1,46 +1,28 @@
|
||||
import { useState } from "react";
|
||||
import { Container } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Message } from "src/components/Messages";
|
||||
import { MessageTable } from "src/components/Messages/MessageTable";
|
||||
import { TaskControls } from "src/components/Survey/TaskControls";
|
||||
import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask";
|
||||
import {
|
||||
LabelPrompterReplyTaskResponse,
|
||||
useLabelPrompterReplyTask,
|
||||
} from "src/hooks/tasks/labeling/useLabelPrompterReply";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import { useLabelPrompterReplyTask } from "src/hooks/tasks/useLabelingTask";
|
||||
|
||||
const LabelPrompterReply = () => {
|
||||
const [sliderValues, setSliderValues] = useState<number[]>([]);
|
||||
const { tasks, isLoading, trigger, reset } = useLabelPrompterReplyTask();
|
||||
|
||||
const { tasks, isLoading, submit, reset } = useLabelPrompterReplyTask();
|
||||
|
||||
if (isLoading || tasks.length === 0) {
|
||||
if (isLoading) {
|
||||
return <LoadingScreen />;
|
||||
}
|
||||
|
||||
const task = tasks[0].task;
|
||||
const messages: Message[] = [
|
||||
...task.conversation.messages,
|
||||
{ text: task.reply, is_assistant: false, message_id: task.message_id },
|
||||
];
|
||||
if (tasks.length === 0) {
|
||||
return <Container className="p-6 text-center text-gray-800">No tasks found...</Container>;
|
||||
}
|
||||
|
||||
return (
|
||||
<LabelTask
|
||||
title="Label Prompter Reply"
|
||||
desc="Given the following discussion, provide labels for the final prompt"
|
||||
messages={<MessageTable messages={messages} />}
|
||||
inputs={<LabelSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />}
|
||||
controls={
|
||||
<TaskControls
|
||||
tasks={tasks}
|
||||
onSkipTask={() => reset()}
|
||||
onNextTask={reset}
|
||||
onSubmitResponse={({ id, task }: LabelPrompterReplyTaskResponse) =>
|
||||
submit(id, task.message_id, task.reply, task.valid_labels, sliderValues)
|
||||
}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
<>
|
||||
<Head>
|
||||
<title>Label Prompter Reply</title>
|
||||
<meta name="description" content="Label Prompter Reply" />
|
||||
</Head>
|
||||
<Task tasks={tasks} trigger={trigger} mutate={reset} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ const MessageDetail = ({ id }) => {
|
||||
Parent
|
||||
</Text>
|
||||
<Box rounded="lg" p="2">
|
||||
<MessageTableEntry item={parent} idx={1} />
|
||||
<MessageTableEntry item={parent} idx={1} valid_labels={[]} />
|
||||
</Box>
|
||||
</>
|
||||
)}
|
||||
|
||||
@@ -2,9 +2,9 @@ import { Box, CircularProgress, SimpleGrid, Text, useColorModeValue } from "@cha
|
||||
import Head from "next/head";
|
||||
import { useEffect, useState } from "react";
|
||||
import { getDashboardLayout } from "src/components/Layout";
|
||||
import { Message } from "src/components/Messages";
|
||||
import { MessageTable } from "src/components/Messages/MessageTable";
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import { Message } from "src/types/Conversation";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
|
||||
const MessagesDashboard = () => {
|
||||
@@ -52,7 +52,11 @@ const MessagesDashboard = () => {
|
||||
borderRadius="xl"
|
||||
className="p-6 shadow-sm"
|
||||
>
|
||||
{receivedMessages ? <MessageTable messages={messages} /> : <CircularProgress isIndeterminate />}
|
||||
{receivedMessages ? (
|
||||
<MessageTable messages={messages} valid_labels={[]} />
|
||||
) : (
|
||||
<CircularProgress isIndeterminate />
|
||||
)}
|
||||
</Box>
|
||||
</Box>
|
||||
<Box>
|
||||
@@ -66,7 +70,11 @@ const MessagesDashboard = () => {
|
||||
borderRadius="xl"
|
||||
className="p-6 shadow-sm"
|
||||
>
|
||||
{receivedUserMessages ? <MessageTable messages={userMessages} /> : <CircularProgress isIndeterminate />}
|
||||
{receivedUserMessages ? (
|
||||
<MessageTable messages={userMessages} valid_labels={[]} />
|
||||
) : (
|
||||
<CircularProgress isIndeterminate />
|
||||
)}
|
||||
</Box>
|
||||
</Box>
|
||||
</SimpleGrid>
|
||||
@@ -74,6 +82,6 @@ const MessagesDashboard = () => {
|
||||
);
|
||||
};
|
||||
|
||||
MessagesDashboard.getLayout = (page) => getDashboardLayout(page);
|
||||
MessagesDashboard.getLayout = getDashboardLayout;
|
||||
|
||||
export default MessagesDashboard;
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
export interface Message {
|
||||
text: string;
|
||||
is_assistant: boolean;
|
||||
message_id: string;
|
||||
}
|
||||
|
||||
export interface Conversation {
|
||||
messages: Message[];
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
export const enum TaskType {
|
||||
initial_prompt = "initial_prompt",
|
||||
assistant_reply = "assistant_reply",
|
||||
prompter_reply = "prompter_reply",
|
||||
|
||||
rank_initial_prompts = "rank_initial_prompts",
|
||||
rank_assistant_replies = "rank_assistant_replies",
|
||||
rank_prompter_replies = "rank_prompter_replies",
|
||||
|
||||
label_initial_prompt = "label_initial_prompt",
|
||||
label_prompter_reply = "label_prompter_reply",
|
||||
label_assistant_reply = "label_assistant_reply",
|
||||
}
|
||||
|
||||
export interface ValidLabel {
|
||||
name: string;
|
||||
display_text: string;
|
||||
help_text: string;
|
||||
}
|
||||
|
||||
export interface BaseTask {
|
||||
id: string;
|
||||
type: TaskType;
|
||||
}
|
||||
|
||||
export interface TaskResponse<Task extends BaseTask> {
|
||||
id: string;
|
||||
userId: string;
|
||||
task: Task;
|
||||
valid_labels: ValidLabel[];
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
import { Conversation } from "./Conversation";
|
||||
import { BaseTask, TaskType } from "./Task";
|
||||
|
||||
export interface CreateInitialPromptTask extends BaseTask {
|
||||
type: TaskType.initial_prompt;
|
||||
hint: string;
|
||||
}
|
||||
|
||||
export interface CreateAssistantReplyTask extends BaseTask {
|
||||
type: TaskType.assistant_reply;
|
||||
conversation: Conversation;
|
||||
}
|
||||
|
||||
export interface CreatePrompterReplyTask extends BaseTask {
|
||||
type: TaskType.prompter_reply;
|
||||
conversation: Conversation;
|
||||
}
|
||||
|
||||
export interface RankInitialPromptsTask extends BaseTask {
|
||||
type: TaskType.rank_initial_prompts;
|
||||
prompts: string[];
|
||||
}
|
||||
|
||||
export interface RankAssistantRepliesTask extends BaseTask {
|
||||
type: TaskType.rank_assistant_replies;
|
||||
conversation: Conversation;
|
||||
replies: string[];
|
||||
}
|
||||
|
||||
export interface RankPrompterRepliesTask extends BaseTask {
|
||||
type: TaskType.rank_prompter_replies;
|
||||
conversation: Conversation;
|
||||
replies: string[];
|
||||
}
|
||||
|
||||
export interface LabelAssistantReplyTask extends BaseTask {
|
||||
type: TaskType.label_assistant_reply;
|
||||
message_id: string;
|
||||
conversation: Conversation;
|
||||
reply: string;
|
||||
valid_labels: string[];
|
||||
}
|
||||
|
||||
export interface LabelInitialPromptTask extends BaseTask {
|
||||
type: TaskType.label_initial_prompt;
|
||||
message_id: string;
|
||||
valid_labels: string[];
|
||||
prompt: string;
|
||||
}
|
||||
|
||||
export interface LabelPrompterReplyTask extends BaseTask {
|
||||
type: TaskType.label_prompter_reply;
|
||||
message_id: string;
|
||||
conversation: Conversation;
|
||||
reply: string;
|
||||
valid_labels: string[];
|
||||
}
|
||||
Reference in New Issue
Block a user