mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Merging from main
This commit is contained in:
@@ -19,7 +19,7 @@ devcontainers in this repo.
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
A successfull run should look something like this:
|
||||
A successful run should look something like this:
|
||||
|
||||
```
|
||||
@andrewm4894 ➜ /workspaces/Open-Assistant (devcontainer-improvements) $ pre-commit run --all-files
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
* text=auto eol=lf
|
||||
*.ipynb linguist-documentation
|
||||
|
||||
@@ -76,7 +76,12 @@ repos:
|
||||
rev: v2.7.1
|
||||
hooks:
|
||||
- id: prettier
|
||||
args: [--prose-wrap=always, --write]
|
||||
args:
|
||||
[
|
||||
--prose-wrap=always,
|
||||
--write,
|
||||
--ignore-path=./website/.prettierignore,
|
||||
]
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
|
||||
+6
-6
@@ -50,12 +50,12 @@ contributions smoothly we recommend the following:
|
||||
[Here](https://github.com/LAION-AI/Open-Assistant/pull/658) is an example PR
|
||||
for this project to illustrate this flow.
|
||||
1. If you're lucky, we can merge your change into `main` without any problems.
|
||||
If there's changes to files you're working on, resolve them by:
|
||||
1. First try rebase as suggested
|
||||
[in these instructions](https://timwise.co.uk/2019/10/14/merge-vs-rebase/#should-you-rebase).
|
||||
1. If rebase feels too painful, merge as suggested
|
||||
[in these instructions](https://timwise.co.uk/2019/10/14/merge-vs-rebase/#should-you-merge).
|
||||
1. Once you've resolved any conflicts, finish the review and
|
||||
If there's changes to files you're working on, resolve them by :
|
||||
1. First try rebase as suggested
|
||||
[in these instructions](https://timwise.co.uk/2019/10/14/merge-vs-rebase/#should-you-rebase).
|
||||
1. If rebase feels too painful, merge as suggested
|
||||
[in these instructions](https://timwise.co.uk/2019/10/14/merge-vs-rebase/#should-you-merge).
|
||||
1. Once you've resolved conflicts (if any), finish the review and
|
||||
[squash and merge](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/incorporating-changes-from-a-pull-request/about-pull-request-merges#squash-and-merge-your-commits)
|
||||
your PR (when squashing try to clean up or update the individual commit
|
||||
messages to be one sensible single one).
|
||||
|
||||
+34
@@ -0,0 +1,34 @@
|
||||
"""add message_id to message_reaction
|
||||
|
||||
Revision ID: 8ba17b5f467a
|
||||
Revises: 160ac010efcc
|
||||
Create Date: 2023-01-24 11:34:42.167575
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8ba17b5f467a"
|
||||
down_revision = "160ac010efcc"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("message_reaction", sa.Column("message_id", sqlmodel.sql.sqltypes.GUID(), nullable=True))
|
||||
op.create_index(op.f("ix_message_reaction_message_id"), "message_reaction", ["message_id"], unique=False)
|
||||
op.add_column("text_labels", sa.Column("task_id", sqlmodel.sql.sqltypes.GUID(), nullable=True))
|
||||
op.create_index(op.f("ix_text_labels_task_id"), "text_labels", ["task_id"], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f("ix_text_labels_task_id"), table_name="text_labels")
|
||||
op.drop_column("text_labels", "task_id")
|
||||
op.drop_index(op.f("ix_message_reaction_message_id"), table_name="message_reaction")
|
||||
op.drop_column("message_reaction", "message_id")
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,44 @@
|
||||
"""add message_emoji
|
||||
|
||||
Revision ID: 40ed93df0ed5
|
||||
Revises: 8ba17b5f467a
|
||||
Create Date: 2023-01-24 22:56:28.229408
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "40ed93df0ed5"
|
||||
down_revision = "8ba17b5f467a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"message_emoji",
|
||||
sa.Column("message_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column(
|
||||
"created_date", sa.DateTime(timezone=True), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False
|
||||
),
|
||||
sa.Column("emoji", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False),
|
||||
sa.ForeignKeyConstraint(["message_id"], ["message.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("message_id", "user_id", "emoji"),
|
||||
)
|
||||
op.create_index("ix_message_emoji__user_id__message_id", "message_emoji", ["user_id", "message_id"], unique=False)
|
||||
op.add_column("message", sa.Column("emojis", postgresql.JSONB(astext_type=sa.Text()), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("message", "emojis")
|
||||
op.drop_index("ix_message_emoji__user_id__message_id", table_name="message_emoji")
|
||||
op.drop_table("message_emoji")
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,26 @@
|
||||
"""add task created date index
|
||||
|
||||
Revision ID: c84fcd6900dc
|
||||
Revises: 40ed93df0ed5
|
||||
Create Date: 2023-01-26 18:35:43.061589
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c84fcd6900dc"
|
||||
down_revision = "40ed93df0ed5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_index(op.f("ix_task_created_date"), "task", ["created_date"], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f("ix_task_created_date"), table_name="task")
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,29 @@
|
||||
"""add user.show_on_leaderboard
|
||||
|
||||
Revision ID: f856bf19d32b
|
||||
Revises: c84fcd6900dc
|
||||
Create Date: 2023-01-27 20:13:56.533374
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f856bf19d32b"
|
||||
down_revision = "c84fcd6900dc"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"user", sa.Column("show_on_leaderboard", sa.Boolean(), server_default=sa.text("true"), nullable=False)
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("user", "show_on_leaderboard")
|
||||
# ### end Alembic commands ###
|
||||
@@ -273,6 +273,38 @@ def get_openapi_schema():
|
||||
return json.dumps(app.openapi())
|
||||
|
||||
|
||||
def export_ready_trees(file: Optional[str] = None, use_compression: bool = False):
|
||||
try:
|
||||
with Session(engine) as db:
|
||||
api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=db)
|
||||
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
|
||||
|
||||
ur = UserRepository(db=db, api_client=api_client)
|
||||
tr = TaskRepository(db=db, api_client=api_client, client_user=dummy_user, user_repository=ur)
|
||||
pr = PromptRepository(
|
||||
db=db, api_client=api_client, client_user=dummy_user, user_repository=ur, task_repository=tr
|
||||
)
|
||||
tm = TreeManager(db, pr)
|
||||
|
||||
tm.export_all_ready_trees(file, use_compression=use_compression)
|
||||
except Exception:
|
||||
logger.exception("Error exporting trees.")
|
||||
|
||||
|
||||
def retry_scoring_failed_message_trees():
|
||||
try:
|
||||
logger.info("TreeManager.retry_scoring_failed_message_trees()")
|
||||
with Session(engine) as db:
|
||||
api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=db)
|
||||
|
||||
pr = PromptRepository(db=db, api_client=api_client)
|
||||
tm = TreeManager(db, pr)
|
||||
tm.retry_scoring_failed_message_trees()
|
||||
|
||||
except Exception:
|
||||
logger.exception("TreeManager.retry_scoring_failed_message_trees() failed.")
|
||||
|
||||
|
||||
def main():
|
||||
# Importing here so we don't import packages unnecessarily if we're
|
||||
# importing main as a module.
|
||||
@@ -289,11 +321,28 @@ def main():
|
||||
)
|
||||
parser.add_argument("--host", help="The host to run the server", default="0.0.0.0")
|
||||
parser.add_argument("--port", help="The port to run the server", default=8080)
|
||||
parser.add_argument(
|
||||
"--export", help="Export all trees which are ready for exporting.", action=argparse.BooleanOptionalAction
|
||||
)
|
||||
parser.add_argument(
|
||||
"--export-file",
|
||||
help="Name of file to export trees to. If not provided when exporting, output will be send to STDOUT",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--retry-scoring",
|
||||
help="Retry scoring failed message trees",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.print_openapi_schema:
|
||||
print(get_openapi_schema())
|
||||
elif args.export:
|
||||
use_compression: bool = ".gz" in args.export_file
|
||||
export_ready_trees(file=args.export_file, use_compression=use_compression)
|
||||
elif args.retry_scoring:
|
||||
retry_scoring_failed_message_trees()
|
||||
else:
|
||||
uvicorn.run(app, host=args.host, port=args.port)
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from http import HTTPStatus
|
||||
from secrets import token_hex
|
||||
from typing import Generator
|
||||
from typing import Generator, NamedTuple
|
||||
|
||||
from fastapi import Depends, Request, Response, Security
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
@@ -19,22 +19,46 @@ def get_db() -> Generator:
|
||||
yield db
|
||||
|
||||
|
||||
api_key_query = APIKeyQuery(name="api_key", auto_error=False)
|
||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
api_key_query = APIKeyQuery(name="api_key", scheme_name="api-key", auto_error=False)
|
||||
api_key_header = APIKeyHeader(name="X-API-Key", scheme_name="api-key", auto_error=False)
|
||||
oasst_user_query = APIKeyQuery(name="oasst_user", scheme_name="oasst-user", auto_error=False)
|
||||
oasst_user_header = APIKeyHeader(name="x-oasst-user", scheme_name="oasst-user", auto_error=False)
|
||||
|
||||
bearer_token = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
async def get_api_key(
|
||||
def get_api_key(
|
||||
api_key_query: str = Security(api_key_query),
|
||||
api_key_header: str = Security(api_key_header),
|
||||
):
|
||||
) -> str:
|
||||
if api_key_query:
|
||||
return api_key_query
|
||||
else:
|
||||
return api_key_header
|
||||
|
||||
|
||||
class FrontendUserId(NamedTuple):
|
||||
auth_method: str
|
||||
username: str
|
||||
|
||||
|
||||
def get_frontend_user_id(
|
||||
user_query: str = Security(oasst_user_query),
|
||||
user_header: str = Security(oasst_user_header),
|
||||
) -> FrontendUserId:
|
||||
def split_user(v: str) -> tuple[str, str]:
|
||||
if type(v) is str:
|
||||
v = v.split(":", maxsplit=1)
|
||||
if len(v) == 2:
|
||||
return FrontendUserId(auth_method=v[0], username=v[1])
|
||||
return FrontendUserId(auth_method=None, username=None)
|
||||
|
||||
if user_query:
|
||||
return split_user(user_query)
|
||||
else:
|
||||
return split_user(user_header)
|
||||
|
||||
|
||||
def create_api_client(
|
||||
*,
|
||||
session: Session,
|
||||
|
||||
@@ -70,13 +70,14 @@ def query_frontend_user_messages(
|
||||
only_roots: bool = False,
|
||||
desc: bool = True,
|
||||
include_deleted: bool = False,
|
||||
lang: Optional[str] = None,
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
"""
|
||||
Query frontend user messages.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client)
|
||||
pr = PromptRepository(db, api_client, auth_method=auth_method, username=username)
|
||||
messages = pr.query_messages_ordered_by_created_date(
|
||||
auth_method=auth_method,
|
||||
username=username,
|
||||
@@ -87,6 +88,7 @@ def query_frontend_user_messages(
|
||||
lte_created_date=end_date,
|
||||
only_roots=only_roots,
|
||||
deleted=None if include_deleted else False,
|
||||
lang=lang,
|
||||
)
|
||||
return utils.prepare_message_list(messages)
|
||||
|
||||
@@ -95,24 +97,28 @@ def query_frontend_user_messages(
|
||||
def query_frontend_user_messages_cursor(
|
||||
auth_method: str,
|
||||
username: str,
|
||||
lt: Optional[str] = None,
|
||||
gt: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
only_roots: Optional[bool] = False,
|
||||
include_deleted: Optional[bool] = False,
|
||||
max_count: Optional[int] = Query(10, gt=0, le=1000),
|
||||
desc: Optional[bool] = False,
|
||||
lang: Optional[str] = None,
|
||||
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
return get_messages_cursor(
|
||||
lt=lt,
|
||||
gt=gt,
|
||||
before=before,
|
||||
after=after,
|
||||
auth_method=auth_method,
|
||||
username=username,
|
||||
only_roots=only_roots,
|
||||
include_deleted=include_deleted,
|
||||
max_count=max_count,
|
||||
desc=desc,
|
||||
lang=lang,
|
||||
frontend_user=frontend_user,
|
||||
api_client=api_client,
|
||||
db=db,
|
||||
)
|
||||
|
||||
@@ -7,6 +7,7 @@ from oasst_backend.api import deps
|
||||
from oasst_backend.api.v1 import utils
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.utils.database_utils import CommitMode, managed_tx_function
|
||||
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol
|
||||
from sqlmodel import Session
|
||||
@@ -17,6 +18,7 @@ router = APIRouter()
|
||||
|
||||
@router.get("/", response_model=list[protocol.Message])
|
||||
def query_messages(
|
||||
*,
|
||||
auth_method: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
api_client_id: Optional[str] = None,
|
||||
@@ -26,13 +28,15 @@ def query_messages(
|
||||
only_roots: Optional[bool] = False,
|
||||
desc: Optional[bool] = True,
|
||||
allow_deleted: Optional[bool] = False,
|
||||
lang: Optional[str] = None,
|
||||
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
"""
|
||||
Query messages.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client)
|
||||
pr = PromptRepository(db, api_client, auth_method=frontend_user.auth_method, username=frontend_user.username)
|
||||
messages = pr.query_messages_ordered_by_created_date(
|
||||
auth_method=auth_method,
|
||||
username=username,
|
||||
@@ -43,6 +47,7 @@ def query_messages(
|
||||
lte_created_date=end_date,
|
||||
only_roots=only_roots,
|
||||
deleted=None if allow_deleted else False,
|
||||
lang=lang,
|
||||
)
|
||||
|
||||
return utils.prepare_message_list(messages)
|
||||
@@ -50,8 +55,9 @@ def query_messages(
|
||||
|
||||
@router.get("/cursor", response_model=protocol.MessagePage)
|
||||
def get_messages_cursor(
|
||||
lt: Optional[str] = None,
|
||||
gt: Optional[str] = None,
|
||||
*,
|
||||
before: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
user_id: Optional[UUID] = None,
|
||||
auth_method: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
@@ -60,9 +66,13 @@ def get_messages_cursor(
|
||||
include_deleted: Optional[bool] = False,
|
||||
max_count: Optional[int] = Query(10, gt=0, le=1000),
|
||||
desc: Optional[bool] = False,
|
||||
lang: Optional[str] = None,
|
||||
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
assert max_count is not None
|
||||
|
||||
def split_cursor(x: str | None) -> tuple[datetime, UUID]:
|
||||
if not x:
|
||||
return None, None
|
||||
@@ -74,11 +84,21 @@ def get_messages_cursor(
|
||||
except ValueError:
|
||||
raise OasstError("Invalid cursor value", OasstErrorCode.INVALID_CURSOR_VALUE)
|
||||
|
||||
lte_created_date, lt_id = split_cursor(lt)
|
||||
gte_created_date, gt_id = split_cursor(gt)
|
||||
if desc:
|
||||
gte_created_date, gt_id = split_cursor(before)
|
||||
lte_created_date, lt_id = split_cursor(after)
|
||||
query_desc = not (before is not None and not after)
|
||||
else:
|
||||
lte_created_date, lt_id = split_cursor(before)
|
||||
gte_created_date, gt_id = split_cursor(after)
|
||||
query_desc = before is not None and not after
|
||||
|
||||
pr = PromptRepository(db, api_client)
|
||||
messages = pr.query_messages_ordered_by_created_date(
|
||||
print(f"{desc=} {query_desc=} {gte_created_date=} {lte_created_date=}")
|
||||
|
||||
qry_max_count = max_count + 1 if before is None or after is None else max_count
|
||||
|
||||
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
|
||||
items = pr.query_messages_ordered_by_created_date(
|
||||
user_id=user_id,
|
||||
auth_method=auth_method,
|
||||
username=username,
|
||||
@@ -89,22 +109,31 @@ def get_messages_cursor(
|
||||
lt_id=lt_id,
|
||||
only_roots=only_roots,
|
||||
deleted=None if include_deleted else False,
|
||||
desc=desc,
|
||||
limit=max_count,
|
||||
desc=query_desc,
|
||||
limit=qry_max_count,
|
||||
lang=lang,
|
||||
)
|
||||
|
||||
items = utils.prepare_message_list(messages)
|
||||
num_rows = len(items)
|
||||
if qry_max_count > max_count and num_rows == qry_max_count:
|
||||
assert not (before and after)
|
||||
items = items[:-1]
|
||||
|
||||
if desc != query_desc:
|
||||
items.reverse()
|
||||
|
||||
items = utils.prepare_message_list(items)
|
||||
n, p = None, None
|
||||
if len(items) > 0:
|
||||
if len(items) == max_count or gte_created_date:
|
||||
if (num_rows > max_count and before) or after:
|
||||
p = str(items[0].id) + "$" + items[0].created_date.isoformat()
|
||||
if len(items) == max_count or lte_created_date:
|
||||
if num_rows > max_count or before:
|
||||
n = str(items[-1].id) + "$" + items[-1].created_date.isoformat()
|
||||
else:
|
||||
if gte_created_date:
|
||||
p = gte_created_date.isoformat()
|
||||
if lte_created_date:
|
||||
n = lte_created_date.isoformat()
|
||||
if after:
|
||||
p = lte_created_date.isoformat() if desc else gte_created_date.isoformat()
|
||||
if before:
|
||||
n = gte_created_date.isoformat() if desc else lte_created_date.isoformat()
|
||||
|
||||
order = "desc" if desc else "asc"
|
||||
return protocol.MessagePage(prev=p, next=n, sort_key="created_date", order=order, items=items)
|
||||
@@ -112,37 +141,49 @@ def get_messages_cursor(
|
||||
|
||||
@router.get("/{message_id}", response_model=protocol.Message)
|
||||
def get_message(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
*,
|
||||
message_id: UUID,
|
||||
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
"""
|
||||
Get a message by its internal ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client)
|
||||
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
|
||||
message = pr.fetch_message(message_id)
|
||||
return utils.prepare_message(message)
|
||||
|
||||
|
||||
@router.get("/{message_id}/conversation", response_model=protocol.Conversation)
|
||||
def get_conv(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
*,
|
||||
message_id: UUID,
|
||||
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
"""
|
||||
Get a conversation from the tree root and up to the message with given internal ID.
|
||||
"""
|
||||
|
||||
pr = PromptRepository(db, api_client)
|
||||
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
|
||||
messages = pr.fetch_message_conversation(message_id)
|
||||
return utils.prepare_conversation(messages)
|
||||
|
||||
|
||||
@router.get("/{message_id}/tree", response_model=protocol.MessageTree)
|
||||
def get_tree(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
*,
|
||||
message_id: UUID,
|
||||
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
"""
|
||||
Get all messages belonging to the same message tree.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client)
|
||||
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
|
||||
message = pr.fetch_message(message_id)
|
||||
tree = pr.fetch_message_tree(message.message_tree_id, reviewed=False)
|
||||
return utils.prepare_tree(tree, message.message_tree_id)
|
||||
@@ -150,24 +191,32 @@ def get_tree(
|
||||
|
||||
@router.get("/{message_id}/children", response_model=list[protocol.Message])
|
||||
def get_children(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
*,
|
||||
message_id: UUID,
|
||||
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
"""
|
||||
Get all messages belonging to the same message tree.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client)
|
||||
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
|
||||
messages = pr.fetch_message_children(message_id)
|
||||
return utils.prepare_message_list(messages)
|
||||
|
||||
|
||||
@router.get("/{message_id}/descendants", response_model=protocol.MessageTree)
|
||||
def get_descendants(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
*,
|
||||
message_id: UUID,
|
||||
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
"""
|
||||
Get a subtree which starts with this message.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client)
|
||||
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
|
||||
message = pr.fetch_message(message_id)
|
||||
descendants = pr.fetch_message_descendants(message)
|
||||
return utils.prepare_tree(descendants, message.id)
|
||||
@@ -175,12 +224,16 @@ def get_descendants(
|
||||
|
||||
@router.get("/{message_id}/longest_conversation_in_tree", response_model=protocol.Conversation)
|
||||
def get_longest_conv(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
*,
|
||||
message_id: UUID,
|
||||
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
"""
|
||||
Get the longest conversation from the tree of the message.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client)
|
||||
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
|
||||
message = pr.fetch_message(message_id)
|
||||
conv = pr.fetch_longest_conversation(message.message_tree_id)
|
||||
return utils.prepare_conversation(conv)
|
||||
@@ -188,12 +241,16 @@ def get_longest_conv(
|
||||
|
||||
@router.get("/{message_id}/max_children_in_tree", response_model=protocol.MessageTree)
|
||||
def get_max_children(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
*,
|
||||
message_id: UUID,
|
||||
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
"""
|
||||
Get message with the most children from the tree of the provided message.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client)
|
||||
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
|
||||
message = pr.fetch_message(message_id)
|
||||
message, children = pr.fetch_message_with_max_children(message.message_tree_id)
|
||||
return utils.prepare_tree([message, *children], message.id)
|
||||
@@ -201,7 +258,30 @@ def get_max_children(
|
||||
|
||||
@router.delete("/{message_id}", status_code=HTTP_204_NO_CONTENT)
|
||||
def mark_message_deleted(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
|
||||
*,
|
||||
message_id: UUID,
|
||||
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
pr = PromptRepository(db, api_client)
|
||||
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
|
||||
pr.mark_messages_deleted(message_id)
|
||||
|
||||
|
||||
@router.post("/{message_id}/emoji", response_model=protocol.Message)
|
||||
def post_message_emoji(
|
||||
*,
|
||||
message_id: UUID,
|
||||
request: protocol.MessageEmojiRequest,
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
) -> protocol.Message:
|
||||
"""
|
||||
Toggle, add or remove message emoji.
|
||||
"""
|
||||
|
||||
@managed_tx_function(CommitMode.COMMIT)
|
||||
def emoji_tx(session: deps.Session):
|
||||
pr = PromptRepository(session, api_client, client_user=request.user)
|
||||
return pr.handle_message_emoji(message_id, request.op, request.emoji)
|
||||
|
||||
return utils.prepare_message(emoji_tx())
|
||||
|
||||
@@ -77,6 +77,7 @@ def tasks_acknowledge(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
|
||||
task_id: UUID,
|
||||
ack_request: protocol_schema.TaskAck,
|
||||
) -> None:
|
||||
@@ -87,7 +88,7 @@ def tasks_acknowledge(
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
|
||||
try:
|
||||
pr = PromptRepository(db, api_client)
|
||||
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
|
||||
|
||||
# here we store the message id in the database for the task
|
||||
logger.info(f"Frontend acknowledges task {task_id=}, {ack_request=}.")
|
||||
@@ -105,6 +106,7 @@ def tasks_acknowledge_failure(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
|
||||
task_id: UUID,
|
||||
nack_request: protocol_schema.TaskNAck,
|
||||
) -> None:
|
||||
@@ -115,7 +117,7 @@ def tasks_acknowledge_failure(
|
||||
try:
|
||||
logger.info(f"Frontend reports failure to implement task {task_id=}, {nack_request=}.")
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
pr = PromptRepository(db, api_client)
|
||||
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
|
||||
pr.task_repository.acknowledge_task_failure(task_id)
|
||||
except (KeyError, RuntimeError):
|
||||
logger.exception("Failed to not acknowledge task.")
|
||||
|
||||
@@ -3,9 +3,11 @@ from fastapi.security.api_key import APIKey
|
||||
from loguru import logger
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.schemas.text_labels import LabelOption, ValidLabelsResponse
|
||||
from oasst_backend.schemas.text_labels import LabelDescription, ValidLabelsResponse
|
||||
from oasst_backend.utils.database_utils import CommitMode, managed_tx_function
|
||||
from oasst_shared.exceptions import OasstError
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
from oasst_shared.schemas.protocol import TextLabel
|
||||
from starlette.status import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST
|
||||
|
||||
router = APIRouter()
|
||||
@@ -14,20 +16,25 @@ router = APIRouter()
|
||||
@router.post("/", status_code=HTTP_204_NO_CONTENT)
|
||||
def label_text(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
text_labels: protocol_schema.TextLabels,
|
||||
) -> None:
|
||||
"""
|
||||
Label a piece of text.
|
||||
"""
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
|
||||
@managed_tx_function(CommitMode.COMMIT)
|
||||
def store_text_labels(session: deps.Session):
|
||||
api_client = deps.api_auth(api_key, session)
|
||||
pr = PromptRepository(session, api_client, client_user=text_labels.user)
|
||||
pr.store_text_labels(text_labels)
|
||||
|
||||
try:
|
||||
logger.info(f"Labeling text {text_labels=}.")
|
||||
pr = PromptRepository(db, api_client, client_user=text_labels.user)
|
||||
pr.store_text_labels(text_labels)
|
||||
store_text_labels()
|
||||
|
||||
except OasstError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to store label.")
|
||||
raise HTTPException(
|
||||
@@ -39,7 +46,29 @@ def label_text(
|
||||
def get_valid_lables() -> ValidLabelsResponse:
|
||||
return ValidLabelsResponse(
|
||||
valid_labels=[
|
||||
LabelOption(name=l.value, display_text=l.display_text, help_text=l.help_text)
|
||||
for l in protocol_schema.TextLabel
|
||||
LabelDescription(name=l.value, widget=l.widget.value, display_text=l.display_text, help_text=l.help_text)
|
||||
for l in TextLabel
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@router.get("/report_labels")
|
||||
def get_report_lables() -> ValidLabelsResponse:
|
||||
report_labels = [
|
||||
TextLabel.spam,
|
||||
TextLabel.not_appropriate,
|
||||
TextLabel.pii,
|
||||
TextLabel.hate_speech,
|
||||
TextLabel.sexual_content,
|
||||
TextLabel.moral_judgement,
|
||||
TextLabel.political_content,
|
||||
TextLabel.toxicity,
|
||||
TextLabel.violence,
|
||||
TextLabel.quality,
|
||||
]
|
||||
return ValidLabelsResponse(
|
||||
valid_labels=[
|
||||
LabelDescription(name=l.value, widget=l.widget.value, display_text=l.display_text, help_text=l.help_text)
|
||||
for l in report_labels
|
||||
]
|
||||
)
|
||||
|
||||
@@ -28,6 +28,7 @@ def get_users_ordered_by_username(
|
||||
search_text: Optional[str] = None,
|
||||
auth_method: Optional[str] = None,
|
||||
max_count: Optional[int] = Query(100, gt=0, le=10000),
|
||||
desc: Optional[bool] = False,
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
@@ -41,6 +42,7 @@ def get_users_ordered_by_username(
|
||||
auth_method=auth_method,
|
||||
search_text=search_text,
|
||||
limit=max_count,
|
||||
desc=desc,
|
||||
)
|
||||
return [u.to_protocol_frontend_user() for u in users]
|
||||
|
||||
@@ -55,6 +57,7 @@ def get_users_ordered_by_display_name(
|
||||
auth_method: Optional[str] = None,
|
||||
search_text: Optional[str] = None,
|
||||
max_count: Optional[int] = Query(100, gt=0, le=10000),
|
||||
desc: Optional[bool] = False,
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
@@ -68,14 +71,15 @@ def get_users_ordered_by_display_name(
|
||||
auth_method=auth_method,
|
||||
search_text=search_text,
|
||||
limit=max_count,
|
||||
desc=desc,
|
||||
)
|
||||
return [u.to_protocol_frontend_user() for u in users]
|
||||
|
||||
|
||||
@router.get("/cursor", response_model=protocol.FrontEndUserPage)
|
||||
def get_users_cursor(
|
||||
lt: Optional[str] = None,
|
||||
gt: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
sort_key: Optional[str] = Query("username", max_length=32),
|
||||
max_count: Optional[int] = Query(100, gt=0, le=10000),
|
||||
api_client_id: Optional[UUID] = None,
|
||||
@@ -95,7 +99,8 @@ def get_users_cursor(
|
||||
return x, None
|
||||
|
||||
items: list[protocol.FrontEndUser]
|
||||
qry_max_count = max_count + 1 if lt is None or gt is None else max_count
|
||||
qry_max_count = max_count + 1 if before is None or after is None else max_count
|
||||
desc = before is not None and not after
|
||||
|
||||
def get_next_prev(num_rows: int, lt: str | None, gt: str | None, key_fn: Callable[[protocol.FrontEndUser], str]):
|
||||
p, n = None, None
|
||||
@@ -114,17 +119,16 @@ def get_users_cursor(
|
||||
def remove_extra_item(items: list[protocol.FrontEndUser], lt: str | None, gt: str | None):
|
||||
num_rows = len(items)
|
||||
if qry_max_count > max_count and num_rows == qry_max_count:
|
||||
assert not (lt and gt)
|
||||
if lt:
|
||||
items = items[1:]
|
||||
else:
|
||||
items = items[:-1]
|
||||
assert not (lt is not None and gt is not None)
|
||||
items = items[:-1]
|
||||
if desc:
|
||||
items.reverse()
|
||||
return items, num_rows
|
||||
|
||||
n, p = None, None
|
||||
if sort_key == "username":
|
||||
lte_username, lt_id = split_cursor(lt)
|
||||
gte_username, gt_id = split_cursor(gt)
|
||||
lte_username, lt_id = split_cursor(before)
|
||||
gte_username, gt_id = split_cursor(after)
|
||||
items = get_users_ordered_by_username(
|
||||
api_client_id=api_client_id,
|
||||
gte_username=gte_username,
|
||||
@@ -134,6 +138,7 @@ def get_users_cursor(
|
||||
auth_method=auth_method,
|
||||
search_text=search_text,
|
||||
max_count=qry_max_count,
|
||||
desc=desc,
|
||||
api_client=api_client,
|
||||
db=db,
|
||||
)
|
||||
@@ -141,8 +146,8 @@ def get_users_cursor(
|
||||
p, n = get_next_prev(num_rows, lte_username, gte_username, lambda x: x.id)
|
||||
|
||||
elif sort_key == "display_name":
|
||||
lte_display_name, lt_id = split_cursor(lt)
|
||||
gte_display_name, gt_id = split_cursor(gt)
|
||||
lte_display_name, lt_id = split_cursor(before)
|
||||
gte_display_name, gt_id = split_cursor(after)
|
||||
items = get_users_ordered_by_display_name(
|
||||
api_client_id=api_client_id,
|
||||
gte_display_name=gte_display_name,
|
||||
@@ -152,6 +157,7 @@ def get_users_cursor(
|
||||
auth_method=auth_method,
|
||||
search_text=search_text,
|
||||
max_count=qry_max_count,
|
||||
desc=desc,
|
||||
api_client=api_client,
|
||||
db=db,
|
||||
)
|
||||
@@ -184,6 +190,7 @@ def update_user(
|
||||
user_id: UUID,
|
||||
enabled: Optional[bool] = None,
|
||||
notes: Optional[str] = None,
|
||||
show_on_leaderboard: Optional[bool] = None,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
):
|
||||
@@ -191,7 +198,7 @@ def update_user(
|
||||
Update a user by global user ID. Only trusted clients can update users.
|
||||
"""
|
||||
ur = UserRepository(db, api_client)
|
||||
ur.update_user(user_id, enabled, notes)
|
||||
ur.update_user(user_id, enabled, notes, show_on_leaderboard)
|
||||
|
||||
|
||||
@router.delete("/{user_id}", status_code=HTTP_204_NO_CONTENT)
|
||||
@@ -217,13 +224,15 @@ def query_user_messages(
|
||||
only_roots: bool = False,
|
||||
desc: bool = True,
|
||||
include_deleted: bool = False,
|
||||
lang: Optional[str] = None,
|
||||
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
"""
|
||||
Query user messages.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client)
|
||||
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
|
||||
messages = pr.query_messages_ordered_by_created_date(
|
||||
user_id=user_id,
|
||||
api_client_id=api_client_id,
|
||||
@@ -233,6 +242,7 @@ def query_user_messages(
|
||||
lte_created_date=end_date,
|
||||
only_roots=only_roots,
|
||||
deleted=None if include_deleted else False,
|
||||
lang=lang,
|
||||
)
|
||||
|
||||
return utils.prepare_message_list(messages)
|
||||
@@ -241,23 +251,27 @@ def query_user_messages(
|
||||
@router.get("/{user_id}/messages/cursor", response_model=protocol.MessagePage)
|
||||
def query_user_messages_cursor(
|
||||
user_id: Optional[UUID],
|
||||
lt: Optional[str] = None,
|
||||
gt: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
only_roots: Optional[bool] = False,
|
||||
include_deleted: Optional[bool] = False,
|
||||
max_count: Optional[int] = Query(10, gt=0, le=1000),
|
||||
desc: Optional[bool] = False,
|
||||
lang: Optional[str] = None,
|
||||
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
return get_messages_cursor(
|
||||
lt=lt,
|
||||
gt=gt,
|
||||
before=before,
|
||||
after=after,
|
||||
user_id=user_id,
|
||||
only_roots=only_roots,
|
||||
include_deleted=include_deleted,
|
||||
max_count=max_count,
|
||||
desc=desc,
|
||||
lang=lang,
|
||||
frontend_user=frontend_user,
|
||||
api_client=api_client,
|
||||
db=db,
|
||||
)
|
||||
@@ -265,9 +279,12 @@ def query_user_messages_cursor(
|
||||
|
||||
@router.delete("/{user_id}/messages", status_code=HTTP_204_NO_CONTENT)
|
||||
def mark_user_messages_deleted(
|
||||
user_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
|
||||
user_id: UUID,
|
||||
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
pr = PromptRepository(db, api_client)
|
||||
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
|
||||
messages = pr.query_messages_ordered_by_created_date(user_id=user_id, limit=None)
|
||||
pr.mark_messages_deleted(messages)
|
||||
|
||||
|
||||
@@ -14,6 +14,8 @@ def prepare_message(m: Message) -> protocol.Message:
|
||||
lang=m.lang,
|
||||
is_assistant=(m.role == "assistant"),
|
||||
created_date=m.created_date,
|
||||
emojis=m.emojis or {},
|
||||
user_emojis=m.user_emojis or [],
|
||||
)
|
||||
|
||||
|
||||
@@ -21,17 +23,20 @@ def prepare_message_list(messages: list[Message]) -> list[protocol.Message]:
|
||||
return [prepare_message(m) for m in messages]
|
||||
|
||||
|
||||
def prepare_conversation_message(message: Message) -> protocol.ConversationMessage:
|
||||
return protocol.ConversationMessage(
|
||||
id=message.id,
|
||||
frontend_message_id=message.frontend_message_id,
|
||||
text=message.text,
|
||||
lang=message.lang,
|
||||
is_assistant=(message.role == "assistant"),
|
||||
emojis=message.emojis or {},
|
||||
user_emojis=message.user_emojis or [],
|
||||
)
|
||||
|
||||
|
||||
def prepare_conversation_message_list(messages: list[Message]) -> list[protocol.ConversationMessage]:
|
||||
return [
|
||||
protocol.ConversationMessage(
|
||||
id=message.id,
|
||||
frontend_message_id=message.frontend_message_id,
|
||||
text=message.text,
|
||||
lang=message.lang,
|
||||
is_assistant=(message.role == "assistant"),
|
||||
)
|
||||
for message in messages
|
||||
]
|
||||
return [prepare_conversation_message(message) for message in messages]
|
||||
|
||||
|
||||
def prepare_conversation(messages: list[Message]) -> protocol.Conversation:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from oasst_shared.schemas.protocol import TextLabel
|
||||
from pydantic import AnyHttpUrl, BaseModel, BaseSettings, FilePath, PostgresDsn, validator
|
||||
|
||||
|
||||
@@ -46,17 +46,69 @@ class TreeManagerConfiguration(BaseModel):
|
||||
num_required_rankings: int = 3
|
||||
"""Number of rankings in which the message participated."""
|
||||
|
||||
mandatory_labels_initial_prompt: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam]
|
||||
labels_initial_prompt: list[TextLabel] = [
|
||||
TextLabel.spam,
|
||||
TextLabel.quality,
|
||||
TextLabel.helpfulness,
|
||||
TextLabel.creativity,
|
||||
TextLabel.humor,
|
||||
TextLabel.toxicity,
|
||||
TextLabel.violence,
|
||||
TextLabel.not_appropriate,
|
||||
TextLabel.pii,
|
||||
TextLabel.hate_speech,
|
||||
TextLabel.sexual_content,
|
||||
]
|
||||
|
||||
labels_assistant_reply: list[TextLabel] = [
|
||||
TextLabel.spam,
|
||||
TextLabel.fails_task,
|
||||
TextLabel.quality,
|
||||
TextLabel.helpfulness,
|
||||
TextLabel.creativity,
|
||||
TextLabel.humor,
|
||||
TextLabel.toxicity,
|
||||
TextLabel.violence,
|
||||
TextLabel.not_appropriate,
|
||||
TextLabel.pii,
|
||||
TextLabel.hate_speech,
|
||||
TextLabel.sexual_content,
|
||||
]
|
||||
|
||||
labels_prompter_reply: list[TextLabel] = [
|
||||
TextLabel.spam,
|
||||
TextLabel.quality,
|
||||
TextLabel.helpfulness,
|
||||
TextLabel.humor,
|
||||
TextLabel.creativity,
|
||||
TextLabel.toxicity,
|
||||
TextLabel.violence,
|
||||
TextLabel.not_appropriate,
|
||||
TextLabel.pii,
|
||||
TextLabel.hate_speech,
|
||||
TextLabel.sexual_content,
|
||||
]
|
||||
|
||||
mandatory_labels_initial_prompt: Optional[list[TextLabel]] = [TextLabel.spam]
|
||||
"""Mandatory labels in text-labeling tasks for initial prompts."""
|
||||
|
||||
mandatory_labels_assistant_reply: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam]
|
||||
mandatory_labels_assistant_reply: Optional[list[TextLabel]] = [TextLabel.spam]
|
||||
"""Mandatory labels in text-labeling tasks for assistant replies."""
|
||||
|
||||
mandatory_labels_prompter_reply: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam]
|
||||
mandatory_labels_prompter_reply: Optional[list[TextLabel]] = [TextLabel.spam]
|
||||
"""Mandatory labels in text-labeling tasks for prompter replies."""
|
||||
|
||||
rank_prompter_replies: bool = False
|
||||
|
||||
lonely_children_count: int = 3
|
||||
"""Number of children below which parents are preferred during sampling for reply tasks."""
|
||||
|
||||
p_lonely_child_extension: float = 0.8
|
||||
"""Probability to select a parent with less than lonely_children_count children."""
|
||||
|
||||
recent_tasks_span_sec: int = 3 * 60 # 3 min
|
||||
"""Time in seconds of recent tasks to consider for exclusion during task selection."""
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
PROJECT_NAME: str = "open-assistant backend"
|
||||
@@ -90,10 +142,13 @@ class Settings(BaseSettings):
|
||||
Path(__file__).parent.parent / "test_data/realistic/realistic_seed_data.json"
|
||||
)
|
||||
DEBUG_ALLOW_SELF_LABELING: bool = False # allow users to label their own messages
|
||||
DEBUG_ALLOW_DUPLICATE_TASKS: bool = False # offer users tasks to which they already responded
|
||||
DEBUG_SKIP_EMBEDDING_COMPUTATION: bool = False
|
||||
DEBUG_SKIP_TOXICITY_CALCULATION: bool = False
|
||||
DEBUG_DATABASE_ECHO: bool = False
|
||||
|
||||
DUPLICATE_MESSAGE_FILTER_WINDOW_MINUTES: int = 120
|
||||
|
||||
HUGGING_FACE_API_KEY: str = ""
|
||||
|
||||
ROOT_TOKENS: List[str] = ["1234"] # supply a string that can be parsed to a json list
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
__all__ = []
|
||||
@@ -1,56 +0,0 @@
|
||||
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union
|
||||
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import Session, SQLModel
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=SQLModel)
|
||||
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
|
||||
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
|
||||
|
||||
|
||||
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
def __init__(self, model: Type[ModelType]):
|
||||
"""
|
||||
CRUD object with default methods to Create, Read, Update, Delete (CRUD).
|
||||
|
||||
**Parameters**
|
||||
|
||||
* `model`: A SQLModel model class
|
||||
* `schema`: A Pydantic model (schema) class
|
||||
"""
|
||||
self.model = model
|
||||
|
||||
def get(self, db: Session, id: Any) -> Optional[ModelType]:
|
||||
return db.query(self.model).filter(self.model.id == id).first()
|
||||
|
||||
def get_multi(self, db: Session, *, begin_id: int = 0, limit: int = 100) -> List[ModelType]:
|
||||
return db.query(self.model).filter(self.model.id >= begin_id).limit(limit).all()
|
||||
|
||||
def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType:
|
||||
obj_in_data = jsonable_encoder(obj_in)
|
||||
db_obj = self.model(**obj_in_data) # type: ignore
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
def update(self, db: Session, *, db_obj: ModelType, obj_in: Union[UpdateSchemaType, Dict[str, Any]]) -> ModelType:
|
||||
obj_data = jsonable_encoder(db_obj)
|
||||
if isinstance(obj_in, dict):
|
||||
update_data = obj_in
|
||||
else:
|
||||
update_data = obj_in.dict(exclude_unset=True)
|
||||
for field in obj_data:
|
||||
if field in update_data:
|
||||
setattr(db_obj, field, update_data[field])
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
def delete(self, db: Session, *, id: int) -> ModelType:
|
||||
obj = db.query(self.model).get(id)
|
||||
db.delete(obj)
|
||||
db.commit()
|
||||
return obj
|
||||
@@ -2,6 +2,7 @@ from .api_client import ApiClient
|
||||
from .journal import Journal, JournalIntegration
|
||||
from .message import Message
|
||||
from .message_embedding import MessageEmbedding
|
||||
from .message_emoji import MessageEmoji
|
||||
from .message_reaction import MessageReaction
|
||||
from .message_toxicity import MessageToxicity
|
||||
from .message_tree_state import MessageTreeState
|
||||
@@ -24,4 +25,5 @@ __all__ = [
|
||||
"TextLabels",
|
||||
"Journal",
|
||||
"JournalIntegration",
|
||||
"MessageEmoji",
|
||||
]
|
||||
|
||||
@@ -117,7 +117,8 @@ class LabelConversationReplyPayload(TaskPayload):
|
||||
|
||||
message_id: UUID
|
||||
conversation: protocol_schema.Conversation
|
||||
reply: str
|
||||
reply: str # deprecated
|
||||
reply_message: Optional[protocol_schema.ConversationMessage]
|
||||
valid_labels: list[str]
|
||||
mandatory_labels: Optional[list[str]]
|
||||
mode: Optional[protocol_schema.LabelTaskMode]
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from datetime import datetime
|
||||
from http import HTTPStatus
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from oasst_backend.models.db_payload import MessagePayload
|
||||
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
||||
from pydantic import PrivateAttr
|
||||
from sqlalchemy import false
|
||||
from sqlmodel import Field, Index, SQLModel
|
||||
|
||||
@@ -17,6 +18,13 @@ class Message(SQLModel, table=True):
|
||||
__tablename__ = "message"
|
||||
__table_args__ = (Index("ix_message_frontend_message_id", "api_client_id", "frontend_message_id", unique=True),)
|
||||
|
||||
def __new__(cls, *args: Any, **kwargs: Any):
|
||||
new_object = super().__new__(cls, *args, **kwargs)
|
||||
# temporary fix until https://github.com/tiangolo/sqlmodel/issues/149 gets merged
|
||||
if not hasattr(new_object, "_user_emojis"):
|
||||
new_object._init_private_attributes()
|
||||
return new_object
|
||||
|
||||
id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(
|
||||
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
|
||||
@@ -49,11 +57,24 @@ class Message(SQLModel, table=True):
|
||||
|
||||
rank: Optional[int] = Field(nullable=True)
|
||||
|
||||
emojis: Optional[dict[str, int]] = Field(default=None, sa_column=sa.Column(pg.JSONB), nullable=False)
|
||||
_user_emojis: Optional[list[str]] = PrivateAttr(default=None)
|
||||
|
||||
def ensure_is_message(self) -> None:
|
||||
if not self.payload or not isinstance(self.payload.payload, MessagePayload):
|
||||
raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
|
||||
def has_emoji(self, emoji_code: str) -> bool:
|
||||
return self.emojis and emoji_code in self.emojis and self.emojis[emoji_code] > 0
|
||||
|
||||
def has_user_emoji(self, emoji_code: str) -> bool:
|
||||
return self._user_emojis and emoji_code in self._user_emojis
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
self.ensure_is_message()
|
||||
return self.payload.payload.text
|
||||
|
||||
@property
|
||||
def user_emojis(self) -> str:
|
||||
return self._user_emojis
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from sqlmodel import Field, Index, SQLModel
|
||||
|
||||
|
||||
class MessageEmoji(SQLModel, table=True):
|
||||
__tablename__ = "message_emoji"
|
||||
__table_args__ = (Index("ix_message_emoji__user_id__message_id", "user_id", "message_id", unique=False),)
|
||||
|
||||
message_id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(
|
||||
pg.UUID(as_uuid=True), sa.ForeignKey("message.id", ondelete="CASCADE"), nullable=False, primary_key=True
|
||||
)
|
||||
)
|
||||
user_id: UUID = Field(
|
||||
sa_column=sa.Column(
|
||||
pg.UUID(as_uuid=True), sa.ForeignKey("user.id", ondelete="CASCADE"), nullable=False, primary_key=True
|
||||
)
|
||||
)
|
||||
emoji: str = Field(nullable=False, max_length=128, primary_key=True)
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp())
|
||||
)
|
||||
@@ -26,3 +26,4 @@ class MessageReaction(SQLModel, table=True):
|
||||
payload_type: str = Field(nullable=False, max_length=200)
|
||||
payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False))
|
||||
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
|
||||
message_id: Optional[UUID] = Field(nullable=True, index=True)
|
||||
|
||||
@@ -20,7 +20,9 @@ class Task(SQLModel, table=True):
|
||||
),
|
||||
)
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp()),
|
||||
sa_column=sa.Column(
|
||||
sa.DateTime(timezone=True), nullable=False, index=True, server_default=sa.func.current_timestamp()
|
||||
),
|
||||
)
|
||||
expiry_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True))
|
||||
user_id: Optional[UUID] = Field(nullable=True, foreign_key="user.id", index=True)
|
||||
|
||||
@@ -27,3 +27,4 @@ class TextLabels(SQLModel, table=True):
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("message.id"), nullable=True)
|
||||
)
|
||||
labels: dict[str, float] = Field(default={}, sa_column=sa.Column(pg.JSONB), nullable=False)
|
||||
task_id: Optional[UUID] = Field(nullable=True, index=True)
|
||||
|
||||
@@ -30,6 +30,7 @@ class User(SQLModel, table=True):
|
||||
enabled: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=sa.true()))
|
||||
notes: str = Field(sa_column=sa.Column(AutoString(length=1024), nullable=False, server_default=""))
|
||||
deleted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=sa.false()))
|
||||
show_on_leaderboard: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=sa.true()))
|
||||
|
||||
def to_protocol_frontend_user(self):
|
||||
return protocol.FrontEndUser(
|
||||
|
||||
@@ -1,18 +1,22 @@
|
||||
import random
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from http import HTTPStatus
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import oasst_backend.models.db_payload as db_payload
|
||||
import sqlalchemy as sa
|
||||
from loguru import logger
|
||||
from oasst_backend.api.deps import FrontendUserId
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.journal_writer import JournalWriter
|
||||
from oasst_backend.models import (
|
||||
ApiClient,
|
||||
Message,
|
||||
MessageEmbedding,
|
||||
MessageEmoji,
|
||||
MessageReaction,
|
||||
MessageToxicity,
|
||||
MessageTreeState,
|
||||
@@ -28,8 +32,10 @@ from oasst_backend.utils.database_utils import CommitMode, managed_tx_method
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from oasst_shared.schemas.protocol import SystemStats
|
||||
from oasst_shared.utils import unaware_to_utc
|
||||
from sqlmodel import Session, and_, func, not_, or_, text, update
|
||||
from oasst_shared.utils import unaware_to_utc, utcnow
|
||||
from sqlalchemy.orm import Query
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
from sqlmodel import JSON, Session, and_, func, literal_column, not_, or_, text, update
|
||||
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
@@ -39,14 +45,30 @@ class PromptRepository:
|
||||
db: Session,
|
||||
api_client: ApiClient,
|
||||
client_user: Optional[protocol_schema.User] = None,
|
||||
*,
|
||||
user_repository: Optional[UserRepository] = None,
|
||||
task_repository: Optional[TaskRepository] = None,
|
||||
user_id: Optional[UUID] = None,
|
||||
auth_method: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
frontend_user: Optional[FrontendUserId] = None,
|
||||
):
|
||||
self.db = db
|
||||
self.api_client = api_client
|
||||
self.user_repository = user_repository or UserRepository(db, api_client)
|
||||
self.user = self.user_repository.lookup_client_user(client_user, create_missing=True)
|
||||
self.user_id = self.user.id if self.user else None
|
||||
|
||||
if frontend_user and not auth_method and not username:
|
||||
auth_method, username = frontend_user
|
||||
|
||||
if user_id:
|
||||
self.user = self.user_repository.get_user(id=user_id)
|
||||
self.user_id = self.user.id
|
||||
elif auth_method and username:
|
||||
self.user = self.user_repository.query_frontend_user(auth_method=auth_method, username=username)
|
||||
self.user_id = self.user.id
|
||||
else:
|
||||
self.user = self.user_repository.lookup_client_user(client_user, create_missing=True)
|
||||
self.user_id = self.user.id if self.user else None
|
||||
logger.debug(f"PromptRepository(api_client_id={self.api_client.id}, {self.user_id=})")
|
||||
self.task_repository = task_repository or TaskRepository(
|
||||
db, api_client, client_user, user_repository=self.user_repository
|
||||
@@ -168,6 +190,18 @@ class PromptRepository:
|
||||
role = None
|
||||
depth = 0
|
||||
|
||||
# reject whitespaces match with ^\s+$
|
||||
if re.match(r"^\s+$", text):
|
||||
raise OasstError("Message text is empty", OasstErrorCode.TASK_MESSAGE_TEXT_EMPTY)
|
||||
|
||||
# ensure message size is below the predefined limit
|
||||
if len(text) > settings.MESSAGE_SIZE_LIMIT:
|
||||
logger.error(f"Message size {len(text)=} exceeds size limit of {settings.MESSAGE_SIZE_LIMIT=}.")
|
||||
raise OasstError("Message size too long.", OasstErrorCode.TASK_MESSAGE_TOO_LONG)
|
||||
|
||||
if self.check_users_recent_replies_for_duplicates(text):
|
||||
raise OasstError("User recent messages have duplicates", OasstErrorCode.TASK_MESSAGE_DUPLICATED)
|
||||
|
||||
if task.parent_message_id:
|
||||
parent_message = self.fetch_message(task.parent_message_id)
|
||||
|
||||
@@ -245,7 +279,7 @@ class PromptRepository:
|
||||
|
||||
# store reaction to message
|
||||
reaction_payload = db_payload.RatingReactionPayload(rating=rating.rating)
|
||||
reaction = self.insert_reaction(message.id, reaction_payload)
|
||||
reaction = self.insert_reaction(task_id=task.id, payload=reaction_payload, message_id=message.id)
|
||||
if not task.collective:
|
||||
task.done = True
|
||||
self.db.add(task)
|
||||
@@ -255,7 +289,7 @@ class PromptRepository:
|
||||
return reaction
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def store_ranking(self, ranking: protocol_schema.MessageRanking) -> Tuple[MessageReaction, Task]:
|
||||
def store_ranking(self, ranking: protocol_schema.MessageRanking) -> tuple[MessageReaction, Task]:
|
||||
# fetch task
|
||||
task = self.task_repository.fetch_task_by_frontend_message_id(ranking.message_id)
|
||||
self._validate_task(task, frontend_message_id=ranking.message_id)
|
||||
@@ -295,7 +329,7 @@ class PromptRepository:
|
||||
ranking_parent_id=task_payload.ranking_parent_id,
|
||||
message_tree_id=task_payload.message_tree_id,
|
||||
)
|
||||
reaction = self.insert_reaction(task.id, reaction_payload)
|
||||
reaction = self.insert_reaction(task_id=task.id, payload=reaction_payload, message_id=parent_msg.id)
|
||||
self.journal.log_ranking(task, message_id=parent_msg.id, ranking=ranking.ranking)
|
||||
|
||||
logger.info(f"Ranking {ranking.ranking} stored for task {task.id}.")
|
||||
@@ -313,9 +347,8 @@ class PromptRepository:
|
||||
reaction_payload = db_payload.RankingReactionPayload(
|
||||
ranking=ranking.ranking, ranked_message_ids=ranked_message_ids
|
||||
)
|
||||
reaction = self.insert_reaction(task.id, reaction_payload)
|
||||
# TODO: resolve message_id
|
||||
self.journal.log_ranking(task, message_id=None, ranking=ranking.ranking)
|
||||
reaction = self.insert_reaction(task_id=task.id, payload=reaction_payload, message_id=None)
|
||||
# self.journal.log_ranking(task, message_id=None, ranking=ranking.ranking)
|
||||
|
||||
logger.info(f"Ranking {ranking.ranking} stored for task {task.id}.")
|
||||
|
||||
@@ -346,13 +379,13 @@ class PromptRepository:
|
||||
return message_toxicity
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def insert_message_embedding(self, message_id: UUID, model: str, embedding: List[float]) -> MessageEmbedding:
|
||||
def insert_message_embedding(self, message_id: UUID, model: str, embedding: list[float]) -> MessageEmbedding:
|
||||
"""Insert the embedding of a new message in the database.
|
||||
|
||||
Args:
|
||||
message_id (UUID): the identifier of the message we want to save its embedding
|
||||
model (str): the model used for creating the embedding
|
||||
embedding (List[float]): the values obtained from the message & model
|
||||
embedding (list[float]): the values obtained from the message & model
|
||||
|
||||
Raises:
|
||||
OasstError: if misses some of the before params
|
||||
@@ -366,7 +399,9 @@ class PromptRepository:
|
||||
return message_embedding
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def insert_reaction(self, task_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction:
|
||||
def insert_reaction(
|
||||
self, task_id: UUID, payload: db_payload.ReactionPayload, message_id: Optional[UUID]
|
||||
) -> MessageReaction:
|
||||
self.ensure_user_is_enabled()
|
||||
|
||||
container = PayloadContainer(payload=payload)
|
||||
@@ -376,12 +411,13 @@ class PromptRepository:
|
||||
payload=container,
|
||||
api_client_id=self.api_client.id,
|
||||
payload_type=type(payload).__name__,
|
||||
message_id=message_id,
|
||||
)
|
||||
self.db.add(reaction)
|
||||
return reaction
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def store_text_labels(self, text_labels: protocol_schema.TextLabels) -> Tuple[TextLabels, Task, Message]:
|
||||
def store_text_labels(self, text_labels: protocol_schema.TextLabels) -> tuple[TextLabels, Task, Message]:
|
||||
|
||||
valid_labels: Optional[list[str]] = None
|
||||
mandatory_labels: Optional[list[str]] = None
|
||||
@@ -441,11 +477,24 @@ class PromptRepository:
|
||||
user_id=self.user_id,
|
||||
text=text_labels.text,
|
||||
labels=text_labels.labels,
|
||||
task_id=task.id if task else None,
|
||||
)
|
||||
|
||||
if message_id:
|
||||
message = self.fetch_message(message_id)
|
||||
if task:
|
||||
if not task:
|
||||
if text_labels.is_report is True:
|
||||
message = self.handle_message_emoji(
|
||||
message_id, protocol_schema.EmojiOp.add, protocol_schema.EmojiCode.red_flag
|
||||
)
|
||||
|
||||
# update existing record for repeated updates (same user no task associated)
|
||||
existing_text_label = self.fetch_non_task_text_labels(message_id, self.user_id)
|
||||
if existing_text_label is not None:
|
||||
existing_text_label.labels = text_labels.labels
|
||||
model = existing_text_label
|
||||
|
||||
else:
|
||||
message = self.fetch_message(message_id)
|
||||
message.review_count += 1
|
||||
self.db.add(message)
|
||||
|
||||
@@ -519,6 +568,46 @@ class PromptRepository:
|
||||
qry = qry.filter(Message.review_result)
|
||||
if not include_deleted:
|
||||
qry = qry.filter(not_(Message.deleted))
|
||||
return self._add_user_emojis_all(qry)
|
||||
|
||||
def check_users_recent_replies_for_duplicates(self, text: str) -> bool:
|
||||
"""
|
||||
Checks if the user has recently replied with the same text within a given time period.
|
||||
"""
|
||||
|
||||
user_id = self.user_id
|
||||
logger.debug(f"Checking for duplicate tasks for user {user_id}")
|
||||
# messages in the past 24 hours
|
||||
messages = (
|
||||
self.db.query(Message)
|
||||
.filter(Message.user_id == user_id)
|
||||
.order_by(Message.created_date.desc())
|
||||
.filter(
|
||||
Message.created_date > utcnow() - timedelta(minutes=settings.DUPLICATE_MESSAGE_FILTER_WINDOW_MINUTES)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
if not messages:
|
||||
return False
|
||||
for msg in messages:
|
||||
if msg.text == text:
|
||||
return True
|
||||
return False
|
||||
|
||||
def fetch_user_message_trees(
|
||||
self, user_id: Message.user_id, reviewed: bool = True, include_deleted: bool = False
|
||||
) -> list[Message]:
|
||||
qry = self.db.query(Message).filter(Message.user_id == user_id)
|
||||
if reviewed:
|
||||
qry = qry.filter(Message.review_result)
|
||||
if not include_deleted:
|
||||
qry = qry.filter(not_(Message.deleted))
|
||||
return self._add_user_emojis_all(qry)
|
||||
|
||||
def fetch_message_trees_ready_for_export(self) -> list[MessageTreeState]:
|
||||
qry = self.db.query(MessageTreeState).filter(
|
||||
MessageTreeState.state == message_tree_state.State.READY_FOR_EXPORT
|
||||
)
|
||||
return qry.all()
|
||||
|
||||
def fetch_multiple_random_replies(self, max_size: int = 5, message_role: str = None):
|
||||
@@ -556,11 +645,25 @@ class PromptRepository:
|
||||
return conversation, replies
|
||||
|
||||
def fetch_message(self, message_id: UUID, fail_if_missing: bool = True) -> Optional[Message]:
|
||||
qry = self.db.query(Message).filter(Message.id == message_id)
|
||||
messages = self._add_user_emojis_all(qry)
|
||||
message = messages[0] if messages else None
|
||||
|
||||
message = self.db.query(Message).filter(Message.id == message_id).one_or_none()
|
||||
if fail_if_missing and not message:
|
||||
raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTP_404_NOT_FOUND)
|
||||
return message
|
||||
|
||||
def fetch_non_task_text_labels(self, message_id: UUID, user_id: UUID) -> Optional[TextLabels]:
|
||||
|
||||
query = (
|
||||
self.db.query(TextLabels)
|
||||
.outerjoin(Task, Task.id == TextLabels.id)
|
||||
.filter(Task.id.is_(None), TextLabels.message_id == message_id, TextLabels.user_id == user_id)
|
||||
)
|
||||
text_label = query.one_or_none()
|
||||
return text_label
|
||||
|
||||
@staticmethod
|
||||
def trace_conversation(messages: list[Message] | dict[UUID, Message], last_message: Message) -> list[Message]:
|
||||
"""
|
||||
@@ -620,9 +723,27 @@ class PromptRepository:
|
||||
qry = qry.filter(Message.review_result)
|
||||
if exclude_deleted:
|
||||
qry = qry.filter(Message.deleted == sa.false())
|
||||
children = qry.all()
|
||||
children = self._add_user_emojis_all(qry)
|
||||
return children
|
||||
|
||||
def fetch_message_siblings(
|
||||
self, message: Message | UUID, reviewed: Optional[bool] = True, deleted: Optional[bool] = False
|
||||
) -> list[Message]:
|
||||
"""
|
||||
Get siblings of a message (other messages with the same parent_id)
|
||||
"""
|
||||
if isinstance(message, Message):
|
||||
message = message.id
|
||||
|
||||
parent_qry = self.db.query(Message.parent_id).filter(Message.id == message).subquery()
|
||||
qry = self.db.query(Message).filter(Message.parent_id == parent_qry.c.parent_id)
|
||||
if reviewed is not None:
|
||||
qry = qry.filter(Message.review_result == reviewed)
|
||||
if deleted is not None:
|
||||
qry = qry.filter(Message.deleted == deleted)
|
||||
siblings = self._add_user_emojis_all(qry)
|
||||
return siblings
|
||||
|
||||
@staticmethod
|
||||
def trace_descendants(root: Message, messages: list[Message]) -> list[Message]:
|
||||
children = defaultdict(list)
|
||||
@@ -651,7 +772,7 @@ class PromptRepository:
|
||||
if max_depth is not None:
|
||||
desc = desc.filter(Message.depth <= max_depth)
|
||||
|
||||
desc = desc.all()
|
||||
desc = self._add_user_emojis_all(desc)
|
||||
|
||||
return self.trace_descendants(message, desc)
|
||||
|
||||
@@ -665,6 +786,33 @@ class PromptRepository:
|
||||
max_message = max(tree, key=lambda m: m.children_count)
|
||||
return max_message, [m for m in tree if m.parent_id == max_message.id]
|
||||
|
||||
def _add_user_emojis_all(self, qry: Query) -> list[Message]:
|
||||
if self.user_id is None:
|
||||
return qry.all()
|
||||
|
||||
sq = qry.subquery("m")
|
||||
qry = (
|
||||
self.db.query(Message, func.string_agg(MessageEmoji.emoji, literal_column("','")).label("user_emojis"))
|
||||
.select_entity_from(sq)
|
||||
.outerjoin(
|
||||
MessageEmoji,
|
||||
and_(
|
||||
sq.c.id == MessageEmoji.message_id,
|
||||
MessageEmoji.user_id == self.user_id,
|
||||
sq.c.emojis != JSON.NULL,
|
||||
),
|
||||
)
|
||||
.group_by(sq)
|
||||
)
|
||||
messages: list[Message] = []
|
||||
for x in qry:
|
||||
m: Message = x.Message
|
||||
user_emojis = x["user_emojis"]
|
||||
if user_emojis:
|
||||
m._user_emojis = user_emojis.split(",")
|
||||
messages.append(m)
|
||||
return messages
|
||||
|
||||
def query_messages_ordered_by_created_date(
|
||||
self,
|
||||
user_id: Optional[UUID] = None,
|
||||
@@ -679,6 +827,7 @@ class PromptRepository:
|
||||
deleted: Optional[bool] = None,
|
||||
desc: bool = False,
|
||||
limit: Optional[int] = 100,
|
||||
lang: Optional[str] = None,
|
||||
) -> list[Message]:
|
||||
if not self.api_client.trusted:
|
||||
if not api_client_id:
|
||||
@@ -693,7 +842,7 @@ class PromptRepository:
|
||||
if user_id:
|
||||
qry = qry.filter(Message.user_id == user_id)
|
||||
if username or auth_method:
|
||||
if not username and auth_method:
|
||||
if not (username and auth_method):
|
||||
raise OasstError("Auth method or username missing.", OasstErrorCode.AUTH_AND_USERNAME_REQUIRED)
|
||||
qry = qry.join(User)
|
||||
qry = qry.filter(User.username == username, User.auth_method == auth_method)
|
||||
@@ -743,7 +892,10 @@ class PromptRepository:
|
||||
if limit is not None:
|
||||
qry = qry.limit(limit)
|
||||
|
||||
return qry.all()
|
||||
if lang is not None:
|
||||
qry = qry.filter(Message.lang == lang)
|
||||
|
||||
return self._add_user_emojis_all(qry)
|
||||
|
||||
def update_children_counts(self, message_tree_id: UUID):
|
||||
sql_update_children_count = """
|
||||
@@ -805,3 +957,82 @@ WHERE message.id = cc.id;
|
||||
deleted=result.get(True, 0),
|
||||
message_trees=result.get(None, 0),
|
||||
)
|
||||
|
||||
def handle_message_emoji(self, message_id: UUID, op: protocol_schema.EmojiOp, emoji: protocol_schema) -> Message:
|
||||
self.ensure_user_is_enabled()
|
||||
|
||||
message = self.fetch_message(message_id)
|
||||
|
||||
# check if emoji exists
|
||||
existing_emoji = (
|
||||
self.db.query(MessageEmoji)
|
||||
.filter(
|
||||
MessageEmoji.message_id == message_id, MessageEmoji.user_id == self.user_id, MessageEmoji.emoji == emoji
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
if existing_emoji:
|
||||
if op == protocol_schema.EmojiOp.add:
|
||||
logger.info(f"Emoji record already exists {message_id=}, {emoji=}, {self.user_id=}")
|
||||
return message
|
||||
elif op == protocol_schema.EmojiOp.togggle:
|
||||
op = protocol_schema.EmojiOp.remove
|
||||
|
||||
if existing_emoji is None:
|
||||
if op == protocol_schema.EmojiOp.remove:
|
||||
logger.info(f"Emoji record not found {message_id=}, {emoji=}, {self.user_id=}")
|
||||
return message
|
||||
elif op == protocol_schema.EmojiOp.togggle:
|
||||
op = protocol_schema.EmojiOp.add
|
||||
|
||||
if op == protocol_schema.EmojiOp.add:
|
||||
# hard coded exclusivity of thumbs_up & thumbs_down
|
||||
if emoji == protocol_schema.EmojiCode.thumbs_up and message.has_user_emoji(
|
||||
protocol_schema.EmojiCode.thumbs_down.value
|
||||
):
|
||||
message = self.handle_message_emoji(
|
||||
message_id, protocol_schema.EmojiOp.remove, protocol_schema.EmojiCode.thumbs_down
|
||||
)
|
||||
elif emoji == protocol_schema.EmojiCode.thumbs_down and message.has_user_emoji(
|
||||
protocol_schema.EmojiCode.thumbs_up.value
|
||||
):
|
||||
message = self.handle_message_emoji(
|
||||
message_id, protocol_schema.EmojiOp.remove, protocol_schema.EmojiCode.thumbs_up
|
||||
)
|
||||
|
||||
# insert emoji record & increment count
|
||||
message_emoji = MessageEmoji(message_id=message.id, user_id=self.user_id, emoji=emoji)
|
||||
self.db.add(message_emoji)
|
||||
emoji_counts = message.emojis
|
||||
if not emoji_counts:
|
||||
message.emojis = {emoji.value: 1}
|
||||
else:
|
||||
count = emoji_counts.get(emoji.value) or 0
|
||||
emoji_counts[emoji.value] = count + 1
|
||||
if message._user_emojis is None:
|
||||
message._user_emojis = []
|
||||
if emoji.value not in message._user_emojis:
|
||||
message._user_emojis.append(emoji.value)
|
||||
elif op == protocol_schema.EmojiOp.remove:
|
||||
# remove emoji record and & decrement count
|
||||
message = self.fetch_message(message_id)
|
||||
if message._user_emojis and emoji.value in message._user_emojis:
|
||||
message._user_emojis.remove(emoji.value)
|
||||
self.db.delete(existing_emoji)
|
||||
emoji_counts = message.emojis
|
||||
count = emoji_counts.get(emoji.value)
|
||||
if count is not None:
|
||||
if count == 1:
|
||||
del emoji_counts[emoji.value]
|
||||
else:
|
||||
emoji_counts[emoji.value] = count - 1
|
||||
flag_modified(message, "emojis")
|
||||
self.db.add(message)
|
||||
else:
|
||||
raise OasstError("Emoji op not supported", OasstErrorCode.EMOJI_OP_UNSUPPORTED)
|
||||
|
||||
flag_modified(message, "emojis")
|
||||
self.db.add(message)
|
||||
self.db.flush()
|
||||
return message
|
||||
|
||||
@@ -1,13 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from oasst_shared.schemas.protocol import LabelDescription
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LabelOption(BaseModel):
|
||||
name: str
|
||||
display_text: str
|
||||
help_text: Optional[str]
|
||||
|
||||
|
||||
class ValidLabelsResponse(BaseModel):
|
||||
valid_labels: list[LabelOption]
|
||||
valid_labels: list[LabelDescription]
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from datetime import timedelta
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
@@ -9,7 +10,7 @@ from oasst_backend.user_repository import UserRepository
|
||||
from oasst_backend.utils.database_utils import CommitMode, managed_tx_method
|
||||
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
from sqlmodel import Session, func, or_
|
||||
from starlette.status import HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
@@ -100,6 +101,7 @@ class TaskRepository:
|
||||
message_id=task.message_id,
|
||||
conversation=task.conversation,
|
||||
reply=task.reply,
|
||||
reply_message=task.reply_message,
|
||||
valid_labels=task.valid_labels,
|
||||
mandatory_labels=task.mandatory_labels,
|
||||
mode=task.mode,
|
||||
@@ -111,6 +113,7 @@ class TaskRepository:
|
||||
message_id=task.message_id,
|
||||
conversation=task.conversation,
|
||||
reply=task.reply,
|
||||
reply_message=task.reply_message,
|
||||
valid_labels=task.valid_labels,
|
||||
mandatory_labels=task.mandatory_labels,
|
||||
mode=task.mode,
|
||||
@@ -219,3 +222,16 @@ class TaskRepository:
|
||||
def fetch_task_by_id(self, task_id: UUID) -> Task:
|
||||
task = self.db.query(Task).filter(Task.api_client_id == self.api_client.id, Task.id == task_id).one_or_none()
|
||||
return task
|
||||
|
||||
def fetch_recent_reply_tasks(
|
||||
self, max_age: timedelta = timedelta(minutes=5), done: bool = False, limit: int = 100
|
||||
) -> list[Task]:
|
||||
qry = self.db.query(Task).filter(
|
||||
func.age(Task.created_date) < max_age,
|
||||
or_(Task.payload_type == "AssistantReplyPayload", Task.payload_type == "PrompterReplyPayload"),
|
||||
)
|
||||
if done is not None:
|
||||
qry = qry.filter(Task.done == done)
|
||||
if limit:
|
||||
qry = qry.limit(limit)
|
||||
return qry.all()
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import json
|
||||
import random
|
||||
from datetime import datetime
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
@@ -7,11 +9,17 @@ from uuid import UUID
|
||||
|
||||
import numpy as np
|
||||
import pydantic
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from loguru import logger
|
||||
from oasst_backend.api.v1.utils import prepare_conversation, prepare_conversation_message_list
|
||||
from oasst_backend.api.v1.utils import (
|
||||
prepare_conversation,
|
||||
prepare_conversation_message,
|
||||
prepare_conversation_message_list,
|
||||
)
|
||||
from oasst_backend.config import TreeManagerConfiguration, settings
|
||||
from oasst_backend.models import Message, MessageReaction, MessageTreeState, Task, TextLabels, User, message_tree_state
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.utils import tree_export
|
||||
from oasst_backend.utils.database_utils import CommitMode, async_managed_tx_method, managed_tx_method
|
||||
from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingModel, HfUrl, HuggingFaceAPI
|
||||
from oasst_backend.utils.ranking import ranked_pairs
|
||||
@@ -37,8 +45,9 @@ class TaskRole(Enum):
|
||||
|
||||
class ActiveTreeSizeRow(pydantic.BaseModel):
|
||||
message_tree_id: UUID
|
||||
tree_size: int
|
||||
goal_tree_size: int
|
||||
tree_size: int
|
||||
awaiting_review: Optional[int]
|
||||
|
||||
@property
|
||||
def remaining_messages(self) -> int:
|
||||
@@ -89,8 +98,6 @@ class TreeManagerStats(pydantic.BaseModel):
|
||||
|
||||
|
||||
class TreeManager:
|
||||
_all_text_labels = list(map(lambda x: x.value, protocol_schema.TextLabel))
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db: Session,
|
||||
@@ -197,7 +204,7 @@ class TreeManager:
|
||||
lang = "en"
|
||||
logger.warning("Task availability request without lang tag received, assuming lang='en'.")
|
||||
|
||||
num_active_trees = self.query_num_active_trees(lang=lang)
|
||||
num_active_trees = self.query_num_active_trees(lang=lang, exclude_ranking=True)
|
||||
extendible_parents = self.query_extendible_parents(lang=lang)
|
||||
prompts_need_review = self.query_prompts_need_review(lang=lang)
|
||||
replies_need_review = self.query_replies_need_review(lang=lang)
|
||||
@@ -211,6 +218,15 @@ class TreeManager:
|
||||
incomplete_rankings=incomplete_rankings,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_label_descriptions(valid_labels: list[TextLabels]) -> list[protocol_schema.LabelDescription]:
|
||||
return [
|
||||
protocol_schema.LabelDescription(
|
||||
name=l.value, widget=l.widget.value, display_text=l.display_text, help_text=l.help_text
|
||||
)
|
||||
for l in valid_labels
|
||||
]
|
||||
|
||||
def next_task(
|
||||
self,
|
||||
desired_task_type: protocol_schema.TaskRequestType = protocol_schema.TaskRequestType.random,
|
||||
@@ -225,7 +241,7 @@ class TreeManager:
|
||||
lang = "en"
|
||||
logger.warning("Task request without lang tag received, assuming 'en'.")
|
||||
|
||||
num_active_trees = self.query_num_active_trees(lang=lang)
|
||||
num_active_trees = self.query_num_active_trees(lang=lang, exclude_ranking=True)
|
||||
prompts_need_review = self.query_prompts_need_review(lang=lang)
|
||||
replies_need_review = self.query_replies_need_review(lang=lang)
|
||||
extendible_parents = self.query_extendible_parents(lang=lang)
|
||||
@@ -334,6 +350,7 @@ class TreeManager:
|
||||
message_tree_id = messages[-1].message_tree_id
|
||||
|
||||
case TaskType.LABEL_REPLY:
|
||||
|
||||
if task_role == TaskRole.PROMPTER:
|
||||
replies_need_review = list(filter(lambda m: m.role == "prompter", replies_need_review))
|
||||
elif task_role == TaskRole.ASSISTANT:
|
||||
@@ -349,58 +366,99 @@ class TreeManager:
|
||||
self.cfg.p_full_labeling_review_reply_prompter: float = 0.1
|
||||
|
||||
label_mode = protocol_schema.LabelTaskMode.full
|
||||
valid_labels = self._all_text_labels
|
||||
label_disposition = protocol_schema.LabelTaskDisposition.quality
|
||||
|
||||
if message.role == "assistant":
|
||||
valid_labels = self.cfg.labels_assistant_reply
|
||||
if (
|
||||
desired_task_type == protocol_schema.TaskRequestType.random
|
||||
and random.random() > self.cfg.p_full_labeling_review_reply_assistant
|
||||
):
|
||||
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply))
|
||||
label_mode = protocol_schema.LabelTaskMode.simple
|
||||
label_disposition = protocol_schema.LabelTaskDisposition.spam
|
||||
valid_labels = list(self.cfg.mandatory_labels_assistant_reply)
|
||||
if protocol_schema.TextLabel.quality not in valid_labels:
|
||||
valid_labels.append(protocol_schema.TextLabel.quality)
|
||||
|
||||
logger.info(f"Generating a LabelAssistantReplyTask. ({label_mode=:s})")
|
||||
task = protocol_schema.LabelAssistantReplyTask(
|
||||
message_id=message.id,
|
||||
conversation=conversation,
|
||||
reply=message.text,
|
||||
valid_labels=valid_labels,
|
||||
reply_message=prepare_conversation_message(message),
|
||||
valid_labels=list(map(lambda x: x.value, valid_labels)),
|
||||
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)),
|
||||
mode=label_mode,
|
||||
disposition=label_disposition,
|
||||
labels=self._get_label_descriptions(valid_labels),
|
||||
)
|
||||
else:
|
||||
valid_labels = self.cfg.labels_prompter_reply
|
||||
if (
|
||||
desired_task_type == protocol_schema.TaskRequestType.random
|
||||
and random.random() > self.cfg.p_full_labeling_review_reply_prompter
|
||||
):
|
||||
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply))
|
||||
label_mode = protocol_schema.LabelTaskMode.simple
|
||||
label_disposition = protocol_schema.LabelTaskDisposition.spam
|
||||
valid_labels = list(self.cfg.mandatory_labels_prompter_reply)
|
||||
if protocol_schema.TextLabel.quality not in valid_labels:
|
||||
valid_labels.append(protocol_schema.TextLabel.quality)
|
||||
|
||||
logger.info(f"Generating a LabelPrompterReplyTask. ({label_mode=:s})")
|
||||
task = protocol_schema.LabelPrompterReplyTask(
|
||||
message_id=message.id,
|
||||
conversation=conversation,
|
||||
reply=message.text,
|
||||
valid_labels=valid_labels,
|
||||
reply_message=prepare_conversation_message(message),
|
||||
valid_labels=list(map(lambda x: x.value, valid_labels)),
|
||||
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)),
|
||||
mode=label_mode,
|
||||
disposition=label_disposition,
|
||||
labels=self._get_label_descriptions(valid_labels),
|
||||
)
|
||||
|
||||
parent_message_id = message.id
|
||||
message_tree_id = message.message_tree_id
|
||||
|
||||
case TaskType.REPLY:
|
||||
# select a tree with missing replies
|
||||
|
||||
recent_reply_tasks = self.pr.task_repository.fetch_recent_reply_tasks(
|
||||
max_age=timedelta(seconds=self.cfg.recent_tasks_span_sec), done=False
|
||||
)
|
||||
recent_reply_task_parents = {t.parent_message_id for t in recent_reply_tasks}
|
||||
|
||||
if task_role == TaskRole.PROMPTER:
|
||||
extendible_parents = list(filter(lambda x: x.parent_role == "assistant", extendible_parents))
|
||||
elif task_role == TaskRole.ASSISTANT:
|
||||
extendible_parents = list(filter(lambda x: x.parent_role == "prompter", extendible_parents))
|
||||
|
||||
# select a tree with missing replies
|
||||
if len(extendible_parents) > 0:
|
||||
random_parent = random.choice(extendible_parents)
|
||||
random_parent: ExtendibleParentRow = None
|
||||
if self.cfg.p_lonely_child_extension > 0 and self.cfg.lonely_children_count > 1:
|
||||
# check if we have extendible parents with a small number of replies
|
||||
|
||||
lonely_children_parents = [
|
||||
p
|
||||
for p in extendible_parents
|
||||
if 0 < p.active_children_count < self.cfg.lonely_children_count
|
||||
and p.parent_id not in recent_reply_task_parents
|
||||
]
|
||||
if len(lonely_children_parents) > 0 and random.random() < self.cfg.p_lonely_child_extension:
|
||||
random_parent = random.choice(lonely_children_parents)
|
||||
|
||||
if random_parent is None:
|
||||
# try to exclude parents for which tasks were recently handed out
|
||||
fresh_parents = [p for p in extendible_parents if p.parent_id not in recent_reply_task_parents]
|
||||
if len(fresh_parents) > 0:
|
||||
random_parent = random.choice(fresh_parents)
|
||||
else:
|
||||
random_parent = random.choice(extendible_parents)
|
||||
|
||||
# fetch random conversation to extend
|
||||
logger.debug(f"selected {random_parent=}")
|
||||
messages = self.pr.fetch_message_conversation(random_parent.parent_id)
|
||||
assert all(m.review_result for m in messages) # ensure all messages have positive review
|
||||
assert all(m.review_result for m in messages) # ensure all messages have positive reviews
|
||||
conversation = prepare_conversation(messages)
|
||||
|
||||
# generate reply task depending on last message
|
||||
@@ -419,19 +477,23 @@ class TreeManager:
|
||||
message = random.choice(prompts_need_review)
|
||||
|
||||
label_mode = protocol_schema.LabelTaskMode.full
|
||||
valid_labels = self._all_text_labels
|
||||
label_disposition = protocol_schema.LabelTaskDisposition.quality
|
||||
valid_labels = self.cfg.labels_initial_prompt
|
||||
|
||||
if random.random() > self.cfg.p_full_labeling_review_prompt:
|
||||
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt))
|
||||
valid_labels = self.cfg.mandatory_labels_initial_prompt
|
||||
label_mode = protocol_schema.LabelTaskMode.simple
|
||||
label_disposition = protocol_schema.LabelTaskDisposition.spam
|
||||
|
||||
logger.info(f"Generating a LabelInitialPromptTask ({label_mode=:s}).")
|
||||
task = protocol_schema.LabelInitialPromptTask(
|
||||
message_id=message.id,
|
||||
prompt=message.text,
|
||||
valid_labels=valid_labels,
|
||||
valid_labels=list(map(lambda x: x.value, valid_labels)),
|
||||
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt)),
|
||||
mode=label_mode,
|
||||
disposition=label_disposition,
|
||||
labels=self._get_label_descriptions(valid_labels),
|
||||
)
|
||||
|
||||
parent_message_id = message.id
|
||||
@@ -464,14 +526,6 @@ class TreeManager:
|
||||
logger.info(
|
||||
f"Frontend reports text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
|
||||
)
|
||||
|
||||
# ensure message size is below the predefined limit
|
||||
if len(interaction.text) > settings.MESSAGE_SIZE_LIMIT:
|
||||
logger.error(
|
||||
f"Message size {len(interaction.text)=} exceeds size limit of {settings.MESSAGE_SIZE_LIMIT=}."
|
||||
)
|
||||
raise OasstError("Message size too long.", OasstErrorCode.TASK_MESSAGE_TOO_LONG)
|
||||
|
||||
# here we store the text reply in the database
|
||||
message = pr.store_text_reply(
|
||||
text=interaction.text,
|
||||
@@ -502,19 +556,18 @@ class TreeManager:
|
||||
try:
|
||||
model_name: str = HfClassificationModel.TOXIC_ROBERTA.value
|
||||
hugging_face_api: HuggingFaceAPI = HuggingFaceAPI(
|
||||
f"{HfUrl.HUGGINGFACE_FEATURE_EXTRACTION.value}/{model_name}"
|
||||
f"{HfUrl.HUGGINGFACE_TOXIC_CLASSIFICATION.value}/{model_name}"
|
||||
)
|
||||
|
||||
toxicity: List[List[Dict[str, Any]]] = await hugging_face_api.post(interaction.text)
|
||||
toxicity = toxicity[0][0]
|
||||
|
||||
pr.insert_toxicity(
|
||||
message_id=message.id, model=model_name, score=toxicity["score"], label=toxicity["label"]
|
||||
)
|
||||
|
||||
except OasstError:
|
||||
logger.error(
|
||||
f"Could not compute toxicity for text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
|
||||
f"Could not compute toxicity for text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
|
||||
)
|
||||
|
||||
case protocol_schema.MessageRating:
|
||||
@@ -530,9 +583,7 @@ class TreeManager:
|
||||
)
|
||||
|
||||
_, task = pr.store_ranking(interaction)
|
||||
|
||||
ok, rankings_by_message = self.check_condition_for_scoring_state(task.message_tree_id)
|
||||
self.update_message_ranks(task.message_tree_id, rankings_by_message)
|
||||
self.check_condition_for_scoring_state(task.message_tree_id)
|
||||
|
||||
case protocol_schema.TextLabels:
|
||||
logger.info(
|
||||
@@ -541,7 +592,7 @@ class TreeManager:
|
||||
|
||||
_, task, msg = pr.store_text_labels(interaction)
|
||||
|
||||
# if it was a respones for a task, check if we have enough reviews to calc review_result
|
||||
# if it was a response for a task, check if we have enough reviews to calc review_result
|
||||
if task and msg:
|
||||
reviews = self.query_reviews_for_message(msg.id)
|
||||
acceptance_score = self._calculate_acceptance(reviews)
|
||||
@@ -622,8 +673,8 @@ class TreeManager:
|
||||
|
||||
# check if desired tree size has been reached and all nodes have been reviewed
|
||||
tree_size = self.query_tree_size(message_tree_id)
|
||||
if tree_size.remaining_messages > 0:
|
||||
logger.debug(f"False {tree_size.remaining_messages=}")
|
||||
if tree_size.remaining_messages > 0 or tree_size.awaiting_review > 0:
|
||||
logger.debug(f"False {tree_size.remaining_messages=}, {tree_size.awaiting_review=}")
|
||||
return False
|
||||
|
||||
self._enter_state(mts, message_tree_state.State.RANKING)
|
||||
@@ -647,9 +698,12 @@ class TreeManager:
|
||||
return False, None
|
||||
|
||||
self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING)
|
||||
return True, rankings_by_message
|
||||
self.update_message_ranks(message_tree_id, rankings_by_message)
|
||||
return True
|
||||
|
||||
def update_message_ranks(self, message_tree_id: UUID, rankings_by_message: Dict[int, int]) -> bool:
|
||||
def update_message_ranks(
|
||||
self, message_tree_id: UUID, rankings_by_message: dict[UUID, list[MessageReaction]]
|
||||
) -> bool:
|
||||
|
||||
mts = self.pr.fetch_tree_state(message_tree_id)
|
||||
# check state, allow retry if in SCORING_FAILED state
|
||||
@@ -657,19 +711,47 @@ class TreeManager:
|
||||
logger.debug(f"False {mts.active=}, {mts.state=}")
|
||||
return False
|
||||
|
||||
if mts.state == message_tree_state.State.SCORING_FAILED:
|
||||
mts.active = True
|
||||
mts.state = message_tree_state.State.READY_FOR_SCORING
|
||||
|
||||
try:
|
||||
for rankings in rankings_by_message.values():
|
||||
sorted_messages = []
|
||||
for msg_reaction in rankings:
|
||||
sorted_messages.append(msg_reaction.payload.payload.ranked_message_ids)
|
||||
logger.debug(f"SORTED MESSAGE {sorted_messages}")
|
||||
consensus = ranked_pairs(sorted_messages)
|
||||
ordered_ids_list: list[list[UUID]] = [
|
||||
msg_reaction.payload.payload.ranked_message_ids for msg_reaction in rankings
|
||||
]
|
||||
|
||||
common_set: set[UUID] = set.intersection(*map(set, ordered_ids_list))
|
||||
if len(common_set) < 2:
|
||||
logger.warning("The intersection of ranking results ID sets has less than two elements. Skipping.")
|
||||
continue
|
||||
|
||||
# keep only elements in commond set
|
||||
ordered_ids_list = [list(filter(lambda x: x in common_set, ids)) for ids in ordered_ids_list]
|
||||
assert all(len(x) == len(common_set) for x in ordered_ids_list)
|
||||
|
||||
logger.debug(f"SORTED MESSAGE IDS {ordered_ids_list}")
|
||||
consensus = ranked_pairs(ordered_ids_list)
|
||||
assert len(consensus) == len(common_set)
|
||||
logger.debug(f"CONSENSUS: {consensus}\n\n")
|
||||
|
||||
# fetch all siblings and clear ranks
|
||||
siblings = self.pr.fetch_message_siblings(consensus[0], reviewed=None, deleted=None)
|
||||
for m in siblings:
|
||||
m.rank = None
|
||||
self.db.add(m)
|
||||
|
||||
# index by id
|
||||
siblings = {m.id: m for m in siblings}
|
||||
|
||||
# set rank for each message that was part of the common set
|
||||
for rank, message_id in enumerate(consensus):
|
||||
# set rank for each message_id for Message rows
|
||||
msg = self.pr.fetch_message(message_id=message_id, fail_if_missing=True)
|
||||
msg.rank = rank
|
||||
self.db.add(msg)
|
||||
msg = siblings.get(message_id)
|
||||
if msg:
|
||||
msg.rank = rank
|
||||
self.db.add(msg)
|
||||
else:
|
||||
logger.warning(f"Message {message_id=} not found among siblings.")
|
||||
|
||||
except Exception:
|
||||
logger.exception(f"update_message_ranks({message_tree_id=}) failed")
|
||||
@@ -683,57 +765,65 @@ class TreeManager:
|
||||
# calculate acceptance based on spam label
|
||||
return np.mean([1 - l.labels[protocol_schema.TextLabel.spam] for l in labels])
|
||||
|
||||
def query_prompts_need_review(self, lang: str) -> list[Message]:
|
||||
"""
|
||||
Select initial prompt messages with less then required rankings in active message tree
|
||||
(active == True in message_tree_state)
|
||||
"""
|
||||
def _query_need_review(
|
||||
self, state: message_tree_state.State, required_reviews: int, root: bool, lang: str
|
||||
) -> list[Message]:
|
||||
|
||||
qry = (
|
||||
need_review = (
|
||||
self.db.query(Message)
|
||||
.select_from(MessageTreeState)
|
||||
.join(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
|
||||
.filter(
|
||||
MessageTreeState.active,
|
||||
MessageTreeState.state == message_tree_state.State.INITIAL_PROMPT_REVIEW,
|
||||
MessageTreeState.state == state,
|
||||
not_(Message.review_result),
|
||||
not_(Message.deleted),
|
||||
Message.review_count < self.cfg.num_reviews_initial_prompt,
|
||||
Message.parent_id.is_(None),
|
||||
Message.review_count < required_reviews,
|
||||
Message.lang == lang,
|
||||
)
|
||||
)
|
||||
|
||||
if root:
|
||||
need_review = need_review.filter(Message.parent_id.is_(None))
|
||||
else:
|
||||
need_review = need_review.filter(Message.parent_id.is_not(None))
|
||||
|
||||
if not settings.DEBUG_ALLOW_SELF_LABELING:
|
||||
qry = qry.filter(Message.user_id != self.pr.user_id)
|
||||
need_review = need_review.filter(Message.user_id != self.pr.user_id)
|
||||
|
||||
if settings.DEBUG_ALLOW_DUPLICATE_TASKS:
|
||||
qry = need_review
|
||||
else:
|
||||
user_id = self.pr.user_id
|
||||
need_review = need_review.cte(name="need_review")
|
||||
qry = (
|
||||
self.db.query(Message)
|
||||
.select_entity_from(need_review)
|
||||
.outerjoin(TextLabels, need_review.c.id == TextLabels.message_id)
|
||||
.group_by(need_review)
|
||||
.having(
|
||||
func.count(TextLabels.id).filter(TextLabels.task_id.is_not(None), TextLabels.user_id == user_id)
|
||||
== 0
|
||||
)
|
||||
)
|
||||
|
||||
return qry.all()
|
||||
|
||||
def query_prompts_need_review(self, lang: str) -> list[Message]:
|
||||
"""
|
||||
Select initial prompt messages with less then required rankings in active message tree
|
||||
(active == True in message_tree_state)
|
||||
"""
|
||||
return self._query_need_review(
|
||||
message_tree_state.State.INITIAL_PROMPT_REVIEW, self.cfg.num_reviews_initial_prompt, True, lang
|
||||
)
|
||||
|
||||
def query_replies_need_review(self, lang: str) -> list[Message]:
|
||||
"""
|
||||
Select child messages (parent_id IS NOT NULL) with less then required rankings
|
||||
in active message tree (active == True in message_tree_state)
|
||||
"""
|
||||
|
||||
qry = (
|
||||
self.db.query(Message)
|
||||
.select_from(MessageTreeState)
|
||||
.join(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
|
||||
.filter(
|
||||
MessageTreeState.active,
|
||||
MessageTreeState.state == message_tree_state.State.GROWING,
|
||||
not_(Message.review_result),
|
||||
not_(Message.deleted),
|
||||
Message.review_count < self.cfg.num_reviews_reply,
|
||||
Message.parent_id.is_not(None),
|
||||
Message.lang == lang,
|
||||
)
|
||||
)
|
||||
|
||||
if not settings.DEBUG_ALLOW_SELF_LABELING:
|
||||
qry = qry.filter(Message.user_id != self.pr.user_id)
|
||||
|
||||
return qry.all()
|
||||
return self._query_need_review(message_tree_state.State.GROWING, self.cfg.num_reviews_reply, False, lang)
|
||||
|
||||
_sql_find_incomplete_rankings = """
|
||||
-- find incomplete rankings
|
||||
@@ -749,17 +839,28 @@ WHERE mts.active -- only consider active trees
|
||||
AND m.parent_id IS NOT NULL -- ignore initial prompts
|
||||
GROUP BY m.parent_id, m.role
|
||||
HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings
|
||||
"""
|
||||
|
||||
_sql_find_incomplete_rankings_ex = f"""
|
||||
-- incomplete rankings but exclude of current user
|
||||
WITH incomplete_rankings AS ({_sql_find_incomplete_rankings})
|
||||
SELECT ir.* FROM incomplete_rankings ir
|
||||
LEFT JOIN message_reaction mr ON ir.parent_id = mr.message_id AND mr.payload_type = 'RankingReactionPayload'
|
||||
GROUP BY ir.parent_id, ir.role, ir.children_count, ir.child_min_ranking_count, ir.completed_rankings
|
||||
HAVING(COUNT(mr.message_id) FILTER (WHERE mr.user_id = :user_id) = 0)
|
||||
"""
|
||||
|
||||
def query_incomplete_rankings(self, lang: str) -> list[IncompleteRankingsRow]:
|
||||
"""Query parents which have childern that need further rankings"""
|
||||
|
||||
user_id = self.pr.user_id if not settings.DEBUG_ALLOW_DUPLICATE_TASKS else None
|
||||
r = self.db.execute(
|
||||
text(self._sql_find_incomplete_rankings),
|
||||
text(self._sql_find_incomplete_rankings_ex),
|
||||
{
|
||||
"num_required_rankings": self.cfg.num_required_rankings,
|
||||
"ranking_state": message_tree_state.State.RANKING,
|
||||
"lang": lang,
|
||||
"user_id": user_id,
|
||||
},
|
||||
)
|
||||
return [IncompleteRankingsRow.from_orm(x) for x in r.all()]
|
||||
@@ -780,17 +881,20 @@ WHERE mts.active -- only consider active trees
|
||||
AND (c.review_result OR coalesce(c.review_count, 0) < :num_reviews_reply) -- don't count children with negative review but count elements under review
|
||||
GROUP BY m.id, m.role, m.depth, m.message_tree_id, mts.max_children_count
|
||||
HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children
|
||||
AND COUNT(c.id) FILTER (WHERE c.user_id = :user_id) = 0 -- without reply by user
|
||||
"""
|
||||
|
||||
def query_extendible_parents(self, lang: str) -> list[ExtendibleParentRow]:
|
||||
"""Query parent messages that have not reached the maximum number of replies."""
|
||||
|
||||
user_id = self.pr.user_id if not settings.DEBUG_ALLOW_DUPLICATE_TASKS else None
|
||||
r = self.db.execute(
|
||||
text(self._sql_find_extendible_parents),
|
||||
{
|
||||
"growing_state": message_tree_state.State.GROWING,
|
||||
"num_reviews_reply": self.cfg.num_reviews_reply,
|
||||
"lang": lang,
|
||||
"user_id": user_id,
|
||||
},
|
||||
)
|
||||
return [ExtendibleParentRow.from_orm(x) for x in r.all()]
|
||||
@@ -814,12 +918,14 @@ HAVING COUNT(m.id) < mts.goal_tree_size
|
||||
def query_extendible_trees(self, lang: str) -> list[ActiveTreeSizeRow]:
|
||||
"""Query size of active message trees in growing state."""
|
||||
|
||||
user_id = self.pr.user_id if not settings.DEBUG_ALLOW_DUPLICATE_TASKS else None
|
||||
r = self.db.execute(
|
||||
text(self._sql_find_extendible_trees),
|
||||
{
|
||||
"growing_state": message_tree_state.State.GROWING,
|
||||
"num_reviews_reply": self.cfg.num_reviews_reply,
|
||||
"lang": lang,
|
||||
"user_id": user_id,
|
||||
},
|
||||
)
|
||||
return [ActiveTreeSizeRow.from_orm(x) for x in r.all()]
|
||||
@@ -827,18 +933,21 @@ HAVING COUNT(m.id) < mts.goal_tree_size
|
||||
def query_tree_size(self, message_tree_id: UUID) -> ActiveTreeSizeRow:
|
||||
"""Returns the number of reviewed not deleted messages in the message tree."""
|
||||
|
||||
required_reviews = settings.tree_manager.num_reviews_reply
|
||||
qry = (
|
||||
self.db.query(
|
||||
MessageTreeState.message_tree_id.label("message_tree_id"),
|
||||
MessageTreeState.goal_tree_size.label("goal_tree_size"),
|
||||
func.count(Message.id).label("tree_size"),
|
||||
func.count(Message.id).filter(Message.review_result).label("tree_size"),
|
||||
func.count(Message.id)
|
||||
.filter(not_(Message.review_result), Message.review_count < required_reviews)
|
||||
.label("awaiting_review"),
|
||||
)
|
||||
.select_from(MessageTreeState)
|
||||
.outerjoin(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
|
||||
.filter(
|
||||
MessageTreeState.active,
|
||||
not_(Message.deleted),
|
||||
Message.review_result,
|
||||
MessageTreeState.message_tree_id == message_tree_id,
|
||||
)
|
||||
.group_by(MessageTreeState.message_tree_id, MessageTreeState.goal_tree_size)
|
||||
@@ -907,7 +1016,7 @@ INNER JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Ranki
|
||||
return rankings_by_message
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def ensure_tree_states(self):
|
||||
def ensure_tree_states(self) -> None:
|
||||
"""Add message tree state rows for all root nodes (inital prompt messages)."""
|
||||
|
||||
missing_tree_ids = self.query_misssing_tree_states()
|
||||
@@ -919,12 +1028,23 @@ INNER JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Ranki
|
||||
logger.info(f"Inserting missing message tree state for message: {id} ({tree_size=}, {state=:s})")
|
||||
self._insert_default_state(id, state=state)
|
||||
|
||||
def query_num_active_trees(self, lang: str) -> int:
|
||||
rankings = (
|
||||
self.db.query(MessageTreeState).filter(MessageTreeState.state == message_tree_state.State.RANKING).all()
|
||||
)
|
||||
if len(rankings) > 0:
|
||||
logger.info(f"Checking state of {len(rankings)} message trees in ranking state.")
|
||||
for r in rankings:
|
||||
self.check_condition_for_scoring_state(r.message_tree_id)
|
||||
|
||||
def query_num_active_trees(self, lang: str, exclude_ranking: bool = True) -> int:
|
||||
"""Count all active trees (optionally exclude those in ranking state)."""
|
||||
query = (
|
||||
self.db.query(func.count(MessageTreeState.message_tree_id))
|
||||
.join(Message, MessageTreeState.message_tree_id == Message.id)
|
||||
.filter(MessageTreeState.active, Message.lang == lang)
|
||||
)
|
||||
if exclude_ranking:
|
||||
query = query.filter(MessageTreeState.state != message_tree_state.State.RANKING)
|
||||
return query.scalar()
|
||||
|
||||
def query_reviews_for_message(self, message_id: UUID) -> list[TextLabels]:
|
||||
@@ -1150,6 +1270,7 @@ DELETE FROM message WHERE message_tree_id = :message_tree_id;
|
||||
sql_purge_user = """
|
||||
DELETE FROM journal WHERE user_id = :user_id;
|
||||
DELETE FROM message_reaction WHERE user_id = :user_id;
|
||||
DELETE FROM message_emoji WHERE user_id = :user_id;
|
||||
DELETE FROM task WHERE user_id = :user_id;
|
||||
DELETE FROM message WHERE user_id = :user_id;
|
||||
DELETE FROM user_stats WHERE user_id = :user_id;
|
||||
@@ -1161,41 +1282,92 @@ DELETE FROM user_stats WHERE user_id = :user_id;
|
||||
if ban:
|
||||
self.db.execute(update(User).filter(User.id == user_id).values(deleted=True, enabled=False))
|
||||
|
||||
def export_trees_to_file(
|
||||
self,
|
||||
message_tree_ids: list[str],
|
||||
file=None,
|
||||
reviewed: bool = True,
|
||||
include_deleted: bool = False,
|
||||
use_compression: bool = False,
|
||||
) -> None:
|
||||
trees_to_export: List[tree_export.ExportMessageTree] = []
|
||||
|
||||
for message_tree_id in message_tree_ids:
|
||||
messages: List[Message] = self.pr.fetch_message_tree(message_tree_id, reviewed, include_deleted)
|
||||
trees_to_export.append(tree_export.build_export_tree(message_tree_id, messages))
|
||||
|
||||
if file:
|
||||
tree_export.write_trees_to_file(file, trees_to_export, use_compression)
|
||||
else:
|
||||
sys.stdout.write(json.dumps(jsonable_encoder(trees_to_export), indent=2))
|
||||
|
||||
def export_all_ready_trees(
|
||||
self, file: str, reviewed: bool = True, include_deleted: bool = False, use_compression: bool = False
|
||||
) -> None:
|
||||
message_tree_states: MessageTreeState = self.pr.fetch_message_trees_ready_for_export()
|
||||
message_tree_ids = [ms.message_tree_id for ms in message_tree_states]
|
||||
self.export_trees_to_file(message_tree_ids, file, reviewed, include_deleted, use_compression)
|
||||
|
||||
def export_all_user_trees(
|
||||
self,
|
||||
user_id: str,
|
||||
file: str,
|
||||
reviewed: bool = True,
|
||||
include_deleted: bool = False,
|
||||
use_compression: bool = False,
|
||||
) -> None:
|
||||
messages = self.pr.fetch_user_message_trees(UUID(user_id))
|
||||
message_tree_ids = [ms.message_tree_id for ms in messages]
|
||||
self.export_trees_to_file(message_tree_ids, file, reviewed, include_deleted, use_compression)
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def retry_scoring_failed_message_trees(self):
|
||||
query = self.db.query(MessageTreeState.message_tree_id).filter(
|
||||
MessageTreeState.state == message_tree_state.State.SCORING_FAILED
|
||||
)
|
||||
ranking_role_filter = None if self.cfg.rank_prompter_replies else "assistant"
|
||||
for row in query.all():
|
||||
try:
|
||||
message_tree_id = row["message_tree_id"]
|
||||
rankings_by_message = self.query_tree_ranking_results(message_tree_id, role_filter=ranking_role_filter)
|
||||
self.update_message_ranks(message_tree_id=message_tree_id, rankings_by_message=rankings_by_message)
|
||||
except Exception:
|
||||
logger.exception(f"retry_scoring_failed_message_trees failed for ({message_tree_id=})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from oasst_backend.api.deps import api_auth
|
||||
|
||||
# from oasst_backend.api.deps import create_api_client
|
||||
from oasst_backend.database import engine
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
|
||||
with Session(engine) as db:
|
||||
api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=db)
|
||||
# api_client = create_api_client(session=db, description="test", frontend_type="bot")
|
||||
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
|
||||
|
||||
pr = PromptRepository(db=db, api_client=api_client, client_user=dummy_user)
|
||||
|
||||
cfg = TreeManagerConfiguration()
|
||||
tm = TreeManager(db, pr, cfg)
|
||||
tm.ensure_tree_states()
|
||||
|
||||
tm.purge_user_messages(user_id=UUID("2ef9ad21-0dc5-442d-8750-6f7f1790723f"), purge_initial_prompts=False)
|
||||
# tm.purge_user_messages(user_id=UUID("2ef9ad21-0dc5-442d-8750-6f7f1790723f"), purge_initial_prompts=False)
|
||||
# tm.purge_user(user_id=UUID("2ef9ad21-0dc5-442d-8750-6f7f1790723f"))
|
||||
# db.commit()
|
||||
|
||||
# print("query_num_active_trees", tm.query_num_active_trees())
|
||||
# print("query_incomplete_rankings", tm.query_incomplete_rankings())
|
||||
# print("query_replies_need_review", tm.query_replies_need_review())
|
||||
# print("query_incomplete_reply_reviews", tm.query_replies_need_review())
|
||||
# print("query_incomplete_initial_prompt_reviews", tm.query_prompts_need_review())
|
||||
# print("query_extendible_trees", tm.query_extendible_trees())
|
||||
# print("query_extendible_parents", tm.query_extendible_parents())
|
||||
# print("query_tree_size", tm.query_tree_size(message_tree_id=UUID("bdf434cf-4df5-4b74-949c-a5a157bc3292")))
|
||||
|
||||
# print(
|
||||
# "query_reviews_for_message",
|
||||
# tm.query_reviews_for_message(message_id=UUID("6a444493-0d48-4316-a9f1-7e263f5a2473")),
|
||||
# )
|
||||
|
||||
# print("next_task:", tm.next_task())
|
||||
|
||||
# print(
|
||||
# "query_tree_ranking_results", tm.query_tree_ranking_results(UUID("6036f58f-41b5-48c4-bdd9-b16f34ab1312"))
|
||||
# ".query_tree_ranking_results", tm.query_tree_ranking_results(UUID("2ac20d38-6650-43aa-8bb3-f61080c0d921"))
|
||||
# )
|
||||
|
||||
# print(tm.export_trees_to_file(message_tree_ids=["7e75fb38-e664-4e2b-817c-b9a0b01b0074"], file="lol.jsonl"))
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.models import ApiClient, User
|
||||
from oasst_backend.utils.database_utils import CommitMode, managed_tx_method
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlmodel import Session, and_, or_
|
||||
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
|
||||
|
||||
@@ -64,7 +66,13 @@ class UserRepository:
|
||||
return user
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def update_user(self, id: UUID, enabled: Optional[bool] = None, notes: Optional[str] = None) -> None:
|
||||
def update_user(
|
||||
self,
|
||||
id: UUID,
|
||||
enabled: Optional[bool] = None,
|
||||
notes: Optional[str] = None,
|
||||
show_on_leaderboard: Optional[bool] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Update a user by global user ID to disable or set admin notes. Only trusted clients may update users.
|
||||
|
||||
@@ -83,6 +91,8 @@ class UserRepository:
|
||||
user.enabled = enabled
|
||||
if notes is not None:
|
||||
user.notes = notes
|
||||
if show_on_leaderboard is not None:
|
||||
user.show_on_leaderboard = show_on_leaderboard
|
||||
|
||||
self.db.add(user)
|
||||
|
||||
@@ -107,9 +117,7 @@ class UserRepository:
|
||||
self.db.add(user)
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]:
|
||||
if not client_user:
|
||||
return None
|
||||
def _lookup_client_user_tx(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]:
|
||||
user: User = (
|
||||
self.db.query(User)
|
||||
.filter(
|
||||
@@ -135,6 +143,18 @@ class UserRepository:
|
||||
self.db.add(user)
|
||||
return user
|
||||
|
||||
def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]:
|
||||
if not client_user:
|
||||
return None
|
||||
num_retries = settings.DATABASE_MAX_TX_RETRY_COUNT
|
||||
for i in range(num_retries):
|
||||
try:
|
||||
return self._lookup_client_user_tx(client_user, create_missing)
|
||||
except IntegrityError:
|
||||
# catch UniqueViolation exception, for concurrent requests due to conflicts in ix_user_username
|
||||
if i + 1 == num_retries:
|
||||
raise
|
||||
|
||||
def query_users_ordered_by_username(
|
||||
self,
|
||||
api_client_id: Optional[UUID] = None,
|
||||
@@ -145,6 +165,7 @@ class UserRepository:
|
||||
auth_method: Optional[str] = None,
|
||||
search_text: Optional[str] = None,
|
||||
limit: Optional[int] = 100,
|
||||
desc: bool = False,
|
||||
) -> list[User]:
|
||||
if not self.api_client.trusted:
|
||||
if not api_client_id:
|
||||
@@ -184,14 +205,13 @@ class UserRepository:
|
||||
pattern = "%{}%".format(search_text.replace("\\", "\\\\").replace("_", "\\_").replace("%", "\\%"))
|
||||
qry = qry.filter(User.username.like(pattern))
|
||||
|
||||
if limit is not None and lte_username and not gte_username:
|
||||
# select top rows but return results in ascernding order
|
||||
sub_qry = qry.order_by(User.username.desc(), User.id.desc()).limit(limit).subquery("u")
|
||||
qry = self.db.query(User).select_entity_from(sub_qry).order_by(User.username, User.id)
|
||||
if desc:
|
||||
qry = qry.order_by(User.username.desc(), User.id.desc())
|
||||
else:
|
||||
qry = qry.order_by(User.username, User.id)
|
||||
if limit is not None:
|
||||
qry = qry.limit(limit)
|
||||
|
||||
if limit is not None:
|
||||
qry = qry.limit(limit)
|
||||
|
||||
return qry.all()
|
||||
|
||||
@@ -205,7 +225,9 @@ class UserRepository:
|
||||
auth_method: Optional[str] = None,
|
||||
search_text: Optional[str] = None,
|
||||
limit: Optional[int] = 100,
|
||||
desc: bool = False,
|
||||
) -> list[User]:
|
||||
|
||||
if not self.api_client.trusted:
|
||||
if not api_client_id:
|
||||
# Let unprivileged api clients query their own users without api_client_id being set
|
||||
@@ -255,13 +277,12 @@ class UserRepository:
|
||||
if auth_method:
|
||||
qry = qry.filter(User.auth_method == auth_method)
|
||||
|
||||
if limit is not None and lte_display_name and not gte_display_name:
|
||||
# select top rows but return results in ascernding order
|
||||
sub_qry = qry.order_by(User.display_name.desc(), User.id.desc()).limit(limit).subquery("u")
|
||||
qry = self.db.query(User).select_entity_from(sub_qry).order_by(User.display_name, User.id)
|
||||
if desc:
|
||||
qry = qry.order_by(User.display_name.desc(), User.id.desc())
|
||||
else:
|
||||
qry = qry.order_by(User.display_name, User.id)
|
||||
if limit is not None:
|
||||
qry = qry.limit(limit)
|
||||
|
||||
if limit is not None:
|
||||
qry = qry.limit(limit)
|
||||
|
||||
return qry.all()
|
||||
|
||||
@@ -39,7 +39,7 @@ class UserStatsRepository:
|
||||
qry = (
|
||||
self.session.query(User.id.label("user_id"), User.username, User.auth_method, User.display_name, UserStats)
|
||||
.join(UserStats, User.id == UserStats.user_id)
|
||||
.filter(UserStats.time_frame == time_frame.value)
|
||||
.filter(UserStats.time_frame == time_frame.value, User.show_on_leaderboard)
|
||||
.order_by(UserStats.rank)
|
||||
.limit(limit)
|
||||
)
|
||||
@@ -250,7 +250,8 @@ FROM
|
||||
PARTITION BY time_frame
|
||||
ORDER BY leader_score DESC, user_id
|
||||
) AS "rank", user_id, time_frame
|
||||
FROM user_stats
|
||||
FROM user_stats us2
|
||||
INNER JOIN "user" u ON us2.user_id = u.id AND u.show_on_leaderboard
|
||||
WHERE (:time_frame IS NULL OR time_frame = :time_frame)) AS r
|
||||
WHERE
|
||||
us.user_id = r.user_id
|
||||
|
||||
@@ -7,9 +7,14 @@ from loguru import logger
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.database import engine
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from psycopg2.errors import DeadlockDetected, ExclusionViolation, SerializationFailure, UniqueViolation
|
||||
from sqlalchemy.exc import OperationalError, PendingRollbackError
|
||||
from sqlmodel import Session, SQLModel
|
||||
|
||||
"""
|
||||
Error Handling Reference: https://www.postgresql.org/docs/15/mvcc-serialization-failure-handling.html
|
||||
"""
|
||||
|
||||
|
||||
class CommitMode(IntEnum):
|
||||
"""
|
||||
@@ -26,7 +31,6 @@ class CommitMode(IntEnum):
|
||||
* managed_tx_method and async_managed_tx_method methods are decorators functions
|
||||
* to be used on class functions. It expects the Class to have a 'db' Session object
|
||||
* initialised
|
||||
* TODO: tx method decorator for non class methods
|
||||
"""
|
||||
|
||||
|
||||
@@ -35,28 +39,46 @@ def managed_tx_method(auto_commit: CommitMode = CommitMode.COMMIT, num_retries=s
|
||||
@wraps(f)
|
||||
def wrapped_f(self, *args, **kwargs):
|
||||
try:
|
||||
for i in range(num_retries):
|
||||
try:
|
||||
result = f(self, *args, **kwargs)
|
||||
if auto_commit == CommitMode.COMMIT:
|
||||
result = None
|
||||
if auto_commit == CommitMode.COMMIT:
|
||||
retry_exhausted = True
|
||||
for i in range(num_retries):
|
||||
try:
|
||||
result = f(self, *args, **kwargs)
|
||||
self.db.commit()
|
||||
elif auto_commit == CommitMode.FLUSH:
|
||||
self.db.flush()
|
||||
elif auto_commit == CommitMode.ROLLBACK:
|
||||
if isinstance(result, SQLModel):
|
||||
self.db.refresh(result)
|
||||
retry_exhausted = False
|
||||
break
|
||||
except PendingRollbackError as e:
|
||||
logger.info(str(e))
|
||||
self.db.rollback()
|
||||
except OperationalError as e:
|
||||
if e.orig is not None and isinstance(
|
||||
e.orig, (SerializationFailure, DeadlockDetected, UniqueViolation, ExclusionViolation)
|
||||
):
|
||||
logger.info(f"{type(e.orig)} Inner {e.orig.pgcode} {type(e.orig.pgcode)}")
|
||||
self.db.rollback()
|
||||
else:
|
||||
raise e
|
||||
logger.info(f"Retry {i+1}/{num_retries}")
|
||||
if retry_exhausted:
|
||||
raise OasstError(
|
||||
"DATABASE_MAX_RETIRES_EXHAUSTED",
|
||||
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
|
||||
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
)
|
||||
else:
|
||||
result = f(self, *args, **kwargs)
|
||||
if auto_commit == CommitMode.FLUSH:
|
||||
self.db.flush()
|
||||
if isinstance(result, SQLModel):
|
||||
self.db.refresh(result)
|
||||
return result
|
||||
except OperationalError:
|
||||
logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.")
|
||||
elif auto_commit == CommitMode.ROLLBACK:
|
||||
self.db.rollback()
|
||||
raise OasstError(
|
||||
"DATABASE_MAX_RETIRES_EXHAUSTED",
|
||||
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
|
||||
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error("DB Rollback Failure")
|
||||
logger.info(str(e))
|
||||
raise e
|
||||
|
||||
return wrapped_f
|
||||
@@ -71,28 +93,46 @@ def async_managed_tx_method(
|
||||
@wraps(f)
|
||||
async def wrapped_f(self, *args, **kwargs):
|
||||
try:
|
||||
for i in range(num_retries):
|
||||
try:
|
||||
result = await f(self, *args, **kwargs)
|
||||
if auto_commit == CommitMode.COMMIT:
|
||||
result = None
|
||||
if auto_commit == CommitMode.COMMIT:
|
||||
retry_exhausted = True
|
||||
for i in range(num_retries):
|
||||
try:
|
||||
result = await f(self, *args, **kwargs)
|
||||
self.db.commit()
|
||||
elif auto_commit == CommitMode.FLUSH:
|
||||
self.db.flush()
|
||||
elif auto_commit == CommitMode.ROLLBACK:
|
||||
if isinstance(result, SQLModel):
|
||||
self.db.refresh(result)
|
||||
retry_exhausted = False
|
||||
break
|
||||
except PendingRollbackError as e:
|
||||
logger.info(str(e))
|
||||
self.db.rollback()
|
||||
except OperationalError as e:
|
||||
if e.orig is not None and isinstance(
|
||||
e.orig, (SerializationFailure, DeadlockDetected, UniqueViolation, ExclusionViolation)
|
||||
):
|
||||
logger.info(f"{type(e.orig)} Inner {e.orig.pgcode} {type(e.orig.pgcode)}")
|
||||
self.db.rollback()
|
||||
else:
|
||||
raise e
|
||||
logger.info(f"Retry {i+1}/{num_retries}")
|
||||
if retry_exhausted:
|
||||
raise OasstError(
|
||||
"DATABASE_MAX_RETIRES_EXHAUSTED",
|
||||
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
|
||||
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
)
|
||||
else:
|
||||
result = await f(self, *args, **kwargs)
|
||||
if auto_commit == CommitMode.FLUSH:
|
||||
self.db.flush()
|
||||
if isinstance(result, SQLModel):
|
||||
self.db.refresh(result)
|
||||
return result
|
||||
except OperationalError:
|
||||
logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.")
|
||||
elif auto_commit == CommitMode.ROLLBACK:
|
||||
self.db.rollback()
|
||||
raise OasstError(
|
||||
"DATABASE_MAX_RETIRES_EXHAUSTED",
|
||||
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
|
||||
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.exception("DB Rollback Failure")
|
||||
logger.info(str(e))
|
||||
raise e
|
||||
|
||||
return wrapped_f
|
||||
@@ -115,27 +155,49 @@ def managed_tx_function(
|
||||
@wraps(f)
|
||||
def wrapped_f(*args, **kwargs):
|
||||
try:
|
||||
for i in range(num_retries):
|
||||
with session_factory() as session:
|
||||
try:
|
||||
result = f(session, *args, **kwargs)
|
||||
if auto_commit == CommitMode.COMMIT:
|
||||
result = None
|
||||
if auto_commit == CommitMode.COMMIT:
|
||||
retry_exhausted = True
|
||||
for i in range(num_retries):
|
||||
with session_factory() as session:
|
||||
try:
|
||||
result = f(session, *args, **kwargs)
|
||||
session.commit()
|
||||
elif auto_commit == CommitMode.FLUSH:
|
||||
session.flush()
|
||||
elif auto_commit == CommitMode.ROLLBACK:
|
||||
if isinstance(result, SQLModel):
|
||||
session.refresh(result)
|
||||
retry_exhausted = False
|
||||
break
|
||||
except PendingRollbackError as e:
|
||||
logger.info(str(e))
|
||||
session.rollback()
|
||||
return result
|
||||
except OperationalError:
|
||||
logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.")
|
||||
session.rollback()
|
||||
raise OasstError(
|
||||
"DATABASE_MAX_RETIRES_EXHAUSTED",
|
||||
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
|
||||
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
)
|
||||
except OperationalError as e:
|
||||
if e.orig is not None and isinstance(
|
||||
e.orig,
|
||||
(SerializationFailure, DeadlockDetected, UniqueViolation, ExclusionViolation),
|
||||
):
|
||||
logger.info(f"{type(e.orig)} Inner {e.orig.pgcode} {type(e.orig.pgcode)}")
|
||||
session.rollback()
|
||||
else:
|
||||
raise e
|
||||
logger.info(f"Retry {i+1}/{num_retries}")
|
||||
if retry_exhausted:
|
||||
raise OasstError(
|
||||
"DATABASE_MAX_RETIRES_EXHAUSTED",
|
||||
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
|
||||
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
)
|
||||
else:
|
||||
with session_factory() as session:
|
||||
result = f(session, *args, **kwargs)
|
||||
if auto_commit == CommitMode.FLUSH:
|
||||
session.flush()
|
||||
if isinstance(result, SQLModel):
|
||||
session.refresh(result)
|
||||
elif auto_commit == CommitMode.ROLLBACK:
|
||||
session.rollback()
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error("DB Rollback Failure")
|
||||
logger.info(str(e))
|
||||
raise e
|
||||
|
||||
return wrapped_f
|
||||
|
||||
@@ -0,0 +1,111 @@
|
||||
import os
|
||||
import pickle
|
||||
from collections import Counter
|
||||
|
||||
from sklearn import metrics
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.pipeline import Pipeline
|
||||
from sklearn.svm import LinearSVC
|
||||
|
||||
|
||||
def load_and_split(foldername, num_words):
|
||||
ls = os.listdir(foldername)
|
||||
X = []
|
||||
Y = []
|
||||
langmap = dict()
|
||||
for idx, x in enumerate(ls):
|
||||
print("loading language", x)
|
||||
with open(foldername + "/" + x, "r") as reader:
|
||||
tmp = reader.read().split(" ")
|
||||
tmp = [" ".join(tmp[i : i + num_words]) for i in range(0, 100_000, num_words)]
|
||||
X.extend(tmp)
|
||||
Y.extend([idx] * len(tmp))
|
||||
langmap[idx] = x
|
||||
x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.90)
|
||||
return x_train, x_test, y_train, y_test, langmap
|
||||
|
||||
|
||||
def build_and_train_pipeline(x_train, y_train):
|
||||
vectorizer = TfidfVectorizer(ngram_range=(1, 2), analyzer="char", use_idf=False)
|
||||
clf = Pipeline(
|
||||
[
|
||||
("vec", vectorizer),
|
||||
# ("nystrom", Nystroem(n_components=1000,n_jobs=6)),
|
||||
("clf", LinearSVC(C=0.5)),
|
||||
# ("clf",GaussianNB())
|
||||
# ("clf", HistGradientBoostingClassifier())
|
||||
]
|
||||
)
|
||||
print("fitting model...")
|
||||
clf.fit(x_train, y_train)
|
||||
return clf
|
||||
|
||||
|
||||
def benchmark(clf, x_test, y_test, langmap):
|
||||
print("benchmarking model...")
|
||||
y_pred = clf.predict(x_test)
|
||||
names = list(langmap.values())
|
||||
# print(y_test)
|
||||
# print(langmap)
|
||||
print(metrics.classification_report(y_test, y_pred, target_names=names))
|
||||
cm = metrics.confusion_matrix(y_test, y_pred)
|
||||
print(cm)
|
||||
|
||||
|
||||
def main(foldername, modelname, num_words):
|
||||
x_train, x_test, y_train, y_test, langmap = load_and_split(foldername=foldername, num_words=num_words)
|
||||
clf = build_and_train_pipeline(x_train, y_train)
|
||||
benchmark(clf, x_test, y_test, langmap)
|
||||
save_model(clf, langmap, num_words, modelname)
|
||||
model = load(modelname)
|
||||
print(
|
||||
"running infernence on long tests",
|
||||
inference_voter(
|
||||
model,
|
||||
"""
|
||||
What language is this text written in? Nobody knows until you fill in at least ten words.
|
||||
This test here is to check whether the moving window approach works,
|
||||
so I still need to fill in a little more text.
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def load(modelname):
|
||||
with open(modelname, "rb") as writer:
|
||||
data = pickle.load(writer)
|
||||
return data
|
||||
|
||||
|
||||
def save_model(model, idx_to_name, num_words, modelname):
|
||||
out = {
|
||||
"model": model,
|
||||
"idx_to_name": idx_to_name,
|
||||
"num_words": num_words,
|
||||
}
|
||||
with open(modelname, "wb") as writer:
|
||||
pickle.dump(out, writer)
|
||||
|
||||
|
||||
def inference_voter(model, text):
|
||||
tmp = text.split()
|
||||
# print(len(tmp), tmp)
|
||||
tmp = [" ".join(tmp[i : i + model["num_words"]]) for i in range(0, len(tmp) - model["num_words"])]
|
||||
predictions = model["model"].predict(tmp)
|
||||
# print("integer predictions", predictions)
|
||||
# print("name predictions", *[model["idx_to_name"][n] for n in predictions])
|
||||
result = Counter(predictions).most_common(1)[0][0]
|
||||
return model["idx_to_name"][result]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-m", "--model", help="save location for model and metadata")
|
||||
parser.add_argument("-d", "--data", help="specify the folder for data files")
|
||||
parser.add_argument("-n", "--num_words", help="number of words to use for statistics", type=int)
|
||||
args = parser.parse_args()
|
||||
# np.set_printoptions(threshold=np.inf)
|
||||
main(args.data, args.model, args.num_words)
|
||||
@@ -96,13 +96,15 @@ def ranked_pairs(ranks: List[List[int]]):
|
||||
"""
|
||||
tallies, names = head_to_head_votes(ranks)
|
||||
tallies = tallies - tallies.T
|
||||
# print(tallies)
|
||||
# note: the resulting tally matrix should be skew-symmetric
|
||||
# order by strength of victory (using tideman's original method, don't think it would make a difference for us)
|
||||
sorted_majorities = []
|
||||
for i in range(len(ranks[0])):
|
||||
for j in range(len(ranks[0])):
|
||||
if tallies[i, j] > 0:
|
||||
# you can never prefer yourself over yourself
|
||||
# we also have to pick one of the two choices,
|
||||
# if the preference is exactly zero...
|
||||
if tallies[i, j] >= 0 and i != j:
|
||||
sorted_majorities.append((i, j, tallies[i, j]))
|
||||
# we don't explicitly deal with tied majorities here
|
||||
sorted_majorities = np.array(sorted(sorted_majorities, key=lambda x: x[2], reverse=True))
|
||||
@@ -128,13 +130,36 @@ def ranked_pairs(ranks: List[List[int]]):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ranks = (
|
||||
|
||||
ranks = """ (
|
||||
[("w", "x", "z", "y") for _ in range(1)]
|
||||
+ [("w", "y", "x", "z") for _ in range(2)]
|
||||
# + [("x","y","z","w") for _ in range(4)]
|
||||
+ [("x", "z", "w", "y") for _ in range(5)]
|
||||
+ [("y", "w", "x", "z") for _ in range(1)]
|
||||
# [("y","z","w","x") for _ in range(1000)]
|
||||
)
|
||||
)"""
|
||||
ranks = [
|
||||
[
|
||||
("c5181083-d3e9-41e7-a935-83fb9fa01488"),
|
||||
("dcf3d179-0f34-4c15-ae21-b8feb15e422d"),
|
||||
("d11705af-5575-43e5-b22e-08d155fbaa62"),
|
||||
],
|
||||
[
|
||||
("d11705af-5575-43e5-b22e-08d155fbaa62"),
|
||||
("c5181083-d3e9-41e7-a935-83fb9fa01488"),
|
||||
("dcf3d179-0f34-4c15-ae21-b8feb15e422d"),
|
||||
],
|
||||
[
|
||||
("dcf3d179-0f34-4c15-ae21-b8feb15e422d"),
|
||||
("c5181083-d3e9-41e7-a935-83fb9fa01488"),
|
||||
("d11705af-5575-43e5-b22e-08d155fbaa62"),
|
||||
],
|
||||
[
|
||||
("d11705af-5575-43e5-b22e-08d155fbaa62"),
|
||||
("c5181083-d3e9-41e7-a935-83fb9fa01488"),
|
||||
("dcf3d179-0f34-4c15-ae21-b8feb15e422d"),
|
||||
],
|
||||
]
|
||||
rp = ranked_pairs(ranks)
|
||||
print(rp)
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import gzip
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from typing import Optional, TextIO
|
||||
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from oasst_backend.models import Message
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ExportMessageNode(BaseModel):
|
||||
message_id: str
|
||||
parent_id: Optional[str]
|
||||
text: Optional[str]
|
||||
role: str
|
||||
review_count: Optional[int]
|
||||
rank: Optional[int]
|
||||
replies: Optional[list[ExportMessageNode]]
|
||||
|
||||
@classmethod
|
||||
def prep_message_export(cls, message: Message) -> ExportMessageNode:
|
||||
return cls(
|
||||
message_id=str(message.id),
|
||||
parent_id=str(message.parent_id) if message.parent_id else None,
|
||||
text=str(message.payload.payload.text),
|
||||
role=message.role,
|
||||
review_count=message.review_count,
|
||||
rank=message.rank,
|
||||
)
|
||||
|
||||
|
||||
class ExportMessageTree(BaseModel):
|
||||
message_tree_id: str
|
||||
replies: Optional[ExportMessageNode]
|
||||
|
||||
|
||||
def build_export_tree(message_tree_id: str, messages: list[Message]) -> ExportMessageTree:
|
||||
export_tree = ExportMessageTree(message_tree_id=str(message_tree_id))
|
||||
export_tree_data = [ExportMessageNode.prep_message_export(m) for m in messages]
|
||||
|
||||
message_parents = defaultdict(list)
|
||||
for message in export_tree_data:
|
||||
message_parents[message.parent_id].append(message)
|
||||
|
||||
def build_tree(tree: dict, parent: Optional[str], messages: list[Message]):
|
||||
children = message_parents[parent]
|
||||
tree.replies = children
|
||||
|
||||
for idx, child in enumerate(tree.replies):
|
||||
build_tree(tree.replies[idx], child.message_id, messages)
|
||||
|
||||
build_tree(export_tree, None, export_tree_data)
|
||||
|
||||
return export_tree
|
||||
|
||||
|
||||
def write_trees_to_file(file, trees: list[ExportMessageTree], use_compression: bool = True) -> None:
|
||||
|
||||
out_buff: TextIO
|
||||
if use_compression:
|
||||
out_buff = gzip.open(file, "wt", encoding="UTF-8")
|
||||
else:
|
||||
out_buff = open(file, "wt", encoding="UTF-8")
|
||||
|
||||
with out_buff as f:
|
||||
for tree in trees:
|
||||
file_data = jsonable_encoder(tree)
|
||||
json.dump(file_data, f)
|
||||
f.write("\n")
|
||||
+14
-1
@@ -2,7 +2,13 @@
|
||||
|
||||
Preliminary implementation of the inference engine for OpenAssistant.
|
||||
|
||||
## Development (you'll need multiple terminals)
|
||||
## Development Variant 1 (you'll need tmux)
|
||||
|
||||
Run `./full-dev-setup.sh` to start the full development setup. Make sure to wait
|
||||
until the 2nd terminal is ready and says `{"message":"Connected"}` before
|
||||
entering input into the last terminal.
|
||||
|
||||
## Development Variant 2 (you'll need multiple terminals)
|
||||
|
||||
Run a redis container (or use the one of the general docker compose file):
|
||||
|
||||
@@ -26,6 +32,13 @@ pip install -r requirements.txt
|
||||
python __main__.py
|
||||
```
|
||||
|
||||
For the worker, you'll also want to have the text-generation-inference server
|
||||
running:
|
||||
|
||||
```bash
|
||||
docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 ykilcher/text-generation-inference
|
||||
```
|
||||
|
||||
Run the client:
|
||||
|
||||
```bash
|
||||
|
||||
Executable
+19
@@ -0,0 +1,19 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Creates a tmux window with splits for the individual services
|
||||
|
||||
tmux new-session -d -s "inference-dev-setup"
|
||||
tmux send-keys "docker run --rm -it -p 6379:6379 redis" C-m
|
||||
tmux split-window -h
|
||||
tmux send-keys "docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 ykilcher/text-generation-inference" C-m
|
||||
tmux split-window -h
|
||||
tmux send-keys "cd server" C-m
|
||||
tmux send-keys "uvicorn main:app --reload" C-m
|
||||
tmux split-window -h
|
||||
tmux send-keys "cd worker" C-m
|
||||
tmux send-keys "python __main__.py" C-m
|
||||
tmux split-window -h
|
||||
tmux send-keys "cd text-client" C-m
|
||||
tmux send-keys "sleep 5" C-m
|
||||
tmux send-keys "python __main__.py" C-m
|
||||
tmux attach-session -t "inference-dev-setup"
|
||||
+53
-35
@@ -5,6 +5,7 @@ import uuid
|
||||
import fastapi
|
||||
import pydantic
|
||||
import redis.asyncio as redis
|
||||
import websockets.exceptions
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from loguru import logger
|
||||
from oasst_shared.schemas import inference, protocol
|
||||
@@ -63,6 +64,7 @@ class MessageRequestState(str, enum.Enum):
|
||||
pending = "pending"
|
||||
in_progress = "in_progress"
|
||||
complete = "complete"
|
||||
aborted_by_worker = "aborted_by_worker"
|
||||
|
||||
|
||||
class DbChatEntry(pydantic.BaseModel):
|
||||
@@ -154,40 +156,56 @@ async def create_message(id: str, message_request: MessageRequest, fastapi_reque
|
||||
async def work(websocket: fastapi.WebSocket):
|
||||
await websocket.accept()
|
||||
worker_config = inference.WorkerConfig.parse_raw(await websocket.receive_text())
|
||||
while True:
|
||||
# find a pending task that matches the worker's config
|
||||
# could also be implemented using task queues
|
||||
# but general compatibility matching is tricky
|
||||
for chat in CHATS.values():
|
||||
if (request := chat.pending_message_request) is not None:
|
||||
if chat.message_request_state == MessageRequestState.pending:
|
||||
if request.compatible_with(worker_config):
|
||||
try:
|
||||
while True:
|
||||
print(websocket.client_state)
|
||||
if websocket.client_state == fastapi.websockets.WebSocketState.DISCONNECTED:
|
||||
logger.warning("Worker disconnected")
|
||||
break
|
||||
# find a pending task that matches the worker's config
|
||||
# could also be implemented using task queues
|
||||
# but general compatibility matching is tricky
|
||||
for chat in CHATS.values():
|
||||
if (request := chat.pending_message_request) is not None:
|
||||
if chat.message_request_state == MessageRequestState.pending:
|
||||
if request.compatible_with(worker_config):
|
||||
break
|
||||
else:
|
||||
logger.debug("No pending tasks")
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
chat.message_request_state = MessageRequestState.in_progress
|
||||
|
||||
work_request = inference.WorkRequest(
|
||||
conversation=chat.conversation,
|
||||
model_name=request.model_name,
|
||||
max_new_tokens=request.max_new_tokens,
|
||||
)
|
||||
|
||||
logger.info(f"Created {work_request}")
|
||||
try:
|
||||
await websocket.send_text(work_request.json())
|
||||
except websockets.exceptions.ConnectionClosedError:
|
||||
logger.warning("Worker disconnected")
|
||||
websocket.close()
|
||||
chat.message_request_state = MessageRequestState.pending
|
||||
break
|
||||
|
||||
try:
|
||||
while True:
|
||||
# maybe unnecessary to parse and re-serialize
|
||||
# could just pass the raw string and mark end via empty string
|
||||
response_packet = inference.WorkResponsePacket.parse_raw(await websocket.receive_text())
|
||||
await redisClient.rpush(chat.id, response_packet.json())
|
||||
if response_packet.is_end:
|
||||
break
|
||||
else:
|
||||
logger.debug("No pending tasks")
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
except fastapi.WebSocketException:
|
||||
# TODO: handle this better
|
||||
logger.exception(f"Websocket closed during handling of {chat.id}")
|
||||
chat.message_request_state = MessageRequestState.aborted_by_worker
|
||||
raise
|
||||
|
||||
chat.message_request_state = MessageRequestState.in_progress
|
||||
|
||||
work_request = inference.WorkRequest(
|
||||
conversation=chat.conversation,
|
||||
model_name=request.model_name,
|
||||
max_new_tokens=request.max_new_tokens,
|
||||
)
|
||||
|
||||
logger.info(f"Created {work_request}")
|
||||
try:
|
||||
await websocket.send_text(work_request.json())
|
||||
while True:
|
||||
# maybe unnecessary to parse and re-serialize
|
||||
# could just pass the raw string and mark end via empty string
|
||||
response_packet = inference.WorkResponsePacket.parse_raw(await websocket.receive_text())
|
||||
await redisClient.rpush(chat.id, response_packet.json())
|
||||
if response_packet.is_end:
|
||||
break
|
||||
except fastapi.WebSocketException:
|
||||
# TODO: handle this better
|
||||
logger.exception(f"Websocket closed during handling of {chat.id}")
|
||||
|
||||
chat.message_request_state = MessageRequestState.complete
|
||||
chat.message_request_state = MessageRequestState.complete
|
||||
except fastapi.WebSocketException:
|
||||
logger.exception("Websocket closed")
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
import re
|
||||
import time
|
||||
import json
|
||||
|
||||
import rel
|
||||
import torch
|
||||
import requests
|
||||
import sseclient
|
||||
import typer
|
||||
import websocket
|
||||
from loguru import logger
|
||||
from oasst_shared.schemas import inference, protocol
|
||||
from transformers import pipeline
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
@@ -16,12 +15,13 @@ app = typer.Typer()
|
||||
def main(
|
||||
backend_url: str = "ws://localhost:8000",
|
||||
model_name: str = "distilgpt2",
|
||||
inference_server_url: str = "http://localhost:8001",
|
||||
):
|
||||
pipe = pipeline("text-generation", model=model_name)
|
||||
|
||||
def on_open(ws: websocket.WebSocket):
|
||||
logger.info("Connected to backend, sending config...")
|
||||
worker_config = inference.WorkerConfig(model_name=model_name)
|
||||
ws.send(worker_config.json())
|
||||
logger.info("Config sent, waiting for work...")
|
||||
|
||||
def on_message(ws: websocket.WebSocket, message: str):
|
||||
# TODO: what if this comes in, but one is already in progress?
|
||||
@@ -35,25 +35,39 @@ def main(
|
||||
# construct prompt
|
||||
messages = [_prepare_message(message) for message in work_request.conversation.messages]
|
||||
|
||||
prompt = "\n".join(messages) + "\nAssistant:"
|
||||
prefix = (
|
||||
"The following is a conversation between a user and an assistant. "
|
||||
"The assistant is helpful, creative, clever, and very friendly.\n"
|
||||
"Assistant: Hello! How can I help you today?\n"
|
||||
)
|
||||
|
||||
# TODO: replace this with incremental generation
|
||||
torch.manual_seed(work_request.seed)
|
||||
model_output = pipe(prompt, max_new_tokens=work_request.max_new_tokens, do_sample=True, return_full_text=False)[
|
||||
0
|
||||
]["generated_text"]
|
||||
model_output = model_output.strip()
|
||||
prompt = prefix + "\n".join(messages) + "\nAssistant:"
|
||||
|
||||
# fake streaming
|
||||
split_idcs = [m.start() for m in re.finditer(r"([\w:]+)", model_output)]
|
||||
pieces = [model_output[a:b] for a, b in zip([0] + split_idcs, split_idcs + [None])]
|
||||
for piece in pieces:
|
||||
if not piece:
|
||||
continue
|
||||
if piece.strip() in ("User:", "Assistant:"):
|
||||
response = requests.post(
|
||||
f"{inference_server_url}/generate_stream",
|
||||
json={
|
||||
"inputs": prompt,
|
||||
"parameters": {
|
||||
"max_new_tokens": work_request.max_new_tokens,
|
||||
"do_sample": work_request.do_sample,
|
||||
"top_k": work_request.top_k,
|
||||
"top_p": work_request.top_p,
|
||||
"temperature": work_request.temperature,
|
||||
"seed": work_request.seed,
|
||||
},
|
||||
},
|
||||
stream=True,
|
||||
headers={"Accept": "text/event-stream"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
client = sseclient.SSEClient(response)
|
||||
for event in client.events():
|
||||
data = json.loads(event.data)
|
||||
if data["is_end"]:
|
||||
break
|
||||
ws.send(inference.WorkResponsePacket(token=piece).json())
|
||||
time.sleep(0.1)
|
||||
intermediate = data["event"]
|
||||
ws.send(inference.WorkResponsePacket(token=intermediate["token"]).json())
|
||||
ws.send(inference.WorkResponsePacket(is_end=True).json())
|
||||
|
||||
def on_error(ws: websocket.WebSocket, error: Exception):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
loguru
|
||||
rel
|
||||
torch
|
||||
transformers
|
||||
requests
|
||||
sseclient-py
|
||||
typer
|
||||
websocket-client
|
||||
|
||||
@@ -43,7 +43,7 @@ def get_one_dataset(conf, dataset_name):
|
||||
if dataset_name == "debate_sum":
|
||||
train, eval = train_val_dataset(train, val_split=0.2)
|
||||
else:
|
||||
val_name = "validation" if dataset_name not in ["billsum"] else "test"
|
||||
val_name = "validation" if dataset_name not in ["billsum", "tldr_news"] else "test"
|
||||
eval = SummarizationDataset(dataset_name, conf.cache_dir, val_name)
|
||||
elif "ted_trans" in dataset_name:
|
||||
language_pair = dataset_name.split("_")[-1]
|
||||
|
||||
@@ -3,7 +3,6 @@ from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from custom_datasets.qa_datasets import QA_SPECIAL_TOKENS
|
||||
from torch.nn import functional as F
|
||||
from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase
|
||||
|
||||
@@ -23,15 +22,8 @@ class DialogueDataCollator:
|
||||
flatten_messages = []
|
||||
label_masks = []
|
||||
|
||||
for feature_one in features:
|
||||
assert len(feature_one) % 2 == 0, "Number of messages must be even"
|
||||
# TODO: we should push this to dataset __getitem__
|
||||
messages = [
|
||||
(QA_SPECIAL_TOKENS["Question"] if i % 2 == 0 else "")
|
||||
+ x
|
||||
+ (QA_SPECIAL_TOKENS["Answer"] if i % 2 == 0 else "")
|
||||
for i, x in enumerate(feature_one)
|
||||
]
|
||||
for messages in features:
|
||||
messages = list(messages)
|
||||
|
||||
# Add a way for the model to terminate generation
|
||||
# When we predict the start of a new expected question, we want to be able to stop generation
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
QA_SPECIAL_TOKENS = {"Question": "<human>", "Answer": "<bot>", "StartPrefix": "<prefix>", "EndPrefix": "</prefix>"}
|
||||
|
||||
|
||||
def format_pair(pair):
|
||||
return "{}{}{}".format(QA_SPECIAL_TOKENS["Question"], pair[0], QA_SPECIAL_TOKENS["Answer"]), pair[1]
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
import os
|
||||
from urllib.request import urlopen
|
||||
|
||||
from custom_datasets.formatting import format_pair
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
@@ -49,8 +50,7 @@ class PromptGeneratedDataset(Dataset):
|
||||
return len(self.pairs)
|
||||
|
||||
def __getitem__(self, index):
|
||||
question, answer = self.pairs[index]
|
||||
return question, answer
|
||||
return format_pair(self.pairs[index])
|
||||
|
||||
|
||||
class InstructionTuning(Dataset):
|
||||
@@ -101,5 +101,4 @@ class InstructionTuning(Dataset):
|
||||
return len(self.pairs)
|
||||
|
||||
def __getitem__(self, index):
|
||||
question, answer = self.pairs[index]
|
||||
return question, answer
|
||||
return format_pair(self.pairs[index])
|
||||
|
||||
@@ -7,14 +7,13 @@ import re
|
||||
from urllib.request import urlopen
|
||||
|
||||
import numpy as np
|
||||
from custom_datasets.formatting import QA_SPECIAL_TOKENS, format_pair
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
# @agoryuno contributed this
|
||||
re_reference_remove = re.compile(r"\[\d+(?:,\s*\d+)*?\]")
|
||||
|
||||
QA_SPECIAL_TOKENS = {"Question": "<human>", "Answer": "<bot>", "StartPrefix": "<prefix>", "EndPrefix": "</prefix>"}
|
||||
|
||||
|
||||
def index_squad_v2(example):
|
||||
if len(example["answers"]["text"]):
|
||||
@@ -78,7 +77,7 @@ class QADataset(Dataset):
|
||||
|
||||
def __getitem__(self, idx):
|
||||
data = self.dataset[idx]
|
||||
return self.index_fn(data)
|
||||
return format_pair(self.index_fn(data))
|
||||
|
||||
|
||||
class WebGPT(Dataset):
|
||||
@@ -111,7 +110,7 @@ class WebGPT(Dataset):
|
||||
def __getitem__(self, index):
|
||||
question = self.index2question[index]
|
||||
answer = self.questions[question]
|
||||
return [question, answer]
|
||||
return format_pair((question, answer))
|
||||
|
||||
|
||||
class SODA(Dataset):
|
||||
@@ -121,14 +120,14 @@ class SODA(Dataset):
|
||||
def process_soda_convo(self, data):
|
||||
pairs = []
|
||||
play_as = data["speakers"][1]
|
||||
prefix = "{}{}. {}{}".format(
|
||||
QA_SPECIAL_TOKENS["StartPrefix"],
|
||||
data["narrative"],
|
||||
"your name {}".format(play_as),
|
||||
QA_SPECIAL_TOKENS["EndPrefix"],
|
||||
)
|
||||
question, answer = "", ""
|
||||
prefix, postfix = "", ""
|
||||
dialogue_bg = "{}{} {}{}".format(
|
||||
QA_SPECIAL_TOKENS["StartPrefix"],
|
||||
data["narrative"],
|
||||
"your are {}".format(play_as),
|
||||
QA_SPECIAL_TOKENS["EndPrefix"],
|
||||
)
|
||||
previous_chat = []
|
||||
|
||||
for idx, convo in enumerate(data["dialogue"]):
|
||||
@@ -138,14 +137,20 @@ class SODA(Dataset):
|
||||
else:
|
||||
answer = convo
|
||||
postfix = data["speakers"][idx]
|
||||
|
||||
if len(question) and len(answer) and prefix != postfix and postfix == play_as:
|
||||
history = "<sep>".join(
|
||||
["{}{}{}".format(p[0], QA_SPECIAL_TOKENS["Answer"], p[1]) for p in previous_chat]
|
||||
[
|
||||
"{}{}{}{}".format(QA_SPECIAL_TOKENS["Question"], p[0], QA_SPECIAL_TOKENS["Answer"], p[1])
|
||||
for p in previous_chat
|
||||
]
|
||||
)
|
||||
if len(history):
|
||||
history += "<sep>"
|
||||
pairs.append((prefix + history + question, answer))
|
||||
prompt = QA_SPECIAL_TOKENS["Question"] + question + QA_SPECIAL_TOKENS["Answer"]
|
||||
pairs.append((dialogue_bg + history + prompt, answer))
|
||||
previous_chat.append((question, answer))
|
||||
|
||||
return pairs
|
||||
|
||||
def __init__(self, cache_dir, max_sample_size=10000, input_max_length=1024) -> None:
|
||||
@@ -166,8 +171,8 @@ class SODA(Dataset):
|
||||
return len(self.pairs)
|
||||
|
||||
def __getitem__(self, index):
|
||||
question, answer = self.pairs[index]
|
||||
return question, answer
|
||||
# special token added during preprocess
|
||||
return self.pairs[index]
|
||||
|
||||
|
||||
class SODADialogue(Dataset):
|
||||
@@ -218,7 +223,7 @@ class SODADialogue(Dataset):
|
||||
return len(self.pairs)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.pairs[index]
|
||||
return format_pair(self.pairs[index])
|
||||
|
||||
|
||||
class JokeExplaination(Dataset):
|
||||
@@ -253,8 +258,7 @@ class JokeExplaination(Dataset):
|
||||
return len(self.pairs)
|
||||
|
||||
def __getitem__(self, index):
|
||||
question, answer = self.pairs[index]
|
||||
return question, answer
|
||||
return format_pair(self.pairs[index])
|
||||
|
||||
|
||||
# https://huggingface.co/datasets/aquamuse
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
"""
|
||||
import random
|
||||
|
||||
from custom_datasets.formatting import format_pair
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
@@ -54,11 +55,12 @@ def index_summary_merge(text, summary):
|
||||
|
||||
|
||||
class SummarizationDataset(Dataset):
|
||||
def __init__(self, dataset, cache_dir, split):
|
||||
def __init__(self, dataset, cache_dir, split, max_words=512):
|
||||
self.name = dataset
|
||||
self.dataset = load_dataset(*summarization_config_mapping[dataset], cache_dir=cache_dir, split=split)
|
||||
self.text_column, self.summary_column = summarization_name_mapping[dataset]
|
||||
self.preprocess_fn = index_summary_merge if dataset == "scitldr" else index_summary_default
|
||||
self.max_words = max_words
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
@@ -72,4 +74,5 @@ class SummarizationDataset(Dataset):
|
||||
else:
|
||||
prompt = random.choice(SUMMARIZATION_SPECIAL_TOKENS["Summary"])
|
||||
|
||||
return ("".join([SUMMARIZATION_SPECIAL_TOKENS["Text"], " ".join(text.split(" ")[:256]), prompt]), summary)
|
||||
context = "".join([SUMMARIZATION_SPECIAL_TOKENS["Text"], " ".join(text.split(" ")[: self.max_words]), prompt])
|
||||
return format_pair((context, summary))
|
||||
|
||||
@@ -4,12 +4,13 @@
|
||||
"""
|
||||
import random
|
||||
|
||||
from custom_datasets.formatting import QA_SPECIAL_TOKENS, format_pair
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class ProsocialDialogueExplaination(Dataset):
|
||||
name = "prosocial_explain"
|
||||
name = "explain_prosocial"
|
||||
TEMPLATE = [
|
||||
# 0 : reply or sentence of interest, 1 : reason of caution
|
||||
("'{0}' Why is this sentence {1}", "This sentence is {1} because {0}"),
|
||||
@@ -36,7 +37,7 @@ class ProsocialDialogueExplaination(Dataset):
|
||||
return len(self.pairs)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.pairs[idx]
|
||||
return format_pair(self.pairs[idx])
|
||||
|
||||
|
||||
class ProsocialDialogue(Dataset):
|
||||
@@ -58,8 +59,9 @@ class ProsocialDialogue(Dataset):
|
||||
dataset = load_dataset("allenai/prosocial-dialog", cache_dir=cache_dir)[split]
|
||||
self.pairs = []
|
||||
for row in dataset:
|
||||
prompt = QA_SPECIAL_TOKENS["Question"] + row["context"] + QA_SPECIAL_TOKENS["Answer"]
|
||||
for answer in row["rots"]:
|
||||
self.pairs.append((self.PREFIX + row["context"], answer))
|
||||
self.pairs.append((self.PREFIX + prompt, answer))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.pairs)
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
"""
|
||||
import random
|
||||
|
||||
from custom_datasets.formatting import format_pair
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
@@ -82,7 +83,7 @@ class TranslationPair(Dataset):
|
||||
return len(self.pairs)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.pairs[index]
|
||||
return format_pair(self.pairs[index])
|
||||
|
||||
|
||||
class WMT2019(TranslationPair):
|
||||
@@ -99,6 +100,8 @@ class WMT2019(TranslationPair):
|
||||
else: # translating in reverse direction
|
||||
source = random.choice(TRANSLATION_PROMPT[src]).format(row[tgt])
|
||||
self.pairs.append((source, row[src]))
|
||||
if len(self.pairs) > 100000:
|
||||
break
|
||||
|
||||
|
||||
class DiveMT(TranslationPair):
|
||||
|
||||
@@ -7,8 +7,8 @@ from custom_datasets.dialogue_collator import DialogueDataCollator
|
||||
def test_all_datasets():
|
||||
qa_base = QA_DATASETS
|
||||
summarize_base = SUMMARIZATION_DATASETS
|
||||
others = ["prompt_dialogue", "webgpt", "soda", "joke", "instruct_tuning"]
|
||||
translation = ["dive_mt", "wmt2019_zh-en", "wmt2019_ru-en", "wmt2019_de-en", "ted_trans_de-ja", "ted_trans_nl-en"]
|
||||
others = ["prompt_dialogue", "webgpt", "soda", "joke", "instruct_tuning", "explain_prosocial", "prosocial_dialogue"]
|
||||
translation = ["dive_mt", "wmt2019_zh-en", "wmt2019_ru-en", "ted_trans_de-ja", "ted_trans_nl-en"]
|
||||
|
||||
config = Namespace(cache_dir=".cache")
|
||||
for dataset_name in translation + others + summarize_base + qa_base:
|
||||
@@ -31,7 +31,6 @@ def test_collate_fn():
|
||||
qa_base = QA_DATASETS
|
||||
summarize_base = SUMMARIZATION_DATASETS
|
||||
others = ["prompt_dialogue", "webgpt", "soda", "joke", "gsm8k"]
|
||||
|
||||
trains, evals = [], []
|
||||
for dataset_name in others + qa_base + summarize_base:
|
||||
print(dataset_name)
|
||||
@@ -41,10 +40,10 @@ def test_collate_fn():
|
||||
|
||||
dataloader = DataLoader(ConcatDataset(trains), collate_fn=collate_fn, batch_size=128)
|
||||
for batch in dataloader:
|
||||
# print(batch.keys())
|
||||
# print(tokenizer.decode(batch['input_ids'][0]))
|
||||
# print('-----')
|
||||
# print(tokenizer.decode(batch['targets'][0][batch['label_masks'][0]]))
|
||||
print(batch.keys())
|
||||
print(tokenizer.decode(batch["input_ids"][0]))
|
||||
print("-----")
|
||||
print(tokenizer.decode(batch["targets"][0][batch["label_masks"][0]]))
|
||||
assert batch["targets"].shape[1] <= 512
|
||||
dataloader = DataLoader(ConcatDataset(evals), collate_fn=collate_fn, batch_size=128)
|
||||
for batch in dataloader:
|
||||
|
||||
@@ -25,6 +25,10 @@ def get_tokenizer(conf):
|
||||
tokenizer.add_special_tokens({"pad_token": tokenizer.eos_token, "sep_token": "<|extratoken_100|>"})
|
||||
elif "codegen" in conf.model_name:
|
||||
tokenizer.add_special_tokens({"pad_token": "<|endoftext|>", "sep_token": "<|endoftext|>"})
|
||||
elif "pythia" in conf.model_name:
|
||||
tokenizer.add_special_tokens(
|
||||
{"pad_token": "<|padding|>", "sep_token": "<|endoftext|>", "eos_token": "<|endoftext|>"}
|
||||
)
|
||||
|
||||
additional_special_tokens = (
|
||||
[]
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
# README
|
||||
|
||||
## Introduction
|
||||
|
||||
This program converts data obtained from the subreddit r/changemyview into a cleaner format for further data processing. The data is not clean enough to be used directly in a model yet, and additional preprocessing is required.
|
||||
|
||||
## Data Format
|
||||
|
||||
The cleaned data is stored in an Apache Parquet file with the following columns:
|
||||
|
||||
| Column Name | Description | Data Type |
|
||||
|-------------|------------------------------------------------------------------------|----------------|
|
||||
| INSTRUCTION | Post title + body text | String |
|
||||
| RESPONSE | Body text of comments attempting to change OP's mind of `INSTRUCTION`. | List\<String\> |
|
||||
| SOURCE | Permalink to the reddit post | String |
|
||||
| METADATA | Metadata related to `RESPONSE`. | Dict\<Variant> |
|
||||
|
||||
### Metadata
|
||||
Currently, metadata is only broken into one category:
|
||||
- `detoxify_labels`- A Dictionary of values outputted by the [Unitaryai Detoxifier](https://github.com/unitaryai/detoxify) model, fitted to every comment under any given post.
|
||||
|
||||
## Usage
|
||||
|
||||
To use the program, follow these instructions:
|
||||
|
||||
1. **Clone the repository** - `git clone https://github.com/LAION-AI/Open-Assistant.git`
|
||||
2. **Navigate to the project directory** - `cd notebooks/data-augmentation/changemyview-builder`
|
||||
3. **Open the Jupyter Notebook** - `jupyter notebook data_processor.ipynb`
|
||||
4. **Run the program** - Go through the notebook and run the cells
|
||||
|
||||
## Contributing
|
||||
|
||||
If you would like to contribute to this project, please fork the repository and submit a pull request with your changes.
|
||||
|
||||
## License
|
||||
|
||||
This project is licensed under the Apache-2.0 License - see the [LICENSE](LICENSE) file for details.
|
||||
@@ -0,0 +1,577 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# r/ChangeMyView data converter\n",
|
||||
"Converts subreddit data into readable format for ML training\n",
|
||||
"\n",
|
||||
"[](https://colab.research.google.com/github/LAION-AI/Open-Assistant/blob/main/notebooks/data-augmentation/changemyview-builder/data_processor.ipynb)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 65,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"### REMEMBER: setup the .env before running this code!\n",
|
||||
"\n",
|
||||
"\"\"\"CONSTANTS\"\"\"\n",
|
||||
"\n",
|
||||
"# Set the head number to the amount of entries you want to load in minus one\n",
|
||||
"ENTRIES_COUNT = 10\n",
|
||||
"\n",
|
||||
"# Set the threshold for toxic comments to be removed\n",
|
||||
"TOXIC_THRESHOLD = 0.95"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 66,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Requirement already satisfied: pandas in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (1.4.4)\r\n",
|
||||
"Requirement already satisfied: python-dateutil>=2.8.1 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from pandas) (2.8.2)\r\n",
|
||||
"Requirement already satisfied: pytz>=2020.1 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from pandas) (2022.1)\r\n",
|
||||
"Requirement already satisfied: numpy>=1.18.5 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from pandas) (1.21.5)\r\n",
|
||||
"Requirement already satisfied: six>=1.5 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from python-dateutil>=2.8.1->pandas) (1.16.0)\r\n",
|
||||
"Requirement already satisfied: praw in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (7.6.1)\r\n",
|
||||
"Requirement already satisfied: websocket-client>=0.54.0 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from praw) (0.58.0)\r\n",
|
||||
"Requirement already satisfied: update-checker>=0.18 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from praw) (0.18.0)\r\n",
|
||||
"Requirement already satisfied: prawcore<3,>=2.1 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from praw) (2.3.0)\r\n",
|
||||
"Requirement already satisfied: requests<3.0,>=2.6.0 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from prawcore<3,>=2.1->praw) (2.28.1)\r\n",
|
||||
"Requirement already satisfied: six in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from websocket-client>=0.54.0->praw) (1.16.0)\r\n",
|
||||
"Requirement already satisfied: idna<4,>=2.5 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from requests<3.0,>=2.6.0->prawcore<3,>=2.1->praw) (3.3)\r\n",
|
||||
"Requirement already satisfied: certifi>=2017.4.17 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from requests<3.0,>=2.6.0->prawcore<3,>=2.1->praw) (2022.9.24)\r\n",
|
||||
"Requirement already satisfied: charset-normalizer<3,>=2 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from requests<3.0,>=2.6.0->prawcore<3,>=2.1->praw) (2.0.4)\r\n",
|
||||
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from requests<3.0,>=2.6.0->prawcore<3,>=2.1->praw) (1.26.11)\r\n",
|
||||
"Requirement already satisfied: python-dotenv in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (0.21.0)\r\n",
|
||||
"Requirement already satisfied: pyarrow in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (10.0.1)\r\n",
|
||||
"Requirement already satisfied: numpy>=1.16.6 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from pyarrow) (1.21.5)\r\n",
|
||||
"Requirement already satisfied: detoxify in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (0.5.1)\r\n",
|
||||
"Requirement already satisfied: transformers==4.22.1 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from detoxify) (4.22.1)\r\n",
|
||||
"Requirement already satisfied: torch>=1.7.0 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from detoxify) (1.13.1)\r\n",
|
||||
"Requirement already satisfied: sentencepiece>=0.1.94 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from detoxify) (0.1.97)\r\n",
|
||||
"Requirement already satisfied: huggingface-hub<1.0,>=0.9.0 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from transformers==4.22.1->detoxify) (0.11.1)\r\n",
|
||||
"Requirement already satisfied: regex!=2019.12.17 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from transformers==4.22.1->detoxify) (2022.7.9)\r\n",
|
||||
"Requirement already satisfied: pyyaml>=5.1 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from transformers==4.22.1->detoxify) (6.0)\r\n",
|
||||
"Requirement already satisfied: tqdm>=4.27 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from transformers==4.22.1->detoxify) (4.64.1)\r\n",
|
||||
"Requirement already satisfied: tokenizers!=0.11.3,<0.13,>=0.11.1 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from transformers==4.22.1->detoxify) (0.12.1)\r\n",
|
||||
"Requirement already satisfied: filelock in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from transformers==4.22.1->detoxify) (3.6.0)\r\n",
|
||||
"Requirement already satisfied: packaging>=20.0 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from transformers==4.22.1->detoxify) (21.3)\r\n",
|
||||
"Requirement already satisfied: numpy>=1.17 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from transformers==4.22.1->detoxify) (1.21.5)\r\n",
|
||||
"Requirement already satisfied: requests in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from transformers==4.22.1->detoxify) (2.28.1)\r\n",
|
||||
"Requirement already satisfied: typing-extensions in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from torch>=1.7.0->detoxify) (4.3.0)\r\n",
|
||||
"Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from packaging>=20.0->transformers==4.22.1->detoxify) (3.0.9)\r\n",
|
||||
"Requirement already satisfied: idna<4,>=2.5 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from requests->transformers==4.22.1->detoxify) (3.3)\r\n",
|
||||
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from requests->transformers==4.22.1->detoxify) (1.26.11)\r\n",
|
||||
"Requirement already satisfied: charset-normalizer<3,>=2 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from requests->transformers==4.22.1->detoxify) (2.0.4)\r\n",
|
||||
"Requirement already satisfied: certifi>=2017.4.17 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from requests->transformers==4.22.1->detoxify) (2022.9.24)\r\n",
|
||||
"Requirement already satisfied: tqdm in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (4.64.1)\r\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Install any dependencies\n",
|
||||
"!pip install pandas\n",
|
||||
"!pip install praw\n",
|
||||
"!pip install python-dotenv\n",
|
||||
"!pip install pyarrow\n",
|
||||
"!pip install detoxify\n",
|
||||
"!pip install tqdm"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 67,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"import praw\n",
|
||||
"import os\n",
|
||||
"from os.path import join, dirname\n",
|
||||
"from dotenv import main\n",
|
||||
"\n",
|
||||
"# Make sure you create a .env file and fill in all the necessary information in the same folder as this script!\n",
|
||||
"main.load_dotenv(join(dirname(os.path.realpath('__file__')), '.env'))\n",
|
||||
"\n",
|
||||
"reddit = praw.Reddit(\n",
|
||||
" client_id=os.environ.get(\"CLIENT_ID\"),\n",
|
||||
" client_secret=os.environ.get(\"CLIENT_SECRET\"),\n",
|
||||
" user_agent=\"CMV_Scraper\",\n",
|
||||
")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 68,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# load the data\n",
|
||||
"import tarfile\n",
|
||||
"import os.path\n",
|
||||
"import json\n",
|
||||
"import re\n",
|
||||
"from bz2 import BZ2File\n",
|
||||
"from urllib import request\n",
|
||||
"from io import BytesIO\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"fname = \"cmv.tar.bz2\"\n",
|
||||
"url = \"https://chenhaot.com/data/cmv/\" + fname\n",
|
||||
"\n",
|
||||
"# download if not exists\n",
|
||||
"if not os.path.isfile(fname):\n",
|
||||
" f = BytesIO()\n",
|
||||
" with request.urlopen(url) as resp, open(fname, 'wb') as f_disk:\n",
|
||||
" data = resp.read()\n",
|
||||
" f_disk.write(data) # save to disk too\n",
|
||||
" f.write(data)\n",
|
||||
" f.seek(0)\n",
|
||||
"else:\n",
|
||||
" f = open(fname, 'rb')\n",
|
||||
"\n",
|
||||
"\n"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 69,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/kayjaymac/opt/anaconda3/lib/python3.9/bz2.py:124: ResourceWarning: unclosed file <_io.BufferedReader name='cmv.tar.bz2'>\n",
|
||||
" self._buffer = None\n",
|
||||
"ResourceWarning: Enable tracemalloc to get the object allocation traceback\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"#tar = tarfile.open(fileobj=f, mode=\"r:bz2\")\n",
|
||||
"tar = tarfile.open(fileobj=f, mode=\"r\")\n",
|
||||
"\n",
|
||||
"# Extract the file we are interested in\n",
|
||||
"\n",
|
||||
"train_fname = \"op_task/train_op_data.jsonlist.bz2\"\n",
|
||||
"test_fname = \"op_task/heldout_op_data.jsonlist.bz2\"\n",
|
||||
"\n",
|
||||
"train_bzlist = tar.extractfile(train_fname)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 70,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Deserialize the JSON list\n",
|
||||
"original_posts_train = [\n",
|
||||
" json.loads(line.decode('utf-8'))\n",
|
||||
" for line in BZ2File(train_bzlist)\n",
|
||||
"]"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 71,
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": "[{'title': \"CMV: I shouldn't get a job in this economic climate because it'll be automated anyway; I should just wait for a post-scarcity utopia.\",\n 'delta_label': False,\n 'name': 't3_2rpsl8',\n 'selftext': \"I think the world is automating fast enough that a utopia will arise where no one will have to work anymore. Within the next 2 decades or so, having a job won't mean much, and most people will be artists and scientists. \\n\\nMy parents let me live with them, so I can just wait until the utopia happens.\\n\\nCMV.\"}]"
|
||||
},
|
||||
"execution_count": 71,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"original_posts_train[:1]"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 72,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load the jsonlist file into a dataframe\n",
|
||||
"#df = pd.read_json(original_posts_train, orient='list', lines=True)\n",
|
||||
"df = pd.DataFrame(original_posts_train)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 73,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Function to check if the posts still exists on reddit\n",
|
||||
"def try_get_post(post_id):\n",
|
||||
" try:\n",
|
||||
" submission = reddit.submission(id=post_id)\n",
|
||||
" submission.name\n",
|
||||
" return True\n",
|
||||
" except Exception as e:\n",
|
||||
" return False"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 74,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Set up the detoxifier model:\n",
|
||||
"from detoxify import Detoxify"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 75,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import re\n",
|
||||
"\n",
|
||||
"# Removes > sign and the template message at the end of a message\n",
|
||||
"def cleanup_body_text(cmv_post):\n",
|
||||
" lines = [line for line in cmv_post.splitlines()\n",
|
||||
" if not line.lstrip().startswith(\">\")\n",
|
||||
" and not line.lstrip().startswith(\"____\")\n",
|
||||
" and not line.lstrip().startswith(\"So go forth and CMV, noble redditors!\")\n",
|
||||
" and \"edit\" not in \" \".join(line.lower().split()[:2])\n",
|
||||
" ]\n",
|
||||
" return \"\\n\".join(lines)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Create the function that will be handling all the data gathering\n",
|
||||
"def get_top_comment_and_clean_data(post_id):\n",
|
||||
" #print(post_id.lstrip(\"t3_\"))\n",
|
||||
" last_author = \"\"\n",
|
||||
" # Grab the post\n",
|
||||
" submission = reddit.submission(id=post_id.lstrip(\"t3_\"))\n",
|
||||
" #print(submission.title)\n",
|
||||
"\n",
|
||||
" # Grab the highest rated comment on root layer\n",
|
||||
" submission.submission_type = 'best'\n",
|
||||
" submission.comments.replace_more(limit=0)\n",
|
||||
" replies = list(submission.comments)[0].replies.list()\n",
|
||||
"\n",
|
||||
" # Just some variables\n",
|
||||
" pros = []\n",
|
||||
"\n",
|
||||
" # If the post author doesn't exist this submission was deleted (submission.deleted doesn't work)\n",
|
||||
" if type(submission.author) == type(None):\n",
|
||||
" last_author = \"[deleted]\"\n",
|
||||
" else:\n",
|
||||
" last_author = submission.author.name\n",
|
||||
"\n",
|
||||
" is_pro_argument = False\n",
|
||||
"\n",
|
||||
" for comment in replies:\n",
|
||||
"\n",
|
||||
" # If redditor object doesn't exist, the account is invalid/deleted\n",
|
||||
" if type(comment.author) != type(None):\n",
|
||||
" author = comment.author.name\n",
|
||||
" else:\n",
|
||||
" author = \"[deleted]\"\n",
|
||||
"\n",
|
||||
" # Assume that whenever the user changes, they are countering the previous person\n",
|
||||
" if author != last_author:\n",
|
||||
" is_pro_argument = !is_pro_argument\n",
|
||||
"\n",
|
||||
" if author == \"[deleted]\" or author==\"DeltaBot\":\n",
|
||||
" #print(\"Skipping comment...\")\n",
|
||||
" continue\n",
|
||||
"\n",
|
||||
" # Remove meta and duplicate comments\n",
|
||||
" comment.body = \" \".join([line for line in comment.body.splitlines()\n",
|
||||
" if not re.search(r\"(?i)(Change\\smy\\sview|CMV)\", line)\n",
|
||||
" and line not in pros # Why doesn't this line work\n",
|
||||
" ])\n",
|
||||
"\n",
|
||||
" # Sometimes for some reason duplicate entries exist\n",
|
||||
" # Also remove automated message with \"Δ\" in it\n",
|
||||
"\n",
|
||||
" if comment.body in pros:\n",
|
||||
" #print(\"Skipping duplicate entry\")\n",
|
||||
" continue\n",
|
||||
"\n",
|
||||
" #print(\"\\t\\t>>\\t\",comment.body)\n",
|
||||
"\n",
|
||||
" # Remove toxic comments\n",
|
||||
" if Detoxify(\"multilingual\").predict(comment.body)[\"toxicity\"] > TOXIC_THRESHOLD:\n",
|
||||
" #print(\"Identified toxic comment, ignoring...\")\n",
|
||||
" comment.body = \"\"\n",
|
||||
"\n",
|
||||
" # Add to the respective argument type \n",
|
||||
" if is_pro_argument:\n",
|
||||
" pros.append(comment.body)\n",
|
||||
" \n",
|
||||
" last_author = comment.author.name\n",
|
||||
" \n",
|
||||
" # Pros = arguments for the Title of this post\n",
|
||||
" # Cons = arguments against the title of this post\n",
|
||||
"\n",
|
||||
" pros.append(comment.body)\n",
|
||||
" return pros"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 76,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Loading in 10 posts\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(f\"Loading in {ENTRIES_COUNT} posts\")\n",
|
||||
"dataset = df.head(ENTRIES_COUNT)\n"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 77,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# the name column does some weird sh** because dataframes already have a name property, so migrate to a different column name\n",
|
||||
"\n",
|
||||
"import warnings\n",
|
||||
"warnings.filterwarnings('ignore')\n",
|
||||
"\n",
|
||||
"dataset[\"post_id\"] = dataset[\"name\"]\n",
|
||||
"warnings.filterwarnings('default')"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 78,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Loading in data... This will take a while.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": " 0%| | 0/10 [00:00<?, ?it/s]",
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"version_major": 2,
|
||||
"version_minor": 0,
|
||||
"model_id": "203fd74c8a5146b68b8af961bb3874c8"
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"<timed exec>:29: SettingWithCopyWarning: \n",
|
||||
"A value is trying to be set on a copy of a slice from a DataFrame\n",
|
||||
"\n",
|
||||
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
|
||||
"/Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages/torch/serialization.py:997: ResourceWarning: unclosed file <_io.BufferedReader name='cmv.tar.bz2'>\n",
|
||||
" storage = zip_file.get_storage_from_record(name, numel, torch._UntypedStorage).storage()._untyped()\n",
|
||||
"ResourceWarning: Enable tracemalloc to get the object allocation traceback\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"CPU times: user 7min 49s, sys: 2min 29s, total: 10min 19s\n",
|
||||
"Wall time: 8min 45s\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"\n",
|
||||
"from tqdm.auto import tqdm\n",
|
||||
"# Reset variables for if we run this multiple times\n",
|
||||
"all_pros = []\n",
|
||||
"all_names = []\n",
|
||||
"all_titles = []\n",
|
||||
"all_sources = []\n",
|
||||
"\n",
|
||||
"print(\"Loading in data... This will take a while.\")\n",
|
||||
"\n",
|
||||
"for i in tqdm(range(dataset.shape[0])):\n",
|
||||
"\n",
|
||||
" post = dataset.iloc[i]\n",
|
||||
" modified_title = post.title.replace('CMV', \"Change my mind\")\n",
|
||||
" #print(f\"\\n Loading entry {i+1}/{dataset.shape[0]}:\\n\\t\\\"{modified_title}\\\"\")\n",
|
||||
"\n",
|
||||
" if type(post) == type(None):\n",
|
||||
" continue\n",
|
||||
"\n",
|
||||
" assert(post.post_id != i)\n",
|
||||
"\n",
|
||||
" pros = get_top_comment_and_clean_data(post.post_id)\n",
|
||||
"\n",
|
||||
" if post.title == \"[deleted]\":\n",
|
||||
" continue\n",
|
||||
"\n",
|
||||
" pros = \" \".join([*set(pros)])\n",
|
||||
" pros = pros.replace(\"[deleted]\",\"\")\n",
|
||||
"\n",
|
||||
" post.selftext = cleanup_body_text(post.selftext)\n",
|
||||
" all_titles.append(modified_title + \" \" + post.selftext)\n",
|
||||
" all_pros.append(pros)\n",
|
||||
" all_names.append(post.name)\n",
|
||||
" all_sources.append(f\"https://reddit.com/r/changemyview/comments/{post.post_id}\")\n",
|
||||
" #print(post.title)\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 83,
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": "'it\\'s already been signed. They even claim to be adhering to it, though they\\'ve been found to be violating it before. There is no such thing as \"de facto acceptance of Israel\\'s nuclear program.\" the Non-Proliferation Treaty is only binding for signatory states. Israel is not a signatory. Article 10 of the NPT allows them to withdraw if they so choose. they have not done so. a whole new country which explicitly has a right to withdraw from the NPT and has not chosen to do so. It\\'s more accurate, I think, to say that the problem with Iran here from a legal standpoint is that they aren\\'t honoring their own commitments, rather than that they\\'re building weapons. They could pull out of the NPT at any time, and the ball would be essentially in America\\'s court, because their nuclear program would no longer be illegal by international legal standards. However, Iran insists both on developing nukes *and* remaining an NPT signatory non-nuclear state, and that\\'s what makes their program illegal. I\\'d also like to clarify that I\\'m not making an ethical argument here, this is just how international law currently works. because international law doesn\\'t require states to sign treaties, it only requires them to adhere to treaties they\\'ve already signed. Israel isn\\'t defying the UN, at least not in this particular case. Think of the NPT less like a standard law within a state and more like a contract. Once you\\'ve signed, you\\'re bound by the contract, but if you never sign it then you haven\\'t broken a law, you\\'ve just decided not to agree to the terms you were offered. > Because Iran did sign the treaty, and thus are bound by it. They signed on July 1, 1968. Hmm. So is the argument here that it\\'s not \"ok\" for Iran to have a nuke, since they signed treaty not to do so. But it\\'s \"ok\" for Israel to have one because they never signed such thing? Can\\'t quite put my finger on it, but doesn\\'t seem quite right this one.'"
|
||||
},
|
||||
"execution_count": 83,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"all_pros[1]"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 80,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Place it all into a Pandas Dataframe\n",
|
||||
"clean_df = pd.DataFrame({\n",
|
||||
" \"INSTRUCTION\": all_titles,\n",
|
||||
" \"RESPONSE\": all_pros,\n",
|
||||
" \"SOURCE\": all_sources\n",
|
||||
"}, index=all_names\n",
|
||||
")"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 81,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Create Apache Paquete file\n",
|
||||
"\n",
|
||||
"import pyarrow as pa\n",
|
||||
"import pyarrow.parquet as pq\n",
|
||||
"\n",
|
||||
"table = pa.Table.from_pandas(clean_df)\n",
|
||||
"pq.write_table(table,\"output.parquet\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 82,
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": " INSTRUCTION \\\n0 Change my mind: I shouldn't get a job in this ... \n1 Change my mind: Iran has the right to develop ... \n2 Change my mind: The events in Paris suck...but... \n3 Change my mind: It is ok to hate a religion so... \n4 Change my mind: There is no productive reason ... \n5 Change my mind: Diet soda is perfectly healthy... \n6 Change my mind:Essential Oils are bullshit My ... \n7 Change my mind: I think the Paris shooting mak... \n8 Change my mind: Printing an image of the Musli... \n9 Change my mind: Philosophy has no tangible val... \n\n RESPONSE \\\n0 That is what someone in the 1500s would have s... \n1 it's already been signed. They even claim to b... \n2 Hm I guess I made the OP incorrectly. The mai... \n3 I don't understand your analogy. Promoting a ... \n4 ∆ I hadn't thought it from a \"let's trick peop... \n5 Thanks for a fresh argument! I hadn't conside... \n6 Most do. Some smell kinda funky. \n7 I already said in different comments that thi... \n8 The first bacon sandwich came about because 9... \n9 >Why restrict it to 50 years? I can name all s... \n\n SOURCE \n0 https://reddit.com/r/changemyview/comments/t3_... \n1 https://reddit.com/r/changemyview/comments/t3_... \n2 https://reddit.com/r/changemyview/comments/t3_... \n3 https://reddit.com/r/changemyview/comments/t3_... \n4 https://reddit.com/r/changemyview/comments/t3_... \n5 https://reddit.com/r/changemyview/comments/t3_... \n6 https://reddit.com/r/changemyview/comments/t3_... \n7 https://reddit.com/r/changemyview/comments/t3_... \n8 https://reddit.com/r/changemyview/comments/t3_... \n9 https://reddit.com/r/changemyview/comments/t3_... ",
|
||||
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>INSTRUCTION</th>\n <th>RESPONSE</th>\n <th>SOURCE</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>Change my mind: I shouldn't get a job in this ...</td>\n <td>That is what someone in the 1500s would have s...</td>\n <td>https://reddit.com/r/changemyview/comments/t3_...</td>\n </tr>\n <tr>\n <th>1</th>\n <td>Change my mind: Iran has the right to develop ...</td>\n <td>it's already been signed. They even claim to b...</td>\n <td>https://reddit.com/r/changemyview/comments/t3_...</td>\n </tr>\n <tr>\n <th>2</th>\n <td>Change my mind: The events in Paris suck...but...</td>\n <td>Hm I guess I made the OP incorrectly. The mai...</td>\n <td>https://reddit.com/r/changemyview/comments/t3_...</td>\n </tr>\n <tr>\n <th>3</th>\n <td>Change my mind: It is ok to hate a religion so...</td>\n <td>I don't understand your analogy. Promoting a ...</td>\n <td>https://reddit.com/r/changemyview/comments/t3_...</td>\n </tr>\n <tr>\n <th>4</th>\n <td>Change my mind: There is no productive reason ...</td>\n <td>∆ I hadn't thought it from a \"let's trick peop...</td>\n <td>https://reddit.com/r/changemyview/comments/t3_...</td>\n </tr>\n <tr>\n <th>5</th>\n <td>Change my mind: Diet soda is perfectly healthy...</td>\n <td>Thanks for a fresh argument! I hadn't conside...</td>\n <td>https://reddit.com/r/changemyview/comments/t3_...</td>\n </tr>\n <tr>\n <th>6</th>\n <td>Change my mind:Essential Oils are bullshit My ...</td>\n <td>Most do. Some smell kinda funky.</td>\n <td>https://reddit.com/r/changemyview/comments/t3_...</td>\n </tr>\n <tr>\n <th>7</th>\n <td>Change my mind: I think the Paris shooting mak...</td>\n <td>I already said in different comments that thi...</td>\n <td>https://reddit.com/r/changemyview/comments/t3_...</td>\n </tr>\n <tr>\n <th>8</th>\n <td>Change my mind: Printing an image of the Musli...</td>\n <td>The first bacon sandwich came about because 9...</td>\n <td>https://reddit.com/r/changemyview/comments/t3_...</td>\n </tr>\n <tr>\n <th>9</th>\n <td>Change my mind: Philosophy has no tangible val...</td>\n <td>>Why restrict it to 50 years? I can name all s...</td>\n <td>https://reddit.com/r/changemyview/comments/t3_...</td>\n </tr>\n </tbody>\n</table>\n</div>"
|
||||
},
|
||||
"execution_count": 82,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Test to see if it was sucessful\n",
|
||||
"table = pq.read_table(\"output.parquet\")\n",
|
||||
"table.to_pandas()"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 1
|
||||
}
|
||||
@@ -38,6 +38,8 @@ class OasstErrorCode(IntEnum):
|
||||
TASK_REQUESTED_TYPE_NOT_AVAILABLE = 1006
|
||||
TASK_AVAILABILITY_QUERY_FAILED = 1007
|
||||
TASK_MESSAGE_TOO_LONG = 1008
|
||||
TASK_MESSAGE_DUPLICATED = 1009
|
||||
TASK_MESSAGE_TEXT_EMPTY = 1010
|
||||
|
||||
# 2000-3000: prompt_repository
|
||||
INVALID_FRONTEND_MESSAGE_ID = 2000
|
||||
@@ -76,6 +78,8 @@ class OasstErrorCode(IntEnum):
|
||||
USER_DISABLED = 4001
|
||||
USER_NOT_FOUND = 4002
|
||||
|
||||
EMOJI_OP_UNSUPPORTED = 5000
|
||||
|
||||
|
||||
class OasstError(Exception):
|
||||
"""Base class for Open-Assistant exceptions."""
|
||||
|
||||
@@ -13,7 +13,11 @@ class WorkRequest(pydantic.BaseModel):
|
||||
conversation: protocol.Conversation = pydantic.Field(..., repr=False)
|
||||
model_name: str = "distilgpt2"
|
||||
max_new_tokens: int = 100
|
||||
seed: int = pydantic.Field(default_factory=lambda: random.randint(0, 2**32 - 1))
|
||||
seed: int = pydantic.Field(default_factory=lambda: random.randint(-(2**31), 2**31 - 1))
|
||||
do_sample: bool = True
|
||||
top_k: int = 50
|
||||
top_p: float = 0.9
|
||||
temperature: float = 1.0
|
||||
|
||||
|
||||
class WorkResponsePacket(pydantic.BaseModel):
|
||||
|
||||
@@ -57,6 +57,8 @@ class ConversationMessage(BaseModel):
|
||||
text: str
|
||||
lang: Optional[str] # BCP 47
|
||||
is_assistant: bool
|
||||
emojis: Optional[dict[str, int]] = None
|
||||
user_emojis: Optional[list[str]] = None
|
||||
|
||||
|
||||
class Conversation(BaseModel):
|
||||
@@ -222,27 +224,43 @@ class LabelTaskMode(str, enum.Enum):
|
||||
full = "full"
|
||||
|
||||
|
||||
class LabelInitialPromptTask(Task):
|
||||
"""A task to label an initial prompt."""
|
||||
class LabelTaskDisposition(str, enum.Enum):
|
||||
"""Reason why the task was issued."""
|
||||
|
||||
type: Literal["label_initial_prompt"] = "label_initial_prompt"
|
||||
quality = "quality"
|
||||
spam = "spam"
|
||||
|
||||
|
||||
class LabelDescription(BaseModel):
|
||||
name: str
|
||||
widget: str
|
||||
display_text: str
|
||||
help_text: Optional[str]
|
||||
|
||||
|
||||
class AbstractLabelTask(Task):
|
||||
message_id: UUID
|
||||
prompt: str
|
||||
valid_labels: list[str]
|
||||
mandatory_labels: Optional[list[str]]
|
||||
mode: Optional[LabelTaskMode]
|
||||
disposition: Optional[LabelTaskDisposition]
|
||||
labels: Optional[list[LabelDescription]]
|
||||
|
||||
|
||||
class LabelConversationReplyTask(Task):
|
||||
class LabelInitialPromptTask(AbstractLabelTask):
|
||||
"""A task to label an initial prompt."""
|
||||
|
||||
type: Literal["label_initial_prompt"] = "label_initial_prompt"
|
||||
prompt: str
|
||||
|
||||
|
||||
class LabelConversationReplyTask(AbstractLabelTask):
|
||||
"""A task to label a reply to a conversation."""
|
||||
|
||||
type: Literal["label_conversation_reply"] = "label_conversation_reply"
|
||||
conversation: Conversation # the conversation so far
|
||||
message_id: UUID
|
||||
reply_message: Optional[ConversationMessage]
|
||||
reply: str
|
||||
valid_labels: list[str]
|
||||
mandatory_labels: Optional[list[str]]
|
||||
mode: Optional[LabelTaskMode]
|
||||
|
||||
|
||||
class LabelPrompterReplyTask(LabelConversationReplyTask):
|
||||
@@ -315,39 +333,48 @@ class MessageRanking(Interaction):
|
||||
ranking: conlist(item_type=int, min_items=1)
|
||||
|
||||
|
||||
class LabelWidget(str, enum.Enum):
|
||||
yes_no = "yes_no"
|
||||
flag = "flag"
|
||||
likert = "likert"
|
||||
|
||||
|
||||
class TextLabel(str, enum.Enum):
|
||||
"""A label for a piece of text."""
|
||||
|
||||
def __new__(cls, label: str, display_text: str = "", help_text: str = None):
|
||||
def __new__(cls, label: str, widget: LabelWidget, display_text: str = "", help_text: str = None):
|
||||
obj = str.__new__(cls, label)
|
||||
obj._value_ = label
|
||||
obj.widget = widget
|
||||
obj.display_text = display_text
|
||||
obj.help_text = help_text
|
||||
return obj
|
||||
|
||||
spam = "spam", "Seems to be intentionally low-quality or irrelevant"
|
||||
fails_task = "fails_task", "Fails to follow the correct instruction / task"
|
||||
not_appropriate = "not_appropriate", "Inappropriate for customer assistant"
|
||||
violence = "violence", "Encourages or fails to discourage violence/abuse/terrorism/self-harm"
|
||||
excessive_harm = (
|
||||
"excessive_harm",
|
||||
"Content likely to cause excessive harm not justifiable in the context",
|
||||
"Harm refers to physical or mental damage or injury to someone or something. Excessive refers to a reasonable threshold of harm in the context, for instance damaging skin is not excessive in the context of surgery.",
|
||||
)
|
||||
sexual_content = "sexual_content", "Contains sexual content"
|
||||
toxicity = "toxicity", "Contains rude, abusive, profane or insulting content"
|
||||
moral_judgement = "moral_judgement", "Expresses moral judgement"
|
||||
political_content = "political_content", "Expresses political views"
|
||||
humor = "humor", "Contains humorous content including sarcasm"
|
||||
# yes/no questions
|
||||
spam = "spam", LabelWidget.yes_no, "Seems to be intentionally low-quality or irrelevant"
|
||||
fails_task = "fails_task", LabelWidget.yes_no, "Fails to follow the correct instruction / task"
|
||||
|
||||
# flags
|
||||
pii = "pii", LabelWidget.flag, "Contains personal identifiable information (PII)"
|
||||
not_appropriate = "not_appropriate", LabelWidget.flag, "Inappropriate"
|
||||
hate_speech = (
|
||||
"hate_speech",
|
||||
LabelWidget.flag,
|
||||
"Content is abusive or threatening and expresses prejudice against a protected characteristic",
|
||||
"Prejudice refers to preconceived views not based on reason. Protected characteristics include gender, ethnicity, religion, sexual orientation, and similar characteristics.",
|
||||
"Prejudice refers to preconceived views not based on reason. Protected characteristics "
|
||||
"include gender, ethnicity, religion, sexual orientation, and similar characteristics.",
|
||||
)
|
||||
threat = "threat", "Contains a threat against a person or persons"
|
||||
misleading = "misleading", "Contains text which is incorrect or misleading"
|
||||
helpful = "helpful", "Completes the task to a high standard"
|
||||
creative = "creative", "Expresses creativity in responding to the task"
|
||||
sexual_content = "sexual_content", LabelWidget.flag, "Contains sexual content"
|
||||
moral_judgement = "moral_judgement", LabelWidget.flag, "Expresses moral judgement"
|
||||
political_content = "political_content", LabelWidget.flag, "Expresses political views"
|
||||
|
||||
# likert
|
||||
quality = "quality", LabelWidget.likert, "Overall subjective quality rating of the message"
|
||||
toxicity = "toxicity", LabelWidget.likert, "Rude, abusive, profane or insulting content"
|
||||
humor = "humor", LabelWidget.likert, "Humorous content including sarcasm"
|
||||
helpfulness = "helpfulness", LabelWidget.likert, "Helpfulness of the message"
|
||||
creativity = "creativity", LabelWidget.likert, "Creativity"
|
||||
violence = "violence", LabelWidget.likert, "Violence/abuse/terrorism/self-harm"
|
||||
|
||||
|
||||
class TextLabels(Interaction):
|
||||
@@ -358,6 +385,7 @@ class TextLabels(Interaction):
|
||||
labels: dict[TextLabel, float]
|
||||
message_id: UUID
|
||||
task_id: Optional[UUID]
|
||||
is_report: Optional[bool]
|
||||
|
||||
@property
|
||||
def has_message_id(self) -> bool:
|
||||
@@ -432,3 +460,29 @@ class OasstErrorResponse(BaseModel):
|
||||
|
||||
error_code: OasstErrorCode
|
||||
message: str
|
||||
|
||||
|
||||
class EmojiCode(str, enum.Enum):
|
||||
thumbs_up = "+1" # 👍
|
||||
thumbs_down = "-1" # 👎
|
||||
red_flag = "red_flag" # 🚩
|
||||
hundred = "100" # 💯
|
||||
rofl = "rofl" # 🤣
|
||||
clap = "clap" # 👏
|
||||
diamond = "diamond" # 💎
|
||||
heart_eyes = "heart_eyes" # 😍
|
||||
disappointed = "disappointed" # 😞
|
||||
poop = "poop" # 💩
|
||||
skull = "skull" # 💀
|
||||
|
||||
|
||||
class EmojiOp(str, enum.Enum):
|
||||
togggle = "toggle"
|
||||
add = "add"
|
||||
remove = "remove"
|
||||
|
||||
|
||||
class MessageEmojiRequest(BaseModel):
|
||||
user: User
|
||||
op: EmojiOp = EmojiOp.togggle
|
||||
emoji: EmojiCode
|
||||
|
||||
@@ -0,0 +1,101 @@
|
||||
---
|
||||
annotations_creators:
|
||||
- no-annotation
|
||||
language:
|
||||
- en
|
||||
language_creators:
|
||||
- machine-generated
|
||||
license:
|
||||
- mit
|
||||
multilinguality:
|
||||
- monolingual
|
||||
pretty_name: "MT Note Generation"
|
||||
size_categories:
|
||||
- <500
|
||||
source_datasets:
|
||||
- mt_samples
|
||||
tags:
|
||||
- open-assistant
|
||||
task_categories:
|
||||
- conversational
|
||||
task_ids:
|
||||
- dialogue-generation
|
||||
dataset_info:
|
||||
features:
|
||||
- name: conversation
|
||||
dtype: string
|
||||
splits:
|
||||
- name: train
|
||||
num_bytes: 1060800
|
||||
num_examples: 270
|
||||
- name: test
|
||||
num_bytes: 339323
|
||||
num_examples: 90
|
||||
- name: validation
|
||||
num_bytes: 337018
|
||||
num_examples: 90
|
||||
---
|
||||
|
||||
# Dataset Card for MT Samples Clinical Note Generation
|
||||
|
||||
## Dataset Description
|
||||
|
||||
- **Repository:**
|
||||
[Code](https://github.com/LAION-AI/Open-Assistant/openassistant/datasets/mt_note_generation/)
|
||||
|
||||
## Dataset Summary
|
||||
|
||||
MT Note Generation is a set of synthetic dialogues between Assistant and User
|
||||
where the user asks the assistant to generate a clinical note for a patient
|
||||
persona.
|
||||
|
||||
This data was created by synthesizing the dialogues in
|
||||
[MT Samples](https://mtsamples.com/) for clinical note generation.
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
User: Write a clinical note about a patient with the following chief complaint: Stomach pain for 2 weeks.
|
||||
|
||||
ASSISTANT:
|
||||
|
||||
CHIEF COMPLAINT: Stomach pain for 2 weeks. HISTORY OF PRESENT ILLNESS: The patient is a 45yo Mexican man without significant past medical history who presents to the emergency room with complaints of mid-epigastric and right upper quadrant abdominal pain for the last 14 days. The pain was initially crampy and burning in character and was relieved with food intake. He also reports that it initially was associated with a sour taste in his mouth. He went to his primary care physician who prescribed cimetidine 400mg qhs x 5 days; however, this did not relieve his symptoms. In fact, the pain has worsened such that the pain now radiates to the back but is waxing and waning in duration. It is relieved with standing and ambulation and exacerbated when lying in a supine position. He reports a decrease in appetite associated with a 4 lb. wt loss over the last 2 wks. He does have nausea with only one episode of non-bilious, non-bloody emesis on day of admission. He reports a 2 wk history of subjective fever and diaphoresis. He denies any diarrhea, constipation, dysuria, melena, or hematochezia. His last bowel movement was during the morning of admission and was normal. He denies any travel in the last 9 years and sick contacts.PAST MEDICAL HISTORY: Right inguinal groin cyst removal 15 years ago. Unknown etiology. No recurrence.
|
||||
|
||||
PAST SURGICAL HISTORY: Left femoral neck fracture with prosthesis secondary to a fall 4 years ago.
|
||||
|
||||
FAMILY HISTORY: Mother with diabetes. No history of liver disease. No malignancies.
|
||||
|
||||
SOCIAL HISTORY: The patient was born in central Mexico but moved to the United States 9 years ago. He is on disability due to his prior femoral fracture. He denies any tobacco or illicit drug use. He only drinks alcohol socially, no more than 1 drink every few weeks. He is married and has 3 healthy children. He denies any tattoos or risky sexual behavior.
|
||||
|
||||
ALLERGIES: NKDA.
|
||||
|
||||
MEDICATIONS: Tylenol prn (1-2 tabs every other day for the last 2 wks), Cimetidine 400mg po qhs x 5 days.
|
||||
|
||||
REVIEW OF SYSTEMS: No headache, vision changes. No shortness of breath. No chest pain or palpitations.
|
||||
|
||||
PHYSICAL EXAMINATION: Vitals: T 100.9-102.7 BP 136/86 Pulse 117 RR 12 98% sat on room air,Gen: Well-developed, well-nourished, no apparent distress.HEENT: Pupils equal, round and reactive to light. Anicteric. Oropharynx clear and moist.Neck: Supple. No lymphadenopathy or carotid bruits. No thyromegaly or masses.CHEST: Clear to auscultation bilaterally.CV: Tachycardic but regular rhythm, normal S1/S2, no murmurs/rubs/gallops.Abd: Soft, active bowel sounds. Tender in the epigastrium and right upper quadrant with palpation associated with slight guarding. No rebound tenderness. No hepatomegaly. No splenomegaly.Rectal: Stool was brown and guaiac negative.Ext: No cyanosis/clubbing/edema.Neurological: He was alert and oriented x3. CN II-XII intact. Normal 2+ DTRs. No focal neurological deficit.Skin: No jaundice. No skin rashes or lesions.
|
||||
|
||||
IMAGING DATA:CT Abdomen with contrast ( 11/29/03 ): There is a 6x6 cm multilobular hypodense mass seen at the level of the hepatic hilum and caudate lobe which is resulting in mass effect with dilatation of the intrahepatic radicals of the left lobe of the liver. The rest of the liver parenchyma is homogeneous. The gallbladder, pancreas, spleen, adrenal glands and kidneys are within normal limits. The retroperitoneal vascular structures are within normal limits. There is no evidence of lymphadenopathy, free fluid or fluid collections.HOSPITAL COURSE: The patient was admitted to the hospital for further evaluation. A diagnostic procedure was performed.
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
The dataset contains one configuration, `dialogue_modeling`, which has a single
|
||||
text `conversation` feature.
|
||||
|
||||
## Source data
|
||||
|
||||
The script modifies data from mtsamples.csv which is hosted in Kaggle:
|
||||
https://www.kaggle.com/datasets/tboyle10/medicaltranscriptions
|
||||
|
||||
## Citation
|
||||
|
||||
Please cite our work if you find the resources in this repository useful:
|
||||
|
||||
```
|
||||
@article{pruks2023mtsamplesnotegen,
|
||||
author = {Yada Pruksachatkun},
|
||||
title = {MT Samples Note Generation},
|
||||
year = {2023}
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,21 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import datasets
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenAssistantConfig(datasets.BuilderConfig):
|
||||
"""BuilderConfig for OpenAssistant datasets."""
|
||||
|
||||
name: str = None
|
||||
version: datasets.Version = None
|
||||
description: str = None
|
||||
schema: str = None
|
||||
subset_id: str = None
|
||||
|
||||
|
||||
features = datasets.Features(
|
||||
{
|
||||
"conversation": datasets.Value("string"),
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,123 @@
|
||||
# Copyright 2023 The OpenAssistant Authors and the current dataset script contributor.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
MT Note Generation is a set of synthetic dialogues between Assistant and
|
||||
User where the user asks the assistant to generate a clinical note for a patient persona.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import datasets
|
||||
|
||||
from .hub import OpenAssistantConfig, features
|
||||
|
||||
_CITATION = """\
|
||||
@misc{transcribed medical transcription sample reports and examples, title={Welcome to MTSamples},
|
||||
url={https://mtsamples.com/},
|
||||
journal={Transcribed Medical Transcription Sample Reports and Examples}}
|
||||
"""
|
||||
|
||||
_DATASETNAME = "mt_note_generation"
|
||||
_DISPLAYNAME = "MT Samples Note Generation"
|
||||
|
||||
_DESCRIPTION = """\
|
||||
A dataset of instructions for generating clinical notes from MT samples.
|
||||
"""
|
||||
|
||||
_HOMEPAGE = ""
|
||||
|
||||
_LICENSE = "mit"
|
||||
|
||||
_URLS = {
|
||||
_DATASETNAME: {
|
||||
"train": "./data/mt_note_generation_train.jsonl",
|
||||
"test": "./data/mt_note_generation_test.jsonl",
|
||||
"validation": "./data/mt_note_generation_validation.jsonl",
|
||||
}
|
||||
}
|
||||
|
||||
_SUPPORTED_TASKS = ["dialogue-modeling"]
|
||||
|
||||
_VERSION = "1.0.0"
|
||||
|
||||
|
||||
class MTNoteGenerationDataset(datasets.GeneratorBasedBuilder):
|
||||
"""A set of dialogues synthesized from the MT Samples dataset."""
|
||||
|
||||
VERSION = datasets.Version(_VERSION)
|
||||
|
||||
BUILDER_CONFIGS = [
|
||||
OpenAssistantConfig(
|
||||
name=f"{_DATASETNAME}_dialogue_modeling",
|
||||
version=VERSION,
|
||||
description=f"OpenAssistant dataset config for {_DATASETNAME}",
|
||||
schema="dialogue_modeling",
|
||||
subset_id=_DATASETNAME,
|
||||
)
|
||||
]
|
||||
|
||||
DEFAULT_CONFIG_NAME = f"{_DATASETNAME}_dialogue_modeling"
|
||||
|
||||
def _info(self) -> datasets.DatasetInfo:
|
||||
|
||||
return datasets.DatasetInfo(
|
||||
description=_DESCRIPTION,
|
||||
features=features,
|
||||
homepage=_HOMEPAGE,
|
||||
license=_LICENSE,
|
||||
citation=_CITATION,
|
||||
)
|
||||
|
||||
def _split_generators(self, dl_manager) -> List[datasets.SplitGenerator]:
|
||||
|
||||
urls = _URLS[_DATASETNAME]
|
||||
data_dir = dl_manager.download_and_extract(urls)
|
||||
return [
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TRAIN,
|
||||
# Whatever you put in gen_kwargs will be passed to _generate_examples
|
||||
gen_kwargs={
|
||||
"filepath": data_dir,
|
||||
"split": "train",
|
||||
},
|
||||
),
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TEST,
|
||||
gen_kwargs={
|
||||
"filepath": data_dir,
|
||||
"split": "test",
|
||||
},
|
||||
),
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.VALIDATION,
|
||||
gen_kwargs={
|
||||
"filepath": data_dir,
|
||||
"split": "validation",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
def _generate_examples(self, filepath, split: str) -> Tuple[int, Dict]:
|
||||
"""Yields examples as (key, example) tuples."""
|
||||
if self.config.schema == "dialogue_modeling":
|
||||
key = 0
|
||||
with open(filepath[split], "r", encoding="utf8") as data:
|
||||
while True:
|
||||
line = data.readline()
|
||||
if not line:
|
||||
return
|
||||
yield key, json.loads(line)
|
||||
key += 1
|
||||
@@ -0,0 +1,84 @@
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
from string import punctuation
|
||||
|
||||
import kaggle
|
||||
import pandas as pd
|
||||
|
||||
CLINICAL_NOTE_GENERATION_TEMPLATE = """User: Write a clinical note about a patient with the following {section}: {section_information}.
|
||||
Rosey: {note}"""
|
||||
|
||||
|
||||
def preprocess(mt_dataset):
|
||||
def filter_for_notes(row):
|
||||
normalized_transcript = row["transcription"].lower()
|
||||
if "chief complaint:" in normalized_transcript:
|
||||
return True
|
||||
return False
|
||||
|
||||
mt_dataset = mt_dataset.dropna(subset=["description", "transcription"])
|
||||
mt_note_subset = mt_dataset[mt_dataset.apply(filter_for_notes, axis=1)]
|
||||
return mt_note_subset
|
||||
|
||||
|
||||
def is_chief_complaint(section):
|
||||
return "chief complaint" in section.lower()
|
||||
|
||||
|
||||
def get_conversations(dataset):
|
||||
def normalize_transcript(x):
|
||||
x = re.sub(r"\.+", ".", x)
|
||||
x = re.sub(r"\,+", ",", x)
|
||||
x = re.sub(r":\s+", ": ", x)
|
||||
x = re.sub(r"\.\s+", ". ", x)
|
||||
x = re.sub(r":(\s)*\,+", ": ", x)
|
||||
x = re.sub(r"\.\,+", ". ", x)
|
||||
return x
|
||||
|
||||
conversations = []
|
||||
for idx in range(len(dataset)):
|
||||
transcript = normalize_transcript(dataset.iloc[idx]["transcription"])
|
||||
sections = re.findall(r"\b[A-Z]+(?: [A-Z]+)*:", transcript)
|
||||
if len(sections) >= 2:
|
||||
note_prompt = transcript.split(sections[0])[1].split(sections[1])[0]
|
||||
else:
|
||||
continue
|
||||
section_name = sections[0].lower().strip(punctuation)
|
||||
if len(note_prompt.split(" ")) > 30 and is_chief_complaint(section_name):
|
||||
# There are some chief complaints that seem to be HPI
|
||||
section_name = "history of present illness"
|
||||
conversations.append(
|
||||
CLINICAL_NOTE_GENERATION_TEMPLATE.format(
|
||||
section=section_name, section_information=note_prompt, note=transcript
|
||||
)
|
||||
)
|
||||
return conversations
|
||||
|
||||
|
||||
def main(output_dir: str = "data"):
|
||||
"""Download and prepare the dataset for use."""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
kaggle.api.dataset_download_files("tboyle10/medicaltranscriptions", "data", unzip=True)
|
||||
mt_samples = preprocess(pd.read_csv("mtsamples.csv"))
|
||||
conversations = get_conversations(mt_samples)
|
||||
random.shuffle(conversations)
|
||||
train_limit = math.ceil(len(conversations) * 0.6)
|
||||
dev_limit = math.ceil(len(conversations) * 0.8)
|
||||
train, validation, test = (
|
||||
conversations[:train_limit],
|
||||
conversations[train_limit:dev_limit],
|
||||
conversations[dev_limit:],
|
||||
)
|
||||
splits = {"train": train, "validation": validation, "test": test}
|
||||
for split in ["train", "validation", "test"]:
|
||||
with open(f"{output_dir}/mt_note_generation_{split}.jsonl", "w") as f:
|
||||
for conversation in splits[split]:
|
||||
f.write(f"{json.dumps({'conversation': conversation})}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -7,6 +7,7 @@ pushd "$parent_path/../../backend"
|
||||
export DEBUG_USE_SEED_DATA=True
|
||||
export DEBUG_SKIP_TOXICITY_CALCULATION=True
|
||||
export DEBUG_ALLOW_SELF_LABELING=True
|
||||
export DEBUG_ALLOW_DUPLICATE_TASKS=True
|
||||
export DEBUG_SKIP_EMBEDDING_COMPUTATION=True
|
||||
|
||||
uvicorn main:app --reload --port 8080 --host 0.0.0.0
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
.eslintrc.json
|
||||
tailwind.config.js
|
||||
.storybook/*
|
||||
@@ -10,6 +10,7 @@ module.exports = {
|
||||
"@storybook/addon-essentials",
|
||||
"@storybook/addon-interactions",
|
||||
"@chakra-ui/storybook-addon",
|
||||
"storybook-addon-next-router",
|
||||
],
|
||||
framework: "@storybook/react",
|
||||
core: {
|
||||
@@ -23,7 +24,12 @@ module.exports = {
|
||||
config.resolve.alias = {
|
||||
...config.resolve.alias,
|
||||
src: path.resolve(__dirname, "../src"),
|
||||
styles: path.resolve(__dirname, "../styles"),
|
||||
};
|
||||
config.resolve.fallback = {
|
||||
fs: false,
|
||||
path: require.resolve('path-browserify'),
|
||||
}
|
||||
return config;
|
||||
},
|
||||
features: {
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import "!style-loader!css-loader!postcss-loader!tailwindcss/tailwind.css";
|
||||
import { RouterContext } from "next/dist/shared/lib/router-context";
|
||||
|
||||
export const parameters = {
|
||||
actions: { argTypesRegex: "^on[A-Z].*" },
|
||||
@@ -8,6 +9,9 @@ export const parameters = {
|
||||
date: /Date$/,
|
||||
},
|
||||
},
|
||||
nextRouter: {
|
||||
Provider: RouterContext.Provider,
|
||||
},
|
||||
};
|
||||
|
||||
// Hacky solution to get Images in next to work
|
||||
|
||||
+67
-100
@@ -2,8 +2,7 @@
|
||||
|
||||
## Purpose
|
||||
|
||||
This provides a comprehensive webapp interface for LAION's Open Assistant
|
||||
project. Initially it will support:
|
||||
This provides a comprehensive webapp interface for LAION's Open Assistant project. Initially it will support:
|
||||
|
||||
1. User registration using either Discord or Email.
|
||||
1. Adding responses to incomplete Open Assistant tasks.
|
||||
@@ -11,8 +10,7 @@ project. Initially it will support:
|
||||
1. Viewing an activity leaderboard.
|
||||
1. Tracking community wide updates.
|
||||
|
||||
This interface compliments the Discord bot and will give access to the same
|
||||
underlying tasks.
|
||||
This interface compliments the Discord bot and will give access to the same underlying tasks.
|
||||
|
||||
## Contributing
|
||||
|
||||
@@ -22,67 +20,54 @@ This website is built using:
|
||||
|
||||
1. [npm](https://www.npmjs.com/): The node package manager for building.
|
||||
1. [React](https://reactjs.org/): The core frontend framework.
|
||||
1. [Next.js](https://nextjs.org/): A React scaffolding framework to streamline
|
||||
development.
|
||||
1. [Prisma](https://www.prisma.io/): An ORM to interact with a web specific
|
||||
[Postgres](https://www.postgresql.org/) database.
|
||||
1. [NextAuth.js](https://next-auth.js.org/): A user authentication framework to
|
||||
ensure we handle accounts with best practices.
|
||||
1. [TailwindCSS](https://tailwindcss.com/): A general purpose framework for
|
||||
styling any component.
|
||||
1. [Chakra-UI](https://chakra-ui.com/): A wide collection of pre-built UI
|
||||
components that generally look pretty good.
|
||||
1. [Next.js](https://nextjs.org/): A React scaffolding framework to streamline development.
|
||||
1. [Prisma](https://www.prisma.io/): An ORM to interact with a web specific [Postgres](https://www.postgresql.org/)
|
||||
database.
|
||||
1. [NextAuth.js](https://next-auth.js.org/): A user authentication framework to ensure we handle accounts with best
|
||||
practices.
|
||||
1. [TailwindCSS](https://tailwindcss.com/): A general purpose framework for styling any component.
|
||||
1. [Chakra-UI](https://chakra-ui.com/): A wide collection of pre-built UI components that generally look pretty good.
|
||||
|
||||
### Set up your environment
|
||||
|
||||
To contribute to the website, make sure you have the following setup and
|
||||
installed:
|
||||
To contribute to the website, make sure you have the following setup and installed:
|
||||
|
||||
1. [NVM](https://github.com/nvm-sh/nvm): The Node Version Manager makes it easy
|
||||
to ensure you have the right NodeJS version installed. Once installed, run
|
||||
`nvm use 16` to use Node 16.x. The website is known to be stable with NodeJS
|
||||
1. [NVM](https://github.com/nvm-sh/nvm): The Node Version Manager makes it easy to ensure you have the right NodeJS
|
||||
version installed. Once installed, run `nvm use 16` to use Node 16.x. The website is known to be stable with NodeJS
|
||||
version 16.x. This will install both Node and NPM.
|
||||
1. [Docker](https://www.docker.com/): We use docker to simplify running
|
||||
dependent services.
|
||||
1. [Docker](https://www.docker.com/): We use docker to simplify running dependent services.
|
||||
|
||||
### Getting everything up and running
|
||||
|
||||
If you're doing active development we suggest the following workflow:
|
||||
|
||||
1. In one tab, navigate to the project root.
|
||||
1. Run `docker compose up frontend-dev --build --attach-dependencies`. You can
|
||||
optionally include `-d` to detach and later track the logs if desired.
|
||||
1. Run `docker compose up frontend-dev --build --attach-dependencies`. You can optionally include `-d` to detach and
|
||||
later track the logs if desired.
|
||||
1. In another tab navigate to `${OPEN_ASSISTANT_ROOT/website`.
|
||||
1. Run `npm ci`
|
||||
1. Run `npx prisma db push` (This is also needed when you restart the docker
|
||||
stack from scratch).
|
||||
1. Run `npm run dev`. Now the website is up and running locally at
|
||||
`http://localhost:3000`.
|
||||
1. To create an account, login via the user using email authentication and
|
||||
navigate to `http://localhost:1080`. Check the email listed and click the
|
||||
log in link. You're now logged in and authenticated.
|
||||
1. Run `npx prisma db push` (This is also needed when you restart the docker stack from scratch).
|
||||
1. Run `npm run dev`. Now the website is up and running locally at `http://localhost:3000`.
|
||||
1. To create an account, login via the user using email authentication and navigate to `http://localhost:1080`. Check
|
||||
the email listed and click the log in link. You're now logged in and authenticated.
|
||||
|
||||
### Using debug user credentials
|
||||
|
||||
You can use the debug credentials provider to log in without fancy emails or
|
||||
OAuth.
|
||||
You can use the debug credentials provider to log in without fancy emails or OAuth.
|
||||
|
||||
1. This feature is automatically on in development mode, i.e. when you run
|
||||
`npm run dev`. In case you want to do the same with a production build (for
|
||||
example, the docker image), then run the website with environment variable
|
||||
1. This feature is automatically on in development mode, i.e. when you run `npm run dev`. In case you want to do the
|
||||
same with a production build (for example, the docker image), then run the website with environment variable
|
||||
`DEBUG_LOGIN=true`.
|
||||
1. Use the `Login` button in the top right to go to the login page.
|
||||
1. You should see a section for debug credentials. Enter any username you wish,
|
||||
you will be logged in as that user.
|
||||
1. You should see a section for debug credentials. Enter any username you wish, you will be logged in as that user.
|
||||
|
||||
### Using Storybook
|
||||
|
||||
To develop components using [Storybook](https://storybook.js.org/) run
|
||||
`npm run storybook`. Then navigate to in your browser to
|
||||
`http://localhost:6006`.
|
||||
To develop components using [Storybook](https://storybook.js.org/) run `npm run storybook`. Then navigate to in your
|
||||
browser to `http://localhost:6006`.
|
||||
|
||||
To create a new story create a file named `[componentName].stories.js`. An
|
||||
example how such a story could look like, see `Header.stories.jsx`.
|
||||
To create a new story create a file named `[componentName].stories.js`. An example how such a story could look like, see
|
||||
`Header.stories.jsx`.
|
||||
|
||||
## Code Layout
|
||||
|
||||
@@ -90,12 +75,10 @@ example how such a story could look like, see `Header.stories.jsx`.
|
||||
|
||||
All react code is under `src/` with a few sub directories:
|
||||
|
||||
1. `pages/`: All pages a user could navigate too and API URLs which are under
|
||||
`pages/api/`.
|
||||
1. `components/`: All re-usable React components. If something gets used twice
|
||||
we should create a component and put it here.
|
||||
1. `lib/`: A generic place to store library files that are used anywhere. This
|
||||
doesn't have much structure yet.
|
||||
1. `pages/`: All pages a user could navigate too and API URLs which are under `pages/api/`.
|
||||
1. `components/`: All re-usable React components. If something gets used twice we should create a component and put it
|
||||
here.
|
||||
1. `lib/`: A generic place to store library files that are used anywhere. This doesn't have much structure yet.
|
||||
|
||||
NOTE: `styles/` can be ignored for now.
|
||||
|
||||
@@ -113,25 +96,20 @@ We're not really using CSS styles. `styles/` can be ignored.
|
||||
|
||||
## Testing the UI
|
||||
|
||||
Cypress is used for end-to-end (e2e) and component testing and is configured in
|
||||
`./cypress.config.ts`. The `./cypress` folder is used for supporting
|
||||
configuration files etc.
|
||||
Cypress is used for end-to-end (e2e) and component testing and is configured in `./cypress.config.ts`. The `./cypress`
|
||||
folder is used for supporting configuration files etc.
|
||||
|
||||
- Store e2e tests in the `./cypress/e2e` folder.
|
||||
- Store component tests adjacent to the component being tested. If you want to
|
||||
wriite a test for `./src/components/Layout.tsx` then store the test file at
|
||||
`./src/components/Layout.cy.tsx`.
|
||||
- Store component tests adjacent to the component being tested. If you want to wriite a test for
|
||||
`./src/components/Layout.tsx` then store the test file at `./src/components/Layout.cy.tsx`.
|
||||
|
||||
A few npm scripts are available for convenience:
|
||||
|
||||
- `npm run cypress`: Useful for development, it opens Cypress and allows you to
|
||||
explore, run and debug tests. It assumes you have the NextJS site running at
|
||||
`localhost:3000`.
|
||||
- `npm run cypress:run`: Runs all tests. Useful for a quick sanity check before
|
||||
sending a PR or to run in CI pipelines.
|
||||
- `npm run cypress:image-baseline`: If you have tests failing because of visual
|
||||
changes that was expected, this command will update the baseline images stored
|
||||
in `./cypress-visual-screenshots/baseline` with those from the adjacent
|
||||
- `npm run cypress`: Useful for development, it opens Cypress and allows you to explore, run and debug tests. It assumes
|
||||
you have the NextJS site running at `localhost:3000`.
|
||||
- `npm run cypress:run`: Runs all tests. Useful for a quick sanity check before sending a PR or to run in CI pipelines.
|
||||
- `npm run cypress:image-baseline`: If you have tests failing because of visual changes that was expected, this command
|
||||
will update the baseline images stored in `./cypress-visual-screenshots/baseline` with those from the adjacent
|
||||
comparison folder. More can be found in the
|
||||
[docs of `uktrade/cypress-image-diff`](https://github.com/uktrade/cypress-image-diff/blob/main/docs/CLI.md#update-all-baseline-images-for-failing-tests).
|
||||
|
||||
@@ -141,10 +119,9 @@ Read more in the [./cypress README](cypress/).
|
||||
|
||||
Jest and React Testing Library are used for unit testing JS/TS/TSX code.
|
||||
|
||||
- Store unit test files adjacent to the file being tested and have the filename
|
||||
end with `.test.ts` for non-React code or `.test.tsx` for React code.
|
||||
- `npm run jest`: automatically runs tests and watches for any relevant changes
|
||||
to rerun tests.
|
||||
- Store unit test files adjacent to the file being tested and have the filename end with `.test.ts` for non-React code
|
||||
or `.test.tsx` for React code.
|
||||
- `npm run jest`: automatically runs tests and watches for any relevant changes to rerun tests.
|
||||
|
||||
Read more in the [./src/README.md](src/README.md).
|
||||
|
||||
@@ -152,30 +129,25 @@ Read more in the [./src/README.md](src/README.md).
|
||||
|
||||
When writing code for the website, we have a few best practices:
|
||||
|
||||
1. When importing packages import external dependencies first then local
|
||||
dependencies. Order them alphabetically according to the package name.
|
||||
1. When trying to implement something new, check if
|
||||
[Chakra-UI](https://chakra-ui.com/) has components that are close enough to
|
||||
your need. For example Sliders, Radio Buttons, Progress indicators, etc.
|
||||
They have a lot and we can save time by re-using what they have and tweaking
|
||||
the style as needed.
|
||||
1. Format everything with [Prettier](https://prettier.io/). This is done by
|
||||
default with pre-submits. We currently don't have any custom settings.
|
||||
1. Define functional React components (with types for all properties when
|
||||
feasible).
|
||||
1. When importing packages import external dependencies first then local dependencies. Order them alphabetically
|
||||
according to the package name.
|
||||
1. When trying to implement something new, check if [Chakra-UI](https://chakra-ui.com/) has components that are close
|
||||
enough to your need. For example Sliders, Radio Buttons, Progress indicators, etc. They have a lot and we can save
|
||||
time by re-using what they have and tweaking the style as needed.
|
||||
1. Format everything with [Prettier](https://prettier.io/). This is done by default with pre-submits. We currently
|
||||
don't have any custom settings.
|
||||
1. Define functional React components (with types for all properties when feasible).
|
||||
|
||||
### Developing New Features
|
||||
|
||||
When working on new features or making significant changes that can't be done
|
||||
within a single Pull Request, we ask that you make use of Feature Flags.
|
||||
When working on new features or making significant changes that can't be done within a single Pull Request, we ask that
|
||||
you make use of Feature Flags.
|
||||
|
||||
We've set up
|
||||
[`react-feature-flags`](https://www.npmjs.com/package/react-feature-flags) to
|
||||
make this easier. To get started:
|
||||
We've set up [`react-feature-flags`](https://www.npmjs.com/package/react-feature-flags) to make this easier. To get
|
||||
started:
|
||||
|
||||
1. Add a new flag entry to `website/src/flags.ts`. We have an example flag you
|
||||
can copy as an example. Be sure to `isActive` to true when testing your
|
||||
features but false when submitting your PR.
|
||||
1. Add a new flag entry to `website/src/flags.ts`. We have an example flag you can copy as an example. Be sure to
|
||||
`isActive` to true when testing your features but false when submitting your PR.
|
||||
1. Use your flag wherever you add a new UI element. This can be done with:
|
||||
|
||||
```js
|
||||
@@ -188,29 +160,24 @@ import { Flags } from "react-feature-flags";
|
||||
|
||||
You can see an example of how this works by checking `website/src/components/Header/Headers.tsx` where we use `flagTest`.
|
||||
|
||||
1. Once you've finished building out the feature and it is ready for everyone
|
||||
to use, it's safe to remove the `Flag` wrappers around your component and
|
||||
the entry in `flags.ts`.
|
||||
1. Once you've finished building out the feature and it is ready for everyone to use, it's safe to remove the `Flag`
|
||||
wrappers around your component and the entry in `flags.ts`.
|
||||
|
||||
### URL Paths
|
||||
|
||||
To use stable and consistent URL paths, we recommend the following strategy for
|
||||
new tasks:
|
||||
To use stable and consistent URL paths, we recommend the following strategy for new tasks:
|
||||
|
||||
1. For any task that involves writing a free-form response, put the page under
|
||||
`website/src/pages/create` with a page name matching the task type, such as
|
||||
`initial_prompt.tsx`.
|
||||
1. For any task that evaluates, rates, or ranks content, put the page under
|
||||
`website/src/pages/evaluate` with a page name matching the task type such as
|
||||
`rank_initial_prompts.tsx`.
|
||||
1. For any task that involves writing a free-form response, put the page under `website/src/pages/create` with a page
|
||||
name matching the task type, such as `initial_prompt.tsx`.
|
||||
1. For any task that evaluates, rates, or ranks content, put the page under `website/src/pages/evaluate` with a page
|
||||
name matching the task type such as `rank_initial_prompts.tsx`.
|
||||
|
||||
With this we'll be able to ensure these contribution pages are hidden from
|
||||
logged out users but accessible to logged in users.
|
||||
With this we'll be able to ensure these contribution pages are hidden from logged out users but accessible to logged in
|
||||
users.
|
||||
|
||||
## Learn More
|
||||
|
||||
To learn more about Next.js, take a look at the following resources:
|
||||
|
||||
- [Next.js Documentation](https://nextjs.org/docs) - learn about Next.js
|
||||
features and API.
|
||||
- [Next.js Documentation](https://nextjs.org/docs) - learn about Next.js features and API.
|
||||
- [Learn Next.js](https://nextjs.org/learn) - an interactive Next.js tutorial.
|
||||
|
||||
+40
-58
@@ -1,24 +1,19 @@
|
||||
# Component and e2e testing with Cypress
|
||||
|
||||
[Cypress](https://www.cypress.io/) is used for both component- and end-to-end
|
||||
testing. Below there's a few examples for the context of this site. To learn
|
||||
more, the
|
||||
[Cypress documentation](https://docs.cypress.io/guides/getting-started/opening-the-app)
|
||||
has it all.
|
||||
[Cypress](https://www.cypress.io/) is used for both component- and end-to-end testing. Below there's a few examples for
|
||||
the context of this site. To learn more, the
|
||||
[Cypress documentation](https://docs.cypress.io/guides/getting-started/opening-the-app) has it all.
|
||||
|
||||
Don't get scared by the commercial offerings they offer. Their core is open
|
||||
source, the cloud offering is not necesarry at all and can be replaced by CI
|
||||
tooling and [community efforts](https://sorry-cypress.dev/).
|
||||
Don't get scared by the commercial offerings they offer. Their core is open source, the cloud offering is not necesarry
|
||||
at all and can be replaced by CI tooling and [community efforts](https://sorry-cypress.dev/).
|
||||
|
||||
# Component testing
|
||||
|
||||
To write a new component test, you either create a new `.tsx` adjacent to the
|
||||
component you want to test or you can use the guide presented yo you when
|
||||
running `npm run cypress` which allows you to easily create the skeleton test
|
||||
for an existing component.
|
||||
To write a new component test, you either create a new `.tsx` adjacent to the component you want to test or you can use
|
||||
the guide presented yo you when running `npm run cypress` which allows you to easily create the skeleton test for an
|
||||
existing component.
|
||||
|
||||
If you have a `Button.tsx` component, create a file next to it called
|
||||
`Button.cy.tsx` which could look like this:
|
||||
If you have a `Button.tsx` component, create a file next to it called `Button.cy.tsx` which could look like this:
|
||||
|
||||
```typescript
|
||||
import React from "react";
|
||||
@@ -35,28 +30,24 @@ describe("<Button />", () => {
|
||||
|
||||
## What's happening here?
|
||||
|
||||
First we use `cy.mount` to mount our component under test. Notive how we specify
|
||||
`className` and inner text - this is where we arrange our component with fake
|
||||
data that we could assert on later.
|
||||
First we use `cy.mount` to mount our component under test. Notive how we specify `className` and inner text - this is
|
||||
where we arrange our component with fake data that we could assert on later.
|
||||
|
||||
In the example above, we also use `cy.get` to select the rendered `button`
|
||||
element. Cypress has multiple ways to
|
||||
[select elements](https://docs.cypress.io/guides/references/best-practices),
|
||||
`get` is just one of them (and often not recommended).
|
||||
In the example above, we also use `cy.get` to select the rendered `button` element. Cypress has multiple ways to
|
||||
[select elements](https://docs.cypress.io/guides/references/best-practices), `get` is just one of them (and often not
|
||||
recommended).
|
||||
|
||||
At last, we use `captureSnapshot` which is a plugin that snaps a photo of the
|
||||
`button` element and compares it to a baseline located in the
|
||||
`./cypress-visual-screenshots/baseline/` folder. If there's too many unidentical
|
||||
pixels between the two, it will fail the test.
|
||||
At last, we use `captureSnapshot` which is a plugin that snaps a photo of the `button` element and compares it to a
|
||||
baseline located in the `./cypress-visual-screenshots/baseline/` folder. If there's too many unidentical pixels between
|
||||
the two, it will fail the test.
|
||||
|
||||
# End-to-end (e2e) testing
|
||||
|
||||
e2e tests are stored in the `./cypress/e2e` folder and should be named
|
||||
`{page}.cy.ts` and located in a relative folder structure that mirrors the page
|
||||
under test.
|
||||
e2e tests are stored in the `./cypress/e2e` folder and should be named `{page}.cy.ts` and located in a relative folder
|
||||
structure that mirrors the page under test.
|
||||
|
||||
When running `npm run cypress` and selecting e2e testing, we assume you have the
|
||||
NextJS site running at `localhost:3000`.
|
||||
When running `npm run cypress` and selecting e2e testing, we assume you have the NextJS site running at
|
||||
`localhost:3000`.
|
||||
|
||||
An example test could look as follows:
|
||||
|
||||
@@ -74,39 +65,33 @@ export {};
|
||||
|
||||
## What's happening here?
|
||||
|
||||
First we use [`cy.visit`](https://docs.cypress.io/api/commands/visit) to point
|
||||
the browser at the desired page. It appends relative paths to the configured
|
||||
`baseUrl` (found in `./cypress.config.ts`).
|
||||
First we use [`cy.visit`](https://docs.cypress.io/api/commands/visit) to point the browser at the desired page. It
|
||||
appends relative paths to the configured `baseUrl` (found in `./cypress.config.ts`).
|
||||
|
||||
Cypress will
|
||||
[automatically await](https://docs.cypress.io/guides/core-concepts/introduction-to-cypress#Timeouts)
|
||||
almost anything you do, but fail if the default timeout is reached.
|
||||
Cypress will [automatically await](https://docs.cypress.io/guides/core-concepts/introduction-to-cypress#Timeouts) almost
|
||||
anything you do, but fail if the default timeout is reached.
|
||||
|
||||
Then we get the email input field and type our email address. We find the input
|
||||
field using the data-cy attribute that we added in the source code of the
|
||||
element on the page.
|
||||
Then we get the email input field and type our email address. We find the input field using the data-cy attribute that
|
||||
we added in the source code of the element on the page.
|
||||
|
||||
```jsx
|
||||
<Input data-cy="email-address" placeholder="Email Address" />
|
||||
```
|
||||
|
||||
Using `data-cy` is how we ensure that selecting the element is robust to changes
|
||||
in page design or function and is one of the
|
||||
Using `data-cy` is how we ensure that selecting the element is robust to changes in page design or function and is one
|
||||
of the
|
||||
[best practices recommended by Cypress](https://docs.cypress.io/guides/references/best-practices#Selecting-Elements).
|
||||
|
||||
Next we call `type()` to use the keyboard, cypress will automatically focus the
|
||||
element and send the keypress events. Notice the `{enter}` keyword, this will
|
||||
cause Cypress to hit the return key which we expect to submit the form.
|
||||
Next we call `type()` to use the keyboard, cypress will automatically focus the element and send the keypress events.
|
||||
Notice the `{enter}` keyword, this will cause Cypress to hit the return key which we expect to submit the form.
|
||||
|
||||
We then assert that the URL should contain `/auth/verify`. Again the timeout
|
||||
will make sure we are not waiting forever, and the test will fail if we do not
|
||||
manage to get there in a reasonable time.
|
||||
We then assert that the URL should contain `/auth/verify`. Again the timeout will make sure we are not waiting forever,
|
||||
and the test will fail if we do not manage to get there in a reasonable time.
|
||||
|
||||
## Authenticating in e2e tests
|
||||
|
||||
For end-to-end tests almost every test will need to first sign in to the
|
||||
website. To make this easier we have a custom command for Cypress that makes
|
||||
logging in with an email address a single command, `cy.signInWithEmail()`.
|
||||
For end-to-end tests almost every test will need to first sign in to the website. To make this easier we have a custom
|
||||
command for Cypress that makes logging in with an email address a single command, `cy.signInWithEmail()`.
|
||||
|
||||
```typescript
|
||||
describe("replying as the assistant", () => {
|
||||
@@ -115,16 +100,13 @@ describe("replying as the assistant", () => {
|
||||
|
||||
cy.visit("/create/assistant_reply");
|
||||
|
||||
cy.get('[data-cy="reply"').type(
|
||||
"You need to run pre-commit to make the reviewer happy."
|
||||
);
|
||||
cy.get('[data-cy="reply"').type("You need to run pre-commit to make the reviewer happy.");
|
||||
cy.get('[data-cy="submit"]').click();
|
||||
});
|
||||
});
|
||||
```
|
||||
|
||||
In this example we sign in as `cypress@example.com` before visiting the
|
||||
`/create/assistant_reply` page that is only available when authenticated. We can
|
||||
then continue on with our test as normal. Note: using `cy.signInWithEmail()`
|
||||
requires that the maildev is running, which should have been started as part of
|
||||
the `docker compose up` command that is required to do any end-to-end testing.
|
||||
In this example we sign in as `cypress@example.com` before visiting the `/create/assistant_reply` page that is only
|
||||
available when authenticated. We can then continue on with our test as normal. Note: using `cy.signInWithEmail()`
|
||||
requires that the maildev is running, which should have been started as part of the `docker compose up` command that is
|
||||
required to do any end-to-end testing.
|
||||
|
||||
@@ -7,9 +7,6 @@ describe("<Container />", () => {
|
||||
const className = "my-class";
|
||||
const text = "test_container";
|
||||
cy.mount(<Container className={className}>{text}</Container>);
|
||||
cy.get(`div.${className}`)
|
||||
.should("have.class", className)
|
||||
.should("be.visible")
|
||||
.should("contain", text);
|
||||
cy.get(`div.${className}`).should("have.class", className).should("be.visible").should("contain", text);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -12,25 +12,18 @@ describe("Contract test for Oasst API", function () {
|
||||
} as BackendUserCore;
|
||||
|
||||
it("can fetch a task", async () => {
|
||||
expect(await oasstApiClient.fetchTask("random", testUser)).to.be.not.null;
|
||||
expect(await oasstApiClient.fetchTask("random", testUser, "en")).to.be.not.null;
|
||||
});
|
||||
|
||||
it("can ack a task", async () => {
|
||||
const task = await oasstApiClient.fetchTask("random", testUser);
|
||||
const task = await oasstApiClient.fetchTask("random", testUser, "en");
|
||||
expect(await oasstApiClient.ackTask(task.id, "321")).to.be.null;
|
||||
});
|
||||
|
||||
it("can record a taskInteraction", async () => {
|
||||
const task = await oasstApiClient.fetchTask("random", testUser);
|
||||
const task = await oasstApiClient.fetchTask("random", testUser, "en");
|
||||
expect(
|
||||
await oasstApiClient.interactTask(
|
||||
"text_reply_to_message",
|
||||
task.id,
|
||||
"321",
|
||||
"1",
|
||||
{ text: "Test" },
|
||||
testUser
|
||||
)
|
||||
await oasstApiClient.interactTask("text_reply_to_message", task.id, "321", "1", { text: "Test" }, testUser, "en")
|
||||
).to.be.not.null;
|
||||
});
|
||||
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
describe("labeling assistant replies", () => {
|
||||
it("completes the current task on submit and on request shows a new task", () => {
|
||||
cy.signInWithEmail("cypress@example.com");
|
||||
cy.visit("/label/label_assistant_reply");
|
||||
|
||||
cy.get('[data-cy="task"]')
|
||||
.invoke("attr", "data-task-type")
|
||||
.then((type) => {
|
||||
cy.log("Task type", type);
|
||||
|
||||
// For specific task pages the no task available result is normal.
|
||||
if (type === undefined) return;
|
||||
|
||||
cy.get('[data-cy="label-options"]').each((label) => {
|
||||
// Click the 4th option
|
||||
cy.wrap(label).find('[data-cy="radio-option"]').eq(3).click();
|
||||
});
|
||||
|
||||
cy.get('[data-cy="review"]').click();
|
||||
|
||||
cy.get('[data-cy="submit"]').click();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
export {};
|
||||
@@ -0,0 +1,26 @@
|
||||
describe("labeling initial prompts", () => {
|
||||
it("completes the current task on submit and on request shows a new task", () => {
|
||||
cy.signInWithEmail("cypress@example.com");
|
||||
cy.visit("/label/label_initial_prompt");
|
||||
|
||||
cy.get('[data-cy="task"]')
|
||||
.invoke("attr", "data-task-type")
|
||||
.then((type) => {
|
||||
cy.log("Task type", type);
|
||||
|
||||
// For specific task pages the no task available result is normal.
|
||||
if (type === undefined) return;
|
||||
|
||||
cy.get('[data-cy="label-options"]').each((label) => {
|
||||
// Click the 4th option
|
||||
cy.wrap(label).find('[data-cy="radio-option"]').eq(3).click();
|
||||
});
|
||||
|
||||
cy.get('[data-cy="review"]').click();
|
||||
|
||||
cy.get('[data-cy="submit"]').click();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
export {};
|
||||
@@ -0,0 +1,26 @@
|
||||
describe("labeling prompter replies", () => {
|
||||
it("completes the current task on submit and on request shows a new task", () => {
|
||||
cy.signInWithEmail("cypress@example.com");
|
||||
cy.visit("/label/label_prompter_reply");
|
||||
|
||||
cy.get('[data-cy="task"]')
|
||||
.invoke("attr", "data-task-type")
|
||||
.then((type) => {
|
||||
cy.log("Task type", type);
|
||||
|
||||
// For specific task pages the no task available result is normal.
|
||||
if (type === undefined) return;
|
||||
|
||||
cy.get('[data-cy="label-options"]').each((label) => {
|
||||
// Click the 4th option
|
||||
cy.wrap(label).find('[data-cy="radio-option"]').eq(3).click();
|
||||
});
|
||||
|
||||
cy.get('[data-cy="review"]').click();
|
||||
|
||||
cy.get('[data-cy="submit"]').click();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
export {};
|
||||
@@ -0,0 +1,23 @@
|
||||
describe("no tasks available", () => {
|
||||
it("displays an empty state when no tasks are available", () => {
|
||||
cy.signInWithEmail("cypress@example.com");
|
||||
cy.intercept(
|
||||
{
|
||||
method: "GET",
|
||||
url: "/api/new_task/prompter_reply",
|
||||
},
|
||||
{
|
||||
statusCode: 500,
|
||||
body: {
|
||||
message: "No tasks of type 'label_prompter_reply' are currently available.",
|
||||
errorCode: 1006,
|
||||
httpStatusCode: 503,
|
||||
},
|
||||
}
|
||||
).as("newTaskPrompterReply");
|
||||
cy.visit("/create/user_reply");
|
||||
cy.wait("@newTaskPrompterReply").then(() => {
|
||||
cy.get('[data-cy="cy-no-tasks"]').should("exist");
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -44,47 +44,29 @@ describe("handles random tasks", () => {
|
||||
break;
|
||||
}
|
||||
case "label-task": {
|
||||
cy.get('[data-cy="label-group-item"]')
|
||||
.first()
|
||||
.invoke("attr", "data-label-type")
|
||||
.then((label_type) => {
|
||||
const parent = cy
|
||||
.get('[data-cy="label-group-item"]')
|
||||
.first();
|
||||
cy.log("Label type", label_type);
|
||||
cy.get('[data-cy="label-options"]').each((label) => {
|
||||
// Click the 4th option
|
||||
cy.wrap(label).find('[data-cy="radio-option"]').eq(3).click();
|
||||
});
|
||||
|
||||
switch (label_type) {
|
||||
case "slider": {
|
||||
// Clicking on the slider will set the value to about the middle where it clicks
|
||||
parent
|
||||
.get('[aria-roledescription="slider"]')
|
||||
.first()
|
||||
.click();
|
||||
cy.get('[data-cy="review"]').click();
|
||||
|
||||
cy.get('[data-cy="review"]').click();
|
||||
|
||||
cy.get('[data-cy="submit"]').click();
|
||||
|
||||
break;
|
||||
}
|
||||
case "radio": {
|
||||
// Clicking on the slider will set the value to about the middle where it clicks
|
||||
parent
|
||||
.get('[aria-roledescription="radio-button"]')
|
||||
.last()
|
||||
.click();
|
||||
|
||||
cy.get('[data-cy="review"]').click();
|
||||
|
||||
cy.get('[data-cy="submit"]').click();
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
cy.get('[data-cy="submit"]').click();
|
||||
|
||||
break;
|
||||
}
|
||||
case "spam-task": {
|
||||
cy.get('[data-cy="not-spam-button"]').click();
|
||||
|
||||
cy.get('[data-cy="review"]').click();
|
||||
|
||||
cy.get('[data-cy="submit"]').click();
|
||||
|
||||
break;
|
||||
}
|
||||
case undefined: {
|
||||
throw new Error("No tasks available, but at least create initial prompt expected");
|
||||
}
|
||||
default:
|
||||
throw new Error(`Unexpected task type: ${type}`);
|
||||
}
|
||||
|
||||
@@ -37,19 +37,14 @@
|
||||
// }
|
||||
|
||||
Cypress.Commands.add("signInUsingEmailedLink", (emailAddress) => {
|
||||
const mailDevApi = `${Cypress.env("MAILDEV_PROTOCOL")}://${Cypress.env(
|
||||
"MAILDEV_HOST"
|
||||
)}:${Cypress.env("MAILDEV_API_PORT")}`;
|
||||
cy.request(
|
||||
"GET",
|
||||
`${mailDevApi}/email?headers.to=${emailAddress.toLowerCase()}`
|
||||
).then((response) => {
|
||||
const mailDevApi = `${Cypress.env("MAILDEV_PROTOCOL")}://${Cypress.env("MAILDEV_HOST")}:${Cypress.env(
|
||||
"MAILDEV_API_PORT"
|
||||
)}`;
|
||||
cy.request("GET", `${mailDevApi}/email?headers.to=${emailAddress.toLowerCase()}`).then((response) => {
|
||||
const emails = response.body;
|
||||
|
||||
// Find and use login link
|
||||
const loginLink = emails
|
||||
.pop()
|
||||
.html.match(/href="[^"]+(\/api\/auth\/callback\/[^"]+?)"/)[1];
|
||||
const loginLink = emails.pop().html.match(/href="[^"]+(\/api\/auth\/callback\/[^"]+?)"/)[1];
|
||||
cy.visit(loginLink);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
module.exports = {
|
||||
i18n: {
|
||||
defaultLocale: "en",
|
||||
locales: ["en"],
|
||||
locales: ["bn", "de", "en", "es", "fr", "ja", "pt-BR", "ru", "zh"],
|
||||
},
|
||||
};
|
||||
|
||||
Generated
+174
-149
@@ -21,6 +21,8 @@
|
||||
"@next/font": "^13.1.0",
|
||||
"@prisma/client": "^4.7.1",
|
||||
"@tailwindcss/forms": "^0.5.3",
|
||||
"@tanstack/react-table": "^8.7.6",
|
||||
"accept-language-parser": "^1.5.0",
|
||||
"autoprefixer": "^10.4.13",
|
||||
"axios": "^1.2.1",
|
||||
"boolean": "^3.2.0",
|
||||
@@ -31,6 +33,7 @@
|
||||
"focus-visible": "^5.2.0",
|
||||
"framer-motion": "^6.5.1",
|
||||
"install": "^0.13.0",
|
||||
"lucide-react": "^0.105.0",
|
||||
"next": "13.0.6",
|
||||
"next-auth": "^4.18.6",
|
||||
"next-i18next": "^13.0.3",
|
||||
@@ -38,13 +41,13 @@
|
||||
"npm": "^9.2.0",
|
||||
"postcss-focus-visible": "^7.1.0",
|
||||
"react": "18.2.0",
|
||||
"react-cookies": "^0.1.1",
|
||||
"react-dom": "18.2.0",
|
||||
"react-feature-flags": "^1.0.0",
|
||||
"react-hook-form": "^7.42.1",
|
||||
"react-i18next": "^12.1.4",
|
||||
"react-icons": "^4.7.1",
|
||||
"react-table": "^7.8.0",
|
||||
"sharp": "^0.31.3",
|
||||
"storybook-addon-next-router": "^4.0.2",
|
||||
"swr": "^2.0.0",
|
||||
"tailwindcss": "^3.2.4",
|
||||
"unique-username-generator": "^1.1.3",
|
||||
@@ -6199,7 +6202,6 @@
|
||||
"version": "6.5.15",
|
||||
"resolved": "https://registry.npmjs.org/@storybook/addon-actions/-/addon-actions-6.5.15.tgz",
|
||||
"integrity": "sha512-cnLzVK1S+EydFDSuvxMmMAxVqNXijBGdV9QTgsu6ys5sOkoiXRETKZmxuN8/ZRbkfc4N+1KAylSCZOOHzBQTBQ==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"@storybook/addons": "6.5.15",
|
||||
"@storybook/api": "6.5.15",
|
||||
@@ -6242,7 +6244,6 @@
|
||||
"version": "5.1.1",
|
||||
"resolved": "https://registry.npmjs.org/react-inspector/-/react-inspector-5.1.1.tgz",
|
||||
"integrity": "sha512-GURDaYzoLbW8pMGXwYPDBIv6nqei4kK7LPRZ9q9HCZF54wqXz/dnylBp/kfE9XmekBhHvLDdcYeyIwSrvtOiWg==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"@babel/runtime": "^7.0.0",
|
||||
"is-dom": "^1.0.0",
|
||||
@@ -6737,7 +6738,6 @@
|
||||
"version": "6.5.15",
|
||||
"resolved": "https://registry.npmjs.org/@storybook/addons/-/addons-6.5.15.tgz",
|
||||
"integrity": "sha512-xT31SuSX+kYGyxCNK2nqL7WTxucs3rSmhiCLovJcUjYk+QquV3c2c53Ki7lwwdDbzfXFcNAe0HJ4hoTN4jhn0Q==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"@storybook/api": "6.5.15",
|
||||
"@storybook/channels": "6.5.15",
|
||||
@@ -6764,7 +6764,6 @@
|
||||
"version": "6.5.15",
|
||||
"resolved": "https://registry.npmjs.org/@storybook/api/-/api-6.5.15.tgz",
|
||||
"integrity": "sha512-BBE0KXKvj1/3jTghbIoWfrcDM0t+xO7EYtWWAXD6XlhGsZVD2Dy82Z52ONyLulMDRpMWl0OYy3h6A1YnFUH25w==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"@storybook/channels": "6.5.15",
|
||||
"@storybook/client-logger": "6.5.15",
|
||||
@@ -8234,7 +8233,6 @@
|
||||
"version": "6.5.15",
|
||||
"resolved": "https://registry.npmjs.org/@storybook/channel-postmessage/-/channel-postmessage-6.5.15.tgz",
|
||||
"integrity": "sha512-gMpA8LWT8lC4z5KWnaMh03aazEwtDO7GtY5kZVru+EEMgExGmaR82qgekwmLmgLj2nRJEv0o138o9IqYUcou8w==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"@storybook/channels": "6.5.15",
|
||||
"@storybook/client-logger": "6.5.15",
|
||||
@@ -8270,7 +8268,6 @@
|
||||
"version": "6.5.15",
|
||||
"resolved": "https://registry.npmjs.org/@storybook/channels/-/channels-6.5.15.tgz",
|
||||
"integrity": "sha512-gPpsBgirv2NCXbH4WbYqdkI0JLE96aiVuu7UEWfn9yu071pQ9CLHbhXGD9fSFNrfOkyBBY10ppSE7uCXw3Wexg==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"core-js": "^3.8.2",
|
||||
"ts-dedent": "^2.0.0",
|
||||
@@ -8285,7 +8282,6 @@
|
||||
"version": "6.5.15",
|
||||
"resolved": "https://registry.npmjs.org/@storybook/client-api/-/client-api-6.5.15.tgz",
|
||||
"integrity": "sha512-0ZGpRgVz7rdbCguBqBpwObXbsVY5qlSTWDzzIBpmz8EkxW/MtK5wEyeq+0L0O+DTn41FwvH5yCGLAENpzWD8BQ==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"@storybook/addons": "6.5.15",
|
||||
"@storybook/channel-postmessage": "6.5.15",
|
||||
@@ -8321,7 +8317,6 @@
|
||||
"version": "6.5.15",
|
||||
"resolved": "https://registry.npmjs.org/@storybook/client-logger/-/client-logger-6.5.15.tgz",
|
||||
"integrity": "sha512-0uyxKvodq+FycGv6aUwC1wUR6suXf2+7ywMFAOlYolI4UvNj8NyU/5AfgKT5XnxYAgPmoCiAjOE700TrfHrosw==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"core-js": "^3.8.2",
|
||||
"global": "^4.4.0"
|
||||
@@ -8335,7 +8330,6 @@
|
||||
"version": "6.5.15",
|
||||
"resolved": "https://registry.npmjs.org/@storybook/components/-/components-6.5.15.tgz",
|
||||
"integrity": "sha512-bHTT0Oa3s4g+MBMaLBbX9ofMtb1AW59AzIUNGrfqW1XqJMGuUHMiJ7TSo+i5dRSFpbFygnwMEG9LfHxpR2Z0Dw==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"@storybook/client-logger": "6.5.15",
|
||||
"@storybook/csf": "0.0.2--canary.4566f4d.1",
|
||||
@@ -9234,7 +9228,6 @@
|
||||
"version": "6.5.15",
|
||||
"resolved": "https://registry.npmjs.org/@storybook/core-events/-/core-events-6.5.15.tgz",
|
||||
"integrity": "sha512-B1Ba6l5W7MeNclclqMMTMHgYgfdpB5SIhNCQFnzIz8blynzRhNFMdxvbAl6Je5G0S4xydYYd7Lno2kXQebs7HA==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"core-js": "^3.8.2"
|
||||
},
|
||||
@@ -10019,7 +10012,6 @@
|
||||
"version": "0.0.2--canary.4566f4d.1",
|
||||
"resolved": "https://registry.npmjs.org/@storybook/csf/-/csf-0.0.2--canary.4566f4d.1.tgz",
|
||||
"integrity": "sha512-9OVvMVh3t9znYZwb0Svf/YQoxX2gVOeQTGe2bses2yj+a3+OJnCrUF3/hGv6Em7KujtOdL2LL+JnG49oMVGFgQ==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"lodash": "^4.17.15"
|
||||
}
|
||||
@@ -11967,7 +11959,6 @@
|
||||
"version": "6.5.15",
|
||||
"resolved": "https://registry.npmjs.org/@storybook/router/-/router-6.5.15.tgz",
|
||||
"integrity": "sha512-9t8rI8t7/Krolau29gsdjdbRQ66orONIyP0efp0EukVgv6reNFzb/U14ARrl0uHys6Tl5Xyece9FoakQUdn8Kg==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"@storybook/client-logger": "6.5.15",
|
||||
"core-js": "^3.8.2",
|
||||
@@ -11988,7 +11979,6 @@
|
||||
"version": "7.3.2",
|
||||
"resolved": "https://registry.npmjs.org/@storybook/semver/-/semver-7.3.2.tgz",
|
||||
"integrity": "sha512-SWeszlsiPsMI0Ps0jVNtH64cI5c0UF3f7KgjVKJoNP30crQ6wUSddY2hsdeczZXEKVJGEn50Q60flcGsQGIcrg==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"core-js": "^3.6.5",
|
||||
"find-up": "^4.1.0"
|
||||
@@ -12004,7 +11994,6 @@
|
||||
"version": "4.1.0",
|
||||
"resolved": "https://registry.npmjs.org/find-up/-/find-up-4.1.0.tgz",
|
||||
"integrity": "sha512-PpOwAdQ/YlXQ2vj8a3h8IipDuYRi3wceVQQGYWxNINccq40Anw7BlsEXCMbt1Zt+OLA6Fq9suIpIWD0OsnISlw==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"locate-path": "^5.0.0",
|
||||
"path-exists": "^4.0.0"
|
||||
@@ -12017,7 +12006,6 @@
|
||||
"version": "5.0.0",
|
||||
"resolved": "https://registry.npmjs.org/locate-path/-/locate-path-5.0.0.tgz",
|
||||
"integrity": "sha512-t7hw9pI+WvuwNJXwk5zVHpyhIqzg2qTlklJOf0mVxGSbe3Fp2VieZcduNYjaLDoy6p9uGpQEGWG87WpMKlNq8g==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"p-locate": "^4.1.0"
|
||||
},
|
||||
@@ -12029,7 +12017,6 @@
|
||||
"version": "2.3.0",
|
||||
"resolved": "https://registry.npmjs.org/p-limit/-/p-limit-2.3.0.tgz",
|
||||
"integrity": "sha512-//88mFWSJx8lxCzwdAABTJL2MyWB12+eIY7MDL2SqLmAkeKU9qxRvWuSyTjm3FUmpBEMuFfckAIqEaVGUDxb6w==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"p-try": "^2.0.0"
|
||||
},
|
||||
@@ -12044,7 +12031,6 @@
|
||||
"version": "4.1.0",
|
||||
"resolved": "https://registry.npmjs.org/p-locate/-/p-locate-4.1.0.tgz",
|
||||
"integrity": "sha512-R79ZZ/0wAxKGu3oYMlz8jy/kbhsNrS7SKZ7PxEHBgJ5+F2mtFW2fK2cOtBh1cHYkQsbzFV7I+EoRKe6Yt0oK7A==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"p-limit": "^2.2.0"
|
||||
},
|
||||
@@ -12094,7 +12080,6 @@
|
||||
"version": "6.5.15",
|
||||
"resolved": "https://registry.npmjs.org/@storybook/store/-/store-6.5.15.tgz",
|
||||
"integrity": "sha512-r6cYTf6GtbqgdI4ZG3xuWdJAAu5fJ3xAWMiDkHyoK2u+R2TeYXIsJvgn0BPrW87sZhELIkg4ckdFECmATs3kpQ==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"@storybook/addons": "6.5.15",
|
||||
"@storybook/client-logger": "6.5.15",
|
||||
@@ -12232,7 +12217,6 @@
|
||||
"version": "6.5.15",
|
||||
"resolved": "https://registry.npmjs.org/@storybook/theming/-/theming-6.5.15.tgz",
|
||||
"integrity": "sha512-pgdW0lVZKKXQ4VhIfLHycMmwFSVOY7vLTKnytag4Y8Yz+aXm0bwDN/QxPntFzDH47F1Rcy2ywNnvty8ooDTvuA==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"@storybook/client-logger": "6.5.15",
|
||||
"core-js": "^3.8.2",
|
||||
@@ -12297,6 +12281,37 @@
|
||||
"tailwindcss": ">=3.0.0 || >= 3.0.0-alpha.1"
|
||||
}
|
||||
},
|
||||
"node_modules/@tanstack/react-table": {
|
||||
"version": "8.7.6",
|
||||
"resolved": "https://registry.npmjs.org/@tanstack/react-table/-/react-table-8.7.6.tgz",
|
||||
"integrity": "sha512-/QijmMFeP7wDLBnr0MQ/5MlbXePbIL/1nOtkxBC9zvmBu4gDKJEDBqipUyM7Wc/iBpSd0IFyqBlvZvTPD9FYDA==",
|
||||
"dependencies": {
|
||||
"@tanstack/table-core": "8.7.6"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
"funding": {
|
||||
"type": "github",
|
||||
"url": "https://github.com/sponsors/tannerlinsley"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"react": ">=16",
|
||||
"react-dom": ">=16"
|
||||
}
|
||||
},
|
||||
"node_modules/@tanstack/table-core": {
|
||||
"version": "8.7.6",
|
||||
"resolved": "https://registry.npmjs.org/@tanstack/table-core/-/table-core-8.7.6.tgz",
|
||||
"integrity": "sha512-sqiNTMzB6cpyL8DFH6/VqW48SwiflLqxQqYpo2wNock7rdVGvlm0BLNI8vZUJbr1+fmmWmHwBvi5OMgZw8n1DA==",
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
"funding": {
|
||||
"type": "github",
|
||||
"url": "https://github.com/sponsors/tannerlinsley"
|
||||
}
|
||||
},
|
||||
"node_modules/@testing-library/dom": {
|
||||
"version": "8.19.1",
|
||||
"resolved": "https://registry.npmjs.org/@testing-library/dom/-/dom-8.19.1.tgz",
|
||||
@@ -12674,8 +12689,7 @@
|
||||
"node_modules/@types/is-function": {
|
||||
"version": "1.0.1",
|
||||
"resolved": "https://registry.npmjs.org/@types/is-function/-/is-function-1.0.1.tgz",
|
||||
"integrity": "sha512-A79HEEiwXTFtfY+Bcbo58M2GRYzCr9itHWzbzHVFNEYCcoU/MMGwYYf721gBrnhpj1s6RGVVha/IgNFnR0Iw/Q==",
|
||||
"dev": true
|
||||
"integrity": "sha512-A79HEEiwXTFtfY+Bcbo58M2GRYzCr9itHWzbzHVFNEYCcoU/MMGwYYf721gBrnhpj1s6RGVVha/IgNFnR0Iw/Q=="
|
||||
},
|
||||
"node_modules/@types/istanbul-lib-coverage": {
|
||||
"version": "2.0.4",
|
||||
@@ -12876,8 +12890,7 @@
|
||||
"node_modules/@types/qs": {
|
||||
"version": "6.9.7",
|
||||
"resolved": "https://registry.npmjs.org/@types/qs/-/qs-6.9.7.tgz",
|
||||
"integrity": "sha512-FGa1F62FT09qcrueBA6qYTrJPVDzah9a+493+o2PCXsesWHIn27G98TsSMs3WPNbZIEj4+VJf6saSFpvD+3Zsw==",
|
||||
"dev": true
|
||||
"integrity": "sha512-FGa1F62FT09qcrueBA6qYTrJPVDzah9a+493+o2PCXsesWHIn27G98TsSMs3WPNbZIEj4+VJf6saSFpvD+3Zsw=="
|
||||
},
|
||||
"node_modules/@types/react": {
|
||||
"version": "18.0.26",
|
||||
@@ -12995,8 +13008,7 @@
|
||||
"node_modules/@types/webpack-env": {
|
||||
"version": "1.18.0",
|
||||
"resolved": "https://registry.npmjs.org/@types/webpack-env/-/webpack-env-1.18.0.tgz",
|
||||
"integrity": "sha512-56/MAlX5WMsPVbOg7tAxnYvNYMMWr/QJiIp6BxVSW3JJXUVzzOn64qW8TzQyMSqSUFM2+PVI4aUHcHOzIz/1tg==",
|
||||
"dev": true
|
||||
"integrity": "sha512-56/MAlX5WMsPVbOg7tAxnYvNYMMWr/QJiIp6BxVSW3JJXUVzzOn64qW8TzQyMSqSUFM2+PVI4aUHcHOzIz/1tg=="
|
||||
},
|
||||
"node_modules/@types/webpack-sources": {
|
||||
"version": "3.2.0",
|
||||
@@ -13616,6 +13628,11 @@
|
||||
"integrity": "sha512-j2afSsaIENvHZN2B8GOpF566vZ5WVk5opAiMTvWgaQT8DkbOqsTfvNAvHoRGU2zzP8cPoqys+xHTRDWW8L+/BA==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/accept-language-parser": {
|
||||
"version": "1.5.0",
|
||||
"resolved": "https://registry.npmjs.org/accept-language-parser/-/accept-language-parser-1.5.0.tgz",
|
||||
"integrity": "sha512-QhyTbMLYo0BBGg1aWbeMG4ekWtds/31BrEU+DONOg/7ax23vxpL03Pb7/zBmha2v7vdD3AyzZVWBVGEZxKOXWw=="
|
||||
},
|
||||
"node_modules/accepts": {
|
||||
"version": "1.3.8",
|
||||
"resolved": "https://registry.npmjs.org/accepts/-/accepts-1.3.8.tgz",
|
||||
@@ -17984,8 +18001,7 @@
|
||||
"node_modules/dom-walk": {
|
||||
"version": "0.1.2",
|
||||
"resolved": "https://registry.npmjs.org/dom-walk/-/dom-walk-0.1.2.tgz",
|
||||
"integrity": "sha512-6QvTW9mrGeIegrFXdtQi9pk7O/nSK6lSdXW2eqUspN5LWD7UTji2Fqw5V2YLjBpHEoU9Xl/eUWNpDeZvoyOv2w==",
|
||||
"dev": true
|
||||
"integrity": "sha512-6QvTW9mrGeIegrFXdtQi9pk7O/nSK6lSdXW2eqUspN5LWD7UTji2Fqw5V2YLjBpHEoU9Xl/eUWNpDeZvoyOv2w=="
|
||||
},
|
||||
"node_modules/domain-browser": {
|
||||
"version": "1.2.0",
|
||||
@@ -20786,7 +20802,6 @@
|
||||
"version": "4.4.0",
|
||||
"resolved": "https://registry.npmjs.org/global/-/global-4.4.0.tgz",
|
||||
"integrity": "sha512-wv/LAoHdRE3BeTGz53FAamhGlPLhlssK45usmGFThIi4XqnBmjKQ16u+RNbP7WvigRZDxUsM0J3gcQ5yicaL0w==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"min-document": "^2.19.0",
|
||||
"process": "^0.11.10"
|
||||
@@ -22018,7 +22033,6 @@
|
||||
"version": "1.1.0",
|
||||
"resolved": "https://registry.npmjs.org/is-dom/-/is-dom-1.1.0.tgz",
|
||||
"integrity": "sha512-u82f6mvhYxRPKpw8V1N0W8ce1xXwOrQtgGcxl6UCL5zBmZu3is/18K0rR7uFCnMDuAsS/3W54mGL4vsaFUQlEQ==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"is-object": "^1.0.1",
|
||||
"is-window": "^1.0.2"
|
||||
@@ -22069,8 +22083,7 @@
|
||||
"node_modules/is-function": {
|
||||
"version": "1.0.2",
|
||||
"resolved": "https://registry.npmjs.org/is-function/-/is-function-1.0.2.tgz",
|
||||
"integrity": "sha512-lw7DUp0aWXYg+CBCN+JKkcE0Q2RayZnSvnZBlwgxHBQhqt5pZNVy4Ri7H9GmmXkdu7LUthszM+Tor1u/2iBcpQ==",
|
||||
"dev": true
|
||||
"integrity": "sha512-lw7DUp0aWXYg+CBCN+JKkcE0Q2RayZnSvnZBlwgxHBQhqt5pZNVy4Ri7H9GmmXkdu7LUthszM+Tor1u/2iBcpQ=="
|
||||
},
|
||||
"node_modules/is-generator-fn": {
|
||||
"version": "2.1.0",
|
||||
@@ -22164,7 +22177,6 @@
|
||||
"version": "1.0.2",
|
||||
"resolved": "https://registry.npmjs.org/is-object/-/is-object-1.0.2.tgz",
|
||||
"integrity": "sha512-2rRIahhZr2UWb45fIOuvZGpFtz0TyOZLf32KxBbSoUCeZR495zCKlWUKKUByk3geS2eAs7ZAABt0Y/Rx0GiQGA==",
|
||||
"dev": true,
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/ljharb"
|
||||
}
|
||||
@@ -22369,8 +22381,7 @@
|
||||
"node_modules/is-window": {
|
||||
"version": "1.0.2",
|
||||
"resolved": "https://registry.npmjs.org/is-window/-/is-window-1.0.2.tgz",
|
||||
"integrity": "sha512-uj00kdXyZb9t9RcAUAwMZAnkBUwdYGhYlt7djMXhfyhUCzwNba50tIiBKR7q0l7tdoBtFVw/3JmLY6fI3rmZmg==",
|
||||
"dev": true
|
||||
"integrity": "sha512-uj00kdXyZb9t9RcAUAwMZAnkBUwdYGhYlt7djMXhfyhUCzwNba50tIiBKR7q0l7tdoBtFVw/3JmLY6fI3rmZmg=="
|
||||
},
|
||||
"node_modules/is-windows": {
|
||||
"version": "1.0.2",
|
||||
@@ -26446,8 +26457,7 @@
|
||||
"node_modules/lodash": {
|
||||
"version": "4.17.21",
|
||||
"resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.21.tgz",
|
||||
"integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==",
|
||||
"dev": true
|
||||
"integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg=="
|
||||
},
|
||||
"node_modules/lodash.debounce": {
|
||||
"version": "4.0.8",
|
||||
@@ -26687,6 +26697,14 @@
|
||||
"yallist": "^3.0.2"
|
||||
}
|
||||
},
|
||||
"node_modules/lucide-react": {
|
||||
"version": "0.105.0",
|
||||
"resolved": "https://registry.npmjs.org/lucide-react/-/lucide-react-0.105.0.tgz",
|
||||
"integrity": "sha512-iHaIkd4Wq6aNIVrFMXt3If8E/+2lnJd4WlCyntoJNIzZ8nWhdSSHWpsw7XM4rlw2319LZ2t4WLdnM8Z0ECDTOQ==",
|
||||
"peerDependencies": {
|
||||
"react": "^16.5.1 || ^17.0.0 || ^18.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/lz-string": {
|
||||
"version": "1.4.4",
|
||||
"resolved": "https://registry.npmjs.org/lz-string/-/lz-string-1.4.4.tgz",
|
||||
@@ -26776,8 +26794,7 @@
|
||||
"node_modules/map-or-similar": {
|
||||
"version": "1.5.0",
|
||||
"resolved": "https://registry.npmjs.org/map-or-similar/-/map-or-similar-1.5.0.tgz",
|
||||
"integrity": "sha512-0aF7ZmVon1igznGI4VS30yugpduQW3y3GkcgGJOp7d8x8QrizhigUxjI/m2UojsXXto+jLAH3KSz+xOJTiORjg==",
|
||||
"dev": true
|
||||
"integrity": "sha512-0aF7ZmVon1igznGI4VS30yugpduQW3y3GkcgGJOp7d8x8QrizhigUxjI/m2UojsXXto+jLAH3KSz+xOJTiORjg=="
|
||||
},
|
||||
"node_modules/map-visit": {
|
||||
"version": "1.0.0",
|
||||
@@ -26924,7 +26941,6 @@
|
||||
"version": "1.11.3",
|
||||
"resolved": "https://registry.npmjs.org/memoizerific/-/memoizerific-1.11.3.tgz",
|
||||
"integrity": "sha512-/EuHYwAPdLtXwAwSZkh/Gutery6pD2KYd44oQLhAvQp/50mpyduZh8Q7PYHXTCJ+wuXxt7oij2LXyIJOOYFPog==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"map-or-similar": "^1.5.0"
|
||||
}
|
||||
@@ -27183,7 +27199,6 @@
|
||||
"version": "2.19.0",
|
||||
"resolved": "https://registry.npmjs.org/min-document/-/min-document-2.19.0.tgz",
|
||||
"integrity": "sha512-9Wy1B3m3f66bPPmU5hdA4DR4PB2OfDU/+GS3yAB7IQozE3tqXaVv2zOjgla7MEGSRv95+ILmOuvhLkOK6wJtCQ==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"dom-walk": "^0.1.0"
|
||||
}
|
||||
@@ -31128,7 +31143,6 @@
|
||||
"version": "2.2.0",
|
||||
"resolved": "https://registry.npmjs.org/p-try/-/p-try-2.2.0.tgz",
|
||||
"integrity": "sha512-R4nPAVTAU0B9D35/Gk3uJf/7XYbQcyohSKdvAxIRSNghFl4e71hVoGnBNQz9cWaXxO2I10KTC+3jMdvvoKw6dQ==",
|
||||
"dev": true,
|
||||
"engines": {
|
||||
"node": ">=6"
|
||||
}
|
||||
@@ -31480,7 +31494,6 @@
|
||||
"version": "4.2.2",
|
||||
"resolved": "https://registry.npmjs.org/polished/-/polished-4.2.2.tgz",
|
||||
"integrity": "sha512-Sz2Lkdxz6F2Pgnpi9U5Ng/WdWAUZxmHrNPoVlm3aAemxoy2Qy7LGjQg4uf8qKelDAUW94F4np3iH2YPf2qefcQ==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"@babel/runtime": "^7.17.8"
|
||||
},
|
||||
@@ -32087,7 +32100,6 @@
|
||||
"version": "0.11.10",
|
||||
"resolved": "https://registry.npmjs.org/process/-/process-0.11.10.tgz",
|
||||
"integrity": "sha512-cdGef/drWFoydD1JsMzuFf8100nZl+GT+yacc2bEced5f9Rjk4z+WtFUTBu9PhOi9j/jfmBPu0mMEY4wIdAF8A==",
|
||||
"dev": true,
|
||||
"engines": {
|
||||
"node": ">= 0.6.0"
|
||||
}
|
||||
@@ -32269,7 +32281,6 @@
|
||||
"version": "6.11.0",
|
||||
"resolved": "https://registry.npmjs.org/qs/-/qs-6.11.0.tgz",
|
||||
"integrity": "sha512-MvjoMCJwEarSbUYk5O+nmoSzSutSsTwF85zcHPQ9OrlFoZOYIjaqBAJIqIXjptyD5vThxGq52Xu/MaJzRkIk4Q==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"side-channel": "^1.0.4"
|
||||
},
|
||||
@@ -32466,6 +32477,23 @@
|
||||
"react": "^15.3.0 || ^16.0.0 || ^17.0.0 || ^18.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/react-cookies": {
|
||||
"version": "0.1.1",
|
||||
"resolved": "https://registry.npmjs.org/react-cookies/-/react-cookies-0.1.1.tgz",
|
||||
"integrity": "sha512-PP75kJ4vtoHuuTdq0TAD3RmlAv7vuDQh9fkC4oDlhntgs9vX1DmREomO0Y1mcQKR9nMZ6/zxoflaMJ3MAmF5KQ==",
|
||||
"dependencies": {
|
||||
"cookie": "^0.3.1",
|
||||
"object-assign": "^4.1.1"
|
||||
}
|
||||
},
|
||||
"node_modules/react-cookies/node_modules/cookie": {
|
||||
"version": "0.3.1",
|
||||
"resolved": "https://registry.npmjs.org/cookie/-/cookie-0.3.1.tgz",
|
||||
"integrity": "sha512-+IJOX0OqlHCszo2mBUq+SrEbCj6w7Kpffqx60zYbPTFaO4+yYgRjHwcZNpWvaTylDHaV7PPmBHzSecZiMhtPgw==",
|
||||
"engines": {
|
||||
"node": ">= 0.6"
|
||||
}
|
||||
},
|
||||
"node_modules/react-docgen": {
|
||||
"version": "5.4.3",
|
||||
"resolved": "https://registry.npmjs.org/react-docgen/-/react-docgen-5.4.3.tgz",
|
||||
@@ -32601,14 +32629,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/react-icons": {
|
||||
"version": "4.7.1",
|
||||
"resolved": "https://registry.npmjs.org/react-icons/-/react-icons-4.7.1.tgz",
|
||||
"integrity": "sha512-yHd3oKGMgm7zxo3EA7H2n7vxSoiGmHk5t6Ou4bXsfcgWyhfDKMpyKfhHR6Bjnn63c+YXBLBPUql9H4wPJM6sXw==",
|
||||
"peerDependencies": {
|
||||
"react": "*"
|
||||
}
|
||||
},
|
||||
"node_modules/react-is": {
|
||||
"version": "16.13.1",
|
||||
"resolved": "https://registry.npmjs.org/react-is/-/react-is-16.13.1.tgz",
|
||||
@@ -32700,18 +32720,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/react-table": {
|
||||
"version": "7.8.0",
|
||||
"resolved": "https://registry.npmjs.org/react-table/-/react-table-7.8.0.tgz",
|
||||
"integrity": "sha512-hNaz4ygkZO4bESeFfnfOft73iBUj8K5oKi1EcSHPAibEydfsX2MyU6Z8KCr3mv3C9Kqqh71U+DhZkFvibbnPbA==",
|
||||
"funding": {
|
||||
"type": "github",
|
||||
"url": "https://github.com/sponsors/tannerlinsley"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"react": "^16.8.3 || ^17.0.0-0 || ^18.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/read-cache": {
|
||||
"version": "1.0.0",
|
||||
"resolved": "https://registry.npmjs.org/read-cache/-/read-cache-1.0.0.tgz",
|
||||
@@ -34576,8 +34584,7 @@
|
||||
"version": "0.1.8",
|
||||
"resolved": "https://registry.npmjs.org/stable/-/stable-0.1.8.tgz",
|
||||
"integrity": "sha512-ji9qxRnOVfcuLDySj9qzhGSEFVobyt1kIOSkj1qZzYLzq7Tos/oUUWvotUPQLlrsidqsK6tBH89Bc9kL5zHA6w==",
|
||||
"deprecated": "Modern JS already guarantees Array#sort() is a stable sort, so this library is deprecated. See the compatibility table on MDN: https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Array/sort#browser_compatibility",
|
||||
"dev": true
|
||||
"deprecated": "Modern JS already guarantees Array#sort() is a stable sort, so this library is deprecated. See the compatibility table on MDN: https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Array/sort#browser_compatibility"
|
||||
},
|
||||
"node_modules/stack-utils": {
|
||||
"version": "2.0.6",
|
||||
@@ -34730,8 +34737,31 @@
|
||||
"node_modules/store2": {
|
||||
"version": "2.14.2",
|
||||
"resolved": "https://registry.npmjs.org/store2/-/store2-2.14.2.tgz",
|
||||
"integrity": "sha512-siT1RiqlfQnGqgT/YzXVUNsom9S0H1OX+dpdGN1xkyYATo4I6sep5NmsRD/40s3IIOvlCq6akxkqG82urIZW1w==",
|
||||
"dev": true
|
||||
"integrity": "sha512-siT1RiqlfQnGqgT/YzXVUNsom9S0H1OX+dpdGN1xkyYATo4I6sep5NmsRD/40s3IIOvlCq6akxkqG82urIZW1w=="
|
||||
},
|
||||
"node_modules/storybook-addon-next-router": {
|
||||
"version": "4.0.2",
|
||||
"resolved": "https://registry.npmjs.org/storybook-addon-next-router/-/storybook-addon-next-router-4.0.2.tgz",
|
||||
"integrity": "sha512-0rjGAl7HziW4ecDq+VU03H1dwkw5f6phqA+PMquPzdowNVl29ejVwVadLMGlovG6x2snaxMGxtySR2c5bwegSw==",
|
||||
"dependencies": {
|
||||
"tslib": "2.4.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=10"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@storybook/addon-actions": "^6.0.0",
|
||||
"@storybook/addons": "^6.0.0",
|
||||
"@storybook/client-api": "^6.0.0",
|
||||
"next": "^9.0.0 || ^10.0.0 || ^11.0.0 || ^12.0.0 || ^13.0.0",
|
||||
"react": "^16.8.0 || ^17.0.0 || ^18.0.0",
|
||||
"react-dom": "^16.8.0 || ^17.0.0 || ^18.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/storybook-addon-next-router/node_modules/tslib": {
|
||||
"version": "2.4.0",
|
||||
"resolved": "https://registry.npmjs.org/tslib/-/tslib-2.4.0.tgz",
|
||||
"integrity": "sha512-d6xOpEDfsi2CZVlPQzGeux8XMwLT9hssAsaPYExaQMuYskwb+x1x7J371tWlbBdWHroy99KnVB6qIkUbs5X3UQ=="
|
||||
},
|
||||
"node_modules/stream-browserify": {
|
||||
"version": "2.0.2",
|
||||
@@ -35176,8 +35206,7 @@
|
||||
"node_modules/synchronous-promise": {
|
||||
"version": "2.0.16",
|
||||
"resolved": "https://registry.npmjs.org/synchronous-promise/-/synchronous-promise-2.0.16.tgz",
|
||||
"integrity": "sha512-qImOD23aDfnIDNqlG1NOehdB9IYsn1V9oByPjKY1nakv2MQYCEMyX033/q+aEtYCpmYK1cv2+NTmlH+ra6GA5A==",
|
||||
"dev": true
|
||||
"integrity": "sha512-qImOD23aDfnIDNqlG1NOehdB9IYsn1V9oByPjKY1nakv2MQYCEMyX033/q+aEtYCpmYK1cv2+NTmlH+ra6GA5A=="
|
||||
},
|
||||
"node_modules/synckit": {
|
||||
"version": "0.8.4",
|
||||
@@ -35334,7 +35363,6 @@
|
||||
"version": "6.0.8",
|
||||
"resolved": "https://registry.npmjs.org/telejson/-/telejson-6.0.8.tgz",
|
||||
"integrity": "sha512-nerNXi+j8NK1QEfBHtZUN/aLdDcyupA//9kAboYLrtzZlPLpUfqbVGWb9zz91f/mIjRbAYhbgtnJHY8I1b5MBg==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"@types/is-function": "^1.0.0",
|
||||
"global": "^4.4.0",
|
||||
@@ -35350,7 +35378,6 @@
|
||||
"version": "4.0.0",
|
||||
"resolved": "https://registry.npmjs.org/isobject/-/isobject-4.0.0.tgz",
|
||||
"integrity": "sha512-S/2fF5wH8SJA/kmwr6HYhK/RI/OkhD84k8ntalo0iJjZikgq1XFvR5M8NPT1x5F7fBwCG3qHfnzeP/Vh/ZxCUA==",
|
||||
"dev": true,
|
||||
"engines": {
|
||||
"node": ">=0.10.0"
|
||||
}
|
||||
@@ -35721,7 +35748,6 @@
|
||||
"version": "2.2.0",
|
||||
"resolved": "https://registry.npmjs.org/ts-dedent/-/ts-dedent-2.2.0.tgz",
|
||||
"integrity": "sha512-q5W7tVM71e2xjHZTlgfTDoPF/SmqKG5hddq9SzR49CH2hayqRKJtQ4mtRlSxKaJlR/+9rEM+mnBHf7I2/BQcpQ==",
|
||||
"dev": true,
|
||||
"engines": {
|
||||
"node": ">=6.10"
|
||||
}
|
||||
@@ -36514,8 +36540,7 @@
|
||||
"node_modules/uuid-browser": {
|
||||
"version": "3.1.0",
|
||||
"resolved": "https://registry.npmjs.org/uuid-browser/-/uuid-browser-3.1.0.tgz",
|
||||
"integrity": "sha512-dsNgbLaTrd6l3MMxTtouOCFw4CBFc/3a+GgYA2YyrJvyQ1u6q4pcu3ktLoUZ/VN/Aw9WsauazbgsgdfVWgAKQg==",
|
||||
"dev": true
|
||||
"integrity": "sha512-dsNgbLaTrd6l3MMxTtouOCFw4CBFc/3a+GgYA2YyrJvyQ1u6q4pcu3ktLoUZ/VN/Aw9WsauazbgsgdfVWgAKQg=="
|
||||
},
|
||||
"node_modules/v8-compile-cache-lib": {
|
||||
"version": "3.0.1",
|
||||
@@ -42022,7 +42047,6 @@
|
||||
"version": "6.5.15",
|
||||
"resolved": "https://registry.npmjs.org/@storybook/addon-actions/-/addon-actions-6.5.15.tgz",
|
||||
"integrity": "sha512-cnLzVK1S+EydFDSuvxMmMAxVqNXijBGdV9QTgsu6ys5sOkoiXRETKZmxuN8/ZRbkfc4N+1KAylSCZOOHzBQTBQ==",
|
||||
"dev": true,
|
||||
"requires": {
|
||||
"@storybook/addons": "6.5.15",
|
||||
"@storybook/api": "6.5.15",
|
||||
@@ -42049,7 +42073,6 @@
|
||||
"version": "5.1.1",
|
||||
"resolved": "https://registry.npmjs.org/react-inspector/-/react-inspector-5.1.1.tgz",
|
||||
"integrity": "sha512-GURDaYzoLbW8pMGXwYPDBIv6nqei4kK7LPRZ9q9HCZF54wqXz/dnylBp/kfE9XmekBhHvLDdcYeyIwSrvtOiWg==",
|
||||
"dev": true,
|
||||
"requires": {
|
||||
"@babel/runtime": "^7.0.0",
|
||||
"is-dom": "^1.0.0",
|
||||
@@ -42319,7 +42342,6 @@
|
||||
"version": "6.5.15",
|
||||
"resolved": "https://registry.npmjs.org/@storybook/addons/-/addons-6.5.15.tgz",
|
||||
"integrity": "sha512-xT31SuSX+kYGyxCNK2nqL7WTxucs3rSmhiCLovJcUjYk+QquV3c2c53Ki7lwwdDbzfXFcNAe0HJ4hoTN4jhn0Q==",
|
||||
"dev": true,
|
||||
"requires": {
|
||||
"@storybook/api": "6.5.15",
|
||||
"@storybook/channels": "6.5.15",
|
||||
@@ -42338,7 +42360,6 @@
|
||||
"version": "6.5.15",
|
||||
"resolved": "https://registry.npmjs.org/@storybook/api/-/api-6.5.15.tgz",
|
||||
"integrity": "sha512-BBE0KXKvj1/3jTghbIoWfrcDM0t+xO7EYtWWAXD6XlhGsZVD2Dy82Z52ONyLulMDRpMWl0OYy3h6A1YnFUH25w==",
|
||||
"dev": true,
|
||||
"requires": {
|
||||
"@storybook/channels": "6.5.15",
|
||||
"@storybook/client-logger": "6.5.15",
|
||||
@@ -43496,7 +43517,6 @@
|
||||
"version": "6.5.15",
|
||||
"resolved": "https://registry.npmjs.org/@storybook/channel-postmessage/-/channel-postmessage-6.5.15.tgz",
|
||||
"integrity": "sha512-gMpA8LWT8lC4z5KWnaMh03aazEwtDO7GtY5kZVru+EEMgExGmaR82qgekwmLmgLj2nRJEv0o138o9IqYUcou8w==",
|
||||
"dev": true,
|
||||
"requires": {
|
||||
"@storybook/channels": "6.5.15",
|
||||
"@storybook/client-logger": "6.5.15",
|
||||
@@ -43524,7 +43544,6 @@
|
||||
"version": "6.5.15",
|
||||
"resolved": "https://registry.npmjs.org/@storybook/channels/-/channels-6.5.15.tgz",
|
||||
"integrity": "sha512-gPpsBgirv2NCXbH4WbYqdkI0JLE96aiVuu7UEWfn9yu071pQ9CLHbhXGD9fSFNrfOkyBBY10ppSE7uCXw3Wexg==",
|
||||
"dev": true,
|
||||
"requires": {
|
||||
"core-js": "^3.8.2",
|
||||
"ts-dedent": "^2.0.0",
|
||||
@@ -43535,7 +43554,6 @@
|
||||
"version": "6.5.15",
|
||||
"resolved": "https://registry.npmjs.org/@storybook/client-api/-/client-api-6.5.15.tgz",
|
||||
"integrity": "sha512-0ZGpRgVz7rdbCguBqBpwObXbsVY5qlSTWDzzIBpmz8EkxW/MtK5wEyeq+0L0O+DTn41FwvH5yCGLAENpzWD8BQ==",
|
||||
"dev": true,
|
||||
"requires": {
|
||||
"@storybook/addons": "6.5.15",
|
||||
"@storybook/channel-postmessage": "6.5.15",
|
||||
@@ -43563,7 +43581,6 @@
|
||||
"version": "6.5.15",
|
||||
"resolved": "https://registry.npmjs.org/@storybook/client-logger/-/client-logger-6.5.15.tgz",
|
||||
"integrity": "sha512-0uyxKvodq+FycGv6aUwC1wUR6suXf2+7ywMFAOlYolI4UvNj8NyU/5AfgKT5XnxYAgPmoCiAjOE700TrfHrosw==",
|
||||
"dev": true,
|
||||
"requires": {
|
||||
"core-js": "^3.8.2",
|
||||
"global": "^4.4.0"
|
||||
@@ -43573,7 +43590,6 @@
|
||||
"version": "6.5.15",
|
||||
"resolved": "https://registry.npmjs.org/@storybook/components/-/components-6.5.15.tgz",
|
||||
"integrity": "sha512-bHTT0Oa3s4g+MBMaLBbX9ofMtb1AW59AzIUNGrfqW1XqJMGuUHMiJ7TSo+i5dRSFpbFygnwMEG9LfHxpR2Z0Dw==",
|
||||
"dev": true,
|
||||
"requires": {
|
||||
"@storybook/client-logger": "6.5.15",
|
||||
"@storybook/csf": "0.0.2--canary.4566f4d.1",
|
||||
@@ -44298,7 +44314,6 @@
|
||||
"version": "6.5.15",
|
||||
"resolved": "https://registry.npmjs.org/@storybook/core-events/-/core-events-6.5.15.tgz",
|
||||
"integrity": "sha512-B1Ba6l5W7MeNclclqMMTMHgYgfdpB5SIhNCQFnzIz8blynzRhNFMdxvbAl6Je5G0S4xydYYd7Lno2kXQebs7HA==",
|
||||
"dev": true,
|
||||
"requires": {
|
||||
"core-js": "^3.8.2"
|
||||
}
|
||||
@@ -44949,7 +44964,6 @@
|
||||
"version": "0.0.2--canary.4566f4d.1",
|
||||
"resolved": "https://registry.npmjs.org/@storybook/csf/-/csf-0.0.2--canary.4566f4d.1.tgz",
|
||||
"integrity": "sha512-9OVvMVh3t9znYZwb0Svf/YQoxX2gVOeQTGe2bses2yj+a3+OJnCrUF3/hGv6Em7KujtOdL2LL+JnG49oMVGFgQ==",
|
||||
"dev": true,
|
||||
"requires": {
|
||||
"lodash": "^4.17.15"
|
||||
}
|
||||
@@ -46441,7 +46455,6 @@
|
||||
"version": "6.5.15",
|
||||
"resolved": "https://registry.npmjs.org/@storybook/router/-/router-6.5.15.tgz",
|
||||
"integrity": "sha512-9t8rI8t7/Krolau29gsdjdbRQ66orONIyP0efp0EukVgv6reNFzb/U14ARrl0uHys6Tl5Xyece9FoakQUdn8Kg==",
|
||||
"dev": true,
|
||||
"requires": {
|
||||
"@storybook/client-logger": "6.5.15",
|
||||
"core-js": "^3.8.2",
|
||||
@@ -46454,7 +46467,6 @@
|
||||
"version": "7.3.2",
|
||||
"resolved": "https://registry.npmjs.org/@storybook/semver/-/semver-7.3.2.tgz",
|
||||
"integrity": "sha512-SWeszlsiPsMI0Ps0jVNtH64cI5c0UF3f7KgjVKJoNP30crQ6wUSddY2hsdeczZXEKVJGEn50Q60flcGsQGIcrg==",
|
||||
"dev": true,
|
||||
"requires": {
|
||||
"core-js": "^3.6.5",
|
||||
"find-up": "^4.1.0"
|
||||
@@ -46464,7 +46476,6 @@
|
||||
"version": "4.1.0",
|
||||
"resolved": "https://registry.npmjs.org/find-up/-/find-up-4.1.0.tgz",
|
||||
"integrity": "sha512-PpOwAdQ/YlXQ2vj8a3h8IipDuYRi3wceVQQGYWxNINccq40Anw7BlsEXCMbt1Zt+OLA6Fq9suIpIWD0OsnISlw==",
|
||||
"dev": true,
|
||||
"requires": {
|
||||
"locate-path": "^5.0.0",
|
||||
"path-exists": "^4.0.0"
|
||||
@@ -46474,7 +46485,6 @@
|
||||
"version": "5.0.0",
|
||||
"resolved": "https://registry.npmjs.org/locate-path/-/locate-path-5.0.0.tgz",
|
||||
"integrity": "sha512-t7hw9pI+WvuwNJXwk5zVHpyhIqzg2qTlklJOf0mVxGSbe3Fp2VieZcduNYjaLDoy6p9uGpQEGWG87WpMKlNq8g==",
|
||||
"dev": true,
|
||||
"requires": {
|
||||
"p-locate": "^4.1.0"
|
||||
}
|
||||
@@ -46483,7 +46493,6 @@
|
||||
"version": "2.3.0",
|
||||
"resolved": "https://registry.npmjs.org/p-limit/-/p-limit-2.3.0.tgz",
|
||||
"integrity": "sha512-//88mFWSJx8lxCzwdAABTJL2MyWB12+eIY7MDL2SqLmAkeKU9qxRvWuSyTjm3FUmpBEMuFfckAIqEaVGUDxb6w==",
|
||||
"dev": true,
|
||||
"requires": {
|
||||
"p-try": "^2.0.0"
|
||||
}
|
||||
@@ -46492,7 +46501,6 @@
|
||||
"version": "4.1.0",
|
||||
"resolved": "https://registry.npmjs.org/p-locate/-/p-locate-4.1.0.tgz",
|
||||
"integrity": "sha512-R79ZZ/0wAxKGu3oYMlz8jy/kbhsNrS7SKZ7PxEHBgJ5+F2mtFW2fK2cOtBh1cHYkQsbzFV7I+EoRKe6Yt0oK7A==",
|
||||
"dev": true,
|
||||
"requires": {
|
||||
"p-limit": "^2.2.0"
|
||||
}
|
||||
@@ -46529,7 +46537,6 @@
|
||||
"version": "6.5.15",
|
||||
"resolved": "https://registry.npmjs.org/@storybook/store/-/store-6.5.15.tgz",
|
||||
"integrity": "sha512-r6cYTf6GtbqgdI4ZG3xuWdJAAu5fJ3xAWMiDkHyoK2u+R2TeYXIsJvgn0BPrW87sZhELIkg4ckdFECmATs3kpQ==",
|
||||
"dev": true,
|
||||
"requires": {
|
||||
"@storybook/addons": "6.5.15",
|
||||
"@storybook/client-logger": "6.5.15",
|
||||
@@ -46636,7 +46643,6 @@
|
||||
"version": "6.5.15",
|
||||
"resolved": "https://registry.npmjs.org/@storybook/theming/-/theming-6.5.15.tgz",
|
||||
"integrity": "sha512-pgdW0lVZKKXQ4VhIfLHycMmwFSVOY7vLTKnytag4Y8Yz+aXm0bwDN/QxPntFzDH47F1Rcy2ywNnvty8ooDTvuA==",
|
||||
"dev": true,
|
||||
"requires": {
|
||||
"@storybook/client-logger": "6.5.15",
|
||||
"core-js": "^3.8.2",
|
||||
@@ -46682,6 +46688,19 @@
|
||||
"mini-svg-data-uri": "^1.2.3"
|
||||
}
|
||||
},
|
||||
"@tanstack/react-table": {
|
||||
"version": "8.7.6",
|
||||
"resolved": "https://registry.npmjs.org/@tanstack/react-table/-/react-table-8.7.6.tgz",
|
||||
"integrity": "sha512-/QijmMFeP7wDLBnr0MQ/5MlbXePbIL/1nOtkxBC9zvmBu4gDKJEDBqipUyM7Wc/iBpSd0IFyqBlvZvTPD9FYDA==",
|
||||
"requires": {
|
||||
"@tanstack/table-core": "8.7.6"
|
||||
}
|
||||
},
|
||||
"@tanstack/table-core": {
|
||||
"version": "8.7.6",
|
||||
"resolved": "https://registry.npmjs.org/@tanstack/table-core/-/table-core-8.7.6.tgz",
|
||||
"integrity": "sha512-sqiNTMzB6cpyL8DFH6/VqW48SwiflLqxQqYpo2wNock7rdVGvlm0BLNI8vZUJbr1+fmmWmHwBvi5OMgZw8n1DA=="
|
||||
},
|
||||
"@testing-library/dom": {
|
||||
"version": "8.19.1",
|
||||
"resolved": "https://registry.npmjs.org/@testing-library/dom/-/dom-8.19.1.tgz",
|
||||
@@ -46996,8 +47015,7 @@
|
||||
"@types/is-function": {
|
||||
"version": "1.0.1",
|
||||
"resolved": "https://registry.npmjs.org/@types/is-function/-/is-function-1.0.1.tgz",
|
||||
"integrity": "sha512-A79HEEiwXTFtfY+Bcbo58M2GRYzCr9itHWzbzHVFNEYCcoU/MMGwYYf721gBrnhpj1s6RGVVha/IgNFnR0Iw/Q==",
|
||||
"dev": true
|
||||
"integrity": "sha512-A79HEEiwXTFtfY+Bcbo58M2GRYzCr9itHWzbzHVFNEYCcoU/MMGwYYf721gBrnhpj1s6RGVVha/IgNFnR0Iw/Q=="
|
||||
},
|
||||
"@types/istanbul-lib-coverage": {
|
||||
"version": "2.0.4",
|
||||
@@ -47184,8 +47202,7 @@
|
||||
"@types/qs": {
|
||||
"version": "6.9.7",
|
||||
"resolved": "https://registry.npmjs.org/@types/qs/-/qs-6.9.7.tgz",
|
||||
"integrity": "sha512-FGa1F62FT09qcrueBA6qYTrJPVDzah9a+493+o2PCXsesWHIn27G98TsSMs3WPNbZIEj4+VJf6saSFpvD+3Zsw==",
|
||||
"dev": true
|
||||
"integrity": "sha512-FGa1F62FT09qcrueBA6qYTrJPVDzah9a+493+o2PCXsesWHIn27G98TsSMs3WPNbZIEj4+VJf6saSFpvD+3Zsw=="
|
||||
},
|
||||
"@types/react": {
|
||||
"version": "18.0.26",
|
||||
@@ -47310,8 +47327,7 @@
|
||||
"@types/webpack-env": {
|
||||
"version": "1.18.0",
|
||||
"resolved": "https://registry.npmjs.org/@types/webpack-env/-/webpack-env-1.18.0.tgz",
|
||||
"integrity": "sha512-56/MAlX5WMsPVbOg7tAxnYvNYMMWr/QJiIp6BxVSW3JJXUVzzOn64qW8TzQyMSqSUFM2+PVI4aUHcHOzIz/1tg==",
|
||||
"dev": true
|
||||
"integrity": "sha512-56/MAlX5WMsPVbOg7tAxnYvNYMMWr/QJiIp6BxVSW3JJXUVzzOn64qW8TzQyMSqSUFM2+PVI4aUHcHOzIz/1tg=="
|
||||
},
|
||||
"@types/webpack-sources": {
|
||||
"version": "3.2.0",
|
||||
@@ -47817,6 +47833,11 @@
|
||||
"integrity": "sha512-j2afSsaIENvHZN2B8GOpF566vZ5WVk5opAiMTvWgaQT8DkbOqsTfvNAvHoRGU2zzP8cPoqys+xHTRDWW8L+/BA==",
|
||||
"dev": true
|
||||
},
|
||||
"accept-language-parser": {
|
||||
"version": "1.5.0",
|
||||
"resolved": "https://registry.npmjs.org/accept-language-parser/-/accept-language-parser-1.5.0.tgz",
|
||||
"integrity": "sha512-QhyTbMLYo0BBGg1aWbeMG4ekWtds/31BrEU+DONOg/7ax23vxpL03Pb7/zBmha2v7vdD3AyzZVWBVGEZxKOXWw=="
|
||||
},
|
||||
"accepts": {
|
||||
"version": "1.3.8",
|
||||
"resolved": "https://registry.npmjs.org/accepts/-/accepts-1.3.8.tgz",
|
||||
@@ -51205,8 +51226,7 @@
|
||||
"dom-walk": {
|
||||
"version": "0.1.2",
|
||||
"resolved": "https://registry.npmjs.org/dom-walk/-/dom-walk-0.1.2.tgz",
|
||||
"integrity": "sha512-6QvTW9mrGeIegrFXdtQi9pk7O/nSK6lSdXW2eqUspN5LWD7UTji2Fqw5V2YLjBpHEoU9Xl/eUWNpDeZvoyOv2w==",
|
||||
"dev": true
|
||||
"integrity": "sha512-6QvTW9mrGeIegrFXdtQi9pk7O/nSK6lSdXW2eqUspN5LWD7UTji2Fqw5V2YLjBpHEoU9Xl/eUWNpDeZvoyOv2w=="
|
||||
},
|
||||
"domain-browser": {
|
||||
"version": "1.2.0",
|
||||
@@ -53400,7 +53420,6 @@
|
||||
"version": "4.4.0",
|
||||
"resolved": "https://registry.npmjs.org/global/-/global-4.4.0.tgz",
|
||||
"integrity": "sha512-wv/LAoHdRE3BeTGz53FAamhGlPLhlssK45usmGFThIi4XqnBmjKQ16u+RNbP7WvigRZDxUsM0J3gcQ5yicaL0w==",
|
||||
"dev": true,
|
||||
"requires": {
|
||||
"min-document": "^2.19.0",
|
||||
"process": "^0.11.10"
|
||||
@@ -54282,7 +54301,6 @@
|
||||
"version": "1.1.0",
|
||||
"resolved": "https://registry.npmjs.org/is-dom/-/is-dom-1.1.0.tgz",
|
||||
"integrity": "sha512-u82f6mvhYxRPKpw8V1N0W8ce1xXwOrQtgGcxl6UCL5zBmZu3is/18K0rR7uFCnMDuAsS/3W54mGL4vsaFUQlEQ==",
|
||||
"dev": true,
|
||||
"requires": {
|
||||
"is-object": "^1.0.1",
|
||||
"is-window": "^1.0.2"
|
||||
@@ -54318,8 +54336,7 @@
|
||||
"is-function": {
|
||||
"version": "1.0.2",
|
||||
"resolved": "https://registry.npmjs.org/is-function/-/is-function-1.0.2.tgz",
|
||||
"integrity": "sha512-lw7DUp0aWXYg+CBCN+JKkcE0Q2RayZnSvnZBlwgxHBQhqt5pZNVy4Ri7H9GmmXkdu7LUthszM+Tor1u/2iBcpQ==",
|
||||
"dev": true
|
||||
"integrity": "sha512-lw7DUp0aWXYg+CBCN+JKkcE0Q2RayZnSvnZBlwgxHBQhqt5pZNVy4Ri7H9GmmXkdu7LUthszM+Tor1u/2iBcpQ=="
|
||||
},
|
||||
"is-generator-fn": {
|
||||
"version": "2.1.0",
|
||||
@@ -54378,8 +54395,7 @@
|
||||
"is-object": {
|
||||
"version": "1.0.2",
|
||||
"resolved": "https://registry.npmjs.org/is-object/-/is-object-1.0.2.tgz",
|
||||
"integrity": "sha512-2rRIahhZr2UWb45fIOuvZGpFtz0TyOZLf32KxBbSoUCeZR495zCKlWUKKUByk3geS2eAs7ZAABt0Y/Rx0GiQGA==",
|
||||
"dev": true
|
||||
"integrity": "sha512-2rRIahhZr2UWb45fIOuvZGpFtz0TyOZLf32KxBbSoUCeZR495zCKlWUKKUByk3geS2eAs7ZAABt0Y/Rx0GiQGA=="
|
||||
},
|
||||
"is-path-inside": {
|
||||
"version": "3.0.3",
|
||||
@@ -54517,8 +54533,7 @@
|
||||
"is-window": {
|
||||
"version": "1.0.2",
|
||||
"resolved": "https://registry.npmjs.org/is-window/-/is-window-1.0.2.tgz",
|
||||
"integrity": "sha512-uj00kdXyZb9t9RcAUAwMZAnkBUwdYGhYlt7djMXhfyhUCzwNba50tIiBKR7q0l7tdoBtFVw/3JmLY6fI3rmZmg==",
|
||||
"dev": true
|
||||
"integrity": "sha512-uj00kdXyZb9t9RcAUAwMZAnkBUwdYGhYlt7djMXhfyhUCzwNba50tIiBKR7q0l7tdoBtFVw/3JmLY6fI3rmZmg=="
|
||||
},
|
||||
"is-windows": {
|
||||
"version": "1.0.2",
|
||||
@@ -57652,8 +57667,7 @@
|
||||
"lodash": {
|
||||
"version": "4.17.21",
|
||||
"resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.21.tgz",
|
||||
"integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==",
|
||||
"dev": true
|
||||
"integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg=="
|
||||
},
|
||||
"lodash.debounce": {
|
||||
"version": "4.0.8",
|
||||
@@ -57840,6 +57854,12 @@
|
||||
"yallist": "^3.0.2"
|
||||
}
|
||||
},
|
||||
"lucide-react": {
|
||||
"version": "0.105.0",
|
||||
"resolved": "https://registry.npmjs.org/lucide-react/-/lucide-react-0.105.0.tgz",
|
||||
"integrity": "sha512-iHaIkd4Wq6aNIVrFMXt3If8E/+2lnJd4WlCyntoJNIzZ8nWhdSSHWpsw7XM4rlw2319LZ2t4WLdnM8Z0ECDTOQ==",
|
||||
"requires": {}
|
||||
},
|
||||
"lz-string": {
|
||||
"version": "1.4.4",
|
||||
"resolved": "https://registry.npmjs.org/lz-string/-/lz-string-1.4.4.tgz",
|
||||
@@ -57910,8 +57930,7 @@
|
||||
"map-or-similar": {
|
||||
"version": "1.5.0",
|
||||
"resolved": "https://registry.npmjs.org/map-or-similar/-/map-or-similar-1.5.0.tgz",
|
||||
"integrity": "sha512-0aF7ZmVon1igznGI4VS30yugpduQW3y3GkcgGJOp7d8x8QrizhigUxjI/m2UojsXXto+jLAH3KSz+xOJTiORjg==",
|
||||
"dev": true
|
||||
"integrity": "sha512-0aF7ZmVon1igznGI4VS30yugpduQW3y3GkcgGJOp7d8x8QrizhigUxjI/m2UojsXXto+jLAH3KSz+xOJTiORjg=="
|
||||
},
|
||||
"map-visit": {
|
||||
"version": "1.0.0",
|
||||
@@ -58022,7 +58041,6 @@
|
||||
"version": "1.11.3",
|
||||
"resolved": "https://registry.npmjs.org/memoizerific/-/memoizerific-1.11.3.tgz",
|
||||
"integrity": "sha512-/EuHYwAPdLtXwAwSZkh/Gutery6pD2KYd44oQLhAvQp/50mpyduZh8Q7PYHXTCJ+wuXxt7oij2LXyIJOOYFPog==",
|
||||
"dev": true,
|
||||
"requires": {
|
||||
"map-or-similar": "^1.5.0"
|
||||
}
|
||||
@@ -58236,7 +58254,6 @@
|
||||
"version": "2.19.0",
|
||||
"resolved": "https://registry.npmjs.org/min-document/-/min-document-2.19.0.tgz",
|
||||
"integrity": "sha512-9Wy1B3m3f66bPPmU5hdA4DR4PB2OfDU/+GS3yAB7IQozE3tqXaVv2zOjgla7MEGSRv95+ILmOuvhLkOK6wJtCQ==",
|
||||
"dev": true,
|
||||
"requires": {
|
||||
"dom-walk": "^0.1.0"
|
||||
}
|
||||
@@ -60966,8 +60983,7 @@
|
||||
"p-try": {
|
||||
"version": "2.2.0",
|
||||
"resolved": "https://registry.npmjs.org/p-try/-/p-try-2.2.0.tgz",
|
||||
"integrity": "sha512-R4nPAVTAU0B9D35/Gk3uJf/7XYbQcyohSKdvAxIRSNghFl4e71hVoGnBNQz9cWaXxO2I10KTC+3jMdvvoKw6dQ==",
|
||||
"dev": true
|
||||
"integrity": "sha512-R4nPAVTAU0B9D35/Gk3uJf/7XYbQcyohSKdvAxIRSNghFl4e71hVoGnBNQz9cWaXxO2I10KTC+3jMdvvoKw6dQ=="
|
||||
},
|
||||
"pako": {
|
||||
"version": "1.0.11",
|
||||
@@ -61253,7 +61269,6 @@
|
||||
"version": "4.2.2",
|
||||
"resolved": "https://registry.npmjs.org/polished/-/polished-4.2.2.tgz",
|
||||
"integrity": "sha512-Sz2Lkdxz6F2Pgnpi9U5Ng/WdWAUZxmHrNPoVlm3aAemxoy2Qy7LGjQg4uf8qKelDAUW94F4np3iH2YPf2qefcQ==",
|
||||
"dev": true,
|
||||
"requires": {
|
||||
"@babel/runtime": "^7.17.8"
|
||||
}
|
||||
@@ -61670,8 +61685,7 @@
|
||||
"process": {
|
||||
"version": "0.11.10",
|
||||
"resolved": "https://registry.npmjs.org/process/-/process-0.11.10.tgz",
|
||||
"integrity": "sha512-cdGef/drWFoydD1JsMzuFf8100nZl+GT+yacc2bEced5f9Rjk4z+WtFUTBu9PhOi9j/jfmBPu0mMEY4wIdAF8A==",
|
||||
"dev": true
|
||||
"integrity": "sha512-cdGef/drWFoydD1JsMzuFf8100nZl+GT+yacc2bEced5f9Rjk4z+WtFUTBu9PhOi9j/jfmBPu0mMEY4wIdAF8A=="
|
||||
},
|
||||
"process-nextick-args": {
|
||||
"version": "2.0.1",
|
||||
@@ -61829,7 +61843,6 @@
|
||||
"version": "6.11.0",
|
||||
"resolved": "https://registry.npmjs.org/qs/-/qs-6.11.0.tgz",
|
||||
"integrity": "sha512-MvjoMCJwEarSbUYk5O+nmoSzSutSsTwF85zcHPQ9OrlFoZOYIjaqBAJIqIXjptyD5vThxGq52Xu/MaJzRkIk4Q==",
|
||||
"dev": true,
|
||||
"requires": {
|
||||
"side-channel": "^1.0.4"
|
||||
}
|
||||
@@ -61962,6 +61975,22 @@
|
||||
"@babel/runtime": "^7.12.13"
|
||||
}
|
||||
},
|
||||
"react-cookies": {
|
||||
"version": "0.1.1",
|
||||
"resolved": "https://registry.npmjs.org/react-cookies/-/react-cookies-0.1.1.tgz",
|
||||
"integrity": "sha512-PP75kJ4vtoHuuTdq0TAD3RmlAv7vuDQh9fkC4oDlhntgs9vX1DmREomO0Y1mcQKR9nMZ6/zxoflaMJ3MAmF5KQ==",
|
||||
"requires": {
|
||||
"cookie": "^0.3.1",
|
||||
"object-assign": "^4.1.1"
|
||||
},
|
||||
"dependencies": {
|
||||
"cookie": {
|
||||
"version": "0.3.1",
|
||||
"resolved": "https://registry.npmjs.org/cookie/-/cookie-0.3.1.tgz",
|
||||
"integrity": "sha512-+IJOX0OqlHCszo2mBUq+SrEbCj6w7Kpffqx60zYbPTFaO4+yYgRjHwcZNpWvaTylDHaV7PPmBHzSecZiMhtPgw=="
|
||||
}
|
||||
}
|
||||
},
|
||||
"react-docgen": {
|
||||
"version": "5.4.3",
|
||||
"resolved": "https://registry.npmjs.org/react-docgen/-/react-docgen-5.4.3.tgz",
|
||||
@@ -62053,12 +62082,6 @@
|
||||
"html-parse-stringify": "^3.0.1"
|
||||
}
|
||||
},
|
||||
"react-icons": {
|
||||
"version": "4.7.1",
|
||||
"resolved": "https://registry.npmjs.org/react-icons/-/react-icons-4.7.1.tgz",
|
||||
"integrity": "sha512-yHd3oKGMgm7zxo3EA7H2n7vxSoiGmHk5t6Ou4bXsfcgWyhfDKMpyKfhHR6Bjnn63c+YXBLBPUql9H4wPJM6sXw==",
|
||||
"requires": {}
|
||||
},
|
||||
"react-is": {
|
||||
"version": "16.13.1",
|
||||
"resolved": "https://registry.npmjs.org/react-is/-/react-is-16.13.1.tgz",
|
||||
@@ -62107,12 +62130,6 @@
|
||||
"tslib": "^2.0.0"
|
||||
}
|
||||
},
|
||||
"react-table": {
|
||||
"version": "7.8.0",
|
||||
"resolved": "https://registry.npmjs.org/react-table/-/react-table-7.8.0.tgz",
|
||||
"integrity": "sha512-hNaz4ygkZO4bESeFfnfOft73iBUj8K5oKi1EcSHPAibEydfsX2MyU6Z8KCr3mv3C9Kqqh71U+DhZkFvibbnPbA==",
|
||||
"requires": {}
|
||||
},
|
||||
"read-cache": {
|
||||
"version": "1.0.0",
|
||||
"resolved": "https://registry.npmjs.org/read-cache/-/read-cache-1.0.0.tgz",
|
||||
@@ -63591,8 +63608,7 @@
|
||||
"stable": {
|
||||
"version": "0.1.8",
|
||||
"resolved": "https://registry.npmjs.org/stable/-/stable-0.1.8.tgz",
|
||||
"integrity": "sha512-ji9qxRnOVfcuLDySj9qzhGSEFVobyt1kIOSkj1qZzYLzq7Tos/oUUWvotUPQLlrsidqsK6tBH89Bc9kL5zHA6w==",
|
||||
"dev": true
|
||||
"integrity": "sha512-ji9qxRnOVfcuLDySj9qzhGSEFVobyt1kIOSkj1qZzYLzq7Tos/oUUWvotUPQLlrsidqsK6tBH89Bc9kL5zHA6w=="
|
||||
},
|
||||
"stack-utils": {
|
||||
"version": "2.0.6",
|
||||
@@ -63716,8 +63732,22 @@
|
||||
"store2": {
|
||||
"version": "2.14.2",
|
||||
"resolved": "https://registry.npmjs.org/store2/-/store2-2.14.2.tgz",
|
||||
"integrity": "sha512-siT1RiqlfQnGqgT/YzXVUNsom9S0H1OX+dpdGN1xkyYATo4I6sep5NmsRD/40s3IIOvlCq6akxkqG82urIZW1w==",
|
||||
"dev": true
|
||||
"integrity": "sha512-siT1RiqlfQnGqgT/YzXVUNsom9S0H1OX+dpdGN1xkyYATo4I6sep5NmsRD/40s3IIOvlCq6akxkqG82urIZW1w=="
|
||||
},
|
||||
"storybook-addon-next-router": {
|
||||
"version": "4.0.2",
|
||||
"resolved": "https://registry.npmjs.org/storybook-addon-next-router/-/storybook-addon-next-router-4.0.2.tgz",
|
||||
"integrity": "sha512-0rjGAl7HziW4ecDq+VU03H1dwkw5f6phqA+PMquPzdowNVl29ejVwVadLMGlovG6x2snaxMGxtySR2c5bwegSw==",
|
||||
"requires": {
|
||||
"tslib": "2.4.0"
|
||||
},
|
||||
"dependencies": {
|
||||
"tslib": {
|
||||
"version": "2.4.0",
|
||||
"resolved": "https://registry.npmjs.org/tslib/-/tslib-2.4.0.tgz",
|
||||
"integrity": "sha512-d6xOpEDfsi2CZVlPQzGeux8XMwLT9hssAsaPYExaQMuYskwb+x1x7J371tWlbBdWHroy99KnVB6qIkUbs5X3UQ=="
|
||||
}
|
||||
}
|
||||
},
|
||||
"stream-browserify": {
|
||||
"version": "2.0.2",
|
||||
@@ -64058,8 +64088,7 @@
|
||||
"synchronous-promise": {
|
||||
"version": "2.0.16",
|
||||
"resolved": "https://registry.npmjs.org/synchronous-promise/-/synchronous-promise-2.0.16.tgz",
|
||||
"integrity": "sha512-qImOD23aDfnIDNqlG1NOehdB9IYsn1V9oByPjKY1nakv2MQYCEMyX033/q+aEtYCpmYK1cv2+NTmlH+ra6GA5A==",
|
||||
"dev": true
|
||||
"integrity": "sha512-qImOD23aDfnIDNqlG1NOehdB9IYsn1V9oByPjKY1nakv2MQYCEMyX033/q+aEtYCpmYK1cv2+NTmlH+ra6GA5A=="
|
||||
},
|
||||
"synckit": {
|
||||
"version": "0.8.4",
|
||||
@@ -64186,7 +64215,6 @@
|
||||
"version": "6.0.8",
|
||||
"resolved": "https://registry.npmjs.org/telejson/-/telejson-6.0.8.tgz",
|
||||
"integrity": "sha512-nerNXi+j8NK1QEfBHtZUN/aLdDcyupA//9kAboYLrtzZlPLpUfqbVGWb9zz91f/mIjRbAYhbgtnJHY8I1b5MBg==",
|
||||
"dev": true,
|
||||
"requires": {
|
||||
"@types/is-function": "^1.0.0",
|
||||
"global": "^4.4.0",
|
||||
@@ -64201,8 +64229,7 @@
|
||||
"isobject": {
|
||||
"version": "4.0.0",
|
||||
"resolved": "https://registry.npmjs.org/isobject/-/isobject-4.0.0.tgz",
|
||||
"integrity": "sha512-S/2fF5wH8SJA/kmwr6HYhK/RI/OkhD84k8ntalo0iJjZikgq1XFvR5M8NPT1x5F7fBwCG3qHfnzeP/Vh/ZxCUA==",
|
||||
"dev": true
|
||||
"integrity": "sha512-S/2fF5wH8SJA/kmwr6HYhK/RI/OkhD84k8ntalo0iJjZikgq1XFvR5M8NPT1x5F7fBwCG3qHfnzeP/Vh/ZxCUA=="
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -64493,8 +64520,7 @@
|
||||
"ts-dedent": {
|
||||
"version": "2.2.0",
|
||||
"resolved": "https://registry.npmjs.org/ts-dedent/-/ts-dedent-2.2.0.tgz",
|
||||
"integrity": "sha512-q5W7tVM71e2xjHZTlgfTDoPF/SmqKG5hddq9SzR49CH2hayqRKJtQ4mtRlSxKaJlR/+9rEM+mnBHf7I2/BQcpQ==",
|
||||
"dev": true
|
||||
"integrity": "sha512-q5W7tVM71e2xjHZTlgfTDoPF/SmqKG5hddq9SzR49CH2hayqRKJtQ4mtRlSxKaJlR/+9rEM+mnBHf7I2/BQcpQ=="
|
||||
},
|
||||
"ts-node": {
|
||||
"version": "10.9.1",
|
||||
@@ -65058,8 +65084,7 @@
|
||||
"uuid-browser": {
|
||||
"version": "3.1.0",
|
||||
"resolved": "https://registry.npmjs.org/uuid-browser/-/uuid-browser-3.1.0.tgz",
|
||||
"integrity": "sha512-dsNgbLaTrd6l3MMxTtouOCFw4CBFc/3a+GgYA2YyrJvyQ1u6q4pcu3ktLoUZ/VN/Aw9WsauazbgsgdfVWgAKQg==",
|
||||
"dev": true
|
||||
"integrity": "sha512-dsNgbLaTrd6l3MMxTtouOCFw4CBFc/3a+GgYA2YyrJvyQ1u6q4pcu3ktLoUZ/VN/Aw9WsauazbgsgdfVWgAKQg=="
|
||||
},
|
||||
"v8-compile-cache-lib": {
|
||||
"version": "3.0.1",
|
||||
|
||||
@@ -38,6 +38,8 @@
|
||||
"@next/font": "^13.1.0",
|
||||
"@prisma/client": "^4.7.1",
|
||||
"@tailwindcss/forms": "^0.5.3",
|
||||
"@tanstack/react-table": "^8.7.6",
|
||||
"accept-language-parser": "^1.5.0",
|
||||
"autoprefixer": "^10.4.13",
|
||||
"axios": "^1.2.1",
|
||||
"boolean": "^3.2.0",
|
||||
@@ -48,6 +50,7 @@
|
||||
"focus-visible": "^5.2.0",
|
||||
"framer-motion": "^6.5.1",
|
||||
"install": "^0.13.0",
|
||||
"lucide-react": "^0.105.0",
|
||||
"next": "13.0.6",
|
||||
"next-auth": "^4.18.6",
|
||||
"next-i18next": "^13.0.3",
|
||||
@@ -55,13 +58,13 @@
|
||||
"npm": "^9.2.0",
|
||||
"postcss-focus-visible": "^7.1.0",
|
||||
"react": "18.2.0",
|
||||
"react-cookies": "^0.1.1",
|
||||
"react-dom": "18.2.0",
|
||||
"react-feature-flags": "^1.0.0",
|
||||
"react-hook-form": "^7.42.1",
|
||||
"react-i18next": "^12.1.4",
|
||||
"react-icons": "^4.7.1",
|
||||
"react-table": "^7.8.0",
|
||||
"sharp": "^0.31.3",
|
||||
"storybook-addon-next-router": "^4.0.2",
|
||||
"swr": "^2.0.0",
|
||||
"tailwindcss": "^3.2.4",
|
||||
"unique-username-generator": "^1.1.3",
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
{
|
||||
"about": "About",
|
||||
"account_settings": "Account",
|
||||
"admin_dashboard": "Admin Dashboard",
|
||||
"connect": "Connect",
|
||||
"conversational": "Conversational AI for everyone.",
|
||||
"dashboard": "Dashboard",
|
||||
@@ -8,6 +9,8 @@
|
||||
"docs": "Docs",
|
||||
"github": "GitHub",
|
||||
"legal": "Legal",
|
||||
"loading": "Loading...",
|
||||
"more_information": "More Information",
|
||||
"privacy_policy": "Privacy Policy",
|
||||
"report_a_bug": "Report a Bug",
|
||||
"sign_in": "Sign In",
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"grab_a_task": "Grab a task!",
|
||||
"create": "Create",
|
||||
"evaluate": "Evaluate",
|
||||
"label": "Label",
|
||||
"dashboard": "Dashboard",
|
||||
"go": "Go"
|
||||
}
|
||||
@@ -7,5 +7,8 @@
|
||||
"rank": "Rank",
|
||||
"score": "Score",
|
||||
"user": "User",
|
||||
"weekly": "Weekly"
|
||||
"weekly": "Weekly",
|
||||
"prompt": "Prompts",
|
||||
"reply": "Replies",
|
||||
"label": "Labels"
|
||||
}
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
{
|
||||
"reactions": "Reactions",
|
||||
"label_action": "Label",
|
||||
"label_title": "Label",
|
||||
"submit_labels": "Submit",
|
||||
"open_new_tab_action": "Open in new tab",
|
||||
"report_title": "Report",
|
||||
"report_action": "Report",
|
||||
"report_placeholder": "Why should this message be reviewed?",
|
||||
"send_report": "Send"
|
||||
}
|
||||
@@ -0,0 +1,79 @@
|
||||
{
|
||||
"write_initial_prompt": "Write your prompt here...",
|
||||
"default": {
|
||||
"unchanged_title": "No changes",
|
||||
"unchanged_message": "Are you sure you would like to continue?"
|
||||
},
|
||||
"random": {
|
||||
"label": "I'm feeling lucky",
|
||||
"desc": "Help us improve Open Assistant by starting a random task."
|
||||
},
|
||||
"create_initial_prompt": {
|
||||
"label": "Create Initial Prompts",
|
||||
"desc": "Write initial prompts to help Open Assistant to try replying to diverse messages.",
|
||||
"overview": "Create an initial message to send to the assistant",
|
||||
"instruction": "Provide the initial prompts"
|
||||
},
|
||||
"reply_as_user": {
|
||||
"label": "Reply as User",
|
||||
"desc": "Chat with Open Assistant and help improve it's responses as you interact with it.",
|
||||
"overview": "Given the following conversation, provide an adequate reply",
|
||||
"instruction": "Provide the user's reply"
|
||||
},
|
||||
"reply_as_assistant": {
|
||||
"label": "Reply as Assistant",
|
||||
"desc": "Help Open Assistant improve its responses to conversations with other users.",
|
||||
"overview": "Given the following conversation, provide an adequate reply"
|
||||
},
|
||||
"rank_user_replies": {
|
||||
"label": "Rank User Replies",
|
||||
"desc": "Help Open Assistant improve its responses to conversations with other users.",
|
||||
"overview": "Given the following User replies, sort them from best to worst, best being first, worst being last.",
|
||||
"unchanged_title": "Order Unchanged",
|
||||
"unchanged_message": "You have not changed the order of the prompts. Are you sure you would like to continue?"
|
||||
},
|
||||
"rank_assistant_replies": {
|
||||
"label": "Rank Assistant Replies",
|
||||
"desc": "Score prompts given by Open Assistant based on their accuracy and readability.",
|
||||
"overview": "Given the following Assistant replies, sort them from best to worst, best being first, worst being last.",
|
||||
"unchanged_title": "Order Unchanged",
|
||||
"unchanged_message": "You have not changed the order of the prompts. Are you sure you would like to continue?"
|
||||
},
|
||||
"rank_initial_prompts": {
|
||||
"label": "Rank Initial Prompts",
|
||||
"desc": "Score prompts given by Open Assistant based on their accuracy and readability.",
|
||||
"overview": "Given the following initial prompts, sort them from best to worst, best being first, worst being last.",
|
||||
"unchanged_title": "Order Unchanged",
|
||||
"unchanged_message": "You have not changed the order of the prompts. Are you sure you would like to continue?"
|
||||
},
|
||||
"label_initial_prompt": {
|
||||
"label": "Label Initial Prompt",
|
||||
"desc": "Provide labels for a prompt.",
|
||||
"overview": "Provide labels for the following prompt"
|
||||
},
|
||||
"label_prompter_reply": {
|
||||
"label": "Label Prompter Reply",
|
||||
"desc": "Provide labels for a prompt.",
|
||||
"overview": "Given the following discussion, provide labels for the final prompt."
|
||||
},
|
||||
"label_assistant_reply": {
|
||||
"label": "Label Assistant Reply",
|
||||
"desc": "Provide labels for a prompt.",
|
||||
"overview": "Given the following discussion, provide labels for the final prompt."
|
||||
},
|
||||
"classify_initial_prompt": {
|
||||
"label": "Classify Initial Prompt",
|
||||
"desc": "Provide labels for a prompt.",
|
||||
"overview": "Read the following prompt and then answer the question about it."
|
||||
},
|
||||
"classify_prompter_reply": {
|
||||
"label": "Classify Prompter Reply",
|
||||
"desc": "Provide labels for a prompt.",
|
||||
"overview": "Read the following conversation and then answer the question about the last reply in the discussion."
|
||||
},
|
||||
"classify_assistant_reply": {
|
||||
"label": "Classify Assistant Reply",
|
||||
"desc": "Provide labels for a prompt.",
|
||||
"overview": "Read the following conversation and then answer the question about the last reply in the discussion."
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
import { Radio, RadioGroup } from "@chakra-ui/react";
|
||||
import { PropsWithChildren } from "react";
|
||||
|
||||
export const LikertButtons = ({
|
||||
isDisabled,
|
||||
count,
|
||||
onChange,
|
||||
"data-cy": dataCy,
|
||||
}: PropsWithChildren<{
|
||||
isDisabled: boolean;
|
||||
count: number;
|
||||
onChange: (value: number) => void;
|
||||
"data-cy"?: string;
|
||||
}>) => {
|
||||
const valueMap = Object.fromEntries(Array.from({ length: count }, (_, idx) => [`${idx}`, idx / (count - 1)]));
|
||||
|
||||
return (
|
||||
<RadioGroup
|
||||
data-cy={dataCy}
|
||||
isDisabled={isDisabled}
|
||||
onChange={(value) => {
|
||||
onChange(valueMap[value]);
|
||||
}}
|
||||
style={{ display: "flex", justifyContent: "space-between" }}
|
||||
>
|
||||
{Object.keys(valueMap).map((value) => {
|
||||
return <Radio key={value} value={value} data-cy="radio-option" size="md" padding="0.6em" />;
|
||||
})}
|
||||
</RadioGroup>
|
||||
);
|
||||
};
|
||||
@@ -1,9 +1,10 @@
|
||||
import { Box, Link, Text, useColorMode } from "@chakra-ui/react";
|
||||
import { Github } from "lucide-react";
|
||||
import { useTranslation } from "next-i18next";
|
||||
import { useId } from "react";
|
||||
import { FaDiscord, FaGithub } from "react-icons/fa";
|
||||
|
||||
import { Container } from "./Container";
|
||||
import { Discord } from "./Icons/Discord";
|
||||
|
||||
const CIRCLE_HEIGHT = 558;
|
||||
const CIRCLE_WIDTH = 558;
|
||||
@@ -70,7 +71,7 @@ export function CallToAction() {
|
||||
type="button"
|
||||
className="mb-2 ml-6 flex items-center rounded-md border border-transparent bg-blue-600 px-6 py-3 text-base font-medium text-white shadow-sm hover:bg-blue-700 focus:outline-none focus:ring-2 focus:ring-blue-500 focus:ring-offset-2"
|
||||
>
|
||||
<FaDiscord size={25} />
|
||||
<Discord size={25} />
|
||||
<Text as="span" className="text-lg ml-3">
|
||||
{t("discord")}
|
||||
</Text>
|
||||
@@ -81,7 +82,7 @@ export function CallToAction() {
|
||||
type="button"
|
||||
className="mb-2 ml-6 flex items-center rounded-md border border-transparent bg-blue-600 px-6 py-3 text-base font-medium text-white shadow-sm hover:bg-blue-700 focus:outline-none focus:ring-2 focus:ring-blue-500 focus:ring-offset-2"
|
||||
>
|
||||
<FaGithub size={25} />
|
||||
<Github size={25} />
|
||||
<Text as="span" className="text-lg ml-3">
|
||||
{t("github")}
|
||||
</Text>
|
||||
|
||||
+8
-14
@@ -1,11 +1,9 @@
|
||||
import { Box, Link, Text, useColorModeValue } from "@chakra-ui/react";
|
||||
import { Card, CardBody, Link, Text } from "@chakra-ui/react";
|
||||
import NextLink from "next/link";
|
||||
import { LeaderboardGridCell } from "src/components/LeaderboardGridCell";
|
||||
import { LeaderboardTable } from "src/components/LeaderboardTable";
|
||||
import { LeaderboardTimeFrame } from "src/types/Leaderboard";
|
||||
|
||||
export function LeaderboardTable() {
|
||||
const backgroundColor = useColorModeValue("white", "gray.700");
|
||||
const accentColor = useColorModeValue("gray.200", "gray.900");
|
||||
export function LeaderboardWidget() {
|
||||
return (
|
||||
<main className="h-fit col-span-3">
|
||||
<div className="flex flex-col gap-4">
|
||||
@@ -17,15 +15,11 @@ export function LeaderboardTable() {
|
||||
</Text>
|
||||
</Link>
|
||||
</div>
|
||||
<Box
|
||||
backgroundColor={backgroundColor}
|
||||
boxShadow="base"
|
||||
dropShadow={accentColor}
|
||||
borderRadius="xl"
|
||||
className="p-6 shadow-sm"
|
||||
>
|
||||
<LeaderboardGridCell timeFrame={LeaderboardTimeFrame.day} />
|
||||
</Box>
|
||||
<Card>
|
||||
<CardBody>
|
||||
<LeaderboardTable timeFrame={LeaderboardTimeFrame.day} limit={5} />
|
||||
</CardBody>
|
||||
</Card>
|
||||
</div>
|
||||
</main>
|
||||
);
|
||||
@@ -1,51 +1,96 @@
|
||||
import { Box, Flex, GridItem, Heading, SimpleGrid, Text, useColorModeValue } from "@chakra-ui/react";
|
||||
import {
|
||||
Box,
|
||||
Flex,
|
||||
GridItem,
|
||||
Heading,
|
||||
IconButton,
|
||||
Link as ExternalLink,
|
||||
SimpleGrid,
|
||||
Spacer,
|
||||
Text,
|
||||
useColorModeValue,
|
||||
} from "@chakra-ui/react";
|
||||
import { HelpCircle } from "lucide-react";
|
||||
import Link from "next/link";
|
||||
import { useTranslation } from "next-i18next";
|
||||
import { useMemo } from "react";
|
||||
import { getTypeSafei18nKey } from "src/lib/i18n";
|
||||
import { TaskType } from "src/types/Task";
|
||||
|
||||
import { TaskCategory, TaskCategoryLabels, TaskTypes } from "../Tasks/TaskTypes";
|
||||
import { TaskCategory, TaskCategoryLabels, TaskInfo, TaskInfos } from "../Tasks/TaskTypes";
|
||||
|
||||
export const TaskOption = ({ displayTaskCategories }: { displayTaskCategories: TaskCategory[] }) => {
|
||||
export interface TasksOptionProps {
|
||||
content: Partial<Record<TaskCategory, TaskType[]>>;
|
||||
}
|
||||
|
||||
export const TaskOption = ({ content }: TasksOptionProps) => {
|
||||
const { t } = useTranslation(["dashboard", "tasks"]);
|
||||
const backgroundColor = useColorModeValue("white", "gray.700");
|
||||
|
||||
const taskInfoMap = useMemo(
|
||||
() =>
|
||||
Object.values(content)
|
||||
.flat()
|
||||
.reduce((obj, taskType) => {
|
||||
obj[taskType] = TaskInfos.filter((t) => t.type === taskType).pop();
|
||||
return obj;
|
||||
}, {} as Record<TaskType, TaskInfo>),
|
||||
[content]
|
||||
);
|
||||
|
||||
return (
|
||||
<Box className="flex flex-col gap-14">
|
||||
{displayTaskCategories.map((category) => (
|
||||
{Object.entries(content).map(([category, taskTypes]) => (
|
||||
<div key={category}>
|
||||
<Text className="text-2xl font-bold pb-4">{TaskCategoryLabels[category]}</Text>
|
||||
<Flex>
|
||||
<Heading size="lg" className="pb-4">
|
||||
{t(TaskCategoryLabels[category])}
|
||||
</Heading>
|
||||
<Spacer />
|
||||
<ExternalLink href="https://projects.laion.ai/Open-Assistant/" isExternal>
|
||||
<IconButton variant="ghost" aria-label="More Information" icon={<HelpCircle size="2em" />} />
|
||||
</ExternalLink>
|
||||
</Flex>
|
||||
<SimpleGrid columns={[1, 1, 2, 2, 3, 4]} gap={4}>
|
||||
{TaskTypes.filter((task) => task.category === category).map((item) => (
|
||||
<Link key={category + item.label} href={item.pathname}>
|
||||
<GridItem
|
||||
bg={backgroundColor}
|
||||
borderRadius="xl"
|
||||
boxShadow="base"
|
||||
className="flex flex-col justify-between h-full"
|
||||
>
|
||||
<Box className="p-6 pb-10">
|
||||
<Flex flexDir="column" gap="3">
|
||||
<Heading size="md" fontFamily="inter">
|
||||
{item.label}
|
||||
</Heading>
|
||||
<Text size="sm" opacity="80%">
|
||||
{item.desc}
|
||||
</Text>
|
||||
</Flex>
|
||||
</Box>
|
||||
<Box
|
||||
bg="blue.500"
|
||||
borderBottomRadius="xl"
|
||||
className="px-6 py-2 transition-colors duration-300"
|
||||
_hover={{ backgroundColor: "blue.600" }}
|
||||
{taskTypes
|
||||
.map((taskType) => taskInfoMap[taskType])
|
||||
.map((item) => (
|
||||
<Link key={category + item.id} href={item.pathname}>
|
||||
<GridItem
|
||||
bg={backgroundColor}
|
||||
borderRadius="xl"
|
||||
boxShadow="base"
|
||||
className="flex flex-col justify-between h-full"
|
||||
>
|
||||
<Text fontWeight="bold" color="white">
|
||||
Go ->
|
||||
<Flex className="p-6 pb-10" flexDir="column" gap="3">
|
||||
<Heading size="md">{t(getTypeSafei18nKey(`tasks:${item.id}.label`))}</Heading>
|
||||
<Text size="sm">{t(getTypeSafei18nKey(`tasks:${item.id}.desc`))}</Text>
|
||||
</Flex>
|
||||
<Text
|
||||
fontWeight="bold"
|
||||
color="white"
|
||||
borderBottomRadius="xl"
|
||||
className="px-6 py-2 transition-colors duration-300 bg-blue-500 hover:bg-blue-600"
|
||||
>
|
||||
{t("go")} ->
|
||||
</Text>
|
||||
</Box>
|
||||
</GridItem>
|
||||
</Link>
|
||||
))}
|
||||
</GridItem>
|
||||
</Link>
|
||||
))}
|
||||
</SimpleGrid>
|
||||
</div>
|
||||
))}
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
export const allTaskOptions: TasksOptionProps["content"] = {
|
||||
[TaskCategory.Random]: [TaskType.random],
|
||||
[TaskCategory.Create]: [TaskType.initial_prompt, TaskType.prompter_reply, TaskType.assistant_reply],
|
||||
[TaskCategory.Evaluate]: [
|
||||
TaskType.rank_initial_prompts,
|
||||
TaskType.rank_prompter_replies,
|
||||
TaskType.rank_assistant_replies,
|
||||
],
|
||||
[TaskCategory.Label]: [TaskType.label_initial_prompt, TaskType.label_prompter_reply, TaskType.label_assistant_reply],
|
||||
};
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
export { LeaderboardTable } from "./LeaderboardTable";
|
||||
export { LeaderboardWidget } from "./LeaderboardWidget";
|
||||
export { TaskOption } from "./TaskOption";
|
||||
export { WelcomeCard } from "./WelcomeCard";
|
||||
|
||||
@@ -0,0 +1,166 @@
|
||||
import {
|
||||
Box,
|
||||
Button,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
Input,
|
||||
Popover,
|
||||
PopoverArrow,
|
||||
PopoverBody,
|
||||
PopoverCloseButton,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
Spacer,
|
||||
Table,
|
||||
TableCaption,
|
||||
TableContainer,
|
||||
Tbody,
|
||||
Td,
|
||||
Th,
|
||||
Thead,
|
||||
Tr,
|
||||
useDisclosure,
|
||||
} from "@chakra-ui/react";
|
||||
import { ColumnDef, flexRender, getCoreRowModel, useReactTable } from "@tanstack/react-table";
|
||||
import { Filter } from "lucide-react";
|
||||
import { ChangeEvent, ReactNode } from "react";
|
||||
import { useDebouncedCallback } from "use-debounce";
|
||||
|
||||
export type DataTableColumnDef<T> = ColumnDef<T> & {
|
||||
filterable?: boolean;
|
||||
};
|
||||
|
||||
// TODO: stricter type
|
||||
export type FilterItem = {
|
||||
id: string;
|
||||
value: string;
|
||||
};
|
||||
|
||||
export type DataTableProps<T> = {
|
||||
data: T[];
|
||||
columns: DataTableColumnDef<T>[];
|
||||
caption?: string;
|
||||
filterValues?: FilterItem[];
|
||||
onNextClick?: () => void;
|
||||
onPreviousClick?: () => void;
|
||||
onFilterChange?: (items: FilterItem[]) => void;
|
||||
disableNext?: boolean;
|
||||
disablePrevious?: boolean;
|
||||
disablePagination?: boolean;
|
||||
};
|
||||
|
||||
export const DataTable = <T,>({
|
||||
data,
|
||||
columns,
|
||||
caption,
|
||||
filterValues = [],
|
||||
onNextClick,
|
||||
onPreviousClick,
|
||||
onFilterChange,
|
||||
disableNext,
|
||||
disablePrevious,
|
||||
disablePagination,
|
||||
}: DataTableProps<T>) => {
|
||||
const { getHeaderGroups, getRowModel } = useReactTable<T>({
|
||||
data,
|
||||
columns,
|
||||
getCoreRowModel: getCoreRowModel(),
|
||||
});
|
||||
|
||||
const handleFilterChange = (value: FilterItem) => {
|
||||
const idx = filterValues.findIndex((oldValue) => oldValue.id === value.id);
|
||||
let newValues: FilterItem[] = [];
|
||||
if (idx === -1) {
|
||||
newValues = [...filterValues, value];
|
||||
} else {
|
||||
newValues = filterValues.map((oldValue) => (oldValue.id === value.id ? value : oldValue));
|
||||
}
|
||||
onFilterChange(newValues);
|
||||
};
|
||||
return (
|
||||
<>
|
||||
{!disablePagination && (
|
||||
<Flex mb="2">
|
||||
<Button onClick={onPreviousClick} disabled={disablePrevious}>
|
||||
Previous
|
||||
</Button>
|
||||
<Spacer />
|
||||
<Button onClick={onNextClick} disabled={disableNext}>
|
||||
Next
|
||||
</Button>
|
||||
</Flex>
|
||||
)}
|
||||
<TableContainer>
|
||||
<Table variant="simple">
|
||||
<TableCaption>{caption}</TableCaption>
|
||||
<Thead>
|
||||
{getHeaderGroups().map((headerGroup) => (
|
||||
<Tr key={headerGroup.id}>
|
||||
{headerGroup.headers.map((header) => (
|
||||
<Th key={header.id}>
|
||||
<Box display="flex" alignItems="center">
|
||||
{header.isPlaceholder ? null : flexRender(header.column.columnDef.header, header.getContext())}
|
||||
{(header.column.columnDef as DataTableColumnDef<T>).filterable && (
|
||||
<FilterModal
|
||||
value={filterValues.find((value) => value.id === header.id)?.value ?? ""}
|
||||
onChange={(value) => handleFilterChange({ id: header.id, value })}
|
||||
label={flexRender(header.column.columnDef.header, header.getContext())}
|
||||
></FilterModal>
|
||||
)}
|
||||
</Box>
|
||||
</Th>
|
||||
))}
|
||||
</Tr>
|
||||
))}
|
||||
</Thead>
|
||||
<Tbody>
|
||||
{getRowModel().rows.map((row) => (
|
||||
<Tr key={row.id}>
|
||||
{row.getVisibleCells().map((cell) => (
|
||||
<Td key={cell.id}>{flexRender(cell.column.columnDef.cell, cell.getContext())}</Td>
|
||||
))}
|
||||
</Tr>
|
||||
))}
|
||||
</Tbody>
|
||||
</Table>
|
||||
</TableContainer>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
const FilterModal = ({
|
||||
label,
|
||||
onChange,
|
||||
value,
|
||||
}: {
|
||||
label: ReactNode;
|
||||
onChange: (val: string) => void;
|
||||
value: string;
|
||||
}) => {
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
|
||||
const handleInputChange = useDebouncedCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
onChange(e.target.value);
|
||||
}, 500);
|
||||
|
||||
return (
|
||||
<Popover isOpen={isOpen} onOpen={onOpen} onClose={onClose}>
|
||||
<PopoverTrigger>
|
||||
<Button variant={"unstyled"} ml="2">
|
||||
<Filter size="1em"></Filter>
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent w="fit-content">
|
||||
<PopoverArrow />
|
||||
<PopoverCloseButton />
|
||||
<PopoverBody mt="4">
|
||||
<FormControl>
|
||||
<FormLabel>{label}</FormLabel>
|
||||
<Input onChange={handleInputChange} defaultValue={value}></Input>
|
||||
</FormControl>
|
||||
</PopoverBody>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
};
|
||||
@@ -1,30 +1,29 @@
|
||||
import { Box, Link, Text, useColorModeValue } from "@chakra-ui/react";
|
||||
import { useRouter } from "next/router";
|
||||
import { FiAlertTriangle } from "react-icons/fi";
|
||||
import { IconType } from "react-icons/lib";
|
||||
import { Box, Text, useColorModeValue } from "@chakra-ui/react";
|
||||
import { AlertTriangle, LucideIcon } from "lucide-react";
|
||||
import NextLink from "next/link";
|
||||
|
||||
type EmptyStateProps = {
|
||||
text: string;
|
||||
icon: IconType;
|
||||
icon: LucideIcon;
|
||||
"data-cy"?: string;
|
||||
};
|
||||
|
||||
export const EmptyState = (props: EmptyStateProps) => {
|
||||
const backgroundColor = useColorModeValue("white", "gray.800");
|
||||
const router = useRouter();
|
||||
|
||||
return (
|
||||
<Box bg={backgroundColor} p="10" borderRadius="xl" shadow="base">
|
||||
<Box data-cy={props["data-cy"]} bg={backgroundColor} p="10" borderRadius="xl" shadow="base">
|
||||
<Box display="flex" flexDirection="column" alignItems="center" gap="8" fontSize="lg">
|
||||
<props.icon size="30" color="DarkOrange" />
|
||||
<Text>{props.text}</Text>
|
||||
<Link onClick={() => router.back()} color="blue.500" textUnderlineOffset="3px">
|
||||
<Text>Click here to go back</Text>
|
||||
</Link>
|
||||
<Text data-cy="cy-no-tasks">{props.text}</Text>
|
||||
<NextLink href="/dashboard">
|
||||
<Text color="blue.500">Go back to the dashboard</Text>
|
||||
</NextLink>
|
||||
</Box>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
export const TaskEmptyState = () => {
|
||||
return <EmptyState text="Looks like no tasks were found." icon={FiAlertTriangle} />;
|
||||
return <EmptyState text="Looks like no tasks were found." icon={AlertTriangle} data-cy="task" />;
|
||||
};
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
import {
|
||||
IconButton,
|
||||
Popover,
|
||||
PopoverArrow,
|
||||
PopoverBody,
|
||||
PopoverCloseButton,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
Text,
|
||||
} from "@chakra-ui/react";
|
||||
import { InformationCircleIcon } from "@heroicons/react/20/solid";
|
||||
|
||||
interface ExplainProps {
|
||||
explanation: string[];
|
||||
}
|
||||
|
||||
export const Explain = ({ explanation }: ExplainProps) => {
|
||||
return (
|
||||
<Popover>
|
||||
<PopoverTrigger>
|
||||
<IconButton
|
||||
aria-label="explanation"
|
||||
variant="link"
|
||||
size="xs"
|
||||
icon={<InformationCircleIcon className="h-4 w-4" />}
|
||||
></IconButton>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent>
|
||||
<PopoverArrow />
|
||||
<PopoverCloseButton />
|
||||
<PopoverBody>
|
||||
{explanation.map((paragraph, idx) => (
|
||||
<Text key={idx}>{paragraph}</Text>
|
||||
))}
|
||||
</PopoverBody>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
};
|
||||
@@ -27,10 +27,10 @@ export function Faq() {
|
||||
return (
|
||||
<ListItem className="space-y-10" key={`question_${index}`}>
|
||||
<Text as="h3" className={`text-lg font-semibold leading-6 ${headingColorClass}`}>
|
||||
{t(`faq_items.q${index}`)}
|
||||
{t(`faq_items.q${index as 0}`)}
|
||||
</Text>
|
||||
<Text as="p" className={`mt-4 text-sm ${textColorClass}`}>
|
||||
{t(`faq_items.a${index}`)}
|
||||
{t(`faq_items.a${index as 0}`)}
|
||||
</Text>
|
||||
</ListItem>
|
||||
);
|
||||
|
||||
@@ -1,127 +1,69 @@
|
||||
import {
|
||||
Box,
|
||||
Button,
|
||||
Checkbox,
|
||||
Flex,
|
||||
Modal,
|
||||
ModalBody,
|
||||
ModalCloseButton,
|
||||
ModalContent,
|
||||
ModalFooter,
|
||||
ModalHeader,
|
||||
ModalOverlay,
|
||||
Popover,
|
||||
PopoverAnchor,
|
||||
PopoverArrow,
|
||||
PopoverBody,
|
||||
PopoverCloseButton,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
Slider,
|
||||
SliderFilledTrack,
|
||||
SliderThumb,
|
||||
SliderTrack,
|
||||
Tooltip,
|
||||
useBoolean,
|
||||
useColorMode,
|
||||
useColorModeValue,
|
||||
useId,
|
||||
useDisclosure,
|
||||
} from "@chakra-ui/react";
|
||||
import { QuestionMarkCircleIcon } from "@heroicons/react/20/solid";
|
||||
import clsx from "clsx";
|
||||
import { useEffect, useReducer } from "react";
|
||||
import { FiAlertCircle } from "react-icons/fi";
|
||||
import { AlertCircle } from "lucide-react";
|
||||
import { useState } from "react";
|
||||
import { get, post } from "src/lib/api";
|
||||
import { colors } from "src/styles/Theme/colors";
|
||||
import { Message } from "src/types/Conversation";
|
||||
import useSWR from "swr";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
|
||||
import { LabelInputGroup } from "./Survey/LabelInputGroup";
|
||||
|
||||
interface Label {
|
||||
name: string;
|
||||
display_text: string;
|
||||
help_text: string;
|
||||
}
|
||||
|
||||
interface LoadLabelsAction {
|
||||
type: "load_labels";
|
||||
labels: Label[];
|
||||
}
|
||||
|
||||
interface UpdateValueAction {
|
||||
type: "update_value";
|
||||
label_index: number;
|
||||
value: number;
|
||||
}
|
||||
|
||||
interface ToggleLabelAction {
|
||||
type: "toggle_label";
|
||||
label_index: number;
|
||||
check: boolean;
|
||||
}
|
||||
|
||||
interface LabelValue {
|
||||
label: Label;
|
||||
checked: boolean;
|
||||
value: number;
|
||||
}
|
||||
|
||||
interface FlagReportState {
|
||||
label_values: LabelValue[];
|
||||
submittable: boolean;
|
||||
}
|
||||
|
||||
interface FlaggableElementProps {
|
||||
children: React.ReactNode;
|
||||
message: Message;
|
||||
}
|
||||
|
||||
interface ValidLabelsResponse {
|
||||
valid_labels: Label[];
|
||||
}
|
||||
|
||||
export const FlaggableElement = (props: FlaggableElementProps) => {
|
||||
const [report, updateReport] = useReducer(
|
||||
(state: FlagReportState, action: LoadLabelsAction | UpdateValueAction | ToggleLabelAction): FlagReportState => {
|
||||
const makeState = (label_values: LabelValue[]): FlagReportState => {
|
||||
const submittable = label_values.map(({ checked }) => checked).some(Boolean);
|
||||
return { label_values, submittable };
|
||||
};
|
||||
const { data: response } = useSWRImmutable<ValidLabelsResponse>("/api/valid_labels", get);
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
const { valid_labels } = response || { valid_labels: [] };
|
||||
const [values, setValues] = useState<number[]>([]);
|
||||
|
||||
switch (action.type) {
|
||||
case "load_labels":
|
||||
return makeState(
|
||||
action.labels.map((label) => {
|
||||
return { label, checked: false, value: 1 };
|
||||
})
|
||||
);
|
||||
case "toggle_label": {
|
||||
const values_copy = state.label_values.slice();
|
||||
values_copy[action.label_index].checked = action.check;
|
||||
return makeState(values_copy);
|
||||
}
|
||||
case "update_value": {
|
||||
const values_copy = state.label_values.slice();
|
||||
values_copy[action.label_index].value = action.value;
|
||||
return makeState(values_copy);
|
||||
}
|
||||
}
|
||||
},
|
||||
{ label_values: [], submittable: false }
|
||||
);
|
||||
const [isEditing, setIsEditing] = useBoolean();
|
||||
|
||||
const { data, isLoading } = useSWR("/api/valid_labels", get);
|
||||
useEffect(() => {
|
||||
if (isLoading) {
|
||||
return;
|
||||
}
|
||||
if (!data) {
|
||||
updateReport({ type: "load_labels", labels: [] });
|
||||
return;
|
||||
}
|
||||
const { valid_labels } = data;
|
||||
updateReport({ type: "load_labels", labels: valid_labels });
|
||||
}, [data, isLoading]);
|
||||
const submittable =
|
||||
values.some((value) => {
|
||||
return value !== null;
|
||||
}) &&
|
||||
values.length === valid_labels.length &&
|
||||
valid_labels.length > 0;
|
||||
|
||||
const { trigger } = useSWRMutation("/api/set_label", post, {
|
||||
onSuccess: setIsEditing.off,
|
||||
onSuccess: onClose,
|
||||
onError: onClose,
|
||||
});
|
||||
|
||||
const submitResponse = () => {
|
||||
const label_map: Map<string, number> = new Map();
|
||||
report.label_values.forEach(({ label, checked, value }) => {
|
||||
if (checked) {
|
||||
label_map.set(label.name, value);
|
||||
console.assert(valid_labels.length === values.length);
|
||||
values.forEach((value, idx) => {
|
||||
if (value !== null) {
|
||||
label_map.set(valid_labels[idx].name, value);
|
||||
}
|
||||
});
|
||||
trigger({
|
||||
@@ -131,22 +73,8 @@ export const FlaggableElement = (props: FlaggableElementProps) => {
|
||||
});
|
||||
};
|
||||
|
||||
const handleCheckboxState = (checked, label_index) => {
|
||||
updateReport({ type: "toggle_label", label_index, check: checked });
|
||||
};
|
||||
const handleSliderState = (value, label_index) => {
|
||||
updateReport({ type: "update_value", label_index, value });
|
||||
};
|
||||
|
||||
return (
|
||||
<Popover
|
||||
isOpen={isEditing}
|
||||
onOpen={setIsEditing.on}
|
||||
onClose={setIsEditing.off}
|
||||
closeOnBlur={false}
|
||||
isLazy
|
||||
lazyBehavior="keepMounted"
|
||||
>
|
||||
<Popover isOpen={isOpen} onOpen={onOpen} onClose={onClose} closeOnBlur={false} isLazy lazyBehavior="keepMounted">
|
||||
<Box display="flex" alignItems="center" flexDirection={["column", "row"]} gap="2">
|
||||
<PopoverAnchor>{props.children}</PopoverAnchor>
|
||||
|
||||
@@ -154,33 +82,24 @@ export const FlaggableElement = (props: FlaggableElementProps) => {
|
||||
<Box>
|
||||
<PopoverTrigger>
|
||||
<Box as="button" display="flex" alignItems="center" justifyContent="center" borderRadius="full" p="1">
|
||||
<FiAlertCircle size="20" className="text-red-400" aria-hidden="true" />
|
||||
<AlertCircle size="20" className="text-red-400" aria-hidden="true" />
|
||||
</Box>
|
||||
</PopoverTrigger>
|
||||
</Box>
|
||||
</Tooltip>
|
||||
</Box>
|
||||
|
||||
<PopoverContent width="auto" p="3" m="4" maxWidth="calc(100vw - 2rem)">
|
||||
<PopoverArrow />
|
||||
<Box className="relative h-4">
|
||||
<PopoverCloseButton />
|
||||
</Box>
|
||||
<PopoverBody>
|
||||
{report.label_values.map(({ label, checked, value }, i) => (
|
||||
<FlagCheckbox
|
||||
label={label}
|
||||
key={i}
|
||||
idx={i}
|
||||
checked={checked}
|
||||
sliderValue={value}
|
||||
checkboxHandler={handleCheckboxState}
|
||||
sliderHandler={handleSliderState}
|
||||
/>
|
||||
))}
|
||||
<Flex justify="center">
|
||||
<Modal isOpen={isOpen} onClose={onClose}>
|
||||
<ModalOverlay />
|
||||
<ModalContent>
|
||||
<ModalHeader>Select one or more labels that apply.</ModalHeader>
|
||||
<ModalCloseButton />
|
||||
<ModalBody>
|
||||
<LabelInputGroup labelIDs={valid_labels.map(({ name }) => name)} onChange={setValues} />
|
||||
</ModalBody>
|
||||
<ModalFooter>
|
||||
<Button
|
||||
isDisabled={!report.submittable}
|
||||
isDisabled={!submittable}
|
||||
onClick={submitResponse}
|
||||
className={`bg-indigo-600 text-${useColorModeValue(
|
||||
colors.light.text,
|
||||
@@ -189,85 +108,9 @@ export const FlaggableElement = (props: FlaggableElementProps) => {
|
||||
>
|
||||
Report
|
||||
</Button>
|
||||
</Flex>
|
||||
</PopoverBody>
|
||||
</PopoverContent>
|
||||
</ModalFooter>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
</Popover>
|
||||
);
|
||||
};
|
||||
|
||||
interface FlagCheckboxProps {
|
||||
label: Label;
|
||||
idx: number;
|
||||
checked: boolean;
|
||||
sliderValue: number;
|
||||
checkboxHandler: (newVal: boolean, idx: number) => void;
|
||||
sliderHandler: (newVal: number, idx: number) => void;
|
||||
}
|
||||
|
||||
export function FlagCheckbox(props: FlagCheckboxProps): JSX.Element {
|
||||
let AdditionalExplanation = null;
|
||||
if (props.label.help_text) {
|
||||
AdditionalExplanation = (
|
||||
<a href="#" className="text-sm inline group leading-4">
|
||||
<QuestionMarkCircleIcon
|
||||
className="h-5 w-5 ml-1 text-gray-400 group-hover:text-gray-500 inline"
|
||||
aria-hidden="true"
|
||||
/>
|
||||
</a>
|
||||
);
|
||||
}
|
||||
|
||||
const id = useId();
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
const labelTextClass =
|
||||
colorMode === "light"
|
||||
? `text-${colors.light.text} hover:text-blue-700`
|
||||
: `text-${colors.dark.text} hover:text-blue-400`;
|
||||
|
||||
return (
|
||||
<Flex gap="4" justifyContent="space-between" className="my-2">
|
||||
<div className="flex items-start align-middle">
|
||||
<Checkbox
|
||||
id={id}
|
||||
isChecked={props.checked}
|
||||
onChange={(e) => {
|
||||
props.checkboxHandler(e.target.checked, props.idx);
|
||||
}}
|
||||
/>
|
||||
<label
|
||||
className={clsx(
|
||||
"text-sm form-check-label ml-2 break-all inline align-middle first-line:leading-4",
|
||||
labelTextClass
|
||||
)}
|
||||
htmlFor={id}
|
||||
>
|
||||
{props.label.display_text}
|
||||
{AdditionalExplanation}
|
||||
</label>
|
||||
</div>
|
||||
<div
|
||||
onClick={() => {
|
||||
if (!props.checked) {
|
||||
props.checkboxHandler(true, props.idx);
|
||||
}
|
||||
}}
|
||||
>
|
||||
<Slider
|
||||
width="100px"
|
||||
isDisabled={!props.checked}
|
||||
defaultValue={100}
|
||||
onChangeEnd={(val) => {
|
||||
props.sliderHandler(val / 100, props.idx);
|
||||
}}
|
||||
>
|
||||
<SliderTrack>
|
||||
<SliderFilledTrack />
|
||||
<SliderThumb />
|
||||
</SliderTrack>
|
||||
</Slider>
|
||||
</div>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import { Box, Button, Flex, Text } from "@chakra-ui/react";
|
||||
import { User } from "lucide-react";
|
||||
import Image from "next/image";
|
||||
import Link from "next/link";
|
||||
import { useSession } from "next-auth/react";
|
||||
import { useTranslation } from "next-i18next";
|
||||
import { Flags } from "react-feature-flags";
|
||||
import { FaUser } from "react-icons/fa";
|
||||
import { LanguageSelector } from "src/components/LanguageSelector";
|
||||
|
||||
import { UserMenu } from "./UserMenu";
|
||||
|
||||
@@ -16,7 +17,7 @@ function AccountButton() {
|
||||
return (
|
||||
<Link href="/auth/signin" aria-label="Home">
|
||||
<Flex alignItems="center">
|
||||
<Button variant="outline" leftIcon={<FaUser />}>
|
||||
<Button variant="outline" leftIcon={<User size={"20"} />}>
|
||||
Sign in
|
||||
</Button>
|
||||
</Flex>
|
||||
@@ -45,6 +46,7 @@ export function Header() {
|
||||
<Flags authorizedFlags={["flagTest"]}>
|
||||
<Text>FlagTest</Text>
|
||||
</Flags>
|
||||
<LanguageSelector />
|
||||
<AccountButton />
|
||||
<UserMenu />
|
||||
</Flex>
|
||||
|
||||
@@ -11,16 +11,15 @@ import {
|
||||
Text,
|
||||
useColorModeValue,
|
||||
} from "@chakra-ui/react";
|
||||
import { AlertTriangle, Layout, LogOut, Settings, Shield } from "lucide-react";
|
||||
import NextLink from "next/link";
|
||||
import { signOut, useSession } from "next-auth/react";
|
||||
import { useTranslation } from "next-i18next";
|
||||
import React, { ElementType, useCallback } from "react";
|
||||
import { FiAlertTriangle, FiLayout, FiLogOut, FiSettings, FiShield } from "react-icons/fi";
|
||||
|
||||
interface MenuOption {
|
||||
name: string;
|
||||
href: string;
|
||||
desc: string;
|
||||
icon: ElementType;
|
||||
isExternal: boolean;
|
||||
}
|
||||
@@ -40,22 +39,19 @@ export function UserMenu() {
|
||||
{
|
||||
name: t("dashboard"),
|
||||
href: "/dashboard",
|
||||
desc: t("dashboard"),
|
||||
icon: FiLayout,
|
||||
icon: Layout,
|
||||
isExternal: false,
|
||||
},
|
||||
{
|
||||
name: t("account_settings"),
|
||||
href: "/account",
|
||||
desc: t("account_settings"),
|
||||
icon: FiSettings,
|
||||
icon: Settings,
|
||||
isExternal: false,
|
||||
},
|
||||
{
|
||||
name: t("report_a_bug"),
|
||||
href: "https://github.com/LAION-AI/Open-Assistant/issues/new/choose",
|
||||
desc: t("report_a_bug"),
|
||||
icon: FiAlertTriangle,
|
||||
icon: AlertTriangle,
|
||||
isExternal: true,
|
||||
},
|
||||
];
|
||||
@@ -64,8 +60,7 @@ export function UserMenu() {
|
||||
options.unshift({
|
||||
name: t("admin_dashboard"),
|
||||
href: "/admin",
|
||||
desc: t("admin_dashboard"),
|
||||
icon: FiShield,
|
||||
icon: Shield,
|
||||
isExternal: false,
|
||||
});
|
||||
}
|
||||
@@ -98,7 +93,7 @@ export function UserMenu() {
|
||||
_hover={{ textDecoration: "none" }}
|
||||
>
|
||||
<MenuItem gap="3" borderRadius="md" p="4">
|
||||
<item.icon className="text-blue-500" aria-hidden="true" />
|
||||
<item.icon size="1em" className="text-blue-500" aria-hidden="true" />
|
||||
<Text>{item.name}</Text>
|
||||
</MenuItem>
|
||||
</Link>
|
||||
@@ -106,7 +101,7 @@ export function UserMenu() {
|
||||
</MenuGroup>
|
||||
<MenuDivider />
|
||||
<MenuItem gap="3" borderRadius="md" p="4" onClick={handleSignOut}>
|
||||
<FiLogOut className="text-blue-500" aria-hidden="true" />
|
||||
<LogOut size="1em" className="text-blue-500" aria-hidden="true" />
|
||||
<Text>{t("sign_out")}</Text>
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
import { LucideIcon } from "lucide-react";
|
||||
|
||||
export const Discord: LucideIcon = ({ size = 24, ...rest }) => {
|
||||
return (
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 127.14 96.36"
|
||||
fill="currentColor"
|
||||
height={size}
|
||||
width={size}
|
||||
{...rest}
|
||||
>
|
||||
<path d="M107.7 8.07A105.15 105.15 0 0 0 81.47 0a72.06 72.06 0 0 0-3.36 6.83 97.68 97.68 0 0 0-29.11 0A72.37 72.37 0 0 0 45.64 0a105.89 105.89 0 0 0-26.25 8.09C2.79 32.65-1.71 56.6.54 80.21a105.73 105.73 0 0 0 32.17 16.15 77.7 77.7 0 0 0 6.89-11.11 68.42 68.42 0 0 1-10.85-5.18c.91-.66 1.8-1.34 2.66-2a75.57 75.57 0 0 0 64.32 0c.87.71 1.76 1.39 2.66 2a68.68 68.68 0 0 1-10.87 5.19 77 77 0 0 0 6.89 11.1 105.25 105.25 0 0 0 32.19-16.14c2.64-27.38-4.51-51.11-18.9-72.15ZM42.45 65.69C36.18 65.69 31 60 31 53s5-12.74 11.43-12.74S54 46 53.89 53s-5.05 12.69-11.44 12.69Zm42.24 0C78.41 65.69 73.25 60 73.25 53s5-12.74 11.44-12.74S96.23 46 96.12 53s-5.04 12.69-11.43 12.69Z" />
|
||||
</svg>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,41 @@
|
||||
import { Select } from "@chakra-ui/react";
|
||||
import { useRouter } from "next/router";
|
||||
import { useTranslation } from "next-i18next";
|
||||
import { useCallback, useMemo } from "react";
|
||||
import cookie from "react-cookies";
|
||||
|
||||
const LanguageSelector = () => {
|
||||
const router = useRouter();
|
||||
const { i18n } = useTranslation();
|
||||
|
||||
// Memo the set of locales and their display names.
|
||||
const localesAndNames = useMemo(() => {
|
||||
return router.locales.map((locale) => ({
|
||||
locale,
|
||||
name: new Intl.DisplayNames([locale], { type: "language" }).of(locale),
|
||||
}));
|
||||
}, [router.locales]);
|
||||
|
||||
const languageChanged = useCallback(
|
||||
async (option) => {
|
||||
const locale = option.target.value;
|
||||
cookie.save("NEXT_LOCALE", locale, { path: "/" });
|
||||
const path = router.asPath;
|
||||
return router.push(path, path, { locale });
|
||||
},
|
||||
[router]
|
||||
);
|
||||
|
||||
const { language: currentLanguage } = i18n;
|
||||
return (
|
||||
<Select onChange={languageChanged} defaultValue={currentLanguage}>
|
||||
{localesAndNames.map(({ locale, name }) => (
|
||||
<option key={locale} value={locale}>
|
||||
{name}
|
||||
</option>
|
||||
))}
|
||||
</Select>
|
||||
);
|
||||
};
|
||||
|
||||
export { LanguageSelector };
|
||||
@@ -0,0 +1 @@
|
||||
export * from "./LanguageSelector";
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user