Merge branch 'LAION-AI:main' into main

This commit is contained in:
Andrew Maguire
2023-01-10 13:08:22 +00:00
committed by GitHub
66 changed files with 668 additions and 436 deletions
+11 -3
View File
@@ -4,20 +4,28 @@ on:
push:
branches:
- main
pull_request:
workflow_call:
pull_request_target:
jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
# in case of PR, check out the PR's head branch
- uses: actions/checkout@v3
if: github.event_name == 'pull_request_target'
with:
ref: ${{ github.event.pull_request.head.sha }}
# in case of push, check out the main branch
- uses: actions/checkout@v3
if: github.event_name == 'push'
- uses: actions/setup-python@v4
with:
python-version: "3.10"
- uses: pre-commit/action@v3.0.0
- name: Post PR comment on failure
if: failure() && github.event_name == 'pull_request'
if: failure() && github.event_name == 'pull_request_target'
uses: peter-evans/create-or-update-comment@v2
with:
issue-number: ${{ github.event.pull_request.number }}
+39
View File
@@ -0,0 +1,39 @@
name: E2E Tests (Website)
on:
push:
branches:
- main
paths:
- oasst-shared/**
- backend/**
- website/**
pull_request:
paths:
- oasst-shared/**
- backend/**
- website/**
jobs:
test-e2e:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Start website, backend, etc
run: docker compose up ci --build -d
- name: Run Cypress tests
uses: cypress-io/github-action@v5.0.2
with:
browser: chrome
working-directory: website
- uses: actions/upload-artifact@v3
if: failure() # NOTE: screenshots will be generated only if E2E test failed
with:
name: cypress-screenshots
path: website/cypress/screenshots
- uses: actions/upload-artifact@v3
if: always()
with:
name: cypress-videos
path: website/cypress/videos
+5
View File
@@ -110,3 +110,8 @@ Upon making a release on GitHub, all docker images are automatically built and
pushed to ghcr.io. The docker images are tagged with the release version, and
the `latest` tag. Further, the ansible playbook in `ansible/dev.yaml` is run to
automatically deploy the built release to the dev machine.
### Contribute a Dataset
See
[here](https://github.com/LAION-AI/Open-Assistant/blob/main/docs/docs/data/datasets.md)
+1
View File
@@ -81,6 +81,7 @@
DEBUG_USE_SEED_DATA: "true"
MAX_WORKERS: "1"
RATE_LIMIT: "false"
DEBUG_SKIP_EMBEDDING_COMPUTATION: "true"
ports:
- 8080:8080
@@ -0,0 +1,27 @@
"""added miniLM_embedding column to message
Revision ID: 023548d474f7
Revises: ba61fe17fb6e
Create Date: 2023-01-08 11:06:25.613290
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "023548d474f7"
down_revision = "ba61fe17fb6e"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("message", sa.Column("miniLM_embedding", sa.ARRAY(sa.Float()), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("message", "miniLM_embedding")
# ### end Alembic commands ###
@@ -0,0 +1,49 @@
"""embedding for message now in its own table
Revision ID: 35bdc1a08bb8
Revises: 023548d474f7
Create Date: 2023-01-08 16:03:48.454207
"""
import sqlalchemy as sa
import sqlmodel
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "35bdc1a08bb8"
down_revision = "023548d474f7"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"message_embedding",
sa.Column("message_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("embedding", sa.ARRAY(sa.Float()), nullable=True),
sa.Column("model", sqlmodel.sql.sqltypes.AutoString(length=256), nullable=False),
sa.ForeignKeyConstraint(
["message_id"],
["message.id"],
),
sa.PrimaryKeyConstraint("message_id", "model"),
)
op.drop_column("message", "miniLM_embedding")
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"message",
sa.Column(
"miniLM_embedding",
postgresql.ARRAY(postgresql.DOUBLE_PRECISION(precision=53)),
autoincrement=False,
nullable=True,
),
)
op.drop_table("message_embedding")
# ### end Alembic commands ###
@@ -0,0 +1,30 @@
"""Created date
Revision ID: aac6b2f66006
Revises: 35bdc1a08bb8
Create Date: 2023-01-08 21:28:27.342729
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "aac6b2f66006"
down_revision = "35bdc1a08bb8"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"message_embedding",
sa.Column("created_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("message_embedding", "created_date")
# ### end Alembic commands ###
+2 -7
View File
@@ -1,19 +1,14 @@
from enum import Enum
from typing import List
from fastapi import APIRouter, Depends
from oasst_backend.api import deps
from oasst_backend.models import ApiClient
from oasst_backend.schemas.hugging_face import ToxicityClassification
from oasst_backend.utils.hugging_face import HuggingFaceAPI
from oasst_backend.utils.hugging_face import HfUrl, HuggingFaceAPI
router = APIRouter()
class HF_url(str, Enum):
HUGGINGFACE_TOXIC_ROBERTA = "https://api-inference.huggingface.co/models/unitary/multilingual-toxic-xlm-roberta"
@router.get("/text_toxicity")
async def get_text_toxicity(
msg: str,
@@ -30,7 +25,7 @@ async def get_text_toxicity(
ToxicityClassification: the score of toxicity of the message.
"""
api_url: str = HF_url.HUGGINGFACE_TOXIC_ROBERTA.value
api_url: str = HfUrl.HUGGINGFACE_TOXIC_ROBERTA.value
hugging_face_api = HuggingFaceAPI(api_url)
response = await hugging_face_api.post(msg)
+18 -2
View File
@@ -7,7 +7,9 @@ from fastapi.security.api_key import APIKey
from loguru import logger
from oasst_backend.api import deps
from oasst_backend.api.v1.utils import prepare_conversation
from oasst_backend.config import settings
from oasst_backend.prompt_repository import PromptRepository, TaskRepository
from oasst_backend.utils.hugging_face import HfEmbeddingModel, HfUrl, HuggingFaceAPI
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
@@ -253,7 +255,7 @@ def tasks_acknowledge_failure(
@router.post("/interaction", response_model=protocol_schema.TaskDone)
def tasks_interaction(
async def tasks_interaction(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
@@ -274,12 +276,26 @@ def tasks_interaction(
)
# here we store the text reply in the database
pr.store_text_reply(
newMessage = pr.store_text_reply(
text=interaction.text,
frontend_message_id=interaction.message_id,
user_frontend_message_id=interaction.user_message_id,
)
if not settings.DEBUG_SKIP_EMBEDDING_COMPUTATION:
try:
hugging_face_api = HuggingFaceAPI(
f"{HfUrl.HUGGINGFACE_FEATURE_EXTRACTION.value}/{HfEmbeddingModel.MINILM.value}"
)
embedding = await hugging_face_api.post(interaction.text)
pr.insert_message_embedding(
message_id=newMessage.id, model=HfEmbeddingModel.MINILM.value, embedding=embedding
)
except OasstError:
logger.error(
f"Could not fetch embbeddings for text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
)
return protocol_schema.TaskDone()
case protocol_schema.MessageRating:
logger.info(
+1
View File
@@ -25,6 +25,7 @@ class Settings(BaseSettings):
DEBUG_USE_SEED_DATA_PATH: Optional[FilePath] = (
Path(__file__).parent.parent / "test_data/generic/test_generic_data.json"
)
DEBUG_SKIP_EMBEDDING_COMPUTATION: bool = False
HUGGING_FACE_API_KEY: str = ""
+2
View File
@@ -1,6 +1,7 @@
from .api_client import ApiClient
from .journal import Journal, JournalIntegration
from .message import Message
from .message_embedding import MessageEmbedding
from .message_reaction import MessageReaction
from .message_tree_state import MessageTreeState
from .task import Task
@@ -13,6 +14,7 @@ __all__ = [
"User",
"UserStats",
"Message",
"MessageEmbedding",
"MessageReaction",
"MessageTreeState",
"Task",
@@ -0,0 +1,21 @@
from datetime import datetime
from typing import List, Optional
from uuid import UUID
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from sqlmodel import ARRAY, Field, Float, SQLModel
class MessageEmbedding(SQLModel, table=True):
__tablename__ = "message_embedding"
__table_args__ = (sa.PrimaryKeyConstraint("message_id", "model"),)
message_id: UUID = Field(sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("message.id"), nullable=False))
model: str = Field(max_length=256, nullable=False)
embedding: List[float] = Field(sa_column=sa.Column(ARRAY(Float)), nullable=True)
# In the case that the Message Embedding is created afterwards
created_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
)
+32 -3
View File
@@ -2,13 +2,13 @@ import datetime
import random
from collections import defaultdict
from http import HTTPStatus
from typing import Optional
from typing import List, Optional
from uuid import UUID, uuid4
import oasst_backend.models.db_payload as db_payload
from loguru import logger
from oasst_backend.journal_writer import JournalWriter
from oasst_backend.models import ApiClient, Message, MessageReaction, TextLabels, User
from oasst_backend.models import ApiClient, Message, MessageEmbedding, MessageReaction, TextLabels, User
from oasst_backend.models.payload_column_type import PayloadContainer
from oasst_backend.task_repository import TaskRepository, validate_frontend_message_id
from oasst_backend.user_repository import UserRepository
@@ -91,7 +91,12 @@ class PromptRepository:
self.db.refresh(message)
return message
def store_text_reply(self, text: str, frontend_message_id: str, user_frontend_message_id: str) -> Message:
def store_text_reply(
self,
text: str,
frontend_message_id: str,
user_frontend_message_id: str,
) -> Message:
validate_frontend_message_id(frontend_message_id)
validate_frontend_message_id(user_frontend_message_id)
@@ -224,6 +229,30 @@ class PromptRepository:
OasstErrorCode.TASK_PAYLOAD_TYPE_MISMATCH,
)
def insert_message_embedding(self, message_id: UUID, model: str, embedding: List[float]) -> MessageEmbedding:
"""Insert the embedding of a new message in the database.
Args:
message_id (UUID): the identifier of the message we want to save its embedding
model (str): the model used for creating the embedding
embedding (List[float]): the values obtained from the message & model
Raises:
OasstError: if misses some of the before params
Returns:
MessageEmbedding: the instance in the database of the embedding saved for that message
"""
if None in (message_id, model, embedding):
raise OasstError("Paramters missing to add embedding", OasstErrorCode.GENERIC_ERROR)
message_embedding = MessageEmbedding(message_id=message_id, model=model, embedding=embedding)
self.db.add(message_embedding)
self.db.commit()
self.db.refresh(message_embedding)
return message_embedding
def insert_reaction(self, task_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction:
if self.user_id is None:
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
@@ -1,10 +1,21 @@
from enum import Enum
from typing import Any, Dict
import aiohttp
from loguru import logger
from oasst_backend.config import settings
from oasst_shared.exceptions import OasstError, OasstErrorCode
class HfUrl(str, Enum):
HUGGINGFACE_TOXIC_ROBERTA = ("https://api-inference.huggingface.co/models/unitary/multilingual-toxic-xlm-roberta",)
HUGGINGFACE_FEATURE_EXTRACTION = "https://api-inference.huggingface.co/pipeline/feature-extraction"
class HfEmbeddingModel(str, Enum):
MINILM = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
class HuggingFaceAPI:
"""Class Object to make post calls to endpoints for inference in models hosted in HuggingFace"""
@@ -41,6 +52,9 @@ class HuggingFaceAPI:
async with session.post(self.api_url, headers=self.headers, json=payload) as response:
# If we get a bad response
if response.status != 200:
logger.error(response)
logger.info(self.headers)
raise OasstError(
"Response Error Detoxify HuggingFace", error_code=OasstErrorCode.HUGGINGFACE_API_ERROR
)
+6
View File
@@ -11,6 +11,11 @@ services:
image: sverrirab/sleep
depends_on: [db, webdb, adminer, maildev, backend, redis]
# Used by CI automations.
ci:
image: sverrirab/sleep
depends_on: [db, webdb, maildev, backend, redis, web]
# This DB is for the FastAPI Backend.
db:
image: postgres
@@ -95,6 +100,7 @@ services:
- DEBUG_SKIP_API_KEY_CHECK=True
- DEBUG_USE_SEED_DATA=True
- MAX_WORKERS=1
- DEBUG_SKIP_EMBEDDING_COMPUTATION=True
depends_on:
db:
condition: service_healthy
+1 -1
View File
@@ -1,6 +1,6 @@
# Example Notebook
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/andrewm4894/Open-Assistant/blob/main/notebooks/example/example.ipynb)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/LAION-AI/Open-Assistant/blob/main/notebooks/example/example.ipynb)
This folder contains an example reference notebook structure and approach for
this project. Please try and follow this structure as closely as possible. While
+5 -4
View File
@@ -9,10 +9,11 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/andrewm4894/Open-Assistant/blob/example-notebook/notebooks/example/example.ipynb)"
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/LAION-AI/Open-Assistant/blob/example-notebook/notebooks/example/example.ipynb)"
]
},
{
@@ -22,7 +23,7 @@
"outputs": [],
"source": [
"# uncomment and run below lines to set up if running in colab\n",
"# !git clone https://github.com/andrewm4894/Open-Assistant.git\n",
"# !git clone https://github.com/LAION-AI/Open-Assistant.git\n",
"# %cd Open-Assistant/notebooks/example\n",
"# !pip install -r requirements.txt"
]
@@ -146,12 +147,12 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
"version": "3.7.4 (tags/v3.7.4:e09359112e, Jul 8 2019, 20:34:20) [MSC v.1916 64 bit (AMD64)]"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "3ad933181bd8a04b432d3370b9dc3b0662ad032c4dfaa4e4f1596c548f763858"
"hash": "25d5c2324055587ceaeef27650c79ce8358ea61d7689f2e0b8ada5d53f85bce4"
}
}
},
+1
View File
@@ -6,6 +6,7 @@ pushd "$parent_path/../../backend"
export DEBUG_SKIP_API_KEY_CHECK=True
export DEBUG_USE_SEED_DATA=True
export DEBUG_SKIP_EMBEDDING_COMPUTATION=True
uvicorn main:app --reload --port 8080 --host 0.0.0.0
@@ -1,7 +1,7 @@
describe("ranking prompter replies", () => {
it("completes the current task on submit and on request shows a new task", () => {
cy.signInWithEmail("cypress@example.com");
cy.visit("/evaluate/rank_user_replies");
cy.visit("/evaluate/rank_assistant_replies");
cy.get('[data-cy="task-id"').then((taskIdElement) => {
const taskId = taskIdElement.text();
@@ -1,7 +1,7 @@
describe("ranking assistant replies", () => {
it("completes the current task on submit and on request shows a new task", () => {
cy.signInWithEmail("cypress@example.com");
cy.visit("/evaluate/rank_assistant_replies");
cy.visit("/evaluate/rank_user_replies");
cy.get('[data-cy="task-id"').then((taskIdElement) => {
const taskId = taskIdElement.text();
-5
View File
@@ -1,10 +1,6 @@
import {
Button,
ButtonProps,
Menu,
MenuButton,
MenuItem,
MenuList,
Modal,
ModalBody,
ModalCloseButton,
@@ -16,7 +12,6 @@ import {
useDisclosure,
} from "@chakra-ui/react";
import { useState } from "react";
import { FaChevronDown } from "react-icons/fa";
interface SkipButtonProps extends ButtonProps {
onSkip: (reason: string) => void;
+1 -1
View File
@@ -12,7 +12,7 @@ import React from "react";
export const CollapsableText = ({ text, maxLength = 220 }) => {
const { isOpen, onOpen, onClose } = useDisclosure();
if (typeof text != "string" || text.length <= maxLength) {
if (typeof text !== "string" || text.length <= maxLength) {
return text;
} else {
return (
+18 -48
View File
@@ -27,11 +27,26 @@ import poster from "src/lib/poster";
import { colors } from "styles/Theme/colors";
import useSWRMutation from "swr/mutation";
interface textFlagLabels {
attributeName: string;
labelText: string;
additionalExplanation?: string;
}
export const FlaggableElement = (props) => {
const [isEditing, setIsEditing] = useBoolean();
const flaggable_labels = props.flaggable_labels;
const TEXT_LABEL_FLAGS =
flaggable_labels?.valid_labels?.map((valid_label) => {
return {
attributeName: valid_label.name,
labelText: valid_label.display_text,
additionalExplanation: valid_label.help_text,
};
}) || [];
const { trigger } = useSWRMutation("/api/set_label", poster, {
onSuccess: () => {
setIsEditing.off;
setIsEditing.off();
},
});
@@ -55,14 +70,14 @@ export const FlaggableElement = (props) => {
const handleCheckboxState = (isChecked, idx) => {
setCheckboxValues(
checkboxValues.map((val, i) => {
return i == idx ? isChecked : val;
return i === idx ? isChecked : val;
})
);
};
const handleSliderState = (newVal, idx) => {
setSliderValues(
sliderValues.map((val, i) => {
return i == idx ? newVal : val;
return i === idx ? newVal : val;
})
);
};
@@ -181,48 +196,3 @@ export function FlagCheckbox(props: {
</Flex>
);
}
interface textFlagLabels {
attributeName: string;
labelText: string;
additionalExplanation?: string;
}
const TEXT_LABEL_FLAGS: textFlagLabels[] = [
// For the time being this list is configured on the FE.
// In the future it may be provided by the API.
// {
// attributeName: "fails_task",
// labelText: "Fails to follow the correct instruction / task",
// additionalExplanation: "__TODO__",
// },
// {
// attributeName: "not_customer_assistant_appropriate",
// labelText: "Inappropriate for customer assistant",
// additionalExplanation: "__TODO__",
// },
{
attributeName: "sexual_content",
labelText: "Contains sexual content",
},
{
attributeName: "violence",
labelText: "Contains violent content",
},
// {
// attributeName: "encourages_violence",
// labelText: "Encourages or fails to discourage violence/abuse/terrorism/self-harm",
// },
// {
// attributeName: "denigrates_a_protected_class",
// labelText: "Denigrates a protected class",
// },
// {
// attributeName: "gives_harmful_advice",
// labelText: "Fails to follow the correct instruction / task",
// additionalExplanation:
// "The advice given in the output is harmful or counter-productive. This may be in addition to, but is distinct from the question about encouraging violence/abuse/terrorism/self-harm.",
// },
// {
// attributeName: "expresses_moral_judgement",
// labelText: "Expresses moral judgement",
// },
];
+28 -12
View File
@@ -1,20 +1,30 @@
import { Grid } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import { forwardRef, useColorMode } from "@chakra-ui/react";
import { useMemo } from "react";
import { Message } from "src/types/Conversation";
import { ValidLabel } from "src/types/Task";
import { FlaggableElement } from "./FlaggableElement";
export interface Message {
text: string;
is_assistant: boolean;
message_id: string;
}
export const Messages = ({ messages, post_id }: { messages: Message[]; post_id: string }) => {
export const Messages = ({
messages,
post_id,
valid_labels,
}: {
messages: Message[];
post_id: string;
valid_labels: ValidLabel[];
}) => {
const items = messages.map((messageProps: Message, i: number) => {
const { message_id, text } = messageProps;
return (
<FlaggableElement text={text} post_id={post_id} message_id={message_id} key={i + text}>
<FlaggableElement
text={text}
post_id={post_id}
message_id={message_id}
key={i + text}
flaggable_labels={valid_labels}
>
<MessageView {...messageProps} />
</FlaggableElement>
);
@@ -23,7 +33,7 @@ export const Messages = ({ messages, post_id }: { messages: Message[]; post_id:
return <Grid gap={2}>{items}</Grid>;
};
export const MessageView = ({ is_assistant, text, message_id }: Message) => {
export const MessageView = forwardRef<Message, "div">(({ is_assistant, text }: Message, ref) => {
const { colorMode } = useColorMode();
const bgColor = useMemo(() => {
@@ -34,5 +44,11 @@ export const MessageView = ({ is_assistant, text, message_id }: Message) => {
}
}, [colorMode, is_assistant]);
return <div className={`${bgColor} p-4 rounded-md text-white whitespace-pre-wrap`}>{text}</div>;
};
return (
<div ref={ref} className={`${bgColor} p-4 rounded-md text-white whitespace-pre-wrap`}>
{text}
</div>
);
});
MessageView.displayName = "MessageView";
@@ -1,11 +1,11 @@
import { Stack, StackDivider } from "@chakra-ui/react";
import { MessageTableEntry } from "src/components/Messages/MessageTableEntry";
export function MessageTable({ messages }) {
export function MessageTable({ messages, valid_labels }) {
return (
<Stack divider={<StackDivider />} spacing="4">
{messages.map((item, idx) => (
<MessageTableEntry item={item} idx={idx} key={item.id} />
<MessageTableEntry item={item} idx={idx} key={item.message_id || item.id} valid_labels={valid_labels} />
))}
</Stack>
);
@@ -2,6 +2,7 @@ import { Avatar, HStack, LinkBox, useColorModeValue } from "@chakra-ui/react";
import { boolean } from "boolean";
import NextLink from "next/link";
import { FlaggableElement } from "src/components/FlaggableElement";
import type { ValidLabel } from "src/types/Task";
interface Message {
text: string;
@@ -11,13 +12,14 @@ interface Message {
interface MessageTableEntryProps {
item: Message;
idx: number;
valid_labels: ValidLabel[];
}
export function MessageTableEntry(props: MessageTableEntryProps) {
const { item, idx } = props;
const { item, idx, valid_labels } = props;
const bgColor = useColorModeValue(idx % 2 === 0 ? "bg-slate-800" : "bg-black", "bg-sky-900");
return (
<FlaggableElement text={item.text} post_id={item.id} key={`flag_${item.id}`}>
<FlaggableElement text={item.text} post_id={item.id} key={`flag_${item.id}`} flaggable_labels={valid_labels}>
<HStack>
<Avatar
name={`${boolean(item.is_assistant) ? "Assitant" : "User"}`}
@@ -64,7 +64,7 @@ export function MessageWithChildren(props: MessageWithChildrenProps) {
<Flex justifyContent="center" pb="2">
<Box maxWidth="container.sm" flex="1" px={isFirstOrOnly ? [4, 6, 8, 9] : "0"}>
<Box px={isFirstOrOnly ? "2" : "0"}>
<MessageTableEntry item={message} idx={1} />
<MessageTableEntry item={message} idx={1} valid_labels={[]} />
</Box>
</Box>
</Flex>
@@ -90,7 +90,7 @@ export function MessageWithChildren(props: MessageWithChildrenProps) {
<HStack {...MessageStackProps}>
{children.map((item, idx) => (
<Box maxWidth="container.sm" flex="1" key={`recursiveMessageWChildren_${idx}`}>
<MessageTableEntry item={item} idx={idx * 2} />
<MessageTableEntry item={item} idx={idx * 2} valid_labels={[]} />
</Box>
))}
</HStack>
+6 -5
View File
@@ -1,16 +1,15 @@
import { useState } from "react";
import { Messages } from "src/components/Messages";
import { TaskControls } from "src/components/Survey/TaskControls";
import { TrackedTextarea } from "src/components/Survey/TrackedTextarea";
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
import { TaskType } from "./TaskTypes";
import { TaskInfo } from "src/components/Tasks/TaskTypes";
export interface CreateTaskProps {
// we need a task type
// eslint-disable-next-line @typescript-eslint/no-explicit-any
tasks: any[];
taskType: TaskType;
taskType: TaskInfo;
trigger: (update: { id: string; update_type: string; content: { text: string } }) => void;
onSkipTask: (task: { id: string }, reason: string) => void;
onNextTask: () => void;
@@ -18,7 +17,7 @@ export interface CreateTaskProps {
}
export const CreateTask = ({ tasks, taskType, trigger, onSkipTask, onNextTask, mainBgClasses }: CreateTaskProps) => {
const task = tasks[0].task;
const valid_labels = tasks[0].valid_labels;
const [inputText, setInputText] = useState("");
const submitResponse = (task: { id: string }) => {
@@ -42,7 +41,9 @@ export const CreateTask = ({ tasks, taskType, trigger, onSkipTask, onNextTask, m
<>
<h5 className="text-lg font-semibold">{taskType.label}</h5>
<p className="text-lg py-1">{taskType.overview}</p>
{task.conversation ? <Messages messages={task.conversation.messages} post_id={task.id} /> : null}
{task.conversation ? (
<Messages messages={task.conversation.messages} post_id={task.id} valid_labels={valid_labels} />
) : null}
</>
<>
<h5 className="text-lg font-semibold">{taskType.instruction}</h5>
@@ -33,6 +33,7 @@ export const EvaluateTask = ({ tasks, trigger, onSkipTask, onNextTask, mainBgCla
messages = messages.map((message, index) => ({ ...message, id: index }));
}
const valid_labels = tasks[0].valid_labels;
const sortables = tasks[0].task.replies ? "replies" : "prompts";
return (
@@ -42,13 +43,13 @@ export const EvaluateTask = ({ tasks, trigger, onSkipTask, onNextTask, mainBgCla
<p className="text-lg py-1">
Given the following {sortables}, sort them from best to worst, best being first, worst being last.
</p>
{messages ? <MessageTable messages={messages} /> : null}
{messages ? <MessageTable messages={messages} valid_labels={valid_labels} /> : null}
<Sortable items={tasks[0].task[sortables]} onChange={setRanking} className="my-8" />
</SurveyCard>
<TaskControlsOverridable
tasks={tasks}
isValid={ranking.length == tasks[0].task[sortables].length}
isValid={ranking.length === tasks[0].task[sortables].length}
prepareForSubmit={() => setRanking(tasks[0].task[sortables].map((_, idx) => idx))}
onSubmitResponse={submitResponse}
onSkipTask={(task, reason) => {
+57 -29
View File
@@ -1,43 +1,71 @@
import { Grid, Slider, SliderFilledTrack, SliderThumb, SliderTrack } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import { ReactNode, useEffect, useId, useMemo, useState } from "react";
import { useEffect, useId, useState } from "react";
import { MessageView } from "src/components/Messages";
import { MessageTable } from "src/components/Messages/MessageTable";
import { TaskControls } from "src/components/Survey/TaskControls";
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
import { TaskInfo } from "src/components/Tasks/TaskTypes";
import { TaskType } from "src/types/Task";
import { colors } from "styles/Theme/colors";
export const LabelTask = ({
title,
desc,
messages,
inputs,
controls,
}: {
title: string;
desc: string;
messages: ReactNode;
inputs: ReactNode;
controls: ReactNode;
}) => {
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
export interface LabelTaskProps {
// we need a task type
// eslint-disable-next-line @typescript-eslint/no-explicit-any
tasks: any[];
taskType: TaskInfo;
trigger: (update: {
id: string;
update_type: string;
content: { text: string; labels: { [k: string]: number }; message_id: string };
}) => void;
onSkipTask: (task: { id: string }, reason: string) => void;
onNextTask: () => void;
mainBgClasses: string;
}
export const LabelTask = ({ tasks, taskType, trigger, onSkipTask, onNextTask, mainBgClasses }: LabelTaskProps) => {
const task = tasks[0].task;
const valid_labels = tasks[0].valid_labels;
const card = useMemo(
() => (
<>
<h5 className="text-lg font-semibold">{title}</h5>
<p className="text-lg py-1">{desc}</p>
{messages}
</>
),
[title, desc, messages]
);
const [sliderValues, setSliderValues] = useState<number[]>([]);
const submitResponse = (task: { id: string; reply: string; message_id: string }) => {
console.assert(valid_labels.length === sliderValues.length);
const labels = Object.fromEntries(valid_labels.valid_labels.map((label, i) => [label, sliderValues[i]]));
trigger({
id: task.id,
update_type: "text_labels",
content: { labels, text: task.reply, message_id: task.message_id },
});
};
return (
<div className={`p-12 ${mainBgClasses}`}>
<TwoColumnsWithCards>
{card}
{inputs}
<>
<h5 className="text-lg font-semibold">{taskType.label}</h5>
<p className="text-lg py-1">{taskType.overview}</p>
{task.conversation ? (
<MessageTable
messages={[
...(task.conversation ? task.conversation.messages : []),
{
text: task.reply,
is_assistant: task.type === TaskType.label_assistant_reply,
message_id: task.message_id,
},
]}
valid_labels={valid_labels}
/>
) : (
<MessageView text={task.prompt} is_assistant={false} message_id={task.message_id} />
)}
</>
<LabelSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />
</TwoColumnsWithCards>
{controls}
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkipTask={onSkipTask} onNextTask={onNextTask} />
</div>
);
};
+21 -5
View File
@@ -1,12 +1,17 @@
import { CreateTask } from "./CreateTask";
import { EvaluateTask } from "./EvaluateTask";
import { TaskCategory, TaskTypes } from "./TaskTypes";
import useSWRMutation from "swr/mutation";
import { useColorMode } from "@chakra-ui/react";
import { CreateTask } from "src/components/Tasks/CreateTask";
import { EvaluateTask } from "src/components/Tasks/EvaluateTask";
import { LabelTask } from "src/components/Tasks/LabelTask";
import { TaskCategory, TaskTypes } from "src/components/Tasks/TaskTypes";
import poster from "src/lib/poster";
import useSWRMutation from "swr/mutation";
export const Task = ({ tasks, trigger, mutate, mainBgClasses }) => {
export const Task = ({ tasks, trigger, mutate }) => {
const task = tasks[0].task;
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
const { trigger: sendRejection } = useSWRMutation("/api/reject_task", poster, {
onSuccess: async () => {
mutate();
@@ -45,6 +50,17 @@ export const Task = ({ tasks, trigger, mutate, mainBgClasses }) => {
mainBgClasses={mainBgClasses}
/>
);
case TaskCategory.Label:
return (
<LabelTask
tasks={tasks}
taskType={taskType}
trigger={trigger}
onSkipTask={rejectTask}
onNextTask={mutate}
mainBgClasses={mainBgClasses}
/>
);
}
}
+5 -2
View File
@@ -4,7 +4,7 @@ export enum TaskCategory {
Label = "Label",
}
export interface TaskType {
export interface TaskInfo {
label: string;
desc: string;
category: TaskCategory;
@@ -14,7 +14,7 @@ export interface TaskType {
instruction?: string;
}
export const TaskTypes: TaskType[] = [
export const TaskTypes: TaskInfo[] = [
// create
{
label: "Create Initial Prompts",
@@ -71,6 +71,7 @@ export const TaskTypes: TaskType[] = [
desc: "Provide labels for a prompt.",
category: TaskCategory.Label,
pathname: "/label/label_initial_prompt",
overview: "Provide labels for the following prompt",
type: "label_initial_prompt",
},
{
@@ -78,6 +79,7 @@ export const TaskTypes: TaskType[] = [
desc: "Provide labels for a prompt.",
category: TaskCategory.Label,
pathname: "/label/label_prompter_reply",
overview: "Given the following discussion, provide labels for the final promp",
type: "label_prompter_reply",
},
{
@@ -85,6 +87,7 @@ export const TaskTypes: TaskType[] = [
desc: "Provide labels for a prompt.",
category: TaskCategory.Label,
pathname: "/label/label_assistant_reply",
overview: "Given the following discussion, provide labels for the final prompt.",
type: "label_assistant_reply",
},
];
@@ -1,9 +0,0 @@
import { useGenericTaskAPI } from "../useGenericTaskAPI";
interface CreateInitialPromptTask {
id: string;
type: "initial_prompt";
hint: string;
}
export const useCreateInitialPrompt = () => useGenericTaskAPI<CreateInitialPromptTask>("initial_prompt");
@@ -1,24 +0,0 @@
import { useGenericTaskAPI } from "../useGenericTaskAPI";
interface BaseCreateReplyTask {
id: string;
conversation: {
messages: Array<{
text: string;
is_assistant: boolean;
message_id: string;
}>;
};
}
export interface CreateAssistantReplyTask extends BaseCreateReplyTask {
type: "assistant_reply";
}
export interface CreatePrompterReplyTask extends BaseCreateReplyTask {
type: "prompter_reply";
}
export const useCreateAssistantReply = () => useGenericTaskAPI<CreateAssistantReplyTask>("assistant_reply");
export const useCreatePrompterReply = () => useGenericTaskAPI<CreatePrompterReplyTask>("prompter_reply");
@@ -1,9 +0,0 @@
import { useGenericTaskAPI } from "../useGenericTaskAPI";
interface RankInitialPromptsTask {
id: string;
type: "rank_initial_prompts";
prompts: string[];
}
export const useRankInitialPromptsTask = () => useGenericTaskAPI<RankInitialPromptsTask>("rank_initial_prompts");
@@ -1,25 +0,0 @@
import { useGenericTaskAPI } from "../useGenericTaskAPI";
interface BaseRankRepliesTask {
id: string;
replies: string[];
conversation: {
messages: Array<{
text: string;
is_assistant: boolean;
message_id: string;
}>;
};
}
interface RankAssistantRepliesTask extends BaseRankRepliesTask {
type: "rank_assistant_replies";
}
interface RankPrompterRepliesTask extends BaseRankRepliesTask {
type: "rank_prompter_replies";
}
export const useRankAssistantRepliesTask = () => useGenericTaskAPI<RankAssistantRepliesTask>("rank_assistant_replies");
export const useRankPrompterRepliesTask = () => useGenericTaskAPI<RankPrompterRepliesTask>("rank_prompter_replies");
@@ -1,22 +0,0 @@
import { TaskResponse } from "../useGenericTaskAPI";
import { LabelingTaskType, useLabelingTask } from "./useLabelingTask";
export interface LabelAssistantReplyTask {
id: string;
type: LabelingTaskType.label_assistant_reply;
message_id: string;
valid_labels: string[];
reply: string;
conversation: {
messages: Array<{
text: string;
is_assistant: boolean;
message_id: string;
}>;
};
}
export type LabelAssistantReplyTaskResponse = TaskResponse<LabelAssistantReplyTask>;
export const useLabelAssistantReplyTask = () =>
useLabelingTask<LabelAssistantReplyTask>(LabelingTaskType.label_assistant_reply);
@@ -1,15 +0,0 @@
import { TaskResponse } from "../useGenericTaskAPI";
import { LabelingTaskType, useLabelingTask } from "./useLabelingTask";
export interface LabelInitialPromptTask {
id: string;
type: LabelingTaskType.label_initial_prompt;
message_id: string;
valid_labels: string[];
prompt: string;
}
export type LabelInitialPromptTaskResponse = TaskResponse<LabelInitialPromptTask>;
export const useLabelInitialPromptTask = () =>
useLabelingTask<LabelInitialPromptTask>(LabelingTaskType.label_initial_prompt);
@@ -1,22 +0,0 @@
import { TaskResponse } from "../useGenericTaskAPI";
import { LabelingTaskType, useLabelingTask } from "./useLabelingTask";
export interface LabelPrompterReplyTask {
id: string;
type: LabelingTaskType.label_prompter_reply;
message_id: string;
valid_labels: string[];
reply: string;
conversation: {
messages: Array<{
text: string;
is_assistant: boolean;
message_id: string;
}>;
};
}
export type LabelPrompterReplyTaskResponse = TaskResponse<LabelPrompterReplyTask>;
export const useLabelPrompterReplyTask = () =>
useLabelingTask<LabelPrompterReplyTask>(LabelingTaskType.label_prompter_reply);
@@ -1,20 +0,0 @@
import { useGenericTaskAPI } from "../useGenericTaskAPI";
export const enum LabelingTaskType {
label_initial_prompt = "label_initial_prompt",
label_prompter_reply = "label_prompter_reply",
label_assistant_reply = "label_assistant_reply",
}
export const useLabelingTask = <TaskType>(endpoint: LabelingTaskType) => {
const { tasks, isLoading, trigger, reset, error } = useGenericTaskAPI<TaskType>(endpoint);
const submit = (id: string, message_id: string, text: string, validLabels: string[], labelWeights: number[]) => {
console.assert(validLabels.length === labelWeights.length);
const labels = Object.fromEntries(validLabels.map((label, i) => [label, labelWeights[i]]));
return trigger({ id, update_type: "text_labels", content: { labels, text, message_id } });
};
return { tasks, isLoading, submit, reset, error };
};
@@ -0,0 +1,8 @@
import { TaskType } from "src/types/Task";
import { CreateAssistantReplyTask, CreateInitialPromptTask, CreatePrompterReplyTask } from "src/types/Tasks";
import { useGenericTaskAPI } from "./useGenericTaskAPI";
export const useCreateAssistantReply = () => useGenericTaskAPI<CreateAssistantReplyTask>(TaskType.assistant_reply);
export const useCreatePrompterReply = () => useGenericTaskAPI<CreatePrompterReplyTask>(TaskType.prompter_reply);
export const useCreateInitialPrompt = () => useGenericTaskAPI<CreateInitialPromptTask>(TaskType.initial_prompt);
@@ -1,18 +1,11 @@
import { useState } from "react";
import fetcher from "src/lib/fetcher";
import poster from "src/lib/poster";
import { BaseTask, TaskResponse } from "src/types/Task";
import useSWRImmutable from "swr/immutable";
import useSWRMutation from "swr/mutation";
// TODO: type & centralize types for all tasks
export interface TaskResponse<TaskType> {
id: string;
userId: string;
task: TaskType;
}
export const useGenericTaskAPI = <TaskType,>(taskApiEndpoint: string) => {
export const useGenericTaskAPI = <TaskType extends BaseTask>(taskApiEndpoint: string) => {
type ConcreteTaskResponse = TaskResponse<TaskType>;
const [tasks, setTasks] = useState<ConcreteTaskResponse[]>([]);
@@ -0,0 +1,9 @@
import { TaskType } from "src/types/Task";
import { LabelAssistantReplyTask, LabelInitialPromptTask, LabelPrompterReplyTask } from "src/types/Tasks";
import { useGenericTaskAPI } from "./useGenericTaskAPI";
export const useLabelAssistantReplyTask = () =>
useGenericTaskAPI<LabelAssistantReplyTask>(TaskType.label_assistant_reply);
export const useLabelInitialPromptTask = () => useGenericTaskAPI<LabelInitialPromptTask>(TaskType.label_initial_prompt);
export const useLabelPrompterReplyTask = () => useGenericTaskAPI<LabelPrompterReplyTask>(TaskType.label_prompter_reply);
+12
View File
@@ -0,0 +1,12 @@
import { TaskType } from "src/types/Task";
import { RankAssistantRepliesTask, RankInitialPromptsTask, RankPrompterRepliesTask } from "src/types/Tasks";
import { useGenericTaskAPI } from "./useGenericTaskAPI";
export const useRankAssistantRepliesTask = () =>
useGenericTaskAPI<RankAssistantRepliesTask>(TaskType.rank_assistant_replies);
export const useRankPrompterRepliesTask = () =>
useGenericTaskAPI<RankPrompterRepliesTask>(TaskType.rank_prompter_replies);
export const useRankInitialPromptsTask = () => useGenericTaskAPI<RankInitialPromptsTask>(TaskType.rank_initial_prompts);
+34 -1
View File
@@ -30,7 +30,34 @@ export class OasstApiClient {
body: JSON.stringify(body),
});
if (resp.status == 204) {
if (resp.status === 204) {
return null;
}
if (resp.status >= 300) {
const errorText = await resp.text();
let error: any;
try {
error = JSON.parse(errorText);
} catch (e) {
throw new OasstError(errorText, 0, resp.status);
}
throw new OasstError(error.message ?? error, error.error_code, resp.status);
}
return await resp.json();
}
private async get(path: string): Promise<any> {
const resp = await fetch(`${this.oasstApiUrl}${path}`, {
method: "GET",
headers: {
"X-API-Key": this.oasstApiKey,
"Content-Type": "application/json",
},
});
if (resp.status === 204) {
return null;
}
@@ -96,6 +123,12 @@ export class OasstApiClient {
...content,
});
}
//Fetch valid labels. This is called every task. though the call may be redundant
//keeping this for future where the valid labels may change per task
async fetch_valid_text(): Promise<void> {
return this.get(`/api/v1/text_labels/valid_labels`);
}
}
export const oasstApiClient =
@@ -23,6 +23,7 @@ const handler = async (req, res) => {
// Fetch the new task.
const task = await oasstApiClient.fetchTask(task_type, token);
const valid_labels = await oasstApiClient.fetch_valid_text();
// Store the task and link it to the user..
const registeredTask = await prisma.registeredTask.create({
@@ -36,6 +37,9 @@ const handler = async (req, res) => {
},
});
// Add the valid labels that can be used to flag messages in this Task
registeredTask["valid_labels"] = valid_labels;
// Send the results to the client.
res.status(200).json(registeredTask);
};
+1
View File
@@ -1,6 +1,7 @@
import { Prisma } from "@prisma/client";
import { getToken } from "next-auth/jwt";
import { oasstApiClient } from "src/lib/oasst_api_client";
import prisma from "src/lib/prismadb";
const handler = async (req, res) => {
const token = await getToken({ req });
+1 -2
View File
@@ -1,5 +1,4 @@
import { getToken } from "next-auth/jwt";
import prisma from "src/lib/prismadb";
/**
* Sets the Label in the Backend.
@@ -15,7 +14,7 @@ const handler = async (req, res) => {
}
// Parse out the local message_id, task ID and the interaction contents.
const { message_id, post_id, label_map, text } = await JSON.parse(req.body);
const { message_id, label_map, text } = await JSON.parse(req.body);
const interactionRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/text_labels`, {
method: "POST",
-1
View File
@@ -1,7 +1,6 @@
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
import { getCsrfToken, getProviders } from "next-auth/react";
import { AuthLayout } from "src/components/AuthLayout";
export default function Verify() {
const { colorMode } = useColorMode();
+2 -6
View File
@@ -1,16 +1,12 @@
import { Container } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Task } from "src/components/Tasks/Task";
import { useCreateAssistantReply } from "src/hooks/tasks/create/useCreateReply";
import { useCreateAssistantReply } from "src/hooks/tasks/useCreateReply";
const AssistantReply = () => {
const { tasks, isLoading, reset, trigger } = useCreateAssistantReply();
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
if (isLoading) {
return <LoadingScreen text="Loading..." />;
}
@@ -25,7 +21,7 @@ const AssistantReply = () => {
<title>Reply as Assistant</title>
<meta name="description" content="Reply as Assistant." />
</Head>
<Task tasks={tasks} trigger={trigger} mutate={reset} mainBgClasses={mainBgClasses} />
<Task tasks={tasks} trigger={trigger} mutate={reset} />
</>
);
};
+2 -6
View File
@@ -1,16 +1,12 @@
import { Container } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Task } from "src/components/Tasks/Task";
import { useCreateInitialPrompt } from "src/hooks/tasks/create/useCreateInitialPrompt";
import { useCreateInitialPrompt } from "src/hooks/tasks/useCreateReply";
const InitialPrompt = () => {
const { tasks, isLoading, reset, trigger } = useCreateInitialPrompt();
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
if (isLoading) {
return <LoadingScreen text="Loading..." />;
}
@@ -25,7 +21,7 @@ const InitialPrompt = () => {
<title>Reply as Assistant</title>
<meta name="description" content="Reply as Assistant." />
</Head>
<Task tasks={tasks} trigger={trigger} mutate={reset} mainBgClasses={mainBgClasses} />
<Task tasks={tasks} trigger={trigger} mutate={reset} />
</>
);
};
+1 -1
View File
@@ -63,7 +63,7 @@ const SummarizeStory = () => {
return <LoadingScreen text="Loading..." />;
}
if (tasks.length == 0) {
if (tasks.length === 0) {
return <div className="p-6 bg-slate-100 text-gray-800">No tasks found...</div>;
}
+2 -6
View File
@@ -1,16 +1,12 @@
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
import { Container } from "src/components/Container";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Task } from "src/components/Tasks/Task";
import { useCreatePrompterReply } from "src/hooks/tasks/create/useCreateReply";
import { useCreatePrompterReply } from "src/hooks/tasks/useCreateReply";
const UserReply = () => {
const { tasks, isLoading, reset, trigger } = useCreatePrompterReply();
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
if (isLoading) {
return <LoadingScreen text="Loading..." />;
}
@@ -25,7 +21,7 @@ const UserReply = () => {
<title>Reply as Assistant</title>
<meta name="description" content="Reply as Assistant." />
</Head>
<Task tasks={tasks} trigger={trigger} mutate={reset} mainBgClasses={mainBgClasses} />
<Task tasks={tasks} trigger={trigger} mutate={reset} />
</>
);
};
@@ -1,16 +1,12 @@
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
import { Container } from "src/components/Container";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Task } from "src/components/Tasks/Task";
import { useRankAssistantRepliesTask } from "src/hooks/tasks/evaluate/useRankReplies";
import { useRankAssistantRepliesTask } from "src/hooks/tasks/useRankReplies";
const RankAssistantReplies = () => {
const { tasks, isLoading, reset, trigger } = useRankAssistantRepliesTask();
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
if (isLoading) {
return <LoadingScreen text="Loading..." />;
}
@@ -25,7 +21,7 @@ const RankAssistantReplies = () => {
<title>Rank Assistant Replies</title>
<meta name="description" content="Rank Assistant Replies." />
</Head>
<Task tasks={tasks} trigger={trigger} mutate={reset} mainBgClasses={mainBgClasses} />
<Task tasks={tasks} trigger={trigger} mutate={reset} />
</>
);
};
@@ -1,16 +1,12 @@
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
import { Container } from "src/components/Container";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Task } from "src/components/Tasks/Task";
import { useRankInitialPromptsTask } from "src/hooks/tasks/evaluate/useRankInitialPrompts";
import { useRankInitialPromptsTask } from "src/hooks/tasks/useRankReplies";
const RankInitialPrompts = () => {
const { tasks, isLoading, reset, trigger } = useRankInitialPromptsTask();
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
if (isLoading) {
return <LoadingScreen text="Loading..." />;
}
@@ -25,7 +21,7 @@ const RankInitialPrompts = () => {
<title>Rank Initial Prompts</title>
<meta name="description" content="Rank initial prompts." />
</Head>
<Task tasks={tasks} trigger={trigger} mutate={reset} mainBgClasses={mainBgClasses} />
<Task tasks={tasks} trigger={trigger} mutate={reset} />
</>
);
};
@@ -1,16 +1,12 @@
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
import { Container } from "src/components/Container";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Task } from "src/components/Tasks/Task";
import { useRankPrompterRepliesTask } from "src/hooks/tasks/evaluate/useRankReplies";
import { useRankPrompterRepliesTask } from "src/hooks/tasks/useRankReplies";
const RankUserReplies = () => {
const { tasks, isLoading, reset, trigger } = useRankPrompterRepliesTask();
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
if (isLoading) {
return <LoadingScreen text="Loading..." />;
}
@@ -25,7 +21,7 @@ const RankUserReplies = () => {
<title>Rank User Replies</title>
<meta name="description" content="Rank User Replies." />
</Head>
<Task tasks={tasks} trigger={trigger} mutate={reset} mainBgClasses={mainBgClasses} />
<Task tasks={tasks} trigger={trigger} mutate={reset} />
</>
);
};
+1 -1
View File
@@ -60,7 +60,7 @@ const RateSummary = () => {
return <LoadingScreen text="Loading..." />;
}
if (tasks.length == 0) {
if (tasks.length === 0) {
return (
<div className={`p-12 ${mainBgClasses}`}>
<div className="flex h-full">
+11
View File
@@ -1,10 +1,21 @@
import Head from "next/head";
import { useRouter } from "next/router";
import { useSession } from "next-auth/react";
import { useEffect } from "react";
import { CallToAction } from "src/components/CallToAction";
import { Faq } from "src/components/Faq";
import { Hero } from "src/components/Hero";
import { getTransparentHeaderLayout } from "src/components/Layout";
const Home = () => {
const router = useRouter();
const { status } = useSession();
useEffect(() => {
if (status === "authenticated") {
router.push("/dashboard");
}
}, [router, status]);
return (
<>
<Head>
@@ -1,46 +1,28 @@
import { useState } from "react";
import { Container } from "@chakra-ui/react";
import Head from "next/head";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Message } from "src/components/Messages";
import { MessageTable } from "src/components/Messages/MessageTable";
import { TaskControls } from "src/components/Survey/TaskControls";
import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask";
import {
LabelAssistantReplyTaskResponse,
useLabelAssistantReplyTask,
} from "src/hooks/tasks/labeling/useLabelAssistantReply";
import { Task } from "src/components/Tasks/Task";
import { useLabelAssistantReplyTask } from "src/hooks/tasks/useLabelingTask";
const LabelAssistantReply = () => {
const [sliderValues, setSliderValues] = useState<number[]>([]);
const { tasks, isLoading, trigger, reset } = useLabelAssistantReplyTask();
const { tasks, isLoading, submit, reset } = useLabelAssistantReplyTask();
if (isLoading || tasks.length === 0) {
if (isLoading) {
return <LoadingScreen />;
}
const task = tasks[0].task;
const messages: Message[] = [
...task.conversation.messages,
{ text: task.reply, is_assistant: true, message_id: task.message_id },
];
if (tasks.length === 0) {
return <Container className="p-6 text-center text-gray-800">No tasks found...</Container>;
}
return (
<LabelTask
title="Label Assistant Reply"
desc="Given the following discussion, provide labels for the final prompt"
messages={<MessageTable messages={messages} />}
inputs={<LabelSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />}
controls={
<TaskControls
tasks={tasks}
onSkipTask={() => reset()}
onNextTask={reset}
onSubmitResponse={({ id, task }: LabelAssistantReplyTaskResponse) =>
submit(id, task.message_id, task.reply, task.valid_labels, sliderValues)
}
/>
}
/>
<>
<Head>
<title>Label Assistant Reply</title>
<meta name="description" content="Label Assistant Reply" />
</Head>
<Task tasks={tasks} trigger={trigger} mutate={reset} />
</>
);
};
@@ -1,41 +1,28 @@
import { useState } from "react";
import { Container } from "@chakra-ui/react";
import Head from "next/head";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { MessageView } from "src/components/Messages";
import { TaskControls } from "src/components/Survey/TaskControls";
import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask";
import {
LabelInitialPromptTaskResponse,
useLabelInitialPromptTask,
} from "src/hooks/tasks/labeling/useLabelInitialPrompt";
import { Task } from "src/components/Tasks/Task";
import { useLabelInitialPromptTask } from "src/hooks/tasks/useLabelingTask";
const LabelInitialPrompt = () => {
const [sliderValues, setSliderValues] = useState<number[]>([]);
const { tasks, isLoading, trigger, reset } = useLabelInitialPromptTask();
const { tasks, isLoading, submit, reset } = useLabelInitialPromptTask();
if (isLoading || tasks.length === 0) {
if (isLoading) {
return <LoadingScreen />;
}
const task = tasks[0].task;
if (tasks.length === 0) {
return <Container className="p-6 text-center text-gray-800">No tasks found...</Container>;
}
return (
<LabelTask
title="Label Initial Prompt"
desc="Provide labels for the following prompt"
messages={<MessageView text={task.prompt} is_assistant message_id={task.message_id} />}
inputs={<LabelSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />}
controls={
<TaskControls
tasks={tasks}
onSkipTask={() => reset()}
onNextTask={reset}
onSubmitResponse={({ id, task }: LabelInitialPromptTaskResponse) =>
submit(id, task.message_id, task.prompt, task.valid_labels, sliderValues)
}
/>
}
/>
<>
<Head>
<title>Label Initial Prompt</title>
<meta name="description" content="Label Initial Prompt" />
</Head>
<Task tasks={tasks} trigger={trigger} mutate={reset} />
</>
);
};
@@ -1,46 +1,28 @@
import { useState } from "react";
import { Container } from "@chakra-ui/react";
import Head from "next/head";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Message } from "src/components/Messages";
import { MessageTable } from "src/components/Messages/MessageTable";
import { TaskControls } from "src/components/Survey/TaskControls";
import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask";
import {
LabelPrompterReplyTaskResponse,
useLabelPrompterReplyTask,
} from "src/hooks/tasks/labeling/useLabelPrompterReply";
import { Task } from "src/components/Tasks/Task";
import { useLabelPrompterReplyTask } from "src/hooks/tasks/useLabelingTask";
const LabelPrompterReply = () => {
const [sliderValues, setSliderValues] = useState<number[]>([]);
const { tasks, isLoading, trigger, reset } = useLabelPrompterReplyTask();
const { tasks, isLoading, submit, reset } = useLabelPrompterReplyTask();
if (isLoading || tasks.length === 0) {
if (isLoading) {
return <LoadingScreen />;
}
const task = tasks[0].task;
const messages: Message[] = [
...task.conversation.messages,
{ text: task.reply, is_assistant: false, message_id: task.message_id },
];
if (tasks.length === 0) {
return <Container className="p-6 text-center text-gray-800">No tasks found...</Container>;
}
return (
<LabelTask
title="Label Prompter Reply"
desc="Given the following discussion, provide labels for the final prompt"
messages={<MessageTable messages={messages} />}
inputs={<LabelSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />}
controls={
<TaskControls
tasks={tasks}
onSkipTask={() => reset()}
onNextTask={reset}
onSubmitResponse={({ id, task }: LabelPrompterReplyTaskResponse) =>
submit(id, task.message_id, task.reply, task.valid_labels, sliderValues)
}
/>
}
/>
<>
<Head>
<title>Label Prompter Reply</title>
<meta name="description" content="Label Prompter Reply" />
</Head>
<Task tasks={tasks} trigger={trigger} mutate={reset} />
</>
);
};
+1 -1
View File
@@ -41,7 +41,7 @@ const MessageDetail = ({ id }) => {
Parent
</Text>
<Box rounded="lg" p="2">
<MessageTableEntry item={parent} idx={1} />
<MessageTableEntry item={parent} idx={1} valid_labels={[]} />
</Box>
</>
)}
+12 -4
View File
@@ -2,9 +2,9 @@ import { Box, CircularProgress, SimpleGrid, Text, useColorModeValue } from "@cha
import Head from "next/head";
import { useEffect, useState } from "react";
import { getDashboardLayout } from "src/components/Layout";
import { Message } from "src/components/Messages";
import { MessageTable } from "src/components/Messages/MessageTable";
import fetcher from "src/lib/fetcher";
import { Message } from "src/types/Conversation";
import useSWRImmutable from "swr/immutable";
const MessagesDashboard = () => {
@@ -52,7 +52,11 @@ const MessagesDashboard = () => {
borderRadius="xl"
className="p-6 shadow-sm"
>
{receivedMessages ? <MessageTable messages={messages} /> : <CircularProgress isIndeterminate />}
{receivedMessages ? (
<MessageTable messages={messages} valid_labels={[]} />
) : (
<CircularProgress isIndeterminate />
)}
</Box>
</Box>
<Box>
@@ -66,7 +70,11 @@ const MessagesDashboard = () => {
borderRadius="xl"
className="p-6 shadow-sm"
>
{receivedUserMessages ? <MessageTable messages={userMessages} /> : <CircularProgress isIndeterminate />}
{receivedUserMessages ? (
<MessageTable messages={userMessages} valid_labels={[]} />
) : (
<CircularProgress isIndeterminate />
)}
</Box>
</Box>
</SimpleGrid>
@@ -74,6 +82,6 @@ const MessagesDashboard = () => {
);
};
MessagesDashboard.getLayout = (page) => getDashboardLayout(page);
MessagesDashboard.getLayout = getDashboardLayout;
export default MessagesDashboard;
+9
View File
@@ -0,0 +1,9 @@
export interface Message {
text: string;
is_assistant: boolean;
message_id: string;
}
export interface Conversation {
messages: Message[];
}
+31
View File
@@ -0,0 +1,31 @@
export const enum TaskType {
initial_prompt = "initial_prompt",
assistant_reply = "assistant_reply",
prompter_reply = "prompter_reply",
rank_initial_prompts = "rank_initial_prompts",
rank_assistant_replies = "rank_assistant_replies",
rank_prompter_replies = "rank_prompter_replies",
label_initial_prompt = "label_initial_prompt",
label_prompter_reply = "label_prompter_reply",
label_assistant_reply = "label_assistant_reply",
}
export interface ValidLabel {
name: string;
display_text: string;
help_text: string;
}
export interface BaseTask {
id: string;
type: TaskType;
}
export interface TaskResponse<Task extends BaseTask> {
id: string;
userId: string;
task: Task;
valid_labels: ValidLabel[];
}
+57
View File
@@ -0,0 +1,57 @@
import { Conversation } from "./Conversation";
import { BaseTask, TaskType } from "./Task";
export interface CreateInitialPromptTask extends BaseTask {
type: TaskType.initial_prompt;
hint: string;
}
export interface CreateAssistantReplyTask extends BaseTask {
type: TaskType.assistant_reply;
conversation: Conversation;
}
export interface CreatePrompterReplyTask extends BaseTask {
type: TaskType.prompter_reply;
conversation: Conversation;
}
export interface RankInitialPromptsTask extends BaseTask {
type: TaskType.rank_initial_prompts;
prompts: string[];
}
export interface RankAssistantRepliesTask extends BaseTask {
type: TaskType.rank_assistant_replies;
conversation: Conversation;
replies: string[];
}
export interface RankPrompterRepliesTask extends BaseTask {
type: TaskType.rank_prompter_replies;
conversation: Conversation;
replies: string[];
}
export interface LabelAssistantReplyTask extends BaseTask {
type: TaskType.label_assistant_reply;
message_id: string;
conversation: Conversation;
reply: string;
valid_labels: string[];
}
export interface LabelInitialPromptTask extends BaseTask {
type: TaskType.label_initial_prompt;
message_id: string;
valid_labels: string[];
prompt: string;
}
export interface LabelPrompterReplyTask extends BaseTask {
type: TaskType.label_prompter_reply;
message_id: string;
conversation: Conversation;
reply: string;
valid_labels: string[];
}