[NEW] Solving merge conflicts

This commit is contained in:
Nil-Andreu
2023-01-08 20:48:13 +01:00
40 changed files with 957 additions and 535 deletions
+9
View File
@@ -16,3 +16,12 @@ jobs:
with:
python-version: "3.10"
- uses: pre-commit/action@v3.0.0
- name: Post PR comment on failure
if: failure() && github.event_name == 'pull_request'
uses: peter-evans/create-or-update-comment@v2
with:
issue-number: ${{ github.event.pull_request.number }}
body: |
:x: **pre-commit** failed.
Please run `pre-commit run --all-files` locally and commit the changes.
Find more information in the repository's CONTRIBUTING.md
+4 -2
View File
@@ -96,8 +96,10 @@ The website is built using Next.js and is in the `website` folder.
### Pre-commit
Install `pre-commit` and run `pre-commit install` to install the pre-commit
hooks.
We are using `pre-commit` to enforce code style and formatting.
Install `pre-commit` from [its website](https://pre-commit.com) and run
`pre-commit install` to install the pre-commit hooks.
In case you haven't done this, have already committed, and CI is failing, you
can run `pre-commit run --all-files` to run the pre-commit hooks on all files.
+11 -6
View File
@@ -14,7 +14,7 @@ from oasst_backend.api.deps import get_dummy_api_client
from oasst_backend.api.v1.api import api_router
from oasst_backend.config import settings
from oasst_backend.database import engine
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.prompt_repository import PromptRepository, TaskRepository, UserRepository
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from pydantic import BaseModel
@@ -110,7 +110,12 @@ if settings.DEBUG_USE_SEED_DATA:
with Session(engine) as db:
api_client = get_dummy_api_client(db)
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
pr = PromptRepository(db=db, api_client=api_client, user=dummy_user)
ur = UserRepository(db=db, api_client=api_client)
tr = TaskRepository(db=db, api_client=api_client, client_user=dummy_user, user_repository=ur)
pr = PromptRepository(
db=db, api_client=api_client, client_user=dummy_user, user_repository=ur, task_repository=tr
)
with open(settings.DEBUG_USE_SEED_DATA_PATH) as f:
dummy_messages_raw = json.load(f)
@@ -118,14 +123,14 @@ if settings.DEBUG_USE_SEED_DATA:
dummy_messages = [DummyMessage(**dm) for dm in dummy_messages_raw]
for msg in dummy_messages:
task = pr.fetch_task_by_frontend_message_id(msg.task_message_id)
task = tr.fetch_task_by_frontend_message_id(msg.task_message_id)
if task and not task.ack:
logger.warning("Deleting unacknowledged seed data task")
db.delete(task)
task = None
if not task:
if msg.parent_message_id is None:
task = pr.store_task(
task = tr.store_task(
protocol_schema.InitialPromptTask(hint=""), message_tree_id=None, parent_message_id=None
)
else:
@@ -144,12 +149,12 @@ if settings.DEBUG_USE_SEED_DATA:
for cmsg in conversation_messages
]
)
task = pr.store_task(
task = tr.store_task(
protocol_schema.AssistantReplyTask(conversation=conversation),
message_tree_id=parent_message.message_tree_id,
parent_message_id=parent_message.id,
)
pr.bind_frontend_message_id(task.id, msg.task_message_id)
tr.bind_frontend_message_id(task.id, msg.task_message_id)
message = pr.store_text_reply(msg.text, msg.task_message_id, msg.user_message_id)
logger.info(
@@ -16,7 +16,7 @@ def get_message_by_frontend_id(
"""
Get a message by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
message = pr.fetch_message_by_frontend_message_id(message_id)
return utils.prepare_message(message)
@@ -29,7 +29,7 @@ def get_conv_by_frontend_id(
Get a conversation from the tree root and up to the message with given frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
message = pr.fetch_message_by_frontend_message_id(message_id)
messages = pr.fetch_message_conversation(message)
return utils.prepare_conversation(messages)
@@ -43,7 +43,7 @@ def get_tree_by_frontend_id(
Get all messages belonging to the same message tree.
Message is identified by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
message = pr.fetch_message_by_frontend_message_id(message_id)
tree = pr.fetch_message_tree(message.message_tree_id)
return utils.prepare_tree(tree, message.message_tree_id)
@@ -56,7 +56,7 @@ def get_children_by_frontend_id(
"""
Get all messages belonging to the same message tree.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
message = pr.fetch_message_by_frontend_message_id(message_id)
messages = pr.fetch_message_children(message.id)
return utils.prepare_message_list(messages)
@@ -70,7 +70,7 @@ def get_descendants_by_frontend_id(
Get a subtree which starts with this message.
The message is identified by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
message = pr.fetch_message_by_frontend_message_id(message_id)
descendants = pr.fetch_message_descendants(message)
return utils.prepare_tree(descendants, message.id)
@@ -84,7 +84,7 @@ def get_longest_conv_by_frontend_id(
Get the longest conversation from the tree of the message.
The message is identified by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
message = pr.fetch_message_by_frontend_message_id(message_id)
conv = pr.fetch_longest_conversation(message.message_tree_id)
return utils.prepare_conversation(conv)
@@ -98,7 +98,7 @@ def get_max_children_by_frontend_id(
Get message with the most children from the tree of the provided message.
The message is identified by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
message = pr.fetch_message_by_frontend_message_id(message_id)
message, children = pr.fetch_message_with_max_children(message.message_tree_id)
return utils.prepare_tree([message, *children], message.id)
@@ -29,7 +29,7 @@ def query_frontend_user_messages(
"""
Query frontend user messages.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
messages = pr.query_messages(
username=username,
api_client_id=api_client_id,
@@ -47,6 +47,6 @@ def query_frontend_user_messages(
def mark_frontend_user_messages_deleted(
username: str, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
):
pr = PromptRepository(db, api_client, None)
pr = PromptRepository(db, api_client)
messages = pr.query_messages(username=username, api_client_id=api_client.id)
pr.mark_messages_deleted(messages)
+8 -7
View File
@@ -1,7 +1,8 @@
from fastapi import APIRouter, Depends
from oasst_backend.api import deps
from oasst_backend.models import ApiClient
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.user_repository import UserRepository
from oasst_shared.schemas.protocol import LeaderboardStats
from sqlmodel import Session
router = APIRouter()
@@ -11,15 +12,15 @@ router = APIRouter()
def get_assistant_leaderboard(
db: Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
):
pr = PromptRepository(db, api_client, None)
return pr.get_user_leaderboard(role="assistant")
) -> LeaderboardStats:
ur = UserRepository(db, api_client)
return ur.get_user_leaderboard(role="assistant")
@router.get("/create/prompter")
def get_prompter_leaderboard(
db: Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
):
pr = PromptRepository(db, api_client, None)
return pr.get_user_leaderboard(role="prompter")
) -> LeaderboardStats:
ur = UserRepository(db, api_client)
return ur.get_user_leaderboard(role="prompter")
+9 -9
View File
@@ -29,7 +29,7 @@ def query_messages(
"""
Query messages.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
messages = pr.query_messages(
username=username,
api_client_id=api_client_id,
@@ -51,7 +51,7 @@ def get_message(
"""
Get a message by its internal ID.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
message = pr.fetch_message(message_id)
return utils.prepare_message(message)
@@ -64,7 +64,7 @@ def get_conv(
Get a conversation from the tree root and up to the message with given internal ID.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
messages = pr.fetch_message_conversation(message_id)
return utils.prepare_conversation(messages)
@@ -76,7 +76,7 @@ def get_tree(
"""
Get all messages belonging to the same message tree.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
message = pr.fetch_message(message_id)
tree = pr.fetch_message_tree(message.message_tree_id)
return utils.prepare_tree(tree, message.message_tree_id)
@@ -89,7 +89,7 @@ def get_children(
"""
Get all messages belonging to the same message tree.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
messages = pr.fetch_message_children(message_id)
return utils.prepare_message_list(messages)
@@ -101,7 +101,7 @@ def get_descendants(
"""
Get a subtree which starts with this message.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
message = pr.fetch_message(message_id)
descendants = pr.fetch_message_descendants(message)
return utils.prepare_tree(descendants, message.id)
@@ -114,7 +114,7 @@ def get_longest_conv(
"""
Get the longest conversation from the tree of the message.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
message = pr.fetch_message(message_id)
conv = pr.fetch_longest_conversation(message.message_tree_id)
return utils.prepare_conversation(conv)
@@ -127,7 +127,7 @@ def get_max_children(
"""
Get message with the most children from the tree of the provided message.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
message = pr.fetch_message(message_id)
message, children = pr.fetch_message_with_max_children(message.message_tree_id)
return utils.prepare_tree([message, *children], message.id)
@@ -137,5 +137,5 @@ def get_max_children(
def mark_message_deleted(
message_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
):
pr = PromptRepository(db, api_client, None)
pr = PromptRepository(db, api_client)
pr.mark_messages_deleted(message_id)
+1 -1
View File
@@ -13,5 +13,5 @@ def get_message_stats(
db: Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
):
pr = PromptRepository(db, api_client, None)
pr = PromptRepository(db, api_client)
return pr.get_stats()
+10 -10
View File
@@ -8,7 +8,7 @@ from loguru import logger
from oasst_backend.api import deps
from oasst_backend.api.v1.utils import prepare_conversation
from oasst_backend.config import settings
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.prompt_repository import PromptRepository, TaskRepository
from oasst_backend.utils.hugging_face import HF_embeddingModel, HF_url, HuggingFaceAPI
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
@@ -192,9 +192,9 @@ def request_task(
api_client = deps.api_auth(api_key, db)
try:
pr = PromptRepository(db, api_client, request.user)
pr = PromptRepository(db, api_client, client_user=request.user)
task, message_tree_id, parent_message_id = generate_task(request, pr)
pr.store_task(task, message_tree_id, parent_message_id, request.collective)
pr.task_repository.store_task(task, message_tree_id, parent_message_id, request.collective)
except OasstError:
raise
@@ -219,11 +219,11 @@ def tasks_acknowledge(
api_client = deps.api_auth(api_key, db)
try:
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
# here we store the message id in the database for the task
logger.info(f"Frontend acknowledges task {task_id=}, {ack_request=}.")
pr.bind_frontend_message_id(task_id=task_id, frontend_message_id=ack_request.message_id)
pr.task_repository.bind_frontend_message_id(task_id=task_id, frontend_message_id=ack_request.message_id)
except OasstError:
raise
@@ -247,8 +247,8 @@ def tasks_acknowledge_failure(
try:
logger.info(f"Frontend reports failure to implement task {task_id=}, {nack_request=}.")
api_client = deps.api_auth(api_key, db)
pr = PromptRepository(db, api_client, user=None)
pr.acknowledge_task_failure(task_id)
pr = PromptRepository(db, api_client)
pr.task_repository.acknowledge_task_failure(task_id)
except (KeyError, RuntimeError):
logger.exception("Failed to not acknowledge task.")
raise OasstError("Failed to not acknowledge task.", OasstErrorCode.TASK_NACK_FAILED)
@@ -267,7 +267,7 @@ async def tasks_interaction(
api_client = deps.api_auth(api_key, db)
try:
pr = PromptRepository(db, api_client, user=interaction.user)
pr = PromptRepository(db, api_client, client_user=interaction.user)
match type(interaction):
case protocol_schema.TextReplyToMessage:
@@ -339,6 +339,6 @@ def close_collective_task(
api_key: APIKey = Depends(deps.get_api_key),
):
api_client = deps.api_auth(api_key, db)
pr = PromptRepository(db, api_client, user=None)
pr.close_task(close_task_request.message_id)
tr = TaskRepository(db, api_client)
tr.close_task(close_task_request.message_id)
return protocol_schema.TaskDone()
+1 -1
View File
@@ -25,7 +25,7 @@ def label_text(
try:
logger.info(f"Labeling text {text_labels=}.")
pr = PromptRepository(db, api_client, user=text_labels.user)
pr = PromptRepository(db, api_client, client_user=text_labels.user)
pr.store_text_labels(text_labels)
except Exception:
+2 -2
View File
@@ -29,7 +29,7 @@ def query_user_messages(
"""
Query user messages.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
messages = pr.query_messages(
user_id=user_id,
api_client_id=api_client_id,
@@ -48,6 +48,6 @@ def query_user_messages(
def mark_user_messages_deleted(
user_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
):
pr = PromptRepository(db, api_client, None)
pr = PromptRepository(db, api_client)
messages = pr.query_messages(user_id=user_id)
pr.mark_messages_deleted(messages)
@@ -6,27 +6,56 @@ import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from sqlmodel import Field, Index, SQLModel
# The types of States a message tree can have.
class States(str, Enum):
"""States of the Open-Assistant message tree state machine."""
INITIAL_PROMPT_REVIEW = "initial_prompt_review"
"""In this state the message tree consists only of a single inital prompt root node.
Initial prompt labeling tasks will determine if the tree goes into `breeding_phase` or
`aborted_low_grade`."""
class States(Enum):
INITIAL = "initial"
BREEDING_PHASE = "breeding_phase"
"""Assistant & prompter human demonstrations are collected. Concurrently labeling tasks
are handed out to check if the quality of the replies surpasses the minimum acceptable
quality.
When the required number of messages passing the initial labelling-quality check has been
collected the tree will enter `ranking_phase`. If too many poor-quality labelling responses
are received the tree can also enter the `aborted_low_grade` state."""
RANKING_PHASE = "ranking_phase"
"""The tree has been successfully populated with the desired number of messages. Ranking
tasks are now handed out for all nodes with more than one child."""
READY_FOR_SCORING = "ready_for_scoring"
CHILDREN_SCORED = "children_scored"
FINAL = "final"
"""Required ranking responses have been collected and the scoring algorithm can now
compute the aggergated ranking scores that will appear in the dataset."""
READY_FOR_EXPORT = "ready_for_export"
"""The Scoring algorithm computed rankings scores for all childern. The message tree can be
exported as part of an Open-Assistant message tree dataset."""
SCORING_FAILED = "scoring_failed"
"""An exception occured in the scoring algorithm."""
ABORTED_LOW_GRADE = "aborted_low_grade"
"""The system received too many bad reviews and stopped handing out tasks for this message tree."""
HALTED_BY_MODERATOR = "halted_by_moderator"
"""A moderator decided to manually halt the message tree construction process."""
VALID_STATES = (
States.INITIAL,
States.INITIAL_PROMPT_REVIEW,
States.BREEDING_PHASE,
States.RANKING_PHASE,
States.READY_FOR_SCORING,
States.CHILDREN_SCORED,
States.FINAL,
States.READY_FOR_EXPORT,
States.ABORTED_LOW_GRADE,
)
TERMINAL_STATES = (States.READY_FOR_EXPORT, States.ABORTED_LOW_GRADE, States.SCORING_FAILED, States.HALTED_BY_MODERATOR)
class MessageTreeState(SQLModel, table=True):
__tablename__ = "message_tree_state"
+63 -269
View File
@@ -8,98 +8,39 @@ from uuid import UUID, uuid4
import oasst_backend.models.db_payload as db_payload
from loguru import logger
from oasst_backend.journal_writer import JournalWriter
from oasst_backend.models import ApiClient, Message, MessageEmbedding, MessageReaction, Task, TextLabels, User
from oasst_backend.models import ApiClient, Message, MessageEmbedding, MessageReaction, TextLabels, User
from oasst_backend.models.payload_column_type import PayloadContainer
from oasst_backend.task_repository import TaskRepository, validate_frontend_message_id
from oasst_backend.user_repository import UserRepository
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.schemas.protocol import LeaderboardStats, SystemStats
from oasst_shared.schemas.protocol import SystemStats
from sqlalchemy import update
from sqlmodel import Session, func
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
class PromptRepository:
def __init__(self, db: Session, api_client: ApiClient, user: Optional[protocol_schema.User]):
def __init__(
self,
db: Session,
api_client: ApiClient,
client_user: Optional[protocol_schema.User] = None,
user_repository: Optional[UserRepository] = None,
task_repository: Optional[TaskRepository] = None,
):
self.db = db
self.api_client = api_client
self.user = self.lookup_user(user)
self.user_repository = user_repository or UserRepository(db, api_client)
self.user = self.user_repository.lookup_client_user(client_user, create_missing=True)
self.user_id = self.user.id if self.user else None
self.task_repository = task_repository or TaskRepository(
db, api_client, client_user, user_repository=self.user_repository
)
self.journal = JournalWriter(db, api_client, self.user)
def lookup_user(self, client_user: protocol_schema.User) -> Optional[User]:
if not client_user:
return None
user: User = (
self.db.query(User)
.filter(
User.api_client_id == self.api_client.id,
User.username == client_user.id,
User.auth_method == client_user.auth_method,
)
.first()
)
if user is None:
# user is unknown, create new record
user = User(
username=client_user.id,
display_name=client_user.display_name,
api_client_id=self.api_client.id,
auth_method=client_user.auth_method,
)
self.db.add(user)
self.db.commit()
self.db.refresh(user)
elif client_user.display_name and client_user.display_name != user.display_name:
# we found the user but the display name changed
user.display_name = client_user.display_name
self.db.add(user)
self.db.commit()
return user
def validate_frontend_message_id(self, message_id: str) -> None:
# TODO: Should it be replaced with fastapi/pydantic validation?
if not isinstance(message_id, str):
raise OasstError(
f"message_id must be string, not {type(message_id)}", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID
)
if not message_id:
raise OasstError("message_id must not be empty", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID)
def bind_frontend_message_id(self, task_id: UUID, frontend_message_id: str):
self.validate_frontend_message_id(frontend_message_id)
# find task
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
if task is None:
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
if task.expired:
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
if task.done or task.ack is not None:
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
task.frontend_message_id = frontend_message_id
task.ack = True
# ToDo: check race-condition, transaction
self.db.add(task)
self.db.commit()
def acknowledge_task_failure(self, task_id):
# find task
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
if task is None:
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
if task.expired:
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
if task.done or task.ack is not None:
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
task.ack = False
# ToDo: check race-condition, transaction
self.db.add(task)
self.db.commit()
def fetch_message_by_frontend_message_id(self, frontend_message_id: str, fail_if_missing: bool = True) -> Message:
self.validate_frontend_message_id(frontend_message_id)
validate_frontend_message_id(frontend_message_id)
message: Message = (
self.db.query(Message)
.filter(Message.api_client_id == self.api_client.id, Message.frontend_message_id == frontend_message_id)
@@ -113,22 +54,54 @@ class PromptRepository:
)
return message
def fetch_task_by_frontend_message_id(self, message_id: str) -> Task:
self.validate_frontend_message_id(message_id)
task = (
self.db.query(Task)
.filter(Task.api_client_id == self.api_client.id, Task.frontend_message_id == message_id)
.one_or_none()
def insert_message(
self,
*,
message_id: UUID,
frontend_message_id: str,
parent_id: UUID,
message_tree_id: UUID,
task_id: UUID,
role: str,
payload: db_payload.MessagePayload,
payload_type: str = None,
depth: int = 0,
) -> Message:
if payload_type is None:
if payload is None:
payload_type = "null"
else:
payload_type = type(payload).__name__
message = Message(
id=message_id,
parent_id=parent_id,
message_tree_id=message_tree_id,
task_id=task_id,
user_id=self.user_id,
role=role,
frontend_message_id=frontend_message_id,
api_client_id=self.api_client.id,
payload_type=payload_type,
payload=PayloadContainer(payload=payload),
depth=depth,
)
return task
self.db.add(message)
self.db.commit()
self.db.refresh(message)
return message
def store_text_reply(
self, text: str, frontend_message_id: str, user_frontend_message_id: str, miniLM_embedding: List[float] = None
self,
text: str,
frontend_message_id: str,
user_frontend_message_id: str,
miniLM_embedding: Optional[List[float]] = None,
) -> Message:
self.validate_frontend_message_id(frontend_message_id)
self.validate_frontend_message_id(user_frontend_message_id)
validate_frontend_message_id(frontend_message_id)
validate_frontend_message_id(user_frontend_message_id)
task = self.fetch_task_by_frontend_message_id(frontend_message_id)
task = self.task_repository.fetch_task_by_frontend_message_id(frontend_message_id)
if task is None:
raise OasstError(f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND)
@@ -176,7 +149,7 @@ class PromptRepository:
def store_rating(self, rating: protocol_schema.MessageRating) -> MessageReaction:
message = self.fetch_message_by_frontend_message_id(rating.message_id, fail_if_missing=True)
task = self.fetch_task_by_frontend_message_id(rating.message_id)
task = self.task_repository.fetch_task_by_frontend_message_id(rating.message_id)
task_payload: db_payload.RateSummaryPayload = task.payload.payload
if type(task_payload) != db_payload.RateSummaryPayload:
raise OasstError(
@@ -203,7 +176,7 @@ class PromptRepository:
def store_ranking(self, ranking: protocol_schema.MessageRanking) -> MessageReaction:
# fetch task
task = self.fetch_task_by_frontend_message_id(ranking.message_id)
task = self.task_repository.fetch_task_by_frontend_message_id(ranking.message_id)
if not task.collective:
task.done = True
self.db.add(task)
@@ -257,142 +230,6 @@ class PromptRepository:
OasstErrorCode.TASK_PAYLOAD_TYPE_MISMATCH,
)
def store_task(
self,
task: protocol_schema.Task,
message_tree_id: UUID = None,
parent_message_id: UUID = None,
collective: bool = False,
) -> Task:
payload: db_payload.TaskPayload
match type(task):
case protocol_schema.SummarizeStoryTask:
payload = db_payload.SummarizationStoryPayload(story=task.story)
case protocol_schema.RateSummaryTask:
payload = db_payload.RateSummaryPayload(
full_text=task.full_text, summary=task.summary, scale=task.scale
)
case protocol_schema.InitialPromptTask:
payload = db_payload.InitialPromptPayload(hint=task.hint)
case protocol_schema.PrompterReplyTask:
payload = db_payload.PrompterReplyPayload(conversation=task.conversation, hint=task.hint)
case protocol_schema.AssistantReplyTask:
payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation)
case protocol_schema.RankInitialPromptsTask:
payload = db_payload.RankInitialPromptsPayload(type=task.type, prompts=task.prompts)
case protocol_schema.RankPrompterRepliesTask:
payload = db_payload.RankPrompterRepliesPayload(
type=task.type, conversation=task.conversation, replies=task.replies
)
case protocol_schema.RankAssistantRepliesTask:
payload = db_payload.RankAssistantRepliesPayload(
type=task.type, conversation=task.conversation, replies=task.replies
)
case protocol_schema.LabelInitialPromptTask:
payload = db_payload.LabelInitialPromptPayload(
type=task.type, message_id=task.message_id, prompt=task.prompt, valid_labels=task.valid_labels
)
case protocol_schema.LabelPrompterReplyTask:
payload = db_payload.LabelPrompterReplyPayload(
type=task.type,
message_id=task.message_id,
conversation=task.conversation,
reply=task.reply,
valid_labels=task.valid_labels,
)
case protocol_schema.LabelAssistantReplyTask:
payload = db_payload.LabelAssistantReplyPayload(
type=task.type,
message_id=task.message_id,
conversation=task.conversation,
reply=task.reply,
valid_labels=task.valid_labels,
)
case _:
raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE)
task = self.insert_task(
payload=payload,
id=task.id,
message_tree_id=message_tree_id,
parent_message_id=parent_message_id,
collective=collective,
)
assert task.id == task.id
return task
def insert_task(
self,
payload: db_payload.TaskPayload,
id: UUID = None,
message_tree_id: UUID = None,
parent_message_id: UUID = None,
collective: bool = False,
) -> Task:
c = PayloadContainer(payload=payload)
task = Task(
id=id,
user_id=self.user_id,
payload_type=type(payload).__name__,
payload=c,
api_client_id=self.api_client.id,
message_tree_id=message_tree_id,
parent_message_id=parent_message_id,
collective=collective,
)
self.db.add(task)
self.db.commit()
self.db.refresh(task)
return task
def insert_message(
self,
*,
message_id: UUID,
frontend_message_id: str,
parent_id: UUID,
message_tree_id: UUID,
task_id: UUID,
role: str,
payload: db_payload.MessagePayload,
payload_type: str = None,
depth: int = 0,
) -> Message:
if payload_type is None:
if payload is None:
payload_type = "null"
else:
payload_type = type(payload).__name__
message = Message(
id=message_id,
parent_id=parent_id,
message_tree_id=message_tree_id,
task_id=task_id,
user_id=self.user_id,
role=role,
frontend_message_id=frontend_message_id,
api_client_id=self.api_client.id,
payload_type=payload_type,
payload=PayloadContainer(payload=payload),
depth=depth,
)
self.db.add(message)
self.db.commit()
self.db.refresh(message)
return message
def insert_message_embedding(self, message_id: UUID, model: str, embedding: List[float]) -> MessageEmbedding:
if None in (message_id, model, embedding):
raise OasstError("Paramters missing to add embedding", OasstErrorCode.GENERIC_ERROR)
@@ -527,28 +364,6 @@ class PromptRepository:
raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTP_404_NOT_FOUND)
return message
def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False):
"""
Mark task as done. No further messages will be accepted for this task.
"""
self.validate_frontend_message_id(frontend_message_id)
task = self.fetch_task_by_frontend_message_id(frontend_message_id)
if not task:
raise OasstError(
f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND
)
if task.expired:
raise OasstError("Task already expired", OasstErrorCode.TASK_EXPIRED)
if not allow_personal_tasks and not task.collective:
raise OasstError("This is not a collective task", OasstErrorCode.TASK_NOT_COLLECTIVE)
if task.done:
raise OasstError("Allready closed", OasstErrorCode.TASK_ALREADY_DONE)
task.done = True
self.db.add(task)
self.db.commit()
@staticmethod
def trace_conversation(messages: list[Message] | dict[UUID, Message], last_message: Message) -> list[Message]:
"""
@@ -740,24 +555,3 @@ class PromptRepository:
deleted=result.get(True, 0),
message_trees=result.get(None, 0),
)
def get_user_leaderboard(self, role: str) -> LeaderboardStats:
"""
Get leaderboard stats for Messages created,
separate leaderboard for prompts & assistants
"""
query = (
self.db.query(Message.user_id, User.username, User.display_name, func.count(Message.user_id))
.join(User, User.id == Message.user_id, isouter=True)
.filter(Message.deleted is not True, Message.role == role)
.group_by(Message.user_id, User.username, User.display_name)
.order_by(func.count(Message.user_id).desc())
)
result = [
{"ranking": i, "user_id": j[0], "username": j[1], "display_name": j[2], "score": j[3]}
for i, j in enumerate(query.all(), start=1)
]
return LeaderboardStats(leaderboard=result)
+199
View File
@@ -0,0 +1,199 @@
from typing import Optional
from uuid import UUID
import oasst_backend.models.db_payload as db_payload
from oasst_backend.models import ApiClient, Task
from oasst_backend.models.payload_column_type import PayloadContainer
from oasst_backend.user_repository import UserRepository
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
from starlette.status import HTTP_404_NOT_FOUND
def validate_frontend_message_id(message_id: str) -> None:
# TODO: Should it be replaced with fastapi/pydantic validation?
if not isinstance(message_id, str):
raise OasstError(
f"message_id must be string, not {type(message_id)}", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID
)
if not message_id:
raise OasstError("message_id must not be empty", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID)
class TaskRepository:
def __init__(
self,
db: Session,
api_client: ApiClient,
client_user: Optional[protocol_schema.User],
user_repository: UserRepository,
):
self.db = db
self.api_client = api_client
self.user_repository = user_repository
self.user = self.user_repository.lookup_client_user(client_user, create_missing=True)
self.user_id = self.user.id if self.user else None
def store_task(
self,
task: protocol_schema.Task,
message_tree_id: UUID = None,
parent_message_id: UUID = None,
collective: bool = False,
) -> Task:
payload: db_payload.TaskPayload
match type(task):
case protocol_schema.SummarizeStoryTask:
payload = db_payload.SummarizationStoryPayload(story=task.story)
case protocol_schema.RateSummaryTask:
payload = db_payload.RateSummaryPayload(
full_text=task.full_text, summary=task.summary, scale=task.scale
)
case protocol_schema.InitialPromptTask:
payload = db_payload.InitialPromptPayload(hint=task.hint)
case protocol_schema.PrompterReplyTask:
payload = db_payload.PrompterReplyPayload(conversation=task.conversation, hint=task.hint)
case protocol_schema.AssistantReplyTask:
payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation)
case protocol_schema.RankInitialPromptsTask:
payload = db_payload.RankInitialPromptsPayload(type=task.type, prompts=task.prompts)
case protocol_schema.RankPrompterRepliesTask:
payload = db_payload.RankPrompterRepliesPayload(
type=task.type, conversation=task.conversation, replies=task.replies
)
case protocol_schema.RankAssistantRepliesTask:
payload = db_payload.RankAssistantRepliesPayload(
type=task.type, conversation=task.conversation, replies=task.replies
)
case protocol_schema.LabelInitialPromptTask:
payload = db_payload.LabelInitialPromptPayload(
type=task.type, message_id=task.message_id, prompt=task.prompt, valid_labels=task.valid_labels
)
case protocol_schema.LabelPrompterReplyTask:
payload = db_payload.LabelPrompterReplyPayload(
type=task.type,
message_id=task.message_id,
conversation=task.conversation,
reply=task.reply,
valid_labels=task.valid_labels,
)
case protocol_schema.LabelAssistantReplyTask:
payload = db_payload.LabelAssistantReplyPayload(
type=task.type,
message_id=task.message_id,
conversation=task.conversation,
reply=task.reply,
valid_labels=task.valid_labels,
)
case _:
raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE)
task = self.insert_task(
payload=payload,
id=task.id,
message_tree_id=message_tree_id,
parent_message_id=parent_message_id,
collective=collective,
)
assert task.id == task.id
return task
def bind_frontend_message_id(self, task_id: UUID, frontend_message_id: str):
validate_frontend_message_id(frontend_message_id)
# find task
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
if task is None:
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
if task.expired:
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
if task.done or task.ack is not None:
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
task.frontend_message_id = frontend_message_id
task.ack = True
# ToDo: check race-condition, transaction
self.db.add(task)
self.db.commit()
def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False):
"""
Mark task as done. No further messages will be accepted for this task.
"""
validate_frontend_message_id(frontend_message_id)
task = self.task_repository.fetch_task_by_frontend_message_id(frontend_message_id)
if not task:
raise OasstError(
f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND
)
if task.expired:
raise OasstError("Task already expired", OasstErrorCode.TASK_EXPIRED)
if not allow_personal_tasks and not task.collective:
raise OasstError("This is not a collective task", OasstErrorCode.TASK_NOT_COLLECTIVE)
if task.done:
raise OasstError("Allready closed", OasstErrorCode.TASK_ALREADY_DONE)
task.done = True
self.db.add(task)
self.db.commit()
def acknowledge_task_failure(self, task_id):
# find task
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
if task is None:
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
if task.expired:
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
if task.done or task.ack is not None:
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
task.ack = False
# ToDo: check race-condition, transaction
self.db.add(task)
self.db.commit()
def insert_task(
self,
payload: db_payload.TaskPayload,
id: UUID = None,
message_tree_id: UUID = None,
parent_message_id: UUID = None,
collective: bool = False,
) -> Task:
c = PayloadContainer(payload=payload)
task = Task(
id=id,
user_id=self.user_id,
payload_type=type(payload).__name__,
payload=c,
api_client_id=self.api_client.id,
message_tree_id=message_tree_id,
parent_message_id=parent_message_id,
collective=collective,
)
self.db.add(task)
self.db.commit()
self.db.refresh(task)
return task
def fetch_task_by_frontend_message_id(self, message_id: str) -> Task:
validate_frontend_message_id(message_id)
task = (
self.db.query(Task)
.filter(Task.api_client_id == self.api_client.id, Task.frontend_message_id == message_id)
.one_or_none()
)
return task
+64
View File
@@ -0,0 +1,64 @@
from typing import Optional
from oasst_backend.models import ApiClient, Message, User
from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.schemas.protocol import LeaderboardStats
from sqlmodel import Session, func
class UserRepository:
def __init__(self, db: Session, api_client: ApiClient):
self.db = db
self.api_client = api_client
def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]:
if not client_user:
return None
user: User = (
self.db.query(User)
.filter(
User.api_client_id == self.api_client.id,
User.username == client_user.id,
User.auth_method == client_user.auth_method,
)
.first()
)
if user is None:
if create_missing:
# user is unknown, create new record
user = User(
username=client_user.id,
display_name=client_user.display_name,
api_client_id=self.api_client.id,
auth_method=client_user.auth_method,
)
self.db.add(user)
self.db.commit()
self.db.refresh(user)
elif client_user.display_name and client_user.display_name != user.display_name:
# we found the user but the display name changed
user.display_name = client_user.display_name
self.db.add(user)
self.db.commit()
return user
def get_user_leaderboard(self, role: str) -> LeaderboardStats:
"""
Get leaderboard stats for Messages created,
separate leaderboard for prompts & assistants
"""
query = (
self.db.query(Message.user_id, User.username, User.display_name, func.count(Message.user_id))
.join(User, User.id == Message.user_id, isouter=True)
.filter(Message.deleted is not True, Message.role == role)
.group_by(Message.user_id, User.username, User.display_name)
.order_by(func.count(Message.user_id).desc())
)
result = [
{"ranking": i, "user_id": j[0], "username": j[1], "display_name": j[2], "score": j[3]}
for i, j in enumerate(query.all(), start=1)
]
return LeaderboardStats(leaderboard=result)
@@ -276,17 +276,21 @@ class TextLabel(str, enum.Enum):
fails_task = "fails_task", "Fails to follow the correct instruction / task"
not_appropriate = "not_appropriate", "Inappropriate for customer assistant"
violence = "violence", "Encourages or fails to discourage violence/abuse/terrorism/self-harm"
harmful = (
"harmful",
"Harmful content",
"The advice given in the output is harmful or counter-productive. This may be in addition to, but is distinct from the label for encouraging violence/abuse/terrorism/self-harm.",
excessive_harm = (
"excessive_harm",
"Content likely to cause excessive harm not justifiable in the context",
"Harm refers to physical or mental damage or injury to someone or something. Excessive refers to a reasonable threshold of harm in the context, for instance damaging skin is not excessive in the context of surgery.",
)
sexual_content = "sexual_content", "Contains sexual content"
toxicity = "toxicity", "Contains rude, abusive, profane or insulting content"
moral_judgement = "moral_judgement", "Expresses moral judgement"
political_content = "political_content", "Expresses political views"
humor = "humor", "Contains humorous content including sarcasm"
hate_speech = "hate_speech", "Expresses sentiment which is discriminatory against a grouping of people"
hate_speech = (
"hate_speech",
"Content is abusive or threatening and expresses prejudice against a protected characteristic",
"Prejudice refers to preconceived views not based on reason. Protected characteristics include gender, ethnicity, religion, sexual orientation, and similar characteristics.",
)
threat = "threat", "Contains a threat against a person or persons"
misleading = "misleading", "Contains text which is incorrect or misleading"
helpful = "helpful", "Completes the task to a high standard"
@@ -1,7 +1,7 @@
import { Progress } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
export const LoadingScreen = ({ text }) => {
export const LoadingScreen = ({ text = "Loading..." } = {}) => {
const { colorMode } = useColorMode();
const mainClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
+1 -2
View File
@@ -12,8 +12,7 @@ export interface Message {
export const Messages = ({ messages, post_id }: { messages: Message[]; post_id: string }) => {
const items = messages.map((messageProps: Message, i: number) => {
const { message_id } = messageProps;
const { text } = messageProps;
const { message_id, text } = messageProps;
return (
<FlaggableElement text={text} post_id={post_id} message_id={message_id} key={i + text}>
<MessageView {...messageProps} />
+100
View File
@@ -0,0 +1,100 @@
import { Grid, Slider, SliderFilledTrack, SliderThumb, SliderTrack } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import { ReactNode, useEffect, useId, useMemo, useState } from "react";
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
import { colors } from "styles/Theme/colors";
export const LabelTask = ({
title,
desc,
messages,
inputs,
controls,
}: {
title: string;
desc: string;
messages: ReactNode;
inputs: ReactNode;
controls: ReactNode;
}) => {
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
const card = useMemo(
() => (
<>
<h5 className="text-lg font-semibold">{title}</h5>
<p className="text-lg py-1">{desc}</p>
{messages}
</>
),
[title, desc, messages]
);
return (
<div className={`p-12 ${mainBgClasses}`}>
<TwoColumnsWithCards>
{card}
{inputs}
</TwoColumnsWithCards>
{controls}
</div>
);
};
// TODO: consolidate with FlaggableElement
interface LabelSliderGroupProps {
labelIDs: Array<string>;
onChange: (sliderValues: number[]) => unknown;
}
export const LabelSliderGroup = ({ labelIDs, onChange }: LabelSliderGroupProps) => {
const [sliderValues, setSliderValues] = useState<number[]>(Array.from({ length: labelIDs.length }).map(() => 0));
useEffect(() => {
onChange(sliderValues);
}, [sliderValues, onChange]);
return (
<Grid templateColumns="auto 1fr" rowGap={1} columnGap={3}>
{labelIDs.map((labelId, idx) => (
<CheckboxSliderItem
key={idx}
labelId={labelId}
sliderValue={sliderValues[idx]}
sliderHandler={(sliderValue) => {
const newState = sliderValues.slice();
newState[idx] = sliderValue;
setSliderValues(newState);
}}
/>
))}
</Grid>
);
};
function CheckboxSliderItem(props: {
labelId: string;
sliderValue: number;
sliderHandler: (newVal: number) => unknown;
}) {
const id = useId();
const { colorMode } = useColorMode();
const labelTextClass = colorMode === "light" ? `text-${colors.light.text}` : `text-${colors.dark.text}`;
return (
<>
<label className="text-sm" htmlFor={id}>
{/* TODO: display real text instead of just the id */}
<span className={labelTextClass}>{props.labelId}</span>
</label>
<Slider defaultValue={0} onChangeEnd={(val) => props.sliderHandler(val / 100)}>
<SliderTrack>
<SliderFilledTrack />
<SliderThumb />
</SliderTrack>
</Slider>
</>
);
}
@@ -63,4 +63,18 @@ export const TaskTypes = [
pathname: "/label/label_initial_prompt",
type: "label_initial_prompt",
},
{
label: "Label Prompter Reply",
desc: "Provide labels for a prompt.",
category: TaskCategory.Label,
pathname: "/label/label_prompter_reply",
type: "label_prompter_reply",
},
{
label: "Label Assistant Reply",
desc: "Provide labels for a prompt.",
category: TaskCategory.Label,
pathname: "/label/label_assistant_reply",
type: "label_assistant_reply",
},
];
+61 -29
View File
@@ -1,4 +1,17 @@
import { Table, TableCaption, TableContainer, Tbody, Td, Th, Thead, Tr } from "@chakra-ui/react";
import {
Button,
Flex,
Spacer,
Stack,
Table,
TableCaption,
TableContainer,
Tbody,
Td,
Th,
Thead,
Tr,
} from "@chakra-ui/react";
import Link from "next/link";
import { useState } from "react";
import fetcher from "src/lib/fetcher";
@@ -8,41 +21,60 @@ import useSWR from "swr";
* Fetches users from the users api route and then presents them in a simple Chakra table.
*/
const UsersCell = () => {
// Fetch and save the users.
const [pageIndex, setPageIndex] = useState(0);
const [users, setUsers] = useState([]);
const { isLoading } = useSWR("/api/admin/users", fetcher, {
// Fetch and save the users.
// This follows useSWR's recommendation for simple pagination:
// https://swr.vercel.app/docs/pagination#when-to-use-useswr
useSWR(`/api/admin/users?pageIndex=${pageIndex}`, fetcher, {
onSuccess: setUsers,
});
const toPreviousPage = () => {
setPageIndex(Math.max(0, pageIndex - 1));
};
const toNextPage = () => {
setPageIndex(pageIndex + 1);
};
// Present users in a naive table.
return (
<TableContainer>
<Table variant="simple">
<TableCaption>Users</TableCaption>
<Thead>
<Tr>
<Th>Id</Th>
<Th>Email</Th>
<Th>Name</Th>
<Th>Role</Th>
<Th>Update</Th>
</Tr>
</Thead>
<Tbody>
{users.map((user, index) => (
<Tr key={index}>
<Td>{user.id}</Td>
<Td>{user.email}</Td>
<Td>{user.name}</Td>
<Td>{user.role}</Td>
<Td>
<Link href={`/admin/manage_user/${user.id}`}>Manage</Link>
</Td>
<Stack>
<Flex p="2">
<Button onClick={toPreviousPage}>Previous</Button>
<Spacer />
<Button onClick={toNextPage}>Next</Button>
</Flex>
<TableContainer>
<Table variant="simple">
<TableCaption>Users</TableCaption>
<Thead>
<Tr>
<Th>Id</Th>
<Th>Email</Th>
<Th>Name</Th>
<Th>Role</Th>
<Th>Update</Th>
</Tr>
))}
</Tbody>
</Table>
</TableContainer>
</Thead>
<Tbody>
{users.map((user, index) => (
<Tr key={index}>
<Td>{user.id}</Td>
<Td>{user.email}</Td>
<Td>{user.name}</Td>
<Td>{user.role}</Td>
<Td>
<Link href={`/admin/manage_user/${user.id}`}>Manage</Link>
</Td>
</Tr>
))}
</Tbody>
</Table>
</TableContainer>
</Stack>
);
};
@@ -0,0 +1,22 @@
import { TaskResponse } from "../useGenericTaskAPI";
import { LabelingTaskType, useLabelingTask } from "./useLabelingTask";
export interface LabelAssistantReplyTask {
id: string;
type: LabelingTaskType.label_assistant_reply;
message_id: string;
valid_labels: string[];
reply: string;
conversation: {
messages: Array<{
text: string;
is_assistant: boolean;
message_id: string;
}>;
};
}
export type LabelAssistantReplyTaskResponse = TaskResponse<LabelAssistantReplyTask>;
export const useLabelAssistantReplyTask = () =>
useLabelingTask<LabelAssistantReplyTask>(LabelingTaskType.label_assistant_reply);
@@ -0,0 +1,15 @@
import { TaskResponse } from "../useGenericTaskAPI";
import { LabelingTaskType, useLabelingTask } from "./useLabelingTask";
export interface LabelInitialPromptTask {
id: string;
type: LabelingTaskType.label_initial_prompt;
message_id: string;
valid_labels: string[];
prompt: string;
}
export type LabelInitialPromptTaskResponse = TaskResponse<LabelInitialPromptTask>;
export const useLabelInitialPromptTask = () =>
useLabelingTask<LabelInitialPromptTask>(LabelingTaskType.label_initial_prompt);
@@ -0,0 +1,22 @@
import { TaskResponse } from "../useGenericTaskAPI";
import { LabelingTaskType, useLabelingTask } from "./useLabelingTask";
export interface LabelPrompterReplyTask {
id: string;
type: LabelingTaskType.label_prompter_reply;
message_id: string;
valid_labels: string[];
reply: string;
conversation: {
messages: Array<{
text: string;
is_assistant: boolean;
message_id: string;
}>;
};
}
export type LabelPrompterReplyTaskResponse = TaskResponse<LabelPrompterReplyTask>;
export const useLabelPrompterReplyTask = () =>
useLabelingTask<LabelPrompterReplyTask>(LabelingTaskType.label_prompter_reply);
@@ -0,0 +1,20 @@
import { useGenericTaskAPI } from "../useGenericTaskAPI";
export const enum LabelingTaskType {
label_initial_prompt = "label_initial_prompt",
label_prompter_reply = "label_prompter_reply",
label_assistant_reply = "label_assistant_reply",
}
export const useLabelingTask = <TaskType>(endpoint: LabelingTaskType) => {
const { tasks, isLoading, trigger, reset, error } = useGenericTaskAPI<TaskType>(endpoint);
const submit = (id: string, message_id: string, text: string, validLabels: string[], labelWeights: number[]) => {
console.assert(validLabels.length === labelWeights.length);
const labels = Object.fromEntries(validLabels.map((label, i) => [label, labelWeights[i]]));
return trigger({ id, update_type: "text_labels", content: { labels, text, message_id } });
};
return { tasks, isLoading, submit, reset, error };
};
@@ -0,0 +1,42 @@
import { useEffect, useState } from "react";
import fetcher from "src/lib/fetcher";
import poster from "src/lib/poster";
import useSWRImmutable from "swr/immutable";
import useSWRMutation from "swr/mutation";
// TODO: type & centralize types for all tasks
export interface TaskResponse<TaskType> {
id: string;
userId: string;
task: TaskType;
}
export const useGenericTaskAPI = <TaskType,>(taskApiEndpoint: string) => {
type ConcreteTaskResponse = TaskResponse<TaskType>;
const [tasks, setTasks] = useState<ConcreteTaskResponse[]>([]);
const { isLoading, mutate, error } = useSWRImmutable<ConcreteTaskResponse>(
"/api/new_task/" + taskApiEndpoint,
fetcher,
{
onSuccess: (data) => setTasks([data]),
}
);
useEffect(() => {
if (tasks.length === 0 && !isLoading && !error) {
mutate();
}
}, [tasks, isLoading, mutate, error]);
const { trigger } = useSWRMutation("/api/update_task", poster, {
onSuccess: async (response) => {
const newTask: ConcreteTaskResponse = await response.json();
setTasks((oldTasks) => [...oldTasks, newTask]);
},
});
return { tasks, isLoading, trigger, error, reset: mutate };
};
-52
View File
@@ -1,52 +0,0 @@
import { useEffect, useState } from "react";
import fetcher from "src/lib/fetcher";
import poster from "src/lib/poster";
import useSWRImmutable from "swr/immutable";
import useSWRMutation from "swr/mutation";
// TODO: type & centralize types for all tasks
interface TaskResponse<TaskType> {
id: string;
userId: string;
task: TaskType;
}
export interface LabelInitialPromptTask {
id: string;
message_id: string;
prompt: string;
type: string;
valid_labels: string[];
}
export type LabelInitialPromptTaskResponse = TaskResponse<LabelInitialPromptTask>;
export const useLabelingTask = <LabelingTaskType>({ taskApiEndpoint }: { taskApiEndpoint: "label_initial_prompt" }) => {
type ConcreteTaskResponse = TaskResponse<LabelingTaskType>;
const [tasks, setTasks] = useState<Array<ConcreteTaskResponse>>([]);
const { isLoading, mutate, error } = useSWRImmutable("/api/new_task/" + taskApiEndpoint, fetcher, {
onSuccess: (data: ConcreteTaskResponse) => {
setTasks([data]);
},
});
useEffect(() => {
if (tasks.length === 0 && !isLoading && !error) {
mutate();
}
}, [tasks, isLoading, mutate, error]);
const { trigger } = useSWRMutation("/api/update_task", poster, {
onSuccess: async (reply) => {
const newTask: ConcreteTaskResponse = await reply.json();
setTasks((oldTasks) => [...oldTasks, newTask]);
},
});
const submit = (id: string, message_id: string, text: string, labels: Record<string, string>) =>
trigger({ id, update_type: "text_labels", content: { labels, text, message_id } });
return { tasks, isLoading, submit, error, reset: mutate };
};
+1 -1
View File
@@ -5,7 +5,7 @@ import { getToken } from "next-auth/jwt";
* Wraps any API Route handler and verifies that the user has the appropriate
* role before running the handler. Returns a 403 otherwise.
*/
const withRole = (role: string, handler: (arg0: NextApiRequest, arg1: NextApiResponse<any>) => any) => {
const withRole = (role: string, handler: (arg0: NextApiRequest, arg1: NextApiResponse) => void) => {
return async (req: NextApiRequest, res: NextApiResponse) => {
const token = await getToken({ req });
if (!token || token.role !== role) {
+2 -2
View File
@@ -1,8 +1,8 @@
export { default } from "next-auth/middleware";
/**
* Guards all pages under `/grading` and redirects them to the sign in page.
* Guards these pages and redirects them to the sign in page.
*/
export const config = {
matcher: ["/create/:path*", "/evaluate/:path*", "/account/:path*", "/dashboard"],
matcher: ["/create/:path*", "/evaluate/:path*", "/label/:path*", "/account/:path*", "/dashboard", "/admin/:path*"],
};
+1 -1
View File
@@ -26,7 +26,7 @@ const AdminIndex = () => {
return;
}
router.push("/");
}, [session, status]);
}, [router, session, status]);
return (
<>
+2 -2
View File
@@ -1,4 +1,4 @@
import { Box, Button, Container, Flex, FormControl, FormLabel, Input, Select, useToast } from "@chakra-ui/react";
import { Button, Container, FormControl, FormLabel, Input, Select, useToast } from "@chakra-ui/react";
import { Field, Form, Formik } from "formik";
import Head from "next/head";
import { useRouter } from "next/router";
@@ -27,7 +27,7 @@ const ManageUser = ({ user }) => {
return;
}
router.push("/");
}, [session, status]);
}, [router, session, status]);
// Trigger to let us update the user's role. Triggers a toast when complete.
const { trigger } = useSWRMutation("/api/admin/update_user", poster, {
@@ -1,4 +1,3 @@
import { getToken } from "next-auth/jwt";
import withRole from "src/lib/auth";
import prisma from "src/lib/prismadb";
+12 -2
View File
@@ -1,12 +1,21 @@
import { getToken } from "next-auth/jwt";
import withRole from "src/lib/auth";
import prisma from "src/lib/prismadb";
// The number of users to fetch in any request.
const PAGE_SIZE = 20;
/**
* Returns a list of user results from the database when the requesting user is
* a logged in admin.
*/
const handler = withRole("admin", async (req, res) => {
// Figure out the pagination index and skip that number of users.
//
// Note: with Prisma this isn't the most efficient but it's the only possible
// option with cuid based User IDs.
const { pageIndex } = req.query;
const skip = parseInt(pageIndex as string) * PAGE_SIZE || 0;
// Fetch 20 users.
const users = await prisma.user.findMany({
select: {
@@ -15,7 +24,8 @@ const handler = withRole("admin", async (req, res) => {
name: true,
email: true,
},
take: 20,
skip,
take: PAGE_SIZE,
});
res.status(200).json(users);
+48 -1
View File
@@ -2,17 +2,59 @@ import { Button, Input, Stack } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
import Link from "next/link";
import { useRouter } from "next/router";
import { getCsrfToken, getProviders, signIn } from "next-auth/react";
import React, { useRef } from "react";
import React, { useEffect, useRef, useState } from "react";
import { FaBug, FaDiscord, FaEnvelope, FaGithub } from "react-icons/fa";
import { AuthLayout } from "src/components/AuthLayout";
import { Footer } from "src/components/Footer";
import { Header } from "src/components/Header";
export type SignInErrorTypes =
| "Signin"
| "OAuthSignin"
| "OAuthCallback"
| "OAuthCreateAccount"
| "EmailCreateAccount"
| "Callback"
| "OAuthAccountNotLinked"
| "EmailSignin"
| "CredentialsSignin"
| "SessionRequired"
| "default";
const errorMessages: Record<SignInErrorTypes, string> = {
Signin: "Try signing in with a different account.",
OAuthSignin: "Try signing in with a different account.",
OAuthCallback: "Try signing in with the same account you used originally.",
OAuthCreateAccount: "Try signing in with a different account.",
EmailCreateAccount: "Try signing in with a different account.",
Callback: "Try signing in with a different account.",
OAuthAccountNotLinked: "To confirm your identity, sign in with the same account you used originally.",
EmailSignin: "The e-mail could not be sent.",
CredentialsSignin: "Sign in failed. Check the details you provided are correct.",
SessionRequired: "Please sign in to access this page.",
default: "Unable to sign in.",
};
// eslint-disable-next-line @typescript-eslint/no-unused-vars
function Signin({ csrfToken, providers }) {
const router = useRouter();
const { discord, email, github, credentials } = providers;
const emailEl = useRef(null);
const [error, setError] = useState("");
useEffect(() => {
const err = router?.query?.error;
if (err) {
if (typeof err === "string") {
setError(errorMessages[err]);
} else {
setError(errorMessages[err[0]]);
}
}
}, [router]);
const signinWithEmail = (ev: React.FormEvent) => {
ev.preventDefault();
signIn(email.id, { callbackUrl: "/dashboard", email: emailEl.current.value });
@@ -110,6 +152,11 @@ function Signin({ csrfToken, providers }) {
</Link>
.
</div>
{error && (
<div className="text-center mt-8">
<p className="text-orange-600">Error: {error}</p>
</div>
)}
</AuthLayout>
</div>
);
+9 -3
View File
@@ -1,17 +1,23 @@
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
import { getCsrfToken, getProviders } from "next-auth/react";
import { AuthLayout } from "src/components/AuthLayout";
export default function Verify() {
const { colorMode } = useColorMode();
const bgColorClass = colorMode === "light" ? "bg-gray-50" : "bg-chakra-gray-900";
return (
<>
<Head>
<title>Sign Up - Open Assistant</title>
<meta name="Sign Up" content="Sign up to access Open Assistant" />
</Head>
<AuthLayout>
<h1 className="text-lg">A sign-in link has been sent to your email address.</h1>
</AuthLayout>
<div className={`flex h-full justify-center items-center ${bgColorClass}`}>
<div className={bgColorClass}>
<h1 className="text-lg">A sign-in link has been sent to your email address.</h1>
</div>
</div>
</>
);
}
@@ -0,0 +1,46 @@
import { useState } from "react";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Message } from "src/components/Messages";
import { MessageTable } from "src/components/Messages/MessageTable";
import { TaskControls } from "src/components/Survey/TaskControls";
import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask";
import {
LabelAssistantReplyTaskResponse,
useLabelAssistantReplyTask,
} from "src/hooks/tasks/labeling/useLabelAssistantReply";
const LabelAssistantReply = () => {
const [sliderValues, setSliderValues] = useState<number[]>([]);
const { tasks, isLoading, submit, reset } = useLabelAssistantReplyTask();
if (isLoading || tasks.length === 0) {
return <LoadingScreen />;
}
const task = tasks[0].task;
const messages: Message[] = [
...task.conversation.messages,
{ text: task.reply, is_assistant: true, message_id: task.message_id },
];
return (
<LabelTask
title="Label Assistant Reply"
desc="Given the following discussion, provide labels for the final prompt"
messages={<MessageTable messages={messages} />}
inputs={<LabelSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />}
controls={
<TaskControls
tasks={tasks}
onSkip={reset}
onSubmitResponse={({ id, task }: LabelAssistantReplyTaskResponse) =>
submit(id, task.message_id, task.reply, task.valid_labels, sliderValues)
}
/>
}
/>
);
};
export default LabelAssistantReply;
@@ -1,113 +1,41 @@
import { Container, Grid, Slider, SliderFilledTrack, SliderThumb, SliderTrack } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import { useEffect, useId, useState } from "react";
import { useState } from "react";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { MessageView } from "src/components/Messages";
import { TaskControls } from "src/components/Survey/TaskControls";
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
import { LabelInitialPromptTask, LabelInitialPromptTaskResponse, useLabelingTask } from "src/hooks/useLabelingTask";
import { colors } from "styles/Theme/colors";
import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask";
import {
LabelInitialPromptTaskResponse,
useLabelInitialPromptTask,
} from "src/hooks/tasks/labeling/useLabelInitialPrompt";
const LabelInitialPrompt = () => {
const [sliderValues, setSliderValues] = useState<number[]>([]);
const { tasks, isLoading, submit, reset } = useLabelingTask<LabelInitialPromptTask>({
taskApiEndpoint: "label_initial_prompt",
});
const { tasks, isLoading, submit, reset } = useLabelInitialPromptTask();
const submitResponse = ({ id, task }: LabelInitialPromptTaskResponse) => {
const labels = task.valid_labels.reduce((obj, label, i) => {
obj[label] = sliderValues[i].toString();
return obj;
}, {} as Record<string, string>);
submit(id, task.message_id, task.prompt, labels);
};
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
if (isLoading) {
return <LoadingScreen text="Loading..." />;
}
if (tasks.length === 0) {
return <Container className="p-6 text-center text-gray-800">No tasks found...</Container>;
if (isLoading || tasks.length === 0) {
return <LoadingScreen />;
}
const task = tasks[0].task;
return (
<div className={`p-12 ${mainBgClasses}`}>
<TwoColumnsWithCards>
<>
<h5 className="text-lg font-semibold">Label Initial Prompt</h5>
<p className="text-lg py-1">Provide labels for the following prompt</p>
<MessageView text={task.prompt} is_assistant message_id={task.message_id} />
</>
<CheckboxSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />
</TwoColumnsWithCards>
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={reset} />
</div>
<LabelTask
title="Label Initial Prompt"
desc="Provide labels for the following prompt"
messages={<MessageView text={task.prompt} is_assistant message_id={task.message_id} />}
inputs={<LabelSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />}
controls={
<TaskControls
tasks={tasks}
onSkip={reset}
onSubmitResponse={({ id, task }: LabelInitialPromptTaskResponse) =>
submit(id, task.message_id, task.prompt, task.valid_labels, sliderValues)
}
/>
}
/>
);
};
export default LabelInitialPrompt;
// TODO: consolidate with FlaggableElement
interface CheckboxSliderGroupProps {
labelIDs: Array<string>;
onChange: (sliderValues: number[]) => unknown;
}
const CheckboxSliderGroup = ({ labelIDs, onChange }: CheckboxSliderGroupProps) => {
const [sliderValues, setSliderValues] = useState<number[]>(Array.from({ length: labelIDs.length }).map(() => 0));
useEffect(() => {
onChange(sliderValues);
}, [sliderValues, onChange]);
return (
<Grid templateColumns="auto 1fr" rowGap={1} columnGap={3}>
{labelIDs.map((labelId, idx) => (
<CheckboxSliderItem
key={idx}
labelId={labelId}
sliderValue={sliderValues[idx]}
sliderHandler={(sliderValue) => {
const newState = sliderValues.slice();
newState[idx] = sliderValue;
setSliderValues(newState);
}}
/>
))}
</Grid>
);
};
function CheckboxSliderItem(props: {
labelId: string;
sliderValue: number;
sliderHandler: (newVal: number) => unknown;
}) {
const id = useId();
const { colorMode } = useColorMode();
const labelTextClass = colorMode === "light" ? `text-${colors.light.text}` : `text-${colors.dark.text}`;
return (
<>
<label className="text-sm" htmlFor={id}>
{/* TODO: display real text instead of just the id */}
<span className={labelTextClass}>{props.labelId}</span>
</label>
<Slider defaultValue={0} onChangeEnd={(val) => props.sliderHandler(val / 100)}>
<SliderTrack>
<SliderFilledTrack />
<SliderThumb />
</SliderTrack>
</Slider>
</>
);
}
@@ -0,0 +1,46 @@
import { useState } from "react";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Message } from "src/components/Messages";
import { MessageTable } from "src/components/Messages/MessageTable";
import { TaskControls } from "src/components/Survey/TaskControls";
import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask";
import {
LabelPrompterReplyTaskResponse,
useLabelPrompterReplyTask,
} from "src/hooks/tasks/labeling/useLabelPrompterReply";
const LabelPrompterReply = () => {
const [sliderValues, setSliderValues] = useState<number[]>([]);
const { tasks, isLoading, submit, reset } = useLabelPrompterReplyTask();
if (isLoading || tasks.length === 0) {
return <LoadingScreen />;
}
const task = tasks[0].task;
const messages: Message[] = [
...task.conversation.messages,
{ text: task.reply, is_assistant: false, message_id: task.message_id },
];
return (
<LabelTask
title="Label Prompter Reply"
desc="Given the following discussion, provide labels for the final prompt"
messages={<MessageTable messages={messages} />}
inputs={<LabelSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />}
controls={
<TaskControls
tasks={tasks}
onSkip={reset}
onSubmitResponse={({ id, task }: LabelPrompterReplyTaskResponse) =>
submit(id, task.message_id, task.reply, task.valid_labels, sliderValues)
}
/>
}
/>
);
};
export default LabelPrompterReply;
+13 -13
View File
@@ -2,6 +2,7 @@ import { Box, CircularProgress, SimpleGrid, Text, useColorModeValue } from "@cha
import Head from "next/head";
import { useEffect, useState } from "react";
import { getDashboardLayout } from "src/components/Layout";
import { Message } from "src/components/Messages";
import { MessageTable } from "src/components/Messages/MessageTable";
import fetcher from "src/lib/fetcher";
import useSWRImmutable from "swr/immutable";
@@ -10,29 +11,28 @@ const MessagesDashboard = () => {
const boxBgColor = useColorModeValue("white", "gray.700");
const boxAccentColor = useColorModeValue("gray.200", "gray.900");
const [messages, setMessages] = useState([]);
const [userMessages, setUserMessages] = useState([]);
const [messages, setMessages] = useState<Message[]>(null);
const [userMessages, setUserMessages] = useState<Message[]>(null);
const { isLoading: isLoadingAll, mutate: mutateAll } = useSWRImmutable("/api/messages", fetcher, {
onSuccess: (data) => {
setMessages(data);
},
onSuccess: setMessages,
});
const { isLoading: isLoadingUser, mutate: mutateUser } = useSWRImmutable(`/api/messages/user`, fetcher, {
onSuccess: (data) => {
setUserMessages(data);
},
onSuccess: setUserMessages,
});
const receivedMessages = !isLoadingAll && Array.isArray(messages);
const receivedUserMessages = !isLoadingUser && Array.isArray(userMessages);
useEffect(() => {
if (messages.length == 0) {
if (!receivedMessages) {
mutateAll();
}
if (userMessages.length == 0) {
if (!receivedUserMessages) {
mutateUser();
}
}, [messages, userMessages]);
}, [receivedMessages, mutateAll, receivedUserMessages, mutateUser]);
return (
<>
@@ -52,7 +52,7 @@ const MessagesDashboard = () => {
borderRadius="xl"
className="p-6 shadow-sm"
>
{isLoadingAll ? <CircularProgress isIndeterminate /> : <MessageTable messages={messages} />}
{receivedMessages ? <MessageTable messages={messages} /> : <CircularProgress isIndeterminate />}
</Box>
</Box>
<Box>
@@ -66,7 +66,7 @@ const MessagesDashboard = () => {
borderRadius="xl"
className="p-6 shadow-sm"
>
{isLoadingUser ? <CircularProgress isIndeterminate /> : <MessageTable messages={userMessages} />}
{receivedUserMessages ? <MessageTable messages={userMessages} /> : <CircularProgress isIndeterminate />}
</Box>
</Box>
</SimpleGrid>
+19
View File
@@ -0,0 +1,19 @@
import Head from "next/head";
import { TaskOption } from "src/components/Dashboard";
import { getDashboardLayout } from "src/components/Layout";
const AllTasks = () => {
return (
<>
<Head>
<title>All Tasks - Open Assistant</title>
<meta name="description" content="All tasks for Open Assistant." />
</Head>
<TaskOption />
</>
);
};
AllTasks.getLayout = (page) => getDashboardLayout(page);
export default AllTasks;