mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-28 16:20:34 +08:00
Merge branch 'main' of github.com:LAION-AI/Open-Assistant
This commit is contained in:
@@ -19,7 +19,7 @@ devcontainers in this repo.
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
A successfull run should look something like this:
|
||||
A successful run should look something like this:
|
||||
|
||||
```
|
||||
@andrewm4894 ➜ /workspaces/Open-Assistant (devcontainer-improvements) $ pre-commit run --all-files
|
||||
|
||||
+6
-6
@@ -50,12 +50,12 @@ contributions smoothly we recommend the following:
|
||||
[Here](https://github.com/LAION-AI/Open-Assistant/pull/658) is an example PR
|
||||
for this project to illustrate this flow.
|
||||
1. If you're lucky, we can merge your change into `main` without any problems.
|
||||
If there's changes to files you're working on, resolve them by:
|
||||
1. First try rebase as suggested
|
||||
[in these instructions](https://timwise.co.uk/2019/10/14/merge-vs-rebase/#should-you-rebase).
|
||||
1. If rebase feels too painful, merge as suggested
|
||||
[in these instructions](https://timwise.co.uk/2019/10/14/merge-vs-rebase/#should-you-merge).
|
||||
1. Once you've resolved any conflicts, finish the review and
|
||||
If there's changes to files you're working on, resolve them by :
|
||||
1. First try rebase as suggested
|
||||
[in these instructions](https://timwise.co.uk/2019/10/14/merge-vs-rebase/#should-you-rebase).
|
||||
1. If rebase feels too painful, merge as suggested
|
||||
[in these instructions](https://timwise.co.uk/2019/10/14/merge-vs-rebase/#should-you-merge).
|
||||
1. Once you've resolved conflicts (if any), finish the review and
|
||||
[squash and merge](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/incorporating-changes-from-a-pull-request/about-pull-request-merges#squash-and-merge-your-commits)
|
||||
your PR (when squashing try to clean up or update the individual commit
|
||||
messages to be one sensible single one).
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
"""add user.show_on_leaderboard
|
||||
|
||||
Revision ID: f856bf19d32b
|
||||
Revises: c84fcd6900dc
|
||||
Create Date: 2023-01-27 20:13:56.533374
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f856bf19d32b"
|
||||
down_revision = "c84fcd6900dc"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"user", sa.Column("show_on_leaderboard", sa.Boolean(), server_default=sa.text("true"), nullable=False)
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("user", "show_on_leaderboard")
|
||||
# ### end Alembic commands ###
|
||||
+31
@@ -0,0 +1,31 @@
|
||||
"""add origin column to message_tree_state
|
||||
|
||||
Revision ID: 49d8445b4c90
|
||||
Revises: f856bf19d32b
|
||||
Create Date: 2023-01-28 11:57:45.580027
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "49d8445b4c90"
|
||||
down_revision = "f856bf19d32b"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("message", sa.Column("synthetic", sa.Boolean(), server_default=sa.text("false"), nullable=False))
|
||||
op.add_column("message", sa.Column("model_name", sa.String(length=1024), nullable=True))
|
||||
op.add_column("message_tree_state", sa.Column("origin", sa.String(length=1024), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("message_tree_state", "origin")
|
||||
op.drop_column("message", "model_name")
|
||||
op.drop_column("message", "synthetic")
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,187 @@
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
import oasst_backend.models.db_payload as db_payload
|
||||
import oasst_backend.utils.database_utils as db_utils
|
||||
import pydantic
|
||||
from loguru import logger
|
||||
from oasst_backend.api.deps import create_api_client
|
||||
from oasst_backend.models import ApiClient, Message
|
||||
from oasst_backend.models.message_tree_state import MessageTreeState
|
||||
from oasst_backend.models.message_tree_state import State as TreeState
|
||||
from oasst_backend.models.payload_column_type import PayloadContainer
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.user_repository import UserRepository
|
||||
from oasst_backend.utils.tree_export import ExportMessageNode, ExportMessageTree
|
||||
from sqlmodel import Session
|
||||
|
||||
# well known id
|
||||
IMPORT_API_CLIENT_ID = UUID("bd8fde8b-1d8e-4e9a-9966-e96d000f8363")
|
||||
|
||||
|
||||
class Importer:
|
||||
def __init__(self, db: Session, origin: str, model_name: Optional[str] = None):
|
||||
self.db = db
|
||||
self.origin = origin
|
||||
self.model_name = model_name
|
||||
|
||||
# get import api client
|
||||
api_client = db.query(ApiClient).filter(ApiClient.id == IMPORT_API_CLIENT_ID).first()
|
||||
if not api_client:
|
||||
api_client = create_api_client(
|
||||
session=db,
|
||||
description="API client used for importing data",
|
||||
frontend_type="import",
|
||||
force_id=IMPORT_API_CLIENT_ID,
|
||||
)
|
||||
|
||||
ur = UserRepository(db, api_client)
|
||||
self.import_user = ur.lookup_system_user(username="import")
|
||||
self.pr = PromptRepository(db=db, api_client=api_client, user_repository=ur)
|
||||
self.api_client = api_client
|
||||
|
||||
def fetch_message_tree_state(self, message_tree_id: UUID) -> MessageTreeState:
|
||||
return self.db.query(MessageTreeState).filter(MessageTreeState.message_tree_id == message_tree_id).one_or_none()
|
||||
|
||||
def import_message(
|
||||
self, message: ExportMessageNode, message_tree_id: UUID, parent_id: Optional[UUID] = None
|
||||
) -> Message:
|
||||
payload = db_payload.MessagePayload(text=message.text)
|
||||
msg = Message(
|
||||
id=message.message_id,
|
||||
message_tree_id=message_tree_id,
|
||||
frontend_message_id=message.message_id,
|
||||
parent_id=parent_id,
|
||||
review_count=message.review_count or 0,
|
||||
lang=message.lang or "en",
|
||||
review_result=True,
|
||||
synthetic=message.synthetic if message.synthetic is not None else True,
|
||||
model_name=message.model_name or self.model_name,
|
||||
role=message.role,
|
||||
api_client_id=self.api_client.id,
|
||||
payload_type=type(payload).__name__,
|
||||
payload=PayloadContainer(payload=payload),
|
||||
user_id=self.import_user.id,
|
||||
)
|
||||
self.db.add(msg)
|
||||
if message.replies:
|
||||
for r in message.replies:
|
||||
self.import_message(r, message_tree_id=message_tree_id, parent_id=msg.id)
|
||||
self.db.flush()
|
||||
if parent_id is None:
|
||||
self.pr.update_children_counts(msg.id)
|
||||
self.db.refresh(msg)
|
||||
return msg
|
||||
|
||||
def import_tree(
|
||||
self, tree: ExportMessageTree, state: TreeState = TreeState.BACKLOG_RANKING
|
||||
) -> tuple[MessageTreeState, Message]:
|
||||
assert tree.message_tree_id is not None and tree.message_tree_id == tree.prompt.message_id
|
||||
root_msg = self.import_message(tree.prompt, message_tree_id=tree.prompt.message_id)
|
||||
assert state == TreeState.BACKLOG_RANKING or state == TreeState.RANKING, f"{state} not supported for import"
|
||||
active = state == TreeState.RANKING
|
||||
mts = MessageTreeState(
|
||||
message_tree_id=root_msg.id,
|
||||
goal_tree_size=0,
|
||||
max_depth=0,
|
||||
max_children_count=0,
|
||||
state=state,
|
||||
origin=self.origin,
|
||||
active=active,
|
||||
)
|
||||
self.db.add(mts)
|
||||
return mts, root_msg
|
||||
|
||||
|
||||
def import_file(
|
||||
input_file_path: Path,
|
||||
origin: str,
|
||||
*,
|
||||
model_name: Optional[str] = None,
|
||||
num_activate: int = 0,
|
||||
max_count: Optional[int] = None,
|
||||
dry_run: bool = False,
|
||||
) -> int:
|
||||
@db_utils.managed_tx_function(auto_commit=db_utils.CommitMode.ROLLBACK if dry_run else db_utils.CommitMode.COMMIT)
|
||||
def import_tx(db: Session) -> int:
|
||||
importer = Importer(db, origin=origin, model_name=model_name)
|
||||
i = 0
|
||||
with input_file_path.open() as file_in:
|
||||
# read line tree object
|
||||
for line in file_in:
|
||||
dict_tree = json.loads(line)
|
||||
|
||||
# validate data
|
||||
tree: ExportMessageTree = pydantic.parse_obj_as(ExportMessageTree, dict_tree)
|
||||
existing_mts = importer.fetch_message_tree_state(tree.message_tree_id)
|
||||
if existing_mts:
|
||||
logger.info(f"Skipping existing message tree: {tree.message_tree_id}")
|
||||
else:
|
||||
state = TreeState.BACKLOG_RANKING if i >= num_activate else TreeState.RANKING
|
||||
mts, root_msg = importer.import_tree(tree, state=state)
|
||||
i += 1
|
||||
logger.info(
|
||||
f"imported tree: {mts.message_tree_id}, {mts.state=}, {mts.active=}, {root_msg.children_count=}"
|
||||
)
|
||||
|
||||
if max_count and i >= max_count:
|
||||
logger.info(f"Reached max count {max_count} of trees to import.")
|
||||
break
|
||||
return i
|
||||
|
||||
if dry_run:
|
||||
logger.info("DRY RUN with rollback")
|
||||
return import_tx()
|
||||
|
||||
|
||||
def parse_args():
|
||||
def str2bool(v):
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
if v.lower() in ("yes", "true", "t", "y", "1"):
|
||||
return True
|
||||
elif v.lower() in ("no", "false", "f", "n", "0"):
|
||||
return False
|
||||
else:
|
||||
raise argparse.ArgumentTypeError("Boolean value expected.")
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"input_file_path",
|
||||
help="Input file path",
|
||||
)
|
||||
parser.add_argument("--origin", type=str, default=None, help="Value for origin of message trees")
|
||||
parser.add_argument("--model_name", type=str, default=None, help="Default name of model (if missing in messages)")
|
||||
parser.add_argument("--num_activate", type=int, default=0, help="Number of trees to add in ranking state")
|
||||
parser.add_argument("--max_count", type=int, default=None, help="Maximum number of message trees to import")
|
||||
parser.add_argument("--dry_run", type=str2bool, default=False)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
input_file_path = Path(args.input_file_path)
|
||||
if not input_file_path.exists() or not input_file_path.is_file():
|
||||
print("Invalid input file:", args.input_file_path)
|
||||
exit(1)
|
||||
|
||||
dry_run = args.dry_run
|
||||
num_imported = import_file(
|
||||
input_file_path,
|
||||
origin=args.origin or input_file_path.name,
|
||||
model_name=args.model_name,
|
||||
num_activate=args.num_activate,
|
||||
max_count=args.max_count,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
logger.info(f"Done ({num_imported=}, {dry_run=})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+3
-1
@@ -191,6 +191,7 @@ if settings.DEBUG_USE_SEED_DATA:
|
||||
review_count=5,
|
||||
review_result=True,
|
||||
check_tree_state=False,
|
||||
check_duplicate=False,
|
||||
)
|
||||
if message.parent_id is None:
|
||||
tm._insert_default_state(
|
||||
@@ -215,7 +216,8 @@ def ensure_tree_states():
|
||||
try:
|
||||
logger.info("Startup: TreeManager.ensure_tree_states()")
|
||||
with Session(engine) as db:
|
||||
tm = TreeManager(db, None)
|
||||
api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=db)
|
||||
tm = TreeManager(db, PromptRepository(db, api_client=api_client))
|
||||
tm.ensure_tree_states()
|
||||
|
||||
except Exception:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from http import HTTPStatus
|
||||
from secrets import token_hex
|
||||
from typing import Generator
|
||||
from typing import Generator, NamedTuple, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends, Request, Response, Security
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
@@ -19,22 +20,46 @@ def get_db() -> Generator:
|
||||
yield db
|
||||
|
||||
|
||||
api_key_query = APIKeyQuery(name="api_key", auto_error=False)
|
||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
api_key_query = APIKeyQuery(name="api_key", scheme_name="api-key", auto_error=False)
|
||||
api_key_header = APIKeyHeader(name="X-API-Key", scheme_name="api-key", auto_error=False)
|
||||
oasst_user_query = APIKeyQuery(name="oasst_user", scheme_name="oasst-user", auto_error=False)
|
||||
oasst_user_header = APIKeyHeader(name="x-oasst-user", scheme_name="oasst-user", auto_error=False)
|
||||
|
||||
bearer_token = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
async def get_api_key(
|
||||
def get_api_key(
|
||||
api_key_query: str = Security(api_key_query),
|
||||
api_key_header: str = Security(api_key_header),
|
||||
):
|
||||
) -> str:
|
||||
if api_key_query:
|
||||
return api_key_query
|
||||
else:
|
||||
return api_key_header
|
||||
|
||||
|
||||
class FrontendUserId(NamedTuple):
|
||||
auth_method: str
|
||||
username: str
|
||||
|
||||
|
||||
def get_frontend_user_id(
|
||||
user_query: str = Security(oasst_user_query),
|
||||
user_header: str = Security(oasst_user_header),
|
||||
) -> FrontendUserId:
|
||||
def split_user(v: str) -> tuple[str, str]:
|
||||
if type(v) is str:
|
||||
v = v.split(":", maxsplit=1)
|
||||
if len(v) == 2:
|
||||
return FrontendUserId(auth_method=v[0], username=v[1])
|
||||
return FrontendUserId(auth_method=None, username=None)
|
||||
|
||||
if user_query:
|
||||
return split_user(user_query)
|
||||
else:
|
||||
return split_user(user_header)
|
||||
|
||||
|
||||
def create_api_client(
|
||||
*,
|
||||
session: Session,
|
||||
@@ -43,6 +68,7 @@ def create_api_client(
|
||||
trusted: bool | None = False,
|
||||
admin_email: str | None = None,
|
||||
api_key: str | None = None,
|
||||
force_id: Optional[UUID] = None,
|
||||
) -> ApiClient:
|
||||
if api_key is None:
|
||||
api_key = token_hex(32)
|
||||
@@ -55,6 +81,8 @@ def create_api_client(
|
||||
trusted=trusted,
|
||||
admin_email=admin_email,
|
||||
)
|
||||
if force_id:
|
||||
api_client.id = force_id
|
||||
session.add(api_client)
|
||||
session.commit()
|
||||
session.refresh(api_client)
|
||||
|
||||
@@ -104,6 +104,7 @@ def query_frontend_user_messages_cursor(
|
||||
max_count: Optional[int] = Query(10, gt=0, le=1000),
|
||||
desc: Optional[bool] = False,
|
||||
lang: Optional[str] = None,
|
||||
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
@@ -117,6 +118,7 @@ def query_frontend_user_messages_cursor(
|
||||
max_count=max_count,
|
||||
desc=desc,
|
||||
lang=lang,
|
||||
frontend_user=frontend_user,
|
||||
api_client=api_client,
|
||||
db=db,
|
||||
)
|
||||
|
||||
@@ -18,6 +18,7 @@ router = APIRouter()
|
||||
|
||||
@router.get("/", response_model=list[protocol.Message])
|
||||
def query_messages(
|
||||
*,
|
||||
auth_method: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
api_client_id: Optional[str] = None,
|
||||
@@ -28,13 +29,14 @@ def query_messages(
|
||||
desc: Optional[bool] = True,
|
||||
allow_deleted: Optional[bool] = False,
|
||||
lang: Optional[str] = None,
|
||||
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
"""
|
||||
Query messages.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, auth_method=auth_method, username=username)
|
||||
pr = PromptRepository(db, api_client, auth_method=frontend_user.auth_method, username=frontend_user.username)
|
||||
messages = pr.query_messages_ordered_by_created_date(
|
||||
auth_method=auth_method,
|
||||
username=username,
|
||||
@@ -53,6 +55,7 @@ def query_messages(
|
||||
|
||||
@router.get("/cursor", response_model=protocol.MessagePage)
|
||||
def get_messages_cursor(
|
||||
*,
|
||||
before: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
user_id: Optional[UUID] = None,
|
||||
@@ -64,6 +67,7 @@ def get_messages_cursor(
|
||||
max_count: Optional[int] = Query(10, gt=0, le=1000),
|
||||
desc: Optional[bool] = False,
|
||||
lang: Optional[str] = None,
|
||||
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
@@ -93,7 +97,7 @@ def get_messages_cursor(
|
||||
|
||||
qry_max_count = max_count + 1 if before is None or after is None else max_count
|
||||
|
||||
pr = PromptRepository(db, api_client, auth_method=auth_method, username=username, user_id=user_id)
|
||||
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
|
||||
items = pr.query_messages_ordered_by_created_date(
|
||||
user_id=user_id,
|
||||
auth_method=auth_method,
|
||||
@@ -137,25 +141,25 @@ def get_messages_cursor(
|
||||
|
||||
@router.get("/{message_id}", response_model=protocol.Message)
|
||||
def get_message(
|
||||
*,
|
||||
message_id: UUID,
|
||||
auth_method: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
"""
|
||||
Get a message by its internal ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, auth_method=auth_method, username=username)
|
||||
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
|
||||
message = pr.fetch_message(message_id)
|
||||
return utils.prepare_message(message)
|
||||
|
||||
|
||||
@router.get("/{message_id}/conversation", response_model=protocol.Conversation)
|
||||
def get_conv(
|
||||
*,
|
||||
message_id: UUID,
|
||||
auth_method: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
@@ -163,23 +167,23 @@ def get_conv(
|
||||
Get a conversation from the tree root and up to the message with given internal ID.
|
||||
"""
|
||||
|
||||
pr = PromptRepository(db, api_client, auth_method=auth_method, username=username)
|
||||
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
|
||||
messages = pr.fetch_message_conversation(message_id)
|
||||
return utils.prepare_conversation(messages)
|
||||
|
||||
|
||||
@router.get("/{message_id}/tree", response_model=protocol.MessageTree)
|
||||
def get_tree(
|
||||
*,
|
||||
message_id: UUID,
|
||||
auth_method: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
"""
|
||||
Get all messages belonging to the same message tree.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, auth_method=auth_method, username=username)
|
||||
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
|
||||
message = pr.fetch_message(message_id)
|
||||
tree = pr.fetch_message_tree(message.message_tree_id, reviewed=False)
|
||||
return utils.prepare_tree(tree, message.message_tree_id)
|
||||
@@ -187,32 +191,32 @@ def get_tree(
|
||||
|
||||
@router.get("/{message_id}/children", response_model=list[protocol.Message])
|
||||
def get_children(
|
||||
*,
|
||||
message_id: UUID,
|
||||
auth_method: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
"""
|
||||
Get all messages belonging to the same message tree.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, auth_method=auth_method, username=username)
|
||||
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
|
||||
messages = pr.fetch_message_children(message_id)
|
||||
return utils.prepare_message_list(messages)
|
||||
|
||||
|
||||
@router.get("/{message_id}/descendants", response_model=protocol.MessageTree)
|
||||
def get_descendants(
|
||||
*,
|
||||
message_id: UUID,
|
||||
auth_method: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
"""
|
||||
Get a subtree which starts with this message.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, auth_method=auth_method, username=username)
|
||||
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
|
||||
message = pr.fetch_message(message_id)
|
||||
descendants = pr.fetch_message_descendants(message)
|
||||
return utils.prepare_tree(descendants, message.id)
|
||||
@@ -220,16 +224,16 @@ def get_descendants(
|
||||
|
||||
@router.get("/{message_id}/longest_conversation_in_tree", response_model=protocol.Conversation)
|
||||
def get_longest_conv(
|
||||
*,
|
||||
message_id: UUID,
|
||||
auth_method: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
"""
|
||||
Get the longest conversation from the tree of the message.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, auth_method=auth_method, username=username)
|
||||
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
|
||||
message = pr.fetch_message(message_id)
|
||||
conv = pr.fetch_longest_conversation(message.message_tree_id)
|
||||
return utils.prepare_conversation(conv)
|
||||
@@ -237,16 +241,16 @@ def get_longest_conv(
|
||||
|
||||
@router.get("/{message_id}/max_children_in_tree", response_model=protocol.MessageTree)
|
||||
def get_max_children(
|
||||
*,
|
||||
message_id: UUID,
|
||||
auth_method: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
"""
|
||||
Get message with the most children from the tree of the provided message.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, auth_method=auth_method, username=username)
|
||||
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
|
||||
message = pr.fetch_message(message_id)
|
||||
message, children = pr.fetch_message_with_max_children(message.message_tree_id)
|
||||
return utils.prepare_tree([message, *children], message.id)
|
||||
@@ -254,9 +258,13 @@ def get_max_children(
|
||||
|
||||
@router.delete("/{message_id}", status_code=HTTP_204_NO_CONTENT)
|
||||
def mark_message_deleted(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
|
||||
*,
|
||||
message_id: UUID,
|
||||
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
pr = PromptRepository(db, api_client)
|
||||
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
|
||||
pr.mark_messages_deleted(message_id)
|
||||
|
||||
|
||||
|
||||
@@ -77,6 +77,7 @@ def tasks_acknowledge(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
|
||||
task_id: UUID,
|
||||
ack_request: protocol_schema.TaskAck,
|
||||
) -> None:
|
||||
@@ -87,7 +88,7 @@ def tasks_acknowledge(
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
|
||||
try:
|
||||
pr = PromptRepository(db, api_client)
|
||||
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
|
||||
|
||||
# here we store the message id in the database for the task
|
||||
logger.info(f"Frontend acknowledges task {task_id=}, {ack_request=}.")
|
||||
@@ -105,6 +106,7 @@ def tasks_acknowledge_failure(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
|
||||
task_id: UUID,
|
||||
nack_request: protocol_schema.TaskNAck,
|
||||
) -> None:
|
||||
@@ -115,7 +117,7 @@ def tasks_acknowledge_failure(
|
||||
try:
|
||||
logger.info(f"Frontend reports failure to implement task {task_id=}, {nack_request=}.")
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
pr = PromptRepository(db, api_client)
|
||||
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
|
||||
pr.task_repository.acknowledge_task_failure(task_id)
|
||||
except (KeyError, RuntimeError):
|
||||
logger.exception("Failed to not acknowledge task.")
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.security.api_key import APIKey
|
||||
from loguru import logger
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.schemas.text_labels import LabelDescription, ValidLabelsResponse
|
||||
from oasst_backend.utils.database_utils import CommitMode, managed_tx_function
|
||||
from oasst_shared.exceptions import OasstError
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from oasst_shared.schemas.protocol import TextLabel
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST
|
||||
|
||||
router = APIRouter()
|
||||
@@ -43,11 +49,28 @@ def label_text(
|
||||
|
||||
|
||||
@router.get("/valid_labels")
|
||||
def get_valid_lables() -> ValidLabelsResponse:
|
||||
def get_valid_lables(
|
||||
*,
|
||||
message_id: Optional[UUID] = None,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
) -> ValidLabelsResponse:
|
||||
if message_id:
|
||||
pr = PromptRepository(db, api_client=api_client)
|
||||
message = pr.fetch_message(message_id=message_id)
|
||||
if message.parent_id is None:
|
||||
valid_labels = settings.tree_manager.labels_initial_prompt
|
||||
elif message.role == "assistant":
|
||||
valid_labels = settings.tree_manager.labels_assistant_reply
|
||||
else:
|
||||
valid_labels = settings.tree_manager.labels_prompter_reply
|
||||
else:
|
||||
valid_labels = [l for l in TextLabel if l != TextLabel.fails_task]
|
||||
|
||||
return ValidLabelsResponse(
|
||||
valid_labels=[
|
||||
LabelDescription(name=l.value, widget=l.widget.value, display_text=l.display_text, help_text=l.help_text)
|
||||
for l in TextLabel
|
||||
for l in valid_labels
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -190,6 +190,7 @@ def update_user(
|
||||
user_id: UUID,
|
||||
enabled: Optional[bool] = None,
|
||||
notes: Optional[str] = None,
|
||||
show_on_leaderboard: Optional[bool] = None,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
):
|
||||
@@ -197,7 +198,7 @@ def update_user(
|
||||
Update a user by global user ID. Only trusted clients can update users.
|
||||
"""
|
||||
ur = UserRepository(db, api_client)
|
||||
ur.update_user(user_id, enabled, notes)
|
||||
ur.update_user(user_id, enabled, notes, show_on_leaderboard)
|
||||
|
||||
|
||||
@router.delete("/{user_id}", status_code=HTTP_204_NO_CONTENT)
|
||||
@@ -224,13 +225,14 @@ def query_user_messages(
|
||||
desc: bool = True,
|
||||
include_deleted: bool = False,
|
||||
lang: Optional[str] = None,
|
||||
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
"""
|
||||
Query user messages.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user_id=user_id)
|
||||
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
|
||||
messages = pr.query_messages_ordered_by_created_date(
|
||||
user_id=user_id,
|
||||
api_client_id=api_client_id,
|
||||
@@ -256,6 +258,7 @@ def query_user_messages_cursor(
|
||||
max_count: Optional[int] = Query(10, gt=0, le=1000),
|
||||
desc: Optional[bool] = False,
|
||||
lang: Optional[str] = None,
|
||||
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
@@ -268,6 +271,7 @@ def query_user_messages_cursor(
|
||||
max_count=max_count,
|
||||
desc=desc,
|
||||
lang=lang,
|
||||
frontend_user=frontend_user,
|
||||
api_client=api_client,
|
||||
db=db,
|
||||
)
|
||||
@@ -275,9 +279,12 @@ def query_user_messages_cursor(
|
||||
|
||||
@router.delete("/{user_id}/messages", status_code=HTTP_204_NO_CONTENT)
|
||||
def mark_user_messages_deleted(
|
||||
user_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
|
||||
user_id: UUID,
|
||||
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
pr = PromptRepository(db, api_client)
|
||||
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
|
||||
messages = pr.query_messages_ordered_by_created_date(user_id=user_id, limit=None)
|
||||
pr.mark_messages_deleted(messages)
|
||||
|
||||
|
||||
@@ -23,19 +23,20 @@ def prepare_message_list(messages: list[Message]) -> list[protocol.Message]:
|
||||
return [prepare_message(m) for m in messages]
|
||||
|
||||
|
||||
def prepare_conversation_message(message: Message) -> protocol.ConversationMessage:
|
||||
return protocol.ConversationMessage(
|
||||
id=message.id,
|
||||
frontend_message_id=message.frontend_message_id,
|
||||
text=message.text,
|
||||
lang=message.lang,
|
||||
is_assistant=(message.role == "assistant"),
|
||||
emojis=message.emojis or {},
|
||||
user_emojis=message.user_emojis or [],
|
||||
)
|
||||
|
||||
|
||||
def prepare_conversation_message_list(messages: list[Message]) -> list[protocol.ConversationMessage]:
|
||||
return [
|
||||
protocol.ConversationMessage(
|
||||
id=message.id,
|
||||
frontend_message_id=message.frontend_message_id,
|
||||
text=message.text,
|
||||
lang=message.lang,
|
||||
is_assistant=(message.role == "assistant"),
|
||||
emojis=message.emojis or {},
|
||||
user_emojis=message.user_emojis or [],
|
||||
)
|
||||
for message in messages
|
||||
]
|
||||
return [prepare_conversation_message(message) for message in messages]
|
||||
|
||||
|
||||
def prepare_conversation(messages: list[Message]) -> protocol.Conversation:
|
||||
|
||||
@@ -46,6 +46,15 @@ class TreeManagerConfiguration(BaseModel):
|
||||
num_required_rankings: int = 3
|
||||
"""Number of rankings in which the message participated."""
|
||||
|
||||
p_activate_backlog_tree: float = 0.8
|
||||
"""Probability to activate a message tree in BACKLOG_RANKING state when another tree enters
|
||||
a terminal state. Use this settting to control ratio of initial prompts and backlog tree
|
||||
activations."""
|
||||
|
||||
min_active_rankings_per_lang: int = 2
|
||||
"""When the number of active ranking tasks is below this value when a tree enters a terminal
|
||||
state an available trees in BACKLOG_RANKING will be actived (i.e. enters the RANKING state)."""
|
||||
|
||||
labels_initial_prompt: list[TextLabel] = [
|
||||
TextLabel.spam,
|
||||
TextLabel.quality,
|
||||
@@ -138,6 +147,8 @@ class Settings(BaseSettings):
|
||||
DEBUG_SKIP_TOXICITY_CALCULATION: bool = False
|
||||
DEBUG_DATABASE_ECHO: bool = False
|
||||
|
||||
DUPLICATE_MESSAGE_FILTER_WINDOW_MINUTES: int = 120
|
||||
|
||||
HUGGING_FACE_API_KEY: str = ""
|
||||
|
||||
ROOT_TOKENS: List[str] = ["1234"] # supply a string that can be parsed to a json list
|
||||
@@ -168,9 +179,9 @@ class Settings(BaseSettings):
|
||||
|
||||
tree_manager: Optional[TreeManagerConfiguration] = TreeManagerConfiguration()
|
||||
|
||||
USER_STATS_INTERVAL_DAY: int = 15 # minutes
|
||||
USER_STATS_INTERVAL_WEEK: int = 60 # minutes
|
||||
USER_STATS_INTERVAL_MONTH: int = 120 # minutes
|
||||
USER_STATS_INTERVAL_DAY: int = 5 # minutes
|
||||
USER_STATS_INTERVAL_WEEK: int = 15 # minutes
|
||||
USER_STATS_INTERVAL_MONTH: int = 60 # minutes
|
||||
USER_STATS_INTERVAL_TOTAL: int = 240 # minutes
|
||||
|
||||
@validator(
|
||||
|
||||
@@ -117,7 +117,8 @@ class LabelConversationReplyPayload(TaskPayload):
|
||||
|
||||
message_id: UUID
|
||||
conversation: protocol_schema.Conversation
|
||||
reply: str
|
||||
reply: str # deprecated
|
||||
reply_message: Optional[protocol_schema.ConversationMessage]
|
||||
valid_labels: list[str]
|
||||
mandatory_labels: Optional[list[str]]
|
||||
mode: Optional[protocol_schema.LabelTaskMode]
|
||||
|
||||
@@ -57,6 +57,11 @@ class Message(SQLModel, table=True):
|
||||
|
||||
rank: Optional[int] = Field(nullable=True)
|
||||
|
||||
synthetic: Optional[bool] = Field(
|
||||
sa_column=sa.Column(sa.Boolean, default=False, server_default=false(), nullable=False)
|
||||
)
|
||||
model_name: Optional[str] = Field(sa_column=sa.Column(sa.String(1024), nullable=True))
|
||||
|
||||
emojis: Optional[dict[str, int]] = Field(default=None, sa_column=sa.Column(pg.JSONB), nullable=False)
|
||||
_user_emojis: Optional[list[str]] = PrivateAttr(default=None)
|
||||
|
||||
|
||||
@@ -43,6 +43,9 @@ class State(str, Enum):
|
||||
HALTED_BY_MODERATOR = "halted_by_moderator"
|
||||
"""A moderator decided to manually halt the message tree construction process."""
|
||||
|
||||
BACKLOG_RANKING = "backlog_ranking"
|
||||
"""Imported tree ready to be activated and ranked by users (currently inactive)."""
|
||||
|
||||
|
||||
VALID_STATES = (
|
||||
State.INITIAL_PROMPT_REVIEW,
|
||||
@@ -51,6 +54,7 @@ VALID_STATES = (
|
||||
State.READY_FOR_SCORING,
|
||||
State.READY_FOR_EXPORT,
|
||||
State.ABORTED_LOW_GRADE,
|
||||
State.BACKLOG_RANKING,
|
||||
)
|
||||
|
||||
TERMINAL_STATES = (State.READY_FOR_EXPORT, State.ABORTED_LOW_GRADE, State.SCORING_FAILED, State.HALTED_BY_MODERATOR)
|
||||
@@ -67,3 +71,4 @@ class MessageTreeState(SQLModel, table=True):
|
||||
max_children_count: int = Field(nullable=False)
|
||||
state: str = Field(nullable=False, max_length=128, index=True)
|
||||
active: bool = Field(nullable=False, index=True)
|
||||
origin: str = Field(sa_column=sa.Column(sa.String(1024), nullable=True))
|
||||
|
||||
@@ -30,6 +30,7 @@ class User(SQLModel, table=True):
|
||||
enabled: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=sa.true()))
|
||||
notes: str = Field(sa_column=sa.Column(AutoString(length=1024), nullable=False, server_default=""))
|
||||
deleted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=sa.false()))
|
||||
show_on_leaderboard: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=sa.true()))
|
||||
|
||||
def to_protocol_frontend_user(self):
|
||||
return protocol.FrontEndUser(
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import random
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from http import HTTPStatus
|
||||
from typing import Optional
|
||||
from uuid import UUID, uuid4
|
||||
@@ -8,6 +9,8 @@ from uuid import UUID, uuid4
|
||||
import oasst_backend.models.db_payload as db_payload
|
||||
import sqlalchemy as sa
|
||||
from loguru import logger
|
||||
from oasst_backend.api.deps import FrontendUserId
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.journal_writer import JournalWriter
|
||||
from oasst_backend.models import (
|
||||
ApiClient,
|
||||
@@ -29,7 +32,7 @@ from oasst_backend.utils.database_utils import CommitMode, managed_tx_method
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from oasst_shared.schemas.protocol import SystemStats
|
||||
from oasst_shared.utils import unaware_to_utc
|
||||
from oasst_shared.utils import unaware_to_utc, utcnow
|
||||
from sqlalchemy.orm import Query
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
from sqlmodel import JSON, Session, and_, func, literal_column, not_, or_, text, update
|
||||
@@ -48,10 +51,15 @@ class PromptRepository:
|
||||
user_id: Optional[UUID] = None,
|
||||
auth_method: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
frontend_user: Optional[FrontendUserId] = None,
|
||||
):
|
||||
self.db = db
|
||||
self.api_client = api_client
|
||||
self.user_repository = user_repository or UserRepository(db, api_client)
|
||||
|
||||
if frontend_user and not auth_method and not username:
|
||||
auth_method, username = frontend_user
|
||||
|
||||
if user_id:
|
||||
self.user = self.user_repository.get_user(id=user_id)
|
||||
self.user_id = self.user.id
|
||||
@@ -169,6 +177,7 @@ class PromptRepository:
|
||||
review_count: int = 0,
|
||||
review_result: bool = False,
|
||||
check_tree_state: bool = True,
|
||||
check_duplicate: bool = True,
|
||||
) -> Message:
|
||||
self.ensure_user_is_enabled()
|
||||
|
||||
@@ -182,6 +191,18 @@ class PromptRepository:
|
||||
role = None
|
||||
depth = 0
|
||||
|
||||
# reject whitespaces match with ^\s+$
|
||||
if re.match(r"^\s+$", text):
|
||||
raise OasstError("Message text is empty", OasstErrorCode.TASK_MESSAGE_TEXT_EMPTY)
|
||||
|
||||
# ensure message size is below the predefined limit
|
||||
if len(text) > settings.MESSAGE_SIZE_LIMIT:
|
||||
logger.error(f"Message size {len(text)=} exceeds size limit of {settings.MESSAGE_SIZE_LIMIT=}.")
|
||||
raise OasstError("Message size too long.", OasstErrorCode.TASK_MESSAGE_TOO_LONG)
|
||||
|
||||
if check_duplicate and self.check_users_recent_replies_for_duplicates(text):
|
||||
raise OasstError("User recent messages have duplicates", OasstErrorCode.TASK_MESSAGE_DUPLICATED)
|
||||
|
||||
if task.parent_message_id:
|
||||
parent_message = self.fetch_message(task.parent_message_id)
|
||||
|
||||
@@ -460,6 +481,7 @@ class PromptRepository:
|
||||
task_id=task.id if task else None,
|
||||
)
|
||||
|
||||
message: Message = None
|
||||
if message_id:
|
||||
if not task:
|
||||
if text_labels.is_report is True:
|
||||
@@ -550,6 +572,30 @@ class PromptRepository:
|
||||
qry = qry.filter(not_(Message.deleted))
|
||||
return self._add_user_emojis_all(qry)
|
||||
|
||||
def check_users_recent_replies_for_duplicates(self, text: str) -> bool:
|
||||
"""
|
||||
Checks if the user has recently replied with the same text within a given time period.
|
||||
"""
|
||||
|
||||
user_id = self.user_id
|
||||
logger.debug(f"Checking for duplicate tasks for user {user_id}")
|
||||
# messages in the past 24 hours
|
||||
messages = (
|
||||
self.db.query(Message)
|
||||
.filter(Message.user_id == user_id)
|
||||
.order_by(Message.created_date.desc())
|
||||
.filter(
|
||||
Message.created_date > utcnow() - timedelta(minutes=settings.DUPLICATE_MESSAGE_FILTER_WINDOW_MINUTES)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
if not messages:
|
||||
return False
|
||||
for msg in messages:
|
||||
if msg.text == text:
|
||||
return True
|
||||
return False
|
||||
|
||||
def fetch_user_message_trees(
|
||||
self, user_id: Message.user_id, reviewed: bool = True, include_deleted: bool = False
|
||||
) -> list[Message]:
|
||||
@@ -865,8 +911,7 @@ FROM (
|
||||
) AS cc
|
||||
WHERE message.id = cc.id;
|
||||
"""
|
||||
r = self.db.execute(text(sql_update_children_count), {"message_tree_id": message_tree_id})
|
||||
logger.debug(f"update_children_count({message_tree_id=}): {r.rowcount} rows.")
|
||||
self.db.execute(text(sql_update_children_count), {"message_tree_id": message_tree_id})
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def mark_messages_deleted(self, messages: Message | UUID | list[Message | UUID], recursive: bool = True):
|
||||
|
||||
@@ -101,6 +101,7 @@ class TaskRepository:
|
||||
message_id=task.message_id,
|
||||
conversation=task.conversation,
|
||||
reply=task.reply,
|
||||
reply_message=task.reply_message,
|
||||
valid_labels=task.valid_labels,
|
||||
mandatory_labels=task.mandatory_labels,
|
||||
mode=task.mode,
|
||||
@@ -112,6 +113,7 @@ class TaskRepository:
|
||||
message_id=task.message_id,
|
||||
conversation=task.conversation,
|
||||
reply=task.reply,
|
||||
reply_message=task.reply_message,
|
||||
valid_labels=task.valid_labels,
|
||||
mandatory_labels=task.mandatory_labels,
|
||||
mode=task.mode,
|
||||
|
||||
@@ -11,7 +11,11 @@ import numpy as np
|
||||
import pydantic
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from loguru import logger
|
||||
from oasst_backend.api.v1.utils import prepare_conversation, prepare_conversation_message_list
|
||||
from oasst_backend.api.v1.utils import (
|
||||
prepare_conversation,
|
||||
prepare_conversation_message,
|
||||
prepare_conversation_message_list,
|
||||
)
|
||||
from oasst_backend.config import TreeManagerConfiguration, settings
|
||||
from oasst_backend.models import Message, MessageReaction, MessageTreeState, Task, TextLabels, User, message_tree_state
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
@@ -21,7 +25,7 @@ from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingM
|
||||
from oasst_backend.utils.ranking import ranked_pairs
|
||||
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session, func, not_, text, update
|
||||
from sqlmodel import Session, func, not_, or_, text, update
|
||||
|
||||
|
||||
class TaskType(Enum):
|
||||
@@ -69,6 +73,7 @@ class IncompleteRankingsRow(pydantic.BaseModel):
|
||||
role: str
|
||||
children_count: int
|
||||
child_min_ranking_count: int
|
||||
message_tree_id: UUID
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
@@ -356,7 +361,7 @@ class TreeManager:
|
||||
random_reply_message = random.choice(replies_need_review)
|
||||
messages = self.pr.fetch_message_conversation(random_reply_message)
|
||||
|
||||
conversation = prepare_conversation(messages[:-1])
|
||||
conversation = prepare_conversation(messages)
|
||||
message = messages[-1]
|
||||
|
||||
self.cfg.p_full_labeling_review_reply_prompter: float = 0.1
|
||||
@@ -370,15 +375,18 @@ class TreeManager:
|
||||
desired_task_type == protocol_schema.TaskRequestType.random
|
||||
and random.random() > self.cfg.p_full_labeling_review_reply_assistant
|
||||
):
|
||||
valid_labels = self.cfg.mandatory_labels_assistant_reply
|
||||
label_mode = protocol_schema.LabelTaskMode.simple
|
||||
label_disposition = protocol_schema.LabelTaskDisposition.spam
|
||||
valid_labels = list(self.cfg.mandatory_labels_assistant_reply)
|
||||
if protocol_schema.TextLabel.quality not in valid_labels:
|
||||
valid_labels.append(protocol_schema.TextLabel.quality)
|
||||
|
||||
logger.info(f"Generating a LabelAssistantReplyTask. ({label_mode=:s})")
|
||||
task = protocol_schema.LabelAssistantReplyTask(
|
||||
message_id=message.id,
|
||||
conversation=conversation,
|
||||
reply=message.text,
|
||||
reply_message=prepare_conversation_message(message),
|
||||
valid_labels=list(map(lambda x: x.value, valid_labels)),
|
||||
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)),
|
||||
mode=label_mode,
|
||||
@@ -391,13 +399,18 @@ class TreeManager:
|
||||
desired_task_type == protocol_schema.TaskRequestType.random
|
||||
and random.random() > self.cfg.p_full_labeling_review_reply_prompter
|
||||
):
|
||||
valid_labels = self.cfg.mandatory_labels_prompter_reply
|
||||
label_mode = protocol_schema.LabelTaskMode.simple
|
||||
label_disposition = protocol_schema.LabelTaskDisposition.spam
|
||||
valid_labels = list(self.cfg.mandatory_labels_prompter_reply)
|
||||
if protocol_schema.TextLabel.quality not in valid_labels:
|
||||
valid_labels.append(protocol_schema.TextLabel.quality)
|
||||
|
||||
logger.info(f"Generating a LabelPrompterReplyTask. ({label_mode=:s})")
|
||||
task = protocol_schema.LabelPrompterReplyTask(
|
||||
message_id=message.id,
|
||||
conversation=conversation,
|
||||
reply=message.text,
|
||||
reply_message=prepare_conversation_message(message),
|
||||
valid_labels=list(map(lambda x: x.value, valid_labels)),
|
||||
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)),
|
||||
mode=label_mode,
|
||||
@@ -514,14 +527,6 @@ class TreeManager:
|
||||
logger.info(
|
||||
f"Frontend reports text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
|
||||
)
|
||||
|
||||
# ensure message size is below the predefined limit
|
||||
if len(interaction.text) > settings.MESSAGE_SIZE_LIMIT:
|
||||
logger.error(
|
||||
f"Message size {len(interaction.text)=} exceeds size limit of {settings.MESSAGE_SIZE_LIMIT=}."
|
||||
)
|
||||
raise OasstError("Message size too long.", OasstErrorCode.TASK_MESSAGE_TOO_LONG)
|
||||
|
||||
# here we store the text reply in the database
|
||||
message = pr.store_text_reply(
|
||||
text=interaction.text,
|
||||
@@ -621,19 +626,28 @@ class TreeManager:
|
||||
|
||||
return protocol_schema.TaskDone()
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def _enter_state(self, mts: MessageTreeState, state: message_tree_state.State):
|
||||
assert mts and mts.active
|
||||
assert mts
|
||||
|
||||
is_terminal = state in message_tree_state.TERMINAL_STATES
|
||||
|
||||
was_active = mts.active
|
||||
if is_terminal:
|
||||
mts.active = False
|
||||
mts.state = state.value
|
||||
self.db.add(mts)
|
||||
self.db.flush
|
||||
|
||||
if is_terminal:
|
||||
logger.info(f"Tree entered terminal '{mts.state}' state ({mts.message_tree_id=})")
|
||||
root_msg = self.pr.fetch_message(message_id=mts.message_tree_id, fail_if_missing=False)
|
||||
if root_msg and was_active:
|
||||
if random.random() < self.cfg.p_activate_backlog_tree:
|
||||
self.activate_backlog_tree(lang=root_msg.lang)
|
||||
|
||||
if self.cfg.min_active_rankings_per_lang > 0:
|
||||
incomplete_rankings = self.query_incomplete_rankings(lang=root_msg.lang)
|
||||
if len(incomplete_rankings) < self.cfg.min_active_rankings_per_lang:
|
||||
self.activate_backlog_tree(lang=root_msg.lang)
|
||||
else:
|
||||
logger.info(f"Tree entered '{mts.state}' state ({mts.message_tree_id=})")
|
||||
|
||||
@@ -676,24 +690,30 @@ class TreeManager:
|
||||
self._enter_state(mts, message_tree_state.State.RANKING)
|
||||
return True
|
||||
|
||||
def check_condition_for_scoring_state(
|
||||
self, message_tree_id: UUID
|
||||
) -> Tuple[bool, dict[UUID, list[MessageReaction]]]:
|
||||
def check_condition_for_scoring_state(self, message_tree_id: UUID) -> bool:
|
||||
logger.debug(f"check_condition_for_scoring_state({message_tree_id=})")
|
||||
|
||||
mts = self.pr.fetch_tree_state(message_tree_id)
|
||||
if not mts.active or mts.state != message_tree_state.State.RANKING:
|
||||
logger.debug(f"False {mts.active=}, {mts.state=}")
|
||||
return False, None
|
||||
if mts.state != message_tree_state.State.SCORING_FAILED:
|
||||
if not mts.active or mts.state not in (
|
||||
message_tree_state.State.RANKING,
|
||||
message_tree_state.State.READY_FOR_SCORING,
|
||||
):
|
||||
logger.debug(f"False {mts.active=}, {mts.state=}")
|
||||
return False
|
||||
|
||||
ranking_role_filter = None if self.cfg.rank_prompter_replies else "assistant"
|
||||
rankings_by_message = self.query_tree_ranking_results(message_tree_id, role_filter=ranking_role_filter)
|
||||
for parent_msg_id, ranking in rankings_by_message.items():
|
||||
if len(ranking) < self.cfg.num_required_rankings:
|
||||
logger.debug(f"False {parent_msg_id=} {len(ranking)=}")
|
||||
return False, None
|
||||
return False
|
||||
|
||||
self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING)
|
||||
if (
|
||||
mts.state != message_tree_state.State.SCORING_FAILED
|
||||
and mts.state != message_tree_state.State.READY_FOR_SCORING
|
||||
):
|
||||
self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING)
|
||||
self.update_message_ranks(message_tree_id, rankings_by_message)
|
||||
return True
|
||||
|
||||
@@ -755,8 +775,35 @@ class TreeManager:
|
||||
return False
|
||||
|
||||
self._enter_state(mts, message_tree_state.State.READY_FOR_EXPORT)
|
||||
|
||||
return True
|
||||
|
||||
def activate_backlog_tree(self, lang: str) -> MessageTreeState:
|
||||
while True:
|
||||
# find tree in backlog state
|
||||
backlog_tree: MessageTreeState = (
|
||||
self.db.query(MessageTreeState)
|
||||
.join(Message, MessageTreeState.message_tree_id == Message.id) # root msg
|
||||
.filter(MessageTreeState.state == message_tree_state.State.BACKLOG_RANKING)
|
||||
.filter(Message.lang == lang)
|
||||
.limit(1)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
if not backlog_tree:
|
||||
return None
|
||||
|
||||
if len(self.query_tree_ranking_results(message_tree_id=backlog_tree.message_tree_id)) == 0:
|
||||
logger.info(
|
||||
f"Backlog tree {backlog_tree.message_tree_id} has no children to rank, aborting with 'aborted_low_grade' state."
|
||||
)
|
||||
self._enter_state(backlog_tree, message_tree_state.State.ABORTED_LOW_GRADE)
|
||||
else:
|
||||
logger.info(f"Activating backlog tree {backlog_tree.message_tree_id}")
|
||||
backlog_tree.active = True
|
||||
self._enter_state(backlog_tree, message_tree_state.State.RANKING)
|
||||
return backlog_tree
|
||||
|
||||
def _calculate_acceptance(self, labels: list[TextLabels]):
|
||||
# calculate acceptance based on spam label
|
||||
return np.mean([1 - l.labels[protocol_schema.TextLabel.spam] for l in labels])
|
||||
@@ -824,7 +871,8 @@ class TreeManager:
|
||||
_sql_find_incomplete_rankings = """
|
||||
-- find incomplete rankings
|
||||
SELECT m.parent_id, m.role, COUNT(m.id) children_count, MIN(m.ranking_count) child_min_ranking_count,
|
||||
COUNT(m.id) FILTER (WHERE m.ranking_count >= :num_required_rankings) as completed_rankings
|
||||
COUNT(m.id) FILTER (WHERE m.ranking_count >= :num_required_rankings) as completed_rankings,
|
||||
mts.message_tree_id
|
||||
FROM message_tree_state mts
|
||||
INNER JOIN message m ON mts.message_tree_id = m.message_tree_id
|
||||
WHERE mts.active -- only consider active trees
|
||||
@@ -833,7 +881,7 @@ WHERE mts.active -- only consider active trees
|
||||
AND m.lang = :lang -- matches lang
|
||||
AND NOT m.deleted -- not deleted
|
||||
AND m.parent_id IS NOT NULL -- ignore initial prompts
|
||||
GROUP BY m.parent_id, m.role
|
||||
GROUP BY m.parent_id, m.role, mts.message_tree_id
|
||||
HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings
|
||||
"""
|
||||
|
||||
@@ -842,7 +890,8 @@ HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings
|
||||
WITH incomplete_rankings AS ({_sql_find_incomplete_rankings})
|
||||
SELECT ir.* FROM incomplete_rankings ir
|
||||
LEFT JOIN message_reaction mr ON ir.parent_id = mr.message_id AND mr.payload_type = 'RankingReactionPayload'
|
||||
GROUP BY ir.parent_id, ir.role, ir.children_count, ir.child_min_ranking_count, ir.completed_rankings
|
||||
GROUP BY ir.parent_id, ir.role, ir.children_count, ir.child_min_ranking_count, ir.completed_rankings,
|
||||
ir.message_tree_id
|
||||
HAVING(COUNT(mr.message_id) FILTER (WHERE mr.user_id = :user_id) = 0)
|
||||
"""
|
||||
|
||||
@@ -981,8 +1030,8 @@ SELECT p.parent_id, mr.* FROM
|
||||
GROUP BY m.parent_id, m.message_tree_id
|
||||
HAVING COUNT(m.id) > 1
|
||||
) as p
|
||||
INNER JOIN task t ON p.parent_id = t.parent_message_id AND t.done AND (t.payload_type = 'RankPrompterRepliesPayload' OR t.payload_type = 'RankAssistantRepliesPayload')
|
||||
INNER JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'RankingReactionPayload'
|
||||
LEFT JOIN task t ON p.parent_id = t.parent_message_id AND t.done AND (t.payload_type = 'RankPrompterRepliesPayload' OR t.payload_type = 'RankAssistantRepliesPayload')
|
||||
LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'RankingReactionPayload'
|
||||
"""
|
||||
|
||||
def query_tree_ranking_results(
|
||||
@@ -1025,7 +1074,14 @@ INNER JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Ranki
|
||||
self._insert_default_state(id, state=state)
|
||||
|
||||
rankings = (
|
||||
self.db.query(MessageTreeState).filter(MessageTreeState.state == message_tree_state.State.RANKING).all()
|
||||
self.db.query(MessageTreeState)
|
||||
.filter(
|
||||
or_(
|
||||
MessageTreeState.state == message_tree_state.State.RANKING,
|
||||
MessageTreeState.state == message_tree_state.State.READY_FOR_SCORING,
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
if len(rankings) > 0:
|
||||
logger.info(f"Checking state of {len(rankings)} message trees in ranking state.")
|
||||
@@ -1318,17 +1374,17 @@ DELETE FROM user_stats WHERE user_id = :user_id;
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def retry_scoring_failed_message_trees(self):
|
||||
query = self.db.query(MessageTreeState.message_tree_id).filter(
|
||||
query = self.db.query(MessageTreeState).filter(
|
||||
MessageTreeState.state == message_tree_state.State.SCORING_FAILED
|
||||
)
|
||||
ranking_role_filter = None if self.cfg.rank_prompter_replies else "assistant"
|
||||
for row in query.all():
|
||||
for mts in query.all():
|
||||
mts: MessageTreeState
|
||||
try:
|
||||
message_tree_id = row["message_tree_id"]
|
||||
rankings_by_message = self.query_tree_ranking_results(message_tree_id, role_filter=ranking_role_filter)
|
||||
self.update_message_ranks(message_tree_id=message_tree_id, rankings_by_message=rankings_by_message)
|
||||
if not self.check_condition_for_scoring_state(mts.message_tree_id):
|
||||
mts.active = True
|
||||
self._enter_state(message_tree_state.State.RANKING)
|
||||
except Exception:
|
||||
logger.exception(f"retry_scoring_failed_message_trees failed for ({message_tree_id=})")
|
||||
logger.exception(f"retry_scoring_failed_message_trees failed for ({mts.message_tree_id=})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -1362,8 +1418,8 @@ if __name__ == "__main__":
|
||||
|
||||
# print("next_task:", tm.next_task())
|
||||
|
||||
# print(
|
||||
# ".query_tree_ranking_results", tm.query_tree_ranking_results(UUID("2ac20d38-6650-43aa-8bb3-f61080c0d921"))
|
||||
# )
|
||||
print(
|
||||
".query_tree_ranking_results", tm.query_tree_ranking_results(UUID("21f9d585-d22c-44ab-a696-baa3d83b5f1b"))
|
||||
)
|
||||
|
||||
# print(tm.export_trees_to_file(message_tree_ids=["7e75fb38-e664-4e2b-817c-b9a0b01b0074"], file="lol.jsonl"))
|
||||
|
||||
@@ -66,7 +66,13 @@ class UserRepository:
|
||||
return user
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def update_user(self, id: UUID, enabled: Optional[bool] = None, notes: Optional[str] = None) -> None:
|
||||
def update_user(
|
||||
self,
|
||||
id: UUID,
|
||||
enabled: Optional[bool] = None,
|
||||
notes: Optional[str] = None,
|
||||
show_on_leaderboard: Optional[bool] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Update a user by global user ID to disable or set admin notes. Only trusted clients may update users.
|
||||
|
||||
@@ -85,6 +91,8 @@ class UserRepository:
|
||||
user.enabled = enabled
|
||||
if notes is not None:
|
||||
user.notes = notes
|
||||
if show_on_leaderboard is not None:
|
||||
user.show_on_leaderboard = show_on_leaderboard
|
||||
|
||||
self.db.add(user)
|
||||
|
||||
@@ -109,15 +117,20 @@ class UserRepository:
|
||||
self.db.add(user)
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def _lookup_client_user_tx(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]:
|
||||
if not client_user:
|
||||
return None
|
||||
def _lookup_user_tx(
|
||||
self,
|
||||
*,
|
||||
username: str,
|
||||
auth_method: str,
|
||||
display_name: Optional[str] = None,
|
||||
create_missing: bool = True,
|
||||
) -> User | None:
|
||||
user: User = (
|
||||
self.db.query(User)
|
||||
.filter(
|
||||
User.api_client_id == self.api_client.id,
|
||||
User.username == client_user.id,
|
||||
User.auth_method == client_user.auth_method,
|
||||
User.username == username,
|
||||
User.auth_method == auth_method,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
@@ -125,28 +138,46 @@ class UserRepository:
|
||||
if create_missing:
|
||||
# user is unknown, create new record
|
||||
user = User(
|
||||
username=client_user.id,
|
||||
display_name=client_user.display_name,
|
||||
username=username,
|
||||
display_name=display_name,
|
||||
api_client_id=self.api_client.id,
|
||||
auth_method=client_user.auth_method,
|
||||
auth_method=auth_method,
|
||||
show_on_leaderboard=(auth_method != "system"), # don't show system users, e.g. import user
|
||||
)
|
||||
self.db.add(user)
|
||||
elif client_user.display_name and client_user.display_name != user.display_name:
|
||||
elif display_name and display_name != user.display_name:
|
||||
# we found the user but the display name changed
|
||||
user.display_name = client_user.display_name
|
||||
user.display_name = display_name
|
||||
self.db.add(user)
|
||||
|
||||
return user
|
||||
|
||||
def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]:
|
||||
def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> User | None:
|
||||
if not client_user:
|
||||
return None
|
||||
num_retries = settings.DATABASE_MAX_TX_RETRY_COUNT
|
||||
for i in range(num_retries):
|
||||
try:
|
||||
return self._lookup_client_user_tx(client_user, create_missing)
|
||||
return self._lookup_user_tx(
|
||||
username=client_user.id,
|
||||
auth_method=client_user.auth_method,
|
||||
display_name=client_user.display_name,
|
||||
create_missing=create_missing,
|
||||
)
|
||||
except IntegrityError:
|
||||
# catch UniqueViolation exception, for concurrent requests due to conflicts in ix_user_username
|
||||
if i + 1 == num_retries:
|
||||
raise
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def lookup_system_user(self, username: str, create_missing: bool = True) -> User | None:
|
||||
return self._lookup_user_tx(
|
||||
username=username,
|
||||
auth_method="system",
|
||||
display_name=f"__system__/{username}",
|
||||
create_missing=create_missing,
|
||||
)
|
||||
|
||||
def query_users_ordered_by_username(
|
||||
self,
|
||||
api_client_id: Optional[UUID] = None,
|
||||
|
||||
@@ -39,7 +39,7 @@ class UserStatsRepository:
|
||||
qry = (
|
||||
self.session.query(User.id.label("user_id"), User.username, User.auth_method, User.display_name, UserStats)
|
||||
.join(UserStats, User.id == UserStats.user_id)
|
||||
.filter(UserStats.time_frame == time_frame.value)
|
||||
.filter(UserStats.time_frame == time_frame.value, User.show_on_leaderboard)
|
||||
.order_by(UserStats.rank)
|
||||
.limit(limit)
|
||||
)
|
||||
@@ -214,6 +214,10 @@ class UserStatsRepository:
|
||||
d = delete(UserStats).where(UserStats.time_frame == time_frame_key)
|
||||
self.session.execute(d)
|
||||
|
||||
if None in stats_by_user:
|
||||
logger.warning("Some messages in DB have NULL values in user_id column.")
|
||||
del stats_by_user[None]
|
||||
|
||||
# compute magic leader score
|
||||
for v in stats_by_user.values():
|
||||
v.leader_score = v.compute_leader_score()
|
||||
@@ -250,7 +254,8 @@ FROM
|
||||
PARTITION BY time_frame
|
||||
ORDER BY leader_score DESC, user_id
|
||||
) AS "rank", user_id, time_frame
|
||||
FROM user_stats
|
||||
FROM user_stats us2
|
||||
INNER JOIN "user" u ON us2.user_id = u.id AND u.show_on_leaderboard
|
||||
WHERE (:time_frame IS NULL OR time_frame = :time_frame)) AS r
|
||||
WHERE
|
||||
us.user_id = r.user_id
|
||||
|
||||
@@ -7,9 +7,14 @@ from loguru import logger
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.database import engine
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from psycopg2.errors import DeadlockDetected, ExclusionViolation, SerializationFailure, UniqueViolation
|
||||
from sqlalchemy.exc import OperationalError, PendingRollbackError
|
||||
from sqlmodel import Session, SQLModel
|
||||
|
||||
"""
|
||||
Error Handling Reference: https://www.postgresql.org/docs/15/mvcc-serialization-failure-handling.html
|
||||
"""
|
||||
|
||||
|
||||
class CommitMode(IntEnum):
|
||||
"""
|
||||
@@ -34,28 +39,46 @@ def managed_tx_method(auto_commit: CommitMode = CommitMode.COMMIT, num_retries=s
|
||||
@wraps(f)
|
||||
def wrapped_f(self, *args, **kwargs):
|
||||
try:
|
||||
for i in range(num_retries):
|
||||
try:
|
||||
result = f(self, *args, **kwargs)
|
||||
if auto_commit == CommitMode.COMMIT:
|
||||
result = None
|
||||
if auto_commit == CommitMode.COMMIT:
|
||||
retry_exhausted = True
|
||||
for i in range(num_retries):
|
||||
try:
|
||||
result = f(self, *args, **kwargs)
|
||||
self.db.commit()
|
||||
elif auto_commit == CommitMode.FLUSH:
|
||||
self.db.flush()
|
||||
elif auto_commit == CommitMode.ROLLBACK:
|
||||
if isinstance(result, SQLModel):
|
||||
self.db.refresh(result)
|
||||
retry_exhausted = False
|
||||
break
|
||||
except PendingRollbackError as e:
|
||||
logger.info(str(e))
|
||||
self.db.rollback()
|
||||
except OperationalError as e:
|
||||
if e.orig is not None and isinstance(
|
||||
e.orig, (SerializationFailure, DeadlockDetected, UniqueViolation, ExclusionViolation)
|
||||
):
|
||||
logger.info(f"{type(e.orig)} Inner {e.orig.pgcode} {type(e.orig.pgcode)}")
|
||||
self.db.rollback()
|
||||
else:
|
||||
raise e
|
||||
logger.info(f"Retry {i+1}/{num_retries}")
|
||||
if retry_exhausted:
|
||||
raise OasstError(
|
||||
"DATABASE_MAX_RETIRES_EXHAUSTED",
|
||||
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
|
||||
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
)
|
||||
else:
|
||||
result = f(self, *args, **kwargs)
|
||||
if auto_commit == CommitMode.FLUSH:
|
||||
self.db.flush()
|
||||
if isinstance(result, SQLModel):
|
||||
self.db.refresh(result)
|
||||
return result
|
||||
except OperationalError:
|
||||
logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.")
|
||||
elif auto_commit == CommitMode.ROLLBACK:
|
||||
self.db.rollback()
|
||||
raise OasstError(
|
||||
"DATABASE_MAX_RETIRES_EXHAUSTED",
|
||||
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
|
||||
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error("DB Rollback Failure")
|
||||
logger.info(str(e))
|
||||
raise e
|
||||
|
||||
return wrapped_f
|
||||
@@ -70,28 +93,46 @@ def async_managed_tx_method(
|
||||
@wraps(f)
|
||||
async def wrapped_f(self, *args, **kwargs):
|
||||
try:
|
||||
for i in range(num_retries):
|
||||
try:
|
||||
result = await f(self, *args, **kwargs)
|
||||
if auto_commit == CommitMode.COMMIT:
|
||||
result = None
|
||||
if auto_commit == CommitMode.COMMIT:
|
||||
retry_exhausted = True
|
||||
for i in range(num_retries):
|
||||
try:
|
||||
result = await f(self, *args, **kwargs)
|
||||
self.db.commit()
|
||||
elif auto_commit == CommitMode.FLUSH:
|
||||
self.db.flush()
|
||||
elif auto_commit == CommitMode.ROLLBACK:
|
||||
if isinstance(result, SQLModel):
|
||||
self.db.refresh(result)
|
||||
retry_exhausted = False
|
||||
break
|
||||
except PendingRollbackError as e:
|
||||
logger.info(str(e))
|
||||
self.db.rollback()
|
||||
except OperationalError as e:
|
||||
if e.orig is not None and isinstance(
|
||||
e.orig, (SerializationFailure, DeadlockDetected, UniqueViolation, ExclusionViolation)
|
||||
):
|
||||
logger.info(f"{type(e.orig)} Inner {e.orig.pgcode} {type(e.orig.pgcode)}")
|
||||
self.db.rollback()
|
||||
else:
|
||||
raise e
|
||||
logger.info(f"Retry {i+1}/{num_retries}")
|
||||
if retry_exhausted:
|
||||
raise OasstError(
|
||||
"DATABASE_MAX_RETIRES_EXHAUSTED",
|
||||
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
|
||||
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
)
|
||||
else:
|
||||
result = await f(self, *args, **kwargs)
|
||||
if auto_commit == CommitMode.FLUSH:
|
||||
self.db.flush()
|
||||
if isinstance(result, SQLModel):
|
||||
self.db.refresh(result)
|
||||
return result
|
||||
except OperationalError:
|
||||
logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.")
|
||||
elif auto_commit == CommitMode.ROLLBACK:
|
||||
self.db.rollback()
|
||||
raise OasstError(
|
||||
"DATABASE_MAX_RETIRES_EXHAUSTED",
|
||||
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
|
||||
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.exception("DB Rollback Failure")
|
||||
logger.info(str(e))
|
||||
raise e
|
||||
|
||||
return wrapped_f
|
||||
@@ -107,7 +148,6 @@ def managed_tx_function(
|
||||
auto_commit: CommitMode = CommitMode.COMMIT,
|
||||
num_retries=settings.DATABASE_MAX_TX_RETRY_COUNT,
|
||||
session_factory: Callable[..., Session] = default_session_factor,
|
||||
refresh_result: bool = True,
|
||||
):
|
||||
"""Passes Session object as first argument to wrapped function."""
|
||||
|
||||
@@ -115,29 +155,49 @@ def managed_tx_function(
|
||||
@wraps(f)
|
||||
def wrapped_f(*args, **kwargs):
|
||||
try:
|
||||
for i in range(num_retries):
|
||||
with session_factory() as session:
|
||||
try:
|
||||
result = f(session, *args, **kwargs)
|
||||
if auto_commit == CommitMode.COMMIT:
|
||||
result = None
|
||||
if auto_commit == CommitMode.COMMIT:
|
||||
retry_exhausted = True
|
||||
for i in range(num_retries):
|
||||
with session_factory() as session:
|
||||
try:
|
||||
result = f(session, *args, **kwargs)
|
||||
session.commit()
|
||||
elif auto_commit == CommitMode.FLUSH:
|
||||
session.flush()
|
||||
elif auto_commit == CommitMode.ROLLBACK:
|
||||
if isinstance(result, SQLModel):
|
||||
session.refresh(result)
|
||||
retry_exhausted = False
|
||||
break
|
||||
except PendingRollbackError as e:
|
||||
logger.info(str(e))
|
||||
session.rollback()
|
||||
if refresh_result and isinstance(result, SQLModel):
|
||||
session.refresh(result)
|
||||
return result
|
||||
except OperationalError:
|
||||
logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.")
|
||||
session.rollback()
|
||||
raise OasstError(
|
||||
"DATABASE_MAX_RETIRES_EXHAUSTED",
|
||||
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
|
||||
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
)
|
||||
except OperationalError as e:
|
||||
if e.orig is not None and isinstance(
|
||||
e.orig,
|
||||
(SerializationFailure, DeadlockDetected, UniqueViolation, ExclusionViolation),
|
||||
):
|
||||
logger.info(f"{type(e.orig)} Inner {e.orig.pgcode} {type(e.orig.pgcode)}")
|
||||
session.rollback()
|
||||
else:
|
||||
raise e
|
||||
logger.info(f"Retry {i+1}/{num_retries}")
|
||||
if retry_exhausted:
|
||||
raise OasstError(
|
||||
"DATABASE_MAX_RETIRES_EXHAUSTED",
|
||||
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
|
||||
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
)
|
||||
else:
|
||||
with session_factory() as session:
|
||||
result = f(session, *args, **kwargs)
|
||||
if auto_commit == CommitMode.FLUSH:
|
||||
session.flush()
|
||||
if isinstance(result, SQLModel):
|
||||
session.refresh(result)
|
||||
elif auto_commit == CommitMode.ROLLBACK:
|
||||
session.rollback()
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error("DB Rollback Failure")
|
||||
logger.info(str(e))
|
||||
raise e
|
||||
|
||||
return wrapped_f
|
||||
|
||||
@@ -0,0 +1,111 @@
|
||||
import os
|
||||
import pickle
|
||||
from collections import Counter
|
||||
|
||||
from sklearn import metrics
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.pipeline import Pipeline
|
||||
from sklearn.svm import LinearSVC
|
||||
|
||||
|
||||
def load_and_split(foldername, num_words):
|
||||
ls = os.listdir(foldername)
|
||||
X = []
|
||||
Y = []
|
||||
langmap = dict()
|
||||
for idx, x in enumerate(ls):
|
||||
print("loading language", x)
|
||||
with open(foldername + "/" + x, "r") as reader:
|
||||
tmp = reader.read().split(" ")
|
||||
tmp = [" ".join(tmp[i : i + num_words]) for i in range(0, 100_000, num_words)]
|
||||
X.extend(tmp)
|
||||
Y.extend([idx] * len(tmp))
|
||||
langmap[idx] = x
|
||||
x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.90)
|
||||
return x_train, x_test, y_train, y_test, langmap
|
||||
|
||||
|
||||
def build_and_train_pipeline(x_train, y_train):
|
||||
vectorizer = TfidfVectorizer(ngram_range=(1, 2), analyzer="char", use_idf=False)
|
||||
clf = Pipeline(
|
||||
[
|
||||
("vec", vectorizer),
|
||||
# ("nystrom", Nystroem(n_components=1000,n_jobs=6)),
|
||||
("clf", LinearSVC(C=0.5)),
|
||||
# ("clf",GaussianNB())
|
||||
# ("clf", HistGradientBoostingClassifier())
|
||||
]
|
||||
)
|
||||
print("fitting model...")
|
||||
clf.fit(x_train, y_train)
|
||||
return clf
|
||||
|
||||
|
||||
def benchmark(clf, x_test, y_test, langmap):
|
||||
print("benchmarking model...")
|
||||
y_pred = clf.predict(x_test)
|
||||
names = list(langmap.values())
|
||||
# print(y_test)
|
||||
# print(langmap)
|
||||
print(metrics.classification_report(y_test, y_pred, target_names=names))
|
||||
cm = metrics.confusion_matrix(y_test, y_pred)
|
||||
print(cm)
|
||||
|
||||
|
||||
def main(foldername, modelname, num_words):
|
||||
x_train, x_test, y_train, y_test, langmap = load_and_split(foldername=foldername, num_words=num_words)
|
||||
clf = build_and_train_pipeline(x_train, y_train)
|
||||
benchmark(clf, x_test, y_test, langmap)
|
||||
save_model(clf, langmap, num_words, modelname)
|
||||
model = load(modelname)
|
||||
print(
|
||||
"running infernence on long tests",
|
||||
inference_voter(
|
||||
model,
|
||||
"""
|
||||
What language is this text written in? Nobody knows until you fill in at least ten words.
|
||||
This test here is to check whether the moving window approach works,
|
||||
so I still need to fill in a little more text.
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def load(modelname):
|
||||
with open(modelname, "rb") as writer:
|
||||
data = pickle.load(writer)
|
||||
return data
|
||||
|
||||
|
||||
def save_model(model, idx_to_name, num_words, modelname):
|
||||
out = {
|
||||
"model": model,
|
||||
"idx_to_name": idx_to_name,
|
||||
"num_words": num_words,
|
||||
}
|
||||
with open(modelname, "wb") as writer:
|
||||
pickle.dump(out, writer)
|
||||
|
||||
|
||||
def inference_voter(model, text):
|
||||
tmp = text.split()
|
||||
# print(len(tmp), tmp)
|
||||
tmp = [" ".join(tmp[i : i + model["num_words"]]) for i in range(0, len(tmp) - model["num_words"])]
|
||||
predictions = model["model"].predict(tmp)
|
||||
# print("integer predictions", predictions)
|
||||
# print("name predictions", *[model["idx_to_name"][n] for n in predictions])
|
||||
result = Counter(predictions).most_common(1)[0][0]
|
||||
return model["idx_to_name"][result]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-m", "--model", help="save location for model and metadata")
|
||||
parser.add_argument("-d", "--data", help="specify the folder for data files")
|
||||
parser.add_argument("-n", "--num_words", help="number of words to use for statistics", type=int)
|
||||
args = parser.parse_args()
|
||||
# np.set_printoptions(threshold=np.inf)
|
||||
main(args.data, args.model, args.num_words)
|
||||
@@ -12,12 +12,15 @@ from pydantic import BaseModel
|
||||
|
||||
class ExportMessageNode(BaseModel):
|
||||
message_id: str
|
||||
parent_id: Optional[str]
|
||||
text: Optional[str]
|
||||
parent_id: str | None
|
||||
text: str
|
||||
role: str
|
||||
review_count: Optional[int]
|
||||
rank: Optional[int]
|
||||
replies: Optional[list[ExportMessageNode]]
|
||||
lang: str | None
|
||||
review_count: int | None
|
||||
rank: int | None
|
||||
synthetic: bool | None
|
||||
model_name: str | None
|
||||
replies: list[ExportMessageNode] | None
|
||||
|
||||
@classmethod
|
||||
def prep_message_export(cls, message: Message) -> ExportMessageNode:
|
||||
@@ -26,14 +29,17 @@ class ExportMessageNode(BaseModel):
|
||||
parent_id=str(message.parent_id) if message.parent_id else None,
|
||||
text=str(message.payload.payload.text),
|
||||
role=message.role,
|
||||
lang=message.lang,
|
||||
review_count=message.review_count,
|
||||
synthetic=message.synthetic,
|
||||
model_name=message.model_name,
|
||||
rank=message.rank,
|
||||
)
|
||||
|
||||
|
||||
class ExportMessageTree(BaseModel):
|
||||
message_tree_id: str
|
||||
replies: Optional[ExportMessageNode]
|
||||
prompt: Optional[ExportMessageNode]
|
||||
|
||||
|
||||
def build_export_tree(message_tree_id: str, messages: list[Message]) -> ExportMessageTree:
|
||||
|
||||
@@ -13,5 +13,6 @@ RUN pip install -e /oasst-shared
|
||||
COPY ./backend/alembic /app/alembic
|
||||
COPY ./backend/alembic.ini /app/alembic.ini
|
||||
COPY ./backend/main.py /app/main.py
|
||||
COPY ./backend/import.py /app/import.py
|
||||
COPY ./backend/oasst_backend /app/oasst_backend
|
||||
COPY ./backend/test_data /app/test_data
|
||||
|
||||
@@ -38,6 +38,8 @@ class OasstErrorCode(IntEnum):
|
||||
TASK_REQUESTED_TYPE_NOT_AVAILABLE = 1006
|
||||
TASK_AVAILABILITY_QUERY_FAILED = 1007
|
||||
TASK_MESSAGE_TOO_LONG = 1008
|
||||
TASK_MESSAGE_DUPLICATED = 1009
|
||||
TASK_MESSAGE_TEXT_EMPTY = 1010
|
||||
|
||||
# 2000-3000: prompt_repository
|
||||
INVALID_FRONTEND_MESSAGE_ID = 2000
|
||||
|
||||
@@ -258,7 +258,8 @@ class LabelConversationReplyTask(AbstractLabelTask):
|
||||
"""A task to label a reply to a conversation."""
|
||||
|
||||
type: Literal["label_conversation_reply"] = "label_conversation_reply"
|
||||
conversation: Conversation # the conversation so far
|
||||
conversation: Conversation # the conversation so far (new: including the reply message)
|
||||
reply_message: Optional[ConversationMessage]
|
||||
reply: str
|
||||
|
||||
|
||||
|
||||
@@ -11,6 +11,10 @@ describe("labeling assistant replies", () => {
|
||||
// For specific task pages the no task available result is normal.
|
||||
if (type === undefined) return;
|
||||
|
||||
cy.get('[data-cy="label-question"]').each((label) => {
|
||||
// Click the no button, this generally approves the spam check
|
||||
cy.wrap(label).find('[data-cy="no"]').click();
|
||||
});
|
||||
cy.get('[data-cy="label-options"]').each((label) => {
|
||||
// Click the 4th option
|
||||
cy.wrap(label).find('[data-cy="radio-option"]').eq(3).click();
|
||||
|
||||
@@ -11,6 +11,10 @@ describe("labeling initial prompts", () => {
|
||||
// For specific task pages the no task available result is normal.
|
||||
if (type === undefined) return;
|
||||
|
||||
cy.get('[data-cy="label-question"]').each((label) => {
|
||||
// Click the no button, this generally approves the spam check
|
||||
cy.wrap(label).find('[data-cy="no"]').click();
|
||||
});
|
||||
cy.get('[data-cy="label-options"]').each((label) => {
|
||||
// Click the 4th option
|
||||
cy.wrap(label).find('[data-cy="radio-option"]').eq(3).click();
|
||||
|
||||
@@ -11,6 +11,10 @@ describe("labeling prompter replies", () => {
|
||||
// For specific task pages the no task available result is normal.
|
||||
if (type === undefined) return;
|
||||
|
||||
cy.get('[data-cy="label-question"]').each((label) => {
|
||||
// Click the no button, this generally approves the spam check
|
||||
cy.wrap(label).find('[data-cy="no"]').click();
|
||||
});
|
||||
cy.get('[data-cy="label-options"]').each((label) => {
|
||||
// Click the 4th option
|
||||
cy.wrap(label).find('[data-cy="radio-option"]').eq(3).click();
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
describe("no tasks available", () => {
|
||||
it("displays an empty state when no tasks are available", () => {
|
||||
cy.signInWithEmail("cypress@example.com");
|
||||
cy.intercept(
|
||||
{
|
||||
method: "GET",
|
||||
url: "/api/new_task/prompter_reply",
|
||||
},
|
||||
{
|
||||
statusCode: 500,
|
||||
body: {
|
||||
message: "No tasks of type 'label_prompter_reply' are currently available.",
|
||||
errorCode: 1006,
|
||||
httpStatusCode: 503,
|
||||
},
|
||||
}
|
||||
).as("newTaskPrompterReply");
|
||||
cy.visit("/create/user_reply");
|
||||
cy.wait("@newTaskPrompterReply").then(() => {
|
||||
cy.get('[data-cy="cy-no-tasks"]').should("exist");
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -44,6 +44,10 @@ describe("handles random tasks", () => {
|
||||
break;
|
||||
}
|
||||
case "label-task": {
|
||||
cy.get('[data-cy="label-question"]').each((label) => {
|
||||
// Click the no button, this generally approves the spam check
|
||||
cy.wrap(label).find('[data-cy="no"]').click();
|
||||
});
|
||||
cy.get('[data-cy="label-options"]').each((label) => {
|
||||
// Click the 4th option
|
||||
cy.wrap(label).find('[data-cy="radio-option"]').eq(3).click();
|
||||
@@ -55,15 +59,6 @@ describe("handles random tasks", () => {
|
||||
|
||||
break;
|
||||
}
|
||||
case "spam-task": {
|
||||
cy.get('[data-cy="not-spam-button"]').click();
|
||||
|
||||
cy.get('[data-cy="review"]').click();
|
||||
|
||||
cy.get('[data-cy="submit"]').click();
|
||||
|
||||
break;
|
||||
}
|
||||
case undefined: {
|
||||
throw new Error("No tasks available, but at least create initial prompt expected");
|
||||
}
|
||||
|
||||
@@ -9,11 +9,14 @@
|
||||
"docs": "Docs",
|
||||
"github": "GitHub",
|
||||
"legal": "Legal",
|
||||
"loading": "Loading...",
|
||||
"more_information": "More Information",
|
||||
"no": "No",
|
||||
"privacy_policy": "Privacy Policy",
|
||||
"report_a_bug": "Report a Bug",
|
||||
"sign_in": "Sign In",
|
||||
"sign_out": "Sign Out",
|
||||
"terms_of_service": "Terms of Service",
|
||||
"title": "Open Assistant",
|
||||
"more_information": "More Information"
|
||||
"yes": "Yes"
|
||||
}
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"label_highlighted_yes_no_instruction": "Answer the following question(s) about the highlighted message:",
|
||||
"label_highlighted_flag_instruction": "Select any that apply to the highlighted message:",
|
||||
"label_highlighted_likert_instruction": "Rate the highlighted message:",
|
||||
"label_message_yes_no_instruction": "Answer the following question(s) about the message:",
|
||||
"label_message_flag_instruction": "Select any that apply to the message:",
|
||||
"label_message_likert_instruction": "Rate the message:",
|
||||
"spam.question": "Is the message spam?",
|
||||
"fails_task.question": "Does the reply fail the propmpters task?",
|
||||
"not_appropriate": "Not Appropriate",
|
||||
"pii": "Contains PII",
|
||||
"hate_speech": "Hate Speech",
|
||||
"sexual_content": "Sexual Content",
|
||||
"moral_judgement": "Judges Morality",
|
||||
"political_content": "Politcal"
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"label_action": "Label",
|
||||
"label_title": "Label",
|
||||
"message": "Message",
|
||||
"open_new_tab_action": "Open in new tab",
|
||||
"parent": "Parent",
|
||||
"reactions": "Reactions",
|
||||
"report_action": "Report",
|
||||
"report_placeholder": "Why should this message be reviewed?",
|
||||
"report_title": "Report",
|
||||
"send_report": "Send",
|
||||
"submit_labels": "Submit"
|
||||
}
|
||||
@@ -15,7 +15,7 @@ export const EmptyState = (props: EmptyStateProps) => {
|
||||
<Box data-cy={props["data-cy"]} bg={backgroundColor} p="10" borderRadius="xl" shadow="base">
|
||||
<Box display="flex" flexDirection="column" alignItems="center" gap="8" fontSize="lg">
|
||||
<props.icon size="30" color="DarkOrange" />
|
||||
<Text>{props.text}</Text>
|
||||
<Text data-cy="cy-no-tasks">{props.text}</Text>
|
||||
<NextLink href="/dashboard">
|
||||
<Text color="blue.500">Go back to the dashboard</Text>
|
||||
</NextLink>
|
||||
|
||||
@@ -1,116 +0,0 @@
|
||||
import {
|
||||
Box,
|
||||
Button,
|
||||
Modal,
|
||||
ModalBody,
|
||||
ModalCloseButton,
|
||||
ModalContent,
|
||||
ModalFooter,
|
||||
ModalHeader,
|
||||
ModalOverlay,
|
||||
Popover,
|
||||
PopoverAnchor,
|
||||
PopoverTrigger,
|
||||
Tooltip,
|
||||
useColorModeValue,
|
||||
useDisclosure,
|
||||
} from "@chakra-ui/react";
|
||||
import { AlertCircle } from "lucide-react";
|
||||
import { useState } from "react";
|
||||
import { get, post } from "src/lib/api";
|
||||
import { colors } from "src/styles/Theme/colors";
|
||||
import { Message } from "src/types/Conversation";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
|
||||
import { LabelInputGroup } from "./Survey/LabelInputGroup";
|
||||
|
||||
interface Label {
|
||||
name: string;
|
||||
display_text: string;
|
||||
help_text: string;
|
||||
}
|
||||
|
||||
interface FlaggableElementProps {
|
||||
children: React.ReactNode;
|
||||
message: Message;
|
||||
}
|
||||
|
||||
interface ValidLabelsResponse {
|
||||
valid_labels: Label[];
|
||||
}
|
||||
|
||||
export const FlaggableElement = (props: FlaggableElementProps) => {
|
||||
const { data: response } = useSWRImmutable<ValidLabelsResponse>("/api/valid_labels", get);
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
const { valid_labels } = response || { valid_labels: [] };
|
||||
const [values, setValues] = useState<number[]>([]);
|
||||
|
||||
const submittable =
|
||||
values.some((value) => {
|
||||
return value !== null;
|
||||
}) &&
|
||||
values.length === valid_labels.length &&
|
||||
valid_labels.length > 0;
|
||||
|
||||
const { trigger } = useSWRMutation("/api/set_label", post, {
|
||||
onSuccess: onClose,
|
||||
onError: onClose,
|
||||
});
|
||||
|
||||
const submitResponse = () => {
|
||||
const label_map: Map<string, number> = new Map();
|
||||
console.assert(valid_labels.length === values.length);
|
||||
values.forEach((value, idx) => {
|
||||
if (value !== null) {
|
||||
label_map.set(valid_labels[idx].name, value);
|
||||
}
|
||||
});
|
||||
trigger({
|
||||
message_id: props.message.id,
|
||||
label_map: Object.fromEntries(label_map),
|
||||
text: props.message.text,
|
||||
});
|
||||
};
|
||||
|
||||
return (
|
||||
<Popover isOpen={isOpen} onOpen={onOpen} onClose={onClose} closeOnBlur={false} isLazy lazyBehavior="keepMounted">
|
||||
<Box display="flex" alignItems="center" flexDirection={["column", "row"]} gap="2">
|
||||
<PopoverAnchor>{props.children}</PopoverAnchor>
|
||||
|
||||
<Tooltip label="Report" bg="red.500" aria-label="A tooltip">
|
||||
<Box>
|
||||
<PopoverTrigger>
|
||||
<Box as="button" display="flex" alignItems="center" justifyContent="center" borderRadius="full" p="1">
|
||||
<AlertCircle size="20" className="text-red-400" aria-hidden="true" />
|
||||
</Box>
|
||||
</PopoverTrigger>
|
||||
</Box>
|
||||
</Tooltip>
|
||||
</Box>
|
||||
|
||||
<Modal isOpen={isOpen} onClose={onClose}>
|
||||
<ModalOverlay />
|
||||
<ModalContent>
|
||||
<ModalHeader>Select one or more labels that apply.</ModalHeader>
|
||||
<ModalCloseButton />
|
||||
<ModalBody>
|
||||
<LabelInputGroup labelIDs={valid_labels.map(({ name }) => name)} onChange={setValues} />
|
||||
</ModalBody>
|
||||
<ModalFooter>
|
||||
<Button
|
||||
isDisabled={!submittable}
|
||||
onClick={submitResponse}
|
||||
className={`bg-indigo-600 text-${useColorModeValue(
|
||||
colors.light.text,
|
||||
colors.dark.text
|
||||
)} hover:bg-indigo-700`}
|
||||
>
|
||||
Report
|
||||
</Button>
|
||||
</ModalFooter>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
</Popover>
|
||||
);
|
||||
};
|
||||
@@ -1,26 +1,8 @@
|
||||
import { Box, forwardRef, Grid, useColorMode } from "@chakra-ui/react";
|
||||
import { Box, forwardRef, useColorMode } from "@chakra-ui/react";
|
||||
import { useMemo } from "react";
|
||||
import { Message } from "src/types/Conversation";
|
||||
|
||||
import { FlaggableElement } from "./FlaggableElement";
|
||||
|
||||
interface MessagesProps {
|
||||
messages: Message[];
|
||||
}
|
||||
|
||||
export const Messages = ({ messages }: MessagesProps) => {
|
||||
const items = messages.map((messageProps: Message, i: number) => {
|
||||
return (
|
||||
<FlaggableElement message={messageProps} key={i + messageProps.id}>
|
||||
<MessageView {...messageProps} />
|
||||
</FlaggableElement>
|
||||
);
|
||||
});
|
||||
// Maybe also show a legend of the colors?
|
||||
return <Grid gap={2}>{items}</Grid>;
|
||||
};
|
||||
|
||||
export const MessageView = forwardRef<Message, "div">((message: Message, ref) => {
|
||||
export const MessageView = forwardRef<Partial<Message>, "div">((message: Partial<Message>, ref) => {
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
const bgColor = useMemo(() => {
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
import { Button, Flex } from "@chakra-ui/react";
|
||||
import { useTranslation } from "next-i18next";
|
||||
import { getTypeSafei18nKey } from "src/lib/i18n";
|
||||
|
||||
interface LabelFlagGroupProps {
|
||||
values: number[];
|
||||
labelNames: string[];
|
||||
isEditable?: boolean;
|
||||
onChange: (values: number[]) => void;
|
||||
}
|
||||
|
||||
export const LabelFlagGroup = ({ values, labelNames, isEditable = true, onChange }: LabelFlagGroupProps) => {
|
||||
const { t } = useTranslation("labelling");
|
||||
return (
|
||||
<Flex wrap="wrap" gap="4">
|
||||
{labelNames.map((name, idx) => (
|
||||
<Button
|
||||
key={name}
|
||||
onClick={() => {
|
||||
const newValues = values.slice();
|
||||
newValues[idx] = newValues[idx] ? 0 : 1;
|
||||
onChange(newValues);
|
||||
}}
|
||||
isDisabled={!isEditable}
|
||||
colorScheme={values[idx] === 1 ? "blue" : undefined}
|
||||
>
|
||||
{t(getTypeSafei18nKey(name))}
|
||||
</Button>
|
||||
))}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,84 @@
|
||||
import { Text, VStack } from "@chakra-ui/react";
|
||||
import { Label } from "src/types/Tasks";
|
||||
|
||||
import { LabelLikertGroup } from "../Survey/LabelLikertGroup";
|
||||
import { LabelFlagGroup } from "./LabelFlagGroup";
|
||||
import { LabelYesNoGroup } from "./LabelYesNoGroup";
|
||||
|
||||
export interface LabelInputInstructions {
|
||||
yesNoInstruction: string;
|
||||
flagInstruction: string;
|
||||
likertInstruction: string;
|
||||
}
|
||||
|
||||
interface LabelInputGroupProps {
|
||||
values: number[];
|
||||
labels: Label[];
|
||||
requiredLabels?: string[];
|
||||
isEditable?: boolean;
|
||||
instructions: LabelInputInstructions;
|
||||
onChange: (values: number[]) => void;
|
||||
}
|
||||
|
||||
export const LabelInputGroup = ({
|
||||
labels,
|
||||
values,
|
||||
requiredLabels,
|
||||
isEditable,
|
||||
instructions,
|
||||
onChange,
|
||||
}: LabelInputGroupProps) => {
|
||||
const yesNoIndexes = labels.map((label, idx) => (label.widget === "yes_no" ? idx : null)).filter((v) => v !== null);
|
||||
const flagIndexes = labels.map((label, idx) => (label.widget === "flag" ? idx : null)).filter((v) => v !== null);
|
||||
const likertIndexes = labels.map((label, idx) => (label.widget === "likert" ? idx : null)).filter((v) => v !== null);
|
||||
|
||||
return (
|
||||
<VStack alignItems="stretch" spacing={6}>
|
||||
{yesNoIndexes.length > 0 && (
|
||||
<VStack alignItems="stretch" spacing={2}>
|
||||
<Text>{instructions.yesNoInstruction}</Text>
|
||||
<LabelYesNoGroup
|
||||
values={yesNoIndexes.map((idx) => values[idx])}
|
||||
labelNames={yesNoIndexes.map((idx) => labels[idx].name)}
|
||||
isEditable={isEditable}
|
||||
requiredLabels={requiredLabels}
|
||||
onChange={(yesNoValues) => {
|
||||
const newValues = values.slice();
|
||||
yesNoIndexes.forEach((idx, yesNoIndex) => (newValues[idx] = yesNoValues[yesNoIndex]));
|
||||
onChange(newValues);
|
||||
}}
|
||||
/>
|
||||
</VStack>
|
||||
)}
|
||||
{flagIndexes.length > 0 && (
|
||||
<VStack alignItems="stretch" spacing={2}>
|
||||
<Text>{instructions.flagInstruction}</Text>
|
||||
<LabelFlagGroup
|
||||
values={flagIndexes.map((idx) => values[idx])}
|
||||
labelNames={flagIndexes.map((idx) => labels[idx].name)}
|
||||
isEditable={isEditable}
|
||||
onChange={(flagValues) => {
|
||||
const newValues = values.slice();
|
||||
flagIndexes.forEach((idx, flagIndex) => (newValues[idx] = flagValues[flagIndex]));
|
||||
onChange(newValues);
|
||||
}}
|
||||
/>
|
||||
</VStack>
|
||||
)}
|
||||
{likertIndexes.length > 0 && (
|
||||
<VStack alignItems="stretch" spacing={2}>
|
||||
<Text>{instructions.likertInstruction}</Text>
|
||||
<LabelLikertGroup
|
||||
labelIDs={likertIndexes.map((idx) => labels[idx].name)}
|
||||
isEditable={isEditable}
|
||||
onChange={(likertValues) => {
|
||||
const newValues = values.slice();
|
||||
likertIndexes.forEach((idx, likertIndex) => (newValues[idx] = likertValues[likertIndex]));
|
||||
onChange(newValues);
|
||||
}}
|
||||
/>
|
||||
</VStack>
|
||||
)}
|
||||
</VStack>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,84 @@
|
||||
import {
|
||||
Button,
|
||||
Modal,
|
||||
ModalBody,
|
||||
ModalCloseButton,
|
||||
ModalContent,
|
||||
ModalFooter,
|
||||
ModalHeader,
|
||||
ModalOverlay,
|
||||
} from "@chakra-ui/react";
|
||||
import { useTranslation } from "next-i18next";
|
||||
import { useEffect, useState } from "react";
|
||||
import { LabelInputGroup } from "src/components/Messages/LabelInputGroup";
|
||||
import { get, post } from "src/lib/api";
|
||||
import { Label } from "src/types/Tasks";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
|
||||
interface LabelMessagePopupProps {
|
||||
messageId: string;
|
||||
show: boolean;
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
interface ValidLabelsResponse {
|
||||
valid_labels: Label[];
|
||||
}
|
||||
|
||||
export const LabelMessagePopup = ({ messageId, show, onClose }: LabelMessagePopupProps) => {
|
||||
const { t } = useTranslation();
|
||||
const { data: response } = useSWRImmutable<ValidLabelsResponse>("/api/valid_labels", get);
|
||||
const valid_labels = response?.valid_labels ?? [];
|
||||
const [values, setValues] = useState<number[]>(new Array(valid_labels.length).fill(null));
|
||||
|
||||
useEffect(() => {
|
||||
setValues(new Array(valid_labels.length).fill(null));
|
||||
}, [messageId, valid_labels.length]);
|
||||
|
||||
const { trigger: setLabels } = useSWRMutation("/api/set_label", post);
|
||||
|
||||
const submit = () => {
|
||||
const label_map: Map<string, number> = new Map();
|
||||
console.assert(valid_labels.length === values.length);
|
||||
values.forEach((value, idx) => {
|
||||
if (value !== null) {
|
||||
label_map.set(valid_labels[idx].name, value);
|
||||
}
|
||||
});
|
||||
setLabels({
|
||||
message_id: messageId,
|
||||
label_map: Object.fromEntries(label_map),
|
||||
});
|
||||
|
||||
setValues(null);
|
||||
onClose();
|
||||
};
|
||||
|
||||
return (
|
||||
<Modal isOpen={show} onClose={onClose}>
|
||||
<ModalOverlay />
|
||||
<ModalContent>
|
||||
<ModalHeader>{t("message:label_title")}</ModalHeader>
|
||||
<ModalCloseButton />
|
||||
<ModalBody>
|
||||
<LabelInputGroup
|
||||
labels={valid_labels}
|
||||
values={values}
|
||||
instructions={{
|
||||
yesNoInstruction: t("labelling:label_message_yes_no_instruction"),
|
||||
flagInstruction: t("labelling:label_message_flag_instruction"),
|
||||
likertInstruction: t("labelling:label_message_likert_instruction"),
|
||||
}}
|
||||
onChange={setValues}
|
||||
/>
|
||||
</ModalBody>
|
||||
<ModalFooter>
|
||||
<Button colorScheme="blue" mr={3} onClick={submit}>
|
||||
{t("message:submit_labels")}
|
||||
</Button>
|
||||
</ModalFooter>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,89 @@
|
||||
import { Button, HStack, Text, Tooltip } from "@chakra-ui/react";
|
||||
import { useTranslation } from "next-i18next";
|
||||
import { getTypeSafei18nKey } from "src/lib/i18n";
|
||||
|
||||
interface LabelYesNoGroupProps {
|
||||
values: number[];
|
||||
labelNames: string[];
|
||||
requiredLabels?: string[];
|
||||
isEditable?: boolean;
|
||||
onChange: (values: number[]) => void;
|
||||
}
|
||||
|
||||
export const LabelYesNoGroup = ({
|
||||
values,
|
||||
labelNames,
|
||||
requiredLabels = [],
|
||||
isEditable = true,
|
||||
onChange,
|
||||
}: LabelYesNoGroupProps) => {
|
||||
const { t } = useTranslation("labelling");
|
||||
return (
|
||||
<>
|
||||
{labelNames.map((name, idx) => {
|
||||
return (
|
||||
<YesNoQuestion
|
||||
key={name}
|
||||
question={t(getTypeSafei18nKey(`${name}.question`))}
|
||||
value={values[idx] === null ? null : values[idx] > 0.1 ? true : false}
|
||||
onChange={(value) => {
|
||||
const newValues = values.slice();
|
||||
newValues[idx] = value;
|
||||
onChange(newValues);
|
||||
}}
|
||||
isEditable={isEditable}
|
||||
isRequired={requiredLabels.includes(name)}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
const YesNoQuestion = ({
|
||||
isEditable,
|
||||
question,
|
||||
value,
|
||||
isRequired,
|
||||
onChange,
|
||||
}: {
|
||||
isEditable: boolean;
|
||||
question: string;
|
||||
value: boolean;
|
||||
isRequired?: boolean;
|
||||
onChange: (boolean) => void;
|
||||
}) => {
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<div data-cy="label-question" style={{ maxWidth: "30em" }}>
|
||||
<Text display="inline">
|
||||
{question}
|
||||
{isRequired ? <RequiredMark /> : undefined}
|
||||
</Text>
|
||||
<HStack style={{ float: "right" }}>
|
||||
<Button
|
||||
data-cy="yes"
|
||||
isDisabled={!isEditable}
|
||||
colorScheme={value === true ? "blue" : undefined}
|
||||
onClick={() => onChange(isRequired ? true : value === null ? true : null)}
|
||||
>
|
||||
{t("yes")}
|
||||
</Button>
|
||||
<Button
|
||||
data-cy="no"
|
||||
isDisabled={!isEditable}
|
||||
colorScheme={value === false ? "blue" : undefined}
|
||||
onClick={() => onChange(isRequired ? false : value === null ? false : null)}
|
||||
>
|
||||
{t("no")}
|
||||
</Button>
|
||||
</HStack>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
const RequiredMark = () => (
|
||||
<Tooltip label="Required">
|
||||
<span style={{ color: "red" }}>*</span>
|
||||
</Tooltip>
|
||||
);
|
||||
@@ -0,0 +1,34 @@
|
||||
import React from "react";
|
||||
|
||||
import { MessageEmojiButton } from "./MessageEmojiButton";
|
||||
|
||||
// eslint-disable-next-line import/no-anonymous-default-export
|
||||
export default {
|
||||
title: "Messages/MessageEmojiButton",
|
||||
component: MessageEmojiButton,
|
||||
};
|
||||
|
||||
const Template = ({ emoji, count, checked }: { emoji: string; count: number; checked?: boolean }) => {
|
||||
return <MessageEmojiButton emoji={{ name: emoji, count }} checked={checked} onClick={undefined} />;
|
||||
};
|
||||
|
||||
export const Default = Template.bind({});
|
||||
Default.args = {
|
||||
emoji: "+1",
|
||||
count: 7,
|
||||
checked: false,
|
||||
};
|
||||
|
||||
export const BigNumber = Template.bind({});
|
||||
BigNumber.args = {
|
||||
emoji: "+1",
|
||||
count: 999,
|
||||
checked: false,
|
||||
};
|
||||
|
||||
export const Checked = Template.bind({});
|
||||
Checked.args = {
|
||||
emoji: "+1",
|
||||
count: 2,
|
||||
checked: true,
|
||||
};
|
||||
@@ -0,0 +1,48 @@
|
||||
import { Button } from "@chakra-ui/react";
|
||||
import { BoxSelect, Flag, LucideProps, ThumbsDown, ThumbsUp } from "lucide-react";
|
||||
import { ReactElement } from "react";
|
||||
import { MessageEmoji } from "src/types/Conversation";
|
||||
|
||||
type EmojiIconPurpose = "MINI_BUTTON" | "NORMAL";
|
||||
|
||||
const defaultIconProps: (purpose: EmojiIconPurpose) => LucideProps = (purpose: EmojiIconPurpose) => {
|
||||
if (purpose === "MINI_BUTTON") return { height: "1em" };
|
||||
return {};
|
||||
};
|
||||
|
||||
export const getEmojiIcon = (name: string, purpose: EmojiIconPurpose): ReactElement => {
|
||||
switch (name) {
|
||||
case "+1":
|
||||
return <ThumbsUp {...defaultIconProps(purpose)} />;
|
||||
case "-1":
|
||||
return <ThumbsDown {...defaultIconProps(purpose)} />;
|
||||
case "flag":
|
||||
case "red_flag":
|
||||
return <Flag {...defaultIconProps(purpose)} />;
|
||||
default:
|
||||
return <BoxSelect {...defaultIconProps(purpose)} />;
|
||||
}
|
||||
};
|
||||
|
||||
interface MessageEmojiButtonProps {
|
||||
emoji: MessageEmoji;
|
||||
checked?: boolean;
|
||||
onClick: () => void;
|
||||
}
|
||||
|
||||
export const MessageEmojiButton = ({ emoji, checked, onClick }: MessageEmojiButtonProps) => {
|
||||
return (
|
||||
<Button
|
||||
onClick={onClick}
|
||||
variant={checked ? "solid" : "ghost"}
|
||||
colorScheme={checked ? "blue" : undefined}
|
||||
size="sm"
|
||||
height="1.6em"
|
||||
minWidth={0}
|
||||
padding="0"
|
||||
>
|
||||
{getEmojiIcon(emoji.name, "MINI_BUTTON")}
|
||||
<span style={{ marginInlineEnd: "0.25em" }}>{emoji.count}</span>
|
||||
</Button>
|
||||
);
|
||||
};
|
||||
@@ -29,18 +29,24 @@ Default.args = {
|
||||
is_assistant: true,
|
||||
id: "",
|
||||
frontend_message_id: "",
|
||||
emojis: {},
|
||||
user_emojis: [],
|
||||
},
|
||||
{
|
||||
text: "No, I just wanted to see how you reply when I type random characters. Can you tell me who invented Wikipedia?",
|
||||
is_assistant: false,
|
||||
id: "",
|
||||
frontend_message_id: "",
|
||||
emojis: { "-1": 11, red_flag: 2 },
|
||||
user_emojis: [],
|
||||
},
|
||||
{
|
||||
text: "Sorry, my cat sat on my keyboard. Can you print a cat in ASCII art?",
|
||||
is_assistant: false,
|
||||
id: "",
|
||||
frontend_message_id: "",
|
||||
emojis: {},
|
||||
user_emojis: [],
|
||||
},
|
||||
],
|
||||
enableLink: true,
|
||||
@@ -50,12 +56,21 @@ Default.args = {
|
||||
export const Conversation = Template.bind({});
|
||||
Conversation.args = {
|
||||
messages: [
|
||||
{ text: "Hello! How can I help you?", is_assistant: true, id: "", frontend_message_id: "" },
|
||||
{
|
||||
text: "Hello! How can I help you?",
|
||||
is_assistant: true,
|
||||
id: "",
|
||||
frontend_message_id: "",
|
||||
emojis: {},
|
||||
user_emojis: [],
|
||||
},
|
||||
{
|
||||
text: "Who were the 8 presidents before George Washington?",
|
||||
is_assistant: false,
|
||||
id: "",
|
||||
frontend_message_id: "",
|
||||
emojis: {},
|
||||
user_emojis: [],
|
||||
},
|
||||
],
|
||||
enableLink: false,
|
||||
@@ -70,18 +85,24 @@ LongText.args = {
|
||||
is_assistant: true,
|
||||
id: "",
|
||||
frontend_message_id: "",
|
||||
emojis: {},
|
||||
user_emojis: [],
|
||||
},
|
||||
{
|
||||
text: "Yes, I think they can be helpful when the child misbehaves, but they should be used with a little bit of compassion and understanding that it\u2019s not the natural state of things to have an adult yelling at them. Time outs are also often used without letting the child know how they\u2019re getting out of the time out, which can make it feel arbitrary or like a punishment, rather than a consequence for something they did. It\u2019s really easy for adults to do this kind of thing unconsciously. It\u2019s easy to get caught up in the notion that \u201cThey\u2019re in time out, and that\u2019s the end of it!\u201d but kids can be pretty imaginative, and they can use their own creativity to make their way out of time outs. A compassionate time out ends when the child shows a sign of understanding what they\u2019ve done wrong, and are ready to begin again. That way the child knows they\u2019re learning, and that the parent is seeing them as an intelligent person, even if they sometimes mess up. You can still use the other techniques you were using to be tough when necessary, but using a compassionate approach will let you use them without actually using them!",
|
||||
is_assistant: false,
|
||||
id: "",
|
||||
frontend_message_id: "",
|
||||
emojis: {},
|
||||
user_emojis: [],
|
||||
},
|
||||
{
|
||||
text: "No. The USA was founded by a Puritan group of Protestants, but it didn\u2019t adopt the religion of the Puritans until much later, and it was always a secular state. The Puritans observed the Sabbath on Sunday, and the Puritans only had a small influence in the early history of the USA. It\u2019s difficult to trace the origins of closing stores on Sunday, but one early and short-lived attempt at forcing the Sabbath on people in the 1800s was motivated by the Protestant ideal that people should spend Sunday focusing on spiritual activities. By the mid-1800s, when the Sunday closing law was made, there was not a lot of pressure from that standpoint, but the church had begun to advocate for Sunday closing laws as a way of counteracting the negative effects of industrialization on the day of rest. Even after that shift, closing stores on Sunday was not always possible, since the religious Sunday was not always chosen for observance. And as industrialization accelerated and mechanization made it possible to operate stores on Sunday, the law was not enforced as much as people liked. The day of rest was also being violated by stores that stayed open all day on Sunday, so closing stores on Sundays became an effort to protect the Sabbath for all citizens.",
|
||||
is_assistant: false,
|
||||
id: "",
|
||||
frontend_message_id: "",
|
||||
emojis: {},
|
||||
user_emojis: [],
|
||||
},
|
||||
],
|
||||
enableLink: true,
|
||||
|
||||
@@ -11,11 +11,11 @@ interface MessageTableProps {
|
||||
export function MessageTable({ messages, enableLink, highlightLastMessage }: MessageTableProps) {
|
||||
return (
|
||||
<Stack spacing="4">
|
||||
{messages.map((item, idx) => (
|
||||
{messages.map((message, idx) => (
|
||||
<MessageTableEntry
|
||||
enabled={enableLink}
|
||||
item={item}
|
||||
key={item.id + item.frontend_message_id}
|
||||
message={message}
|
||||
key={message.id + message.frontend_message_id}
|
||||
highlight={highlightLastMessage && idx === messages.length - 1}
|
||||
/>
|
||||
))}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import React from "react";
|
||||
import { Message } from "src/types/Conversation";
|
||||
|
||||
import { MessageTableEntry } from "./MessageTableEntry";
|
||||
|
||||
@@ -8,10 +9,8 @@ export default {
|
||||
component: MessageTableEntry,
|
||||
};
|
||||
|
||||
const Template = ({ text, is_assistant, id, frontend_message_id, enabled, highlight }) => {
|
||||
return (
|
||||
<MessageTableEntry item={{ text, is_assistant, id, frontend_message_id }} enabled={enabled} highlight={highlight} />
|
||||
);
|
||||
const Template = ({ enabled, highlight, ...message }) => {
|
||||
return <MessageTableEntry message={message as Message} enabled={enabled} highlight={highlight} />;
|
||||
};
|
||||
|
||||
export const Default = Template.bind({});
|
||||
@@ -22,6 +21,8 @@ Default.args = {
|
||||
frontend_message_id: "",
|
||||
enabled: true,
|
||||
highlight: false,
|
||||
emojis: {},
|
||||
user_emojis: [],
|
||||
};
|
||||
|
||||
export const Asistant = Template.bind({});
|
||||
@@ -32,6 +33,8 @@ Asistant.args = {
|
||||
frontend_message_id: "",
|
||||
enabled: true,
|
||||
highlight: false,
|
||||
emojis: {},
|
||||
user_emojis: [],
|
||||
};
|
||||
|
||||
export const LongText = Template.bind({});
|
||||
@@ -42,4 +45,18 @@ LongText.args = {
|
||||
frontend_message_id: "",
|
||||
enabled: true,
|
||||
highlight: false,
|
||||
emojis: {},
|
||||
user_emojis: [],
|
||||
};
|
||||
|
||||
export const WithEmoji = Template.bind({});
|
||||
WithEmoji.args = {
|
||||
text: "As you\u2019ve mentioned, Star Wars has many sequels, prequels, and crossovers. The official list of movies in Star Wars is:",
|
||||
is_assistant: true,
|
||||
id: "",
|
||||
frontend_message_id: "",
|
||||
enabled: true,
|
||||
highlight: false,
|
||||
emojis: { "-1": 5, "+1": 1 },
|
||||
user_emojis: ["-1"],
|
||||
};
|
||||
|
||||
@@ -1,23 +1,47 @@
|
||||
import { Avatar, Box, HStack, useBreakpointValue, useColorModeValue } from "@chakra-ui/react";
|
||||
import {
|
||||
Avatar,
|
||||
Box,
|
||||
HStack,
|
||||
Menu,
|
||||
MenuButton,
|
||||
MenuDivider,
|
||||
MenuGroup,
|
||||
MenuItem,
|
||||
MenuList,
|
||||
SimpleGrid,
|
||||
useBreakpointValue,
|
||||
useColorModeValue,
|
||||
useDisclosure,
|
||||
} from "@chakra-ui/react";
|
||||
import { boolean } from "boolean";
|
||||
import { ClipboardList, Flag, MessageSquare, MoreHorizontal } from "lucide-react";
|
||||
import { useRouter } from "next/router";
|
||||
import { useCallback, useMemo } from "react";
|
||||
import { FlaggableElement } from "src/components/FlaggableElement";
|
||||
import { Message } from "src/types/Conversation";
|
||||
import { useTranslation } from "next-i18next";
|
||||
import { useCallback, useEffect, useMemo, useState } from "react";
|
||||
import { LabelMessagePopup } from "src/components/Messages/LabelPopup";
|
||||
import { getEmojiIcon, MessageEmojiButton } from "src/components/Messages/MessageEmojiButton";
|
||||
import { ReportPopup } from "src/components/Messages/ReportPopup";
|
||||
import { post } from "src/lib/api";
|
||||
import { Message, MessageEmojis } from "src/types/Conversation";
|
||||
import { colors } from "styles/Theme/colors";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
|
||||
interface MessageTableEntryProps {
|
||||
item: Message;
|
||||
message: Message;
|
||||
enabled?: boolean;
|
||||
highlight?: boolean;
|
||||
}
|
||||
|
||||
export function MessageTableEntry(props: MessageTableEntryProps) {
|
||||
export function MessageTableEntry({ message, enabled, highlight }: MessageTableEntryProps) {
|
||||
const router = useRouter();
|
||||
const [emojiState, setEmojis] = useState<MessageEmojis>({ emojis: {}, user_emojis: [] });
|
||||
useEffect(() => {
|
||||
setEmojis({ emojis: message.emojis || {}, user_emojis: message.user_emojis || [] });
|
||||
}, [message.emojis, message.user_emojis]);
|
||||
|
||||
const { item } = props;
|
||||
|
||||
const goToMessage = useCallback(() => router.push(`/messages/${item.id}`), [router, item.id]);
|
||||
const goToMessage = useCallback(() => router.push(`/messages/${message.id}`), [router, message.id]);
|
||||
const { isOpen: reportPopupOpen, onOpen: showReportPopup, onClose: closeReportPopup } = useDisclosure();
|
||||
const { isOpen: labelPopupOpen, onOpen: showLabelPopup, onClose: closeLabelPopup } = useDisclosure();
|
||||
|
||||
const backgroundColor = useColorModeValue("gray.100", "gray.700");
|
||||
const backgroundColor2 = useColorModeValue("#DFE8F1", "#42536B");
|
||||
@@ -32,34 +56,124 @@ export function MessageTableEntry(props: MessageTableEntryProps) {
|
||||
borderColor={borderColor}
|
||||
size={inlineAvatar ? "xs" : "sm"}
|
||||
mr={inlineAvatar ? 2 : 0}
|
||||
name={`${boolean(item.is_assistant) ? "Assistant" : "User"}`}
|
||||
src={`${boolean(item.is_assistant) ? "/images/logos/logo.png" : "/images/temp-avatars/av1.jpg"}`}
|
||||
name={`${boolean(message.is_assistant) ? "Assistant" : "User"}`}
|
||||
src={`${boolean(message.is_assistant) ? "/images/logos/logo.png" : "/images/temp-avatars/av1.jpg"}`}
|
||||
/>
|
||||
),
|
||||
[borderColor, inlineAvatar, item.is_assistant]
|
||||
[borderColor, inlineAvatar, message.is_assistant]
|
||||
);
|
||||
const highlightColor = useColorModeValue(colors.light.highlight, colors.dark.highlight);
|
||||
|
||||
const { trigger: sendEmojiChange } = useSWRMutation(`/api/messages/${message.id}/emoji`, post, {
|
||||
onSuccess: setEmojis,
|
||||
});
|
||||
const react = (emoji: string, state: boolean) => {
|
||||
sendEmojiChange({ op: state ? "add" : "remove", emoji });
|
||||
};
|
||||
|
||||
return (
|
||||
<FlaggableElement message={item}>
|
||||
<HStack w={["full", "full", "full", "fit-content"]} gap={2}>
|
||||
{!inlineAvatar && avatar}
|
||||
<Box
|
||||
width={["full", "full", "full", "fit-content"]}
|
||||
maxWidth={["full", "full", "full", "2xl"]}
|
||||
p="4"
|
||||
borderRadius="md"
|
||||
bg={item.is_assistant ? backgroundColor : backgroundColor2}
|
||||
outline={props.highlight && "2px solid black"}
|
||||
outlineColor={highlightColor}
|
||||
onClick={props.enabled && goToMessage}
|
||||
_hover={props.enabled && { cursor: "pointer", opacity: 0.9 }}
|
||||
whiteSpace="pre-wrap"
|
||||
<HStack w={["full", "full", "full", "fit-content"]} gap={2}>
|
||||
{!inlineAvatar && avatar}
|
||||
<Box
|
||||
width={["full", "full", "full", "fit-content"]}
|
||||
maxWidth={["full", "full", "full", "2xl"]}
|
||||
p="4"
|
||||
borderRadius="md"
|
||||
bg={message.is_assistant ? backgroundColor : backgroundColor2}
|
||||
outline={highlight && "2px solid black"}
|
||||
outlineColor={highlightColor}
|
||||
onClick={enabled && goToMessage}
|
||||
whiteSpace="pre-wrap"
|
||||
cursor={enabled && "pointer"}
|
||||
style={{ position: "relative" }}
|
||||
>
|
||||
{inlineAvatar && avatar}
|
||||
{message.text}
|
||||
<HStack
|
||||
style={{ float: "right", position: "relative", right: "-0.3em", bottom: "-0em", marginLeft: "1em" }}
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
>
|
||||
{inlineAvatar && avatar}
|
||||
{item.text}
|
||||
</Box>
|
||||
</HStack>
|
||||
</FlaggableElement>
|
||||
{Object.entries(emojiState.emojis).map(([emoji, count]) => (
|
||||
<MessageEmojiButton
|
||||
key={emoji}
|
||||
emoji={{ name: emoji, count }}
|
||||
checked={emojiState.user_emojis.includes(emoji)}
|
||||
onClick={() => react(emoji, !emojiState.user_emojis.includes(emoji))}
|
||||
/>
|
||||
))}
|
||||
<MessageActions
|
||||
react={react}
|
||||
userEmoji={emojiState.user_emojis}
|
||||
onLabel={showLabelPopup}
|
||||
onReport={showReportPopup}
|
||||
messageId={message.id}
|
||||
/>
|
||||
<LabelMessagePopup messageId={message.id} show={labelPopupOpen} onClose={closeLabelPopup} />
|
||||
<ReportPopup messageId={message.id} show={reportPopupOpen} onClose={closeReportPopup} />
|
||||
</HStack>
|
||||
</Box>
|
||||
</HStack>
|
||||
);
|
||||
}
|
||||
|
||||
const EmojiMenuItem = ({
|
||||
emoji,
|
||||
checked,
|
||||
react,
|
||||
}: {
|
||||
emoji: string;
|
||||
checked?: boolean;
|
||||
react: (emoji: string, state: boolean) => void;
|
||||
}) => {
|
||||
const activeColor = useColorModeValue(colors.light.active, colors.dark.active);
|
||||
|
||||
return (
|
||||
<MenuItem onClick={() => react(emoji, !checked)} justifyContent="center" color={checked ? activeColor : undefined}>
|
||||
{getEmojiIcon(emoji, "NORMAL")}
|
||||
</MenuItem>
|
||||
);
|
||||
};
|
||||
|
||||
const MessageActions = ({
|
||||
react,
|
||||
userEmoji,
|
||||
onLabel,
|
||||
onReport,
|
||||
messageId,
|
||||
}: {
|
||||
react: (emoji: string, state: boolean) => void;
|
||||
userEmoji: string[];
|
||||
onLabel: () => void;
|
||||
onReport: () => void;
|
||||
messageId: string;
|
||||
}) => {
|
||||
const { t } = useTranslation("message");
|
||||
|
||||
return (
|
||||
<Menu>
|
||||
<MenuButton>
|
||||
<MoreHorizontal />
|
||||
</MenuButton>
|
||||
<MenuList>
|
||||
<MenuGroup title={t("reactions")}>
|
||||
<SimpleGrid columns={4}>
|
||||
{["+1", "-1"].map((emoji) => (
|
||||
<EmojiMenuItem key={emoji} emoji={emoji} checked={userEmoji.includes(emoji)} react={react} />
|
||||
))}
|
||||
</SimpleGrid>
|
||||
</MenuGroup>
|
||||
<MenuDivider />
|
||||
<MenuItem onClick={onLabel} icon={<ClipboardList />}>
|
||||
{t("label_action")}
|
||||
</MenuItem>
|
||||
<MenuItem onClick={onReport} icon={<Flag />}>
|
||||
{t("report_action")}
|
||||
</MenuItem>
|
||||
<MenuDivider />
|
||||
<MenuItem as="a" href={`/messages/${messageId}`} target="_blank" icon={<MessageSquare />}>
|
||||
{t("open_new_tab_action")}
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
</Menu>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -52,7 +52,7 @@ export function MessageWithChildren(props: MessageWithChildrenProps) {
|
||||
{isFirst ? "Message" : depth === 1 ? "Children" : "Ancestor"}
|
||||
</Text>
|
||||
<Box width="fit-content" bg={backgroundColor} padding="4" borderRadius="xl" boxShadow="base">
|
||||
<MessageTableEntry enabled item={message} />
|
||||
<MessageTableEntry enabled message={message} />
|
||||
</Box>
|
||||
</Box>
|
||||
</>
|
||||
@@ -86,9 +86,9 @@ export function MessageWithChildren(props: MessageWithChildrenProps) {
|
||||
gap="4"
|
||||
shadow="base"
|
||||
>
|
||||
{children.map((item, idx) => (
|
||||
{children.map((message, idx) => (
|
||||
<Box flex="1" key={`recursiveMessageWChildren_${idx}`}>
|
||||
<MessageTableEntry enabled item={item} />
|
||||
<MessageTableEntry enabled message={message} />
|
||||
</Box>
|
||||
))}
|
||||
</Box>
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
import {
|
||||
Button,
|
||||
Modal,
|
||||
ModalBody,
|
||||
ModalCloseButton,
|
||||
ModalContent,
|
||||
ModalFooter,
|
||||
ModalHeader,
|
||||
ModalOverlay,
|
||||
Textarea,
|
||||
} from "@chakra-ui/react";
|
||||
import { useTranslation } from "next-i18next";
|
||||
import { useState } from "react";
|
||||
import { post } from "src/lib/api";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
|
||||
interface ReportPopupProps {
|
||||
messageId: string;
|
||||
show: boolean;
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
export const ReportPopup = ({ messageId, show, onClose }: ReportPopupProps) => {
|
||||
const { t } = useTranslation("message");
|
||||
const [text, setText] = useState("");
|
||||
const { trigger } = useSWRMutation("/api/report", post);
|
||||
|
||||
const submit = () => {
|
||||
trigger({
|
||||
message_id: messageId,
|
||||
text,
|
||||
});
|
||||
|
||||
setText("");
|
||||
onClose();
|
||||
};
|
||||
|
||||
return (
|
||||
<Modal isOpen={show} onClose={onClose}>
|
||||
<ModalOverlay />
|
||||
<ModalContent>
|
||||
<ModalHeader>{t("report_title")}</ModalHeader>
|
||||
<ModalCloseButton />
|
||||
<ModalBody>
|
||||
<Textarea onChange={(e) => setText(e.target.value)} resize="none" placeholder={t("report_placeholder")} />
|
||||
</ModalBody>
|
||||
|
||||
<ModalFooter>
|
||||
<Button colorScheme="blue" mr={3} onClick={submit}>
|
||||
{t("send_report")}
|
||||
</Button>
|
||||
</ModalFooter>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
+5
-5
@@ -135,9 +135,9 @@ const getLabelInfo = (label: string): LabelInfo => {
|
||||
oneDescription: ["Contains text which is incorrect or misleading"],
|
||||
inverted: true,
|
||||
};
|
||||
case "helpful":
|
||||
case "helpfulness":
|
||||
return {
|
||||
zeroText: "Unhelful",
|
||||
zeroText: "Unhelpful",
|
||||
zeroDescription: [],
|
||||
oneText: "Helpful",
|
||||
oneDescription: ["Completes the task to a high standard"],
|
||||
@@ -186,7 +186,7 @@ const getLabelInfo = (label: string): LabelInfo => {
|
||||
}
|
||||
};
|
||||
|
||||
export const LabelInputGroup = ({ labelIDs, onChange, isEditable = true }: LabelInputGroupProps) => {
|
||||
export const LabelLikertGroup = ({ labelIDs, onChange, isEditable = true }: LabelInputGroupProps) => {
|
||||
const [labelValues, setLabelValues] = useState<number[]>(Array.from({ length: labelIDs.length }).map(() => null));
|
||||
|
||||
const cardColor = useColorModeValue("gray.50", "gray.800");
|
||||
@@ -211,7 +211,7 @@ export const LabelInputGroup = ({ labelIDs, onChange, isEditable = true }: Label
|
||||
}}
|
||||
alignItems="center"
|
||||
>
|
||||
<Text>
|
||||
<Text as="div">
|
||||
{textA}
|
||||
{descriptionA.length > 0 ? <Explain explanation={descriptionA} /> : null}
|
||||
</Text>
|
||||
@@ -229,7 +229,7 @@ export const LabelInputGroup = ({ labelIDs, onChange, isEditable = true }: Label
|
||||
/>
|
||||
</GridItem>
|
||||
<GridItem>
|
||||
<Text textAlign="right">
|
||||
<Text textAlign="right" as="div">
|
||||
{textB}
|
||||
{descriptionB.length > 0 ? <Explain explanation={descriptionB} /> : null}
|
||||
</Text>
|
||||
@@ -4,12 +4,10 @@ import { SkipButton } from "src/components/Buttons/Skip";
|
||||
import { SubmitButton } from "src/components/Buttons/Submit";
|
||||
import { TaskInfo } from "src/components/TaskInfo/TaskInfo";
|
||||
import { TaskStatus } from "src/components/Tasks/Task";
|
||||
import { BaseTask } from "src/types/Task";
|
||||
|
||||
export interface TaskControlsProps {
|
||||
// we need a task type
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
task: any;
|
||||
className?: string;
|
||||
task: BaseTask;
|
||||
taskStatus: TaskStatus;
|
||||
onEdit: () => void;
|
||||
onReview: () => void;
|
||||
@@ -17,7 +15,7 @@ export interface TaskControlsProps {
|
||||
onSkip: (reason: string) => void;
|
||||
}
|
||||
|
||||
export const TaskControls = (props: TaskControlsProps) => {
|
||||
export const TaskControls = ({ task, taskStatus, onEdit, onReview, onSubmit, onSkip }: TaskControlsProps) => {
|
||||
const backgroundColor = useColorModeValue("white", "gray.800");
|
||||
|
||||
return (
|
||||
@@ -31,38 +29,32 @@ export const TaskControls = (props: TaskControlsProps) => {
|
||||
shadow="base"
|
||||
gap="4"
|
||||
>
|
||||
<TaskInfo id={props.task.id} output="Submit your answer" />
|
||||
<TaskInfo id={task.id} output="Submit your answer" />
|
||||
<Flex width={["full", "fit-content"]} justify="center" ml="auto" gap={2}>
|
||||
{props.taskStatus === "REVIEW" || props.taskStatus === "SUBMITTED" ? (
|
||||
{taskStatus.mode === "EDIT" ? (
|
||||
<>
|
||||
<Tooltip label="Edit">
|
||||
<IconButton
|
||||
size="lg"
|
||||
data-cy="edit"
|
||||
aria-label="edit"
|
||||
onClick={props.onEdit}
|
||||
icon={<Edit2 size="1em" />}
|
||||
/>
|
||||
</Tooltip>
|
||||
<SkipButton onSkip={onSkip} />
|
||||
<SubmitButton
|
||||
colorScheme="green"
|
||||
data-cy="submit"
|
||||
isDisabled={props.taskStatus === "SUBMITTED"}
|
||||
onClick={props.onSubmit}
|
||||
colorScheme="blue"
|
||||
data-cy="review"
|
||||
isDisabled={taskStatus.replyValidity === "INVALID"}
|
||||
onClick={onReview}
|
||||
>
|
||||
Submit
|
||||
Review
|
||||
</SubmitButton>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<SkipButton onSkip={props.onSkip} />
|
||||
<Tooltip label="Edit">
|
||||
<IconButton size="lg" data-cy="edit" aria-label="edit" onClick={onEdit} icon={<Edit2 size="1em" />} />
|
||||
</Tooltip>
|
||||
<SubmitButton
|
||||
colorScheme="blue"
|
||||
data-cy="review"
|
||||
isDisabled={props.taskStatus === "NOT_SUBMITTABLE"}
|
||||
onClick={props.onReview}
|
||||
colorScheme="green"
|
||||
data-cy="submit"
|
||||
isDisabled={taskStatus.mode === "SUBMITTED"}
|
||||
onClick={onSubmit}
|
||||
>
|
||||
Review
|
||||
Submit
|
||||
</SubmitButton>
|
||||
</>
|
||||
)}
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
import Head from "next/head";
|
||||
import { useTranslation } from "next-i18next";
|
||||
import { TaskEmptyState } from "src/components/EmptyState";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import { TaskInfos } from "src/components/Tasks/TaskTypes";
|
||||
import { ERROR_CODES, taskApiHooks } from "src/lib/constants";
|
||||
import { getTypeSafei18nKey } from "src/lib/i18n";
|
||||
import { TaskType } from "src/types/Task";
|
||||
|
||||
type TaskPageProps = {
|
||||
type: TaskType;
|
||||
};
|
||||
|
||||
export const TaskPage = ({ type }: TaskPageProps) => {
|
||||
const { t } = useTranslation(["tasks", "common"]);
|
||||
const taskApiHook = taskApiHooks[type];
|
||||
const { tasks, isLoading, reset, trigger, error } = taskApiHook(type);
|
||||
const taskInfo = TaskInfos.find((taskType) => taskType.type === type);
|
||||
|
||||
if (isLoading) {
|
||||
return <LoadingScreen text={t("common:loading")} />;
|
||||
}
|
||||
|
||||
if (tasks.length === 0 || error?.errorCode === ERROR_CODES.TASK_REQUESTED_TYPE_NOT_AVAILABLE) {
|
||||
return <TaskEmptyState />;
|
||||
}
|
||||
|
||||
const task = tasks[0];
|
||||
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<title>{t(getTypeSafei18nKey(`${taskInfo.id}.label`))}</title>
|
||||
<meta name="description" content={t(getTypeSafei18nKey(`${taskInfo.id}.desc`))} />
|
||||
</Head>
|
||||
<Task key={task.task.id} frontendId={task.id} task={task.task} trigger={trigger} mutate={reset} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
@@ -7,6 +7,8 @@ import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
|
||||
import { TaskSurveyProps } from "src/components/Tasks/Task";
|
||||
import { TaskHeader } from "src/components/Tasks/TaskHeader";
|
||||
import { getTypeSafei18nKey } from "src/lib/i18n";
|
||||
import { TaskType } from "src/types/Task";
|
||||
import { CreateAssistantReplyTask, CreateInitialPromptTask, CreatePrompterReplyTask } from "src/types/Tasks";
|
||||
|
||||
export const CreateTask = ({
|
||||
task,
|
||||
@@ -15,7 +17,7 @@ export const CreateTask = ({
|
||||
isDisabled,
|
||||
onReplyChanged,
|
||||
onValidityChanged,
|
||||
}: TaskSurveyProps<{ text: string }>) => {
|
||||
}: TaskSurveyProps<CreateInitialPromptTask | CreateAssistantReplyTask | CreatePrompterReplyTask, { text: string }>) => {
|
||||
const { t, i18n } = useTranslation(["tasks", "common"]);
|
||||
const cardColor = useColorModeValue("gray.50", "gray.800");
|
||||
const titleColor = useColorModeValue("gray.800", "gray.300");
|
||||
@@ -39,7 +41,7 @@ export const CreateTask = ({
|
||||
<TwoColumnsWithCards>
|
||||
<>
|
||||
<TaskHeader taskType={taskType} />
|
||||
{!!task.conversation && (
|
||||
{task.type !== TaskType.initial_prompt && (
|
||||
<Box mt="4" borderRadius="lg" bg={cardColor} className="p-3 sm:p-6">
|
||||
<MessageTable messages={task.conversation.messages} highlightLastMessage />
|
||||
</Box>
|
||||
|
||||
@@ -5,6 +5,8 @@ import { Sortable } from "src/components/Sortable/Sortable";
|
||||
import { SurveyCard } from "src/components/Survey/SurveyCard";
|
||||
import { TaskSurveyProps } from "src/components/Tasks/Task";
|
||||
import { TaskHeader } from "src/components/Tasks/TaskHeader";
|
||||
import { TaskType } from "src/types/Task";
|
||||
import { RankAssistantRepliesTask, RankInitialPromptsTask, RankPrompterRepliesTask } from "src/types/Tasks";
|
||||
|
||||
export const EvaluateTask = ({
|
||||
task,
|
||||
@@ -13,20 +15,25 @@ export const EvaluateTask = ({
|
||||
isDisabled,
|
||||
onReplyChanged,
|
||||
onValidityChanged,
|
||||
}: TaskSurveyProps<{ ranking: number[] }>) => {
|
||||
}: TaskSurveyProps<
|
||||
RankInitialPromptsTask | RankAssistantRepliesTask | RankPrompterRepliesTask,
|
||||
{ ranking: number[] }
|
||||
>) => {
|
||||
const cardColor = useColorModeValue("gray.50", "gray.800");
|
||||
const [ranking, setRanking] = useState<number[]>(null);
|
||||
|
||||
let messages = [];
|
||||
if (task.conversation) {
|
||||
if (task.type !== TaskType.rank_initial_prompts) {
|
||||
messages = task.conversation.messages;
|
||||
messages = messages.map((message, index) => ({ ...message, id: index }));
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
if (ranking === null) {
|
||||
const defaultRanking = (task.replies ?? task.prompts).map((_, idx) => idx);
|
||||
onReplyChanged({ ranking: defaultRanking });
|
||||
if (task.type === TaskType.rank_initial_prompts) {
|
||||
onReplyChanged({ ranking: task.prompts.map((_, idx) => idx) });
|
||||
} else {
|
||||
onReplyChanged({ ranking: task.replies.map((_, idx) => idx) });
|
||||
}
|
||||
onValidityChanged("DEFAULT");
|
||||
} else {
|
||||
onReplyChanged({ ranking });
|
||||
@@ -34,7 +41,7 @@ export const EvaluateTask = ({
|
||||
}
|
||||
}, [task, ranking, onReplyChanged, onValidityChanged]);
|
||||
|
||||
const sortables = task.replies ? "replies" : "prompts";
|
||||
const sortables = task.type === TaskType.rank_initial_prompts ? "prompts" : "replies";
|
||||
|
||||
return (
|
||||
<div data-cy="task" data-task-type="evaluate-task">
|
||||
|
||||
@@ -1,12 +1,18 @@
|
||||
import { Box, Button, Flex, HStack, Text, useColorModeValue } from "@chakra-ui/react";
|
||||
import { Box, useBoolean, useColorModeValue } from "@chakra-ui/react";
|
||||
import { useTranslation } from "next-i18next";
|
||||
import { useEffect, useState } from "react";
|
||||
import { MessageView } from "src/components/Messages";
|
||||
import { LabelInputGroup } from "src/components/Messages/LabelInputGroup";
|
||||
import { MessageTable } from "src/components/Messages/MessageTable";
|
||||
import { LabelInputGroup } from "src/components/Survey/LabelInputGroup";
|
||||
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
|
||||
import { TaskSurveyProps } from "src/components/Tasks/Task";
|
||||
import { TaskHeader } from "src/components/Tasks/TaskHeader";
|
||||
import { TaskType } from "src/types/Task";
|
||||
import { LabelTaskType } from "src/types/Tasks";
|
||||
|
||||
const isRequired = (labelName: string, requiredLabels?: string[]) => {
|
||||
return requiredLabels ? requiredLabels.includes(labelName) : false;
|
||||
};
|
||||
|
||||
export const LabelTask = ({
|
||||
task,
|
||||
@@ -14,15 +20,33 @@ export const LabelTask = ({
|
||||
isEditable,
|
||||
onReplyChanged,
|
||||
onValidityChanged,
|
||||
}: TaskSurveyProps<{ text: string; labels: Record<string, number>; message_id: string }>) => {
|
||||
const [sliderValues, setSliderValues] = useState<number[]>(new Array(task.valid_labels.length).fill(null));
|
||||
}: TaskSurveyProps<LabelTaskType, { text: string; labels: Record<string, number>; message_id: string }>) => {
|
||||
const { t } = useTranslation("labelling");
|
||||
const [values, setValues] = useState<number[]>(new Array(task.labels.length).fill(null));
|
||||
const [userInputMade, setUserInputMade] = useBoolean(false);
|
||||
|
||||
// Initial setup to run when the task changes
|
||||
useEffect(() => {
|
||||
console.assert(task.valid_labels.length === sliderValues.length);
|
||||
const labels = Object.fromEntries(task.valid_labels.map((label, i) => [label, sliderValues[i]]));
|
||||
onReplyChanged({ labels, text: task.reply || task.prompt, message_id: task.message_id });
|
||||
onValidityChanged(sliderValues.every((value) => value !== null) ? "VALID" : "INVALID");
|
||||
}, [task, sliderValues, onReplyChanged, onValidityChanged]);
|
||||
setValues(new Array(task.labels.length).fill(null));
|
||||
onValidityChanged(task.labels.some(({ name }) => isRequired(name, task.mandatory_labels)) ? "INVALID" : "DEFAULT");
|
||||
setUserInputMade.off();
|
||||
}, [task, setUserInputMade, onValidityChanged]);
|
||||
|
||||
// Update the reply and validity when the values change
|
||||
useEffect(() => {
|
||||
onReplyChanged({
|
||||
text: "unused?",
|
||||
labels: Object.fromEntries(task.labels.map(({ name }, idx) => [name, values[idx] || 0])),
|
||||
message_id: task.message_id,
|
||||
});
|
||||
onValidityChanged(
|
||||
task.labels.some(({ name }, idx) => values[idx] === null && isRequired(name, task.mandatory_labels))
|
||||
? "INVALID"
|
||||
: userInputMade
|
||||
? "VALID"
|
||||
: "DEFAULT"
|
||||
);
|
||||
}, [task, values, onReplyChanged, userInputMade, onValidityChanged]);
|
||||
|
||||
const cardColor = useColorModeValue("gray.50", "gray.800");
|
||||
const isSpamTask = task.mode === "simple" && task.valid_labels.length === 1 && task.valid_labels[0] === "spam";
|
||||
@@ -32,71 +56,32 @@ export const LabelTask = ({
|
||||
<TwoColumnsWithCards>
|
||||
<>
|
||||
<TaskHeader taskType={taskType} />
|
||||
{task.conversation ? (
|
||||
{task.type !== TaskType.label_initial_prompt ? (
|
||||
<Box mt="4" p={[4, 6]} borderRadius="lg" bg={cardColor}>
|
||||
<MessageTable
|
||||
messages={[
|
||||
...(task.conversation?.messages ?? []),
|
||||
{
|
||||
text: task.reply,
|
||||
is_assistant: task.type === TaskType.label_assistant_reply,
|
||||
message_id: task.message_id,
|
||||
},
|
||||
]}
|
||||
highlightLastMessage
|
||||
/>
|
||||
<MessageTable messages={task.conversation.messages} highlightLastMessage />
|
||||
</Box>
|
||||
) : (
|
||||
<Box mt="4">
|
||||
<MessageView text={task.prompt} is_assistant={false} id={task.message_id} />
|
||||
<MessageView text={task.prompt} is_assistant={false} id={task.message_id} emojis={{}} user_emojis={[]} />
|
||||
</Box>
|
||||
)}
|
||||
</>
|
||||
{isSpamTask ? (
|
||||
<SpamTaskInput
|
||||
value={sliderValues[0]}
|
||||
onChange={(value) => setSliderValues([value])}
|
||||
isEditable={isEditable}
|
||||
/>
|
||||
) : (
|
||||
<Flex direction="column" alignItems="stretch">
|
||||
<Text>The highlighted message:</Text>
|
||||
<LabelInputGroup labelIDs={task.valid_labels} isEditable={isEditable} onChange={setSliderValues} />
|
||||
</Flex>
|
||||
)}
|
||||
<LabelInputGroup
|
||||
labels={task.labels}
|
||||
values={values}
|
||||
requiredLabels={task.mandatory_labels}
|
||||
isEditable={isEditable}
|
||||
instructions={{
|
||||
yesNoInstruction: t("label_highlighted_yes_no_instruction"),
|
||||
flagInstruction: t("label_highlighted_flag_instruction"),
|
||||
likertInstruction: t("label_highlighted_likert_instruction"),
|
||||
}}
|
||||
onChange={(values) => {
|
||||
setValues(values);
|
||||
setUserInputMade.on();
|
||||
}}
|
||||
/>
|
||||
</TwoColumnsWithCards>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
const SpamTaskInput = ({
|
||||
isEditable,
|
||||
value,
|
||||
onChange,
|
||||
}: {
|
||||
isEditable: boolean;
|
||||
value: number;
|
||||
onChange: (number) => void;
|
||||
}) => {
|
||||
return (
|
||||
<HStack>
|
||||
<Text>Is the highlighted message spam?</Text>
|
||||
<Button
|
||||
data-cy="spam-button"
|
||||
isDisabled={!isEditable}
|
||||
colorScheme={value === 1 ? "blue" : undefined}
|
||||
onClick={() => onChange(1)}
|
||||
>
|
||||
Yes
|
||||
</Button>
|
||||
<Button
|
||||
data-cy="not-spam-button"
|
||||
isDisabled={!isEditable}
|
||||
colorScheme={value === 0 ? "blue" : undefined}
|
||||
onClick={() => onChange(0)}
|
||||
>
|
||||
No
|
||||
</Button>
|
||||
</HStack>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { useTranslation } from "next-i18next";
|
||||
import { useRef, useState } from "react";
|
||||
import { useCallback, useEffect, useReducer } from "react";
|
||||
import { useMemo, useRef } from "react";
|
||||
import { TaskControls } from "src/components/Survey/TaskControls";
|
||||
import { CreateTask } from "src/components/Tasks/CreateTask";
|
||||
import { EvaluateTask } from "src/components/Tasks/EvaluateTask";
|
||||
@@ -8,15 +9,52 @@ import { TaskCategory, TaskInfo, TaskInfos } from "src/components/Tasks/TaskType
|
||||
import { UnchangedWarning } from "src/components/Tasks/UnchangedWarning";
|
||||
import { post } from "src/lib/api";
|
||||
import { getTypeSafei18nKey } from "src/lib/i18n";
|
||||
import { TaskContent, TaskReplyValidity } from "src/types/Task";
|
||||
import { BaseTask, TaskContent, TaskReplyValidity } from "src/types/Task";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
|
||||
export type TaskStatus = "NOT_SUBMITTABLE" | "DEFAULT" | "VALID" | "REVIEW" | "SUBMITTED";
|
||||
interface EditMode {
|
||||
mode: "EDIT";
|
||||
replyValidity: TaskReplyValidity;
|
||||
}
|
||||
interface ReviewMode {
|
||||
mode: "REVIEW";
|
||||
}
|
||||
interface DefaultWarnMode {
|
||||
mode: "DEFAULT_WARN";
|
||||
}
|
||||
interface SubmittedMode {
|
||||
mode: "SUBMITTED";
|
||||
}
|
||||
|
||||
export interface TaskSurveyProps<T> {
|
||||
// we need a task type
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
task: any;
|
||||
export type TaskStatus = EditMode | DefaultWarnMode | ReviewMode | SubmittedMode;
|
||||
|
||||
interface NewTask {
|
||||
action: "NEW_TASK";
|
||||
}
|
||||
|
||||
interface Review {
|
||||
action: "REVIEW";
|
||||
}
|
||||
|
||||
interface SetSubmitted {
|
||||
action: "SET_SUBMITTED";
|
||||
}
|
||||
|
||||
interface ReturnToEdit {
|
||||
action: "RETURN_EDIT";
|
||||
}
|
||||
|
||||
interface AcceptDefault {
|
||||
action: "ACCEPT_DEFAULT";
|
||||
}
|
||||
|
||||
interface UpdateValidity {
|
||||
action: "UPDATE_VALIDITY";
|
||||
replyValidity: TaskReplyValidity;
|
||||
}
|
||||
|
||||
export interface TaskSurveyProps<TaskType extends BaseTask, T> {
|
||||
task: TaskType;
|
||||
taskType: TaskInfo;
|
||||
isEditable: boolean;
|
||||
isDisabled?: boolean;
|
||||
@@ -26,13 +64,63 @@ export interface TaskSurveyProps<T> {
|
||||
|
||||
export const Task = ({ frontendId, task, trigger, mutate }) => {
|
||||
const { t } = useTranslation("tasks");
|
||||
const [taskStatus, setTaskStatus] = useState<TaskStatus>("NOT_SUBMITTABLE");
|
||||
const [taskStatus, taskEvent] = useReducer(
|
||||
(
|
||||
status: TaskStatus,
|
||||
event: NewTask | UpdateValidity | AcceptDefault | Review | ReturnToEdit | SetSubmitted
|
||||
): TaskStatus => {
|
||||
switch (event.action) {
|
||||
case "NEW_TASK":
|
||||
return { mode: "EDIT", replyValidity: "INVALID" };
|
||||
case "UPDATE_VALIDITY":
|
||||
return status.mode === "EDIT" ? { mode: "EDIT", replyValidity: event.replyValidity } : status;
|
||||
case "ACCEPT_DEFAULT":
|
||||
return status.mode === "DEFAULT_WARN" ? { mode: "REVIEW" } : status;
|
||||
case "REVIEW": {
|
||||
if (status.mode === "EDIT") {
|
||||
switch (status.replyValidity) {
|
||||
case "DEFAULT":
|
||||
return { mode: "DEFAULT_WARN" };
|
||||
case "VALID":
|
||||
return { mode: "REVIEW" };
|
||||
}
|
||||
}
|
||||
return status;
|
||||
}
|
||||
case "RETURN_EDIT": {
|
||||
switch (status.mode) {
|
||||
case "REVIEW":
|
||||
return { mode: "EDIT", replyValidity: "VALID" };
|
||||
case "DEFAULT_WARN":
|
||||
return { mode: "EDIT", replyValidity: "DEFAULT" };
|
||||
default:
|
||||
return status;
|
||||
}
|
||||
}
|
||||
case "SET_SUBMITTED": {
|
||||
return status.mode === "REVIEW" ? { mode: "SUBMITTED" } : status;
|
||||
}
|
||||
}
|
||||
},
|
||||
{ mode: "EDIT", replyValidity: "INVALID" }
|
||||
);
|
||||
|
||||
const replyContent = useRef<TaskContent>(null);
|
||||
const [showUnchangedWarning, setShowUnchangedWarning] = useState(false);
|
||||
const updateValidity = useCallback(
|
||||
(replyValidity: TaskReplyValidity) => taskEvent({ action: "UPDATE_VALIDITY", replyValidity }),
|
||||
[taskEvent]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
taskEvent({ action: "NEW_TASK" });
|
||||
}, [task.id, updateValidity]);
|
||||
|
||||
const rootEl = useRef<HTMLDivElement>(null);
|
||||
|
||||
const taskType = TaskInfos.find((taskType) => taskType.type === task.type && taskType.mode === task.mode);
|
||||
const taskType = useMemo(
|
||||
() => TaskInfos.find((taskType) => taskType.type === task.type && taskType.mode === task.mode),
|
||||
[task.type, task.mode]
|
||||
);
|
||||
|
||||
const { trigger: sendRejection } = useSWRMutation("/api/reject_task", post, {
|
||||
onSuccess: async () => {
|
||||
@@ -47,79 +135,36 @@ export const Task = ({ frontendId, task, trigger, mutate }) => {
|
||||
});
|
||||
};
|
||||
|
||||
const edit_mode = taskStatus === "NOT_SUBMITTABLE" || taskStatus === "DEFAULT" || taskStatus === "VALID";
|
||||
const submitted = taskStatus === "SUBMITTED";
|
||||
|
||||
const onValidityChanged = (validity: TaskReplyValidity) => {
|
||||
if (!edit_mode) return;
|
||||
switch (validity) {
|
||||
case "DEFAULT":
|
||||
if (taskStatus !== "DEFAULT") setTaskStatus("DEFAULT");
|
||||
break;
|
||||
case "VALID":
|
||||
if (taskStatus !== "VALID") setTaskStatus("VALID");
|
||||
break;
|
||||
case "INVALID":
|
||||
if (taskStatus !== "NOT_SUBMITTABLE") setTaskStatus("NOT_SUBMITTABLE");
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
const onReplyChanged = (content: TaskContent) => {
|
||||
replyContent.current = content;
|
||||
};
|
||||
|
||||
const reviewResponse = () => {
|
||||
switch (taskStatus) {
|
||||
case "DEFAULT":
|
||||
setShowUnchangedWarning(true);
|
||||
break;
|
||||
case "VALID":
|
||||
setTaskStatus("REVIEW");
|
||||
break;
|
||||
default:
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
const editResponse = () => {
|
||||
switch (taskStatus) {
|
||||
case "REVIEW":
|
||||
setTaskStatus("VALID");
|
||||
break;
|
||||
default:
|
||||
return;
|
||||
}
|
||||
};
|
||||
const onReplyChanged = useCallback(
|
||||
(content: TaskContent) => {
|
||||
replyContent.current = content;
|
||||
},
|
||||
[replyContent]
|
||||
);
|
||||
|
||||
const submitResponse = () => {
|
||||
switch (taskStatus) {
|
||||
case "REVIEW": {
|
||||
trigger({
|
||||
id: frontendId,
|
||||
update_type: taskType.update_type,
|
||||
content: replyContent.current,
|
||||
});
|
||||
setTaskStatus("SUBMITTED");
|
||||
scrollToTop(rootEl.current);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return;
|
||||
if (taskStatus.mode === "REVIEW") {
|
||||
trigger({
|
||||
id: frontendId,
|
||||
update_type: taskType.update_type,
|
||||
content: replyContent.current,
|
||||
});
|
||||
taskEvent({ action: "SET_SUBMITTED" });
|
||||
scrollToTop(rootEl.current);
|
||||
}
|
||||
};
|
||||
|
||||
function taskTypeComponent() {
|
||||
const taskTypeComponent = useMemo(() => {
|
||||
switch (taskType.category) {
|
||||
case TaskCategory.Create:
|
||||
return (
|
||||
<CreateTask
|
||||
task={task}
|
||||
taskType={taskType}
|
||||
isEditable={edit_mode}
|
||||
isDisabled={submitted}
|
||||
isEditable={taskStatus.mode === "EDIT"}
|
||||
isDisabled={taskStatus.mode === "SUBMITTED"}
|
||||
onReplyChanged={onReplyChanged}
|
||||
onValidityChanged={onValidityChanged}
|
||||
onValidityChanged={updateValidity}
|
||||
/>
|
||||
);
|
||||
case TaskCategory.Evaluate:
|
||||
@@ -127,10 +172,10 @@ export const Task = ({ frontendId, task, trigger, mutate }) => {
|
||||
<EvaluateTask
|
||||
task={task}
|
||||
taskType={taskType}
|
||||
isEditable={edit_mode}
|
||||
isDisabled={submitted}
|
||||
isEditable={taskStatus.mode === "EDIT"}
|
||||
isDisabled={taskStatus.mode === "SUBMITTED"}
|
||||
onReplyChanged={onReplyChanged}
|
||||
onValidityChanged={onValidityChanged}
|
||||
onValidityChanged={updateValidity}
|
||||
/>
|
||||
);
|
||||
case TaskCategory.Label:
|
||||
@@ -138,37 +183,34 @@ export const Task = ({ frontendId, task, trigger, mutate }) => {
|
||||
<LabelTask
|
||||
task={task}
|
||||
taskType={taskType}
|
||||
isEditable={edit_mode}
|
||||
isDisabled={submitted}
|
||||
isEditable={taskStatus.mode === "EDIT"}
|
||||
isDisabled={taskStatus.mode === "SUBMITTED"}
|
||||
onReplyChanged={onReplyChanged}
|
||||
onValidityChanged={onValidityChanged}
|
||||
onValidityChanged={updateValidity}
|
||||
/>
|
||||
);
|
||||
}
|
||||
}
|
||||
}, [task, taskType, taskStatus.mode, onReplyChanged, updateValidity]);
|
||||
|
||||
return (
|
||||
<div ref={rootEl}>
|
||||
{taskTypeComponent()}
|
||||
{taskTypeComponent}
|
||||
<TaskControls
|
||||
task={task}
|
||||
taskStatus={taskStatus}
|
||||
onEdit={editResponse}
|
||||
onReview={reviewResponse}
|
||||
onEdit={() => taskEvent({ action: "RETURN_EDIT" })}
|
||||
onReview={() => taskEvent({ action: "REVIEW" })}
|
||||
onSubmit={submitResponse}
|
||||
onSkip={rejectTask}
|
||||
/>
|
||||
<UnchangedWarning
|
||||
show={showUnchangedWarning}
|
||||
show={taskStatus.mode === "DEFAULT_WARN"}
|
||||
title={t(getTypeSafei18nKey(`${taskType.id}.unchanged_title`)) || t("default.unchanged_title")}
|
||||
message={t(getTypeSafei18nKey(`${taskType.id}.unchanged_message`)) || t("default.unchanged_message")}
|
||||
continueButtonText={"Continue anyway"}
|
||||
onClose={() => setShowUnchangedWarning(false)}
|
||||
onClose={() => taskEvent({ action: "RETURN_EDIT" })}
|
||||
onContinueAnyway={() => {
|
||||
if (taskStatus === "DEFAULT") {
|
||||
setTaskStatus("REVIEW");
|
||||
setShowUnchangedWarning(false);
|
||||
}
|
||||
taskEvent({ action: "ACCEPT_DEFAULT" });
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -17,7 +17,8 @@ export const post = (url: string, { arg: data }) => api.post(url, data).then((re
|
||||
api.interceptors.response.use(
|
||||
(response) => response,
|
||||
(error) => {
|
||||
throw new OasstError(error.message ?? error, error.error_code, error?.response?.status || -1);
|
||||
const err = error?.response?.data;
|
||||
throw new OasstError(err?.message ?? error, err?.errorCode, error?.response?.httpStatusCode || -1);
|
||||
}
|
||||
);
|
||||
|
||||
|
||||
@@ -21,14 +21,14 @@ const withoutRole = (role: Role, handler: (arg0: NextApiRequest, arg1: NextApiRe
|
||||
* Wraps any API Route handler and verifies that the user has the appropriate
|
||||
* role before running the handler. Returns a 403 otherwise.
|
||||
*/
|
||||
const withRole = (role: Role, handler: (arg0: NextApiRequest, arg1: NextApiResponse) => void) => {
|
||||
const withRole = (role: Role, handler: (arg0: NextApiRequest, arg1: NextApiResponse, token: JWT) => void) => {
|
||||
return async (req: NextApiRequest, res: NextApiResponse) => {
|
||||
const token = await getToken({ req });
|
||||
if (!token || token.role !== role) {
|
||||
res.status(403).end();
|
||||
return;
|
||||
}
|
||||
return handler(req, res);
|
||||
return handler(req, res, token);
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
import {
|
||||
useCreateAssistantReply,
|
||||
useCreateInitialPrompt,
|
||||
useCreatePrompterReply,
|
||||
} from "src/hooks/tasks/useCreateReply";
|
||||
import { useGenericTaskAPI } from "src/hooks/tasks/useGenericTaskAPI";
|
||||
import {
|
||||
useLabelAssistantReplyTask,
|
||||
useLabelInitialPromptTask,
|
||||
useLabelPrompterReplyTask,
|
||||
} from "src/hooks/tasks/useLabelingTask";
|
||||
import {
|
||||
useRankAssistantRepliesTask,
|
||||
useRankInitialPromptsTask,
|
||||
useRankPrompterRepliesTask,
|
||||
} from "src/hooks/tasks/useRankReplies";
|
||||
import { TaskApiHooks } from "src/types/Hooks";
|
||||
import { TaskType } from "src/types/Task";
|
||||
|
||||
export const ERROR_CODES = {
|
||||
TASK_REQUESTED_TYPE_NOT_AVAILABLE: 1006,
|
||||
TASK_INVALID_REQUEST_TYPE: 1000,
|
||||
TASK_ACK_FAILED: 1001,
|
||||
TASK_NACK_FAILED: 1002,
|
||||
TASK_INVALID_RESPONSE_TYPE: 1003,
|
||||
TASK_INTERACTION_REQUEST_FAILED: 1004,
|
||||
TASK_GENERATION_FAILED: 1005,
|
||||
TASK_AVAILABILITY_QUERY_FAILED: 1007,
|
||||
TASK_MESSAGE_TOO_LONG: 1008,
|
||||
};
|
||||
|
||||
export const taskApiHooks: TaskApiHooks = {
|
||||
[TaskType.random]: useGenericTaskAPI,
|
||||
[TaskType.assistant_reply]: useCreateAssistantReply,
|
||||
[TaskType.initial_prompt]: useCreateInitialPrompt,
|
||||
[TaskType.label_assistant_reply]: useLabelAssistantReplyTask,
|
||||
[TaskType.label_initial_prompt]: useLabelInitialPromptTask,
|
||||
[TaskType.label_prompter_reply]: useLabelPrompterReplyTask,
|
||||
[TaskType.prompter_reply]: useCreatePrompterReply,
|
||||
[TaskType.rank_assistant_replies]: useRankAssistantRepliesTask,
|
||||
[TaskType.rank_initial_prompts]: useRankInitialPromptsTask,
|
||||
[TaskType.rank_prompter_replies]: useRankPrompterRepliesTask,
|
||||
};
|
||||
@@ -1,4 +1,4 @@
|
||||
import type { Message } from "src/types/Conversation";
|
||||
import type { EmojiOp, Message } from "src/types/Conversation";
|
||||
import { LeaderboardReply, LeaderboardTimeFrame } from "src/types/Leaderboard";
|
||||
import type { AvailableTasks } from "src/types/Task";
|
||||
import type { BackendUser, BackendUserCore, FetchUsersParams, FetchUsersResponse } from "src/types/Users";
|
||||
@@ -18,10 +18,16 @@ export class OasstError {
|
||||
export class OasstApiClient {
|
||||
oasstApiUrl: string;
|
||||
oasstApiKey: string;
|
||||
userHeaders: Record<string, string> = {};
|
||||
|
||||
constructor(oasstApiUrl: string, oasstApiKey: string) {
|
||||
constructor(oasstApiUrl: string, oasstApiKey: string, user?: BackendUserCore) {
|
||||
this.oasstApiUrl = oasstApiUrl;
|
||||
this.oasstApiKey = oasstApiKey;
|
||||
if (user) {
|
||||
this.userHeaders = {
|
||||
"X-OASST-USER": `${user.auth_method}:${user.id}`,
|
||||
};
|
||||
}
|
||||
}
|
||||
// TODO return a strongly typed Task?
|
||||
// This method is used to store a task in RegisteredTask.task.
|
||||
@@ -76,6 +82,27 @@ export class OasstApiClient {
|
||||
return this.post<AvailableTasks>("/api/v1/tasks/availability", user);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the `Message`s associated with `user_id` in the backend.
|
||||
*/
|
||||
async fetch_message(message_id: string, user: BackendUserCore): Promise<Message> {
|
||||
return this.get<Message>(`/api/v1/messages/${message_id}?username=${user.id}&auth_method=${user.auth_method}`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Send a report about a message
|
||||
*/
|
||||
async send_report(message_id: string, user: BackendUserCore, text: string) {
|
||||
return this.post("/api/v1/text_labels", {
|
||||
type: "text_labels",
|
||||
message_id,
|
||||
labels: [], // Not yet implemented
|
||||
text,
|
||||
is_report: true,
|
||||
user,
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the message stats from the backend.
|
||||
*/
|
||||
@@ -144,7 +171,7 @@ export class OasstApiClient {
|
||||
time_frame: LeaderboardTimeFrame,
|
||||
{ limit = 20 }: { limit?: number }
|
||||
): Promise<LeaderboardReply | null> {
|
||||
return this.get<LeaderboardReply>(`/api/v1/leaderboards/${time_frame}`, { limit });
|
||||
return this.get<LeaderboardReply>(`/api/v1/leaderboards/${time_frame}`, { max_count: limit });
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -154,6 +181,17 @@ export class OasstApiClient {
|
||||
return this.post<AvailableTasks>(`/api/v1/tasks/availability?lang=${lang}`, user);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add/remove an emoji on a message for a user
|
||||
*/
|
||||
async set_user_message_emoji(message_id: string, user: BackendUserCore, emoji: string, op: EmojiOp): Promise<void> {
|
||||
await this.post(`/api/v1/messages/${message_id}/emoji`, {
|
||||
user,
|
||||
emoji,
|
||||
op,
|
||||
});
|
||||
}
|
||||
|
||||
private async post<T>(path: string, body: unknown) {
|
||||
return this.request<T>("POST", path, {
|
||||
body: JSON.stringify(body),
|
||||
@@ -183,9 +221,10 @@ export class OasstApiClient {
|
||||
method,
|
||||
...init,
|
||||
headers: {
|
||||
...init?.headers,
|
||||
...this.userHeaders,
|
||||
"X-API-Key": this.oasstApiKey,
|
||||
"Content-Type": "application/json",
|
||||
...init?.headers,
|
||||
},
|
||||
});
|
||||
|
||||
@@ -195,8 +234,7 @@ export class OasstApiClient {
|
||||
|
||||
if (resp.status >= 300) {
|
||||
const errorText = await resp.text();
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
let error: any;
|
||||
let error;
|
||||
try {
|
||||
error = JSON.parse(errorText);
|
||||
} catch (e) {
|
||||
@@ -207,8 +245,24 @@ export class OasstApiClient {
|
||||
|
||||
return await resp.json();
|
||||
}
|
||||
|
||||
fetch_my_messages(user: BackendUserCore) {
|
||||
const params = new URLSearchParams({
|
||||
username: user.id,
|
||||
auth_method: user.auth_method,
|
||||
});
|
||||
return this.get<Message[]>(`/api/v1/messages?${params}`);
|
||||
}
|
||||
|
||||
fetch_recent_messages() {
|
||||
return this.get<Message[]>(`/api/v1/messages`);
|
||||
}
|
||||
|
||||
fetch_message_children(messageId: string) {
|
||||
return this.get<Message[]>(`/api/v1/messages/${messageId}/children`);
|
||||
}
|
||||
|
||||
fetch_conversation(messageId: string) {
|
||||
return this.get(`/api/v1/messages/${messageId}/conversation`);
|
||||
}
|
||||
}
|
||||
|
||||
const oasstApiClient = new OasstApiClient(process.env.FASTAPI_URL, process.env.FASTAPI_KEY);
|
||||
|
||||
export { oasstApiClient };
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
import { JWT } from "next-auth/jwt";
|
||||
import { OasstApiClient } from "src/lib/oasst_api_client";
|
||||
import { getBackendUserCore } from "src/lib/users";
|
||||
import { BackendUserCore } from "src/types/Users";
|
||||
|
||||
export const createApiClientFromUser = (user: BackendUserCore) =>
|
||||
new OasstApiClient(process.env.FASTAPI_URL, process.env.FASTAPI_KEY, user);
|
||||
|
||||
export const createApiClient = async (token: JWT) => createApiClientFromUser(await getBackendUserCore(token.sub));
|
||||
|
||||
export const userlessApiClient = new OasstApiClient(process.env.FASTAPI_URL, process.env.FASTAPI_KEY);
|
||||
@@ -10,7 +10,7 @@ import { getAdminLayout } from "src/components/Layout";
|
||||
import { Role, RoleSelect } from "src/components/RoleSelect";
|
||||
import { UserMessagesCell } from "src/components/UserMessagesCell";
|
||||
import { post } from "src/lib/api";
|
||||
import { oasstApiClient } from "src/lib/oasst_api_client";
|
||||
import { userlessApiClient } from "src/lib/oasst_client_factory";
|
||||
import prisma from "src/lib/prismadb";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
|
||||
@@ -113,7 +113,7 @@ const ManageUser = ({ user }: InferGetServerSidePropsType<typeof getServerSidePr
|
||||
* Fetch the user's data on the server side when rendering.
|
||||
*/
|
||||
export async function getServerSideProps({ query, locale }) {
|
||||
const backend_user = await oasstApiClient.fetch_user(query.id);
|
||||
const backend_user = await userlessApiClient.fetch_user(query.id);
|
||||
const local_user = await prisma.user.findUnique({
|
||||
where: { id: backend_user.id },
|
||||
select: {
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
import { getToken } from "next-auth/jwt";
|
||||
import { withRole } from "src/lib/auth";
|
||||
import { oasstApiClient } from "src/lib/oasst_api_client";
|
||||
import { getBackendUserCore } from "src/lib/users";
|
||||
import { createApiClientFromUser } from "src/lib/oasst_client_factory";
|
||||
|
||||
/**
|
||||
* Returns tasks availability, stats, and tree manager stats.
|
||||
*/
|
||||
const handler = withRole("admin", async (req, res) => {
|
||||
// NOTE: why are we using a dummy user here?
|
||||
const dummyUser = {
|
||||
id: "__dummy_user__",
|
||||
display_name: "Dummy User",
|
||||
auth_method: "local",
|
||||
};
|
||||
const oasstApiClient = createApiClientFromUser(dummyUser);
|
||||
const [tasksAvailabilityOutcome, statsOutcome, treeManagerOutcome] = await Promise.allSettled([
|
||||
oasstApiClient.fetch_tasks_availability(dummyUser),
|
||||
oasstApiClient.fetch_stats(),
|
||||
|
||||
@@ -1,22 +1,19 @@
|
||||
import { withRole } from "src/lib/auth";
|
||||
import { oasstApiClient } from "src/lib/oasst_api_client";
|
||||
import { createApiClient } from "src/lib/oasst_client_factory";
|
||||
import prisma from "src/lib/prismadb";
|
||||
|
||||
/**
|
||||
* Update's the user's data in the database. Accessible only to admins.
|
||||
*/
|
||||
const handler = withRole("admin", async (req, res) => {
|
||||
const handler = withRole("admin", async (req, res, token) => {
|
||||
const { id, auth_method, user_id, notes, role } = req.body;
|
||||
|
||||
const oasstApiClient = await createApiClient(token);
|
||||
// If the user is authorized by the web, update their role.
|
||||
if (auth_method === "local") {
|
||||
await prisma.user.update({
|
||||
where: {
|
||||
id,
|
||||
},
|
||||
data: {
|
||||
role,
|
||||
},
|
||||
where: { id },
|
||||
data: { role },
|
||||
});
|
||||
}
|
||||
// Tell the backend the user's enabled or not enabled status.
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import { withRole } from "src/lib/auth";
|
||||
import { oasstApiClient } from "src/lib/oasst_api_client";
|
||||
import { createApiClient } from "src/lib/oasst_client_factory";
|
||||
import type { Message } from "src/types/Conversation";
|
||||
|
||||
/**
|
||||
* Returns the messages recorded by the backend for a user.
|
||||
*/
|
||||
const handler = withRole("admin", async (req, res) => {
|
||||
const handler = withRole("admin", async (req, res, token) => {
|
||||
const { user } = req.query;
|
||||
const oasstApiClient = await createApiClient(token);
|
||||
const messages: Message[] = await oasstApiClient.fetch_user_messages(user as string);
|
||||
res.status(200).json(messages);
|
||||
});
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { withRole } from "src/lib/auth";
|
||||
import { oasstApiClient } from "src/lib/oasst_api_client";
|
||||
import { createApiClient } from "src/lib/oasst_client_factory";
|
||||
import prisma from "src/lib/prismadb";
|
||||
import { FetchUsersParams } from "src/types/Users";
|
||||
|
||||
@@ -17,9 +17,10 @@ const PAGE_SIZE = 20;
|
||||
* - `direction`: Either "forward" or "backward" representing the pagination
|
||||
* direction.
|
||||
*/
|
||||
const handler = withRole("admin", async (req, res) => {
|
||||
const handler = withRole("admin", async (req, res, token) => {
|
||||
const { cursor, direction, searchDisplayName = "", sortKey = "username" } = req.query;
|
||||
|
||||
const oasstApiClient = await createApiClient(token);
|
||||
// First, get all the users according to the backend.
|
||||
const { items: all_users, ...rest } = await oasstApiClient.fetch_users({
|
||||
searchDisplayName: searchDisplayName as FetchUsersParams["searchDisplayName"],
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import { withoutRole } from "src/lib/auth";
|
||||
import { oasstApiClient } from "src/lib/oasst_api_client";
|
||||
import { createApiClientFromUser } from "src/lib/oasst_client_factory";
|
||||
import { getBackendUserCore, getUserLanguage } from "src/lib/users";
|
||||
|
||||
const handler = withoutRole("banned", async (req, res, token) => {
|
||||
const user = await getBackendUserCore(token.sub);
|
||||
const oasstApiClient = createApiClientFromUser(user);
|
||||
const userLanguage = getUserLanguage(req);
|
||||
const availableTasks = await oasstApiClient.fetch_available_tasks(user, userLanguage);
|
||||
res.status(200).json(availableTasks);
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import { withoutRole } from "src/lib/auth";
|
||||
import { oasstApiClient } from "src/lib/oasst_api_client";
|
||||
import { createApiClient } from "src/lib/oasst_client_factory";
|
||||
import { LeaderboardTimeFrame } from "src/types/Leaderboard";
|
||||
|
||||
/**
|
||||
* Returns the set of valid labels that can be applied to messages.
|
||||
*/
|
||||
const handler = withoutRole("banned", async (req, res) => {
|
||||
const handler = withoutRole("banned", async (req, res, token) => {
|
||||
const oasstApiClient = await createApiClient(token);
|
||||
const time_frame = (req.query.time_frame as LeaderboardTimeFrame) ?? LeaderboardTimeFrame.day;
|
||||
const info = await oasstApiClient.fetch_leaderboard(time_frame, { limit: req.query.limit as unknown as number });
|
||||
res.status(200).json(info);
|
||||
|
||||
@@ -1,18 +1,10 @@
|
||||
import { withoutRole } from "src/lib/auth";
|
||||
import { createApiClient } from "src/lib/oasst_client_factory";
|
||||
|
||||
const handler = withoutRole("banned", async (req, res) => {
|
||||
const handler = withoutRole("banned", async (req, res, token) => {
|
||||
const { id } = req.query;
|
||||
|
||||
const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages/${id}/children`, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
"X-API-Key": process.env.FASTAPI_KEY,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
const messages = await messagesRes.json();
|
||||
|
||||
// Send recieved messages to the client.
|
||||
const client = await createApiClient(token);
|
||||
const messages = await client.fetch_message_children(id as string);
|
||||
res.status(200).json(messages);
|
||||
});
|
||||
|
||||
|
||||
@@ -1,18 +1,10 @@
|
||||
import { withoutRole } from "src/lib/auth";
|
||||
import { createApiClient } from "src/lib/oasst_client_factory";
|
||||
|
||||
const handler = withoutRole("banned", async (req, res) => {
|
||||
const handler = withoutRole("banned", async (req, res, token) => {
|
||||
const { id } = req.query;
|
||||
|
||||
const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages/${id}/conversation`, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
"X-API-Key": process.env.FASTAPI_KEY,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
const messages = await messagesRes.json();
|
||||
|
||||
// Send recieved messages to the client.
|
||||
const client = await createApiClient(token);
|
||||
const messages = await client.fetch_conversation(id as string);
|
||||
res.status(200).json(messages);
|
||||
});
|
||||
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
import { withoutRole } from "src/lib/auth";
|
||||
import { createApiClientFromUser } from "src/lib/oasst_client_factory";
|
||||
import { getBackendUserCore } from "src/lib/users";
|
||||
|
||||
const handler = withoutRole("banned", async (req, res, token) => {
|
||||
const { id } = req.query;
|
||||
|
||||
if (!id) {
|
||||
res.status(400).end();
|
||||
return;
|
||||
}
|
||||
|
||||
const messageId = id as string;
|
||||
|
||||
const { emoji, op } = req.body;
|
||||
|
||||
const user = await getBackendUserCore(token.sub);
|
||||
const oasstApiClient = createApiClientFromUser(user);
|
||||
try {
|
||||
await oasstApiClient.set_user_message_emoji(messageId, user, emoji, op);
|
||||
} catch (err) {
|
||||
console.error(JSON.stringify(err));
|
||||
return res.status(500).json(err);
|
||||
}
|
||||
|
||||
// Get updated emoji
|
||||
const message = await oasstApiClient.fetch_message(messageId, user);
|
||||
res.status(200).json({ emojis: message.emojis, user_emojis: message.user_emojis });
|
||||
});
|
||||
|
||||
export default handler;
|
||||
@@ -1,18 +1,12 @@
|
||||
import { withoutRole } from "src/lib/auth";
|
||||
import { createApiClientFromUser } from "src/lib/oasst_client_factory";
|
||||
import { getBackendUserCore } from "src/lib/users";
|
||||
|
||||
const handler = withoutRole("banned", async (req, res) => {
|
||||
const handler = withoutRole("banned", async (req, res, token) => {
|
||||
const { id } = req.query;
|
||||
|
||||
const messageRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages/${id}`, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
"X-API-Key": process.env.FASTAPI_KEY,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
const message = await messageRes.json();
|
||||
|
||||
// Send recieved messages to the client.
|
||||
const user = await getBackendUserCore(token.sub);
|
||||
const client = createApiClientFromUser(user);
|
||||
const message = await client.fetch_message(id as string, user);
|
||||
res.status(200).json(message);
|
||||
});
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import { withoutRole } from "src/lib/auth";
|
||||
import { createApiClient, createApiClientFromUser } from "src/lib/oasst_client_factory";
|
||||
import { getBackendUserCore } from "src/lib/users";
|
||||
|
||||
const handler = withoutRole("banned", async (req, res) => {
|
||||
const handler = withoutRole("banned", async (req, res, token) => {
|
||||
const { id } = req.query;
|
||||
|
||||
if (!id) {
|
||||
@@ -8,32 +10,16 @@ const handler = withoutRole("banned", async (req, res) => {
|
||||
return;
|
||||
}
|
||||
|
||||
const messageRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages/${id}`, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
"X-API-Key": process.env.FASTAPI_KEY,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
|
||||
const message = await messageRes.json();
|
||||
const user = await getBackendUserCore(token.sub);
|
||||
const client = createApiClientFromUser(user);
|
||||
const message = await client.fetch_message(id as string, user);
|
||||
|
||||
if (!message.parent_id) {
|
||||
res.status(404).end();
|
||||
return;
|
||||
}
|
||||
|
||||
const parentRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages/${message.parent_id}`, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
"X-API-Key": process.env.FASTAPI_KEY,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
|
||||
const parent = await parentRes.json();
|
||||
|
||||
// Send recieved messages to the client.
|
||||
const parent = await client.fetch_message(message.parent_id, user);
|
||||
res.status(200).json(parent);
|
||||
});
|
||||
|
||||
|
||||
@@ -1,15 +1,9 @@
|
||||
import { withoutRole } from "src/lib/auth";
|
||||
import { createApiClient } from "src/lib/oasst_client_factory";
|
||||
|
||||
const handler = withoutRole("banned", async (req, res) => {
|
||||
const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages`, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
"X-API-Key": process.env.FASTAPI_KEY,
|
||||
},
|
||||
});
|
||||
const messages = await messagesRes.json();
|
||||
|
||||
// Send recieved messages to the client.
|
||||
const handler = withoutRole("banned", async (req, res, token) => {
|
||||
const client = await createApiClient(token);
|
||||
const messages = await client.fetch_recent_messages();
|
||||
res.status(200).json(messages);
|
||||
});
|
||||
|
||||
|
||||
@@ -1,23 +1,11 @@
|
||||
import { withoutRole } from "src/lib/auth";
|
||||
import { createApiClientFromUser } from "src/lib/oasst_client_factory";
|
||||
import { getBackendUserCore } from "src/lib/users";
|
||||
|
||||
const handler = withoutRole("banned", async (req, res, token) => {
|
||||
//TODO: add params if needed
|
||||
const user = await getBackendUserCore(token.sub);
|
||||
const params = new URLSearchParams({
|
||||
username: user.id,
|
||||
auth_method: user.auth_method,
|
||||
});
|
||||
|
||||
const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages?${params}`, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
"X-API-Key": process.env.FASTAPI_KEY,
|
||||
},
|
||||
});
|
||||
const messages = await messagesRes.json();
|
||||
|
||||
// Send recieved messages to the client.
|
||||
const client = createApiClientFromUser(user);
|
||||
const messages = await client.fetch_my_messages(user);
|
||||
res.status(200).json(messages);
|
||||
});
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { withoutRole } from "src/lib/auth";
|
||||
import { oasstApiClient } from "src/lib/oasst_api_client";
|
||||
import { createApiClientFromUser } from "src/lib/oasst_client_factory";
|
||||
import prisma from "src/lib/prismadb";
|
||||
import { getBackendUserCore, getUserLanguage } from "src/lib/users";
|
||||
|
||||
@@ -17,6 +17,7 @@ const handler = withoutRole("banned", async (req, res, token) => {
|
||||
const userLanguage = getUserLanguage(req);
|
||||
|
||||
const user = await getBackendUserCore(token.sub);
|
||||
const oasstApiClient = createApiClientFromUser(user);
|
||||
let task;
|
||||
try {
|
||||
task = await oasstApiClient.fetchTask(task_type as string, user, userLanguage);
|
||||
|
||||
@@ -1,19 +1,21 @@
|
||||
import { Prisma } from "@prisma/client";
|
||||
import { withoutRole } from "src/lib/auth";
|
||||
import { oasstApiClient } from "src/lib/oasst_api_client";
|
||||
import { createApiClient } from "src/lib/oasst_client_factory";
|
||||
import prisma from "src/lib/prismadb";
|
||||
|
||||
const handler = withoutRole("banned", async (req, res) => {
|
||||
const handler = withoutRole("banned", async (req, res, token) => {
|
||||
// Parse out the local task ID and the interaction contents.
|
||||
const { id: frontendId, reason } = req.body;
|
||||
|
||||
const registeredTask = await prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } });
|
||||
const [oasstApiClient, registeredTask] = await Promise.all([
|
||||
createApiClient(token),
|
||||
prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } }),
|
||||
]);
|
||||
|
||||
const task = registeredTask.task as Prisma.JsonObject;
|
||||
const id = task.id as string;
|
||||
const taskId = (registeredTask.task as Prisma.JsonObject).id as string;
|
||||
|
||||
// Update the backend with the rejection
|
||||
await oasstApiClient.nackTask(id, reason);
|
||||
await oasstApiClient.nackTask(taskId, reason);
|
||||
|
||||
// Send the results to the client.
|
||||
res.status(200).json({});
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
import { withoutRole } from "src/lib/auth";
|
||||
import { createApiClientFromUser } from "src/lib/oasst_client_factory";
|
||||
import { getBackendUserCore } from "src/lib/users";
|
||||
|
||||
/**
|
||||
* Adds a report for a message
|
||||
*
|
||||
*/
|
||||
const handler = withoutRole("banned", async (req, res, token) => {
|
||||
// Parse out the local message_id, and the interaction contents.
|
||||
const { message_id, text } = req.body;
|
||||
|
||||
const user = await getBackendUserCore(token.sub);
|
||||
const oasstApiClient = createApiClientFromUser(user);
|
||||
try {
|
||||
await oasstApiClient.send_report(message_id, user, text);
|
||||
} catch (err) {
|
||||
console.error(JSON.stringify(err));
|
||||
return res.status(500).json(err);
|
||||
}
|
||||
|
||||
res.status(200).end();
|
||||
});
|
||||
|
||||
export default handler;
|
||||
@@ -5,8 +5,10 @@ import { withoutRole } from "src/lib/auth";
|
||||
*
|
||||
*/
|
||||
const handler = withoutRole("banned", async (req, res, token) => {
|
||||
// TODO: move to oasst_api_client
|
||||
// Parse out the local message_id, and the interaction contents.
|
||||
const { message_id, label_map, text } = req.body;
|
||||
const { message_id, label_map } = req.body;
|
||||
|
||||
const interactionRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/text_labels`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
@@ -17,7 +19,8 @@ const handler = withoutRole("banned", async (req, res, token) => {
|
||||
type: "text_labels",
|
||||
message_id: message_id,
|
||||
labels: label_map,
|
||||
text: text,
|
||||
text: "", // used only in reporting
|
||||
is_report: false,
|
||||
user: {
|
||||
id: token.sub,
|
||||
display_name: token.name || token.email,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { Prisma } from "@prisma/client";
|
||||
import { withoutRole } from "src/lib/auth";
|
||||
import { oasstApiClient } from "src/lib/oasst_api_client";
|
||||
import { createApiClient } from "src/lib/oasst_client_factory";
|
||||
import prisma from "src/lib/prismadb";
|
||||
import { getBackendUserCore, getUserLanguage } from "src/lib/users";
|
||||
|
||||
@@ -18,13 +18,18 @@ const handler = withoutRole("banned", async (req, res, token) => {
|
||||
// Parse out the local task ID and the interaction contents.
|
||||
const { id: frontendId, content, update_type } = req.body;
|
||||
|
||||
// Record that the user has done meaningful work and is no longer new.
|
||||
await prisma.user.update({ where: { id: token.sub }, data: { isNew: false } });
|
||||
// do in parallel since they are independent
|
||||
const [_, registeredTask, oasstApiClient] = await Promise.all([
|
||||
// Record that the user has done meaningful work and is no longer new.
|
||||
prisma.user.update({ where: { id: token.sub }, data: { isNew: false } }),
|
||||
// Accept the task so that we can complete it, this will probably go away soon.
|
||||
prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } }),
|
||||
// Create client for upcoming requests
|
||||
createApiClient(token),
|
||||
]);
|
||||
|
||||
const taskId = (registeredTask.task as Prisma.JsonObject).id as string;
|
||||
|
||||
// Accept the task so that we can complete it, this will probably go away soon.
|
||||
const registeredTask = await prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } });
|
||||
const task = registeredTask.task as Prisma.JsonObject;
|
||||
const taskId = task.id as string;
|
||||
await oasstApiClient.ackTask(taskId, registeredTask.id);
|
||||
|
||||
// Log the interaction locally to create our user_post_id needed by the Task
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import { withoutRole } from "src/lib/auth";
|
||||
import { oasstApiClient } from "src/lib/oasst_api_client";
|
||||
import { createApiClient } from "src/lib/oasst_client_factory";
|
||||
|
||||
/**
|
||||
* Returns the set of valid labels that can be applied to messages.
|
||||
*/
|
||||
const handler = withoutRole("banned", async (req, res) => {
|
||||
const valid_labels = await oasstApiClient.fetch_valid_text();
|
||||
const handler = withoutRole("banned", async (req, res, token) => {
|
||||
const client = await createApiClient(token);
|
||||
const valid_labels = await client.fetch_valid_text();
|
||||
res.status(200).json(valid_labels);
|
||||
});
|
||||
|
||||
|
||||
@@ -1,32 +1,9 @@
|
||||
import Head from "next/head";
|
||||
import { TaskEmptyState } from "src/components/EmptyState";
|
||||
import { getDashboardLayout } from "src/components/Layout";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import { useCreateAssistantReply } from "src/hooks/tasks/useCreateReply";
|
||||
import { TaskPage } from "src/components/TaskPage/TaskPage";
|
||||
import { TaskType } from "src/types/Task";
|
||||
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
|
||||
|
||||
const AssistantReply = () => {
|
||||
const { tasks, isLoading, reset, trigger } = useCreateAssistantReply();
|
||||
|
||||
if (isLoading) {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
|
||||
if (tasks.length === 0) {
|
||||
return <TaskEmptyState />;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<title>Reply as Assistant</title>
|
||||
<meta name="description" content="Reply as Assistant." />
|
||||
</Head>
|
||||
<Task key={tasks[0].task.id} frontendId={tasks[0].id} task={tasks[0].task} trigger={trigger} mutate={reset} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
const AssistantReply = () => <TaskPage type={TaskType.assistant_reply} />;
|
||||
|
||||
AssistantReply.getLayout = getDashboardLayout;
|
||||
|
||||
|
||||
@@ -1,32 +1,9 @@
|
||||
import Head from "next/head";
|
||||
import { TaskEmptyState } from "src/components/EmptyState";
|
||||
import { getDashboardLayout } from "src/components/Layout";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import { useCreateInitialPrompt } from "src/hooks/tasks/useCreateReply";
|
||||
import { TaskPage } from "src/components/TaskPage/TaskPage";
|
||||
import { TaskType } from "src/types/Task";
|
||||
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
|
||||
|
||||
const InitialPrompt = () => {
|
||||
const { tasks, isLoading, reset, trigger } = useCreateInitialPrompt();
|
||||
|
||||
if (isLoading) {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
|
||||
if (tasks.length === 0) {
|
||||
return <TaskEmptyState />;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<title>Initial Prompt</title>
|
||||
<meta name="description" content="Add an initial Prompt." />
|
||||
</Head>
|
||||
<Task key={tasks[0].task.id} frontendId={tasks[0].id} task={tasks[0].task} trigger={trigger} mutate={reset} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
const InitialPrompt = () => <TaskPage type={TaskType.initial_prompt} />;
|
||||
|
||||
InitialPrompt.getLayout = getDashboardLayout;
|
||||
|
||||
|
||||
@@ -1,33 +1,10 @@
|
||||
import Head from "next/head";
|
||||
import { TaskEmptyState } from "src/components/EmptyState";
|
||||
import { getDashboardLayout } from "src/components/Layout";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import { useCreatePrompterReply } from "src/hooks/tasks/useCreateReply";
|
||||
import { TaskPage } from "src/components/TaskPage/TaskPage";
|
||||
import { TaskType } from "src/types/Task";
|
||||
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
|
||||
|
||||
const UserReply = () => {
|
||||
const { tasks, isLoading, reset, trigger } = useCreatePrompterReply();
|
||||
const PrompterReply = () => <TaskPage type={TaskType.prompter_reply} />;
|
||||
|
||||
if (isLoading) {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
PrompterReply.getLayout = getDashboardLayout;
|
||||
|
||||
if (tasks.length === 0) {
|
||||
return <TaskEmptyState />;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<title>Reply as User</title>
|
||||
<meta name="description" content="Reply as User." />
|
||||
</Head>
|
||||
<Task key={tasks[0].task.id} frontendId={tasks[0].id} task={tasks[0].task} trigger={trigger} mutate={reset} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
UserReply.getLayout = getDashboardLayout;
|
||||
|
||||
export default UserReply;
|
||||
export default PrompterReply;
|
||||
|
||||
@@ -1,32 +1,9 @@
|
||||
import Head from "next/head";
|
||||
import { TaskEmptyState } from "src/components/EmptyState";
|
||||
import { getDashboardLayout } from "src/components/Layout";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import { useRankAssistantRepliesTask } from "src/hooks/tasks/useRankReplies";
|
||||
import { TaskPage } from "src/components/TaskPage/TaskPage";
|
||||
import { TaskType } from "src/types/Task";
|
||||
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
|
||||
|
||||
const RankAssistantReplies = () => {
|
||||
const { tasks, isLoading, reset, trigger } = useRankAssistantRepliesTask();
|
||||
|
||||
if (isLoading) {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
|
||||
if (tasks.length === 0) {
|
||||
return <TaskEmptyState />;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<title>Rank Assistant Replies</title>
|
||||
<meta name="description" content="Rank Assistant Replies." />
|
||||
</Head>
|
||||
<Task key={tasks[0].task.id} frontendId={tasks[0].id} task={tasks[0].task} trigger={trigger} mutate={reset} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
const RankAssistantReplies = () => <TaskPage type={TaskType.rank_assistant_replies} />;
|
||||
|
||||
RankAssistantReplies.getLayout = getDashboardLayout;
|
||||
|
||||
|
||||
@@ -1,32 +1,9 @@
|
||||
import Head from "next/head";
|
||||
import { TaskEmptyState } from "src/components/EmptyState";
|
||||
import { getDashboardLayout } from "src/components/Layout";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import { useRankInitialPromptsTask } from "src/hooks/tasks/useRankReplies";
|
||||
import { TaskPage } from "src/components/TaskPage/TaskPage";
|
||||
import { TaskType } from "src/types/Task";
|
||||
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
|
||||
|
||||
const RankInitialPrompts = () => {
|
||||
const { tasks, isLoading, reset, trigger } = useRankInitialPromptsTask();
|
||||
|
||||
if (isLoading) {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
|
||||
if (tasks.length === 0) {
|
||||
return <TaskEmptyState />;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<title>Rank Initial Prompts</title>
|
||||
<meta name="description" content="Rank initial prompts." />
|
||||
</Head>
|
||||
<Task key={tasks[0].task.id} frontendId={tasks[0].id} task={tasks[0].task} trigger={trigger} mutate={reset} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
const RankInitialPrompts = () => <TaskPage type={TaskType.rank_initial_prompts} />;
|
||||
|
||||
RankInitialPrompts.getLayout = getDashboardLayout;
|
||||
|
||||
|
||||
@@ -1,33 +1,10 @@
|
||||
import Head from "next/head";
|
||||
import { TaskEmptyState } from "src/components/EmptyState";
|
||||
import { getDashboardLayout } from "src/components/Layout";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import { useRankPrompterRepliesTask } from "src/hooks/tasks/useRankReplies";
|
||||
import { TaskPage } from "src/components/TaskPage/TaskPage";
|
||||
import { TaskType } from "src/types/Task";
|
||||
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
|
||||
|
||||
const RankUserReplies = () => {
|
||||
const { tasks, isLoading, reset, trigger } = useRankPrompterRepliesTask();
|
||||
const RankPrompterReplies = () => <TaskPage type={TaskType.rank_prompter_replies} />;
|
||||
|
||||
if (isLoading) {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
RankPrompterReplies.getLayout = getDashboardLayout;
|
||||
|
||||
if (tasks.length === 0) {
|
||||
return <TaskEmptyState />;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<title>Rank User Replies</title>
|
||||
<meta name="description" content="Rank User Replies." />
|
||||
</Head>
|
||||
<Task key={tasks[0].task.id} frontendId={tasks[0].id} task={tasks[0].task} trigger={trigger} mutate={reset} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
RankUserReplies.getLayout = getDashboardLayout;
|
||||
|
||||
export default RankUserReplies;
|
||||
export default RankPrompterReplies;
|
||||
|
||||
@@ -1,32 +1,9 @@
|
||||
import Head from "next/head";
|
||||
import { TaskEmptyState } from "src/components/EmptyState";
|
||||
import { getDashboardLayout } from "src/components/Layout";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import { useLabelAssistantReplyTask } from "src/hooks/tasks/useLabelingTask";
|
||||
import { TaskPage } from "src/components/TaskPage/TaskPage";
|
||||
import { TaskType } from "src/types/Task";
|
||||
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
|
||||
|
||||
const LabelAssistantReply = () => {
|
||||
const { tasks, isLoading, trigger, reset } = useLabelAssistantReplyTask();
|
||||
|
||||
if (isLoading) {
|
||||
return <LoadingScreen />;
|
||||
}
|
||||
|
||||
if (tasks.length === 0) {
|
||||
return <TaskEmptyState />;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<title>Label Assistant Reply</title>
|
||||
<meta name="description" content="Label Assistant Reply" />
|
||||
</Head>
|
||||
<Task key={tasks[0].task.id} frontendId={tasks[0].id} task={tasks[0].task} trigger={trigger} mutate={reset} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
const LabelAssistantReply = () => <TaskPage type={TaskType.label_assistant_reply} />;
|
||||
|
||||
LabelAssistantReply.getLayout = getDashboardLayout;
|
||||
|
||||
|
||||
@@ -1,32 +1,9 @@
|
||||
import Head from "next/head";
|
||||
import { TaskEmptyState } from "src/components/EmptyState";
|
||||
import { getDashboardLayout } from "src/components/Layout";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import { useLabelInitialPromptTask } from "src/hooks/tasks/useLabelingTask";
|
||||
import { TaskPage } from "src/components/TaskPage/TaskPage";
|
||||
import { TaskType } from "src/types/Task";
|
||||
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
|
||||
|
||||
const LabelInitialPrompt = () => {
|
||||
const { tasks, isLoading, trigger, reset } = useLabelInitialPromptTask();
|
||||
|
||||
if (isLoading) {
|
||||
return <LoadingScreen />;
|
||||
}
|
||||
|
||||
if (tasks.length === 0) {
|
||||
return <TaskEmptyState />;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<title>Label Initial Prompt</title>
|
||||
<meta name="description" content="Label Initial Prompt" />
|
||||
</Head>
|
||||
<Task key={tasks[0].task.id} frontendId={tasks[0].id} task={tasks[0].task} trigger={trigger} mutate={reset} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
const LabelInitialPrompt = () => <TaskPage type={TaskType.label_initial_prompt} />;
|
||||
|
||||
LabelInitialPrompt.getLayout = getDashboardLayout;
|
||||
|
||||
|
||||
@@ -1,32 +1,9 @@
|
||||
import Head from "next/head";
|
||||
import { TaskEmptyState } from "src/components/EmptyState";
|
||||
import { getDashboardLayout } from "src/components/Layout";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import { useLabelPrompterReplyTask } from "src/hooks/tasks/useLabelingTask";
|
||||
import { TaskPage } from "src/components/TaskPage/TaskPage";
|
||||
import { TaskType } from "src/types/Task";
|
||||
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
|
||||
|
||||
const LabelPrompterReply = () => {
|
||||
const { tasks, isLoading, trigger, reset } = useLabelPrompterReplyTask();
|
||||
|
||||
if (isLoading) {
|
||||
return <LoadingScreen />;
|
||||
}
|
||||
|
||||
if (tasks.length === 0) {
|
||||
return <TaskEmptyState />;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<title>Label Prompter Reply</title>
|
||||
<meta name="description" content="Label Prompter Reply" />
|
||||
</Head>
|
||||
<Task key={tasks[0].task.id} frontendId={tasks[0].id} task={tasks[0].task} trigger={trigger} mutate={reset} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
const LabelPrompterReply = () => <TaskPage type={TaskType.label_prompter_reply} />;
|
||||
|
||||
LabelPrompterReply.getLayout = getDashboardLayout;
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { Box, Text, useColorModeValue } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { useTranslation } from "next-i18next";
|
||||
import { serverSideTranslations } from "next-i18next/serverSideTranslations";
|
||||
import { getDashboardLayout } from "src/components/Layout";
|
||||
import { MessageLoading } from "src/components/Loading/MessageLoading";
|
||||
@@ -10,6 +11,7 @@ import { Message } from "src/types/Conversation";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
|
||||
const MessageDetail = ({ id }: { id: string }) => {
|
||||
const { t } = useTranslation(["message", "common"]);
|
||||
const backgroundColor = useColorModeValue("white", "gray.800");
|
||||
|
||||
const { isLoading: isLoadingParent, data: parent } = useSWRImmutable<Message>(`/api/messages/${id}/parent`, get);
|
||||
@@ -20,7 +22,7 @@ const MessageDetail = ({ id }: { id: string }) => {
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<title>Open Assistant</title>
|
||||
<title>{t("common:title")}</title>
|
||||
<meta
|
||||
name="description"
|
||||
content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world."
|
||||
@@ -32,10 +34,10 @@ const MessageDetail = ({ id }: { id: string }) => {
|
||||
<>
|
||||
<Box pb="4">
|
||||
<Text fontWeight="bold" fontSize="xl" pb="2">
|
||||
Parent
|
||||
{t("parent")}
|
||||
</Text>
|
||||
<Box bg={backgroundColor} padding="4" borderRadius="xl" boxShadow="base" width="fit-content">
|
||||
<MessageTableEntry enabled item={parent} />
|
||||
<MessageTableEntry enabled message={parent} />
|
||||
</Box>
|
||||
</Box>
|
||||
</>
|
||||
@@ -54,7 +56,7 @@ MessageDetail.getLayout = (page) => getDashboardLayout(page);
|
||||
export const getServerSideProps = async ({ locale, query }) => ({
|
||||
props: {
|
||||
id: query.id,
|
||||
...(await serverSideTranslations(locale, ["common"])),
|
||||
...(await serverSideTranslations(locale, ["common", "message"])),
|
||||
},
|
||||
});
|
||||
|
||||
|
||||
@@ -1,34 +1,10 @@
|
||||
import Head from "next/head";
|
||||
import { TaskEmptyState } from "src/components/EmptyState";
|
||||
import { getDashboardLayout } from "src/components/Layout";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import { useGenericTaskAPI } from "src/hooks/tasks/useGenericTaskAPI";
|
||||
import { TaskPage } from "src/components/TaskPage/TaskPage";
|
||||
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
|
||||
import { TaskType } from "src/types/Task";
|
||||
|
||||
const RandomTask = () => {
|
||||
const { tasks, isLoading, trigger, reset } = useGenericTaskAPI(TaskType.random);
|
||||
const Random = () => <TaskPage type={TaskType.random} />;
|
||||
|
||||
if (isLoading) {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
Random.getLayout = getDashboardLayout;
|
||||
|
||||
if (tasks.length === 0) {
|
||||
return <TaskEmptyState />;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<title>Random Task</title>
|
||||
<meta name="description" content="Random Task." />
|
||||
</Head>
|
||||
<Task key={tasks[0].task.id} frontendId={tasks[0].id} task={tasks[0].task} trigger={trigger} mutate={reset} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
RandomTask.getLayout = (page) => getDashboardLayout(page);
|
||||
|
||||
export default RandomTask;
|
||||
export default Random;
|
||||
|
||||
@@ -1,7 +1,22 @@
|
||||
export interface Message {
|
||||
export type EmojiOp = "add" | "remove" | "toggle";
|
||||
|
||||
export interface MessageEmoji {
|
||||
name: string;
|
||||
count: number;
|
||||
}
|
||||
|
||||
export interface MessageEmojis {
|
||||
emojis: { [emoji: string]: number };
|
||||
user_emojis: string[];
|
||||
}
|
||||
|
||||
export interface Message extends MessageEmojis {
|
||||
id: string;
|
||||
text: string;
|
||||
is_assistant: boolean;
|
||||
id: string;
|
||||
lang: string;
|
||||
created_date: string; // iso date string
|
||||
parent_id: string;
|
||||
frontend_message_id?: string;
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
import { MutatorCallback, MutatorOptions } from "swr";
|
||||
|
||||
import { BaseTask, TaskResponse, TaskType } from "./Task";
|
||||
|
||||
type ConcreteTaskResponse = TaskResponse<BaseTask>;
|
||||
type TaskError = { errorCode: number; message: string };
|
||||
|
||||
type Trigger = (
|
||||
extraArgument?: unknown,
|
||||
options?: MutatorOptions<ConcreteTaskResponse>
|
||||
) => Promise<ConcreteTaskResponse>;
|
||||
|
||||
type Reset = (
|
||||
data?: ConcreteTaskResponse | Promise<ConcreteTaskResponse> | MutatorCallback<ConcreteTaskResponse>,
|
||||
opts?: boolean | MutatorOptions<ConcreteTaskResponse>
|
||||
) => Promise<ConcreteTaskResponse>;
|
||||
|
||||
type TaskAPIHook = {
|
||||
tasks: TaskResponse<BaseTask>[];
|
||||
isLoading: boolean;
|
||||
error: TaskError;
|
||||
trigger: Trigger;
|
||||
reset: Reset;
|
||||
};
|
||||
|
||||
export type TaskApiHooks = Record<TaskType, (args: TaskType) => TaskAPIHook>;
|
||||
@@ -1,4 +1,4 @@
|
||||
export const enum TaskType {
|
||||
export enum TaskType {
|
||||
initial_prompt = "initial_prompt",
|
||||
assistant_reply = "assistant_reply",
|
||||
prompter_reply = "prompter_reply",
|
||||
|
||||
+24
-14
@@ -1,4 +1,4 @@
|
||||
import { Conversation } from "./Conversation";
|
||||
import { Conversation, Message } from "./Conversation";
|
||||
import { BaseTask, TaskType } from "./Task";
|
||||
|
||||
export interface CreateInitialPromptTask extends BaseTask {
|
||||
@@ -33,29 +33,39 @@ export interface RankPrompterRepliesTask extends BaseTask {
|
||||
replies: string[];
|
||||
}
|
||||
|
||||
export interface LabelAssistantReplyTask extends BaseTask {
|
||||
export interface Label {
|
||||
display_text: string;
|
||||
help_text: string;
|
||||
name: string;
|
||||
widget: "flag" | "yes_no" | "likert";
|
||||
}
|
||||
|
||||
export interface BaseLabelTask extends BaseTask {
|
||||
message_id: string;
|
||||
labels: Label[];
|
||||
valid_labels: string[];
|
||||
disposition: "spam" | "quality";
|
||||
mode: "simple" | "full";
|
||||
mandatory_labels?: string[];
|
||||
}
|
||||
|
||||
export interface LabelAssistantReplyTask extends BaseLabelTask {
|
||||
type: TaskType.label_assistant_reply;
|
||||
message_id: string;
|
||||
conversation: Conversation;
|
||||
reply_message: Message;
|
||||
reply: string;
|
||||
valid_labels: string[];
|
||||
mode: "simple" | "full";
|
||||
mandatory_labels?: string[];
|
||||
}
|
||||
|
||||
export interface LabelPrompterReplyTask extends BaseTask {
|
||||
export interface LabelPrompterReplyTask extends BaseLabelTask {
|
||||
type: TaskType.label_prompter_reply;
|
||||
message_id: string;
|
||||
conversation: Conversation;
|
||||
reply_message: Message;
|
||||
reply: string;
|
||||
valid_labels: string[];
|
||||
mode: "simple" | "full";
|
||||
mandatory_labels?: string[];
|
||||
}
|
||||
|
||||
export interface LabelInitialPromptTask extends BaseTask {
|
||||
export interface LabelInitialPromptTask extends BaseLabelTask {
|
||||
type: TaskType.label_initial_prompt;
|
||||
message_id: string;
|
||||
valid_labels: string[];
|
||||
prompt: string;
|
||||
}
|
||||
|
||||
export type LabelTaskType = LabelAssistantReplyTask | LabelPrompterReplyTask | LabelInitialPromptTask;
|
||||
|
||||
@@ -5,6 +5,7 @@ export const colors = {
|
||||
div: "white",
|
||||
text: "black",
|
||||
highlight: "blue.400",
|
||||
active: "blue.400",
|
||||
},
|
||||
dark: {
|
||||
bg: "gray.900",
|
||||
@@ -12,5 +13,6 @@ export const colors = {
|
||||
div: "gray.700",
|
||||
text: "gray.200",
|
||||
highlight: "blue.500",
|
||||
active: "blue.500",
|
||||
},
|
||||
};
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user