From c3696759917b2cc03c2424e1a99a9b55d59b4e57 Mon Sep 17 00:00:00 2001 From: Adrian Cowan Date: Fri, 6 Jan 2023 20:51:49 +1100 Subject: [PATCH 01/39] website: Automate e2e testing with a simpler method than #376 --- .github/workflows/test-e2e.yaml | 37 +++++++++++++++++++++++++++++++++ docker-compose.yaml | 5 +++++ 2 files changed, 42 insertions(+) create mode 100644 .github/workflows/test-e2e.yaml diff --git a/.github/workflows/test-e2e.yaml b/.github/workflows/test-e2e.yaml new file mode 100644 index 00000000..f2759808 --- /dev/null +++ b/.github/workflows/test-e2e.yaml @@ -0,0 +1,37 @@ +name: E2E Tests (Website) + +on: + push: + branches: + - main + paths: + - backend/** + - website/** + pull_request: + paths: + - backend/** + - website/** + +jobs: + test-e2e: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v3 + - name: Start website, backend, etc + run: docker compose up ci --build -d + - name: Run Cypress tests + uses: cypress-io/github-action@v5.0.2 + with: + browser: chrome + working-directory: website + - uses: actions/upload-artifact@v3 + if: failure() # NOTE: screenshots will be generated only if E2E test failed + with: + name: cypress-screenshots + path: website/cypress/screenshots + - uses: actions/upload-artifact@v3 + if: always() + with: + name: cypress-videos + path: website/cypress/videos diff --git a/docker-compose.yaml b/docker-compose.yaml index 6bc42c51..6896bf18 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -11,6 +11,11 @@ services: image: sverrirab/sleep depends_on: [db, webdb, adminer, maildev, backend, redis] + # Used by CI automations. + ci: + image: sverrirab/sleep + depends_on: [db, webdb, maildev, backend, redis, web] + # This DB is for the FastAPI Backend. db: image: postgres From 11d55d572a06ba42193f846405128d9da21f9957 Mon Sep 17 00:00:00 2001 From: jojopirker Date: Sun, 8 Jan 2023 12:28:38 +0100 Subject: [PATCH 02/39] message embeddings in Messages table --- ...dded_minilm_embedding_column_to_message.py | 29 +++++++++++++++++++ backend/oasst_backend/api/v1/hugging_face.py | 7 +---- backend/oasst_backend/api/v1/tasks.py | 15 +++++++++- backend/oasst_backend/config.py | 1 + backend/oasst_backend/models/message.py | 5 ++-- .../oasst_backend/models/message_embedding.py | 0 backend/oasst_backend/prompt_repository.py | 9 ++++-- backend/oasst_backend/utils/hugging_face.py | 10 +++++++ docker-compose.yaml | 1 + scripts/backend-development/run-local.sh | 1 + 10 files changed, 67 insertions(+), 11 deletions(-) create mode 100644 backend/alembic/versions/2023_01_08_1106-3d96bb92e33a_added_minilm_embedding_column_to_message.py create mode 100644 backend/oasst_backend/models/message_embedding.py diff --git a/backend/alembic/versions/2023_01_08_1106-3d96bb92e33a_added_minilm_embedding_column_to_message.py b/backend/alembic/versions/2023_01_08_1106-3d96bb92e33a_added_minilm_embedding_column_to_message.py new file mode 100644 index 00000000..843f03bc --- /dev/null +++ b/backend/alembic/versions/2023_01_08_1106-3d96bb92e33a_added_minilm_embedding_column_to_message.py @@ -0,0 +1,29 @@ +"""added miniLM_embedding column to message + +Revision ID: 023548d474f7 +Revises: ba61fe17fb6e +Create Date: 2023-01-08 11:06:25.613290 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel + + +# revision identifiers, used by Alembic. +revision = '023548d474f7' +down_revision = 'ba61fe17fb6e' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('message', sa.Column('miniLM_embedding', sa.ARRAY(sa.Float()), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('message', 'miniLM_embedding') + # ### end Alembic commands ### diff --git a/backend/oasst_backend/api/v1/hugging_face.py b/backend/oasst_backend/api/v1/hugging_face.py index 1e7f1ffe..a8d8aeb9 100644 --- a/backend/oasst_backend/api/v1/hugging_face.py +++ b/backend/oasst_backend/api/v1/hugging_face.py @@ -1,19 +1,14 @@ -from enum import Enum from typing import List from fastapi import APIRouter, Depends from oasst_backend.api import deps from oasst_backend.models import ApiClient from oasst_backend.schemas.hugging_face import ToxicityClassification -from oasst_backend.utils.hugging_face import HuggingFaceAPI +from oasst_backend.utils.hugging_face import HF_url, HuggingFaceAPI router = APIRouter() -class HF_url(str, Enum): - HUGGINGFACE_TOXIC_ROBERTA = "https://api-inference.huggingface.co/models/unitary/multilingual-toxic-xlm-roberta" - - @router.get("/text_toxicity") async def get_text_toxicity( msg: str, diff --git a/backend/oasst_backend/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py index adfb2907..2248b85a 100644 --- a/backend/oasst_backend/api/v1/tasks.py +++ b/backend/oasst_backend/api/v1/tasks.py @@ -7,7 +7,9 @@ from fastapi.security.api_key import APIKey from loguru import logger from oasst_backend.api import deps from oasst_backend.api.v1.utils import prepare_conversation +from oasst_backend.config import settings from oasst_backend.prompt_repository import PromptRepository +from oasst_backend.utils.hugging_face import HF_url, HuggingFaceAPI from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema from sqlmodel import Session @@ -253,7 +255,7 @@ def tasks_acknowledge_failure( @router.post("/interaction", response_model=protocol_schema.TaskDone) -def tasks_interaction( +async def tasks_interaction( *, db: Session = Depends(deps.get_db), api_key: APIKey = Depends(deps.get_api_key), @@ -273,11 +275,22 @@ def tasks_interaction( f"Frontend reports text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}." ) + embedding = None + if not settings.DEBUG_SKIP_EMBEDDING_COMPUTATION: + try: + hugging_face_api = HuggingFaceAPI(HF_url.HUGGINGFACE_MINILM_EMBEDDING.value) + embedding = await hugging_face_api.post(interaction.text) + except: + logger.error( + f"Could not fetch embbeddings for text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}." + ) + # here we store the text reply in the database pr.store_text_reply( text=interaction.text, frontend_message_id=interaction.message_id, user_frontend_message_id=interaction.user_message_id, + miniLM_embedding=embedding, ) return protocol_schema.TaskDone() diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index 1765af7a..ed394412 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -25,6 +25,7 @@ class Settings(BaseSettings): DEBUG_USE_SEED_DATA_PATH: Optional[FilePath] = ( Path(__file__).parent.parent / "test_data/generic/test_generic_data.json" ) + DEBUG_SKIP_EMBEDDING_COMPUTATION: bool = False HUGGING_FACE_API_KEY: str = "" diff --git a/backend/oasst_backend/models/message.py b/backend/oasst_backend/models/message.py index 6d24fd13..c7c2abbb 100644 --- a/backend/oasst_backend/models/message.py +++ b/backend/oasst_backend/models/message.py @@ -1,6 +1,6 @@ from datetime import datetime from http import HTTPStatus -from typing import Optional +from typing import List, Optional from uuid import UUID, uuid4 import sqlalchemy as sa @@ -8,7 +8,7 @@ import sqlalchemy.dialects.postgresql as pg from oasst_backend.models.db_payload import MessagePayload from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode from sqlalchemy import false -from sqlmodel import Field, Index, SQLModel +from sqlmodel import ARRAY, Field, Float, Index, SQLModel from .payload_column_type import PayloadContainer, payload_column_type @@ -40,6 +40,7 @@ class Message(SQLModel, table=True): depth: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False)) children_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False)) deleted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false())) + miniLM_embedding: List[float] = Field(sa_column=sa.Column(ARRAY(Float)), nullable=True) def ensure_is_message(self) -> None: if not self.payload or not isinstance(self.payload.payload, MessagePayload): diff --git a/backend/oasst_backend/models/message_embedding.py b/backend/oasst_backend/models/message_embedding.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 7c7dd7b6..8bb1eb4a 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -2,7 +2,7 @@ import datetime import random from collections import defaultdict from http import HTTPStatus -from typing import Optional +from typing import List, Optional from uuid import UUID, uuid4 import oasst_backend.models.db_payload as db_payload @@ -122,7 +122,9 @@ class PromptRepository: ) return task - def store_text_reply(self, text: str, frontend_message_id: str, user_frontend_message_id: str) -> Message: + def store_text_reply( + self, text: str, frontend_message_id: str, user_frontend_message_id: str, miniLM_embedding: List[float] = None + ) -> Message: self.validate_frontend_message_id(frontend_message_id) self.validate_frontend_message_id(user_frontend_message_id) @@ -163,6 +165,7 @@ class PromptRepository: role=role, payload=db_payload.MessagePayload(text=text), depth=depth, + miniLM_embedding=miniLM_embedding, ) if not task.collective: task.done = True @@ -366,6 +369,7 @@ class PromptRepository: payload: db_payload.MessagePayload, payload_type: str = None, depth: int = 0, + miniLM_embedding: List[float] = None, ) -> Message: if payload_type is None: if payload is None: @@ -385,6 +389,7 @@ class PromptRepository: payload_type=payload_type, payload=PayloadContainer(payload=payload), depth=depth, + miniLM_embedding=miniLM_embedding, ) self.db.add(message) self.db.commit() diff --git a/backend/oasst_backend/utils/hugging_face.py b/backend/oasst_backend/utils/hugging_face.py index 0df913f5..867c537c 100644 --- a/backend/oasst_backend/utils/hugging_face.py +++ b/backend/oasst_backend/utils/hugging_face.py @@ -1,3 +1,4 @@ +from enum import Enum from typing import Any, Dict import aiohttp @@ -5,6 +6,11 @@ from oasst_backend.config import settings from oasst_shared.exceptions import OasstError, OasstErrorCode +class HF_url(str, Enum): + HUGGINGFACE_TOXIC_ROBERTA = ("https://api-inference.huggingface.co/models/unitary/multilingual-toxic-xlm-roberta",) + HUGGINGFACE_MINILM_EMBEDDING = "https://api-inference.huggingface.co/pipeline/feature-extraction/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" + + class HuggingFaceAPI: """Class Object to make post calls to endpoints for inference in models hosted in HuggingFace""" @@ -41,6 +47,10 @@ class HuggingFaceAPI: async with session.post(self.api_url, headers=self.headers, json=payload) as response: # If we get a bad response if response.status != 200: + from loguru import logger + + logger.error(response) + logger.info(self.headers) raise OasstError( "Response Error Detoxify HuggingFace", error_code=OasstErrorCode.HUGGINGFACE_API_ERROR ) diff --git a/docker-compose.yaml b/docker-compose.yaml index 6bc42c51..5b0032f4 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -95,6 +95,7 @@ services: - DEBUG_SKIP_API_KEY_CHECK=True - DEBUG_USE_SEED_DATA=True - MAX_WORKERS=1 + - DEBUG_SKIP_EMBEDDING_COMPUTATION=True depends_on: db: condition: service_healthy diff --git a/scripts/backend-development/run-local.sh b/scripts/backend-development/run-local.sh index 3ed0e936..f0f6d16c 100755 --- a/scripts/backend-development/run-local.sh +++ b/scripts/backend-development/run-local.sh @@ -6,6 +6,7 @@ pushd "$parent_path/../../backend" export DEBUG_SKIP_API_KEY_CHECK=True export DEBUG_USE_SEED_DATA=True +export DEBUG_SKIP_EMBEDDING_COMPUTATION=True uvicorn main:app --reload --port 8080 --host 0.0.0.0 From 34e7d1db8a272551a9f29bddb1f34ba7a3df8952 Mon Sep 17 00:00:00 2001 From: Nil-Andreu Date: Sun, 8 Jan 2023 14:36:38 +0100 Subject: [PATCH 03/39] [NEW] Except OasstError --- backend/oasst_backend/api/v1/tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/oasst_backend/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py index 2248b85a..8656a063 100644 --- a/backend/oasst_backend/api/v1/tasks.py +++ b/backend/oasst_backend/api/v1/tasks.py @@ -280,7 +280,7 @@ async def tasks_interaction( try: hugging_face_api = HuggingFaceAPI(HF_url.HUGGINGFACE_MINILM_EMBEDDING.value) embedding = await hugging_face_api.post(interaction.text) - except: + except OasstError: logger.error( f"Could not fetch embbeddings for text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}." ) From a677e40cffcc026a431c333af64b42b823c5af7a Mon Sep 17 00:00:00 2001 From: jojopirker Date: Sun, 8 Jan 2023 16:46:53 +0100 Subject: [PATCH 04/39] insert embedding now to new table --- ...dded_minilm_embedding_column_to_message.py | 12 ++--- ...8_embedding_for_message_now_in_its_own_.py | 49 +++++++++++++++++++ backend/oasst_backend/api/v1/tasks.py | 25 +++++----- backend/oasst_backend/models/__init__.py | 2 + backend/oasst_backend/models/message.py | 5 +- .../oasst_backend/models/message_embedding.py | 15 ++++++ backend/oasst_backend/prompt_repository.py | 15 ++++-- backend/oasst_backend/utils/hugging_face.py | 6 ++- 8 files changed, 103 insertions(+), 26 deletions(-) create mode 100644 backend/alembic/versions/2023_01_08_1603-35bdc1a08bb8_embedding_for_message_now_in_its_own_.py diff --git a/backend/alembic/versions/2023_01_08_1106-3d96bb92e33a_added_minilm_embedding_column_to_message.py b/backend/alembic/versions/2023_01_08_1106-3d96bb92e33a_added_minilm_embedding_column_to_message.py index 843f03bc..9b81105f 100644 --- a/backend/alembic/versions/2023_01_08_1106-3d96bb92e33a_added_minilm_embedding_column_to_message.py +++ b/backend/alembic/versions/2023_01_08_1106-3d96bb92e33a_added_minilm_embedding_column_to_message.py @@ -5,25 +5,23 @@ Revises: ba61fe17fb6e Create Date: 2023-01-08 11:06:25.613290 """ -from alembic import op import sqlalchemy as sa -import sqlmodel - +from alembic import op # revision identifiers, used by Alembic. -revision = '023548d474f7' -down_revision = 'ba61fe17fb6e' +revision = "023548d474f7" +down_revision = "ba61fe17fb6e" branch_labels = None depends_on = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.add_column('message', sa.Column('miniLM_embedding', sa.ARRAY(sa.Float()), nullable=True)) + op.add_column("message", sa.Column("miniLM_embedding", sa.ARRAY(sa.Float()), nullable=True)) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_column('message', 'miniLM_embedding') + op.drop_column("message", "miniLM_embedding") # ### end Alembic commands ### diff --git a/backend/alembic/versions/2023_01_08_1603-35bdc1a08bb8_embedding_for_message_now_in_its_own_.py b/backend/alembic/versions/2023_01_08_1603-35bdc1a08bb8_embedding_for_message_now_in_its_own_.py new file mode 100644 index 00000000..b732b792 --- /dev/null +++ b/backend/alembic/versions/2023_01_08_1603-35bdc1a08bb8_embedding_for_message_now_in_its_own_.py @@ -0,0 +1,49 @@ +"""embedding for message now in its own table + +Revision ID: 35bdc1a08bb8 +Revises: 023548d474f7 +Create Date: 2023-01-08 16:03:48.454207 + +""" +import sqlalchemy as sa +import sqlmodel +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "35bdc1a08bb8" +down_revision = "023548d474f7" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "message_embedding", + sa.Column("message_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column("embedding", sa.ARRAY(sa.Float()), nullable=True), + sa.Column("model", sqlmodel.sql.sqltypes.AutoString(length=256), nullable=False), + sa.ForeignKeyConstraint( + ["message_id"], + ["message.id"], + ), + sa.PrimaryKeyConstraint("message_id", "model"), + ) + op.drop_column("message", "miniLM_embedding") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "message", + sa.Column( + "miniLM_embedding", + postgresql.ARRAY(postgresql.DOUBLE_PRECISION(precision=53)), + autoincrement=False, + nullable=True, + ), + ) + op.drop_table("message_embedding") + # ### end Alembic commands ### diff --git a/backend/oasst_backend/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py index 8656a063..e033c25e 100644 --- a/backend/oasst_backend/api/v1/tasks.py +++ b/backend/oasst_backend/api/v1/tasks.py @@ -9,7 +9,7 @@ from oasst_backend.api import deps from oasst_backend.api.v1.utils import prepare_conversation from oasst_backend.config import settings from oasst_backend.prompt_repository import PromptRepository -from oasst_backend.utils.hugging_face import HF_url, HuggingFaceAPI +from oasst_backend.utils.hugging_face import HF_embeddingModel, HF_url, HuggingFaceAPI from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema from sqlmodel import Session @@ -275,24 +275,27 @@ async def tasks_interaction( f"Frontend reports text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}." ) - embedding = None + # here we store the text reply in the database + newMessage = pr.store_text_reply( + text=interaction.text, + frontend_message_id=interaction.message_id, + user_frontend_message_id=interaction.user_message_id, + ) + if not settings.DEBUG_SKIP_EMBEDDING_COMPUTATION: try: - hugging_face_api = HuggingFaceAPI(HF_url.HUGGINGFACE_MINILM_EMBEDDING.value) + hugging_face_api = HuggingFaceAPI( + f"{HF_url.HUGGINGFACE_FEATURE_EXTRACTION.value}{HF_embeddingModel.MINILM.value}" + ) embedding = await hugging_face_api.post(interaction.text) + pr.insert_message_embedding( + message_id=newMessage.id, model=HF_embeddingModel.MINILM.value, embedding=embedding + ) except OasstError: logger.error( f"Could not fetch embbeddings for text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}." ) - # here we store the text reply in the database - pr.store_text_reply( - text=interaction.text, - frontend_message_id=interaction.message_id, - user_frontend_message_id=interaction.user_message_id, - miniLM_embedding=embedding, - ) - return protocol_schema.TaskDone() case protocol_schema.MessageRating: logger.info( diff --git a/backend/oasst_backend/models/__init__.py b/backend/oasst_backend/models/__init__.py index a856b155..0873381c 100644 --- a/backend/oasst_backend/models/__init__.py +++ b/backend/oasst_backend/models/__init__.py @@ -1,6 +1,7 @@ from .api_client import ApiClient from .journal import Journal, JournalIntegration from .message import Message +from .message_embedding import MessageEmbedding from .message_reaction import MessageReaction from .message_tree_state import MessageTreeState from .task import Task @@ -13,6 +14,7 @@ __all__ = [ "User", "UserStats", "Message", + "MessageEmbedding", "MessageReaction", "MessageTreeState", "Task", diff --git a/backend/oasst_backend/models/message.py b/backend/oasst_backend/models/message.py index c7c2abbb..6d24fd13 100644 --- a/backend/oasst_backend/models/message.py +++ b/backend/oasst_backend/models/message.py @@ -1,6 +1,6 @@ from datetime import datetime from http import HTTPStatus -from typing import List, Optional +from typing import Optional from uuid import UUID, uuid4 import sqlalchemy as sa @@ -8,7 +8,7 @@ import sqlalchemy.dialects.postgresql as pg from oasst_backend.models.db_payload import MessagePayload from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode from sqlalchemy import false -from sqlmodel import ARRAY, Field, Float, Index, SQLModel +from sqlmodel import Field, Index, SQLModel from .payload_column_type import PayloadContainer, payload_column_type @@ -40,7 +40,6 @@ class Message(SQLModel, table=True): depth: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False)) children_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False)) deleted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false())) - miniLM_embedding: List[float] = Field(sa_column=sa.Column(ARRAY(Float)), nullable=True) def ensure_is_message(self) -> None: if not self.payload or not isinstance(self.payload.payload, MessagePayload): diff --git a/backend/oasst_backend/models/message_embedding.py b/backend/oasst_backend/models/message_embedding.py index e69de29b..697a776b 100644 --- a/backend/oasst_backend/models/message_embedding.py +++ b/backend/oasst_backend/models/message_embedding.py @@ -0,0 +1,15 @@ +from typing import List +from uuid import UUID + +import sqlalchemy as sa +import sqlalchemy.dialects.postgresql as pg +from sqlmodel import ARRAY, Field, Float, SQLModel + + +class MessageEmbedding(SQLModel, table=True): + __tablename__ = "message_embedding" + __table_args__ = (sa.PrimaryKeyConstraint("message_id", "model"),) + + message_id: UUID = Field(sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("message.id"), nullable=False)) + model: str = Field(max_length=256, nullable=False) + embedding: List[float] = Field(sa_column=sa.Column(ARRAY(Float)), nullable=True) diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 8bb1eb4a..efdc1b33 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -8,7 +8,7 @@ from uuid import UUID, uuid4 import oasst_backend.models.db_payload as db_payload from loguru import logger from oasst_backend.journal_writer import JournalWriter -from oasst_backend.models import ApiClient, Message, MessageReaction, Task, TextLabels, User +from oasst_backend.models import ApiClient, Message, MessageEmbedding, MessageReaction, Task, TextLabels, User from oasst_backend.models.payload_column_type import PayloadContainer from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema @@ -165,7 +165,6 @@ class PromptRepository: role=role, payload=db_payload.MessagePayload(text=text), depth=depth, - miniLM_embedding=miniLM_embedding, ) if not task.collective: task.done = True @@ -369,7 +368,6 @@ class PromptRepository: payload: db_payload.MessagePayload, payload_type: str = None, depth: int = 0, - miniLM_embedding: List[float] = None, ) -> Message: if payload_type is None: if payload is None: @@ -389,13 +387,22 @@ class PromptRepository: payload_type=payload_type, payload=PayloadContainer(payload=payload), depth=depth, - miniLM_embedding=miniLM_embedding, ) self.db.add(message) self.db.commit() self.db.refresh(message) return message + def insert_message_embedding(self, message_id: UUID, model: str, embedding: List[float]) -> MessageEmbedding: + if None in (message_id, model, embedding): + raise OasstError("Paramters missing to add embedding", OasstErrorCode.GENERIC_ERROR) + + model = MessageEmbedding(message_id=message_id, model=model, embedding=embedding) + self.db.add(model) + self.db.commit() + self.db.refresh(model) + return model + def insert_reaction(self, task_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction: if self.user_id is None: raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED) diff --git a/backend/oasst_backend/utils/hugging_face.py b/backend/oasst_backend/utils/hugging_face.py index 867c537c..7a9fa8e3 100644 --- a/backend/oasst_backend/utils/hugging_face.py +++ b/backend/oasst_backend/utils/hugging_face.py @@ -8,7 +8,11 @@ from oasst_shared.exceptions import OasstError, OasstErrorCode class HF_url(str, Enum): HUGGINGFACE_TOXIC_ROBERTA = ("https://api-inference.huggingface.co/models/unitary/multilingual-toxic-xlm-roberta",) - HUGGINGFACE_MINILM_EMBEDDING = "https://api-inference.huggingface.co/pipeline/feature-extraction/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" + HUGGINGFACE_FEATURE_EXTRACTION = "https://api-inference.huggingface.co/pipeline/feature-extraction/" + + +class HF_embeddingModel(str, Enum): + MINILM = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" class HuggingFaceAPI: From 7101e0e7d527ec608362006927185bf0145de088 Mon Sep 17 00:00:00 2001 From: Nil-Andreu Date: Sun, 8 Jan 2023 20:17:14 +0100 Subject: [PATCH 05/39] [NEW] Message embedding created_date --- backend/oasst_backend/models/message_embedding.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/backend/oasst_backend/models/message_embedding.py b/backend/oasst_backend/models/message_embedding.py index 697a776b..74da5004 100644 --- a/backend/oasst_backend/models/message_embedding.py +++ b/backend/oasst_backend/models/message_embedding.py @@ -1,4 +1,5 @@ -from typing import List +from datetime import datetime +from typing import List, Optional from uuid import UUID import sqlalchemy as sa @@ -13,3 +14,8 @@ class MessageEmbedding(SQLModel, table=True): message_id: UUID = Field(sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("message.id"), nullable=False)) model: str = Field(max_length=256, nullable=False) embedding: List[float] = Field(sa_column=sa.Column(ARRAY(Float)), nullable=True) + + # In the case that the Message Embedding is created afterwards + created_date: Optional[datetime] = Field( + sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()) + ) From 19eee6be58fd7bd7298169207ba195a331fa7185 Mon Sep 17 00:00:00 2001 From: Nil-Andreu Date: Sun, 8 Jan 2023 20:53:25 +0100 Subject: [PATCH 06/39] [NEW] Removing embedding param in function Store Text Reply --- backend/oasst_backend/prompt_repository.py | 1 - 1 file changed, 1 deletion(-) diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index dde5003d..919da19b 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -96,7 +96,6 @@ class PromptRepository: text: str, frontend_message_id: str, user_frontend_message_id: str, - miniLM_embedding: Optional[List[float]] = None, ) -> Message: validate_frontend_message_id(frontend_message_id) validate_frontend_message_id(user_frontend_message_id) From 225a136ad1f65948c14da5cf513d7ed2f2420156 Mon Sep 17 00:00:00 2001 From: Nil-Andreu Date: Sun, 8 Jan 2023 20:58:19 +0100 Subject: [PATCH 07/39] [NEW] Refactor name of message_embedding object --- backend/oasst_backend/prompt_repository.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 919da19b..3feb77bc 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -233,11 +233,11 @@ class PromptRepository: if None in (message_id, model, embedding): raise OasstError("Paramters missing to add embedding", OasstErrorCode.GENERIC_ERROR) - model = MessageEmbedding(message_id=message_id, model=model, embedding=embedding) - self.db.add(model) + message_embedding = MessageEmbedding(message_id=message_id, model=model, embedding=embedding) + self.db.add(message_embedding) self.db.commit() - self.db.refresh(model) - return model + self.db.refresh(message_embedding) + return message_embedding def insert_reaction(self, task_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction: if self.user_id is None: From 412736f52c67d3330733c94f773f0f96b923854e Mon Sep 17 00:00:00 2001 From: Nil-Andreu Date: Sun, 8 Jan 2023 20:59:46 +0100 Subject: [PATCH 08/39] [NEW] insert_message_embedding: documentation --- backend/oasst_backend/prompt_repository.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 3feb77bc..c31c0061 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -230,6 +230,20 @@ class PromptRepository: ) def insert_message_embedding(self, message_id: UUID, model: str, embedding: List[float]) -> MessageEmbedding: + """Insert the embedding of a new message in the database. + + Args: + message_id (UUID): the identifier of the message we want to save its embedding + model (str): the model used for creating the embedding + embedding (List[float]): the values obtained from the message & model + + Raises: + OasstError: if misses some of the before params + + Returns: + MessageEmbedding: the instance in the database of the embedding saved for that message + """ + if None in (message_id, model, embedding): raise OasstError("Paramters missing to add embedding", OasstErrorCode.GENERIC_ERROR) From e241a8bf28019261523f2b873f261a759a35a89c Mon Sep 17 00:00:00 2001 From: Nil-Andreu Date: Sun, 8 Jan 2023 21:01:15 +0100 Subject: [PATCH 09/39] [NEW] Adding consistency in the URLs --- backend/oasst_backend/api/v1/tasks.py | 2 +- backend/oasst_backend/utils/hugging_face.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/oasst_backend/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py index 3bdacf49..186a6ac8 100644 --- a/backend/oasst_backend/api/v1/tasks.py +++ b/backend/oasst_backend/api/v1/tasks.py @@ -285,7 +285,7 @@ async def tasks_interaction( if not settings.DEBUG_SKIP_EMBEDDING_COMPUTATION: try: hugging_face_api = HuggingFaceAPI( - f"{HF_url.HUGGINGFACE_FEATURE_EXTRACTION.value}{HF_embeddingModel.MINILM.value}" + f"{HF_url.HUGGINGFACE_FEATURE_EXTRACTION.value}/{HF_embeddingModel.MINILM.value}" ) embedding = await hugging_face_api.post(interaction.text) pr.insert_message_embedding( diff --git a/backend/oasst_backend/utils/hugging_face.py b/backend/oasst_backend/utils/hugging_face.py index 7a9fa8e3..80be074d 100644 --- a/backend/oasst_backend/utils/hugging_face.py +++ b/backend/oasst_backend/utils/hugging_face.py @@ -8,7 +8,7 @@ from oasst_shared.exceptions import OasstError, OasstErrorCode class HF_url(str, Enum): HUGGINGFACE_TOXIC_ROBERTA = ("https://api-inference.huggingface.co/models/unitary/multilingual-toxic-xlm-roberta",) - HUGGINGFACE_FEATURE_EXTRACTION = "https://api-inference.huggingface.co/pipeline/feature-extraction/" + HUGGINGFACE_FEATURE_EXTRACTION = "https://api-inference.huggingface.co/pipeline/feature-extraction" class HF_embeddingModel(str, Enum): From 70620520b49680e9aa5cee89395fa87ffffeef55 Mon Sep 17 00:00:00 2001 From: Nil-Andreu Date: Sun, 8 Jan 2023 21:29:12 +0100 Subject: [PATCH 10/39] [NEW] Created date --- ...23_01_08_2128-aac6b2f66006_created_date.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 backend/alembic/versions/2023_01_08_2128-aac6b2f66006_created_date.py diff --git a/backend/alembic/versions/2023_01_08_2128-aac6b2f66006_created_date.py b/backend/alembic/versions/2023_01_08_2128-aac6b2f66006_created_date.py new file mode 100644 index 00000000..6d40d896 --- /dev/null +++ b/backend/alembic/versions/2023_01_08_2128-aac6b2f66006_created_date.py @@ -0,0 +1,30 @@ +"""Created date + +Revision ID: aac6b2f66006 +Revises: 35bdc1a08bb8 +Create Date: 2023-01-08 21:28:27.342729 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "aac6b2f66006" +down_revision = "35bdc1a08bb8" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "message_embedding", + sa.Column("created_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("message_embedding", "created_date") + # ### end Alembic commands ### From b930d2b9c7f6e29823cbf348d21dc15e2d2181ab Mon Sep 17 00:00:00 2001 From: Adrian Cowan Date: Mon, 9 Jan 2023 10:09:54 +1100 Subject: [PATCH 11/39] Trigger e2e testing on changes to oasst-shared --- .github/workflows/test-e2e.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/test-e2e.yaml b/.github/workflows/test-e2e.yaml index f2759808..44ef5b04 100644 --- a/.github/workflows/test-e2e.yaml +++ b/.github/workflows/test-e2e.yaml @@ -5,10 +5,12 @@ on: branches: - main paths: + - oasst-shared/** - backend/** - website/** pull_request: paths: + - oasst-shared/** - backend/** - website/** From b39b86309cdaf8a6d8b207b3278781fd80a47140 Mon Sep 17 00:00:00 2001 From: Nil-Andreu Date: Mon, 9 Jan 2023 09:03:02 +0100 Subject: [PATCH 12/39] [FIX] Import on top --- backend/oasst_backend/utils/hugging_face.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/oasst_backend/utils/hugging_face.py b/backend/oasst_backend/utils/hugging_face.py index 80be074d..ef440055 100644 --- a/backend/oasst_backend/utils/hugging_face.py +++ b/backend/oasst_backend/utils/hugging_face.py @@ -2,6 +2,7 @@ from enum import Enum from typing import Any, Dict import aiohttp +from loguru import logger from oasst_backend.config import settings from oasst_shared.exceptions import OasstError, OasstErrorCode @@ -51,7 +52,6 @@ class HuggingFaceAPI: async with session.post(self.api_url, headers=self.headers, json=payload) as response: # If we get a bad response if response.status != 200: - from loguru import logger logger.error(response) logger.info(self.headers) From ef7bd89df2124739973f703b51531c62324a6a15 Mon Sep 17 00:00:00 2001 From: Nil-Andreu Date: Mon, 9 Jan 2023 09:07:08 +0100 Subject: [PATCH 13/39] [NEW] ansible: DEBUG_SKIP_EMBEDDING_COMPUTATION --- ansible/dev.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/ansible/dev.yaml b/ansible/dev.yaml index 577abd68..81ab12ee 100644 --- a/ansible/dev.yaml +++ b/ansible/dev.yaml @@ -55,6 +55,7 @@ DEBUG_USE_SEED_DATA: "true" MAX_WORKERS: "1" RATE_LIMIT: "false" + DEBUG_SKIP_EMBEDDING_COMPUTATION: "true" ports: - 8080:8080 From 4b4a564a8fb8a6308527dbeb65b6cafd108e0f75 Mon Sep 17 00:00:00 2001 From: Nil-Andreu Date: Mon, 9 Jan 2023 10:10:19 +0100 Subject: [PATCH 14/39] [NEW] Camelcase & 2x space --- backend/oasst_backend/api/v1/hugging_face.py | 4 ++-- backend/oasst_backend/api/v1/tasks.py | 8 ++++---- backend/oasst_backend/utils/hugging_face.py | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/backend/oasst_backend/api/v1/hugging_face.py b/backend/oasst_backend/api/v1/hugging_face.py index a8d8aeb9..62d2ea6b 100644 --- a/backend/oasst_backend/api/v1/hugging_face.py +++ b/backend/oasst_backend/api/v1/hugging_face.py @@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends from oasst_backend.api import deps from oasst_backend.models import ApiClient from oasst_backend.schemas.hugging_face import ToxicityClassification -from oasst_backend.utils.hugging_face import HF_url, HuggingFaceAPI +from oasst_backend.utils.hugging_face import HfUrl, HuggingFaceAPI router = APIRouter() @@ -25,7 +25,7 @@ async def get_text_toxicity( ToxicityClassification: the score of toxicity of the message. """ - api_url: str = HF_url.HUGGINGFACE_TOXIC_ROBERTA.value + api_url: str = HfUrl.HUGGINGFACE_TOXIC_ROBERTA.value hugging_face_api = HuggingFaceAPI(api_url) response = await hugging_face_api.post(msg) diff --git a/backend/oasst_backend/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py index 186a6ac8..821ba562 100644 --- a/backend/oasst_backend/api/v1/tasks.py +++ b/backend/oasst_backend/api/v1/tasks.py @@ -9,7 +9,7 @@ from oasst_backend.api import deps from oasst_backend.api.v1.utils import prepare_conversation from oasst_backend.config import settings from oasst_backend.prompt_repository import PromptRepository, TaskRepository -from oasst_backend.utils.hugging_face import HF_embeddingModel, HF_url, HuggingFaceAPI +from oasst_backend.utils.hugging_face import HfEmbeddingModel, HfUrl, HuggingFaceAPI from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema from sqlmodel import Session @@ -285,15 +285,15 @@ async def tasks_interaction( if not settings.DEBUG_SKIP_EMBEDDING_COMPUTATION: try: hugging_face_api = HuggingFaceAPI( - f"{HF_url.HUGGINGFACE_FEATURE_EXTRACTION.value}/{HF_embeddingModel.MINILM.value}" + f"{HfUrl.HUGGINGFACE_FEATURE_EXTRACTION.value}/{HfEmbeddingModel.MINILM.value}" ) embedding = await hugging_face_api.post(interaction.text) pr.insert_message_embedding( - message_id=newMessage.id, model=HF_embeddingModel.MINILM.value, embedding=embedding + message_id=newMessage.id, model=HfEmbeddingModel.MINILM.value, embedding=embedding ) except OasstError: logger.error( - f"Could not fetch embbeddings for text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}." + f"Could not fetch embbeddings for text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}." ) return protocol_schema.TaskDone() diff --git a/backend/oasst_backend/utils/hugging_face.py b/backend/oasst_backend/utils/hugging_face.py index ef440055..87c6288e 100644 --- a/backend/oasst_backend/utils/hugging_face.py +++ b/backend/oasst_backend/utils/hugging_face.py @@ -7,12 +7,12 @@ from oasst_backend.config import settings from oasst_shared.exceptions import OasstError, OasstErrorCode -class HF_url(str, Enum): +class HfUrl(str, Enum): HUGGINGFACE_TOXIC_ROBERTA = ("https://api-inference.huggingface.co/models/unitary/multilingual-toxic-xlm-roberta",) HUGGINGFACE_FEATURE_EXTRACTION = "https://api-inference.huggingface.co/pipeline/feature-extraction" -class HF_embeddingModel(str, Enum): +class HfEmbeddingModel(str, Enum): MINILM = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" From 420b3739eb90a58fd0ec99a6a0dc33d2467ee956 Mon Sep 17 00:00:00 2001 From: James Melvin Date: Mon, 9 Jan 2023 22:31:44 +0530 Subject: [PATCH 15/39] fix: retrieval of valid_labels from API to populate the TEXT_LABEL_FLAGS in FlaggableElements.tsx --- website/src/components/FlaggableElement.tsx | 59 +++++-------------- website/src/components/Messages.tsx | 24 +++++++- website/src/components/Tasks/CreateTask.tsx | 6 +- website/src/lib/oasst_api_client.ts | 33 +++++++++++ website/src/pages/api/new_task/[task_type].ts | 6 ++ 5 files changed, 79 insertions(+), 49 deletions(-) diff --git a/website/src/components/FlaggableElement.tsx b/website/src/components/FlaggableElement.tsx index 63370444..df7296cc 100644 --- a/website/src/components/FlaggableElement.tsx +++ b/website/src/components/FlaggableElement.tsx @@ -27,8 +27,22 @@ import poster from "src/lib/poster"; import { colors } from "styles/Theme/colors"; import useSWRMutation from "swr/mutation"; +interface textFlagLabels { + attributeName: string; + labelText: string; + additionalExplanation?: string; +} + export const FlaggableElement = (props) => { const [isEditing, setIsEditing] = useBoolean(); + const flaggable_labels = props.flaggable_labels; + const TEXT_LABEL_FLAGS = flaggable_labels.valid_labels.map((valid_label) => { + return { + attributeName: valid_label.name, + labelText: valid_label.display_text, + additionalExplanation: valid_label.help_text, + }; + }); const { trigger } = useSWRMutation("/api/set_label", poster, { onSuccess: () => { setIsEditing.off; @@ -181,48 +195,3 @@ export function FlagCheckbox(props: { ); } -interface textFlagLabels { - attributeName: string; - labelText: string; - additionalExplanation?: string; -} -const TEXT_LABEL_FLAGS: textFlagLabels[] = [ - // For the time being this list is configured on the FE. - // In the future it may be provided by the API. - // { - // attributeName: "fails_task", - // labelText: "Fails to follow the correct instruction / task", - // additionalExplanation: "__TODO__", - // }, - // { - // attributeName: "not_customer_assistant_appropriate", - // labelText: "Inappropriate for customer assistant", - // additionalExplanation: "__TODO__", - // }, - { - attributeName: "sexual_content", - labelText: "Contains sexual content", - }, - { - attributeName: "violence", - labelText: "Contains violent content", - }, - // { - // attributeName: "encourages_violence", - // labelText: "Encourages or fails to discourage violence/abuse/terrorism/self-harm", - // }, - // { - // attributeName: "denigrates_a_protected_class", - // labelText: "Denigrates a protected class", - // }, - // { - // attributeName: "gives_harmful_advice", - // labelText: "Fails to follow the correct instruction / task", - // additionalExplanation: - // "The advice given in the output is harmful or counter-productive. This may be in addition to, but is distinct from the question about encouraging violence/abuse/terrorism/self-harm.", - // }, - // { - // attributeName: "expresses_moral_judgement", - // labelText: "Expresses moral judgement", - // }, -]; diff --git a/website/src/components/Messages.tsx b/website/src/components/Messages.tsx index fb84559e..71ff237b 100644 --- a/website/src/components/Messages.tsx +++ b/website/src/components/Messages.tsx @@ -10,11 +10,31 @@ export interface Message { message_id: string; } -export const Messages = ({ messages, post_id }: { messages: Message[]; post_id: string }) => { +export interface ValidLabel { + name: string; + display_text: string; + help_text: string; +} + +export const Messages = ({ + messages, + post_id, + valid_labels, +}: { + messages: Message[]; + post_id: string; + valid_labels: ValidLabel[]; +}) => { const items = messages.map((messageProps: Message, i: number) => { const { message_id, text } = messageProps; return ( - + ); diff --git a/website/src/components/Tasks/CreateTask.tsx b/website/src/components/Tasks/CreateTask.tsx index 7dcb0d0f..7af262e1 100644 --- a/website/src/components/Tasks/CreateTask.tsx +++ b/website/src/components/Tasks/CreateTask.tsx @@ -18,7 +18,7 @@ export interface CreateTaskProps { } export const CreateTask = ({ tasks, taskType, trigger, onSkipTask, onNextTask, mainBgClasses }: CreateTaskProps) => { const task = tasks[0].task; - + const valid_labels = tasks[0].valid_labels; const [inputText, setInputText] = useState(""); const submitResponse = (task: { id: string }) => { @@ -42,7 +42,9 @@ export const CreateTask = ({ tasks, taskType, trigger, onSkipTask, onNextTask, m <>
{taskType.label}

{taskType.overview}

- {task.conversation ? : null} + {task.conversation ? ( + + ) : null} <>
{taskType.instruction}
diff --git a/website/src/lib/oasst_api_client.ts b/website/src/lib/oasst_api_client.ts index 889d8b5b..1bbb13a8 100644 --- a/website/src/lib/oasst_api_client.ts +++ b/website/src/lib/oasst_api_client.ts @@ -48,6 +48,33 @@ export class OasstApiClient { return await resp.json(); } + private async get(path: string): Promise { + const resp = await fetch(`${this.oasstApiUrl}${path}`, { + method: "GET", + headers: { + "X-API-Key": this.oasstApiKey, + "Content-Type": "application/json", + }, + }); + + if (resp.status == 204) { + return null; + } + + if (resp.status >= 300) { + const errorText = await resp.text(); + let error: any; + try { + error = JSON.parse(errorText); + } catch (e) { + throw new OasstError(errorText, 0, resp.status); + } + throw new OasstError(error.message ?? error, error.error_code, resp.status); + } + + return await resp.json(); + } + // TODO return a strongly typed Task? // This method is used to store a task in RegisteredTask.task. // This is a raw Json type, so we can't use it to strongly type the task. @@ -96,6 +123,12 @@ export class OasstApiClient { ...content, }); } + + //Fetch valid labels. This is called every task. though the call may be redundant + //keeping this for future where the valid labels may change per task + async fetch_valid_text(): Promise { + return this.get(`/api/v1/text_labels/valid_labels`); + } } export const oasstApiClient = diff --git a/website/src/pages/api/new_task/[task_type].ts b/website/src/pages/api/new_task/[task_type].ts index 9f3be55c..a685d589 100644 --- a/website/src/pages/api/new_task/[task_type].ts +++ b/website/src/pages/api/new_task/[task_type].ts @@ -23,6 +23,7 @@ const handler = async (req, res) => { // Fetch the new task. const task = await oasstApiClient.fetchTask(task_type, token); + const valid_labels = await oasstApiClient.fetch_valid_text(); // Store the task and link it to the user.. const registeredTask = await prisma.registeredTask.create({ @@ -36,6 +37,11 @@ const handler = async (req, res) => { }, }); + // Add the valid labels that can be used to flag messages in this Task + registeredTask["valid_labels"] = valid_labels; + // Update the backend with our Task ID + await oasstApiClient.ackTask(task.id, registeredTask.id); + // Send the results to the client. res.status(200).json(registeredTask); }; From c33c9887dd1452828dcf0c9cb8145c3e96d91d9c Mon Sep 17 00:00:00 2001 From: Yannic Kilcher Date: Mon, 9 Jan 2023 21:23:39 +0100 Subject: [PATCH 16/39] improved pull request workflow --- .github/workflows/pre-commit.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml index 0f82185f..8395a06e 100644 --- a/.github/workflows/pre-commit.yaml +++ b/.github/workflows/pre-commit.yaml @@ -4,7 +4,7 @@ on: push: branches: - main - pull_request: + pull_request_target: workflow_call: jobs: @@ -17,7 +17,7 @@ jobs: python-version: "3.10" - uses: pre-commit/action@v3.0.0 - name: Post PR comment on failure - if: failure() && github.event_name == 'pull_request' + if: failure() && github.event_name == 'pull_request_target' uses: peter-evans/create-or-update-comment@v2 with: issue-number: ${{ github.event.pull_request.number }} From 156e1bca7d82ea8036297287cf0631bb440ed037 Mon Sep 17 00:00:00 2001 From: Yannic Kilcher Date: Mon, 9 Jan 2023 21:25:04 +0100 Subject: [PATCH 17/39] it was a bad fix --- .github/workflows/pre-commit.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml index 8395a06e..0f82185f 100644 --- a/.github/workflows/pre-commit.yaml +++ b/.github/workflows/pre-commit.yaml @@ -4,7 +4,7 @@ on: push: branches: - main - pull_request_target: + pull_request: workflow_call: jobs: @@ -17,7 +17,7 @@ jobs: python-version: "3.10" - uses: pre-commit/action@v3.0.0 - name: Post PR comment on failure - if: failure() && github.event_name == 'pull_request_target' + if: failure() && github.event_name == 'pull_request' uses: peter-evans/create-or-update-comment@v2 with: issue-number: ${{ github.event.pull_request.number }} From defd453639f96798c3431dde454c6b51ee24624e Mon Sep 17 00:00:00 2001 From: Andrew Maguire Date: Mon, 9 Jan 2023 22:59:00 +0000 Subject: [PATCH 18/39] replace `andrewm4894` with `LAION-AI` now that example notebook is merged (#573) --- notebooks/example/README.md | 2 +- notebooks/example/example.ipynb | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/notebooks/example/README.md b/notebooks/example/README.md index 2136834d..763b4812 100644 --- a/notebooks/example/README.md +++ b/notebooks/example/README.md @@ -1,6 +1,6 @@ # Example Notebook -[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/andrewm4894/Open-Assistant/blob/main/notebooks/example/example.ipynb) +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/LAION-AI/Open-Assistant/blob/main/notebooks/example/example.ipynb) This folder contains an example reference notebook structure and approach for this project. Please try and follow this structure as closely as possible. While diff --git a/notebooks/example/example.ipynb b/notebooks/example/example.ipynb index 2c6b1e01..5f938c13 100644 --- a/notebooks/example/example.ipynb +++ b/notebooks/example/example.ipynb @@ -9,10 +9,11 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/andrewm4894/Open-Assistant/blob/example-notebook/notebooks/example/example.ipynb)" + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/LAION-AI/Open-Assistant/blob/example-notebook/notebooks/example/example.ipynb)" ] }, { @@ -22,7 +23,7 @@ "outputs": [], "source": [ "# uncomment and run below lines to set up if running in colab\n", - "# !git clone https://github.com/andrewm4894/Open-Assistant.git\n", + "# !git clone https://github.com/LAION-AI/Open-Assistant.git\n", "# %cd Open-Assistant/notebooks/example\n", "# !pip install -r requirements.txt" ] @@ -146,12 +147,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.4" + "version": "3.7.4 (tags/v3.7.4:e09359112e, Jul 8 2019, 20:34:20) [MSC v.1916 64 bit (AMD64)]" }, "orig_nbformat": 4, "vscode": { "interpreter": { - "hash": "3ad933181bd8a04b432d3370b9dc3b0662ad032c4dfaa4e4f1596c548f763858" + "hash": "25d5c2324055587ceaeef27650c79ce8358ea61d7689f2e0b8ada5d53f85bce4" } } }, From 2e2efdec650bd0e660be91e528912e1bffffec96 Mon Sep 17 00:00:00 2001 From: Keith Stevens Date: Tue, 10 Jan 2023 12:51:56 +0900 Subject: [PATCH 19/39] Fixing a small remaining issues, reporting labels, rejecting tasks, and redirecting users on landing page --- website/src/components/FlaggableElement.tsx | 2 +- website/src/components/Messages.tsx | 8 +++++--- website/src/pages/api/reject_task.ts | 1 + website/src/pages/index.tsx | 11 +++++++++++ 4 files changed, 18 insertions(+), 4 deletions(-) diff --git a/website/src/components/FlaggableElement.tsx b/website/src/components/FlaggableElement.tsx index df7296cc..2f3f25b2 100644 --- a/website/src/components/FlaggableElement.tsx +++ b/website/src/components/FlaggableElement.tsx @@ -45,7 +45,7 @@ export const FlaggableElement = (props) => { }); const { trigger } = useSWRMutation("/api/set_label", poster, { onSuccess: () => { - setIsEditing.off; + setIsEditing.off(); }, }); diff --git a/website/src/components/Messages.tsx b/website/src/components/Messages.tsx index 71ff237b..b14b0aaa 100644 --- a/website/src/components/Messages.tsx +++ b/website/src/components/Messages.tsx @@ -1,6 +1,6 @@ import { Grid } from "@chakra-ui/react"; import { useColorMode } from "@chakra-ui/react"; -import { useMemo } from "react"; +import { forwardRef, useMemo } from "react"; import { FlaggableElement } from "./FlaggableElement"; @@ -43,7 +43,7 @@ export const Messages = ({ return {items}; }; -export const MessageView = ({ is_assistant, text, message_id }: Message) => { +export const MessageView = forwardRef(({ is_assistant, text, message_id }: Message, ref) => { const { colorMode } = useColorMode(); const bgColor = useMemo(() => { @@ -55,4 +55,6 @@ export const MessageView = ({ is_assistant, text, message_id }: Message) => { }, [colorMode, is_assistant]); return
{text}
; -}; +}); + +MessageView.displayName = "MessageView"; diff --git a/website/src/pages/api/reject_task.ts b/website/src/pages/api/reject_task.ts index d146c44b..fc807b67 100644 --- a/website/src/pages/api/reject_task.ts +++ b/website/src/pages/api/reject_task.ts @@ -1,6 +1,7 @@ import { Prisma } from "@prisma/client"; import { getToken } from "next-auth/jwt"; import { oasstApiClient } from "src/lib/oasst_api_client"; +import prisma from "src/lib/prismadb"; const handler = async (req, res) => { const token = await getToken({ req }); diff --git a/website/src/pages/index.tsx b/website/src/pages/index.tsx index 04f99829..64b1a0d5 100644 --- a/website/src/pages/index.tsx +++ b/website/src/pages/index.tsx @@ -1,10 +1,21 @@ import Head from "next/head"; +import { useRouter } from "next/router"; +import { useSession } from "next-auth/react"; +import { useEffect } from "react"; import { CallToAction } from "src/components/CallToAction"; import { Faq } from "src/components/Faq"; import { Hero } from "src/components/Hero"; import { getTransparentHeaderLayout } from "src/components/Layout"; const Home = () => { + const router = useRouter(); + const { status } = useSession(); + useEffect(() => { + if (status === "authenticated") { + router.push("/dashboard"); + } + }, [router, status]); + return ( <> From 34a1715923d4c8d908efdf44390620f583d87486 Mon Sep 17 00:00:00 2001 From: Keith Stevens Date: Tue, 10 Jan 2023 13:02:17 +0900 Subject: [PATCH 20/39] Cleaning up a suite of eslint warnings --- website/src/components/Buttons/Skip.tsx | 5 ----- website/src/components/CollapsableText.tsx | 2 +- website/src/components/FlaggableElement.tsx | 4 ++-- website/src/components/Messages.tsx | 8 ++++++-- website/src/components/Tasks/CreateTask.tsx | 3 +-- website/src/components/Tasks/EvaluateTask.tsx | 2 +- website/src/components/Tasks/Task.tsx | 8 ++++---- website/src/lib/oasst_api_client.ts | 4 ++-- website/src/pages/api/set_label.ts | 3 +-- website/src/pages/auth/verify.tsx | 1 - website/src/pages/create/summarize_story.tsx | 2 +- website/src/pages/evaluate/rate_summary.tsx | 2 +- 12 files changed, 20 insertions(+), 24 deletions(-) diff --git a/website/src/components/Buttons/Skip.tsx b/website/src/components/Buttons/Skip.tsx index 8440e348..bcd0eb79 100644 --- a/website/src/components/Buttons/Skip.tsx +++ b/website/src/components/Buttons/Skip.tsx @@ -1,10 +1,6 @@ import { Button, ButtonProps, - Menu, - MenuButton, - MenuItem, - MenuList, Modal, ModalBody, ModalCloseButton, @@ -16,7 +12,6 @@ import { useDisclosure, } from "@chakra-ui/react"; import { useState } from "react"; -import { FaChevronDown } from "react-icons/fa"; interface SkipButtonProps extends ButtonProps { onSkip: (reason: string) => void; diff --git a/website/src/components/CollapsableText.tsx b/website/src/components/CollapsableText.tsx index 1f34c508..5dcab595 100644 --- a/website/src/components/CollapsableText.tsx +++ b/website/src/components/CollapsableText.tsx @@ -12,7 +12,7 @@ import React from "react"; export const CollapsableText = ({ text, maxLength = 220 }) => { const { isOpen, onOpen, onClose } = useDisclosure(); - if (typeof text != "string" || text.length <= maxLength) { + if (typeof text !== "string" || text.length <= maxLength) { return text; } else { return ( diff --git a/website/src/components/FlaggableElement.tsx b/website/src/components/FlaggableElement.tsx index 2f3f25b2..9606f425 100644 --- a/website/src/components/FlaggableElement.tsx +++ b/website/src/components/FlaggableElement.tsx @@ -69,14 +69,14 @@ export const FlaggableElement = (props) => { const handleCheckboxState = (isChecked, idx) => { setCheckboxValues( checkboxValues.map((val, i) => { - return i == idx ? isChecked : val; + return i === idx ? isChecked : val; }) ); }; const handleSliderState = (newVal, idx) => { setSliderValues( sliderValues.map((val, i) => { - return i == idx ? newVal : val; + return i === idx ? newVal : val; }) ); }; diff --git a/website/src/components/Messages.tsx b/website/src/components/Messages.tsx index b14b0aaa..ef02d4e0 100644 --- a/website/src/components/Messages.tsx +++ b/website/src/components/Messages.tsx @@ -43,7 +43,7 @@ export const Messages = ({ return {items}; }; -export const MessageView = forwardRef(({ is_assistant, text, message_id }: Message, ref) => { +export const MessageView = forwardRef(({ is_assistant, text }: Message, ref) => { const { colorMode } = useColorMode(); const bgColor = useMemo(() => { @@ -54,7 +54,11 @@ export const MessageView = forwardRef(({ is_assistant, text, message_id }: Messa } }, [colorMode, is_assistant]); - return
{text}
; + return ( +
+ {text} +
+ ); }); MessageView.displayName = "MessageView"; diff --git a/website/src/components/Tasks/CreateTask.tsx b/website/src/components/Tasks/CreateTask.tsx index 7af262e1..e02dcdeb 100644 --- a/website/src/components/Tasks/CreateTask.tsx +++ b/website/src/components/Tasks/CreateTask.tsx @@ -1,10 +1,9 @@ import { useState } from "react"; - import { Messages } from "src/components/Messages"; import { TaskControls } from "src/components/Survey/TaskControls"; import { TrackedTextarea } from "src/components/Survey/TrackedTextarea"; import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards"; -import { TaskType } from "./TaskTypes"; +import { TaskType } from "src/components/Tasks/TaskTypes"; export interface CreateTaskProps { // we need a task type diff --git a/website/src/components/Tasks/EvaluateTask.tsx b/website/src/components/Tasks/EvaluateTask.tsx index 3871b2d9..d0a1f404 100644 --- a/website/src/components/Tasks/EvaluateTask.tsx +++ b/website/src/components/Tasks/EvaluateTask.tsx @@ -48,7 +48,7 @@ export const EvaluateTask = ({ tasks, trigger, onSkipTask, onNextTask, mainBgCla setRanking(tasks[0].task[sortables].map((_, idx) => idx))} onSubmitResponse={submitResponse} onSkipTask={(task, reason) => { diff --git a/website/src/components/Tasks/Task.tsx b/website/src/components/Tasks/Task.tsx index 153e0a93..777f5dd5 100644 --- a/website/src/components/Tasks/Task.tsx +++ b/website/src/components/Tasks/Task.tsx @@ -1,8 +1,8 @@ -import { CreateTask } from "./CreateTask"; -import { EvaluateTask } from "./EvaluateTask"; -import { TaskCategory, TaskTypes } from "./TaskTypes"; -import useSWRMutation from "swr/mutation"; +import { CreateTask } from "src/components/Tasks/CreateTask"; +import { EvaluateTask } from "src/components/Tasks/EvaluateTask"; +import { TaskCategory, TaskTypes } from "src/components/Tasks/TaskTypes"; import poster from "src/lib/poster"; +import useSWRMutation from "swr/mutation"; export const Task = ({ tasks, trigger, mutate, mainBgClasses }) => { const task = tasks[0].task; diff --git a/website/src/lib/oasst_api_client.ts b/website/src/lib/oasst_api_client.ts index 1bbb13a8..6a5ca58e 100644 --- a/website/src/lib/oasst_api_client.ts +++ b/website/src/lib/oasst_api_client.ts @@ -30,7 +30,7 @@ export class OasstApiClient { body: JSON.stringify(body), }); - if (resp.status == 204) { + if (resp.status === 204) { return null; } @@ -57,7 +57,7 @@ export class OasstApiClient { }, }); - if (resp.status == 204) { + if (resp.status === 204) { return null; } diff --git a/website/src/pages/api/set_label.ts b/website/src/pages/api/set_label.ts index 4db5ddaf..cfda114b 100644 --- a/website/src/pages/api/set_label.ts +++ b/website/src/pages/api/set_label.ts @@ -1,5 +1,4 @@ import { getToken } from "next-auth/jwt"; -import prisma from "src/lib/prismadb"; /** * Sets the Label in the Backend. @@ -15,7 +14,7 @@ const handler = async (req, res) => { } // Parse out the local message_id, task ID and the interaction contents. - const { message_id, post_id, label_map, text } = await JSON.parse(req.body); + const { message_id, label_map, text } = await JSON.parse(req.body); const interactionRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/text_labels`, { method: "POST", diff --git a/website/src/pages/auth/verify.tsx b/website/src/pages/auth/verify.tsx index b4d7d739..876aa677 100644 --- a/website/src/pages/auth/verify.tsx +++ b/website/src/pages/auth/verify.tsx @@ -1,7 +1,6 @@ import { useColorMode } from "@chakra-ui/react"; import Head from "next/head"; import { getCsrfToken, getProviders } from "next-auth/react"; -import { AuthLayout } from "src/components/AuthLayout"; export default function Verify() { const { colorMode } = useColorMode(); diff --git a/website/src/pages/create/summarize_story.tsx b/website/src/pages/create/summarize_story.tsx index 8620a8f5..61415962 100644 --- a/website/src/pages/create/summarize_story.tsx +++ b/website/src/pages/create/summarize_story.tsx @@ -63,7 +63,7 @@ const SummarizeStory = () => { return ; } - if (tasks.length == 0) { + if (tasks.length === 0) { return
No tasks found...
; } diff --git a/website/src/pages/evaluate/rate_summary.tsx b/website/src/pages/evaluate/rate_summary.tsx index 0d2352a2..e9da4a63 100644 --- a/website/src/pages/evaluate/rate_summary.tsx +++ b/website/src/pages/evaluate/rate_summary.tsx @@ -60,7 +60,7 @@ const RateSummary = () => { return ; } - if (tasks.length == 0) { + if (tasks.length === 0) { return (
From f028c07dfbd1ae46f9924bbef3f375ac37b64ae2 Mon Sep 17 00:00:00 2001 From: Keith Stevens Date: Tue, 10 Jan 2023 13:10:31 +0900 Subject: [PATCH 21/39] Fixing the typing for the MessageView forwardRef --- website/src/components/Messages.tsx | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/website/src/components/Messages.tsx b/website/src/components/Messages.tsx index ef02d4e0..bb97ab5a 100644 --- a/website/src/components/Messages.tsx +++ b/website/src/components/Messages.tsx @@ -1,6 +1,6 @@ import { Grid } from "@chakra-ui/react"; -import { useColorMode } from "@chakra-ui/react"; -import { forwardRef, useMemo } from "react"; +import { forwardRef, useColorMode } from "@chakra-ui/react"; +import { useMemo } from "react"; import { FlaggableElement } from "./FlaggableElement"; @@ -43,7 +43,7 @@ export const Messages = ({ return {items}; }; -export const MessageView = forwardRef(({ is_assistant, text }: Message, ref) => { +export const MessageView = forwardRef(({ is_assistant, text }: Message, ref) => { const { colorMode } = useColorMode(); const bgColor = useMemo(() => { From f2c235476e1cbe522354c17e0ea9533d1a86c3da Mon Sep 17 00:00:00 2001 From: Keith Stevens Date: Tue, 10 Jan 2023 13:24:56 +0900 Subject: [PATCH 22/39] Deleting the auto-ack on task fetching in favor of auto-acking on answer submission. Fixes broken e2e tests --- website/src/pages/api/new_task/[task_type].ts | 2 -- 1 file changed, 2 deletions(-) diff --git a/website/src/pages/api/new_task/[task_type].ts b/website/src/pages/api/new_task/[task_type].ts index a685d589..80334f76 100644 --- a/website/src/pages/api/new_task/[task_type].ts +++ b/website/src/pages/api/new_task/[task_type].ts @@ -39,8 +39,6 @@ const handler = async (req, res) => { // Add the valid labels that can be used to flag messages in this Task registeredTask["valid_labels"] = valid_labels; - // Update the backend with our Task ID - await oasstApiClient.ackTask(task.id, registeredTask.id); // Send the results to the client. res.status(200).json(registeredTask); From 54cc88bb1f8f7d1bc401e69f30cfc27737a4226d Mon Sep 17 00:00:00 2001 From: Keith Stevens Date: Tue, 10 Jan 2023 14:08:43 +0900 Subject: [PATCH 23/39] Ensure FlaggableElement always has the set of labels --- website/cypress/e2e/evaluate/rank_assistant_replies.cy.ts | 2 +- website/cypress/e2e/evaluate/rank_user_replies.cy.ts | 2 +- website/src/components/Messages/MessageTable.tsx | 5 +++-- website/src/components/Messages/MessageTableEntry.tsx | 6 ++++-- website/src/components/Tasks/EvaluateTask.tsx | 3 ++- website/src/pages/label/label_assistant_reply.tsx | 3 ++- website/src/pages/label/label_prompter_reply.tsx | 3 ++- 7 files changed, 15 insertions(+), 9 deletions(-) diff --git a/website/cypress/e2e/evaluate/rank_assistant_replies.cy.ts b/website/cypress/e2e/evaluate/rank_assistant_replies.cy.ts index c7b85695..3093bd56 100644 --- a/website/cypress/e2e/evaluate/rank_assistant_replies.cy.ts +++ b/website/cypress/e2e/evaluate/rank_assistant_replies.cy.ts @@ -1,7 +1,7 @@ describe("ranking prompter replies", () => { it("completes the current task on submit and on request shows a new task", () => { cy.signInWithEmail("cypress@example.com"); - cy.visit("/evaluate/rank_user_replies"); + cy.visit("/evaluate/rank_assistant_replies"); cy.get('[data-cy="task-id"').then((taskIdElement) => { const taskId = taskIdElement.text(); diff --git a/website/cypress/e2e/evaluate/rank_user_replies.cy.ts b/website/cypress/e2e/evaluate/rank_user_replies.cy.ts index c448a4c7..55487f1d 100644 --- a/website/cypress/e2e/evaluate/rank_user_replies.cy.ts +++ b/website/cypress/e2e/evaluate/rank_user_replies.cy.ts @@ -1,7 +1,7 @@ describe("ranking assistant replies", () => { it("completes the current task on submit and on request shows a new task", () => { cy.signInWithEmail("cypress@example.com"); - cy.visit("/evaluate/rank_assistant_replies"); + cy.visit("/evaluate/rank_user_replies"); cy.get('[data-cy="task-id"').then((taskIdElement) => { const taskId = taskIdElement.text(); diff --git a/website/src/components/Messages/MessageTable.tsx b/website/src/components/Messages/MessageTable.tsx index 95ccc540..33ffc4a4 100644 --- a/website/src/components/Messages/MessageTable.tsx +++ b/website/src/components/Messages/MessageTable.tsx @@ -1,11 +1,12 @@ import { Stack, StackDivider } from "@chakra-ui/react"; import { MessageTableEntry } from "src/components/Messages/MessageTableEntry"; -export function MessageTable({ messages }) { +export function MessageTable({ messages, valid_labels }) { + console.log(messages); return ( } spacing="4"> {messages.map((item, idx) => ( - + ))} ); diff --git a/website/src/components/Messages/MessageTableEntry.tsx b/website/src/components/Messages/MessageTableEntry.tsx index 9fad7262..0f58efad 100644 --- a/website/src/components/Messages/MessageTableEntry.tsx +++ b/website/src/components/Messages/MessageTableEntry.tsx @@ -2,6 +2,7 @@ import { Avatar, HStack, LinkBox, useColorModeValue } from "@chakra-ui/react"; import { boolean } from "boolean"; import NextLink from "next/link"; import { FlaggableElement } from "src/components/FlaggableElement"; +import type { ValidLabel } from "src/components/Messages"; interface Message { text: string; @@ -11,13 +12,14 @@ interface Message { interface MessageTableEntryProps { item: Message; idx: number; + valid_labels: ValidLabel[]; } export function MessageTableEntry(props: MessageTableEntryProps) { - const { item, idx } = props; + const { item, idx, valid_labels } = props; const bgColor = useColorModeValue(idx % 2 === 0 ? "bg-slate-800" : "bg-black", "bg-sky-900"); return ( - + ({ ...message, id: index })); } + const valid_labels = tasks[0].valid_labels; const sortables = tasks[0].task.replies ? "replies" : "prompts"; return ( @@ -42,7 +43,7 @@ export const EvaluateTask = ({ tasks, trigger, onSkipTask, onNextTask, mainBgCla

Given the following {sortables}, sort them from best to worst, best being first, worst being last.

- {messages ? : null} + {messages ? : null} diff --git a/website/src/pages/label/label_assistant_reply.tsx b/website/src/pages/label/label_assistant_reply.tsx index a0f961f7..89b612ca 100644 --- a/website/src/pages/label/label_assistant_reply.tsx +++ b/website/src/pages/label/label_assistant_reply.tsx @@ -19,6 +19,7 @@ const LabelAssistantReply = () => { } const task = tasks[0].task; + const valid_labels = tasks[0].valid_labels; const messages: Message[] = [ ...task.conversation.messages, { text: task.reply, is_assistant: true, message_id: task.message_id }, @@ -28,7 +29,7 @@ const LabelAssistantReply = () => { } + messages={} inputs={} controls={ { } const task = tasks[0].task; + const valid_labels = tasks[0].valid_labels; const messages: Message[] = [ ...task.conversation.messages, { text: task.reply, is_assistant: false, message_id: task.message_id }, @@ -28,7 +29,7 @@ const LabelPrompterReply = () => { } + messages={} inputs={} controls={ Date: Tue, 10 Jan 2023 14:16:19 +0900 Subject: [PATCH 24/39] Ensure FlaggableElement has an empty list of labels in the messages views --- .../src/components/Messages/MessageWithChildren.tsx | 4 ++-- website/src/hooks/tasks/useGenericTaskAPI.tsx | 2 ++ website/src/pages/messages/[id]/index.tsx | 2 +- website/src/pages/messages/index.tsx | 12 ++++++++++-- 4 files changed, 15 insertions(+), 5 deletions(-) diff --git a/website/src/components/Messages/MessageWithChildren.tsx b/website/src/components/Messages/MessageWithChildren.tsx index c412f996..7e604dfa 100644 --- a/website/src/components/Messages/MessageWithChildren.tsx +++ b/website/src/components/Messages/MessageWithChildren.tsx @@ -64,7 +64,7 @@ export function MessageWithChildren(props: MessageWithChildrenProps) { - + @@ -90,7 +90,7 @@ export function MessageWithChildren(props: MessageWithChildrenProps) { {children.map((item, idx) => ( - + ))} diff --git a/website/src/hooks/tasks/useGenericTaskAPI.tsx b/website/src/hooks/tasks/useGenericTaskAPI.tsx index a57c9da4..e300e220 100644 --- a/website/src/hooks/tasks/useGenericTaskAPI.tsx +++ b/website/src/hooks/tasks/useGenericTaskAPI.tsx @@ -1,4 +1,5 @@ import { useState } from "react"; +import type { ValidLabel } from "src/components/Messages"; import fetcher from "src/lib/fetcher"; import poster from "src/lib/poster"; import useSWRImmutable from "swr/immutable"; @@ -10,6 +11,7 @@ export interface TaskResponse { id: string; userId: string; task: TaskType; + valid_labels: ValidLabel[]; } export const useGenericTaskAPI = (taskApiEndpoint: string) => { diff --git a/website/src/pages/messages/[id]/index.tsx b/website/src/pages/messages/[id]/index.tsx index eacd5a72..933c7508 100644 --- a/website/src/pages/messages/[id]/index.tsx +++ b/website/src/pages/messages/[id]/index.tsx @@ -41,7 +41,7 @@ const MessageDetail = ({ id }) => { Parent - + )} diff --git a/website/src/pages/messages/index.tsx b/website/src/pages/messages/index.tsx index 2809ba5c..28ec9c54 100644 --- a/website/src/pages/messages/index.tsx +++ b/website/src/pages/messages/index.tsx @@ -52,7 +52,11 @@ const MessagesDashboard = () => { borderRadius="xl" className="p-6 shadow-sm" > - {receivedMessages ? : } + {receivedMessages ? ( + + ) : ( + + )} @@ -66,7 +70,11 @@ const MessagesDashboard = () => { borderRadius="xl" className="p-6 shadow-sm" > - {receivedUserMessages ? : } + {receivedUserMessages ? ( + + ) : ( + + )} From 54a042d002b3ac34384a8e3d2d1e8e73c171af16 Mon Sep 17 00:00:00 2001 From: AbdBarho Date: Mon, 9 Jan 2023 20:08:26 +0100 Subject: [PATCH 25/39] Centralize task types --- website/src/components/Buttons/Skip.tsx | 5 -- website/src/components/Messages.tsx | 7 +-- website/src/components/Tasks/CreateTask.tsx | 4 +- website/src/components/Tasks/TaskTypes.tsx | 4 +- .../tasks/create/useCreateInitialPrompt.ts | 9 --- .../src/hooks/tasks/create/useCreateReply.ts | 24 -------- .../tasks/evaluate/useRankInitialPrompts.ts | 9 --- .../hooks/tasks/evaluate/useRankReplies.ts | 25 -------- .../tasks/labeling/useLabelAssistantReply.ts | 22 ------- .../tasks/labeling/useLabelInitialPrompt.tsx | 15 ----- .../tasks/labeling/useLabelPrompterReply.ts | 22 ------- .../hooks/tasks/labeling/useLabelingTask.ts | 20 ------- website/src/hooks/tasks/useCreateReply.ts | 8 +++ website/src/hooks/tasks/useGenericTaskAPI.tsx | 11 +--- website/src/hooks/tasks/useLabelingTask.ts | 32 +++++++++++ website/src/hooks/tasks/useRankReplies.ts | 12 ++++ website/src/pages/create/assistant_reply.tsx | 2 +- website/src/pages/create/initial_prompt.tsx | 2 +- website/src/pages/create/user_reply.tsx | 2 +- .../pages/evaluate/rank_assistant_replies.tsx | 2 +- .../pages/evaluate/rank_initial_prompts.tsx | 2 +- .../src/pages/evaluate/rank_user_replies.tsx | 2 +- .../src/pages/label/label_assistant_reply.tsx | 7 +-- .../src/pages/label/label_initial_prompt.tsx | 5 +- .../src/pages/label/label_prompter_reply.tsx | 7 +-- website/src/pages/messages/index.tsx | 4 +- website/src/types/Conversation.ts | 9 +++ website/src/types/Task.ts | 24 ++++++++ website/src/types/Tasks.ts | 57 +++++++++++++++++++ 29 files changed, 162 insertions(+), 192 deletions(-) delete mode 100644 website/src/hooks/tasks/create/useCreateInitialPrompt.ts delete mode 100644 website/src/hooks/tasks/create/useCreateReply.ts delete mode 100644 website/src/hooks/tasks/evaluate/useRankInitialPrompts.ts delete mode 100644 website/src/hooks/tasks/evaluate/useRankReplies.ts delete mode 100644 website/src/hooks/tasks/labeling/useLabelAssistantReply.ts delete mode 100644 website/src/hooks/tasks/labeling/useLabelInitialPrompt.tsx delete mode 100644 website/src/hooks/tasks/labeling/useLabelPrompterReply.ts delete mode 100644 website/src/hooks/tasks/labeling/useLabelingTask.ts create mode 100644 website/src/hooks/tasks/useCreateReply.ts create mode 100644 website/src/hooks/tasks/useLabelingTask.ts create mode 100644 website/src/hooks/tasks/useRankReplies.ts create mode 100644 website/src/types/Conversation.ts create mode 100644 website/src/types/Task.ts create mode 100644 website/src/types/Tasks.ts diff --git a/website/src/components/Buttons/Skip.tsx b/website/src/components/Buttons/Skip.tsx index 8440e348..bcd0eb79 100644 --- a/website/src/components/Buttons/Skip.tsx +++ b/website/src/components/Buttons/Skip.tsx @@ -1,10 +1,6 @@ import { Button, ButtonProps, - Menu, - MenuButton, - MenuItem, - MenuList, Modal, ModalBody, ModalCloseButton, @@ -16,7 +12,6 @@ import { useDisclosure, } from "@chakra-ui/react"; import { useState } from "react"; -import { FaChevronDown } from "react-icons/fa"; interface SkipButtonProps extends ButtonProps { onSkip: (reason: string) => void; diff --git a/website/src/components/Messages.tsx b/website/src/components/Messages.tsx index 71ff237b..20d26808 100644 --- a/website/src/components/Messages.tsx +++ b/website/src/components/Messages.tsx @@ -1,15 +1,10 @@ import { Grid } from "@chakra-ui/react"; import { useColorMode } from "@chakra-ui/react"; import { useMemo } from "react"; +import { Message } from "src/types/Conversation"; import { FlaggableElement } from "./FlaggableElement"; -export interface Message { - text: string; - is_assistant: boolean; - message_id: string; -} - export interface ValidLabel { name: string; display_text: string; diff --git a/website/src/components/Tasks/CreateTask.tsx b/website/src/components/Tasks/CreateTask.tsx index 7af262e1..2c11cb01 100644 --- a/website/src/components/Tasks/CreateTask.tsx +++ b/website/src/components/Tasks/CreateTask.tsx @@ -4,13 +4,13 @@ import { Messages } from "src/components/Messages"; import { TaskControls } from "src/components/Survey/TaskControls"; import { TrackedTextarea } from "src/components/Survey/TrackedTextarea"; import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards"; -import { TaskType } from "./TaskTypes"; +import { TaskInfo } from "./TaskTypes"; export interface CreateTaskProps { // we need a task type // eslint-disable-next-line @typescript-eslint/no-explicit-any tasks: any[]; - taskType: TaskType; + taskType: TaskInfo; trigger: (update: { id: string; update_type: string; content: { text: string } }) => void; onSkipTask: (task: { id: string }, reason: string) => void; onNextTask: () => void; diff --git a/website/src/components/Tasks/TaskTypes.tsx b/website/src/components/Tasks/TaskTypes.tsx index 409e7038..d255756a 100644 --- a/website/src/components/Tasks/TaskTypes.tsx +++ b/website/src/components/Tasks/TaskTypes.tsx @@ -4,7 +4,7 @@ export enum TaskCategory { Label = "Label", } -export interface TaskType { +export interface TaskInfo { label: string; desc: string; category: TaskCategory; @@ -14,7 +14,7 @@ export interface TaskType { instruction?: string; } -export const TaskTypes: TaskType[] = [ +export const TaskTypes: TaskInfo[] = [ // create { label: "Create Initial Prompts", diff --git a/website/src/hooks/tasks/create/useCreateInitialPrompt.ts b/website/src/hooks/tasks/create/useCreateInitialPrompt.ts deleted file mode 100644 index cf0193e8..00000000 --- a/website/src/hooks/tasks/create/useCreateInitialPrompt.ts +++ /dev/null @@ -1,9 +0,0 @@ -import { useGenericTaskAPI } from "../useGenericTaskAPI"; - -interface CreateInitialPromptTask { - id: string; - type: "initial_prompt"; - hint: string; -} - -export const useCreateInitialPrompt = () => useGenericTaskAPI("initial_prompt"); diff --git a/website/src/hooks/tasks/create/useCreateReply.ts b/website/src/hooks/tasks/create/useCreateReply.ts deleted file mode 100644 index 0bc78319..00000000 --- a/website/src/hooks/tasks/create/useCreateReply.ts +++ /dev/null @@ -1,24 +0,0 @@ -import { useGenericTaskAPI } from "../useGenericTaskAPI"; - -interface BaseCreateReplyTask { - id: string; - conversation: { - messages: Array<{ - text: string; - is_assistant: boolean; - message_id: string; - }>; - }; -} - -export interface CreateAssistantReplyTask extends BaseCreateReplyTask { - type: "assistant_reply"; -} - -export interface CreatePrompterReplyTask extends BaseCreateReplyTask { - type: "prompter_reply"; -} - -export const useCreateAssistantReply = () => useGenericTaskAPI("assistant_reply"); - -export const useCreatePrompterReply = () => useGenericTaskAPI("prompter_reply"); diff --git a/website/src/hooks/tasks/evaluate/useRankInitialPrompts.ts b/website/src/hooks/tasks/evaluate/useRankInitialPrompts.ts deleted file mode 100644 index da772c80..00000000 --- a/website/src/hooks/tasks/evaluate/useRankInitialPrompts.ts +++ /dev/null @@ -1,9 +0,0 @@ -import { useGenericTaskAPI } from "../useGenericTaskAPI"; - -interface RankInitialPromptsTask { - id: string; - type: "rank_initial_prompts"; - prompts: string[]; -} - -export const useRankInitialPromptsTask = () => useGenericTaskAPI("rank_initial_prompts"); diff --git a/website/src/hooks/tasks/evaluate/useRankReplies.ts b/website/src/hooks/tasks/evaluate/useRankReplies.ts deleted file mode 100644 index 2d8d513f..00000000 --- a/website/src/hooks/tasks/evaluate/useRankReplies.ts +++ /dev/null @@ -1,25 +0,0 @@ -import { useGenericTaskAPI } from "../useGenericTaskAPI"; - -interface BaseRankRepliesTask { - id: string; - replies: string[]; - conversation: { - messages: Array<{ - text: string; - is_assistant: boolean; - message_id: string; - }>; - }; -} - -interface RankAssistantRepliesTask extends BaseRankRepliesTask { - type: "rank_assistant_replies"; -} - -interface RankPrompterRepliesTask extends BaseRankRepliesTask { - type: "rank_prompter_replies"; -} - -export const useRankAssistantRepliesTask = () => useGenericTaskAPI("rank_assistant_replies"); - -export const useRankPrompterRepliesTask = () => useGenericTaskAPI("rank_prompter_replies"); diff --git a/website/src/hooks/tasks/labeling/useLabelAssistantReply.ts b/website/src/hooks/tasks/labeling/useLabelAssistantReply.ts deleted file mode 100644 index 3c44046e..00000000 --- a/website/src/hooks/tasks/labeling/useLabelAssistantReply.ts +++ /dev/null @@ -1,22 +0,0 @@ -import { TaskResponse } from "../useGenericTaskAPI"; -import { LabelingTaskType, useLabelingTask } from "./useLabelingTask"; - -export interface LabelAssistantReplyTask { - id: string; - type: LabelingTaskType.label_assistant_reply; - message_id: string; - valid_labels: string[]; - reply: string; - conversation: { - messages: Array<{ - text: string; - is_assistant: boolean; - message_id: string; - }>; - }; -} - -export type LabelAssistantReplyTaskResponse = TaskResponse; - -export const useLabelAssistantReplyTask = () => - useLabelingTask(LabelingTaskType.label_assistant_reply); diff --git a/website/src/hooks/tasks/labeling/useLabelInitialPrompt.tsx b/website/src/hooks/tasks/labeling/useLabelInitialPrompt.tsx deleted file mode 100644 index f7ba8ab5..00000000 --- a/website/src/hooks/tasks/labeling/useLabelInitialPrompt.tsx +++ /dev/null @@ -1,15 +0,0 @@ -import { TaskResponse } from "../useGenericTaskAPI"; -import { LabelingTaskType, useLabelingTask } from "./useLabelingTask"; - -export interface LabelInitialPromptTask { - id: string; - type: LabelingTaskType.label_initial_prompt; - message_id: string; - valid_labels: string[]; - prompt: string; -} - -export type LabelInitialPromptTaskResponse = TaskResponse; - -export const useLabelInitialPromptTask = () => - useLabelingTask(LabelingTaskType.label_initial_prompt); diff --git a/website/src/hooks/tasks/labeling/useLabelPrompterReply.ts b/website/src/hooks/tasks/labeling/useLabelPrompterReply.ts deleted file mode 100644 index 9de2057f..00000000 --- a/website/src/hooks/tasks/labeling/useLabelPrompterReply.ts +++ /dev/null @@ -1,22 +0,0 @@ -import { TaskResponse } from "../useGenericTaskAPI"; -import { LabelingTaskType, useLabelingTask } from "./useLabelingTask"; - -export interface LabelPrompterReplyTask { - id: string; - type: LabelingTaskType.label_prompter_reply; - message_id: string; - valid_labels: string[]; - reply: string; - conversation: { - messages: Array<{ - text: string; - is_assistant: boolean; - message_id: string; - }>; - }; -} - -export type LabelPrompterReplyTaskResponse = TaskResponse; - -export const useLabelPrompterReplyTask = () => - useLabelingTask(LabelingTaskType.label_prompter_reply); diff --git a/website/src/hooks/tasks/labeling/useLabelingTask.ts b/website/src/hooks/tasks/labeling/useLabelingTask.ts deleted file mode 100644 index 27555284..00000000 --- a/website/src/hooks/tasks/labeling/useLabelingTask.ts +++ /dev/null @@ -1,20 +0,0 @@ -import { useGenericTaskAPI } from "../useGenericTaskAPI"; - -export const enum LabelingTaskType { - label_initial_prompt = "label_initial_prompt", - label_prompter_reply = "label_prompter_reply", - label_assistant_reply = "label_assistant_reply", -} - -export const useLabelingTask = (endpoint: LabelingTaskType) => { - const { tasks, isLoading, trigger, reset, error } = useGenericTaskAPI(endpoint); - - const submit = (id: string, message_id: string, text: string, validLabels: string[], labelWeights: number[]) => { - console.assert(validLabels.length === labelWeights.length); - const labels = Object.fromEntries(validLabels.map((label, i) => [label, labelWeights[i]])); - - return trigger({ id, update_type: "text_labels", content: { labels, text, message_id } }); - }; - - return { tasks, isLoading, submit, reset, error }; -}; diff --git a/website/src/hooks/tasks/useCreateReply.ts b/website/src/hooks/tasks/useCreateReply.ts new file mode 100644 index 00000000..23bc041d --- /dev/null +++ b/website/src/hooks/tasks/useCreateReply.ts @@ -0,0 +1,8 @@ +import { TaskType } from "src/types/Task"; +import { CreateAssistantReplyTask, CreateInitialPromptTask, CreatePrompterReplyTask } from "src/types/Tasks"; + +import { useGenericTaskAPI } from "./useGenericTaskAPI"; + +export const useCreateAssistantReply = () => useGenericTaskAPI(TaskType.assistant_reply); +export const useCreatePrompterReply = () => useGenericTaskAPI(TaskType.prompter_reply); +export const useCreateInitialPrompt = () => useGenericTaskAPI(TaskType.initial_prompt); diff --git a/website/src/hooks/tasks/useGenericTaskAPI.tsx b/website/src/hooks/tasks/useGenericTaskAPI.tsx index a57c9da4..a456cbf1 100644 --- a/website/src/hooks/tasks/useGenericTaskAPI.tsx +++ b/website/src/hooks/tasks/useGenericTaskAPI.tsx @@ -1,18 +1,11 @@ import { useState } from "react"; import fetcher from "src/lib/fetcher"; import poster from "src/lib/poster"; +import { BaseTask, TaskResponse } from "src/types/Task"; import useSWRImmutable from "swr/immutable"; import useSWRMutation from "swr/mutation"; -// TODO: type & centralize types for all tasks - -export interface TaskResponse { - id: string; - userId: string; - task: TaskType; -} - -export const useGenericTaskAPI = (taskApiEndpoint: string) => { +export const useGenericTaskAPI = (taskApiEndpoint: string) => { type ConcreteTaskResponse = TaskResponse; const [tasks, setTasks] = useState([]); diff --git a/website/src/hooks/tasks/useLabelingTask.ts b/website/src/hooks/tasks/useLabelingTask.ts new file mode 100644 index 00000000..5e5050ab --- /dev/null +++ b/website/src/hooks/tasks/useLabelingTask.ts @@ -0,0 +1,32 @@ +import { BaseTask, TaskResponse, TaskType } from "src/types/Task"; +import { LabelAssistantReplyTask, LabelInitialPromptTask, LabelPrompterReplyTask } from "src/types/Tasks"; + +import { useGenericTaskAPI } from "./useGenericTaskAPI"; + +const useLabelingTask = ( + endpoint: TaskType.label_assistant_reply | TaskType.label_prompter_reply | TaskType.label_initial_prompt +) => { + const { tasks, isLoading, trigger, reset, error } = useGenericTaskAPI(endpoint); + + const submit = (id: string, message_id: string, text: string, validLabels: string[], labelWeights: number[]) => { + console.assert(validLabels.length === labelWeights.length); + const labels = Object.fromEntries(validLabels.map((label, i) => [label, labelWeights[i]])); + + return trigger({ id, update_type: "text_labels", content: { labels, text, message_id } }); + }; + + return { tasks, isLoading, submit, reset, error }; +}; + +export type LabelAssistantReplyTaskResponse = TaskResponse; + +export const useLabelAssistantReplyTask = () => + useLabelingTask(TaskType.label_assistant_reply); + +export type LabelInitialPromptTaskResponse = TaskResponse; + +export const useLabelInitialPromptTask = () => useLabelingTask(TaskType.label_initial_prompt); + +export type LabelPrompterReplyTaskResponse = TaskResponse; + +export const useLabelPrompterReplyTask = () => useLabelingTask(TaskType.label_prompter_reply); diff --git a/website/src/hooks/tasks/useRankReplies.ts b/website/src/hooks/tasks/useRankReplies.ts new file mode 100644 index 00000000..d4accda0 --- /dev/null +++ b/website/src/hooks/tasks/useRankReplies.ts @@ -0,0 +1,12 @@ +import { TaskType } from "src/types/Task"; +import { RankAssistantRepliesTask, RankInitialPromptsTask, RankPrompterRepliesTask } from "src/types/Tasks"; + +import { useGenericTaskAPI } from "./useGenericTaskAPI"; + +export const useRankAssistantRepliesTask = () => + useGenericTaskAPI(TaskType.rank_assistant_replies); + +export const useRankPrompterRepliesTask = () => + useGenericTaskAPI(TaskType.rank_prompter_replies); + +export const useRankInitialPromptsTask = () => useGenericTaskAPI(TaskType.rank_initial_prompts); diff --git a/website/src/pages/create/assistant_reply.tsx b/website/src/pages/create/assistant_reply.tsx index e9aee226..17facd5d 100644 --- a/website/src/pages/create/assistant_reply.tsx +++ b/website/src/pages/create/assistant_reply.tsx @@ -3,7 +3,7 @@ import { useColorMode } from "@chakra-ui/react"; import Head from "next/head"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; -import { useCreateAssistantReply } from "src/hooks/tasks/create/useCreateReply"; +import { useCreateAssistantReply } from "src/hooks/tasks/useCreateReply"; const AssistantReply = () => { const { tasks, isLoading, reset, trigger } = useCreateAssistantReply(); diff --git a/website/src/pages/create/initial_prompt.tsx b/website/src/pages/create/initial_prompt.tsx index efea6474..57f0dabd 100644 --- a/website/src/pages/create/initial_prompt.tsx +++ b/website/src/pages/create/initial_prompt.tsx @@ -3,7 +3,7 @@ import { useColorMode } from "@chakra-ui/react"; import Head from "next/head"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; -import { useCreateInitialPrompt } from "src/hooks/tasks/create/useCreateInitialPrompt"; +import { useCreateInitialPrompt } from "src/hooks/tasks/useCreateReply"; const InitialPrompt = () => { const { tasks, isLoading, reset, trigger } = useCreateInitialPrompt(); diff --git a/website/src/pages/create/user_reply.tsx b/website/src/pages/create/user_reply.tsx index 2394bd63..a0af0e95 100644 --- a/website/src/pages/create/user_reply.tsx +++ b/website/src/pages/create/user_reply.tsx @@ -3,7 +3,7 @@ import Head from "next/head"; import { Container } from "src/components/Container"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; -import { useCreatePrompterReply } from "src/hooks/tasks/create/useCreateReply"; +import { useCreatePrompterReply } from "src/hooks/tasks/useCreateReply"; const UserReply = () => { const { tasks, isLoading, reset, trigger } = useCreatePrompterReply(); diff --git a/website/src/pages/evaluate/rank_assistant_replies.tsx b/website/src/pages/evaluate/rank_assistant_replies.tsx index 931c9194..8546e7a6 100644 --- a/website/src/pages/evaluate/rank_assistant_replies.tsx +++ b/website/src/pages/evaluate/rank_assistant_replies.tsx @@ -3,7 +3,7 @@ import Head from "next/head"; import { Container } from "src/components/Container"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; -import { useRankAssistantRepliesTask } from "src/hooks/tasks/evaluate/useRankReplies"; +import { useRankAssistantRepliesTask } from "src/hooks/tasks/useRankReplies"; const RankAssistantReplies = () => { const { tasks, isLoading, reset, trigger } = useRankAssistantRepliesTask(); diff --git a/website/src/pages/evaluate/rank_initial_prompts.tsx b/website/src/pages/evaluate/rank_initial_prompts.tsx index 4b717143..1898a93a 100644 --- a/website/src/pages/evaluate/rank_initial_prompts.tsx +++ b/website/src/pages/evaluate/rank_initial_prompts.tsx @@ -3,7 +3,7 @@ import Head from "next/head"; import { Container } from "src/components/Container"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; -import { useRankInitialPromptsTask } from "src/hooks/tasks/evaluate/useRankInitialPrompts"; +import { useRankInitialPromptsTask } from "src/hooks/tasks/useRankReplies"; const RankInitialPrompts = () => { const { tasks, isLoading, reset, trigger } = useRankInitialPromptsTask(); diff --git a/website/src/pages/evaluate/rank_user_replies.tsx b/website/src/pages/evaluate/rank_user_replies.tsx index 659874a2..e2a39977 100644 --- a/website/src/pages/evaluate/rank_user_replies.tsx +++ b/website/src/pages/evaluate/rank_user_replies.tsx @@ -3,7 +3,7 @@ import Head from "next/head"; import { Container } from "src/components/Container"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; -import { useRankPrompterRepliesTask } from "src/hooks/tasks/evaluate/useRankReplies"; +import { useRankPrompterRepliesTask } from "src/hooks/tasks/useRankReplies"; const RankUserReplies = () => { const { tasks, isLoading, reset, trigger } = useRankPrompterRepliesTask(); diff --git a/website/src/pages/label/label_assistant_reply.tsx b/website/src/pages/label/label_assistant_reply.tsx index a0f961f7..59a7bbcc 100644 --- a/website/src/pages/label/label_assistant_reply.tsx +++ b/website/src/pages/label/label_assistant_reply.tsx @@ -1,13 +1,10 @@ import { useState } from "react"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; -import { Message } from "src/components/Messages"; import { MessageTable } from "src/components/Messages/MessageTable"; import { TaskControls } from "src/components/Survey/TaskControls"; import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask"; -import { - LabelAssistantReplyTaskResponse, - useLabelAssistantReplyTask, -} from "src/hooks/tasks/labeling/useLabelAssistantReply"; +import { LabelAssistantReplyTaskResponse, useLabelAssistantReplyTask } from "src/hooks/tasks/useLabelingTask"; +import { Message } from "src/types/Conversation"; const LabelAssistantReply = () => { const [sliderValues, setSliderValues] = useState([]); diff --git a/website/src/pages/label/label_initial_prompt.tsx b/website/src/pages/label/label_initial_prompt.tsx index 3c791f23..4cd4343b 100644 --- a/website/src/pages/label/label_initial_prompt.tsx +++ b/website/src/pages/label/label_initial_prompt.tsx @@ -3,10 +3,7 @@ import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { MessageView } from "src/components/Messages"; import { TaskControls } from "src/components/Survey/TaskControls"; import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask"; -import { - LabelInitialPromptTaskResponse, - useLabelInitialPromptTask, -} from "src/hooks/tasks/labeling/useLabelInitialPrompt"; +import { LabelInitialPromptTaskResponse, useLabelInitialPromptTask } from "src/hooks/tasks/useLabelingTask"; const LabelInitialPrompt = () => { const [sliderValues, setSliderValues] = useState([]); diff --git a/website/src/pages/label/label_prompter_reply.tsx b/website/src/pages/label/label_prompter_reply.tsx index 2fd3d76a..e11e801d 100644 --- a/website/src/pages/label/label_prompter_reply.tsx +++ b/website/src/pages/label/label_prompter_reply.tsx @@ -1,13 +1,10 @@ import { useState } from "react"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; -import { Message } from "src/components/Messages"; import { MessageTable } from "src/components/Messages/MessageTable"; import { TaskControls } from "src/components/Survey/TaskControls"; import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask"; -import { - LabelPrompterReplyTaskResponse, - useLabelPrompterReplyTask, -} from "src/hooks/tasks/labeling/useLabelPrompterReply"; +import { LabelPrompterReplyTaskResponse, useLabelPrompterReplyTask } from "src/hooks/tasks/useLabelingTask"; +import { Message } from "src/types/Conversation"; const LabelPrompterReply = () => { const [sliderValues, setSliderValues] = useState([]); diff --git a/website/src/pages/messages/index.tsx b/website/src/pages/messages/index.tsx index 2809ba5c..9cdc2ac5 100644 --- a/website/src/pages/messages/index.tsx +++ b/website/src/pages/messages/index.tsx @@ -2,9 +2,9 @@ import { Box, CircularProgress, SimpleGrid, Text, useColorModeValue } from "@cha import Head from "next/head"; import { useEffect, useState } from "react"; import { getDashboardLayout } from "src/components/Layout"; -import { Message } from "src/components/Messages"; import { MessageTable } from "src/components/Messages/MessageTable"; import fetcher from "src/lib/fetcher"; +import { Message } from "src/types/Conversation"; import useSWRImmutable from "swr/immutable"; const MessagesDashboard = () => { @@ -74,6 +74,6 @@ const MessagesDashboard = () => { ); }; -MessagesDashboard.getLayout = (page) => getDashboardLayout(page); +MessagesDashboard.getLayout = getDashboardLayout; export default MessagesDashboard; diff --git a/website/src/types/Conversation.ts b/website/src/types/Conversation.ts new file mode 100644 index 00000000..f12b2781 --- /dev/null +++ b/website/src/types/Conversation.ts @@ -0,0 +1,9 @@ +export interface Message { + text: string; + is_assistant: boolean; + message_id: string; +} + +export interface Conversation { + messages: Message[]; +} diff --git a/website/src/types/Task.ts b/website/src/types/Task.ts new file mode 100644 index 00000000..0dca6a5b --- /dev/null +++ b/website/src/types/Task.ts @@ -0,0 +1,24 @@ +export const enum TaskType { + initial_prompt = "initial_prompt", + assistant_reply = "assistant_reply", + prompter_reply = "prompter_reply", + + rank_initial_prompts = "rank_initial_prompts", + rank_assistant_replies = "rank_assistant_replies", + rank_prompter_replies = "rank_prompter_replies", + + label_initial_prompt = "label_initial_prompt", + label_prompter_reply = "label_prompter_reply", + label_assistant_reply = "label_assistant_reply", +} + +export interface BaseTask { + id: string; + type: TaskType; +} + +export interface TaskResponse { + id: string; + userId: string; + task: Task; +} diff --git a/website/src/types/Tasks.ts b/website/src/types/Tasks.ts new file mode 100644 index 00000000..50c251bb --- /dev/null +++ b/website/src/types/Tasks.ts @@ -0,0 +1,57 @@ +import { Conversation } from "./Conversation"; +import { BaseTask, TaskType } from "./Task"; + +export interface CreateInitialPromptTask extends BaseTask { + type: TaskType.initial_prompt; + hint: string; +} + +export interface CreateAssistantReplyTask extends BaseTask { + type: TaskType.assistant_reply; + conversation: Conversation; +} + +export interface CreatePrompterReplyTask extends BaseTask { + type: TaskType.prompter_reply; + conversation: Conversation; +} + +export interface RankInitialPromptsTask extends BaseTask { + type: TaskType.rank_initial_prompts; + prompts: string[]; +} + +export interface RankAssistantRepliesTask extends BaseTask { + type: TaskType.rank_assistant_replies; + conversation: Conversation; + replies: string[]; +} + +export interface RankPrompterRepliesTask extends BaseTask { + type: TaskType.rank_prompter_replies; + conversation: Conversation; + replies: string[]; +} + +export interface LabelAssistantReplyTask extends BaseTask { + type: TaskType.label_assistant_reply; + message_id: string; + conversation: Conversation; + reply: string; + valid_labels: string[]; +} + +export interface LabelInitialPromptTask extends BaseTask { + type: TaskType.label_initial_prompt; + message_id: string; + valid_labels: string[]; + prompt: string; +} + +export interface LabelPrompterReplyTask extends BaseTask { + type: TaskType.label_prompter_reply; + message_id: string; + conversation: Conversation; + reply: string; + valid_labels: string[]; +} From aa22ed0d1ca25f8cae3fdddae07e4cd012ce6ba1 Mon Sep 17 00:00:00 2001 From: Keith Stevens Date: Tue, 10 Jan 2023 14:52:55 +0900 Subject: [PATCH 26/39] Remove debug log --- website/src/components/Messages/MessageTable.tsx | 1 - 1 file changed, 1 deletion(-) diff --git a/website/src/components/Messages/MessageTable.tsx b/website/src/components/Messages/MessageTable.tsx index 33ffc4a4..bacd27f9 100644 --- a/website/src/components/Messages/MessageTable.tsx +++ b/website/src/components/Messages/MessageTable.tsx @@ -2,7 +2,6 @@ import { Stack, StackDivider } from "@chakra-ui/react"; import { MessageTableEntry } from "src/components/Messages/MessageTableEntry"; export function MessageTable({ messages, valid_labels }) { - console.log(messages); return ( } spacing="4"> {messages.map((item, idx) => ( From 79f84697cc8b141abd9284ab06e6a3dc68885bf6 Mon Sep 17 00:00:00 2001 From: Yannic Kilcher Date: Tue, 10 Jan 2023 07:46:53 +0100 Subject: [PATCH 27/39] changed pre-commit event type --- .github/workflows/pre-commit.yaml | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml index 0f82185f..47f21feb 100644 --- a/.github/workflows/pre-commit.yaml +++ b/.github/workflows/pre-commit.yaml @@ -4,20 +4,28 @@ on: push: branches: - main - pull_request: - workflow_call: + pull_request_target: jobs: pre-commit: runs-on: ubuntu-latest steps: + # in case of PR, check out the PR's head branch - uses: actions/checkout@v3 + if: github.event_name == 'pull_request_target' + with: + ref: ${{ github.event.pull_request.head.sha }} + + # in case of push, check out the main branch + - uses: actions/checkout@v3 + if: github.event_name == 'push' + - uses: actions/setup-python@v4 with: python-version: "3.10" - uses: pre-commit/action@v3.0.0 - name: Post PR comment on failure - if: failure() && github.event_name == 'pull_request' + if: failure() && github.event_name == 'pull_request_target' uses: peter-evans/create-or-update-comment@v2 with: issue-number: ${{ github.event.pull_request.number }} From 97c1f12e11a9d004fc19cc885d8b408d79a69a70 Mon Sep 17 00:00:00 2001 From: AbdBarho Date: Mon, 9 Jan 2023 20:08:26 +0100 Subject: [PATCH 28/39] Centralize task types --- website/src/components/Messages.tsx | 7 +-- website/src/components/Tasks/CreateTask.tsx | 5 +- website/src/components/Tasks/TaskTypes.tsx | 4 +- .../tasks/create/useCreateInitialPrompt.ts | 9 --- .../src/hooks/tasks/create/useCreateReply.ts | 24 -------- .../tasks/evaluate/useRankInitialPrompts.ts | 9 --- .../hooks/tasks/evaluate/useRankReplies.ts | 25 -------- .../tasks/labeling/useLabelAssistantReply.ts | 22 ------- .../tasks/labeling/useLabelInitialPrompt.tsx | 15 ----- .../tasks/labeling/useLabelPrompterReply.ts | 22 ------- .../hooks/tasks/labeling/useLabelingTask.ts | 20 ------- website/src/hooks/tasks/useCreateReply.ts | 8 +++ website/src/hooks/tasks/useGenericTaskAPI.tsx | 12 +--- website/src/hooks/tasks/useLabelingTask.ts | 32 +++++++++++ website/src/hooks/tasks/useRankReplies.ts | 12 ++++ website/src/pages/create/assistant_reply.tsx | 2 +- website/src/pages/create/initial_prompt.tsx | 2 +- website/src/pages/create/user_reply.tsx | 2 +- .../pages/evaluate/rank_assistant_replies.tsx | 2 +- .../pages/evaluate/rank_initial_prompts.tsx | 2 +- .../src/pages/evaluate/rank_user_replies.tsx | 2 +- .../src/pages/label/label_assistant_reply.tsx | 7 +-- .../src/pages/label/label_initial_prompt.tsx | 5 +- .../src/pages/label/label_prompter_reply.tsx | 7 +-- website/src/pages/messages/index.tsx | 4 +- website/src/types/Conversation.ts | 9 +++ website/src/types/Task.ts | 24 ++++++++ website/src/types/Tasks.ts | 57 +++++++++++++++++++ 28 files changed, 163 insertions(+), 188 deletions(-) delete mode 100644 website/src/hooks/tasks/create/useCreateInitialPrompt.ts delete mode 100644 website/src/hooks/tasks/create/useCreateReply.ts delete mode 100644 website/src/hooks/tasks/evaluate/useRankInitialPrompts.ts delete mode 100644 website/src/hooks/tasks/evaluate/useRankReplies.ts delete mode 100644 website/src/hooks/tasks/labeling/useLabelAssistantReply.ts delete mode 100644 website/src/hooks/tasks/labeling/useLabelInitialPrompt.tsx delete mode 100644 website/src/hooks/tasks/labeling/useLabelPrompterReply.ts delete mode 100644 website/src/hooks/tasks/labeling/useLabelingTask.ts create mode 100644 website/src/hooks/tasks/useCreateReply.ts create mode 100644 website/src/hooks/tasks/useLabelingTask.ts create mode 100644 website/src/hooks/tasks/useRankReplies.ts create mode 100644 website/src/types/Conversation.ts create mode 100644 website/src/types/Task.ts create mode 100644 website/src/types/Tasks.ts diff --git a/website/src/components/Messages.tsx b/website/src/components/Messages.tsx index bb97ab5a..c814a6d6 100644 --- a/website/src/components/Messages.tsx +++ b/website/src/components/Messages.tsx @@ -1,15 +1,10 @@ import { Grid } from "@chakra-ui/react"; import { forwardRef, useColorMode } from "@chakra-ui/react"; import { useMemo } from "react"; +import { Message } from "src/types/Conversation"; import { FlaggableElement } from "./FlaggableElement"; -export interface Message { - text: string; - is_assistant: boolean; - message_id: string; -} - export interface ValidLabel { name: string; display_text: string; diff --git a/website/src/components/Tasks/CreateTask.tsx b/website/src/components/Tasks/CreateTask.tsx index e02dcdeb..a56432b7 100644 --- a/website/src/components/Tasks/CreateTask.tsx +++ b/website/src/components/Tasks/CreateTask.tsx @@ -3,13 +3,14 @@ import { Messages } from "src/components/Messages"; import { TaskControls } from "src/components/Survey/TaskControls"; import { TrackedTextarea } from "src/components/Survey/TrackedTextarea"; import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards"; -import { TaskType } from "src/components/Tasks/TaskTypes"; +import {} from "src/components/Tasks/TaskTypes"; +import { TaskType } from "./TaskTypes"; export interface CreateTaskProps { // we need a task type // eslint-disable-next-line @typescript-eslint/no-explicit-any tasks: any[]; - taskType: TaskType; + taskType: TaskInfo; trigger: (update: { id: string; update_type: string; content: { text: string } }) => void; onSkipTask: (task: { id: string }, reason: string) => void; onNextTask: () => void; diff --git a/website/src/components/Tasks/TaskTypes.tsx b/website/src/components/Tasks/TaskTypes.tsx index 409e7038..d255756a 100644 --- a/website/src/components/Tasks/TaskTypes.tsx +++ b/website/src/components/Tasks/TaskTypes.tsx @@ -4,7 +4,7 @@ export enum TaskCategory { Label = "Label", } -export interface TaskType { +export interface TaskInfo { label: string; desc: string; category: TaskCategory; @@ -14,7 +14,7 @@ export interface TaskType { instruction?: string; } -export const TaskTypes: TaskType[] = [ +export const TaskTypes: TaskInfo[] = [ // create { label: "Create Initial Prompts", diff --git a/website/src/hooks/tasks/create/useCreateInitialPrompt.ts b/website/src/hooks/tasks/create/useCreateInitialPrompt.ts deleted file mode 100644 index cf0193e8..00000000 --- a/website/src/hooks/tasks/create/useCreateInitialPrompt.ts +++ /dev/null @@ -1,9 +0,0 @@ -import { useGenericTaskAPI } from "../useGenericTaskAPI"; - -interface CreateInitialPromptTask { - id: string; - type: "initial_prompt"; - hint: string; -} - -export const useCreateInitialPrompt = () => useGenericTaskAPI("initial_prompt"); diff --git a/website/src/hooks/tasks/create/useCreateReply.ts b/website/src/hooks/tasks/create/useCreateReply.ts deleted file mode 100644 index 0bc78319..00000000 --- a/website/src/hooks/tasks/create/useCreateReply.ts +++ /dev/null @@ -1,24 +0,0 @@ -import { useGenericTaskAPI } from "../useGenericTaskAPI"; - -interface BaseCreateReplyTask { - id: string; - conversation: { - messages: Array<{ - text: string; - is_assistant: boolean; - message_id: string; - }>; - }; -} - -export interface CreateAssistantReplyTask extends BaseCreateReplyTask { - type: "assistant_reply"; -} - -export interface CreatePrompterReplyTask extends BaseCreateReplyTask { - type: "prompter_reply"; -} - -export const useCreateAssistantReply = () => useGenericTaskAPI("assistant_reply"); - -export const useCreatePrompterReply = () => useGenericTaskAPI("prompter_reply"); diff --git a/website/src/hooks/tasks/evaluate/useRankInitialPrompts.ts b/website/src/hooks/tasks/evaluate/useRankInitialPrompts.ts deleted file mode 100644 index da772c80..00000000 --- a/website/src/hooks/tasks/evaluate/useRankInitialPrompts.ts +++ /dev/null @@ -1,9 +0,0 @@ -import { useGenericTaskAPI } from "../useGenericTaskAPI"; - -interface RankInitialPromptsTask { - id: string; - type: "rank_initial_prompts"; - prompts: string[]; -} - -export const useRankInitialPromptsTask = () => useGenericTaskAPI("rank_initial_prompts"); diff --git a/website/src/hooks/tasks/evaluate/useRankReplies.ts b/website/src/hooks/tasks/evaluate/useRankReplies.ts deleted file mode 100644 index 2d8d513f..00000000 --- a/website/src/hooks/tasks/evaluate/useRankReplies.ts +++ /dev/null @@ -1,25 +0,0 @@ -import { useGenericTaskAPI } from "../useGenericTaskAPI"; - -interface BaseRankRepliesTask { - id: string; - replies: string[]; - conversation: { - messages: Array<{ - text: string; - is_assistant: boolean; - message_id: string; - }>; - }; -} - -interface RankAssistantRepliesTask extends BaseRankRepliesTask { - type: "rank_assistant_replies"; -} - -interface RankPrompterRepliesTask extends BaseRankRepliesTask { - type: "rank_prompter_replies"; -} - -export const useRankAssistantRepliesTask = () => useGenericTaskAPI("rank_assistant_replies"); - -export const useRankPrompterRepliesTask = () => useGenericTaskAPI("rank_prompter_replies"); diff --git a/website/src/hooks/tasks/labeling/useLabelAssistantReply.ts b/website/src/hooks/tasks/labeling/useLabelAssistantReply.ts deleted file mode 100644 index 3c44046e..00000000 --- a/website/src/hooks/tasks/labeling/useLabelAssistantReply.ts +++ /dev/null @@ -1,22 +0,0 @@ -import { TaskResponse } from "../useGenericTaskAPI"; -import { LabelingTaskType, useLabelingTask } from "./useLabelingTask"; - -export interface LabelAssistantReplyTask { - id: string; - type: LabelingTaskType.label_assistant_reply; - message_id: string; - valid_labels: string[]; - reply: string; - conversation: { - messages: Array<{ - text: string; - is_assistant: boolean; - message_id: string; - }>; - }; -} - -export type LabelAssistantReplyTaskResponse = TaskResponse; - -export const useLabelAssistantReplyTask = () => - useLabelingTask(LabelingTaskType.label_assistant_reply); diff --git a/website/src/hooks/tasks/labeling/useLabelInitialPrompt.tsx b/website/src/hooks/tasks/labeling/useLabelInitialPrompt.tsx deleted file mode 100644 index f7ba8ab5..00000000 --- a/website/src/hooks/tasks/labeling/useLabelInitialPrompt.tsx +++ /dev/null @@ -1,15 +0,0 @@ -import { TaskResponse } from "../useGenericTaskAPI"; -import { LabelingTaskType, useLabelingTask } from "./useLabelingTask"; - -export interface LabelInitialPromptTask { - id: string; - type: LabelingTaskType.label_initial_prompt; - message_id: string; - valid_labels: string[]; - prompt: string; -} - -export type LabelInitialPromptTaskResponse = TaskResponse; - -export const useLabelInitialPromptTask = () => - useLabelingTask(LabelingTaskType.label_initial_prompt); diff --git a/website/src/hooks/tasks/labeling/useLabelPrompterReply.ts b/website/src/hooks/tasks/labeling/useLabelPrompterReply.ts deleted file mode 100644 index 9de2057f..00000000 --- a/website/src/hooks/tasks/labeling/useLabelPrompterReply.ts +++ /dev/null @@ -1,22 +0,0 @@ -import { TaskResponse } from "../useGenericTaskAPI"; -import { LabelingTaskType, useLabelingTask } from "./useLabelingTask"; - -export interface LabelPrompterReplyTask { - id: string; - type: LabelingTaskType.label_prompter_reply; - message_id: string; - valid_labels: string[]; - reply: string; - conversation: { - messages: Array<{ - text: string; - is_assistant: boolean; - message_id: string; - }>; - }; -} - -export type LabelPrompterReplyTaskResponse = TaskResponse; - -export const useLabelPrompterReplyTask = () => - useLabelingTask(LabelingTaskType.label_prompter_reply); diff --git a/website/src/hooks/tasks/labeling/useLabelingTask.ts b/website/src/hooks/tasks/labeling/useLabelingTask.ts deleted file mode 100644 index 27555284..00000000 --- a/website/src/hooks/tasks/labeling/useLabelingTask.ts +++ /dev/null @@ -1,20 +0,0 @@ -import { useGenericTaskAPI } from "../useGenericTaskAPI"; - -export const enum LabelingTaskType { - label_initial_prompt = "label_initial_prompt", - label_prompter_reply = "label_prompter_reply", - label_assistant_reply = "label_assistant_reply", -} - -export const useLabelingTask = (endpoint: LabelingTaskType) => { - const { tasks, isLoading, trigger, reset, error } = useGenericTaskAPI(endpoint); - - const submit = (id: string, message_id: string, text: string, validLabels: string[], labelWeights: number[]) => { - console.assert(validLabels.length === labelWeights.length); - const labels = Object.fromEntries(validLabels.map((label, i) => [label, labelWeights[i]])); - - return trigger({ id, update_type: "text_labels", content: { labels, text, message_id } }); - }; - - return { tasks, isLoading, submit, reset, error }; -}; diff --git a/website/src/hooks/tasks/useCreateReply.ts b/website/src/hooks/tasks/useCreateReply.ts new file mode 100644 index 00000000..23bc041d --- /dev/null +++ b/website/src/hooks/tasks/useCreateReply.ts @@ -0,0 +1,8 @@ +import { TaskType } from "src/types/Task"; +import { CreateAssistantReplyTask, CreateInitialPromptTask, CreatePrompterReplyTask } from "src/types/Tasks"; + +import { useGenericTaskAPI } from "./useGenericTaskAPI"; + +export const useCreateAssistantReply = () => useGenericTaskAPI(TaskType.assistant_reply); +export const useCreatePrompterReply = () => useGenericTaskAPI(TaskType.prompter_reply); +export const useCreateInitialPrompt = () => useGenericTaskAPI(TaskType.initial_prompt); diff --git a/website/src/hooks/tasks/useGenericTaskAPI.tsx b/website/src/hooks/tasks/useGenericTaskAPI.tsx index e300e220..4b9b3bae 100644 --- a/website/src/hooks/tasks/useGenericTaskAPI.tsx +++ b/website/src/hooks/tasks/useGenericTaskAPI.tsx @@ -2,19 +2,11 @@ import { useState } from "react"; import type { ValidLabel } from "src/components/Messages"; import fetcher from "src/lib/fetcher"; import poster from "src/lib/poster"; +import { BaseTask, TaskResponse } from "src/types/Task"; import useSWRImmutable from "swr/immutable"; import useSWRMutation from "swr/mutation"; -// TODO: type & centralize types for all tasks - -export interface TaskResponse { - id: string; - userId: string; - task: TaskType; - valid_labels: ValidLabel[]; -} - -export const useGenericTaskAPI = (taskApiEndpoint: string) => { +export const useGenericTaskAPI = (taskApiEndpoint: string) => { type ConcreteTaskResponse = TaskResponse; const [tasks, setTasks] = useState([]); diff --git a/website/src/hooks/tasks/useLabelingTask.ts b/website/src/hooks/tasks/useLabelingTask.ts new file mode 100644 index 00000000..5e5050ab --- /dev/null +++ b/website/src/hooks/tasks/useLabelingTask.ts @@ -0,0 +1,32 @@ +import { BaseTask, TaskResponse, TaskType } from "src/types/Task"; +import { LabelAssistantReplyTask, LabelInitialPromptTask, LabelPrompterReplyTask } from "src/types/Tasks"; + +import { useGenericTaskAPI } from "./useGenericTaskAPI"; + +const useLabelingTask = ( + endpoint: TaskType.label_assistant_reply | TaskType.label_prompter_reply | TaskType.label_initial_prompt +) => { + const { tasks, isLoading, trigger, reset, error } = useGenericTaskAPI(endpoint); + + const submit = (id: string, message_id: string, text: string, validLabels: string[], labelWeights: number[]) => { + console.assert(validLabels.length === labelWeights.length); + const labels = Object.fromEntries(validLabels.map((label, i) => [label, labelWeights[i]])); + + return trigger({ id, update_type: "text_labels", content: { labels, text, message_id } }); + }; + + return { tasks, isLoading, submit, reset, error }; +}; + +export type LabelAssistantReplyTaskResponse = TaskResponse; + +export const useLabelAssistantReplyTask = () => + useLabelingTask(TaskType.label_assistant_reply); + +export type LabelInitialPromptTaskResponse = TaskResponse; + +export const useLabelInitialPromptTask = () => useLabelingTask(TaskType.label_initial_prompt); + +export type LabelPrompterReplyTaskResponse = TaskResponse; + +export const useLabelPrompterReplyTask = () => useLabelingTask(TaskType.label_prompter_reply); diff --git a/website/src/hooks/tasks/useRankReplies.ts b/website/src/hooks/tasks/useRankReplies.ts new file mode 100644 index 00000000..d4accda0 --- /dev/null +++ b/website/src/hooks/tasks/useRankReplies.ts @@ -0,0 +1,12 @@ +import { TaskType } from "src/types/Task"; +import { RankAssistantRepliesTask, RankInitialPromptsTask, RankPrompterRepliesTask } from "src/types/Tasks"; + +import { useGenericTaskAPI } from "./useGenericTaskAPI"; + +export const useRankAssistantRepliesTask = () => + useGenericTaskAPI(TaskType.rank_assistant_replies); + +export const useRankPrompterRepliesTask = () => + useGenericTaskAPI(TaskType.rank_prompter_replies); + +export const useRankInitialPromptsTask = () => useGenericTaskAPI(TaskType.rank_initial_prompts); diff --git a/website/src/pages/create/assistant_reply.tsx b/website/src/pages/create/assistant_reply.tsx index e9aee226..17facd5d 100644 --- a/website/src/pages/create/assistant_reply.tsx +++ b/website/src/pages/create/assistant_reply.tsx @@ -3,7 +3,7 @@ import { useColorMode } from "@chakra-ui/react"; import Head from "next/head"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; -import { useCreateAssistantReply } from "src/hooks/tasks/create/useCreateReply"; +import { useCreateAssistantReply } from "src/hooks/tasks/useCreateReply"; const AssistantReply = () => { const { tasks, isLoading, reset, trigger } = useCreateAssistantReply(); diff --git a/website/src/pages/create/initial_prompt.tsx b/website/src/pages/create/initial_prompt.tsx index efea6474..57f0dabd 100644 --- a/website/src/pages/create/initial_prompt.tsx +++ b/website/src/pages/create/initial_prompt.tsx @@ -3,7 +3,7 @@ import { useColorMode } from "@chakra-ui/react"; import Head from "next/head"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; -import { useCreateInitialPrompt } from "src/hooks/tasks/create/useCreateInitialPrompt"; +import { useCreateInitialPrompt } from "src/hooks/tasks/useCreateReply"; const InitialPrompt = () => { const { tasks, isLoading, reset, trigger } = useCreateInitialPrompt(); diff --git a/website/src/pages/create/user_reply.tsx b/website/src/pages/create/user_reply.tsx index 2394bd63..a0af0e95 100644 --- a/website/src/pages/create/user_reply.tsx +++ b/website/src/pages/create/user_reply.tsx @@ -3,7 +3,7 @@ import Head from "next/head"; import { Container } from "src/components/Container"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; -import { useCreatePrompterReply } from "src/hooks/tasks/create/useCreateReply"; +import { useCreatePrompterReply } from "src/hooks/tasks/useCreateReply"; const UserReply = () => { const { tasks, isLoading, reset, trigger } = useCreatePrompterReply(); diff --git a/website/src/pages/evaluate/rank_assistant_replies.tsx b/website/src/pages/evaluate/rank_assistant_replies.tsx index 931c9194..8546e7a6 100644 --- a/website/src/pages/evaluate/rank_assistant_replies.tsx +++ b/website/src/pages/evaluate/rank_assistant_replies.tsx @@ -3,7 +3,7 @@ import Head from "next/head"; import { Container } from "src/components/Container"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; -import { useRankAssistantRepliesTask } from "src/hooks/tasks/evaluate/useRankReplies"; +import { useRankAssistantRepliesTask } from "src/hooks/tasks/useRankReplies"; const RankAssistantReplies = () => { const { tasks, isLoading, reset, trigger } = useRankAssistantRepliesTask(); diff --git a/website/src/pages/evaluate/rank_initial_prompts.tsx b/website/src/pages/evaluate/rank_initial_prompts.tsx index 4b717143..1898a93a 100644 --- a/website/src/pages/evaluate/rank_initial_prompts.tsx +++ b/website/src/pages/evaluate/rank_initial_prompts.tsx @@ -3,7 +3,7 @@ import Head from "next/head"; import { Container } from "src/components/Container"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; -import { useRankInitialPromptsTask } from "src/hooks/tasks/evaluate/useRankInitialPrompts"; +import { useRankInitialPromptsTask } from "src/hooks/tasks/useRankReplies"; const RankInitialPrompts = () => { const { tasks, isLoading, reset, trigger } = useRankInitialPromptsTask(); diff --git a/website/src/pages/evaluate/rank_user_replies.tsx b/website/src/pages/evaluate/rank_user_replies.tsx index 659874a2..e2a39977 100644 --- a/website/src/pages/evaluate/rank_user_replies.tsx +++ b/website/src/pages/evaluate/rank_user_replies.tsx @@ -3,7 +3,7 @@ import Head from "next/head"; import { Container } from "src/components/Container"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; -import { useRankPrompterRepliesTask } from "src/hooks/tasks/evaluate/useRankReplies"; +import { useRankPrompterRepliesTask } from "src/hooks/tasks/useRankReplies"; const RankUserReplies = () => { const { tasks, isLoading, reset, trigger } = useRankPrompterRepliesTask(); diff --git a/website/src/pages/label/label_assistant_reply.tsx b/website/src/pages/label/label_assistant_reply.tsx index 89b612ca..99c10f56 100644 --- a/website/src/pages/label/label_assistant_reply.tsx +++ b/website/src/pages/label/label_assistant_reply.tsx @@ -1,13 +1,10 @@ import { useState } from "react"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; -import { Message } from "src/components/Messages"; import { MessageTable } from "src/components/Messages/MessageTable"; import { TaskControls } from "src/components/Survey/TaskControls"; import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask"; -import { - LabelAssistantReplyTaskResponse, - useLabelAssistantReplyTask, -} from "src/hooks/tasks/labeling/useLabelAssistantReply"; +import { LabelAssistantReplyTaskResponse, useLabelAssistantReplyTask } from "src/hooks/tasks/useLabelingTask"; +import { Message } from "src/types/Conversation"; const LabelAssistantReply = () => { const [sliderValues, setSliderValues] = useState([]); diff --git a/website/src/pages/label/label_initial_prompt.tsx b/website/src/pages/label/label_initial_prompt.tsx index 3c791f23..4cd4343b 100644 --- a/website/src/pages/label/label_initial_prompt.tsx +++ b/website/src/pages/label/label_initial_prompt.tsx @@ -3,10 +3,7 @@ import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { MessageView } from "src/components/Messages"; import { TaskControls } from "src/components/Survey/TaskControls"; import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask"; -import { - LabelInitialPromptTaskResponse, - useLabelInitialPromptTask, -} from "src/hooks/tasks/labeling/useLabelInitialPrompt"; +import { LabelInitialPromptTaskResponse, useLabelInitialPromptTask } from "src/hooks/tasks/useLabelingTask"; const LabelInitialPrompt = () => { const [sliderValues, setSliderValues] = useState([]); diff --git a/website/src/pages/label/label_prompter_reply.tsx b/website/src/pages/label/label_prompter_reply.tsx index 812a3fcc..35654a47 100644 --- a/website/src/pages/label/label_prompter_reply.tsx +++ b/website/src/pages/label/label_prompter_reply.tsx @@ -1,13 +1,10 @@ import { useState } from "react"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; -import { Message } from "src/components/Messages"; import { MessageTable } from "src/components/Messages/MessageTable"; import { TaskControls } from "src/components/Survey/TaskControls"; import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask"; -import { - LabelPrompterReplyTaskResponse, - useLabelPrompterReplyTask, -} from "src/hooks/tasks/labeling/useLabelPrompterReply"; +import { LabelPrompterReplyTaskResponse, useLabelPrompterReplyTask } from "src/hooks/tasks/useLabelingTask"; +import { Message } from "src/types/Conversation"; const LabelPrompterReply = () => { const [sliderValues, setSliderValues] = useState([]); diff --git a/website/src/pages/messages/index.tsx b/website/src/pages/messages/index.tsx index 28ec9c54..ed48d47b 100644 --- a/website/src/pages/messages/index.tsx +++ b/website/src/pages/messages/index.tsx @@ -2,9 +2,9 @@ import { Box, CircularProgress, SimpleGrid, Text, useColorModeValue } from "@cha import Head from "next/head"; import { useEffect, useState } from "react"; import { getDashboardLayout } from "src/components/Layout"; -import { Message } from "src/components/Messages"; import { MessageTable } from "src/components/Messages/MessageTable"; import fetcher from "src/lib/fetcher"; +import { Message } from "src/types/Conversation"; import useSWRImmutable from "swr/immutable"; const MessagesDashboard = () => { @@ -82,6 +82,6 @@ const MessagesDashboard = () => { ); }; -MessagesDashboard.getLayout = (page) => getDashboardLayout(page); +MessagesDashboard.getLayout = getDashboardLayout; export default MessagesDashboard; diff --git a/website/src/types/Conversation.ts b/website/src/types/Conversation.ts new file mode 100644 index 00000000..f12b2781 --- /dev/null +++ b/website/src/types/Conversation.ts @@ -0,0 +1,9 @@ +export interface Message { + text: string; + is_assistant: boolean; + message_id: string; +} + +export interface Conversation { + messages: Message[]; +} diff --git a/website/src/types/Task.ts b/website/src/types/Task.ts new file mode 100644 index 00000000..0dca6a5b --- /dev/null +++ b/website/src/types/Task.ts @@ -0,0 +1,24 @@ +export const enum TaskType { + initial_prompt = "initial_prompt", + assistant_reply = "assistant_reply", + prompter_reply = "prompter_reply", + + rank_initial_prompts = "rank_initial_prompts", + rank_assistant_replies = "rank_assistant_replies", + rank_prompter_replies = "rank_prompter_replies", + + label_initial_prompt = "label_initial_prompt", + label_prompter_reply = "label_prompter_reply", + label_assistant_reply = "label_assistant_reply", +} + +export interface BaseTask { + id: string; + type: TaskType; +} + +export interface TaskResponse { + id: string; + userId: string; + task: Task; +} diff --git a/website/src/types/Tasks.ts b/website/src/types/Tasks.ts new file mode 100644 index 00000000..50c251bb --- /dev/null +++ b/website/src/types/Tasks.ts @@ -0,0 +1,57 @@ +import { Conversation } from "./Conversation"; +import { BaseTask, TaskType } from "./Task"; + +export interface CreateInitialPromptTask extends BaseTask { + type: TaskType.initial_prompt; + hint: string; +} + +export interface CreateAssistantReplyTask extends BaseTask { + type: TaskType.assistant_reply; + conversation: Conversation; +} + +export interface CreatePrompterReplyTask extends BaseTask { + type: TaskType.prompter_reply; + conversation: Conversation; +} + +export interface RankInitialPromptsTask extends BaseTask { + type: TaskType.rank_initial_prompts; + prompts: string[]; +} + +export interface RankAssistantRepliesTask extends BaseTask { + type: TaskType.rank_assistant_replies; + conversation: Conversation; + replies: string[]; +} + +export interface RankPrompterRepliesTask extends BaseTask { + type: TaskType.rank_prompter_replies; + conversation: Conversation; + replies: string[]; +} + +export interface LabelAssistantReplyTask extends BaseTask { + type: TaskType.label_assistant_reply; + message_id: string; + conversation: Conversation; + reply: string; + valid_labels: string[]; +} + +export interface LabelInitialPromptTask extends BaseTask { + type: TaskType.label_initial_prompt; + message_id: string; + valid_labels: string[]; + prompt: string; +} + +export interface LabelPrompterReplyTask extends BaseTask { + type: TaskType.label_prompter_reply; + message_id: string; + conversation: Conversation; + reply: string; + valid_labels: string[]; +} From 5b275ed804e813afe964511efbe0840d701b57bc Mon Sep 17 00:00:00 2001 From: Yannic Kilcher Date: Tue, 10 Jan 2023 07:46:53 +0100 Subject: [PATCH 29/39] changed pre-commit event type --- .github/workflows/pre-commit.yaml | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml index 0f82185f..47f21feb 100644 --- a/.github/workflows/pre-commit.yaml +++ b/.github/workflows/pre-commit.yaml @@ -4,20 +4,28 @@ on: push: branches: - main - pull_request: - workflow_call: + pull_request_target: jobs: pre-commit: runs-on: ubuntu-latest steps: + # in case of PR, check out the PR's head branch - uses: actions/checkout@v3 + if: github.event_name == 'pull_request_target' + with: + ref: ${{ github.event.pull_request.head.sha }} + + # in case of push, check out the main branch + - uses: actions/checkout@v3 + if: github.event_name == 'push' + - uses: actions/setup-python@v4 with: python-version: "3.10" - uses: pre-commit/action@v3.0.0 - name: Post PR comment on failure - if: failure() && github.event_name == 'pull_request' + if: failure() && github.event_name == 'pull_request_target' uses: peter-evans/create-or-update-comment@v2 with: issue-number: ${{ github.event.pull_request.number }} From 2c0463f6f89fd24ebd356a60d6525e0aeac8749f Mon Sep 17 00:00:00 2001 From: Keith Stevens Date: Tue, 10 Jan 2023 16:16:14 +0900 Subject: [PATCH 30/39] rebasing with main --- website/src/components/Messages.tsx | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/website/src/components/Messages.tsx b/website/src/components/Messages.tsx index c814a6d6..24f491f1 100644 --- a/website/src/components/Messages.tsx +++ b/website/src/components/Messages.tsx @@ -1,5 +1,4 @@ -import { Grid } from "@chakra-ui/react"; -import { forwardRef, useColorMode } from "@chakra-ui/react"; +import { Grid, forwardRef, useColorMode } from "@chakra-ui/react"; import { useMemo } from "react"; import { Message } from "src/types/Conversation"; From c4d5ed990bc5e65d262c4d2330d97a128d17a0de Mon Sep 17 00:00:00 2001 From: Keith Stevens Date: Tue, 10 Jan 2023 16:17:25 +0900 Subject: [PATCH 31/39] rebasing --- website/src/components/Tasks/CreateTask.tsx | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/website/src/components/Tasks/CreateTask.tsx b/website/src/components/Tasks/CreateTask.tsx index a56432b7..a424315a 100644 --- a/website/src/components/Tasks/CreateTask.tsx +++ b/website/src/components/Tasks/CreateTask.tsx @@ -3,8 +3,7 @@ import { Messages } from "src/components/Messages"; import { TaskControls } from "src/components/Survey/TaskControls"; import { TrackedTextarea } from "src/components/Survey/TrackedTextarea"; import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards"; -import {} from "src/components/Tasks/TaskTypes"; -import { TaskType } from "./TaskTypes"; +import { TaskInfo } from "src/components/Tasks/TaskTypes"; export interface CreateTaskProps { // we need a task type From fe4e949f2fb4f92404a5068993c7988ccef99730 Mon Sep 17 00:00:00 2001 From: Keith Stevens Date: Tue, 10 Jan 2023 16:23:56 +0900 Subject: [PATCH 32/39] Revert "rebasing" This reverts commit c4d5ed990bc5e65d262c4d2330d97a128d17a0de. --- website/src/components/Tasks/CreateTask.tsx | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/website/src/components/Tasks/CreateTask.tsx b/website/src/components/Tasks/CreateTask.tsx index a424315a..a56432b7 100644 --- a/website/src/components/Tasks/CreateTask.tsx +++ b/website/src/components/Tasks/CreateTask.tsx @@ -3,7 +3,8 @@ import { Messages } from "src/components/Messages"; import { TaskControls } from "src/components/Survey/TaskControls"; import { TrackedTextarea } from "src/components/Survey/TrackedTextarea"; import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards"; -import { TaskInfo } from "src/components/Tasks/TaskTypes"; +import {} from "src/components/Tasks/TaskTypes"; +import { TaskType } from "./TaskTypes"; export interface CreateTaskProps { // we need a task type From 81cb88615bf01e0a4ea8d9ee451a13bdf85f1498 Mon Sep 17 00:00:00 2001 From: Keith Stevens Date: Tue, 10 Jan 2023 16:24:15 +0900 Subject: [PATCH 33/39] Revert "rebasing with main" This reverts commit 2c0463f6f89fd24ebd356a60d6525e0aeac8749f. --- website/src/components/Messages.tsx | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/website/src/components/Messages.tsx b/website/src/components/Messages.tsx index 24f491f1..c814a6d6 100644 --- a/website/src/components/Messages.tsx +++ b/website/src/components/Messages.tsx @@ -1,4 +1,5 @@ -import { Grid, forwardRef, useColorMode } from "@chakra-ui/react"; +import { Grid } from "@chakra-ui/react"; +import { forwardRef, useColorMode } from "@chakra-ui/react"; import { useMemo } from "react"; import { Message } from "src/types/Conversation"; From 555113a6f222da5c17c4cd70d6b1966380528375 Mon Sep 17 00:00:00 2001 From: Keith Stevens Date: Tue, 10 Jan 2023 16:24:25 +0900 Subject: [PATCH 34/39] Revert "changed pre-commit event type" This reverts commit 5b275ed804e813afe964511efbe0840d701b57bc. --- .github/workflows/pre-commit.yaml | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml index 47f21feb..0f82185f 100644 --- a/.github/workflows/pre-commit.yaml +++ b/.github/workflows/pre-commit.yaml @@ -4,28 +4,20 @@ on: push: branches: - main - pull_request_target: + pull_request: + workflow_call: jobs: pre-commit: runs-on: ubuntu-latest steps: - # in case of PR, check out the PR's head branch - uses: actions/checkout@v3 - if: github.event_name == 'pull_request_target' - with: - ref: ${{ github.event.pull_request.head.sha }} - - # in case of push, check out the main branch - - uses: actions/checkout@v3 - if: github.event_name == 'push' - - uses: actions/setup-python@v4 with: python-version: "3.10" - uses: pre-commit/action@v3.0.0 - name: Post PR comment on failure - if: failure() && github.event_name == 'pull_request_target' + if: failure() && github.event_name == 'pull_request' uses: peter-evans/create-or-update-comment@v2 with: issue-number: ${{ github.event.pull_request.number }} From 062bfdba3afb36cb6fdf2f21d0f621adbb2b84e3 Mon Sep 17 00:00:00 2001 From: Keith Stevens Date: Tue, 10 Jan 2023 16:24:35 +0900 Subject: [PATCH 35/39] Revert "Centralize task types" This reverts commit 97c1f12e11a9d004fc19cc885d8b408d79a69a70. --- website/src/components/Messages.tsx | 7 ++- website/src/components/Tasks/CreateTask.tsx | 5 +- website/src/components/Tasks/TaskTypes.tsx | 4 +- .../tasks/create/useCreateInitialPrompt.ts | 9 +++ .../src/hooks/tasks/create/useCreateReply.ts | 24 ++++++++ .../tasks/evaluate/useRankInitialPrompts.ts | 9 +++ .../hooks/tasks/evaluate/useRankReplies.ts | 25 ++++++++ .../tasks/labeling/useLabelAssistantReply.ts | 22 +++++++ .../tasks/labeling/useLabelInitialPrompt.tsx | 15 +++++ .../tasks/labeling/useLabelPrompterReply.ts | 22 +++++++ .../hooks/tasks/labeling/useLabelingTask.ts | 20 +++++++ website/src/hooks/tasks/useCreateReply.ts | 8 --- website/src/hooks/tasks/useGenericTaskAPI.tsx | 12 +++- website/src/hooks/tasks/useLabelingTask.ts | 32 ----------- website/src/hooks/tasks/useRankReplies.ts | 12 ---- website/src/pages/create/assistant_reply.tsx | 2 +- website/src/pages/create/initial_prompt.tsx | 2 +- website/src/pages/create/user_reply.tsx | 2 +- .../pages/evaluate/rank_assistant_replies.tsx | 2 +- .../pages/evaluate/rank_initial_prompts.tsx | 2 +- .../src/pages/evaluate/rank_user_replies.tsx | 2 +- .../src/pages/label/label_assistant_reply.tsx | 7 ++- .../src/pages/label/label_initial_prompt.tsx | 5 +- .../src/pages/label/label_prompter_reply.tsx | 7 ++- website/src/pages/messages/index.tsx | 4 +- website/src/types/Conversation.ts | 9 --- website/src/types/Task.ts | 24 -------- website/src/types/Tasks.ts | 57 ------------------- 28 files changed, 188 insertions(+), 163 deletions(-) create mode 100644 website/src/hooks/tasks/create/useCreateInitialPrompt.ts create mode 100644 website/src/hooks/tasks/create/useCreateReply.ts create mode 100644 website/src/hooks/tasks/evaluate/useRankInitialPrompts.ts create mode 100644 website/src/hooks/tasks/evaluate/useRankReplies.ts create mode 100644 website/src/hooks/tasks/labeling/useLabelAssistantReply.ts create mode 100644 website/src/hooks/tasks/labeling/useLabelInitialPrompt.tsx create mode 100644 website/src/hooks/tasks/labeling/useLabelPrompterReply.ts create mode 100644 website/src/hooks/tasks/labeling/useLabelingTask.ts delete mode 100644 website/src/hooks/tasks/useCreateReply.ts delete mode 100644 website/src/hooks/tasks/useLabelingTask.ts delete mode 100644 website/src/hooks/tasks/useRankReplies.ts delete mode 100644 website/src/types/Conversation.ts delete mode 100644 website/src/types/Task.ts delete mode 100644 website/src/types/Tasks.ts diff --git a/website/src/components/Messages.tsx b/website/src/components/Messages.tsx index c814a6d6..bb97ab5a 100644 --- a/website/src/components/Messages.tsx +++ b/website/src/components/Messages.tsx @@ -1,10 +1,15 @@ import { Grid } from "@chakra-ui/react"; import { forwardRef, useColorMode } from "@chakra-ui/react"; import { useMemo } from "react"; -import { Message } from "src/types/Conversation"; import { FlaggableElement } from "./FlaggableElement"; +export interface Message { + text: string; + is_assistant: boolean; + message_id: string; +} + export interface ValidLabel { name: string; display_text: string; diff --git a/website/src/components/Tasks/CreateTask.tsx b/website/src/components/Tasks/CreateTask.tsx index a56432b7..e02dcdeb 100644 --- a/website/src/components/Tasks/CreateTask.tsx +++ b/website/src/components/Tasks/CreateTask.tsx @@ -3,14 +3,13 @@ import { Messages } from "src/components/Messages"; import { TaskControls } from "src/components/Survey/TaskControls"; import { TrackedTextarea } from "src/components/Survey/TrackedTextarea"; import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards"; -import {} from "src/components/Tasks/TaskTypes"; -import { TaskType } from "./TaskTypes"; +import { TaskType } from "src/components/Tasks/TaskTypes"; export interface CreateTaskProps { // we need a task type // eslint-disable-next-line @typescript-eslint/no-explicit-any tasks: any[]; - taskType: TaskInfo; + taskType: TaskType; trigger: (update: { id: string; update_type: string; content: { text: string } }) => void; onSkipTask: (task: { id: string }, reason: string) => void; onNextTask: () => void; diff --git a/website/src/components/Tasks/TaskTypes.tsx b/website/src/components/Tasks/TaskTypes.tsx index d255756a..409e7038 100644 --- a/website/src/components/Tasks/TaskTypes.tsx +++ b/website/src/components/Tasks/TaskTypes.tsx @@ -4,7 +4,7 @@ export enum TaskCategory { Label = "Label", } -export interface TaskInfo { +export interface TaskType { label: string; desc: string; category: TaskCategory; @@ -14,7 +14,7 @@ export interface TaskInfo { instruction?: string; } -export const TaskTypes: TaskInfo[] = [ +export const TaskTypes: TaskType[] = [ // create { label: "Create Initial Prompts", diff --git a/website/src/hooks/tasks/create/useCreateInitialPrompt.ts b/website/src/hooks/tasks/create/useCreateInitialPrompt.ts new file mode 100644 index 00000000..cf0193e8 --- /dev/null +++ b/website/src/hooks/tasks/create/useCreateInitialPrompt.ts @@ -0,0 +1,9 @@ +import { useGenericTaskAPI } from "../useGenericTaskAPI"; + +interface CreateInitialPromptTask { + id: string; + type: "initial_prompt"; + hint: string; +} + +export const useCreateInitialPrompt = () => useGenericTaskAPI("initial_prompt"); diff --git a/website/src/hooks/tasks/create/useCreateReply.ts b/website/src/hooks/tasks/create/useCreateReply.ts new file mode 100644 index 00000000..0bc78319 --- /dev/null +++ b/website/src/hooks/tasks/create/useCreateReply.ts @@ -0,0 +1,24 @@ +import { useGenericTaskAPI } from "../useGenericTaskAPI"; + +interface BaseCreateReplyTask { + id: string; + conversation: { + messages: Array<{ + text: string; + is_assistant: boolean; + message_id: string; + }>; + }; +} + +export interface CreateAssistantReplyTask extends BaseCreateReplyTask { + type: "assistant_reply"; +} + +export interface CreatePrompterReplyTask extends BaseCreateReplyTask { + type: "prompter_reply"; +} + +export const useCreateAssistantReply = () => useGenericTaskAPI("assistant_reply"); + +export const useCreatePrompterReply = () => useGenericTaskAPI("prompter_reply"); diff --git a/website/src/hooks/tasks/evaluate/useRankInitialPrompts.ts b/website/src/hooks/tasks/evaluate/useRankInitialPrompts.ts new file mode 100644 index 00000000..da772c80 --- /dev/null +++ b/website/src/hooks/tasks/evaluate/useRankInitialPrompts.ts @@ -0,0 +1,9 @@ +import { useGenericTaskAPI } from "../useGenericTaskAPI"; + +interface RankInitialPromptsTask { + id: string; + type: "rank_initial_prompts"; + prompts: string[]; +} + +export const useRankInitialPromptsTask = () => useGenericTaskAPI("rank_initial_prompts"); diff --git a/website/src/hooks/tasks/evaluate/useRankReplies.ts b/website/src/hooks/tasks/evaluate/useRankReplies.ts new file mode 100644 index 00000000..2d8d513f --- /dev/null +++ b/website/src/hooks/tasks/evaluate/useRankReplies.ts @@ -0,0 +1,25 @@ +import { useGenericTaskAPI } from "../useGenericTaskAPI"; + +interface BaseRankRepliesTask { + id: string; + replies: string[]; + conversation: { + messages: Array<{ + text: string; + is_assistant: boolean; + message_id: string; + }>; + }; +} + +interface RankAssistantRepliesTask extends BaseRankRepliesTask { + type: "rank_assistant_replies"; +} + +interface RankPrompterRepliesTask extends BaseRankRepliesTask { + type: "rank_prompter_replies"; +} + +export const useRankAssistantRepliesTask = () => useGenericTaskAPI("rank_assistant_replies"); + +export const useRankPrompterRepliesTask = () => useGenericTaskAPI("rank_prompter_replies"); diff --git a/website/src/hooks/tasks/labeling/useLabelAssistantReply.ts b/website/src/hooks/tasks/labeling/useLabelAssistantReply.ts new file mode 100644 index 00000000..3c44046e --- /dev/null +++ b/website/src/hooks/tasks/labeling/useLabelAssistantReply.ts @@ -0,0 +1,22 @@ +import { TaskResponse } from "../useGenericTaskAPI"; +import { LabelingTaskType, useLabelingTask } from "./useLabelingTask"; + +export interface LabelAssistantReplyTask { + id: string; + type: LabelingTaskType.label_assistant_reply; + message_id: string; + valid_labels: string[]; + reply: string; + conversation: { + messages: Array<{ + text: string; + is_assistant: boolean; + message_id: string; + }>; + }; +} + +export type LabelAssistantReplyTaskResponse = TaskResponse; + +export const useLabelAssistantReplyTask = () => + useLabelingTask(LabelingTaskType.label_assistant_reply); diff --git a/website/src/hooks/tasks/labeling/useLabelInitialPrompt.tsx b/website/src/hooks/tasks/labeling/useLabelInitialPrompt.tsx new file mode 100644 index 00000000..f7ba8ab5 --- /dev/null +++ b/website/src/hooks/tasks/labeling/useLabelInitialPrompt.tsx @@ -0,0 +1,15 @@ +import { TaskResponse } from "../useGenericTaskAPI"; +import { LabelingTaskType, useLabelingTask } from "./useLabelingTask"; + +export interface LabelInitialPromptTask { + id: string; + type: LabelingTaskType.label_initial_prompt; + message_id: string; + valid_labels: string[]; + prompt: string; +} + +export type LabelInitialPromptTaskResponse = TaskResponse; + +export const useLabelInitialPromptTask = () => + useLabelingTask(LabelingTaskType.label_initial_prompt); diff --git a/website/src/hooks/tasks/labeling/useLabelPrompterReply.ts b/website/src/hooks/tasks/labeling/useLabelPrompterReply.ts new file mode 100644 index 00000000..9de2057f --- /dev/null +++ b/website/src/hooks/tasks/labeling/useLabelPrompterReply.ts @@ -0,0 +1,22 @@ +import { TaskResponse } from "../useGenericTaskAPI"; +import { LabelingTaskType, useLabelingTask } from "./useLabelingTask"; + +export interface LabelPrompterReplyTask { + id: string; + type: LabelingTaskType.label_prompter_reply; + message_id: string; + valid_labels: string[]; + reply: string; + conversation: { + messages: Array<{ + text: string; + is_assistant: boolean; + message_id: string; + }>; + }; +} + +export type LabelPrompterReplyTaskResponse = TaskResponse; + +export const useLabelPrompterReplyTask = () => + useLabelingTask(LabelingTaskType.label_prompter_reply); diff --git a/website/src/hooks/tasks/labeling/useLabelingTask.ts b/website/src/hooks/tasks/labeling/useLabelingTask.ts new file mode 100644 index 00000000..27555284 --- /dev/null +++ b/website/src/hooks/tasks/labeling/useLabelingTask.ts @@ -0,0 +1,20 @@ +import { useGenericTaskAPI } from "../useGenericTaskAPI"; + +export const enum LabelingTaskType { + label_initial_prompt = "label_initial_prompt", + label_prompter_reply = "label_prompter_reply", + label_assistant_reply = "label_assistant_reply", +} + +export const useLabelingTask = (endpoint: LabelingTaskType) => { + const { tasks, isLoading, trigger, reset, error } = useGenericTaskAPI(endpoint); + + const submit = (id: string, message_id: string, text: string, validLabels: string[], labelWeights: number[]) => { + console.assert(validLabels.length === labelWeights.length); + const labels = Object.fromEntries(validLabels.map((label, i) => [label, labelWeights[i]])); + + return trigger({ id, update_type: "text_labels", content: { labels, text, message_id } }); + }; + + return { tasks, isLoading, submit, reset, error }; +}; diff --git a/website/src/hooks/tasks/useCreateReply.ts b/website/src/hooks/tasks/useCreateReply.ts deleted file mode 100644 index 23bc041d..00000000 --- a/website/src/hooks/tasks/useCreateReply.ts +++ /dev/null @@ -1,8 +0,0 @@ -import { TaskType } from "src/types/Task"; -import { CreateAssistantReplyTask, CreateInitialPromptTask, CreatePrompterReplyTask } from "src/types/Tasks"; - -import { useGenericTaskAPI } from "./useGenericTaskAPI"; - -export const useCreateAssistantReply = () => useGenericTaskAPI(TaskType.assistant_reply); -export const useCreatePrompterReply = () => useGenericTaskAPI(TaskType.prompter_reply); -export const useCreateInitialPrompt = () => useGenericTaskAPI(TaskType.initial_prompt); diff --git a/website/src/hooks/tasks/useGenericTaskAPI.tsx b/website/src/hooks/tasks/useGenericTaskAPI.tsx index 4b9b3bae..e300e220 100644 --- a/website/src/hooks/tasks/useGenericTaskAPI.tsx +++ b/website/src/hooks/tasks/useGenericTaskAPI.tsx @@ -2,11 +2,19 @@ import { useState } from "react"; import type { ValidLabel } from "src/components/Messages"; import fetcher from "src/lib/fetcher"; import poster from "src/lib/poster"; -import { BaseTask, TaskResponse } from "src/types/Task"; import useSWRImmutable from "swr/immutable"; import useSWRMutation from "swr/mutation"; -export const useGenericTaskAPI = (taskApiEndpoint: string) => { +// TODO: type & centralize types for all tasks + +export interface TaskResponse { + id: string; + userId: string; + task: TaskType; + valid_labels: ValidLabel[]; +} + +export const useGenericTaskAPI = (taskApiEndpoint: string) => { type ConcreteTaskResponse = TaskResponse; const [tasks, setTasks] = useState([]); diff --git a/website/src/hooks/tasks/useLabelingTask.ts b/website/src/hooks/tasks/useLabelingTask.ts deleted file mode 100644 index 5e5050ab..00000000 --- a/website/src/hooks/tasks/useLabelingTask.ts +++ /dev/null @@ -1,32 +0,0 @@ -import { BaseTask, TaskResponse, TaskType } from "src/types/Task"; -import { LabelAssistantReplyTask, LabelInitialPromptTask, LabelPrompterReplyTask } from "src/types/Tasks"; - -import { useGenericTaskAPI } from "./useGenericTaskAPI"; - -const useLabelingTask = ( - endpoint: TaskType.label_assistant_reply | TaskType.label_prompter_reply | TaskType.label_initial_prompt -) => { - const { tasks, isLoading, trigger, reset, error } = useGenericTaskAPI(endpoint); - - const submit = (id: string, message_id: string, text: string, validLabels: string[], labelWeights: number[]) => { - console.assert(validLabels.length === labelWeights.length); - const labels = Object.fromEntries(validLabels.map((label, i) => [label, labelWeights[i]])); - - return trigger({ id, update_type: "text_labels", content: { labels, text, message_id } }); - }; - - return { tasks, isLoading, submit, reset, error }; -}; - -export type LabelAssistantReplyTaskResponse = TaskResponse; - -export const useLabelAssistantReplyTask = () => - useLabelingTask(TaskType.label_assistant_reply); - -export type LabelInitialPromptTaskResponse = TaskResponse; - -export const useLabelInitialPromptTask = () => useLabelingTask(TaskType.label_initial_prompt); - -export type LabelPrompterReplyTaskResponse = TaskResponse; - -export const useLabelPrompterReplyTask = () => useLabelingTask(TaskType.label_prompter_reply); diff --git a/website/src/hooks/tasks/useRankReplies.ts b/website/src/hooks/tasks/useRankReplies.ts deleted file mode 100644 index d4accda0..00000000 --- a/website/src/hooks/tasks/useRankReplies.ts +++ /dev/null @@ -1,12 +0,0 @@ -import { TaskType } from "src/types/Task"; -import { RankAssistantRepliesTask, RankInitialPromptsTask, RankPrompterRepliesTask } from "src/types/Tasks"; - -import { useGenericTaskAPI } from "./useGenericTaskAPI"; - -export const useRankAssistantRepliesTask = () => - useGenericTaskAPI(TaskType.rank_assistant_replies); - -export const useRankPrompterRepliesTask = () => - useGenericTaskAPI(TaskType.rank_prompter_replies); - -export const useRankInitialPromptsTask = () => useGenericTaskAPI(TaskType.rank_initial_prompts); diff --git a/website/src/pages/create/assistant_reply.tsx b/website/src/pages/create/assistant_reply.tsx index 17facd5d..e9aee226 100644 --- a/website/src/pages/create/assistant_reply.tsx +++ b/website/src/pages/create/assistant_reply.tsx @@ -3,7 +3,7 @@ import { useColorMode } from "@chakra-ui/react"; import Head from "next/head"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; -import { useCreateAssistantReply } from "src/hooks/tasks/useCreateReply"; +import { useCreateAssistantReply } from "src/hooks/tasks/create/useCreateReply"; const AssistantReply = () => { const { tasks, isLoading, reset, trigger } = useCreateAssistantReply(); diff --git a/website/src/pages/create/initial_prompt.tsx b/website/src/pages/create/initial_prompt.tsx index 57f0dabd..efea6474 100644 --- a/website/src/pages/create/initial_prompt.tsx +++ b/website/src/pages/create/initial_prompt.tsx @@ -3,7 +3,7 @@ import { useColorMode } from "@chakra-ui/react"; import Head from "next/head"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; -import { useCreateInitialPrompt } from "src/hooks/tasks/useCreateReply"; +import { useCreateInitialPrompt } from "src/hooks/tasks/create/useCreateInitialPrompt"; const InitialPrompt = () => { const { tasks, isLoading, reset, trigger } = useCreateInitialPrompt(); diff --git a/website/src/pages/create/user_reply.tsx b/website/src/pages/create/user_reply.tsx index a0af0e95..2394bd63 100644 --- a/website/src/pages/create/user_reply.tsx +++ b/website/src/pages/create/user_reply.tsx @@ -3,7 +3,7 @@ import Head from "next/head"; import { Container } from "src/components/Container"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; -import { useCreatePrompterReply } from "src/hooks/tasks/useCreateReply"; +import { useCreatePrompterReply } from "src/hooks/tasks/create/useCreateReply"; const UserReply = () => { const { tasks, isLoading, reset, trigger } = useCreatePrompterReply(); diff --git a/website/src/pages/evaluate/rank_assistant_replies.tsx b/website/src/pages/evaluate/rank_assistant_replies.tsx index 8546e7a6..931c9194 100644 --- a/website/src/pages/evaluate/rank_assistant_replies.tsx +++ b/website/src/pages/evaluate/rank_assistant_replies.tsx @@ -3,7 +3,7 @@ import Head from "next/head"; import { Container } from "src/components/Container"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; -import { useRankAssistantRepliesTask } from "src/hooks/tasks/useRankReplies"; +import { useRankAssistantRepliesTask } from "src/hooks/tasks/evaluate/useRankReplies"; const RankAssistantReplies = () => { const { tasks, isLoading, reset, trigger } = useRankAssistantRepliesTask(); diff --git a/website/src/pages/evaluate/rank_initial_prompts.tsx b/website/src/pages/evaluate/rank_initial_prompts.tsx index 1898a93a..4b717143 100644 --- a/website/src/pages/evaluate/rank_initial_prompts.tsx +++ b/website/src/pages/evaluate/rank_initial_prompts.tsx @@ -3,7 +3,7 @@ import Head from "next/head"; import { Container } from "src/components/Container"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; -import { useRankInitialPromptsTask } from "src/hooks/tasks/useRankReplies"; +import { useRankInitialPromptsTask } from "src/hooks/tasks/evaluate/useRankInitialPrompts"; const RankInitialPrompts = () => { const { tasks, isLoading, reset, trigger } = useRankInitialPromptsTask(); diff --git a/website/src/pages/evaluate/rank_user_replies.tsx b/website/src/pages/evaluate/rank_user_replies.tsx index e2a39977..659874a2 100644 --- a/website/src/pages/evaluate/rank_user_replies.tsx +++ b/website/src/pages/evaluate/rank_user_replies.tsx @@ -3,7 +3,7 @@ import Head from "next/head"; import { Container } from "src/components/Container"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; -import { useRankPrompterRepliesTask } from "src/hooks/tasks/useRankReplies"; +import { useRankPrompterRepliesTask } from "src/hooks/tasks/evaluate/useRankReplies"; const RankUserReplies = () => { const { tasks, isLoading, reset, trigger } = useRankPrompterRepliesTask(); diff --git a/website/src/pages/label/label_assistant_reply.tsx b/website/src/pages/label/label_assistant_reply.tsx index 99c10f56..89b612ca 100644 --- a/website/src/pages/label/label_assistant_reply.tsx +++ b/website/src/pages/label/label_assistant_reply.tsx @@ -1,10 +1,13 @@ import { useState } from "react"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; +import { Message } from "src/components/Messages"; import { MessageTable } from "src/components/Messages/MessageTable"; import { TaskControls } from "src/components/Survey/TaskControls"; import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask"; -import { LabelAssistantReplyTaskResponse, useLabelAssistantReplyTask } from "src/hooks/tasks/useLabelingTask"; -import { Message } from "src/types/Conversation"; +import { + LabelAssistantReplyTaskResponse, + useLabelAssistantReplyTask, +} from "src/hooks/tasks/labeling/useLabelAssistantReply"; const LabelAssistantReply = () => { const [sliderValues, setSliderValues] = useState([]); diff --git a/website/src/pages/label/label_initial_prompt.tsx b/website/src/pages/label/label_initial_prompt.tsx index 4cd4343b..3c791f23 100644 --- a/website/src/pages/label/label_initial_prompt.tsx +++ b/website/src/pages/label/label_initial_prompt.tsx @@ -3,7 +3,10 @@ import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { MessageView } from "src/components/Messages"; import { TaskControls } from "src/components/Survey/TaskControls"; import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask"; -import { LabelInitialPromptTaskResponse, useLabelInitialPromptTask } from "src/hooks/tasks/useLabelingTask"; +import { + LabelInitialPromptTaskResponse, + useLabelInitialPromptTask, +} from "src/hooks/tasks/labeling/useLabelInitialPrompt"; const LabelInitialPrompt = () => { const [sliderValues, setSliderValues] = useState([]); diff --git a/website/src/pages/label/label_prompter_reply.tsx b/website/src/pages/label/label_prompter_reply.tsx index 35654a47..812a3fcc 100644 --- a/website/src/pages/label/label_prompter_reply.tsx +++ b/website/src/pages/label/label_prompter_reply.tsx @@ -1,10 +1,13 @@ import { useState } from "react"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; +import { Message } from "src/components/Messages"; import { MessageTable } from "src/components/Messages/MessageTable"; import { TaskControls } from "src/components/Survey/TaskControls"; import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask"; -import { LabelPrompterReplyTaskResponse, useLabelPrompterReplyTask } from "src/hooks/tasks/useLabelingTask"; -import { Message } from "src/types/Conversation"; +import { + LabelPrompterReplyTaskResponse, + useLabelPrompterReplyTask, +} from "src/hooks/tasks/labeling/useLabelPrompterReply"; const LabelPrompterReply = () => { const [sliderValues, setSliderValues] = useState([]); diff --git a/website/src/pages/messages/index.tsx b/website/src/pages/messages/index.tsx index ed48d47b..28ec9c54 100644 --- a/website/src/pages/messages/index.tsx +++ b/website/src/pages/messages/index.tsx @@ -2,9 +2,9 @@ import { Box, CircularProgress, SimpleGrid, Text, useColorModeValue } from "@cha import Head from "next/head"; import { useEffect, useState } from "react"; import { getDashboardLayout } from "src/components/Layout"; +import { Message } from "src/components/Messages"; import { MessageTable } from "src/components/Messages/MessageTable"; import fetcher from "src/lib/fetcher"; -import { Message } from "src/types/Conversation"; import useSWRImmutable from "swr/immutable"; const MessagesDashboard = () => { @@ -82,6 +82,6 @@ const MessagesDashboard = () => { ); }; -MessagesDashboard.getLayout = getDashboardLayout; +MessagesDashboard.getLayout = (page) => getDashboardLayout(page); export default MessagesDashboard; diff --git a/website/src/types/Conversation.ts b/website/src/types/Conversation.ts deleted file mode 100644 index f12b2781..00000000 --- a/website/src/types/Conversation.ts +++ /dev/null @@ -1,9 +0,0 @@ -export interface Message { - text: string; - is_assistant: boolean; - message_id: string; -} - -export interface Conversation { - messages: Message[]; -} diff --git a/website/src/types/Task.ts b/website/src/types/Task.ts deleted file mode 100644 index 0dca6a5b..00000000 --- a/website/src/types/Task.ts +++ /dev/null @@ -1,24 +0,0 @@ -export const enum TaskType { - initial_prompt = "initial_prompt", - assistant_reply = "assistant_reply", - prompter_reply = "prompter_reply", - - rank_initial_prompts = "rank_initial_prompts", - rank_assistant_replies = "rank_assistant_replies", - rank_prompter_replies = "rank_prompter_replies", - - label_initial_prompt = "label_initial_prompt", - label_prompter_reply = "label_prompter_reply", - label_assistant_reply = "label_assistant_reply", -} - -export interface BaseTask { - id: string; - type: TaskType; -} - -export interface TaskResponse { - id: string; - userId: string; - task: Task; -} diff --git a/website/src/types/Tasks.ts b/website/src/types/Tasks.ts deleted file mode 100644 index 50c251bb..00000000 --- a/website/src/types/Tasks.ts +++ /dev/null @@ -1,57 +0,0 @@ -import { Conversation } from "./Conversation"; -import { BaseTask, TaskType } from "./Task"; - -export interface CreateInitialPromptTask extends BaseTask { - type: TaskType.initial_prompt; - hint: string; -} - -export interface CreateAssistantReplyTask extends BaseTask { - type: TaskType.assistant_reply; - conversation: Conversation; -} - -export interface CreatePrompterReplyTask extends BaseTask { - type: TaskType.prompter_reply; - conversation: Conversation; -} - -export interface RankInitialPromptsTask extends BaseTask { - type: TaskType.rank_initial_prompts; - prompts: string[]; -} - -export interface RankAssistantRepliesTask extends BaseTask { - type: TaskType.rank_assistant_replies; - conversation: Conversation; - replies: string[]; -} - -export interface RankPrompterRepliesTask extends BaseTask { - type: TaskType.rank_prompter_replies; - conversation: Conversation; - replies: string[]; -} - -export interface LabelAssistantReplyTask extends BaseTask { - type: TaskType.label_assistant_reply; - message_id: string; - conversation: Conversation; - reply: string; - valid_labels: string[]; -} - -export interface LabelInitialPromptTask extends BaseTask { - type: TaskType.label_initial_prompt; - message_id: string; - valid_labels: string[]; - prompt: string; -} - -export interface LabelPrompterReplyTask extends BaseTask { - type: TaskType.label_prompter_reply; - message_id: string; - conversation: Conversation; - reply: string; - valid_labels: string[]; -} From 9e14de25706f3718eef6d8aaf48400b7358b40f0 Mon Sep 17 00:00:00 2001 From: Keith Stevens Date: Tue, 10 Jan 2023 16:30:44 +0900 Subject: [PATCH 36/39] Moving the ValidLabel type --- website/src/components/Messages.tsx | 7 +------ website/src/components/Messages/MessageTableEntry.tsx | 2 +- website/src/hooks/tasks/useGenericTaskAPI.tsx | 1 - website/src/types/Task.ts | 7 +++++++ 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/website/src/components/Messages.tsx b/website/src/components/Messages.tsx index c814a6d6..b49662ed 100644 --- a/website/src/components/Messages.tsx +++ b/website/src/components/Messages.tsx @@ -2,15 +2,10 @@ import { Grid } from "@chakra-ui/react"; import { forwardRef, useColorMode } from "@chakra-ui/react"; import { useMemo } from "react"; import { Message } from "src/types/Conversation"; +import { ValidLabel } from "src/types/Task"; import { FlaggableElement } from "./FlaggableElement"; -export interface ValidLabel { - name: string; - display_text: string; - help_text: string; -} - export const Messages = ({ messages, post_id, diff --git a/website/src/components/Messages/MessageTableEntry.tsx b/website/src/components/Messages/MessageTableEntry.tsx index 0f58efad..e9e8775a 100644 --- a/website/src/components/Messages/MessageTableEntry.tsx +++ b/website/src/components/Messages/MessageTableEntry.tsx @@ -2,7 +2,7 @@ import { Avatar, HStack, LinkBox, useColorModeValue } from "@chakra-ui/react"; import { boolean } from "boolean"; import NextLink from "next/link"; import { FlaggableElement } from "src/components/FlaggableElement"; -import type { ValidLabel } from "src/components/Messages"; +import type { ValidLabel } from "src/types/Task"; interface Message { text: string; diff --git a/website/src/hooks/tasks/useGenericTaskAPI.tsx b/website/src/hooks/tasks/useGenericTaskAPI.tsx index 4b9b3bae..a456cbf1 100644 --- a/website/src/hooks/tasks/useGenericTaskAPI.tsx +++ b/website/src/hooks/tasks/useGenericTaskAPI.tsx @@ -1,5 +1,4 @@ import { useState } from "react"; -import type { ValidLabel } from "src/components/Messages"; import fetcher from "src/lib/fetcher"; import poster from "src/lib/poster"; import { BaseTask, TaskResponse } from "src/types/Task"; diff --git a/website/src/types/Task.ts b/website/src/types/Task.ts index 0dca6a5b..6975fa14 100644 --- a/website/src/types/Task.ts +++ b/website/src/types/Task.ts @@ -12,6 +12,12 @@ export const enum TaskType { label_assistant_reply = "label_assistant_reply", } +export interface ValidLabel { + name: string; + display_text: string; + help_text: string; +} + export interface BaseTask { id: string; type: TaskType; @@ -21,4 +27,5 @@ export interface TaskResponse { id: string; userId: string; task: Task; + valid_labels: ValidLabel[]; } From 80268136b0aeded35724169dde17d89cd39611be Mon Sep 17 00:00:00 2001 From: Keith Stevens Date: Tue, 10 Jan 2023 17:52:48 +0900 Subject: [PATCH 37/39] A simpler fix to ensuring the messages view can render --- website/src/components/FlaggableElement.tsx | 15 ++++++++------- website/src/components/Messages/MessageTable.tsx | 2 +- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/website/src/components/FlaggableElement.tsx b/website/src/components/FlaggableElement.tsx index 9606f425..a7157c4a 100644 --- a/website/src/components/FlaggableElement.tsx +++ b/website/src/components/FlaggableElement.tsx @@ -36,13 +36,14 @@ interface textFlagLabels { export const FlaggableElement = (props) => { const [isEditing, setIsEditing] = useBoolean(); const flaggable_labels = props.flaggable_labels; - const TEXT_LABEL_FLAGS = flaggable_labels.valid_labels.map((valid_label) => { - return { - attributeName: valid_label.name, - labelText: valid_label.display_text, - additionalExplanation: valid_label.help_text, - }; - }); + const TEXT_LABEL_FLAGS = + flaggable_labels?.valid_labels?.map((valid_label) => { + return { + attributeName: valid_label.name, + labelText: valid_label.display_text, + additionalExplanation: valid_label.help_text, + }; + }) || []; const { trigger } = useSWRMutation("/api/set_label", poster, { onSuccess: () => { setIsEditing.off(); diff --git a/website/src/components/Messages/MessageTable.tsx b/website/src/components/Messages/MessageTable.tsx index bacd27f9..872b79f1 100644 --- a/website/src/components/Messages/MessageTable.tsx +++ b/website/src/components/Messages/MessageTable.tsx @@ -5,7 +5,7 @@ export function MessageTable({ messages, valid_labels }) { return ( } spacing="4"> {messages.map((item, idx) => ( - + ))} ); From 15e1203be91934a6509ecb692b37e37d61df7723 Mon Sep 17 00:00:00 2001 From: Adrian Cowan Date: Tue, 10 Jan 2023 21:08:40 +1100 Subject: [PATCH 38/39] website: Refactor remaining task pages to use Task.tsx --- website/src/components/Tasks/LabelTask.tsx | 86 ++++++++++++------- website/src/components/Tasks/Task.tsx | 18 +++- website/src/components/Tasks/TaskTypes.tsx | 3 + website/src/hooks/tasks/useLabelingTask.ts | 31 +------ website/src/pages/create/assistant_reply.tsx | 6 +- website/src/pages/create/initial_prompt.tsx | 6 +- website/src/pages/create/user_reply.tsx | 6 +- .../pages/evaluate/rank_assistant_replies.tsx | 6 +- .../pages/evaluate/rank_initial_prompts.tsx | 6 +- .../src/pages/evaluate/rank_user_replies.tsx | 6 +- .../src/pages/label/label_assistant_reply.tsx | 48 ++++------- .../src/pages/label/label_initial_prompt.tsx | 42 ++++----- .../src/pages/label/label_prompter_reply.tsx | 48 ++++------- 13 files changed, 135 insertions(+), 177 deletions(-) diff --git a/website/src/components/Tasks/LabelTask.tsx b/website/src/components/Tasks/LabelTask.tsx index bb9d417c..966c0a53 100644 --- a/website/src/components/Tasks/LabelTask.tsx +++ b/website/src/components/Tasks/LabelTask.tsx @@ -1,43 +1,71 @@ import { Grid, Slider, SliderFilledTrack, SliderThumb, SliderTrack } from "@chakra-ui/react"; import { useColorMode } from "@chakra-ui/react"; -import { ReactNode, useEffect, useId, useMemo, useState } from "react"; +import { useEffect, useId, useState } from "react"; +import { MessageView } from "src/components/Messages"; +import { MessageTable } from "src/components/Messages/MessageTable"; +import { TaskControls } from "src/components/Survey/TaskControls"; import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards"; +import { TaskInfo } from "src/components/Tasks/TaskTypes"; +import { TaskType } from "src/types/Task"; import { colors } from "styles/Theme/colors"; -export const LabelTask = ({ - title, - desc, - messages, - inputs, - controls, -}: { - title: string; - desc: string; - messages: ReactNode; - inputs: ReactNode; - controls: ReactNode; -}) => { - const { colorMode } = useColorMode(); - const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white"; +export interface LabelTaskProps { + // we need a task type + // eslint-disable-next-line @typescript-eslint/no-explicit-any + tasks: any[]; + taskType: TaskInfo; + trigger: (update: { + id: string; + update_type: string; + content: { text: string; labels: { [k: string]: number }; message_id: string }; + }) => void; + onSkipTask: (task: { id: string }, reason: string) => void; + onNextTask: () => void; + mainBgClasses: string; +} +export const LabelTask = ({ tasks, taskType, trigger, onSkipTask, onNextTask, mainBgClasses }: LabelTaskProps) => { + const task = tasks[0].task; + const valid_labels = tasks[0].valid_labels; - const card = useMemo( - () => ( - <> -
{title}
-

{desc}

- {messages} - - ), - [title, desc, messages] - ); + const [sliderValues, setSliderValues] = useState([]); + + const submitResponse = (task: { id: string; reply: string; message_id: string }) => { + console.assert(valid_labels.length === sliderValues.length); + const labels = Object.fromEntries(valid_labels.valid_labels.map((label, i) => [label, sliderValues[i]])); + trigger({ + id: task.id, + update_type: "text_labels", + content: { labels, text: task.reply, message_id: task.message_id }, + }); + }; return (
- {card} - {inputs} + <> +
{taskType.label}
+

{taskType.overview}

+ + {task.conversation ? ( + + ) : ( + + )} + +
- {controls} + +
); }; diff --git a/website/src/components/Tasks/Task.tsx b/website/src/components/Tasks/Task.tsx index 777f5dd5..e95fe3e2 100644 --- a/website/src/components/Tasks/Task.tsx +++ b/website/src/components/Tasks/Task.tsx @@ -1,12 +1,17 @@ +import { useColorMode } from "@chakra-ui/react"; import { CreateTask } from "src/components/Tasks/CreateTask"; import { EvaluateTask } from "src/components/Tasks/EvaluateTask"; +import { LabelTask } from "src/components/Tasks/LabelTask"; import { TaskCategory, TaskTypes } from "src/components/Tasks/TaskTypes"; import poster from "src/lib/poster"; import useSWRMutation from "swr/mutation"; -export const Task = ({ tasks, trigger, mutate, mainBgClasses }) => { +export const Task = ({ tasks, trigger, mutate }) => { const task = tasks[0].task; + const { colorMode } = useColorMode(); + const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white"; + const { trigger: sendRejection } = useSWRMutation("/api/reject_task", poster, { onSuccess: async () => { mutate(); @@ -45,6 +50,17 @@ export const Task = ({ tasks, trigger, mutate, mainBgClasses }) => { mainBgClasses={mainBgClasses} /> ); + case TaskCategory.Label: + return ( + + ); } } diff --git a/website/src/components/Tasks/TaskTypes.tsx b/website/src/components/Tasks/TaskTypes.tsx index d255756a..c9a978d6 100644 --- a/website/src/components/Tasks/TaskTypes.tsx +++ b/website/src/components/Tasks/TaskTypes.tsx @@ -71,6 +71,7 @@ export const TaskTypes: TaskInfo[] = [ desc: "Provide labels for a prompt.", category: TaskCategory.Label, pathname: "/label/label_initial_prompt", + overview: "Provide labels for the following prompt", type: "label_initial_prompt", }, { @@ -78,6 +79,7 @@ export const TaskTypes: TaskInfo[] = [ desc: "Provide labels for a prompt.", category: TaskCategory.Label, pathname: "/label/label_prompter_reply", + overview: "Given the following discussion, provide labels for the final promp", type: "label_prompter_reply", }, { @@ -85,6 +87,7 @@ export const TaskTypes: TaskInfo[] = [ desc: "Provide labels for a prompt.", category: TaskCategory.Label, pathname: "/label/label_assistant_reply", + overview: "Given the following discussion, provide labels for the final prompt.", type: "label_assistant_reply", }, ]; diff --git a/website/src/hooks/tasks/useLabelingTask.ts b/website/src/hooks/tasks/useLabelingTask.ts index 5e5050ab..3782c7a3 100644 --- a/website/src/hooks/tasks/useLabelingTask.ts +++ b/website/src/hooks/tasks/useLabelingTask.ts @@ -1,32 +1,9 @@ -import { BaseTask, TaskResponse, TaskType } from "src/types/Task"; +import { TaskType } from "src/types/Task"; import { LabelAssistantReplyTask, LabelInitialPromptTask, LabelPrompterReplyTask } from "src/types/Tasks"; import { useGenericTaskAPI } from "./useGenericTaskAPI"; -const useLabelingTask = ( - endpoint: TaskType.label_assistant_reply | TaskType.label_prompter_reply | TaskType.label_initial_prompt -) => { - const { tasks, isLoading, trigger, reset, error } = useGenericTaskAPI(endpoint); - - const submit = (id: string, message_id: string, text: string, validLabels: string[], labelWeights: number[]) => { - console.assert(validLabels.length === labelWeights.length); - const labels = Object.fromEntries(validLabels.map((label, i) => [label, labelWeights[i]])); - - return trigger({ id, update_type: "text_labels", content: { labels, text, message_id } }); - }; - - return { tasks, isLoading, submit, reset, error }; -}; - -export type LabelAssistantReplyTaskResponse = TaskResponse; - export const useLabelAssistantReplyTask = () => - useLabelingTask(TaskType.label_assistant_reply); - -export type LabelInitialPromptTaskResponse = TaskResponse; - -export const useLabelInitialPromptTask = () => useLabelingTask(TaskType.label_initial_prompt); - -export type LabelPrompterReplyTaskResponse = TaskResponse; - -export const useLabelPrompterReplyTask = () => useLabelingTask(TaskType.label_prompter_reply); + useGenericTaskAPI(TaskType.label_assistant_reply); +export const useLabelInitialPromptTask = () => useGenericTaskAPI(TaskType.label_initial_prompt); +export const useLabelPrompterReplyTask = () => useGenericTaskAPI(TaskType.label_prompter_reply); diff --git a/website/src/pages/create/assistant_reply.tsx b/website/src/pages/create/assistant_reply.tsx index 17facd5d..6e389190 100644 --- a/website/src/pages/create/assistant_reply.tsx +++ b/website/src/pages/create/assistant_reply.tsx @@ -1,5 +1,4 @@ import { Container } from "@chakra-ui/react"; -import { useColorMode } from "@chakra-ui/react"; import Head from "next/head"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; @@ -8,9 +7,6 @@ import { useCreateAssistantReply } from "src/hooks/tasks/useCreateReply"; const AssistantReply = () => { const { tasks, isLoading, reset, trigger } = useCreateAssistantReply(); - const { colorMode } = useColorMode(); - const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white"; - if (isLoading) { return ; } @@ -25,7 +21,7 @@ const AssistantReply = () => { Reply as Assistant - + ); }; diff --git a/website/src/pages/create/initial_prompt.tsx b/website/src/pages/create/initial_prompt.tsx index 57f0dabd..20b36467 100644 --- a/website/src/pages/create/initial_prompt.tsx +++ b/website/src/pages/create/initial_prompt.tsx @@ -1,5 +1,4 @@ import { Container } from "@chakra-ui/react"; -import { useColorMode } from "@chakra-ui/react"; import Head from "next/head"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; @@ -8,9 +7,6 @@ import { useCreateInitialPrompt } from "src/hooks/tasks/useCreateReply"; const InitialPrompt = () => { const { tasks, isLoading, reset, trigger } = useCreateInitialPrompt(); - const { colorMode } = useColorMode(); - const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white"; - if (isLoading) { return ; } @@ -25,7 +21,7 @@ const InitialPrompt = () => { Reply as Assistant - + ); }; diff --git a/website/src/pages/create/user_reply.tsx b/website/src/pages/create/user_reply.tsx index a0af0e95..41969e54 100644 --- a/website/src/pages/create/user_reply.tsx +++ b/website/src/pages/create/user_reply.tsx @@ -1,4 +1,3 @@ -import { useColorMode } from "@chakra-ui/react"; import Head from "next/head"; import { Container } from "src/components/Container"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; @@ -8,9 +7,6 @@ import { useCreatePrompterReply } from "src/hooks/tasks/useCreateReply"; const UserReply = () => { const { tasks, isLoading, reset, trigger } = useCreatePrompterReply(); - const { colorMode } = useColorMode(); - const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white"; - if (isLoading) { return ; } @@ -25,7 +21,7 @@ const UserReply = () => { Reply as Assistant - + ); }; diff --git a/website/src/pages/evaluate/rank_assistant_replies.tsx b/website/src/pages/evaluate/rank_assistant_replies.tsx index 8546e7a6..16eee130 100644 --- a/website/src/pages/evaluate/rank_assistant_replies.tsx +++ b/website/src/pages/evaluate/rank_assistant_replies.tsx @@ -1,4 +1,3 @@ -import { useColorMode } from "@chakra-ui/react"; import Head from "next/head"; import { Container } from "src/components/Container"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; @@ -8,9 +7,6 @@ import { useRankAssistantRepliesTask } from "src/hooks/tasks/useRankReplies"; const RankAssistantReplies = () => { const { tasks, isLoading, reset, trigger } = useRankAssistantRepliesTask(); - const { colorMode } = useColorMode(); - const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white"; - if (isLoading) { return ; } @@ -25,7 +21,7 @@ const RankAssistantReplies = () => { Rank Assistant Replies - + ); }; diff --git a/website/src/pages/evaluate/rank_initial_prompts.tsx b/website/src/pages/evaluate/rank_initial_prompts.tsx index 1898a93a..0b305192 100644 --- a/website/src/pages/evaluate/rank_initial_prompts.tsx +++ b/website/src/pages/evaluate/rank_initial_prompts.tsx @@ -1,4 +1,3 @@ -import { useColorMode } from "@chakra-ui/react"; import Head from "next/head"; import { Container } from "src/components/Container"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; @@ -8,9 +7,6 @@ import { useRankInitialPromptsTask } from "src/hooks/tasks/useRankReplies"; const RankInitialPrompts = () => { const { tasks, isLoading, reset, trigger } = useRankInitialPromptsTask(); - const { colorMode } = useColorMode(); - const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white"; - if (isLoading) { return ; } @@ -25,7 +21,7 @@ const RankInitialPrompts = () => { Rank Initial Prompts - + ); }; diff --git a/website/src/pages/evaluate/rank_user_replies.tsx b/website/src/pages/evaluate/rank_user_replies.tsx index e2a39977..c269745f 100644 --- a/website/src/pages/evaluate/rank_user_replies.tsx +++ b/website/src/pages/evaluate/rank_user_replies.tsx @@ -1,4 +1,3 @@ -import { useColorMode } from "@chakra-ui/react"; import Head from "next/head"; import { Container } from "src/components/Container"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; @@ -8,9 +7,6 @@ import { useRankPrompterRepliesTask } from "src/hooks/tasks/useRankReplies"; const RankUserReplies = () => { const { tasks, isLoading, reset, trigger } = useRankPrompterRepliesTask(); - const { colorMode } = useColorMode(); - const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white"; - if (isLoading) { return ; } @@ -25,7 +21,7 @@ const RankUserReplies = () => { Rank User Replies - + ); }; diff --git a/website/src/pages/label/label_assistant_reply.tsx b/website/src/pages/label/label_assistant_reply.tsx index 99c10f56..945a612e 100644 --- a/website/src/pages/label/label_assistant_reply.tsx +++ b/website/src/pages/label/label_assistant_reply.tsx @@ -1,44 +1,28 @@ -import { useState } from "react"; +import { Container } from "@chakra-ui/react"; +import Head from "next/head"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; -import { MessageTable } from "src/components/Messages/MessageTable"; -import { TaskControls } from "src/components/Survey/TaskControls"; -import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask"; -import { LabelAssistantReplyTaskResponse, useLabelAssistantReplyTask } from "src/hooks/tasks/useLabelingTask"; -import { Message } from "src/types/Conversation"; +import { Task } from "src/components/Tasks/Task"; +import { useLabelAssistantReplyTask } from "src/hooks/tasks/useLabelingTask"; const LabelAssistantReply = () => { - const [sliderValues, setSliderValues] = useState([]); + const { tasks, isLoading, trigger, reset } = useLabelAssistantReplyTask(); - const { tasks, isLoading, submit, reset } = useLabelAssistantReplyTask(); - - if (isLoading || tasks.length === 0) { + if (isLoading) { return ; } - const task = tasks[0].task; - const valid_labels = tasks[0].valid_labels; - const messages: Message[] = [ - ...task.conversation.messages, - { text: task.reply, is_assistant: true, message_id: task.message_id }, - ]; + if (tasks.length === 0) { + return No tasks found...; + } return ( - } - inputs={} - controls={ - reset()} - onNextTask={reset} - onSubmitResponse={({ id, task }: LabelAssistantReplyTaskResponse) => - submit(id, task.message_id, task.reply, task.valid_labels, sliderValues) - } - /> - } - /> + <> + + Label Assistant Reply + + + + ); }; diff --git a/website/src/pages/label/label_initial_prompt.tsx b/website/src/pages/label/label_initial_prompt.tsx index 4cd4343b..bfacfdbe 100644 --- a/website/src/pages/label/label_initial_prompt.tsx +++ b/website/src/pages/label/label_initial_prompt.tsx @@ -1,38 +1,28 @@ -import { useState } from "react"; +import { Container } from "@chakra-ui/react"; +import Head from "next/head"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; -import { MessageView } from "src/components/Messages"; -import { TaskControls } from "src/components/Survey/TaskControls"; -import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask"; -import { LabelInitialPromptTaskResponse, useLabelInitialPromptTask } from "src/hooks/tasks/useLabelingTask"; +import { Task } from "src/components/Tasks/Task"; +import { useLabelInitialPromptTask } from "src/hooks/tasks/useLabelingTask"; const LabelInitialPrompt = () => { - const [sliderValues, setSliderValues] = useState([]); + const { tasks, isLoading, trigger, reset } = useLabelInitialPromptTask(); - const { tasks, isLoading, submit, reset } = useLabelInitialPromptTask(); - - if (isLoading || tasks.length === 0) { + if (isLoading) { return ; } - const task = tasks[0].task; + if (tasks.length === 0) { + return No tasks found...; + } return ( - } - inputs={} - controls={ - reset()} - onNextTask={reset} - onSubmitResponse={({ id, task }: LabelInitialPromptTaskResponse) => - submit(id, task.message_id, task.prompt, task.valid_labels, sliderValues) - } - /> - } - /> + <> + + Label Initial Prompt + + + + ); }; diff --git a/website/src/pages/label/label_prompter_reply.tsx b/website/src/pages/label/label_prompter_reply.tsx index 35654a47..3d47f74b 100644 --- a/website/src/pages/label/label_prompter_reply.tsx +++ b/website/src/pages/label/label_prompter_reply.tsx @@ -1,44 +1,28 @@ -import { useState } from "react"; +import { Container } from "@chakra-ui/react"; +import Head from "next/head"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; -import { MessageTable } from "src/components/Messages/MessageTable"; -import { TaskControls } from "src/components/Survey/TaskControls"; -import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask"; -import { LabelPrompterReplyTaskResponse, useLabelPrompterReplyTask } from "src/hooks/tasks/useLabelingTask"; -import { Message } from "src/types/Conversation"; +import { Task } from "src/components/Tasks/Task"; +import { useLabelPrompterReplyTask } from "src/hooks/tasks/useLabelingTask"; const LabelPrompterReply = () => { - const [sliderValues, setSliderValues] = useState([]); + const { tasks, isLoading, trigger, reset } = useLabelPrompterReplyTask(); - const { tasks, isLoading, submit, reset } = useLabelPrompterReplyTask(); - - if (isLoading || tasks.length === 0) { + if (isLoading) { return ; } - const task = tasks[0].task; - const valid_labels = tasks[0].valid_labels; - const messages: Message[] = [ - ...task.conversation.messages, - { text: task.reply, is_assistant: false, message_id: task.message_id }, - ]; + if (tasks.length === 0) { + return No tasks found...; + } return ( - } - inputs={} - controls={ - reset()} - onNextTask={reset} - onSubmitResponse={({ id, task }: LabelPrompterReplyTaskResponse) => - submit(id, task.message_id, task.reply, task.valid_labels, sliderValues) - } - /> - } - /> + <> + + Label Prompter Reply + + + + ); }; From a1fd2cc6380d03b062965f012363f988f2d7778a Mon Sep 17 00:00:00 2001 From: Yannic Kilcher Date: Tue, 10 Jan 2023 12:04:10 +0100 Subject: [PATCH 39/39] added to readme --- CONTRIBUTING.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 428f6a50..b290f5fe 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -110,3 +110,8 @@ Upon making a release on GitHub, all docker images are automatically built and pushed to ghcr.io. The docker images are tagged with the release version, and the `latest` tag. Further, the ansible playbook in `ansible/dev.yaml` is run to automatically deploy the built release to the dev machine. + +### Contribute a Dataset + +See +[here](https://github.com/LAION-AI/Open-Assistant/blob/main/docs/docs/data/datasets.md)