mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-28 16:20:34 +08:00
[NEW] Solving merge conflicts
This commit is contained in:
@@ -16,3 +16,12 @@ jobs:
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- uses: pre-commit/action@v3.0.0
|
||||
- name: Post PR comment on failure
|
||||
if: failure() && github.event_name == 'pull_request'
|
||||
uses: peter-evans/create-or-update-comment@v2
|
||||
with:
|
||||
issue-number: ${{ github.event.pull_request.number }}
|
||||
body: |
|
||||
:x: **pre-commit** failed.
|
||||
Please run `pre-commit run --all-files` locally and commit the changes.
|
||||
Find more information in the repository's CONTRIBUTING.md
|
||||
|
||||
+4
-2
@@ -96,8 +96,10 @@ The website is built using Next.js and is in the `website` folder.
|
||||
|
||||
### Pre-commit
|
||||
|
||||
Install `pre-commit` and run `pre-commit install` to install the pre-commit
|
||||
hooks.
|
||||
We are using `pre-commit` to enforce code style and formatting.
|
||||
|
||||
Install `pre-commit` from [its website](https://pre-commit.com) and run
|
||||
`pre-commit install` to install the pre-commit hooks.
|
||||
|
||||
In case you haven't done this, have already committed, and CI is failing, you
|
||||
can run `pre-commit run --all-files` to run the pre-commit hooks on all files.
|
||||
|
||||
+11
-6
@@ -14,7 +14,7 @@ from oasst_backend.api.deps import get_dummy_api_client
|
||||
from oasst_backend.api.v1.api import api_router
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.database import engine
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.prompt_repository import PromptRepository, TaskRepository, UserRepository
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from pydantic import BaseModel
|
||||
@@ -110,7 +110,12 @@ if settings.DEBUG_USE_SEED_DATA:
|
||||
with Session(engine) as db:
|
||||
api_client = get_dummy_api_client(db)
|
||||
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
|
||||
pr = PromptRepository(db=db, api_client=api_client, user=dummy_user)
|
||||
|
||||
ur = UserRepository(db=db, api_client=api_client)
|
||||
tr = TaskRepository(db=db, api_client=api_client, client_user=dummy_user, user_repository=ur)
|
||||
pr = PromptRepository(
|
||||
db=db, api_client=api_client, client_user=dummy_user, user_repository=ur, task_repository=tr
|
||||
)
|
||||
|
||||
with open(settings.DEBUG_USE_SEED_DATA_PATH) as f:
|
||||
dummy_messages_raw = json.load(f)
|
||||
@@ -118,14 +123,14 @@ if settings.DEBUG_USE_SEED_DATA:
|
||||
dummy_messages = [DummyMessage(**dm) for dm in dummy_messages_raw]
|
||||
|
||||
for msg in dummy_messages:
|
||||
task = pr.fetch_task_by_frontend_message_id(msg.task_message_id)
|
||||
task = tr.fetch_task_by_frontend_message_id(msg.task_message_id)
|
||||
if task and not task.ack:
|
||||
logger.warning("Deleting unacknowledged seed data task")
|
||||
db.delete(task)
|
||||
task = None
|
||||
if not task:
|
||||
if msg.parent_message_id is None:
|
||||
task = pr.store_task(
|
||||
task = tr.store_task(
|
||||
protocol_schema.InitialPromptTask(hint=""), message_tree_id=None, parent_message_id=None
|
||||
)
|
||||
else:
|
||||
@@ -144,12 +149,12 @@ if settings.DEBUG_USE_SEED_DATA:
|
||||
for cmsg in conversation_messages
|
||||
]
|
||||
)
|
||||
task = pr.store_task(
|
||||
task = tr.store_task(
|
||||
protocol_schema.AssistantReplyTask(conversation=conversation),
|
||||
message_tree_id=parent_message.message_tree_id,
|
||||
parent_message_id=parent_message.id,
|
||||
)
|
||||
pr.bind_frontend_message_id(task.id, msg.task_message_id)
|
||||
tr.bind_frontend_message_id(task.id, msg.task_message_id)
|
||||
message = pr.store_text_reply(msg.text, msg.task_message_id, msg.user_message_id)
|
||||
|
||||
logger.info(
|
||||
|
||||
@@ -16,7 +16,7 @@ def get_message_by_frontend_id(
|
||||
"""
|
||||
Get a message by its frontend ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
return utils.prepare_message(message)
|
||||
|
||||
@@ -29,7 +29,7 @@ def get_conv_by_frontend_id(
|
||||
Get a conversation from the tree root and up to the message with given frontend ID.
|
||||
"""
|
||||
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
messages = pr.fetch_message_conversation(message)
|
||||
return utils.prepare_conversation(messages)
|
||||
@@ -43,7 +43,7 @@ def get_tree_by_frontend_id(
|
||||
Get all messages belonging to the same message tree.
|
||||
Message is identified by its frontend ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
tree = pr.fetch_message_tree(message.message_tree_id)
|
||||
return utils.prepare_tree(tree, message.message_tree_id)
|
||||
@@ -56,7 +56,7 @@ def get_children_by_frontend_id(
|
||||
"""
|
||||
Get all messages belonging to the same message tree.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
messages = pr.fetch_message_children(message.id)
|
||||
return utils.prepare_message_list(messages)
|
||||
@@ -70,7 +70,7 @@ def get_descendants_by_frontend_id(
|
||||
Get a subtree which starts with this message.
|
||||
The message is identified by its frontend ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
descendants = pr.fetch_message_descendants(message)
|
||||
return utils.prepare_tree(descendants, message.id)
|
||||
@@ -84,7 +84,7 @@ def get_longest_conv_by_frontend_id(
|
||||
Get the longest conversation from the tree of the message.
|
||||
The message is identified by its frontend ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
conv = pr.fetch_longest_conversation(message.message_tree_id)
|
||||
return utils.prepare_conversation(conv)
|
||||
@@ -98,7 +98,7 @@ def get_max_children_by_frontend_id(
|
||||
Get message with the most children from the tree of the provided message.
|
||||
The message is identified by its frontend ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
message, children = pr.fetch_message_with_max_children(message.message_tree_id)
|
||||
return utils.prepare_tree([message, *children], message.id)
|
||||
|
||||
@@ -29,7 +29,7 @@ def query_frontend_user_messages(
|
||||
"""
|
||||
Query frontend user messages.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
messages = pr.query_messages(
|
||||
username=username,
|
||||
api_client_id=api_client_id,
|
||||
@@ -47,6 +47,6 @@ def query_frontend_user_messages(
|
||||
def mark_frontend_user_messages_deleted(
|
||||
username: str, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
pr = PromptRepository(db, api_client, None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
messages = pr.query_messages(username=username, api_client_id=api_client.id)
|
||||
pr.mark_messages_deleted(messages)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.user_repository import UserRepository
|
||||
from oasst_shared.schemas.protocol import LeaderboardStats
|
||||
from sqlmodel import Session
|
||||
|
||||
router = APIRouter()
|
||||
@@ -11,15 +12,15 @@ router = APIRouter()
|
||||
def get_assistant_leaderboard(
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
):
|
||||
pr = PromptRepository(db, api_client, None)
|
||||
return pr.get_user_leaderboard(role="assistant")
|
||||
) -> LeaderboardStats:
|
||||
ur = UserRepository(db, api_client)
|
||||
return ur.get_user_leaderboard(role="assistant")
|
||||
|
||||
|
||||
@router.get("/create/prompter")
|
||||
def get_prompter_leaderboard(
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
):
|
||||
pr = PromptRepository(db, api_client, None)
|
||||
return pr.get_user_leaderboard(role="prompter")
|
||||
) -> LeaderboardStats:
|
||||
ur = UserRepository(db, api_client)
|
||||
return ur.get_user_leaderboard(role="prompter")
|
||||
|
||||
@@ -29,7 +29,7 @@ def query_messages(
|
||||
"""
|
||||
Query messages.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
messages = pr.query_messages(
|
||||
username=username,
|
||||
api_client_id=api_client_id,
|
||||
@@ -51,7 +51,7 @@ def get_message(
|
||||
"""
|
||||
Get a message by its internal ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
message = pr.fetch_message(message_id)
|
||||
return utils.prepare_message(message)
|
||||
|
||||
@@ -64,7 +64,7 @@ def get_conv(
|
||||
Get a conversation from the tree root and up to the message with given internal ID.
|
||||
"""
|
||||
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
messages = pr.fetch_message_conversation(message_id)
|
||||
return utils.prepare_conversation(messages)
|
||||
|
||||
@@ -76,7 +76,7 @@ def get_tree(
|
||||
"""
|
||||
Get all messages belonging to the same message tree.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
message = pr.fetch_message(message_id)
|
||||
tree = pr.fetch_message_tree(message.message_tree_id)
|
||||
return utils.prepare_tree(tree, message.message_tree_id)
|
||||
@@ -89,7 +89,7 @@ def get_children(
|
||||
"""
|
||||
Get all messages belonging to the same message tree.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
messages = pr.fetch_message_children(message_id)
|
||||
return utils.prepare_message_list(messages)
|
||||
|
||||
@@ -101,7 +101,7 @@ def get_descendants(
|
||||
"""
|
||||
Get a subtree which starts with this message.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
message = pr.fetch_message(message_id)
|
||||
descendants = pr.fetch_message_descendants(message)
|
||||
return utils.prepare_tree(descendants, message.id)
|
||||
@@ -114,7 +114,7 @@ def get_longest_conv(
|
||||
"""
|
||||
Get the longest conversation from the tree of the message.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
message = pr.fetch_message(message_id)
|
||||
conv = pr.fetch_longest_conversation(message.message_tree_id)
|
||||
return utils.prepare_conversation(conv)
|
||||
@@ -127,7 +127,7 @@ def get_max_children(
|
||||
"""
|
||||
Get message with the most children from the tree of the provided message.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
message = pr.fetch_message(message_id)
|
||||
message, children = pr.fetch_message_with_max_children(message.message_tree_id)
|
||||
return utils.prepare_tree([message, *children], message.id)
|
||||
@@ -137,5 +137,5 @@ def get_max_children(
|
||||
def mark_message_deleted(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
pr = PromptRepository(db, api_client, None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
pr.mark_messages_deleted(message_id)
|
||||
|
||||
@@ -13,5 +13,5 @@ def get_message_stats(
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
):
|
||||
pr = PromptRepository(db, api_client, None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
return pr.get_stats()
|
||||
|
||||
@@ -8,7 +8,7 @@ from loguru import logger
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.api.v1.utils import prepare_conversation
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.prompt_repository import PromptRepository, TaskRepository
|
||||
from oasst_backend.utils.hugging_face import HF_embeddingModel, HF_url, HuggingFaceAPI
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
@@ -192,9 +192,9 @@ def request_task(
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
|
||||
try:
|
||||
pr = PromptRepository(db, api_client, request.user)
|
||||
pr = PromptRepository(db, api_client, client_user=request.user)
|
||||
task, message_tree_id, parent_message_id = generate_task(request, pr)
|
||||
pr.store_task(task, message_tree_id, parent_message_id, request.collective)
|
||||
pr.task_repository.store_task(task, message_tree_id, parent_message_id, request.collective)
|
||||
|
||||
except OasstError:
|
||||
raise
|
||||
@@ -219,11 +219,11 @@ def tasks_acknowledge(
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
|
||||
try:
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
|
||||
# here we store the message id in the database for the task
|
||||
logger.info(f"Frontend acknowledges task {task_id=}, {ack_request=}.")
|
||||
pr.bind_frontend_message_id(task_id=task_id, frontend_message_id=ack_request.message_id)
|
||||
pr.task_repository.bind_frontend_message_id(task_id=task_id, frontend_message_id=ack_request.message_id)
|
||||
|
||||
except OasstError:
|
||||
raise
|
||||
@@ -247,8 +247,8 @@ def tasks_acknowledge_failure(
|
||||
try:
|
||||
logger.info(f"Frontend reports failure to implement task {task_id=}, {nack_request=}.")
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr.acknowledge_task_failure(task_id)
|
||||
pr = PromptRepository(db, api_client)
|
||||
pr.task_repository.acknowledge_task_failure(task_id)
|
||||
except (KeyError, RuntimeError):
|
||||
logger.exception("Failed to not acknowledge task.")
|
||||
raise OasstError("Failed to not acknowledge task.", OasstErrorCode.TASK_NACK_FAILED)
|
||||
@@ -267,7 +267,7 @@ async def tasks_interaction(
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
|
||||
try:
|
||||
pr = PromptRepository(db, api_client, user=interaction.user)
|
||||
pr = PromptRepository(db, api_client, client_user=interaction.user)
|
||||
|
||||
match type(interaction):
|
||||
case protocol_schema.TextReplyToMessage:
|
||||
@@ -339,6 +339,6 @@ def close_collective_task(
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
):
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr.close_task(close_task_request.message_id)
|
||||
tr = TaskRepository(db, api_client)
|
||||
tr.close_task(close_task_request.message_id)
|
||||
return protocol_schema.TaskDone()
|
||||
|
||||
@@ -25,7 +25,7 @@ def label_text(
|
||||
|
||||
try:
|
||||
logger.info(f"Labeling text {text_labels=}.")
|
||||
pr = PromptRepository(db, api_client, user=text_labels.user)
|
||||
pr = PromptRepository(db, api_client, client_user=text_labels.user)
|
||||
pr.store_text_labels(text_labels)
|
||||
|
||||
except Exception:
|
||||
|
||||
@@ -29,7 +29,7 @@ def query_user_messages(
|
||||
"""
|
||||
Query user messages.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
messages = pr.query_messages(
|
||||
user_id=user_id,
|
||||
api_client_id=api_client_id,
|
||||
@@ -48,6 +48,6 @@ def query_user_messages(
|
||||
def mark_user_messages_deleted(
|
||||
user_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
pr = PromptRepository(db, api_client, None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
messages = pr.query_messages(user_id=user_id)
|
||||
pr.mark_messages_deleted(messages)
|
||||
|
||||
@@ -6,27 +6,56 @@ import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from sqlmodel import Field, Index, SQLModel
|
||||
|
||||
# The types of States a message tree can have.
|
||||
|
||||
class States(str, Enum):
|
||||
"""States of the Open-Assistant message tree state machine."""
|
||||
|
||||
INITIAL_PROMPT_REVIEW = "initial_prompt_review"
|
||||
"""In this state the message tree consists only of a single inital prompt root node.
|
||||
Initial prompt labeling tasks will determine if the tree goes into `breeding_phase` or
|
||||
`aborted_low_grade`."""
|
||||
|
||||
class States(Enum):
|
||||
INITIAL = "initial"
|
||||
BREEDING_PHASE = "breeding_phase"
|
||||
"""Assistant & prompter human demonstrations are collected. Concurrently labeling tasks
|
||||
are handed out to check if the quality of the replies surpasses the minimum acceptable
|
||||
quality.
|
||||
When the required number of messages passing the initial labelling-quality check has been
|
||||
collected the tree will enter `ranking_phase`. If too many poor-quality labelling responses
|
||||
are received the tree can also enter the `aborted_low_grade` state."""
|
||||
|
||||
RANKING_PHASE = "ranking_phase"
|
||||
"""The tree has been successfully populated with the desired number of messages. Ranking
|
||||
tasks are now handed out for all nodes with more than one child."""
|
||||
|
||||
READY_FOR_SCORING = "ready_for_scoring"
|
||||
CHILDREN_SCORED = "children_scored"
|
||||
FINAL = "final"
|
||||
"""Required ranking responses have been collected and the scoring algorithm can now
|
||||
compute the aggergated ranking scores that will appear in the dataset."""
|
||||
|
||||
READY_FOR_EXPORT = "ready_for_export"
|
||||
"""The Scoring algorithm computed rankings scores for all childern. The message tree can be
|
||||
exported as part of an Open-Assistant message tree dataset."""
|
||||
|
||||
SCORING_FAILED = "scoring_failed"
|
||||
"""An exception occured in the scoring algorithm."""
|
||||
|
||||
ABORTED_LOW_GRADE = "aborted_low_grade"
|
||||
"""The system received too many bad reviews and stopped handing out tasks for this message tree."""
|
||||
|
||||
HALTED_BY_MODERATOR = "halted_by_moderator"
|
||||
"""A moderator decided to manually halt the message tree construction process."""
|
||||
|
||||
|
||||
VALID_STATES = (
|
||||
States.INITIAL,
|
||||
States.INITIAL_PROMPT_REVIEW,
|
||||
States.BREEDING_PHASE,
|
||||
States.RANKING_PHASE,
|
||||
States.READY_FOR_SCORING,
|
||||
States.CHILDREN_SCORED,
|
||||
States.FINAL,
|
||||
States.READY_FOR_EXPORT,
|
||||
States.ABORTED_LOW_GRADE,
|
||||
)
|
||||
|
||||
TERMINAL_STATES = (States.READY_FOR_EXPORT, States.ABORTED_LOW_GRADE, States.SCORING_FAILED, States.HALTED_BY_MODERATOR)
|
||||
|
||||
|
||||
class MessageTreeState(SQLModel, table=True):
|
||||
__tablename__ = "message_tree_state"
|
||||
|
||||
@@ -8,98 +8,39 @@ from uuid import UUID, uuid4
|
||||
import oasst_backend.models.db_payload as db_payload
|
||||
from loguru import logger
|
||||
from oasst_backend.journal_writer import JournalWriter
|
||||
from oasst_backend.models import ApiClient, Message, MessageEmbedding, MessageReaction, Task, TextLabels, User
|
||||
from oasst_backend.models import ApiClient, Message, MessageEmbedding, MessageReaction, TextLabels, User
|
||||
from oasst_backend.models.payload_column_type import PayloadContainer
|
||||
from oasst_backend.task_repository import TaskRepository, validate_frontend_message_id
|
||||
from oasst_backend.user_repository import UserRepository
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from oasst_shared.schemas.protocol import LeaderboardStats, SystemStats
|
||||
from oasst_shared.schemas.protocol import SystemStats
|
||||
from sqlalchemy import update
|
||||
from sqlmodel import Session, func
|
||||
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
class PromptRepository:
|
||||
def __init__(self, db: Session, api_client: ApiClient, user: Optional[protocol_schema.User]):
|
||||
def __init__(
|
||||
self,
|
||||
db: Session,
|
||||
api_client: ApiClient,
|
||||
client_user: Optional[protocol_schema.User] = None,
|
||||
user_repository: Optional[UserRepository] = None,
|
||||
task_repository: Optional[TaskRepository] = None,
|
||||
):
|
||||
self.db = db
|
||||
self.api_client = api_client
|
||||
self.user = self.lookup_user(user)
|
||||
self.user_repository = user_repository or UserRepository(db, api_client)
|
||||
self.user = self.user_repository.lookup_client_user(client_user, create_missing=True)
|
||||
self.user_id = self.user.id if self.user else None
|
||||
self.task_repository = task_repository or TaskRepository(
|
||||
db, api_client, client_user, user_repository=self.user_repository
|
||||
)
|
||||
self.journal = JournalWriter(db, api_client, self.user)
|
||||
|
||||
def lookup_user(self, client_user: protocol_schema.User) -> Optional[User]:
|
||||
if not client_user:
|
||||
return None
|
||||
user: User = (
|
||||
self.db.query(User)
|
||||
.filter(
|
||||
User.api_client_id == self.api_client.id,
|
||||
User.username == client_user.id,
|
||||
User.auth_method == client_user.auth_method,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if user is None:
|
||||
# user is unknown, create new record
|
||||
user = User(
|
||||
username=client_user.id,
|
||||
display_name=client_user.display_name,
|
||||
api_client_id=self.api_client.id,
|
||||
auth_method=client_user.auth_method,
|
||||
)
|
||||
self.db.add(user)
|
||||
self.db.commit()
|
||||
self.db.refresh(user)
|
||||
elif client_user.display_name and client_user.display_name != user.display_name:
|
||||
# we found the user but the display name changed
|
||||
user.display_name = client_user.display_name
|
||||
self.db.add(user)
|
||||
self.db.commit()
|
||||
return user
|
||||
|
||||
def validate_frontend_message_id(self, message_id: str) -> None:
|
||||
# TODO: Should it be replaced with fastapi/pydantic validation?
|
||||
if not isinstance(message_id, str):
|
||||
raise OasstError(
|
||||
f"message_id must be string, not {type(message_id)}", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID
|
||||
)
|
||||
if not message_id:
|
||||
raise OasstError("message_id must not be empty", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID)
|
||||
|
||||
def bind_frontend_message_id(self, task_id: UUID, frontend_message_id: str):
|
||||
self.validate_frontend_message_id(frontend_message_id)
|
||||
|
||||
# find task
|
||||
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
|
||||
if task is None:
|
||||
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
|
||||
if task.expired:
|
||||
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
|
||||
if task.done or task.ack is not None:
|
||||
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
|
||||
|
||||
task.frontend_message_id = frontend_message_id
|
||||
task.ack = True
|
||||
# ToDo: check race-condition, transaction
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
|
||||
def acknowledge_task_failure(self, task_id):
|
||||
# find task
|
||||
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
|
||||
if task is None:
|
||||
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
|
||||
if task.expired:
|
||||
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
|
||||
if task.done or task.ack is not None:
|
||||
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
|
||||
|
||||
task.ack = False
|
||||
# ToDo: check race-condition, transaction
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
|
||||
def fetch_message_by_frontend_message_id(self, frontend_message_id: str, fail_if_missing: bool = True) -> Message:
|
||||
self.validate_frontend_message_id(frontend_message_id)
|
||||
validate_frontend_message_id(frontend_message_id)
|
||||
message: Message = (
|
||||
self.db.query(Message)
|
||||
.filter(Message.api_client_id == self.api_client.id, Message.frontend_message_id == frontend_message_id)
|
||||
@@ -113,22 +54,54 @@ class PromptRepository:
|
||||
)
|
||||
return message
|
||||
|
||||
def fetch_task_by_frontend_message_id(self, message_id: str) -> Task:
|
||||
self.validate_frontend_message_id(message_id)
|
||||
task = (
|
||||
self.db.query(Task)
|
||||
.filter(Task.api_client_id == self.api_client.id, Task.frontend_message_id == message_id)
|
||||
.one_or_none()
|
||||
def insert_message(
|
||||
self,
|
||||
*,
|
||||
message_id: UUID,
|
||||
frontend_message_id: str,
|
||||
parent_id: UUID,
|
||||
message_tree_id: UUID,
|
||||
task_id: UUID,
|
||||
role: str,
|
||||
payload: db_payload.MessagePayload,
|
||||
payload_type: str = None,
|
||||
depth: int = 0,
|
||||
) -> Message:
|
||||
if payload_type is None:
|
||||
if payload is None:
|
||||
payload_type = "null"
|
||||
else:
|
||||
payload_type = type(payload).__name__
|
||||
|
||||
message = Message(
|
||||
id=message_id,
|
||||
parent_id=parent_id,
|
||||
message_tree_id=message_tree_id,
|
||||
task_id=task_id,
|
||||
user_id=self.user_id,
|
||||
role=role,
|
||||
frontend_message_id=frontend_message_id,
|
||||
api_client_id=self.api_client.id,
|
||||
payload_type=payload_type,
|
||||
payload=PayloadContainer(payload=payload),
|
||||
depth=depth,
|
||||
)
|
||||
return task
|
||||
self.db.add(message)
|
||||
self.db.commit()
|
||||
self.db.refresh(message)
|
||||
return message
|
||||
|
||||
def store_text_reply(
|
||||
self, text: str, frontend_message_id: str, user_frontend_message_id: str, miniLM_embedding: List[float] = None
|
||||
self,
|
||||
text: str,
|
||||
frontend_message_id: str,
|
||||
user_frontend_message_id: str,
|
||||
miniLM_embedding: Optional[List[float]] = None,
|
||||
) -> Message:
|
||||
self.validate_frontend_message_id(frontend_message_id)
|
||||
self.validate_frontend_message_id(user_frontend_message_id)
|
||||
validate_frontend_message_id(frontend_message_id)
|
||||
validate_frontend_message_id(user_frontend_message_id)
|
||||
|
||||
task = self.fetch_task_by_frontend_message_id(frontend_message_id)
|
||||
task = self.task_repository.fetch_task_by_frontend_message_id(frontend_message_id)
|
||||
|
||||
if task is None:
|
||||
raise OasstError(f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND)
|
||||
@@ -176,7 +149,7 @@ class PromptRepository:
|
||||
def store_rating(self, rating: protocol_schema.MessageRating) -> MessageReaction:
|
||||
message = self.fetch_message_by_frontend_message_id(rating.message_id, fail_if_missing=True)
|
||||
|
||||
task = self.fetch_task_by_frontend_message_id(rating.message_id)
|
||||
task = self.task_repository.fetch_task_by_frontend_message_id(rating.message_id)
|
||||
task_payload: db_payload.RateSummaryPayload = task.payload.payload
|
||||
if type(task_payload) != db_payload.RateSummaryPayload:
|
||||
raise OasstError(
|
||||
@@ -203,7 +176,7 @@ class PromptRepository:
|
||||
|
||||
def store_ranking(self, ranking: protocol_schema.MessageRanking) -> MessageReaction:
|
||||
# fetch task
|
||||
task = self.fetch_task_by_frontend_message_id(ranking.message_id)
|
||||
task = self.task_repository.fetch_task_by_frontend_message_id(ranking.message_id)
|
||||
if not task.collective:
|
||||
task.done = True
|
||||
self.db.add(task)
|
||||
@@ -257,142 +230,6 @@ class PromptRepository:
|
||||
OasstErrorCode.TASK_PAYLOAD_TYPE_MISMATCH,
|
||||
)
|
||||
|
||||
def store_task(
|
||||
self,
|
||||
task: protocol_schema.Task,
|
||||
message_tree_id: UUID = None,
|
||||
parent_message_id: UUID = None,
|
||||
collective: bool = False,
|
||||
) -> Task:
|
||||
payload: db_payload.TaskPayload
|
||||
match type(task):
|
||||
case protocol_schema.SummarizeStoryTask:
|
||||
payload = db_payload.SummarizationStoryPayload(story=task.story)
|
||||
|
||||
case protocol_schema.RateSummaryTask:
|
||||
payload = db_payload.RateSummaryPayload(
|
||||
full_text=task.full_text, summary=task.summary, scale=task.scale
|
||||
)
|
||||
|
||||
case protocol_schema.InitialPromptTask:
|
||||
payload = db_payload.InitialPromptPayload(hint=task.hint)
|
||||
|
||||
case protocol_schema.PrompterReplyTask:
|
||||
payload = db_payload.PrompterReplyPayload(conversation=task.conversation, hint=task.hint)
|
||||
|
||||
case protocol_schema.AssistantReplyTask:
|
||||
payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation)
|
||||
|
||||
case protocol_schema.RankInitialPromptsTask:
|
||||
payload = db_payload.RankInitialPromptsPayload(type=task.type, prompts=task.prompts)
|
||||
|
||||
case protocol_schema.RankPrompterRepliesTask:
|
||||
payload = db_payload.RankPrompterRepliesPayload(
|
||||
type=task.type, conversation=task.conversation, replies=task.replies
|
||||
)
|
||||
|
||||
case protocol_schema.RankAssistantRepliesTask:
|
||||
payload = db_payload.RankAssistantRepliesPayload(
|
||||
type=task.type, conversation=task.conversation, replies=task.replies
|
||||
)
|
||||
|
||||
case protocol_schema.LabelInitialPromptTask:
|
||||
payload = db_payload.LabelInitialPromptPayload(
|
||||
type=task.type, message_id=task.message_id, prompt=task.prompt, valid_labels=task.valid_labels
|
||||
)
|
||||
|
||||
case protocol_schema.LabelPrompterReplyTask:
|
||||
payload = db_payload.LabelPrompterReplyPayload(
|
||||
type=task.type,
|
||||
message_id=task.message_id,
|
||||
conversation=task.conversation,
|
||||
reply=task.reply,
|
||||
valid_labels=task.valid_labels,
|
||||
)
|
||||
|
||||
case protocol_schema.LabelAssistantReplyTask:
|
||||
payload = db_payload.LabelAssistantReplyPayload(
|
||||
type=task.type,
|
||||
message_id=task.message_id,
|
||||
conversation=task.conversation,
|
||||
reply=task.reply,
|
||||
valid_labels=task.valid_labels,
|
||||
)
|
||||
|
||||
case _:
|
||||
raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE)
|
||||
|
||||
task = self.insert_task(
|
||||
payload=payload,
|
||||
id=task.id,
|
||||
message_tree_id=message_tree_id,
|
||||
parent_message_id=parent_message_id,
|
||||
collective=collective,
|
||||
)
|
||||
assert task.id == task.id
|
||||
return task
|
||||
|
||||
def insert_task(
|
||||
self,
|
||||
payload: db_payload.TaskPayload,
|
||||
id: UUID = None,
|
||||
message_tree_id: UUID = None,
|
||||
parent_message_id: UUID = None,
|
||||
collective: bool = False,
|
||||
) -> Task:
|
||||
c = PayloadContainer(payload=payload)
|
||||
task = Task(
|
||||
id=id,
|
||||
user_id=self.user_id,
|
||||
payload_type=type(payload).__name__,
|
||||
payload=c,
|
||||
api_client_id=self.api_client.id,
|
||||
message_tree_id=message_tree_id,
|
||||
parent_message_id=parent_message_id,
|
||||
collective=collective,
|
||||
)
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
self.db.refresh(task)
|
||||
return task
|
||||
|
||||
def insert_message(
|
||||
self,
|
||||
*,
|
||||
message_id: UUID,
|
||||
frontend_message_id: str,
|
||||
parent_id: UUID,
|
||||
message_tree_id: UUID,
|
||||
task_id: UUID,
|
||||
role: str,
|
||||
payload: db_payload.MessagePayload,
|
||||
payload_type: str = None,
|
||||
depth: int = 0,
|
||||
) -> Message:
|
||||
if payload_type is None:
|
||||
if payload is None:
|
||||
payload_type = "null"
|
||||
else:
|
||||
payload_type = type(payload).__name__
|
||||
|
||||
message = Message(
|
||||
id=message_id,
|
||||
parent_id=parent_id,
|
||||
message_tree_id=message_tree_id,
|
||||
task_id=task_id,
|
||||
user_id=self.user_id,
|
||||
role=role,
|
||||
frontend_message_id=frontend_message_id,
|
||||
api_client_id=self.api_client.id,
|
||||
payload_type=payload_type,
|
||||
payload=PayloadContainer(payload=payload),
|
||||
depth=depth,
|
||||
)
|
||||
self.db.add(message)
|
||||
self.db.commit()
|
||||
self.db.refresh(message)
|
||||
return message
|
||||
|
||||
def insert_message_embedding(self, message_id: UUID, model: str, embedding: List[float]) -> MessageEmbedding:
|
||||
if None in (message_id, model, embedding):
|
||||
raise OasstError("Paramters missing to add embedding", OasstErrorCode.GENERIC_ERROR)
|
||||
@@ -527,28 +364,6 @@ class PromptRepository:
|
||||
raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTP_404_NOT_FOUND)
|
||||
return message
|
||||
|
||||
def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False):
|
||||
"""
|
||||
Mark task as done. No further messages will be accepted for this task.
|
||||
"""
|
||||
self.validate_frontend_message_id(frontend_message_id)
|
||||
task = self.fetch_task_by_frontend_message_id(frontend_message_id)
|
||||
|
||||
if not task:
|
||||
raise OasstError(
|
||||
f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND
|
||||
)
|
||||
if task.expired:
|
||||
raise OasstError("Task already expired", OasstErrorCode.TASK_EXPIRED)
|
||||
if not allow_personal_tasks and not task.collective:
|
||||
raise OasstError("This is not a collective task", OasstErrorCode.TASK_NOT_COLLECTIVE)
|
||||
if task.done:
|
||||
raise OasstError("Allready closed", OasstErrorCode.TASK_ALREADY_DONE)
|
||||
|
||||
task.done = True
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
|
||||
@staticmethod
|
||||
def trace_conversation(messages: list[Message] | dict[UUID, Message], last_message: Message) -> list[Message]:
|
||||
"""
|
||||
@@ -740,24 +555,3 @@ class PromptRepository:
|
||||
deleted=result.get(True, 0),
|
||||
message_trees=result.get(None, 0),
|
||||
)
|
||||
|
||||
def get_user_leaderboard(self, role: str) -> LeaderboardStats:
|
||||
"""
|
||||
Get leaderboard stats for Messages created,
|
||||
separate leaderboard for prompts & assistants
|
||||
|
||||
"""
|
||||
query = (
|
||||
self.db.query(Message.user_id, User.username, User.display_name, func.count(Message.user_id))
|
||||
.join(User, User.id == Message.user_id, isouter=True)
|
||||
.filter(Message.deleted is not True, Message.role == role)
|
||||
.group_by(Message.user_id, User.username, User.display_name)
|
||||
.order_by(func.count(Message.user_id).desc())
|
||||
)
|
||||
|
||||
result = [
|
||||
{"ranking": i, "user_id": j[0], "username": j[1], "display_name": j[2], "score": j[3]}
|
||||
for i, j in enumerate(query.all(), start=1)
|
||||
]
|
||||
|
||||
return LeaderboardStats(leaderboard=result)
|
||||
|
||||
@@ -0,0 +1,199 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
import oasst_backend.models.db_payload as db_payload
|
||||
from oasst_backend.models import ApiClient, Task
|
||||
from oasst_backend.models.payload_column_type import PayloadContainer
|
||||
from oasst_backend.user_repository import UserRepository
|
||||
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
def validate_frontend_message_id(message_id: str) -> None:
|
||||
# TODO: Should it be replaced with fastapi/pydantic validation?
|
||||
if not isinstance(message_id, str):
|
||||
raise OasstError(
|
||||
f"message_id must be string, not {type(message_id)}", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID
|
||||
)
|
||||
if not message_id:
|
||||
raise OasstError("message_id must not be empty", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID)
|
||||
|
||||
|
||||
class TaskRepository:
|
||||
def __init__(
|
||||
self,
|
||||
db: Session,
|
||||
api_client: ApiClient,
|
||||
client_user: Optional[protocol_schema.User],
|
||||
user_repository: UserRepository,
|
||||
):
|
||||
self.db = db
|
||||
self.api_client = api_client
|
||||
self.user_repository = user_repository
|
||||
self.user = self.user_repository.lookup_client_user(client_user, create_missing=True)
|
||||
self.user_id = self.user.id if self.user else None
|
||||
|
||||
def store_task(
|
||||
self,
|
||||
task: protocol_schema.Task,
|
||||
message_tree_id: UUID = None,
|
||||
parent_message_id: UUID = None,
|
||||
collective: bool = False,
|
||||
) -> Task:
|
||||
payload: db_payload.TaskPayload
|
||||
match type(task):
|
||||
case protocol_schema.SummarizeStoryTask:
|
||||
payload = db_payload.SummarizationStoryPayload(story=task.story)
|
||||
|
||||
case protocol_schema.RateSummaryTask:
|
||||
payload = db_payload.RateSummaryPayload(
|
||||
full_text=task.full_text, summary=task.summary, scale=task.scale
|
||||
)
|
||||
|
||||
case protocol_schema.InitialPromptTask:
|
||||
payload = db_payload.InitialPromptPayload(hint=task.hint)
|
||||
|
||||
case protocol_schema.PrompterReplyTask:
|
||||
payload = db_payload.PrompterReplyPayload(conversation=task.conversation, hint=task.hint)
|
||||
|
||||
case protocol_schema.AssistantReplyTask:
|
||||
payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation)
|
||||
|
||||
case protocol_schema.RankInitialPromptsTask:
|
||||
payload = db_payload.RankInitialPromptsPayload(type=task.type, prompts=task.prompts)
|
||||
|
||||
case protocol_schema.RankPrompterRepliesTask:
|
||||
payload = db_payload.RankPrompterRepliesPayload(
|
||||
type=task.type, conversation=task.conversation, replies=task.replies
|
||||
)
|
||||
|
||||
case protocol_schema.RankAssistantRepliesTask:
|
||||
payload = db_payload.RankAssistantRepliesPayload(
|
||||
type=task.type, conversation=task.conversation, replies=task.replies
|
||||
)
|
||||
|
||||
case protocol_schema.LabelInitialPromptTask:
|
||||
payload = db_payload.LabelInitialPromptPayload(
|
||||
type=task.type, message_id=task.message_id, prompt=task.prompt, valid_labels=task.valid_labels
|
||||
)
|
||||
|
||||
case protocol_schema.LabelPrompterReplyTask:
|
||||
payload = db_payload.LabelPrompterReplyPayload(
|
||||
type=task.type,
|
||||
message_id=task.message_id,
|
||||
conversation=task.conversation,
|
||||
reply=task.reply,
|
||||
valid_labels=task.valid_labels,
|
||||
)
|
||||
|
||||
case protocol_schema.LabelAssistantReplyTask:
|
||||
payload = db_payload.LabelAssistantReplyPayload(
|
||||
type=task.type,
|
||||
message_id=task.message_id,
|
||||
conversation=task.conversation,
|
||||
reply=task.reply,
|
||||
valid_labels=task.valid_labels,
|
||||
)
|
||||
|
||||
case _:
|
||||
raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE)
|
||||
|
||||
task = self.insert_task(
|
||||
payload=payload,
|
||||
id=task.id,
|
||||
message_tree_id=message_tree_id,
|
||||
parent_message_id=parent_message_id,
|
||||
collective=collective,
|
||||
)
|
||||
assert task.id == task.id
|
||||
return task
|
||||
|
||||
def bind_frontend_message_id(self, task_id: UUID, frontend_message_id: str):
|
||||
validate_frontend_message_id(frontend_message_id)
|
||||
|
||||
# find task
|
||||
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
|
||||
if task is None:
|
||||
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
|
||||
if task.expired:
|
||||
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
|
||||
if task.done or task.ack is not None:
|
||||
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
|
||||
|
||||
task.frontend_message_id = frontend_message_id
|
||||
task.ack = True
|
||||
# ToDo: check race-condition, transaction
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
|
||||
def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False):
|
||||
"""
|
||||
Mark task as done. No further messages will be accepted for this task.
|
||||
"""
|
||||
validate_frontend_message_id(frontend_message_id)
|
||||
task = self.task_repository.fetch_task_by_frontend_message_id(frontend_message_id)
|
||||
|
||||
if not task:
|
||||
raise OasstError(
|
||||
f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND
|
||||
)
|
||||
if task.expired:
|
||||
raise OasstError("Task already expired", OasstErrorCode.TASK_EXPIRED)
|
||||
if not allow_personal_tasks and not task.collective:
|
||||
raise OasstError("This is not a collective task", OasstErrorCode.TASK_NOT_COLLECTIVE)
|
||||
if task.done:
|
||||
raise OasstError("Allready closed", OasstErrorCode.TASK_ALREADY_DONE)
|
||||
|
||||
task.done = True
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
|
||||
def acknowledge_task_failure(self, task_id):
|
||||
# find task
|
||||
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
|
||||
if task is None:
|
||||
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
|
||||
if task.expired:
|
||||
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
|
||||
if task.done or task.ack is not None:
|
||||
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
|
||||
|
||||
task.ack = False
|
||||
# ToDo: check race-condition, transaction
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
|
||||
def insert_task(
|
||||
self,
|
||||
payload: db_payload.TaskPayload,
|
||||
id: UUID = None,
|
||||
message_tree_id: UUID = None,
|
||||
parent_message_id: UUID = None,
|
||||
collective: bool = False,
|
||||
) -> Task:
|
||||
c = PayloadContainer(payload=payload)
|
||||
task = Task(
|
||||
id=id,
|
||||
user_id=self.user_id,
|
||||
payload_type=type(payload).__name__,
|
||||
payload=c,
|
||||
api_client_id=self.api_client.id,
|
||||
message_tree_id=message_tree_id,
|
||||
parent_message_id=parent_message_id,
|
||||
collective=collective,
|
||||
)
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
self.db.refresh(task)
|
||||
return task
|
||||
|
||||
def fetch_task_by_frontend_message_id(self, message_id: str) -> Task:
|
||||
validate_frontend_message_id(message_id)
|
||||
task = (
|
||||
self.db.query(Task)
|
||||
.filter(Task.api_client_id == self.api_client.id, Task.frontend_message_id == message_id)
|
||||
.one_or_none()
|
||||
)
|
||||
return task
|
||||
@@ -0,0 +1,64 @@
|
||||
from typing import Optional
|
||||
|
||||
from oasst_backend.models import ApiClient, Message, User
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from oasst_shared.schemas.protocol import LeaderboardStats
|
||||
from sqlmodel import Session, func
|
||||
|
||||
|
||||
class UserRepository:
|
||||
def __init__(self, db: Session, api_client: ApiClient):
|
||||
self.db = db
|
||||
self.api_client = api_client
|
||||
|
||||
def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]:
|
||||
if not client_user:
|
||||
return None
|
||||
user: User = (
|
||||
self.db.query(User)
|
||||
.filter(
|
||||
User.api_client_id == self.api_client.id,
|
||||
User.username == client_user.id,
|
||||
User.auth_method == client_user.auth_method,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if user is None:
|
||||
if create_missing:
|
||||
# user is unknown, create new record
|
||||
user = User(
|
||||
username=client_user.id,
|
||||
display_name=client_user.display_name,
|
||||
api_client_id=self.api_client.id,
|
||||
auth_method=client_user.auth_method,
|
||||
)
|
||||
self.db.add(user)
|
||||
self.db.commit()
|
||||
self.db.refresh(user)
|
||||
elif client_user.display_name and client_user.display_name != user.display_name:
|
||||
# we found the user but the display name changed
|
||||
user.display_name = client_user.display_name
|
||||
self.db.add(user)
|
||||
self.db.commit()
|
||||
return user
|
||||
|
||||
def get_user_leaderboard(self, role: str) -> LeaderboardStats:
|
||||
"""
|
||||
Get leaderboard stats for Messages created,
|
||||
separate leaderboard for prompts & assistants
|
||||
|
||||
"""
|
||||
query = (
|
||||
self.db.query(Message.user_id, User.username, User.display_name, func.count(Message.user_id))
|
||||
.join(User, User.id == Message.user_id, isouter=True)
|
||||
.filter(Message.deleted is not True, Message.role == role)
|
||||
.group_by(Message.user_id, User.username, User.display_name)
|
||||
.order_by(func.count(Message.user_id).desc())
|
||||
)
|
||||
|
||||
result = [
|
||||
{"ranking": i, "user_id": j[0], "username": j[1], "display_name": j[2], "score": j[3]}
|
||||
for i, j in enumerate(query.all(), start=1)
|
||||
]
|
||||
|
||||
return LeaderboardStats(leaderboard=result)
|
||||
@@ -276,17 +276,21 @@ class TextLabel(str, enum.Enum):
|
||||
fails_task = "fails_task", "Fails to follow the correct instruction / task"
|
||||
not_appropriate = "not_appropriate", "Inappropriate for customer assistant"
|
||||
violence = "violence", "Encourages or fails to discourage violence/abuse/terrorism/self-harm"
|
||||
harmful = (
|
||||
"harmful",
|
||||
"Harmful content",
|
||||
"The advice given in the output is harmful or counter-productive. This may be in addition to, but is distinct from the label for encouraging violence/abuse/terrorism/self-harm.",
|
||||
excessive_harm = (
|
||||
"excessive_harm",
|
||||
"Content likely to cause excessive harm not justifiable in the context",
|
||||
"Harm refers to physical or mental damage or injury to someone or something. Excessive refers to a reasonable threshold of harm in the context, for instance damaging skin is not excessive in the context of surgery.",
|
||||
)
|
||||
sexual_content = "sexual_content", "Contains sexual content"
|
||||
toxicity = "toxicity", "Contains rude, abusive, profane or insulting content"
|
||||
moral_judgement = "moral_judgement", "Expresses moral judgement"
|
||||
political_content = "political_content", "Expresses political views"
|
||||
humor = "humor", "Contains humorous content including sarcasm"
|
||||
hate_speech = "hate_speech", "Expresses sentiment which is discriminatory against a grouping of people"
|
||||
hate_speech = (
|
||||
"hate_speech",
|
||||
"Content is abusive or threatening and expresses prejudice against a protected characteristic",
|
||||
"Prejudice refers to preconceived views not based on reason. Protected characteristics include gender, ethnicity, religion, sexual orientation, and similar characteristics.",
|
||||
)
|
||||
threat = "threat", "Contains a threat against a person or persons"
|
||||
misleading = "misleading", "Contains text which is incorrect or misleading"
|
||||
helpful = "helpful", "Completes the task to a high standard"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { Progress } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
|
||||
export const LoadingScreen = ({ text }) => {
|
||||
export const LoadingScreen = ({ text = "Loading..." } = {}) => {
|
||||
const { colorMode } = useColorMode();
|
||||
const mainClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
|
||||
@@ -12,8 +12,7 @@ export interface Message {
|
||||
|
||||
export const Messages = ({ messages, post_id }: { messages: Message[]; post_id: string }) => {
|
||||
const items = messages.map((messageProps: Message, i: number) => {
|
||||
const { message_id } = messageProps;
|
||||
const { text } = messageProps;
|
||||
const { message_id, text } = messageProps;
|
||||
return (
|
||||
<FlaggableElement text={text} post_id={post_id} message_id={message_id} key={i + text}>
|
||||
<MessageView {...messageProps} />
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
import { Grid, Slider, SliderFilledTrack, SliderThumb, SliderTrack } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import { ReactNode, useEffect, useId, useMemo, useState } from "react";
|
||||
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
|
||||
import { colors } from "styles/Theme/colors";
|
||||
|
||||
export const LabelTask = ({
|
||||
title,
|
||||
desc,
|
||||
messages,
|
||||
inputs,
|
||||
controls,
|
||||
}: {
|
||||
title: string;
|
||||
desc: string;
|
||||
messages: ReactNode;
|
||||
inputs: ReactNode;
|
||||
controls: ReactNode;
|
||||
}) => {
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
const card = useMemo(
|
||||
() => (
|
||||
<>
|
||||
<h5 className="text-lg font-semibold">{title}</h5>
|
||||
<p className="text-lg py-1">{desc}</p>
|
||||
{messages}
|
||||
</>
|
||||
),
|
||||
[title, desc, messages]
|
||||
);
|
||||
|
||||
return (
|
||||
<div className={`p-12 ${mainBgClasses}`}>
|
||||
<TwoColumnsWithCards>
|
||||
{card}
|
||||
{inputs}
|
||||
</TwoColumnsWithCards>
|
||||
{controls}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
// TODO: consolidate with FlaggableElement
|
||||
interface LabelSliderGroupProps {
|
||||
labelIDs: Array<string>;
|
||||
onChange: (sliderValues: number[]) => unknown;
|
||||
}
|
||||
|
||||
export const LabelSliderGroup = ({ labelIDs, onChange }: LabelSliderGroupProps) => {
|
||||
const [sliderValues, setSliderValues] = useState<number[]>(Array.from({ length: labelIDs.length }).map(() => 0));
|
||||
|
||||
useEffect(() => {
|
||||
onChange(sliderValues);
|
||||
}, [sliderValues, onChange]);
|
||||
|
||||
return (
|
||||
<Grid templateColumns="auto 1fr" rowGap={1} columnGap={3}>
|
||||
{labelIDs.map((labelId, idx) => (
|
||||
<CheckboxSliderItem
|
||||
key={idx}
|
||||
labelId={labelId}
|
||||
sliderValue={sliderValues[idx]}
|
||||
sliderHandler={(sliderValue) => {
|
||||
const newState = sliderValues.slice();
|
||||
newState[idx] = sliderValue;
|
||||
setSliderValues(newState);
|
||||
}}
|
||||
/>
|
||||
))}
|
||||
</Grid>
|
||||
);
|
||||
};
|
||||
|
||||
function CheckboxSliderItem(props: {
|
||||
labelId: string;
|
||||
sliderValue: number;
|
||||
sliderHandler: (newVal: number) => unknown;
|
||||
}) {
|
||||
const id = useId();
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
const labelTextClass = colorMode === "light" ? `text-${colors.light.text}` : `text-${colors.dark.text}`;
|
||||
|
||||
return (
|
||||
<>
|
||||
<label className="text-sm" htmlFor={id}>
|
||||
{/* TODO: display real text instead of just the id */}
|
||||
<span className={labelTextClass}>{props.labelId}</span>
|
||||
</label>
|
||||
<Slider defaultValue={0} onChangeEnd={(val) => props.sliderHandler(val / 100)}>
|
||||
<SliderTrack>
|
||||
<SliderFilledTrack />
|
||||
<SliderThumb />
|
||||
</SliderTrack>
|
||||
</Slider>
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -63,4 +63,18 @@ export const TaskTypes = [
|
||||
pathname: "/label/label_initial_prompt",
|
||||
type: "label_initial_prompt",
|
||||
},
|
||||
{
|
||||
label: "Label Prompter Reply",
|
||||
desc: "Provide labels for a prompt.",
|
||||
category: TaskCategory.Label,
|
||||
pathname: "/label/label_prompter_reply",
|
||||
type: "label_prompter_reply",
|
||||
},
|
||||
{
|
||||
label: "Label Assistant Reply",
|
||||
desc: "Provide labels for a prompt.",
|
||||
category: TaskCategory.Label,
|
||||
pathname: "/label/label_assistant_reply",
|
||||
type: "label_assistant_reply",
|
||||
},
|
||||
];
|
||||
|
||||
@@ -1,4 +1,17 @@
|
||||
import { Table, TableCaption, TableContainer, Tbody, Td, Th, Thead, Tr } from "@chakra-ui/react";
|
||||
import {
|
||||
Button,
|
||||
Flex,
|
||||
Spacer,
|
||||
Stack,
|
||||
Table,
|
||||
TableCaption,
|
||||
TableContainer,
|
||||
Tbody,
|
||||
Td,
|
||||
Th,
|
||||
Thead,
|
||||
Tr,
|
||||
} from "@chakra-ui/react";
|
||||
import Link from "next/link";
|
||||
import { useState } from "react";
|
||||
import fetcher from "src/lib/fetcher";
|
||||
@@ -8,41 +21,60 @@ import useSWR from "swr";
|
||||
* Fetches users from the users api route and then presents them in a simple Chakra table.
|
||||
*/
|
||||
const UsersCell = () => {
|
||||
// Fetch and save the users.
|
||||
const [pageIndex, setPageIndex] = useState(0);
|
||||
const [users, setUsers] = useState([]);
|
||||
const { isLoading } = useSWR("/api/admin/users", fetcher, {
|
||||
|
||||
// Fetch and save the users.
|
||||
// This follows useSWR's recommendation for simple pagination:
|
||||
// https://swr.vercel.app/docs/pagination#when-to-use-useswr
|
||||
useSWR(`/api/admin/users?pageIndex=${pageIndex}`, fetcher, {
|
||||
onSuccess: setUsers,
|
||||
});
|
||||
|
||||
const toPreviousPage = () => {
|
||||
setPageIndex(Math.max(0, pageIndex - 1));
|
||||
};
|
||||
|
||||
const toNextPage = () => {
|
||||
setPageIndex(pageIndex + 1);
|
||||
};
|
||||
|
||||
// Present users in a naive table.
|
||||
return (
|
||||
<TableContainer>
|
||||
<Table variant="simple">
|
||||
<TableCaption>Users</TableCaption>
|
||||
<Thead>
|
||||
<Tr>
|
||||
<Th>Id</Th>
|
||||
<Th>Email</Th>
|
||||
<Th>Name</Th>
|
||||
<Th>Role</Th>
|
||||
<Th>Update</Th>
|
||||
</Tr>
|
||||
</Thead>
|
||||
<Tbody>
|
||||
{users.map((user, index) => (
|
||||
<Tr key={index}>
|
||||
<Td>{user.id}</Td>
|
||||
<Td>{user.email}</Td>
|
||||
<Td>{user.name}</Td>
|
||||
<Td>{user.role}</Td>
|
||||
<Td>
|
||||
<Link href={`/admin/manage_user/${user.id}`}>Manage</Link>
|
||||
</Td>
|
||||
<Stack>
|
||||
<Flex p="2">
|
||||
<Button onClick={toPreviousPage}>Previous</Button>
|
||||
<Spacer />
|
||||
<Button onClick={toNextPage}>Next</Button>
|
||||
</Flex>
|
||||
<TableContainer>
|
||||
<Table variant="simple">
|
||||
<TableCaption>Users</TableCaption>
|
||||
<Thead>
|
||||
<Tr>
|
||||
<Th>Id</Th>
|
||||
<Th>Email</Th>
|
||||
<Th>Name</Th>
|
||||
<Th>Role</Th>
|
||||
<Th>Update</Th>
|
||||
</Tr>
|
||||
))}
|
||||
</Tbody>
|
||||
</Table>
|
||||
</TableContainer>
|
||||
</Thead>
|
||||
<Tbody>
|
||||
{users.map((user, index) => (
|
||||
<Tr key={index}>
|
||||
<Td>{user.id}</Td>
|
||||
<Td>{user.email}</Td>
|
||||
<Td>{user.name}</Td>
|
||||
<Td>{user.role}</Td>
|
||||
<Td>
|
||||
<Link href={`/admin/manage_user/${user.id}`}>Manage</Link>
|
||||
</Td>
|
||||
</Tr>
|
||||
))}
|
||||
</Tbody>
|
||||
</Table>
|
||||
</TableContainer>
|
||||
</Stack>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
import { TaskResponse } from "../useGenericTaskAPI";
|
||||
import { LabelingTaskType, useLabelingTask } from "./useLabelingTask";
|
||||
|
||||
export interface LabelAssistantReplyTask {
|
||||
id: string;
|
||||
type: LabelingTaskType.label_assistant_reply;
|
||||
message_id: string;
|
||||
valid_labels: string[];
|
||||
reply: string;
|
||||
conversation: {
|
||||
messages: Array<{
|
||||
text: string;
|
||||
is_assistant: boolean;
|
||||
message_id: string;
|
||||
}>;
|
||||
};
|
||||
}
|
||||
|
||||
export type LabelAssistantReplyTaskResponse = TaskResponse<LabelAssistantReplyTask>;
|
||||
|
||||
export const useLabelAssistantReplyTask = () =>
|
||||
useLabelingTask<LabelAssistantReplyTask>(LabelingTaskType.label_assistant_reply);
|
||||
@@ -0,0 +1,15 @@
|
||||
import { TaskResponse } from "../useGenericTaskAPI";
|
||||
import { LabelingTaskType, useLabelingTask } from "./useLabelingTask";
|
||||
|
||||
export interface LabelInitialPromptTask {
|
||||
id: string;
|
||||
type: LabelingTaskType.label_initial_prompt;
|
||||
message_id: string;
|
||||
valid_labels: string[];
|
||||
prompt: string;
|
||||
}
|
||||
|
||||
export type LabelInitialPromptTaskResponse = TaskResponse<LabelInitialPromptTask>;
|
||||
|
||||
export const useLabelInitialPromptTask = () =>
|
||||
useLabelingTask<LabelInitialPromptTask>(LabelingTaskType.label_initial_prompt);
|
||||
@@ -0,0 +1,22 @@
|
||||
import { TaskResponse } from "../useGenericTaskAPI";
|
||||
import { LabelingTaskType, useLabelingTask } from "./useLabelingTask";
|
||||
|
||||
export interface LabelPrompterReplyTask {
|
||||
id: string;
|
||||
type: LabelingTaskType.label_prompter_reply;
|
||||
message_id: string;
|
||||
valid_labels: string[];
|
||||
reply: string;
|
||||
conversation: {
|
||||
messages: Array<{
|
||||
text: string;
|
||||
is_assistant: boolean;
|
||||
message_id: string;
|
||||
}>;
|
||||
};
|
||||
}
|
||||
|
||||
export type LabelPrompterReplyTaskResponse = TaskResponse<LabelPrompterReplyTask>;
|
||||
|
||||
export const useLabelPrompterReplyTask = () =>
|
||||
useLabelingTask<LabelPrompterReplyTask>(LabelingTaskType.label_prompter_reply);
|
||||
@@ -0,0 +1,20 @@
|
||||
import { useGenericTaskAPI } from "../useGenericTaskAPI";
|
||||
|
||||
export const enum LabelingTaskType {
|
||||
label_initial_prompt = "label_initial_prompt",
|
||||
label_prompter_reply = "label_prompter_reply",
|
||||
label_assistant_reply = "label_assistant_reply",
|
||||
}
|
||||
|
||||
export const useLabelingTask = <TaskType>(endpoint: LabelingTaskType) => {
|
||||
const { tasks, isLoading, trigger, reset, error } = useGenericTaskAPI<TaskType>(endpoint);
|
||||
|
||||
const submit = (id: string, message_id: string, text: string, validLabels: string[], labelWeights: number[]) => {
|
||||
console.assert(validLabels.length === labelWeights.length);
|
||||
const labels = Object.fromEntries(validLabels.map((label, i) => [label, labelWeights[i]]));
|
||||
|
||||
return trigger({ id, update_type: "text_labels", content: { labels, text, message_id } });
|
||||
};
|
||||
|
||||
return { tasks, isLoading, submit, reset, error };
|
||||
};
|
||||
@@ -0,0 +1,42 @@
|
||||
import { useEffect, useState } from "react";
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import poster from "src/lib/poster";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
|
||||
// TODO: type & centralize types for all tasks
|
||||
|
||||
export interface TaskResponse<TaskType> {
|
||||
id: string;
|
||||
userId: string;
|
||||
task: TaskType;
|
||||
}
|
||||
|
||||
export const useGenericTaskAPI = <TaskType,>(taskApiEndpoint: string) => {
|
||||
type ConcreteTaskResponse = TaskResponse<TaskType>;
|
||||
|
||||
const [tasks, setTasks] = useState<ConcreteTaskResponse[]>([]);
|
||||
|
||||
const { isLoading, mutate, error } = useSWRImmutable<ConcreteTaskResponse>(
|
||||
"/api/new_task/" + taskApiEndpoint,
|
||||
fetcher,
|
||||
{
|
||||
onSuccess: (data) => setTasks([data]),
|
||||
}
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (tasks.length === 0 && !isLoading && !error) {
|
||||
mutate();
|
||||
}
|
||||
}, [tasks, isLoading, mutate, error]);
|
||||
|
||||
const { trigger } = useSWRMutation("/api/update_task", poster, {
|
||||
onSuccess: async (response) => {
|
||||
const newTask: ConcreteTaskResponse = await response.json();
|
||||
setTasks((oldTasks) => [...oldTasks, newTask]);
|
||||
},
|
||||
});
|
||||
|
||||
return { tasks, isLoading, trigger, error, reset: mutate };
|
||||
};
|
||||
@@ -1,52 +0,0 @@
|
||||
import { useEffect, useState } from "react";
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import poster from "src/lib/poster";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
|
||||
// TODO: type & centralize types for all tasks
|
||||
interface TaskResponse<TaskType> {
|
||||
id: string;
|
||||
userId: string;
|
||||
task: TaskType;
|
||||
}
|
||||
|
||||
export interface LabelInitialPromptTask {
|
||||
id: string;
|
||||
message_id: string;
|
||||
prompt: string;
|
||||
type: string;
|
||||
valid_labels: string[];
|
||||
}
|
||||
|
||||
export type LabelInitialPromptTaskResponse = TaskResponse<LabelInitialPromptTask>;
|
||||
|
||||
export const useLabelingTask = <LabelingTaskType>({ taskApiEndpoint }: { taskApiEndpoint: "label_initial_prompt" }) => {
|
||||
type ConcreteTaskResponse = TaskResponse<LabelingTaskType>;
|
||||
|
||||
const [tasks, setTasks] = useState<Array<ConcreteTaskResponse>>([]);
|
||||
|
||||
const { isLoading, mutate, error } = useSWRImmutable("/api/new_task/" + taskApiEndpoint, fetcher, {
|
||||
onSuccess: (data: ConcreteTaskResponse) => {
|
||||
setTasks([data]);
|
||||
},
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
if (tasks.length === 0 && !isLoading && !error) {
|
||||
mutate();
|
||||
}
|
||||
}, [tasks, isLoading, mutate, error]);
|
||||
|
||||
const { trigger } = useSWRMutation("/api/update_task", poster, {
|
||||
onSuccess: async (reply) => {
|
||||
const newTask: ConcreteTaskResponse = await reply.json();
|
||||
setTasks((oldTasks) => [...oldTasks, newTask]);
|
||||
},
|
||||
});
|
||||
|
||||
const submit = (id: string, message_id: string, text: string, labels: Record<string, string>) =>
|
||||
trigger({ id, update_type: "text_labels", content: { labels, text, message_id } });
|
||||
|
||||
return { tasks, isLoading, submit, error, reset: mutate };
|
||||
};
|
||||
@@ -5,7 +5,7 @@ import { getToken } from "next-auth/jwt";
|
||||
* Wraps any API Route handler and verifies that the user has the appropriate
|
||||
* role before running the handler. Returns a 403 otherwise.
|
||||
*/
|
||||
const withRole = (role: string, handler: (arg0: NextApiRequest, arg1: NextApiResponse<any>) => any) => {
|
||||
const withRole = (role: string, handler: (arg0: NextApiRequest, arg1: NextApiResponse) => void) => {
|
||||
return async (req: NextApiRequest, res: NextApiResponse) => {
|
||||
const token = await getToken({ req });
|
||||
if (!token || token.role !== role) {
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
export { default } from "next-auth/middleware";
|
||||
|
||||
/**
|
||||
* Guards all pages under `/grading` and redirects them to the sign in page.
|
||||
* Guards these pages and redirects them to the sign in page.
|
||||
*/
|
||||
export const config = {
|
||||
matcher: ["/create/:path*", "/evaluate/:path*", "/account/:path*", "/dashboard"],
|
||||
matcher: ["/create/:path*", "/evaluate/:path*", "/label/:path*", "/account/:path*", "/dashboard", "/admin/:path*"],
|
||||
};
|
||||
|
||||
@@ -26,7 +26,7 @@ const AdminIndex = () => {
|
||||
return;
|
||||
}
|
||||
router.push("/");
|
||||
}, [session, status]);
|
||||
}, [router, session, status]);
|
||||
|
||||
return (
|
||||
<>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Box, Button, Container, Flex, FormControl, FormLabel, Input, Select, useToast } from "@chakra-ui/react";
|
||||
import { Button, Container, FormControl, FormLabel, Input, Select, useToast } from "@chakra-ui/react";
|
||||
import { Field, Form, Formik } from "formik";
|
||||
import Head from "next/head";
|
||||
import { useRouter } from "next/router";
|
||||
@@ -27,7 +27,7 @@ const ManageUser = ({ user }) => {
|
||||
return;
|
||||
}
|
||||
router.push("/");
|
||||
}, [session, status]);
|
||||
}, [router, session, status]);
|
||||
|
||||
// Trigger to let us update the user's role. Triggers a toast when complete.
|
||||
const { trigger } = useSWRMutation("/api/admin/update_user", poster, {
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import { getToken } from "next-auth/jwt";
|
||||
import withRole from "src/lib/auth";
|
||||
import prisma from "src/lib/prismadb";
|
||||
|
||||
|
||||
@@ -1,12 +1,21 @@
|
||||
import { getToken } from "next-auth/jwt";
|
||||
import withRole from "src/lib/auth";
|
||||
import prisma from "src/lib/prismadb";
|
||||
|
||||
// The number of users to fetch in any request.
|
||||
const PAGE_SIZE = 20;
|
||||
|
||||
/**
|
||||
* Returns a list of user results from the database when the requesting user is
|
||||
* a logged in admin.
|
||||
*/
|
||||
const handler = withRole("admin", async (req, res) => {
|
||||
// Figure out the pagination index and skip that number of users.
|
||||
//
|
||||
// Note: with Prisma this isn't the most efficient but it's the only possible
|
||||
// option with cuid based User IDs.
|
||||
const { pageIndex } = req.query;
|
||||
const skip = parseInt(pageIndex as string) * PAGE_SIZE || 0;
|
||||
|
||||
// Fetch 20 users.
|
||||
const users = await prisma.user.findMany({
|
||||
select: {
|
||||
@@ -15,7 +24,8 @@ const handler = withRole("admin", async (req, res) => {
|
||||
name: true,
|
||||
email: true,
|
||||
},
|
||||
take: 20,
|
||||
skip,
|
||||
take: PAGE_SIZE,
|
||||
});
|
||||
|
||||
res.status(200).json(users);
|
||||
|
||||
@@ -2,17 +2,59 @@ import { Button, Input, Stack } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import Link from "next/link";
|
||||
import { useRouter } from "next/router";
|
||||
import { getCsrfToken, getProviders, signIn } from "next-auth/react";
|
||||
import React, { useRef } from "react";
|
||||
import React, { useEffect, useRef, useState } from "react";
|
||||
import { FaBug, FaDiscord, FaEnvelope, FaGithub } from "react-icons/fa";
|
||||
import { AuthLayout } from "src/components/AuthLayout";
|
||||
import { Footer } from "src/components/Footer";
|
||||
import { Header } from "src/components/Header";
|
||||
|
||||
export type SignInErrorTypes =
|
||||
| "Signin"
|
||||
| "OAuthSignin"
|
||||
| "OAuthCallback"
|
||||
| "OAuthCreateAccount"
|
||||
| "EmailCreateAccount"
|
||||
| "Callback"
|
||||
| "OAuthAccountNotLinked"
|
||||
| "EmailSignin"
|
||||
| "CredentialsSignin"
|
||||
| "SessionRequired"
|
||||
| "default";
|
||||
|
||||
const errorMessages: Record<SignInErrorTypes, string> = {
|
||||
Signin: "Try signing in with a different account.",
|
||||
OAuthSignin: "Try signing in with a different account.",
|
||||
OAuthCallback: "Try signing in with the same account you used originally.",
|
||||
OAuthCreateAccount: "Try signing in with a different account.",
|
||||
EmailCreateAccount: "Try signing in with a different account.",
|
||||
Callback: "Try signing in with a different account.",
|
||||
OAuthAccountNotLinked: "To confirm your identity, sign in with the same account you used originally.",
|
||||
EmailSignin: "The e-mail could not be sent.",
|
||||
CredentialsSignin: "Sign in failed. Check the details you provided are correct.",
|
||||
SessionRequired: "Please sign in to access this page.",
|
||||
default: "Unable to sign in.",
|
||||
};
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
function Signin({ csrfToken, providers }) {
|
||||
const router = useRouter();
|
||||
const { discord, email, github, credentials } = providers;
|
||||
const emailEl = useRef(null);
|
||||
const [error, setError] = useState("");
|
||||
|
||||
useEffect(() => {
|
||||
const err = router?.query?.error;
|
||||
if (err) {
|
||||
if (typeof err === "string") {
|
||||
setError(errorMessages[err]);
|
||||
} else {
|
||||
setError(errorMessages[err[0]]);
|
||||
}
|
||||
}
|
||||
}, [router]);
|
||||
|
||||
const signinWithEmail = (ev: React.FormEvent) => {
|
||||
ev.preventDefault();
|
||||
signIn(email.id, { callbackUrl: "/dashboard", email: emailEl.current.value });
|
||||
@@ -110,6 +152,11 @@ function Signin({ csrfToken, providers }) {
|
||||
</Link>
|
||||
.
|
||||
</div>
|
||||
{error && (
|
||||
<div className="text-center mt-8">
|
||||
<p className="text-orange-600">Error: {error}</p>
|
||||
</div>
|
||||
)}
|
||||
</AuthLayout>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -1,17 +1,23 @@
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { getCsrfToken, getProviders } from "next-auth/react";
|
||||
import { AuthLayout } from "src/components/AuthLayout";
|
||||
|
||||
export default function Verify() {
|
||||
const { colorMode } = useColorMode();
|
||||
const bgColorClass = colorMode === "light" ? "bg-gray-50" : "bg-chakra-gray-900";
|
||||
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<title>Sign Up - Open Assistant</title>
|
||||
<meta name="Sign Up" content="Sign up to access Open Assistant" />
|
||||
</Head>
|
||||
<AuthLayout>
|
||||
<h1 className="text-lg">A sign-in link has been sent to your email address.</h1>
|
||||
</AuthLayout>
|
||||
<div className={`flex h-full justify-center items-center ${bgColorClass}`}>
|
||||
<div className={bgColorClass}>
|
||||
<h1 className="text-lg">A sign-in link has been sent to your email address.</h1>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
import { useState } from "react";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Message } from "src/components/Messages";
|
||||
import { MessageTable } from "src/components/Messages/MessageTable";
|
||||
import { TaskControls } from "src/components/Survey/TaskControls";
|
||||
import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask";
|
||||
import {
|
||||
LabelAssistantReplyTaskResponse,
|
||||
useLabelAssistantReplyTask,
|
||||
} from "src/hooks/tasks/labeling/useLabelAssistantReply";
|
||||
|
||||
const LabelAssistantReply = () => {
|
||||
const [sliderValues, setSliderValues] = useState<number[]>([]);
|
||||
|
||||
const { tasks, isLoading, submit, reset } = useLabelAssistantReplyTask();
|
||||
|
||||
if (isLoading || tasks.length === 0) {
|
||||
return <LoadingScreen />;
|
||||
}
|
||||
|
||||
const task = tasks[0].task;
|
||||
const messages: Message[] = [
|
||||
...task.conversation.messages,
|
||||
{ text: task.reply, is_assistant: true, message_id: task.message_id },
|
||||
];
|
||||
|
||||
return (
|
||||
<LabelTask
|
||||
title="Label Assistant Reply"
|
||||
desc="Given the following discussion, provide labels for the final prompt"
|
||||
messages={<MessageTable messages={messages} />}
|
||||
inputs={<LabelSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />}
|
||||
controls={
|
||||
<TaskControls
|
||||
tasks={tasks}
|
||||
onSkip={reset}
|
||||
onSubmitResponse={({ id, task }: LabelAssistantReplyTaskResponse) =>
|
||||
submit(id, task.message_id, task.reply, task.valid_labels, sliderValues)
|
||||
}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default LabelAssistantReply;
|
||||
@@ -1,113 +1,41 @@
|
||||
import { Container, Grid, Slider, SliderFilledTrack, SliderThumb, SliderTrack } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import { useEffect, useId, useState } from "react";
|
||||
import { useState } from "react";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { MessageView } from "src/components/Messages";
|
||||
import { TaskControls } from "src/components/Survey/TaskControls";
|
||||
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
|
||||
import { LabelInitialPromptTask, LabelInitialPromptTaskResponse, useLabelingTask } from "src/hooks/useLabelingTask";
|
||||
import { colors } from "styles/Theme/colors";
|
||||
import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask";
|
||||
import {
|
||||
LabelInitialPromptTaskResponse,
|
||||
useLabelInitialPromptTask,
|
||||
} from "src/hooks/tasks/labeling/useLabelInitialPrompt";
|
||||
|
||||
const LabelInitialPrompt = () => {
|
||||
const [sliderValues, setSliderValues] = useState<number[]>([]);
|
||||
|
||||
const { tasks, isLoading, submit, reset } = useLabelingTask<LabelInitialPromptTask>({
|
||||
taskApiEndpoint: "label_initial_prompt",
|
||||
});
|
||||
const { tasks, isLoading, submit, reset } = useLabelInitialPromptTask();
|
||||
|
||||
const submitResponse = ({ id, task }: LabelInitialPromptTaskResponse) => {
|
||||
const labels = task.valid_labels.reduce((obj, label, i) => {
|
||||
obj[label] = sliderValues[i].toString();
|
||||
return obj;
|
||||
}, {} as Record<string, string>);
|
||||
|
||||
submit(id, task.message_id, task.prompt, labels);
|
||||
};
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
if (isLoading) {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
|
||||
if (tasks.length === 0) {
|
||||
return <Container className="p-6 text-center text-gray-800">No tasks found...</Container>;
|
||||
if (isLoading || tasks.length === 0) {
|
||||
return <LoadingScreen />;
|
||||
}
|
||||
|
||||
const task = tasks[0].task;
|
||||
|
||||
return (
|
||||
<div className={`p-12 ${mainBgClasses}`}>
|
||||
<TwoColumnsWithCards>
|
||||
<>
|
||||
<h5 className="text-lg font-semibold">Label Initial Prompt</h5>
|
||||
<p className="text-lg py-1">Provide labels for the following prompt</p>
|
||||
<MessageView text={task.prompt} is_assistant message_id={task.message_id} />
|
||||
</>
|
||||
<CheckboxSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />
|
||||
</TwoColumnsWithCards>
|
||||
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={reset} />
|
||||
</div>
|
||||
<LabelTask
|
||||
title="Label Initial Prompt"
|
||||
desc="Provide labels for the following prompt"
|
||||
messages={<MessageView text={task.prompt} is_assistant message_id={task.message_id} />}
|
||||
inputs={<LabelSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />}
|
||||
controls={
|
||||
<TaskControls
|
||||
tasks={tasks}
|
||||
onSkip={reset}
|
||||
onSubmitResponse={({ id, task }: LabelInitialPromptTaskResponse) =>
|
||||
submit(id, task.message_id, task.prompt, task.valid_labels, sliderValues)
|
||||
}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default LabelInitialPrompt;
|
||||
|
||||
// TODO: consolidate with FlaggableElement
|
||||
|
||||
interface CheckboxSliderGroupProps {
|
||||
labelIDs: Array<string>;
|
||||
onChange: (sliderValues: number[]) => unknown;
|
||||
}
|
||||
|
||||
const CheckboxSliderGroup = ({ labelIDs, onChange }: CheckboxSliderGroupProps) => {
|
||||
const [sliderValues, setSliderValues] = useState<number[]>(Array.from({ length: labelIDs.length }).map(() => 0));
|
||||
|
||||
useEffect(() => {
|
||||
onChange(sliderValues);
|
||||
}, [sliderValues, onChange]);
|
||||
|
||||
return (
|
||||
<Grid templateColumns="auto 1fr" rowGap={1} columnGap={3}>
|
||||
{labelIDs.map((labelId, idx) => (
|
||||
<CheckboxSliderItem
|
||||
key={idx}
|
||||
labelId={labelId}
|
||||
sliderValue={sliderValues[idx]}
|
||||
sliderHandler={(sliderValue) => {
|
||||
const newState = sliderValues.slice();
|
||||
newState[idx] = sliderValue;
|
||||
setSliderValues(newState);
|
||||
}}
|
||||
/>
|
||||
))}
|
||||
</Grid>
|
||||
);
|
||||
};
|
||||
|
||||
function CheckboxSliderItem(props: {
|
||||
labelId: string;
|
||||
sliderValue: number;
|
||||
sliderHandler: (newVal: number) => unknown;
|
||||
}) {
|
||||
const id = useId();
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
const labelTextClass = colorMode === "light" ? `text-${colors.light.text}` : `text-${colors.dark.text}`;
|
||||
|
||||
return (
|
||||
<>
|
||||
<label className="text-sm" htmlFor={id}>
|
||||
{/* TODO: display real text instead of just the id */}
|
||||
<span className={labelTextClass}>{props.labelId}</span>
|
||||
</label>
|
||||
<Slider defaultValue={0} onChangeEnd={(val) => props.sliderHandler(val / 100)}>
|
||||
<SliderTrack>
|
||||
<SliderFilledTrack />
|
||||
<SliderThumb />
|
||||
</SliderTrack>
|
||||
</Slider>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
import { useState } from "react";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Message } from "src/components/Messages";
|
||||
import { MessageTable } from "src/components/Messages/MessageTable";
|
||||
import { TaskControls } from "src/components/Survey/TaskControls";
|
||||
import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask";
|
||||
import {
|
||||
LabelPrompterReplyTaskResponse,
|
||||
useLabelPrompterReplyTask,
|
||||
} from "src/hooks/tasks/labeling/useLabelPrompterReply";
|
||||
|
||||
const LabelPrompterReply = () => {
|
||||
const [sliderValues, setSliderValues] = useState<number[]>([]);
|
||||
|
||||
const { tasks, isLoading, submit, reset } = useLabelPrompterReplyTask();
|
||||
|
||||
if (isLoading || tasks.length === 0) {
|
||||
return <LoadingScreen />;
|
||||
}
|
||||
|
||||
const task = tasks[0].task;
|
||||
const messages: Message[] = [
|
||||
...task.conversation.messages,
|
||||
{ text: task.reply, is_assistant: false, message_id: task.message_id },
|
||||
];
|
||||
|
||||
return (
|
||||
<LabelTask
|
||||
title="Label Prompter Reply"
|
||||
desc="Given the following discussion, provide labels for the final prompt"
|
||||
messages={<MessageTable messages={messages} />}
|
||||
inputs={<LabelSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />}
|
||||
controls={
|
||||
<TaskControls
|
||||
tasks={tasks}
|
||||
onSkip={reset}
|
||||
onSubmitResponse={({ id, task }: LabelPrompterReplyTaskResponse) =>
|
||||
submit(id, task.message_id, task.reply, task.valid_labels, sliderValues)
|
||||
}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default LabelPrompterReply;
|
||||
@@ -2,6 +2,7 @@ import { Box, CircularProgress, SimpleGrid, Text, useColorModeValue } from "@cha
|
||||
import Head from "next/head";
|
||||
import { useEffect, useState } from "react";
|
||||
import { getDashboardLayout } from "src/components/Layout";
|
||||
import { Message } from "src/components/Messages";
|
||||
import { MessageTable } from "src/components/Messages/MessageTable";
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
@@ -10,29 +11,28 @@ const MessagesDashboard = () => {
|
||||
const boxBgColor = useColorModeValue("white", "gray.700");
|
||||
const boxAccentColor = useColorModeValue("gray.200", "gray.900");
|
||||
|
||||
const [messages, setMessages] = useState([]);
|
||||
const [userMessages, setUserMessages] = useState([]);
|
||||
const [messages, setMessages] = useState<Message[]>(null);
|
||||
const [userMessages, setUserMessages] = useState<Message[]>(null);
|
||||
|
||||
const { isLoading: isLoadingAll, mutate: mutateAll } = useSWRImmutable("/api/messages", fetcher, {
|
||||
onSuccess: (data) => {
|
||||
setMessages(data);
|
||||
},
|
||||
onSuccess: setMessages,
|
||||
});
|
||||
|
||||
const { isLoading: isLoadingUser, mutate: mutateUser } = useSWRImmutable(`/api/messages/user`, fetcher, {
|
||||
onSuccess: (data) => {
|
||||
setUserMessages(data);
|
||||
},
|
||||
onSuccess: setUserMessages,
|
||||
});
|
||||
|
||||
const receivedMessages = !isLoadingAll && Array.isArray(messages);
|
||||
const receivedUserMessages = !isLoadingUser && Array.isArray(userMessages);
|
||||
|
||||
useEffect(() => {
|
||||
if (messages.length == 0) {
|
||||
if (!receivedMessages) {
|
||||
mutateAll();
|
||||
}
|
||||
if (userMessages.length == 0) {
|
||||
if (!receivedUserMessages) {
|
||||
mutateUser();
|
||||
}
|
||||
}, [messages, userMessages]);
|
||||
}, [receivedMessages, mutateAll, receivedUserMessages, mutateUser]);
|
||||
|
||||
return (
|
||||
<>
|
||||
@@ -52,7 +52,7 @@ const MessagesDashboard = () => {
|
||||
borderRadius="xl"
|
||||
className="p-6 shadow-sm"
|
||||
>
|
||||
{isLoadingAll ? <CircularProgress isIndeterminate /> : <MessageTable messages={messages} />}
|
||||
{receivedMessages ? <MessageTable messages={messages} /> : <CircularProgress isIndeterminate />}
|
||||
</Box>
|
||||
</Box>
|
||||
<Box>
|
||||
@@ -66,7 +66,7 @@ const MessagesDashboard = () => {
|
||||
borderRadius="xl"
|
||||
className="p-6 shadow-sm"
|
||||
>
|
||||
{isLoadingUser ? <CircularProgress isIndeterminate /> : <MessageTable messages={userMessages} />}
|
||||
{receivedUserMessages ? <MessageTable messages={userMessages} /> : <CircularProgress isIndeterminate />}
|
||||
</Box>
|
||||
</Box>
|
||||
</SimpleGrid>
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
import Head from "next/head";
|
||||
import { TaskOption } from "src/components/Dashboard";
|
||||
import { getDashboardLayout } from "src/components/Layout";
|
||||
|
||||
const AllTasks = () => {
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<title>All Tasks - Open Assistant</title>
|
||||
<meta name="description" content="All tasks for Open Assistant." />
|
||||
</Head>
|
||||
<TaskOption />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
AllTasks.getLayout = (page) => getDashboardLayout(page);
|
||||
|
||||
export default AllTasks;
|
||||
Reference in New Issue
Block a user