Merge branch 'main' into 911_sigin_captcha

This commit is contained in:
notmd
2023-01-29 23:34:04 +07:00
155 changed files with 6347 additions and 1905 deletions
+1 -1
View File
@@ -19,7 +19,7 @@ devcontainers in this repo.
pre-commit run --all-files
```
A successfull run should look something like this:
A successful run should look something like this:
```
@andrewm4894 ➜ /workspaces/Open-Assistant (devcontainer-improvements) $ pre-commit run --all-files
+5
View File
@@ -39,8 +39,13 @@ jobs:
AWS_SECRET_KEY: ${{ secrets.AWS_SECRET_KEY }}
MAX_ACTIVE_TREES: ${{ vars.MAX_ACTIVE_TREES }}
MAX_TREE_DEPTH: ${{ vars.MAX_TREE_DEPTH }}
MAX_CHILDREN_COUNT: ${{ vars.MAX_CHILDREN_COUNT }}
GOAL_TREE_SIZE: ${{ vars.GOAL_TREE_SIZE }}
SKIP_TOXICITY_CALCULATION: ${{ vars.SKIP_TOXICITY_CALCULATION }}
STATS_INTERVAL_DAY: ${{ vars.STATS_INTERVAL_DAY }}
STATS_INTERVAL_WEEK: ${{ vars.STATS_INTERVAL_WEEK }}
STATS_INTERVAL_MONTH: ${{ vars.STATS_INTERVAL_MONTH }}
STATS_INTERVAL_TOTAL: ${{ vars.STATS_INTERVAL_TOTAL }}
steps:
- name: Checkout
uses: actions/checkout@v2
+7 -3
View File
@@ -68,7 +68,7 @@ repos:
- id: flake8
- repo: https://github.com/pycqa/isort
rev: 5.11.1
rev: 5.12.0
hooks:
- id: isort
@@ -76,8 +76,12 @@ repos:
rev: v2.7.1
hooks:
- id: prettier
args: [--prose-wrap=always, --write]
exclude: website/tailwind.config.js|website/.storybook/main.js|website/.eslintrc.json
args:
[
--prose-wrap=always,
--write,
--ignore-path=./website/.prettierignore,
]
- repo: local
hooks:
+6 -6
View File
@@ -50,12 +50,12 @@ contributions smoothly we recommend the following:
[Here](https://github.com/LAION-AI/Open-Assistant/pull/658) is an example PR
for this project to illustrate this flow.
1. If you're lucky, we can merge your change into `main` without any problems.
If there's changes to files you're working on, resolve them by:
1. First try rebase as suggested
[in these instructions](https://timwise.co.uk/2019/10/14/merge-vs-rebase/#should-you-rebase).
1. If rebase feels too painful, merge as suggested
[in these instructions](https://timwise.co.uk/2019/10/14/merge-vs-rebase/#should-you-merge).
1. Once you've resolved any conflicts, finish the review and
If there's changes to files you're working on, resolve them by :
1. First try rebase as suggested
[in these instructions](https://timwise.co.uk/2019/10/14/merge-vs-rebase/#should-you-rebase).
1. If rebase feels too painful, merge as suggested
[in these instructions](https://timwise.co.uk/2019/10/14/merge-vs-rebase/#should-you-merge).
1. Once you've resolved conflicts (if any), finish the review and
[squash and merge](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/incorporating-changes-from-a-pull-request/about-pull-request-merges#squash-and-merge-your-commits)
your PR (when squashing try to clean up or update the individual commit
messages to be one sensible single one).
+15
View File
@@ -119,6 +119,21 @@
TREE_MANAGER__GOAL_TREE_SIZE:
"{{ lookup('ansible.builtin.env', 'GOAL_TREE_SIZE') | default('15',
true) }}"
TREE_MANAGER__MAX_CHILDREN_COUNT:
"{{ lookup('ansible.builtin.env', 'MAX_CHILDREN_COUNT') |
default('3', true) }}"
USER_STATS_INTERVAL_DAY:
"{{ lookup('ansible.builtin.env', 'STATS_INTERVAL_DAY') |
default('5', true) }}"
USER_STATS_INTERVAL_WEEK:
"{{ lookup('ansible.builtin.env', 'STATS_INTERVAL_WEEK') |
default('15', true) }}"
USER_STATS_INTERVAL_MONTH:
"{{ lookup('ansible.builtin.env', 'STATS_INTERVAL_MONTH') |
default('60', true) }}"
USER_STATS_INTERVAL_TOTAL:
"{{ lookup('ansible.builtin.env', 'STATS_INTERVAL_TOTAL') |
default('240', true) }}"
ports:
- "{{ backend_port }}:8080"
@@ -0,0 +1,26 @@
"""add task created date index
Revision ID: c84fcd6900dc
Revises: 40ed93df0ed5
Create Date: 2023-01-26 18:35:43.061589
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "c84fcd6900dc"
down_revision = "40ed93df0ed5"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_index(op.f("ix_task_created_date"), "task", ["created_date"], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_task_created_date"), table_name="task")
# ### end Alembic commands ###
@@ -0,0 +1,29 @@
"""add user.show_on_leaderboard
Revision ID: f856bf19d32b
Revises: c84fcd6900dc
Create Date: 2023-01-27 20:13:56.533374
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "f856bf19d32b"
down_revision = "c84fcd6900dc"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"user", sa.Column("show_on_leaderboard", sa.Boolean(), server_default=sa.text("true"), nullable=False)
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("user", "show_on_leaderboard")
# ### end Alembic commands ###
@@ -0,0 +1,31 @@
"""add origin column to message_tree_state
Revision ID: 49d8445b4c90
Revises: f856bf19d32b
Create Date: 2023-01-28 11:57:45.580027
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "49d8445b4c90"
down_revision = "f856bf19d32b"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("message", sa.Column("synthetic", sa.Boolean(), server_default=sa.text("false"), nullable=False))
op.add_column("message", sa.Column("model_name", sa.String(length=1024), nullable=True))
op.add_column("message_tree_state", sa.Column("origin", sa.String(length=1024), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("message_tree_state", "origin")
op.drop_column("message", "model_name")
op.drop_column("message", "synthetic")
# ### end Alembic commands ###
+187
View File
@@ -0,0 +1,187 @@
import argparse
import json
from pathlib import Path
from typing import Optional
from uuid import UUID
import oasst_backend.models.db_payload as db_payload
import oasst_backend.utils.database_utils as db_utils
import pydantic
from loguru import logger
from oasst_backend.api.deps import create_api_client
from oasst_backend.models import ApiClient, Message
from oasst_backend.models.message_tree_state import MessageTreeState
from oasst_backend.models.message_tree_state import State as TreeState
from oasst_backend.models.payload_column_type import PayloadContainer
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.user_repository import UserRepository
from oasst_backend.utils.tree_export import ExportMessageNode, ExportMessageTree
from sqlmodel import Session
# well known id
IMPORT_API_CLIENT_ID = UUID("bd8fde8b-1d8e-4e9a-9966-e96d000f8363")
class Importer:
def __init__(self, db: Session, origin: str, model_name: Optional[str] = None):
self.db = db
self.origin = origin
self.model_name = model_name
# get import api client
api_client = db.query(ApiClient).filter(ApiClient.id == IMPORT_API_CLIENT_ID).first()
if not api_client:
api_client = create_api_client(
session=db,
description="API client used for importing data",
frontend_type="import",
force_id=IMPORT_API_CLIENT_ID,
)
ur = UserRepository(db, api_client)
self.import_user = ur.lookup_system_user(username="import")
self.pr = PromptRepository(db=db, api_client=api_client, user_repository=ur)
self.api_client = api_client
def fetch_message_tree_state(self, message_tree_id: UUID) -> MessageTreeState:
return self.db.query(MessageTreeState).filter(MessageTreeState.message_tree_id == message_tree_id).one_or_none()
def import_message(
self, message: ExportMessageNode, message_tree_id: UUID, parent_id: Optional[UUID] = None
) -> Message:
payload = db_payload.MessagePayload(text=message.text)
msg = Message(
id=message.message_id,
message_tree_id=message_tree_id,
frontend_message_id=message.message_id,
parent_id=parent_id,
review_count=message.review_count or 0,
lang=message.lang or "en",
review_result=True,
synthetic=message.synthetic if message.synthetic is not None else True,
model_name=message.model_name or self.model_name,
role=message.role,
api_client_id=self.api_client.id,
payload_type=type(payload).__name__,
payload=PayloadContainer(payload=payload),
user_id=self.import_user.id,
)
self.db.add(msg)
if message.replies:
for r in message.replies:
self.import_message(r, message_tree_id=message_tree_id, parent_id=msg.id)
self.db.flush()
if parent_id is None:
self.pr.update_children_counts(msg.id)
self.db.refresh(msg)
return msg
def import_tree(
self, tree: ExportMessageTree, state: TreeState = TreeState.BACKLOG_RANKING
) -> tuple[MessageTreeState, Message]:
assert tree.message_tree_id is not None and tree.message_tree_id == tree.prompt.message_id
root_msg = self.import_message(tree.prompt, message_tree_id=tree.prompt.message_id)
assert state == TreeState.BACKLOG_RANKING or state == TreeState.RANKING, f"{state} not supported for import"
active = state == TreeState.RANKING
mts = MessageTreeState(
message_tree_id=root_msg.id,
goal_tree_size=0,
max_depth=0,
max_children_count=0,
state=state,
origin=self.origin,
active=active,
)
self.db.add(mts)
return mts, root_msg
def import_file(
input_file_path: Path,
origin: str,
*,
model_name: Optional[str] = None,
num_activate: int = 0,
max_count: Optional[int] = None,
dry_run: bool = False,
) -> int:
@db_utils.managed_tx_function(auto_commit=db_utils.CommitMode.ROLLBACK if dry_run else db_utils.CommitMode.COMMIT)
def import_tx(db: Session) -> int:
importer = Importer(db, origin=origin, model_name=model_name)
i = 0
with input_file_path.open() as file_in:
# read line tree object
for line in file_in:
dict_tree = json.loads(line)
# validate data
tree: ExportMessageTree = pydantic.parse_obj_as(ExportMessageTree, dict_tree)
existing_mts = importer.fetch_message_tree_state(tree.message_tree_id)
if existing_mts:
logger.info(f"Skipping existing message tree: {tree.message_tree_id}")
else:
state = TreeState.BACKLOG_RANKING if i >= num_activate else TreeState.RANKING
mts, root_msg = importer.import_tree(tree, state=state)
i += 1
logger.info(
f"imported tree: {mts.message_tree_id}, {mts.state=}, {mts.active=}, {root_msg.children_count=}"
)
if max_count and i >= max_count:
logger.info(f"Reached max count {max_count} of trees to import.")
break
return i
if dry_run:
logger.info("DRY RUN with rollback")
return import_tx()
def parse_args():
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
parser = argparse.ArgumentParser()
parser.add_argument(
"input_file_path",
help="Input file path",
)
parser.add_argument("--origin", type=str, default=None, help="Value for origin of message trees")
parser.add_argument("--model_name", type=str, default=None, help="Default name of model (if missing in messages)")
parser.add_argument("--num_activate", type=int, default=0, help="Number of trees to add in ranking state")
parser.add_argument("--max_count", type=int, default=None, help="Maximum number of message trees to import")
parser.add_argument("--dry_run", type=str2bool, default=False)
args = parser.parse_args()
return args
def main():
args = parse_args()
input_file_path = Path(args.input_file_path)
if not input_file_path.exists() or not input_file_path.is_file():
print("Invalid input file:", args.input_file_path)
exit(1)
dry_run = args.dry_run
num_imported = import_file(
input_file_path,
origin=args.origin or input_file_path.name,
model_name=args.model_name,
num_activate=args.num_activate,
max_count=args.max_count,
dry_run=dry_run,
)
logger.info(f"Done ({num_imported=}, {dry_run=})")
if __name__ == "__main__":
main()
+24 -1
View File
@@ -191,6 +191,7 @@ if settings.DEBUG_USE_SEED_DATA:
review_count=5,
review_result=True,
check_tree_state=False,
check_duplicate=False,
)
if message.parent_id is None:
tm._insert_default_state(
@@ -215,7 +216,8 @@ def ensure_tree_states():
try:
logger.info("Startup: TreeManager.ensure_tree_states()")
with Session(engine) as db:
tm = TreeManager(db, None)
api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=db)
tm = TreeManager(db, PromptRepository(db, api_client=api_client))
tm.ensure_tree_states()
except Exception:
@@ -291,6 +293,20 @@ def export_ready_trees(file: Optional[str] = None, use_compression: bool = False
logger.exception("Error exporting trees.")
def retry_scoring_failed_message_trees():
try:
logger.info("TreeManager.retry_scoring_failed_message_trees()")
with Session(engine) as db:
api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=db)
pr = PromptRepository(db=db, api_client=api_client)
tm = TreeManager(db, pr)
tm.retry_scoring_failed_message_trees()
except Exception:
logger.exception("TreeManager.retry_scoring_failed_message_trees() failed.")
def main():
# Importing here so we don't import packages unnecessarily if we're
# importing main as a module.
@@ -314,6 +330,11 @@ def main():
"--export-file",
help="Name of file to export trees to. If not provided when exporting, output will be send to STDOUT",
)
parser.add_argument(
"--retry-scoring",
help="Retry scoring failed message trees",
action=argparse.BooleanOptionalAction,
)
args = parser.parse_args()
@@ -322,6 +343,8 @@ def main():
elif args.export:
use_compression: bool = ".gz" in args.export_file
export_ready_trees(file=args.export_file, use_compression=use_compression)
elif args.retry_scoring:
retry_scoring_failed_message_trees()
else:
uvicorn.run(app, host=args.host, port=args.port)
+33 -5
View File
@@ -1,6 +1,7 @@
from http import HTTPStatus
from secrets import token_hex
from typing import Generator
from typing import Generator, NamedTuple, Optional
from uuid import UUID
from fastapi import Depends, Request, Response, Security
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
@@ -19,22 +20,46 @@ def get_db() -> Generator:
yield db
api_key_query = APIKeyQuery(name="api_key", auto_error=False)
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
api_key_query = APIKeyQuery(name="api_key", scheme_name="api-key", auto_error=False)
api_key_header = APIKeyHeader(name="X-API-Key", scheme_name="api-key", auto_error=False)
oasst_user_query = APIKeyQuery(name="oasst_user", scheme_name="oasst-user", auto_error=False)
oasst_user_header = APIKeyHeader(name="x-oasst-user", scheme_name="oasst-user", auto_error=False)
bearer_token = HTTPBearer(auto_error=False)
async def get_api_key(
def get_api_key(
api_key_query: str = Security(api_key_query),
api_key_header: str = Security(api_key_header),
):
) -> str:
if api_key_query:
return api_key_query
else:
return api_key_header
class FrontendUserId(NamedTuple):
auth_method: str
username: str
def get_frontend_user_id(
user_query: str = Security(oasst_user_query),
user_header: str = Security(oasst_user_header),
) -> FrontendUserId:
def split_user(v: str) -> tuple[str, str]:
if type(v) is str:
v = v.split(":", maxsplit=1)
if len(v) == 2:
return FrontendUserId(auth_method=v[0], username=v[1])
return FrontendUserId(auth_method=None, username=None)
if user_query:
return split_user(user_query)
else:
return split_user(user_header)
def create_api_client(
*,
session: Session,
@@ -43,6 +68,7 @@ def create_api_client(
trusted: bool | None = False,
admin_email: str | None = None,
api_key: str | None = None,
force_id: Optional[UUID] = None,
) -> ApiClient:
if api_key is None:
api_key = token_hex(32)
@@ -55,6 +81,8 @@ def create_api_client(
trusted=trusted,
admin_email=admin_email,
)
if force_id:
api_client.id = force_id
session.add(api_client)
session.commit()
session.refresh(api_client)
+2
View File
@@ -1,6 +1,7 @@
from fastapi import APIRouter
from oasst_backend.api.v1 import (
admin,
auth,
frontend_messages,
frontend_users,
hugging_face,
@@ -23,3 +24,4 @@ api_router.include_router(stats.router, prefix="/stats", tags=["stats"])
api_router.include_router(leaderboards.router, prefix="/leaderboards", tags=["leaderboards"])
api_router.include_router(hugging_face.router, prefix="/hf", tags=["hugging_face"])
api_router.include_router(admin.router, prefix="/admin", tags=["admin"])
api_router.include_router(auth.router, prefix="/auth", tags=["auth"])
+46
View File
@@ -0,0 +1,46 @@
from typing import Union
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from fastapi import APIRouter, Depends, Security
from fastapi.security import APIKeyCookie
from jose import jwe
from oasst_backend.config import settings
from pydantic import BaseModel, EmailStr
router = APIRouter()
oauth2_scheme = APIKeyCookie(name=settings.AUTH_COOKIE_NAME)
class TokenData(BaseModel):
"""
A minimal re-creation of the web's token type. To be expanded later.
"""
email: Union[EmailStr, None] = None
async def get_current_user(token: str = Security(oauth2_scheme)):
"""
Decrypts the user's JSON Web Token using HKDF encryption and returns the
TokenData.
"""
# We first generate a key from the auth secret.
hkdf = HKDF(
algorithm=hashes.SHA256(),
length=settings.AUTH_LENGTH,
salt=settings.AUTH_SALT,
info=settings.AUTH_INFO,
)
key = hkdf.derive(settings.AUTH_SECRET)
# Next we decrypt the JWE token.
payload = jwe.decrypt(token, key)
# Finally we have the real token JSON payload and can do whatever we want.
return TokenData.parse_raw(payload)
@router.get("/check", response_model=str)
async def auth_check(token_data: TokenData = Depends(get_current_user)):
"""Returns the user's email if it can be decrypted."""
return token_data.email
@@ -77,7 +77,7 @@ def query_frontend_user_messages(
"""
Query frontend user messages.
"""
pr = PromptRepository(db, api_client)
pr = PromptRepository(db, api_client, auth_method=auth_method, username=username)
messages = pr.query_messages_ordered_by_created_date(
auth_method=auth_method,
username=username,
@@ -104,6 +104,7 @@ def query_frontend_user_messages_cursor(
max_count: Optional[int] = Query(10, gt=0, le=1000),
desc: Optional[bool] = False,
lang: Optional[str] = None,
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
@@ -117,6 +118,7 @@ def query_frontend_user_messages_cursor(
max_count=max_count,
desc=desc,
lang=lang,
frontend_user=frontend_user,
api_client=api_client,
db=db,
)
+54 -18
View File
@@ -18,6 +18,7 @@ router = APIRouter()
@router.get("/", response_model=list[protocol.Message])
def query_messages(
*,
auth_method: Optional[str] = None,
username: Optional[str] = None,
api_client_id: Optional[str] = None,
@@ -28,13 +29,14 @@ def query_messages(
desc: Optional[bool] = True,
allow_deleted: Optional[bool] = False,
lang: Optional[str] = None,
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
"""
Query messages.
"""
pr = PromptRepository(db, api_client)
pr = PromptRepository(db, api_client, auth_method=frontend_user.auth_method, username=frontend_user.username)
messages = pr.query_messages_ordered_by_created_date(
auth_method=auth_method,
username=username,
@@ -53,6 +55,7 @@ def query_messages(
@router.get("/cursor", response_model=protocol.MessagePage)
def get_messages_cursor(
*,
before: Optional[str] = None,
after: Optional[str] = None,
user_id: Optional[UUID] = None,
@@ -64,6 +67,7 @@ def get_messages_cursor(
max_count: Optional[int] = Query(10, gt=0, le=1000),
desc: Optional[bool] = False,
lang: Optional[str] = None,
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
@@ -93,7 +97,7 @@ def get_messages_cursor(
qry_max_count = max_count + 1 if before is None or after is None else max_count
pr = PromptRepository(db, api_client)
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
items = pr.query_messages_ordered_by_created_date(
user_id=user_id,
auth_method=auth_method,
@@ -137,37 +141,49 @@ def get_messages_cursor(
@router.get("/{message_id}", response_model=protocol.Message)
def get_message(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
*,
message_id: UUID,
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
"""
Get a message by its internal ID.
"""
pr = PromptRepository(db, api_client)
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
message = pr.fetch_message(message_id)
return utils.prepare_message(message)
@router.get("/{message_id}/conversation", response_model=protocol.Conversation)
def get_conv(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
*,
message_id: UUID,
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
"""
Get a conversation from the tree root and up to the message with given internal ID.
"""
pr = PromptRepository(db, api_client)
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
messages = pr.fetch_message_conversation(message_id)
return utils.prepare_conversation(messages)
@router.get("/{message_id}/tree", response_model=protocol.MessageTree)
def get_tree(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
*,
message_id: UUID,
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
"""
Get all messages belonging to the same message tree.
"""
pr = PromptRepository(db, api_client)
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
message = pr.fetch_message(message_id)
tree = pr.fetch_message_tree(message.message_tree_id, reviewed=False)
return utils.prepare_tree(tree, message.message_tree_id)
@@ -175,24 +191,32 @@ def get_tree(
@router.get("/{message_id}/children", response_model=list[protocol.Message])
def get_children(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
*,
message_id: UUID,
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
"""
Get all messages belonging to the same message tree.
"""
pr = PromptRepository(db, api_client)
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
messages = pr.fetch_message_children(message_id)
return utils.prepare_message_list(messages)
@router.get("/{message_id}/descendants", response_model=protocol.MessageTree)
def get_descendants(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
*,
message_id: UUID,
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
"""
Get a subtree which starts with this message.
"""
pr = PromptRepository(db, api_client)
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
message = pr.fetch_message(message_id)
descendants = pr.fetch_message_descendants(message)
return utils.prepare_tree(descendants, message.id)
@@ -200,12 +224,16 @@ def get_descendants(
@router.get("/{message_id}/longest_conversation_in_tree", response_model=protocol.Conversation)
def get_longest_conv(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
*,
message_id: UUID,
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
"""
Get the longest conversation from the tree of the message.
"""
pr = PromptRepository(db, api_client)
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
message = pr.fetch_message(message_id)
conv = pr.fetch_longest_conversation(message.message_tree_id)
return utils.prepare_conversation(conv)
@@ -213,12 +241,16 @@ def get_longest_conv(
@router.get("/{message_id}/max_children_in_tree", response_model=protocol.MessageTree)
def get_max_children(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
*,
message_id: UUID,
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
"""
Get message with the most children from the tree of the provided message.
"""
pr = PromptRepository(db, api_client)
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
message = pr.fetch_message(message_id)
message, children = pr.fetch_message_with_max_children(message.message_tree_id)
return utils.prepare_tree([message, *children], message.id)
@@ -226,9 +258,13 @@ def get_max_children(
@router.delete("/{message_id}", status_code=HTTP_204_NO_CONTENT)
def mark_message_deleted(
message_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
*,
message_id: UUID,
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
db: Session = Depends(deps.get_db),
):
pr = PromptRepository(db, api_client)
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
pr.mark_messages_deleted(message_id)
+4 -2
View File
@@ -77,6 +77,7 @@ def tasks_acknowledge(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
task_id: UUID,
ack_request: protocol_schema.TaskAck,
) -> None:
@@ -87,7 +88,7 @@ def tasks_acknowledge(
api_client = deps.api_auth(api_key, db)
try:
pr = PromptRepository(db, api_client)
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
# here we store the message id in the database for the task
logger.info(f"Frontend acknowledges task {task_id=}, {ack_request=}.")
@@ -105,6 +106,7 @@ def tasks_acknowledge_failure(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
task_id: UUID,
nack_request: protocol_schema.TaskNAck,
) -> None:
@@ -115,7 +117,7 @@ def tasks_acknowledge_failure(
try:
logger.info(f"Frontend reports failure to implement task {task_id=}, {nack_request=}.")
api_client = deps.api_auth(api_key, db)
pr = PromptRepository(db, api_client)
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
pr.task_repository.acknowledge_task_failure(task_id)
except (KeyError, RuntimeError):
logger.exception("Failed to not acknowledge task.")
+50 -4
View File
@@ -1,12 +1,19 @@
from typing import Optional
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from fastapi.security.api_key import APIKey
from loguru import logger
from oasst_backend.api import deps
from oasst_backend.config import settings
from oasst_backend.models import ApiClient
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.schemas.text_labels import LabelOption, ValidLabelsResponse
from oasst_backend.schemas.text_labels import LabelDescription, ValidLabelsResponse
from oasst_backend.utils.database_utils import CommitMode, managed_tx_function
from oasst_shared.exceptions import OasstError
from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.schemas.protocol import TextLabel
from sqlmodel import Session
from starlette.status import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST
router = APIRouter()
@@ -42,10 +49,49 @@ def label_text(
@router.get("/valid_labels")
def get_valid_lables() -> ValidLabelsResponse:
def get_valid_lables(
*,
message_id: Optional[UUID] = None,
db: Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_api_client),
) -> ValidLabelsResponse:
if message_id:
pr = PromptRepository(db, api_client=api_client)
message = pr.fetch_message(message_id=message_id)
if message.parent_id is None:
valid_labels = settings.tree_manager.labels_initial_prompt
elif message.role == "assistant":
valid_labels = settings.tree_manager.labels_assistant_reply
else:
valid_labels = settings.tree_manager.labels_prompter_reply
else:
valid_labels = [l for l in TextLabel if l != TextLabel.fails_task]
return ValidLabelsResponse(
valid_labels=[
LabelOption(name=l.value, display_text=l.display_text, help_text=l.help_text)
for l in protocol_schema.TextLabel
LabelDescription(name=l.value, widget=l.widget.value, display_text=l.display_text, help_text=l.help_text)
for l in valid_labels
]
)
@router.get("/report_labels")
def get_report_lables() -> ValidLabelsResponse:
report_labels = [
TextLabel.spam,
TextLabel.not_appropriate,
TextLabel.pii,
TextLabel.hate_speech,
TextLabel.sexual_content,
TextLabel.moral_judgement,
TextLabel.political_content,
TextLabel.toxicity,
TextLabel.violence,
TextLabel.quality,
]
return ValidLabelsResponse(
valid_labels=[
LabelDescription(name=l.value, widget=l.widget.value, display_text=l.display_text, help_text=l.help_text)
for l in report_labels
]
)
+11 -4
View File
@@ -190,6 +190,7 @@ def update_user(
user_id: UUID,
enabled: Optional[bool] = None,
notes: Optional[str] = None,
show_on_leaderboard: Optional[bool] = None,
db: Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
):
@@ -197,7 +198,7 @@ def update_user(
Update a user by global user ID. Only trusted clients can update users.
"""
ur = UserRepository(db, api_client)
ur.update_user(user_id, enabled, notes)
ur.update_user(user_id, enabled, notes, show_on_leaderboard)
@router.delete("/{user_id}", status_code=HTTP_204_NO_CONTENT)
@@ -224,13 +225,14 @@ def query_user_messages(
desc: bool = True,
include_deleted: bool = False,
lang: Optional[str] = None,
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
"""
Query user messages.
"""
pr = PromptRepository(db, api_client)
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
messages = pr.query_messages_ordered_by_created_date(
user_id=user_id,
api_client_id=api_client_id,
@@ -256,6 +258,7 @@ def query_user_messages_cursor(
max_count: Optional[int] = Query(10, gt=0, le=1000),
desc: Optional[bool] = False,
lang: Optional[str] = None,
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
@@ -268,6 +271,7 @@ def query_user_messages_cursor(
max_count=max_count,
desc=desc,
lang=lang,
frontend_user=frontend_user,
api_client=api_client,
db=db,
)
@@ -275,9 +279,12 @@ def query_user_messages_cursor(
@router.delete("/{user_id}/messages", status_code=HTTP_204_NO_CONTENT)
def mark_user_messages_deleted(
user_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
user_id: UUID,
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
db: Session = Depends(deps.get_db),
):
pr = PromptRepository(db, api_client)
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
messages = pr.query_messages_ordered_by_created_date(user_id=user_id, limit=None)
pr.mark_messages_deleted(messages)
+15 -11
View File
@@ -14,7 +14,8 @@ def prepare_message(m: Message) -> protocol.Message:
lang=m.lang,
is_assistant=(m.role == "assistant"),
created_date=m.created_date,
emojis=m.emojis,
emojis=m.emojis or {},
user_emojis=m.user_emojis or [],
)
@@ -22,17 +23,20 @@ def prepare_message_list(messages: list[Message]) -> list[protocol.Message]:
return [prepare_message(m) for m in messages]
def prepare_conversation_message(message: Message) -> protocol.ConversationMessage:
return protocol.ConversationMessage(
id=message.id,
frontend_message_id=message.frontend_message_id,
text=message.text,
lang=message.lang,
is_assistant=(message.role == "assistant"),
emojis=message.emojis or {},
user_emojis=message.user_emojis or [],
)
def prepare_conversation_message_list(messages: list[Message]) -> list[protocol.ConversationMessage]:
return [
protocol.ConversationMessage(
id=message.id,
frontend_message_id=message.frontend_message_id,
text=message.text,
lang=message.lang,
is_assistant=(message.role == "assistant"),
)
for message in messages
]
return [prepare_conversation_message(message) for message in messages]
def prepare_conversation(messages: list[Message]) -> protocol.Conversation:
+80 -8
View File
@@ -1,7 +1,7 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.schemas.protocol import TextLabel
from pydantic import AnyHttpUrl, BaseModel, BaseSettings, FilePath, PostgresDsn, validator
@@ -16,7 +16,7 @@ class TreeManagerConfiguration(BaseModel):
max_tree_depth: int = 6
"""Maximum depth of message tree."""
max_children_count: int = 5
max_children_count: int = 3
"""Maximum number of reply messages per tree node."""
goal_tree_size: int = 15
@@ -46,23 +46,93 @@ class TreeManagerConfiguration(BaseModel):
num_required_rankings: int = 3
"""Number of rankings in which the message participated."""
mandatory_labels_initial_prompt: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam]
p_activate_backlog_tree: float = 0.8
"""Probability to activate a message tree in BACKLOG_RANKING state when another tree enters
a terminal state. Use this settting to control ratio of initial prompts and backlog tree
activations."""
min_active_rankings_per_lang: int = 2
"""When the number of active ranking tasks is below this value when a tree enters a terminal
state an available trees in BACKLOG_RANKING will be actived (i.e. enters the RANKING state)."""
labels_initial_prompt: list[TextLabel] = [
TextLabel.spam,
TextLabel.quality,
TextLabel.helpfulness,
TextLabel.creativity,
TextLabel.humor,
TextLabel.toxicity,
TextLabel.violence,
TextLabel.not_appropriate,
TextLabel.pii,
TextLabel.hate_speech,
TextLabel.sexual_content,
]
labels_assistant_reply: list[TextLabel] = [
TextLabel.spam,
TextLabel.fails_task,
TextLabel.quality,
TextLabel.helpfulness,
TextLabel.creativity,
TextLabel.humor,
TextLabel.toxicity,
TextLabel.violence,
TextLabel.not_appropriate,
TextLabel.pii,
TextLabel.hate_speech,
TextLabel.sexual_content,
]
labels_prompter_reply: list[TextLabel] = [
TextLabel.spam,
TextLabel.quality,
TextLabel.helpfulness,
TextLabel.humor,
TextLabel.creativity,
TextLabel.toxicity,
TextLabel.violence,
TextLabel.not_appropriate,
TextLabel.pii,
TextLabel.hate_speech,
TextLabel.sexual_content,
]
mandatory_labels_initial_prompt: Optional[list[TextLabel]] = [TextLabel.spam]
"""Mandatory labels in text-labeling tasks for initial prompts."""
mandatory_labels_assistant_reply: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam]
mandatory_labels_assistant_reply: Optional[list[TextLabel]] = [TextLabel.spam]
"""Mandatory labels in text-labeling tasks for assistant replies."""
mandatory_labels_prompter_reply: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam]
mandatory_labels_prompter_reply: Optional[list[TextLabel]] = [TextLabel.spam]
"""Mandatory labels in text-labeling tasks for prompter replies."""
rank_prompter_replies: bool = False
lonely_children_count: int = 2
"""Number of children below which parents are preferred during sampling for reply tasks."""
p_lonely_child_extension: float = 0.8
"""Probability to select a parent with less than lonely_children_count children."""
recent_tasks_span_sec: int = 3 * 60 # 3 min
"""Time in seconds of recent tasks to consider for exclusion during task selection."""
class Settings(BaseSettings):
PROJECT_NAME: str = "open-assistant backend"
API_V1_STR: str = "/api/v1"
OFFICIAL_WEB_API_KEY: str = "1234"
# Encryption fields for handling the web generated JSON Web Tokens.
# These fields need to be shared with the web's auth settings in order to
# correctly decrypt the web tokens.
AUTH_INFO: bytes = b"NextAuth.js Generated Encryption Key"
AUTH_SALT: bytes = b""
AUTH_LENGTH: int = 32
AUTH_SECRET: bytes = b"O/M2uIbGj+lDD2oyNa8ax4jEOJqCPJzO53UbWShmq98="
AUTH_COOKIE_NAME: str = "next-auth.session-token"
POSTGRES_HOST: str = "localhost"
POSTGRES_PORT: str = "5432"
POSTGRES_USER: str = "postgres"
@@ -86,6 +156,8 @@ class Settings(BaseSettings):
DEBUG_SKIP_TOXICITY_CALCULATION: bool = False
DEBUG_DATABASE_ECHO: bool = False
DUPLICATE_MESSAGE_FILTER_WINDOW_MINUTES: int = 120
HUGGING_FACE_API_KEY: str = ""
ROOT_TOKENS: List[str] = ["1234"] # supply a string that can be parsed to a json list
@@ -116,9 +188,9 @@ class Settings(BaseSettings):
tree_manager: Optional[TreeManagerConfiguration] = TreeManagerConfiguration()
USER_STATS_INTERVAL_DAY: int = 15 # minutes
USER_STATS_INTERVAL_WEEK: int = 60 # minutes
USER_STATS_INTERVAL_MONTH: int = 120 # minutes
USER_STATS_INTERVAL_DAY: int = 5 # minutes
USER_STATS_INTERVAL_WEEK: int = 15 # minutes
USER_STATS_INTERVAL_MONTH: int = 60 # minutes
USER_STATS_INTERVAL_TOTAL: int = 240 # minutes
@validator(
+2 -1
View File
@@ -117,7 +117,8 @@ class LabelConversationReplyPayload(TaskPayload):
message_id: UUID
conversation: protocol_schema.Conversation
reply: str
reply: str # deprecated
reply_message: Optional[protocol_schema.ConversationMessage]
valid_labels: list[str]
mandatory_labels: Optional[list[str]]
mode: Optional[protocol_schema.LabelTaskMode]
+26 -2
View File
@@ -1,12 +1,13 @@
from datetime import datetime
from http import HTTPStatus
from typing import Optional
from typing import Any, Optional
from uuid import UUID, uuid4
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from oasst_backend.models.db_payload import MessagePayload
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from pydantic import PrivateAttr
from sqlalchemy import false
from sqlmodel import Field, Index, SQLModel
@@ -17,6 +18,13 @@ class Message(SQLModel, table=True):
__tablename__ = "message"
__table_args__ = (Index("ix_message_frontend_message_id", "api_client_id", "frontend_message_id", unique=True),)
def __new__(cls, *args: Any, **kwargs: Any):
new_object = super().__new__(cls, *args, **kwargs)
# temporary fix until https://github.com/tiangolo/sqlmodel/issues/149 gets merged
if not hasattr(new_object, "_user_emojis"):
new_object._init_private_attributes()
return new_object
id: Optional[UUID] = Field(
sa_column=sa.Column(
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
@@ -49,13 +57,29 @@ class Message(SQLModel, table=True):
rank: Optional[int] = Field(nullable=True)
emojis: dict[str, int] = Field(default={}, sa_column=sa.Column(pg.JSONB), nullable=False)
synthetic: Optional[bool] = Field(
sa_column=sa.Column(sa.Boolean, default=False, server_default=false(), nullable=False)
)
model_name: Optional[str] = Field(sa_column=sa.Column(sa.String(1024), nullable=True))
emojis: Optional[dict[str, int]] = Field(default=None, sa_column=sa.Column(pg.JSONB), nullable=False)
_user_emojis: Optional[list[str]] = PrivateAttr(default=None)
def ensure_is_message(self) -> None:
if not self.payload or not isinstance(self.payload.payload, MessagePayload):
raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE, HTTPStatus.INTERNAL_SERVER_ERROR)
def has_emoji(self, emoji_code: str) -> bool:
return self.emojis and emoji_code in self.emojis and self.emojis[emoji_code] > 0
def has_user_emoji(self, emoji_code: str) -> bool:
return self._user_emojis and emoji_code in self._user_emojis
@property
def text(self) -> str:
self.ensure_is_message()
return self.payload.payload.text
@property
def user_emojis(self) -> str:
return self._user_emojis
@@ -43,6 +43,9 @@ class State(str, Enum):
HALTED_BY_MODERATOR = "halted_by_moderator"
"""A moderator decided to manually halt the message tree construction process."""
BACKLOG_RANKING = "backlog_ranking"
"""Imported tree ready to be activated and ranked by users (currently inactive)."""
VALID_STATES = (
State.INITIAL_PROMPT_REVIEW,
@@ -51,6 +54,7 @@ VALID_STATES = (
State.READY_FOR_SCORING,
State.READY_FOR_EXPORT,
State.ABORTED_LOW_GRADE,
State.BACKLOG_RANKING,
)
TERMINAL_STATES = (State.READY_FOR_EXPORT, State.ABORTED_LOW_GRADE, State.SCORING_FAILED, State.HALTED_BY_MODERATOR)
@@ -67,3 +71,4 @@ class MessageTreeState(SQLModel, table=True):
max_children_count: int = Field(nullable=False)
state: str = Field(nullable=False, max_length=128, index=True)
active: bool = Field(nullable=False, index=True)
origin: str = Field(sa_column=sa.Column(sa.String(1024), nullable=True))
+3 -1
View File
@@ -20,7 +20,9 @@ class Task(SQLModel, table=True):
),
)
created_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp()),
sa_column=sa.Column(
sa.DateTime(timezone=True), nullable=False, index=True, server_default=sa.func.current_timestamp()
),
)
expiry_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True))
user_id: Optional[UUID] = Field(nullable=True, foreign_key="user.id", index=True)
+1
View File
@@ -30,6 +30,7 @@ class User(SQLModel, table=True):
enabled: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=sa.true()))
notes: str = Field(sa_column=sa.Column(AutoString(length=1024), nullable=False, server_default=""))
deleted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=sa.false()))
show_on_leaderboard: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=sa.true()))
def to_protocol_frontend_user(self):
return protocol.FrontEndUser(
+152 -19
View File
@@ -1,6 +1,7 @@
import random
import re
from collections import defaultdict
from datetime import datetime
from datetime import datetime, timedelta
from http import HTTPStatus
from typing import Optional
from uuid import UUID, uuid4
@@ -8,6 +9,8 @@ from uuid import UUID, uuid4
import oasst_backend.models.db_payload as db_payload
import sqlalchemy as sa
from loguru import logger
from oasst_backend.api.deps import FrontendUserId
from oasst_backend.config import settings
from oasst_backend.journal_writer import JournalWriter
from oasst_backend.models import (
ApiClient,
@@ -29,9 +32,10 @@ from oasst_backend.utils.database_utils import CommitMode, managed_tx_method
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.schemas.protocol import SystemStats
from oasst_shared.utils import unaware_to_utc
from oasst_shared.utils import unaware_to_utc, utcnow
from sqlalchemy.orm import Query
from sqlalchemy.orm.attributes import flag_modified
from sqlmodel import Session, and_, func, not_, or_, text, update
from sqlmodel import JSON, Session, and_, func, literal_column, not_, or_, text, update
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
@@ -41,14 +45,30 @@ class PromptRepository:
db: Session,
api_client: ApiClient,
client_user: Optional[protocol_schema.User] = None,
*,
user_repository: Optional[UserRepository] = None,
task_repository: Optional[TaskRepository] = None,
user_id: Optional[UUID] = None,
auth_method: Optional[str] = None,
username: Optional[str] = None,
frontend_user: Optional[FrontendUserId] = None,
):
self.db = db
self.api_client = api_client
self.user_repository = user_repository or UserRepository(db, api_client)
self.user = self.user_repository.lookup_client_user(client_user, create_missing=True)
self.user_id = self.user.id if self.user else None
if frontend_user and not auth_method and not username:
auth_method, username = frontend_user
if user_id:
self.user = self.user_repository.get_user(id=user_id)
self.user_id = self.user.id
elif auth_method and username:
self.user = self.user_repository.query_frontend_user(auth_method=auth_method, username=username)
self.user_id = self.user.id
else:
self.user = self.user_repository.lookup_client_user(client_user, create_missing=True)
self.user_id = self.user.id if self.user else None
logger.debug(f"PromptRepository(api_client_id={self.api_client.id}, {self.user_id=})")
self.task_repository = task_repository or TaskRepository(
db, api_client, client_user, user_repository=self.user_repository
@@ -157,6 +177,7 @@ class PromptRepository:
review_count: int = 0,
review_result: bool = False,
check_tree_state: bool = True,
check_duplicate: bool = True,
) -> Message:
self.ensure_user_is_enabled()
@@ -170,6 +191,18 @@ class PromptRepository:
role = None
depth = 0
# reject whitespaces match with ^\s+$
if re.match(r"^\s+$", text):
raise OasstError("Message text is empty", OasstErrorCode.TASK_MESSAGE_TEXT_EMPTY)
# ensure message size is below the predefined limit
if len(text) > settings.MESSAGE_SIZE_LIMIT:
logger.error(f"Message size {len(text)=} exceeds size limit of {settings.MESSAGE_SIZE_LIMIT=}.")
raise OasstError("Message size too long.", OasstErrorCode.TASK_MESSAGE_TOO_LONG)
if check_duplicate and self.check_users_recent_replies_for_duplicates(text):
raise OasstError("User recent messages have duplicates", OasstErrorCode.TASK_MESSAGE_DUPLICATED)
if task.parent_message_id:
parent_message = self.fetch_message(task.parent_message_id)
@@ -448,16 +481,24 @@ class PromptRepository:
task_id=task.id if task else None,
)
message: Message = None
if message_id:
message = self.fetch_message(message_id)
if task:
if not task:
if text_labels.is_report is True:
message = self.handle_message_emoji(
message_id, protocol_schema.EmojiOp.add, protocol_schema.EmojiCode.red_flag
)
# update existing record for repeated updates (same user no task associated)
existing_text_label = self.fetch_non_task_text_labels(message_id, self.user_id)
if existing_text_label is not None:
existing_text_label.labels = text_labels.labels
model = existing_text_label
else:
message = self.fetch_message(message_id)
message.review_count += 1
self.db.add(message)
# for the same User id with no task id associated with the message, then update existing record for repeated updates
existing_text_label = self.fetch_non_task_text_labels(message_id, self.user_id)
if existing_text_label is not None:
existing_text_label.labels = text_labels.labels
model = existing_text_label
self.db.add(model)
return model, task, message
@@ -529,7 +570,31 @@ class PromptRepository:
qry = qry.filter(Message.review_result)
if not include_deleted:
qry = qry.filter(not_(Message.deleted))
return qry.all()
return self._add_user_emojis_all(qry)
def check_users_recent_replies_for_duplicates(self, text: str) -> bool:
"""
Checks if the user has recently replied with the same text within a given time period.
"""
user_id = self.user_id
logger.debug(f"Checking for duplicate tasks for user {user_id}")
# messages in the past 24 hours
messages = (
self.db.query(Message)
.filter(Message.user_id == user_id)
.order_by(Message.created_date.desc())
.filter(
Message.created_date > utcnow() - timedelta(minutes=settings.DUPLICATE_MESSAGE_FILTER_WINDOW_MINUTES)
)
.all()
)
if not messages:
return False
for msg in messages:
if msg.text == text:
return True
return False
def fetch_user_message_trees(
self, user_id: Message.user_id, reviewed: bool = True, include_deleted: bool = False
@@ -539,7 +604,7 @@ class PromptRepository:
qry = qry.filter(Message.review_result)
if not include_deleted:
qry = qry.filter(not_(Message.deleted))
return qry.all()
return self._add_user_emojis_all(qry)
def fetch_message_trees_ready_for_export(self) -> list[MessageTreeState]:
qry = self.db.query(MessageTreeState).filter(
@@ -582,6 +647,10 @@ class PromptRepository:
return conversation, replies
def fetch_message(self, message_id: UUID, fail_if_missing: bool = True) -> Optional[Message]:
qry = self.db.query(Message).filter(Message.id == message_id)
messages = self._add_user_emojis_all(qry)
message = messages[0] if messages else None
message = self.db.query(Message).filter(Message.id == message_id).one_or_none()
if fail_if_missing and not message:
raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTP_404_NOT_FOUND)
@@ -656,9 +725,27 @@ class PromptRepository:
qry = qry.filter(Message.review_result)
if exclude_deleted:
qry = qry.filter(Message.deleted == sa.false())
children = qry.all()
children = self._add_user_emojis_all(qry)
return children
def fetch_message_siblings(
self, message: Message | UUID, reviewed: Optional[bool] = True, deleted: Optional[bool] = False
) -> list[Message]:
"""
Get siblings of a message (other messages with the same parent_id)
"""
if isinstance(message, Message):
message = message.id
parent_qry = self.db.query(Message.parent_id).filter(Message.id == message).subquery()
qry = self.db.query(Message).filter(Message.parent_id == parent_qry.c.parent_id)
if reviewed is not None:
qry = qry.filter(Message.review_result == reviewed)
if deleted is not None:
qry = qry.filter(Message.deleted == deleted)
siblings = self._add_user_emojis_all(qry)
return siblings
@staticmethod
def trace_descendants(root: Message, messages: list[Message]) -> list[Message]:
children = defaultdict(list)
@@ -687,7 +774,7 @@ class PromptRepository:
if max_depth is not None:
desc = desc.filter(Message.depth <= max_depth)
desc = desc.all()
desc = self._add_user_emojis_all(desc)
return self.trace_descendants(message, desc)
@@ -701,6 +788,33 @@ class PromptRepository:
max_message = max(tree, key=lambda m: m.children_count)
return max_message, [m for m in tree if m.parent_id == max_message.id]
def _add_user_emojis_all(self, qry: Query) -> list[Message]:
if self.user_id is None:
return qry.all()
sq = qry.subquery("m")
qry = (
self.db.query(Message, func.string_agg(MessageEmoji.emoji, literal_column("','")).label("user_emojis"))
.select_entity_from(sq)
.outerjoin(
MessageEmoji,
and_(
sq.c.id == MessageEmoji.message_id,
MessageEmoji.user_id == self.user_id,
sq.c.emojis != JSON.NULL,
),
)
.group_by(sq)
)
messages: list[Message] = []
for x in qry:
m: Message = x.Message
user_emojis = x["user_emojis"]
if user_emojis:
m._user_emojis = user_emojis.split(",")
messages.append(m)
return messages
def query_messages_ordered_by_created_date(
self,
user_id: Optional[UUID] = None,
@@ -783,7 +897,7 @@ class PromptRepository:
if lang is not None:
qry = qry.filter(Message.lang == lang)
return qry.all()
return self._add_user_emojis_all(qry)
def update_children_counts(self, message_tree_id: UUID):
sql_update_children_count = """
@@ -797,8 +911,7 @@ FROM (
) AS cc
WHERE message.id = cc.id;
"""
r = self.db.execute(text(sql_update_children_count), {"message_tree_id": message_tree_id})
logger.debug(f"update_children_count({message_tree_id=}): {r.rowcount} rows.")
self.db.execute(text(sql_update_children_count), {"message_tree_id": message_tree_id})
@managed_tx_method(CommitMode.COMMIT)
def mark_messages_deleted(self, messages: Message | UUID | list[Message | UUID], recursive: bool = True):
@@ -875,6 +988,20 @@ WHERE message.id = cc.id;
op = protocol_schema.EmojiOp.add
if op == protocol_schema.EmojiOp.add:
# hard coded exclusivity of thumbs_up & thumbs_down
if emoji == protocol_schema.EmojiCode.thumbs_up and message.has_user_emoji(
protocol_schema.EmojiCode.thumbs_down.value
):
message = self.handle_message_emoji(
message_id, protocol_schema.EmojiOp.remove, protocol_schema.EmojiCode.thumbs_down
)
elif emoji == protocol_schema.EmojiCode.thumbs_down and message.has_user_emoji(
protocol_schema.EmojiCode.thumbs_up.value
):
message = self.handle_message_emoji(
message_id, protocol_schema.EmojiOp.remove, protocol_schema.EmojiCode.thumbs_up
)
# insert emoji record & increment count
message_emoji = MessageEmoji(message_id=message.id, user_id=self.user_id, emoji=emoji)
self.db.add(message_emoji)
@@ -884,9 +1011,15 @@ WHERE message.id = cc.id;
else:
count = emoji_counts.get(emoji.value) or 0
emoji_counts[emoji.value] = count + 1
if message._user_emojis is None:
message._user_emojis = []
if emoji.value not in message._user_emojis:
message._user_emojis.append(emoji.value)
elif op == protocol_schema.EmojiOp.remove:
# remove emoji record and & decrement count
message = self.fetch_message(message_id)
if message._user_emojis and emoji.value in message._user_emojis:
message._user_emojis.remove(emoji.value)
self.db.delete(existing_emoji)
emoji_counts = message.emojis
count = emoji_counts.get(emoji.value)
+2 -9
View File
@@ -1,13 +1,6 @@
from typing import Optional
from oasst_shared.schemas.protocol import LabelDescription
from pydantic import BaseModel
class LabelOption(BaseModel):
name: str
display_text: str
help_text: Optional[str]
class ValidLabelsResponse(BaseModel):
valid_labels: list[LabelOption]
valid_labels: list[LabelDescription]
+17 -1
View File
@@ -1,3 +1,4 @@
from datetime import timedelta
from typing import Optional
from uuid import UUID
@@ -9,7 +10,7 @@ from oasst_backend.user_repository import UserRepository
from oasst_backend.utils.database_utils import CommitMode, managed_tx_method
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
from sqlmodel import Session, func, or_
from starlette.status import HTTP_404_NOT_FOUND
@@ -100,6 +101,7 @@ class TaskRepository:
message_id=task.message_id,
conversation=task.conversation,
reply=task.reply,
reply_message=task.reply_message,
valid_labels=task.valid_labels,
mandatory_labels=task.mandatory_labels,
mode=task.mode,
@@ -111,6 +113,7 @@ class TaskRepository:
message_id=task.message_id,
conversation=task.conversation,
reply=task.reply,
reply_message=task.reply_message,
valid_labels=task.valid_labels,
mandatory_labels=task.mandatory_labels,
mode=task.mode,
@@ -219,3 +222,16 @@ class TaskRepository:
def fetch_task_by_id(self, task_id: UUID) -> Task:
task = self.db.query(Task).filter(Task.api_client_id == self.api_client.id, Task.id == task_id).one_or_none()
return task
def fetch_recent_reply_tasks(
self, max_age: timedelta = timedelta(minutes=5), done: bool = False, limit: int = 100
) -> list[Task]:
qry = self.db.query(Task).filter(
func.age(Task.created_date) < max_age,
or_(Task.payload_type == "AssistantReplyPayload", Task.payload_type == "PrompterReplyPayload"),
)
if done is not None:
qry = qry.filter(Task.done == done)
if limit:
qry = qry.limit(limit)
return qry.all()
+263 -76
View File
@@ -1,7 +1,7 @@
import json
import random
import sys
from datetime import datetime
from datetime import datetime, timedelta
from enum import Enum
from http import HTTPStatus
from typing import Any, Dict, List, Optional, Tuple
@@ -11,7 +11,11 @@ import numpy as np
import pydantic
from fastapi.encoders import jsonable_encoder
from loguru import logger
from oasst_backend.api.v1.utils import prepare_conversation, prepare_conversation_message_list
from oasst_backend.api.v1.utils import (
prepare_conversation,
prepare_conversation_message,
prepare_conversation_message_list,
)
from oasst_backend.config import TreeManagerConfiguration, settings
from oasst_backend.models import Message, MessageReaction, MessageTreeState, Task, TextLabels, User, message_tree_state
from oasst_backend.prompt_repository import PromptRepository
@@ -21,7 +25,7 @@ from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingM
from oasst_backend.utils.ranking import ranked_pairs
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session, func, not_, text, update
from sqlmodel import Session, func, not_, or_, text, update
class TaskType(Enum):
@@ -41,8 +45,9 @@ class TaskRole(Enum):
class ActiveTreeSizeRow(pydantic.BaseModel):
message_tree_id: UUID
tree_size: int
goal_tree_size: int
tree_size: int
awaiting_review: Optional[int]
@property
def remaining_messages(self) -> int:
@@ -68,6 +73,7 @@ class IncompleteRankingsRow(pydantic.BaseModel):
role: str
children_count: int
child_min_ranking_count: int
message_tree_id: UUID
class Config:
orm_mode = True
@@ -93,8 +99,6 @@ class TreeManagerStats(pydantic.BaseModel):
class TreeManager:
_all_text_labels = list(map(lambda x: x.value, protocol_schema.TextLabel))
def __init__(
self,
db: Session,
@@ -201,8 +205,8 @@ class TreeManager:
lang = "en"
logger.warning("Task availability request without lang tag received, assuming lang='en'.")
num_active_trees = self.query_num_active_trees(lang=lang)
extendible_parents = self.query_extendible_parents(lang=lang)
num_active_trees = self.query_num_active_trees(lang=lang, exclude_ranking=True)
extendible_parents, _ = self.query_extendible_parents(lang=lang)
prompts_need_review = self.query_prompts_need_review(lang=lang)
replies_need_review = self.query_replies_need_review(lang=lang)
incomplete_rankings = self.query_incomplete_rankings(lang=lang)
@@ -215,6 +219,15 @@ class TreeManager:
incomplete_rankings=incomplete_rankings,
)
@staticmethod
def _get_label_descriptions(valid_labels: list[TextLabels]) -> list[protocol_schema.LabelDescription]:
return [
protocol_schema.LabelDescription(
name=l.value, widget=l.widget.value, display_text=l.display_text, help_text=l.help_text
)
for l in valid_labels
]
def next_task(
self,
desired_task_type: protocol_schema.TaskRequestType = protocol_schema.TaskRequestType.random,
@@ -229,17 +242,15 @@ class TreeManager:
lang = "en"
logger.warning("Task request without lang tag received, assuming 'en'.")
num_active_trees = self.query_num_active_trees(lang=lang)
num_active_trees = self.query_num_active_trees(lang=lang, exclude_ranking=True)
prompts_need_review = self.query_prompts_need_review(lang=lang)
replies_need_review = self.query_replies_need_review(lang=lang)
extendible_parents = self.query_extendible_parents(lang=lang)
extendible_parents, active_tree_sizes = self.query_extendible_parents(lang=lang)
incomplete_rankings = self.query_incomplete_rankings(lang=lang)
if not self.cfg.rank_prompter_replies:
incomplete_rankings = list(filter(lambda r: r.role == "assistant", incomplete_rankings))
active_tree_sizes = self.query_extendible_trees(lang=lang)
# determine type of task to generate
num_missing_replies = sum(x.remaining_messages for x in active_tree_sizes)
@@ -338,6 +349,7 @@ class TreeManager:
message_tree_id = messages[-1].message_tree_id
case TaskType.LABEL_REPLY:
if task_role == TaskRole.PROMPTER:
replies_need_review = list(filter(lambda m: m.role == "prompter", replies_need_review))
elif task_role == TaskRole.ASSISTANT:
@@ -347,64 +359,105 @@ class TreeManager:
random_reply_message = random.choice(replies_need_review)
messages = self.pr.fetch_message_conversation(random_reply_message)
conversation = prepare_conversation(messages[:-1])
conversation = prepare_conversation(messages)
message = messages[-1]
self.cfg.p_full_labeling_review_reply_prompter: float = 0.1
label_mode = protocol_schema.LabelTaskMode.full
valid_labels = self._all_text_labels
label_disposition = protocol_schema.LabelTaskDisposition.quality
if message.role == "assistant":
valid_labels = self.cfg.labels_assistant_reply
if (
desired_task_type == protocol_schema.TaskRequestType.random
and random.random() > self.cfg.p_full_labeling_review_reply_assistant
):
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply))
label_mode = protocol_schema.LabelTaskMode.simple
label_disposition = protocol_schema.LabelTaskDisposition.spam
valid_labels = list(self.cfg.mandatory_labels_assistant_reply)
if protocol_schema.TextLabel.quality not in valid_labels:
valid_labels.append(protocol_schema.TextLabel.quality)
logger.info(f"Generating a LabelAssistantReplyTask. ({label_mode=:s})")
task = protocol_schema.LabelAssistantReplyTask(
message_id=message.id,
conversation=conversation,
reply=message.text,
valid_labels=valid_labels,
reply_message=prepare_conversation_message(message),
valid_labels=list(map(lambda x: x.value, valid_labels)),
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)),
mode=label_mode,
disposition=label_disposition,
labels=self._get_label_descriptions(valid_labels),
)
else:
valid_labels = self.cfg.labels_prompter_reply
if (
desired_task_type == protocol_schema.TaskRequestType.random
and random.random() > self.cfg.p_full_labeling_review_reply_prompter
):
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply))
label_mode = protocol_schema.LabelTaskMode.simple
label_disposition = protocol_schema.LabelTaskDisposition.spam
valid_labels = list(self.cfg.mandatory_labels_prompter_reply)
if protocol_schema.TextLabel.quality not in valid_labels:
valid_labels.append(protocol_schema.TextLabel.quality)
logger.info(f"Generating a LabelPrompterReplyTask. ({label_mode=:s})")
task = protocol_schema.LabelPrompterReplyTask(
message_id=message.id,
conversation=conversation,
reply=message.text,
valid_labels=valid_labels,
reply_message=prepare_conversation_message(message),
valid_labels=list(map(lambda x: x.value, valid_labels)),
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)),
mode=label_mode,
disposition=label_disposition,
labels=self._get_label_descriptions(valid_labels),
)
parent_message_id = message.id
message_tree_id = message.message_tree_id
case TaskType.REPLY:
# select a tree with missing replies
recent_reply_tasks = self.pr.task_repository.fetch_recent_reply_tasks(
max_age=timedelta(seconds=self.cfg.recent_tasks_span_sec), done=False
)
recent_reply_task_parents = {t.parent_message_id for t in recent_reply_tasks}
if task_role == TaskRole.PROMPTER:
extendible_parents = list(filter(lambda x: x.parent_role == "assistant", extendible_parents))
elif task_role == TaskRole.ASSISTANT:
extendible_parents = list(filter(lambda x: x.parent_role == "prompter", extendible_parents))
# select a tree with missing replies
if len(extendible_parents) > 0:
random_parent = random.choice(extendible_parents)
random_parent: ExtendibleParentRow = None
if self.cfg.p_lonely_child_extension > 0 and self.cfg.lonely_children_count > 1:
# check if we have extendible parents with a small number of replies
lonely_children_parents = [
p
for p in extendible_parents
if 0 < p.active_children_count < self.cfg.lonely_children_count
and p.parent_id not in recent_reply_task_parents
]
if len(lonely_children_parents) > 0 and random.random() < self.cfg.p_lonely_child_extension:
random_parent = random.choice(lonely_children_parents)
if random_parent is None:
# try to exclude parents for which tasks were recently handed out
fresh_parents = [p for p in extendible_parents if p.parent_id not in recent_reply_task_parents]
if len(fresh_parents) > 0:
random_parent = random.choice(fresh_parents)
else:
random_parent = random.choice(extendible_parents)
# fetch random conversation to extend
logger.debug(f"selected {random_parent=}")
messages = self.pr.fetch_message_conversation(random_parent.parent_id)
assert all(m.review_result for m in messages) # ensure all messages have positive review
assert all(m.review_result for m in messages) # ensure all messages have positive reviews
conversation = prepare_conversation(messages)
# generate reply task depending on last message
@@ -423,19 +476,23 @@ class TreeManager:
message = random.choice(prompts_need_review)
label_mode = protocol_schema.LabelTaskMode.full
valid_labels = self._all_text_labels
label_disposition = protocol_schema.LabelTaskDisposition.quality
valid_labels = self.cfg.labels_initial_prompt
if random.random() > self.cfg.p_full_labeling_review_prompt:
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt))
valid_labels = self.cfg.mandatory_labels_initial_prompt
label_mode = protocol_schema.LabelTaskMode.simple
label_disposition = protocol_schema.LabelTaskDisposition.spam
logger.info(f"Generating a LabelInitialPromptTask ({label_mode=:s}).")
task = protocol_schema.LabelInitialPromptTask(
message_id=message.id,
prompt=message.text,
valid_labels=valid_labels,
valid_labels=list(map(lambda x: x.value, valid_labels)),
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt)),
mode=label_mode,
disposition=label_disposition,
labels=self._get_label_descriptions(valid_labels),
)
parent_message_id = message.id
@@ -468,14 +525,6 @@ class TreeManager:
logger.info(
f"Frontend reports text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
)
# ensure message size is below the predefined limit
if len(interaction.text) > settings.MESSAGE_SIZE_LIMIT:
logger.error(
f"Message size {len(interaction.text)=} exceeds size limit of {settings.MESSAGE_SIZE_LIMIT=}."
)
raise OasstError("Message size too long.", OasstErrorCode.TASK_MESSAGE_TOO_LONG)
# here we store the text reply in the database
message = pr.store_text_reply(
text=interaction.text,
@@ -533,9 +582,7 @@ class TreeManager:
)
_, task = pr.store_ranking(interaction)
ok, rankings_by_message = self.check_condition_for_scoring_state(task.message_tree_id)
self.update_message_ranks(task.message_tree_id, rankings_by_message)
self.check_condition_for_scoring_state(task.message_tree_id)
case protocol_schema.TextLabels:
logger.info(
@@ -544,7 +591,7 @@ class TreeManager:
_, task, msg = pr.store_text_labels(interaction)
# if it was a respones for a task, check if we have enough reviews to calc review_result
# if it was a response for a task, check if we have enough reviews to calc review_result
if task and msg:
reviews = self.query_reviews_for_message(msg.id)
acceptance_score = self._calculate_acceptance(reviews)
@@ -577,19 +624,28 @@ class TreeManager:
return protocol_schema.TaskDone()
@managed_tx_method(CommitMode.FLUSH)
def _enter_state(self, mts: MessageTreeState, state: message_tree_state.State):
assert mts and mts.active
assert mts
is_terminal = state in message_tree_state.TERMINAL_STATES
was_active = mts.active
if is_terminal:
mts.active = False
mts.state = state.value
self.db.add(mts)
self.db.flush
if is_terminal:
logger.info(f"Tree entered terminal '{mts.state}' state ({mts.message_tree_id=})")
root_msg = self.pr.fetch_message(message_id=mts.message_tree_id, fail_if_missing=False)
if root_msg and was_active:
if random.random() < self.cfg.p_activate_backlog_tree:
self.activate_backlog_tree(lang=root_msg.lang)
if self.cfg.min_active_rankings_per_lang > 0:
incomplete_rankings = self.query_incomplete_rankings(lang=root_msg.lang)
if len(incomplete_rankings) < self.cfg.min_active_rankings_per_lang:
self.activate_backlog_tree(lang=root_msg.lang)
else:
logger.info(f"Tree entered '{mts.state}' state ({mts.message_tree_id=})")
@@ -625,34 +681,43 @@ class TreeManager:
# check if desired tree size has been reached and all nodes have been reviewed
tree_size = self.query_tree_size(message_tree_id)
if tree_size.remaining_messages > 0:
logger.debug(f"False {tree_size.remaining_messages=}")
if tree_size.remaining_messages > 0 or tree_size.awaiting_review > 0:
logger.debug(f"False {tree_size.remaining_messages=}, {tree_size.awaiting_review=}")
return False
self._enter_state(mts, message_tree_state.State.RANKING)
return True
def check_condition_for_scoring_state(
self, message_tree_id: UUID
) -> Tuple[bool, dict[UUID, list[MessageReaction]]]:
def check_condition_for_scoring_state(self, message_tree_id: UUID) -> bool:
logger.debug(f"check_condition_for_scoring_state({message_tree_id=})")
mts = self.pr.fetch_tree_state(message_tree_id)
if not mts.active or mts.state != message_tree_state.State.RANKING:
logger.debug(f"False {mts.active=}, {mts.state=}")
return False, None
if mts.state != message_tree_state.State.SCORING_FAILED:
if not mts.active or mts.state not in (
message_tree_state.State.RANKING,
message_tree_state.State.READY_FOR_SCORING,
):
logger.debug(f"False {mts.active=}, {mts.state=}")
return False
ranking_role_filter = None if self.cfg.rank_prompter_replies else "assistant"
rankings_by_message = self.query_tree_ranking_results(message_tree_id, role_filter=ranking_role_filter)
for parent_msg_id, ranking in rankings_by_message.items():
if len(ranking) < self.cfg.num_required_rankings:
logger.debug(f"False {parent_msg_id=} {len(ranking)=}")
return False, None
return False
self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING)
return True, rankings_by_message
if (
mts.state != message_tree_state.State.SCORING_FAILED
and mts.state != message_tree_state.State.READY_FOR_SCORING
):
self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING)
self.update_message_ranks(message_tree_id, rankings_by_message)
return True
def update_message_ranks(self, message_tree_id: UUID, rankings_by_message: Dict[int, int]) -> bool:
def update_message_ranks(
self, message_tree_id: UUID, rankings_by_message: dict[UUID, list[MessageReaction]]
) -> bool:
mts = self.pr.fetch_tree_state(message_tree_id)
# check state, allow retry if in SCORING_FAILED state
@@ -660,19 +725,47 @@ class TreeManager:
logger.debug(f"False {mts.active=}, {mts.state=}")
return False
if mts.state == message_tree_state.State.SCORING_FAILED:
mts.active = True
mts.state = message_tree_state.State.READY_FOR_SCORING
try:
for rankings in rankings_by_message.values():
sorted_messages = []
for msg_reaction in rankings:
sorted_messages.append(msg_reaction.payload.payload.ranked_message_ids)
logger.debug(f"SORTED MESSAGE {sorted_messages}")
consensus = ranked_pairs(sorted_messages)
ordered_ids_list: list[list[UUID]] = [
msg_reaction.payload.payload.ranked_message_ids for msg_reaction in rankings
]
common_set: set[UUID] = set.intersection(*map(set, ordered_ids_list))
if len(common_set) < 2:
logger.warning("The intersection of ranking results ID sets has less than two elements. Skipping.")
continue
# keep only elements in commond set
ordered_ids_list = [list(filter(lambda x: x in common_set, ids)) for ids in ordered_ids_list]
assert all(len(x) == len(common_set) for x in ordered_ids_list)
logger.debug(f"SORTED MESSAGE IDS {ordered_ids_list}")
consensus = ranked_pairs(ordered_ids_list)
assert len(consensus) == len(common_set)
logger.debug(f"CONSENSUS: {consensus}\n\n")
# fetch all siblings and clear ranks
siblings = self.pr.fetch_message_siblings(consensus[0], reviewed=None, deleted=None)
for m in siblings:
m.rank = None
self.db.add(m)
# index by id
siblings = {m.id: m for m in siblings}
# set rank for each message that was part of the common set
for rank, message_id in enumerate(consensus):
# set rank for each message_id for Message rows
msg = self.pr.fetch_message(message_id=message_id, fail_if_missing=True)
msg.rank = rank
self.db.add(msg)
msg = siblings.get(message_id)
if msg:
msg.rank = rank
self.db.add(msg)
else:
logger.warning(f"Message {message_id=} not found among siblings.")
except Exception:
logger.exception(f"update_message_ranks({message_tree_id=}) failed")
@@ -680,8 +773,35 @@ class TreeManager:
return False
self._enter_state(mts, message_tree_state.State.READY_FOR_EXPORT)
return True
def activate_backlog_tree(self, lang: str) -> MessageTreeState:
while True:
# find tree in backlog state
backlog_tree: MessageTreeState = (
self.db.query(MessageTreeState)
.join(Message, MessageTreeState.message_tree_id == Message.id) # root msg
.filter(MessageTreeState.state == message_tree_state.State.BACKLOG_RANKING)
.filter(Message.lang == lang)
.limit(1)
.one_or_none()
)
if not backlog_tree:
return None
if len(self.query_tree_ranking_results(message_tree_id=backlog_tree.message_tree_id)) == 0:
logger.info(
f"Backlog tree {backlog_tree.message_tree_id} has no children to rank, aborting with 'aborted_low_grade' state."
)
self._enter_state(backlog_tree, message_tree_state.State.ABORTED_LOW_GRADE)
else:
logger.info(f"Activating backlog tree {backlog_tree.message_tree_id}")
backlog_tree.active = True
self._enter_state(backlog_tree, message_tree_state.State.RANKING)
return backlog_tree
def _calculate_acceptance(self, labels: list[TextLabels]):
# calculate acceptance based on spam label
return np.mean([1 - l.labels[protocol_schema.TextLabel.spam] for l in labels])
@@ -749,7 +869,8 @@ class TreeManager:
_sql_find_incomplete_rankings = """
-- find incomplete rankings
SELECT m.parent_id, m.role, COUNT(m.id) children_count, MIN(m.ranking_count) child_min_ranking_count,
COUNT(m.id) FILTER (WHERE m.ranking_count >= :num_required_rankings) as completed_rankings
COUNT(m.id) FILTER (WHERE m.ranking_count >= :num_required_rankings) as completed_rankings,
mts.message_tree_id
FROM message_tree_state mts
INNER JOIN message m ON mts.message_tree_id = m.message_tree_id
WHERE mts.active -- only consider active trees
@@ -758,7 +879,7 @@ WHERE mts.active -- only consider active trees
AND m.lang = :lang -- matches lang
AND NOT m.deleted -- not deleted
AND m.parent_id IS NOT NULL -- ignore initial prompts
GROUP BY m.parent_id, m.role
GROUP BY m.parent_id, m.role, mts.message_tree_id
HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings
"""
@@ -767,7 +888,8 @@ HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings
WITH incomplete_rankings AS ({_sql_find_incomplete_rankings})
SELECT ir.* FROM incomplete_rankings ir
LEFT JOIN message_reaction mr ON ir.parent_id = mr.message_id AND mr.payload_type = 'RankingReactionPayload'
GROUP BY ir.parent_id, ir.role, ir.children_count, ir.child_min_ranking_count, ir.completed_rankings
GROUP BY ir.parent_id, ir.role, ir.children_count, ir.child_min_ranking_count, ir.completed_rankings,
ir.message_tree_id
HAVING(COUNT(mr.message_id) FILTER (WHERE mr.user_id = :user_id) = 0)
"""
@@ -805,7 +927,7 @@ HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children
AND COUNT(c.id) FILTER (WHERE c.user_id = :user_id) = 0 -- without reply by user
"""
def query_extendible_parents(self, lang: str) -> list[ExtendibleParentRow]:
def query_extendible_parents(self, lang: str) -> tuple[list[ExtendibleParentRow], list[ActiveTreeSizeRow]]:
"""Query parent messages that have not reached the maximum number of replies."""
user_id = self.pr.user_id if not settings.DEBUG_ALLOW_DUPLICATE_TASKS else None
@@ -818,7 +940,13 @@ HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children
"user_id": user_id,
},
)
return [ExtendibleParentRow.from_orm(x) for x in r.all()]
potential_parents = [ExtendibleParentRow.from_orm(x) for x in r.all()]
extendible_trees = self.query_extendible_trees(lang=lang)
extendible_tree_ids = set(t.message_tree_id for t in extendible_trees)
extendible_parents = list(p for p in potential_parents if p.message_tree_id in extendible_tree_ids)
return extendible_parents, extendible_trees
_sql_find_extendible_trees = f"""
-- find extendible trees
@@ -854,18 +982,21 @@ HAVING COUNT(m.id) < mts.goal_tree_size
def query_tree_size(self, message_tree_id: UUID) -> ActiveTreeSizeRow:
"""Returns the number of reviewed not deleted messages in the message tree."""
required_reviews = settings.tree_manager.num_reviews_reply
qry = (
self.db.query(
MessageTreeState.message_tree_id.label("message_tree_id"),
MessageTreeState.goal_tree_size.label("goal_tree_size"),
func.count(Message.id).label("tree_size"),
func.count(Message.id).filter(Message.review_result).label("tree_size"),
func.count(Message.id)
.filter(not_(Message.review_result), Message.review_count < required_reviews)
.label("awaiting_review"),
)
.select_from(MessageTreeState)
.outerjoin(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
.filter(
MessageTreeState.active,
not_(Message.deleted),
Message.review_result,
MessageTreeState.message_tree_id == message_tree_id,
)
.group_by(MessageTreeState.message_tree_id, MessageTreeState.goal_tree_size)
@@ -903,8 +1034,8 @@ SELECT p.parent_id, mr.* FROM
GROUP BY m.parent_id, m.message_tree_id
HAVING COUNT(m.id) > 1
) as p
INNER JOIN task t ON p.parent_id = t.parent_message_id AND t.done AND (t.payload_type = 'RankPrompterRepliesPayload' OR t.payload_type = 'RankAssistantRepliesPayload')
INNER JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'RankingReactionPayload'
LEFT JOIN task t ON p.parent_id = t.parent_message_id AND t.done AND (t.payload_type = 'RankPrompterRepliesPayload' OR t.payload_type = 'RankAssistantRepliesPayload')
LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'RankingReactionPayload'
"""
def query_tree_ranking_results(
@@ -934,7 +1065,7 @@ INNER JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Ranki
return rankings_by_message
@managed_tx_method(CommitMode.COMMIT)
def ensure_tree_states(self):
def ensure_tree_states(self) -> None:
"""Add message tree state rows for all root nodes (inital prompt messages)."""
missing_tree_ids = self.query_misssing_tree_states()
@@ -946,12 +1077,54 @@ INNER JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Ranki
logger.info(f"Inserting missing message tree state for message: {id} ({tree_size=}, {state=:s})")
self._insert_default_state(id, state=state)
def query_num_active_trees(self, lang: str) -> int:
# check tree state transitions (maybe variables haves changes): prompt review -> growing -> ranking -> scoring
prompt_review_trees: list[MessageTreeState] = (
self.db.query(MessageTreeState)
.filter(MessageTreeState.state == message_tree_state.State.INITIAL_PROMPT_REVIEW, MessageTreeState.active)
.all()
)
if len(prompt_review_trees) > 0:
logger.info(
f"Checking state of {len(prompt_review_trees)} active message trees in 'initial_prompt_review' state."
)
for t in prompt_review_trees:
self.check_condition_for_growing_state(t.message_tree_id)
growing_trees: list[MessageTreeState] = (
self.db.query(MessageTreeState)
.filter(MessageTreeState.state == message_tree_state.State.GROWING, MessageTreeState.active)
.all()
)
if len(growing_trees) > 0:
logger.info(f"Checking state of {len(growing_trees)} active message trees in 'growing' state.")
for t in growing_trees:
self.check_condition_for_ranking_state(t.message_tree_id)
ranking_trees: list[MessageTreeState] = (
self.db.query(MessageTreeState)
.filter(
or_(
MessageTreeState.state == message_tree_state.State.RANKING,
MessageTreeState.state == message_tree_state.State.READY_FOR_SCORING,
),
MessageTreeState.active,
)
.all()
)
if len(ranking_trees) > 0:
logger.info(f"Checking state of {len(ranking_trees)} active message trees in 'ranking' state.")
for t in ranking_trees:
self.check_condition_for_scoring_state(t.message_tree_id)
def query_num_active_trees(self, lang: str, exclude_ranking: bool = True) -> int:
"""Count all active trees (optionally exclude those in ranking state)."""
query = (
self.db.query(func.count(MessageTreeState.message_tree_id))
.join(Message, MessageTreeState.message_tree_id == Message.id)
.filter(MessageTreeState.active, Message.lang == lang)
)
if exclude_ranking:
query = query.filter(MessageTreeState.state != message_tree_state.State.RANKING)
return query.scalar()
def query_reviews_for_message(self, message_id: UUID) -> list[TextLabels]:
@@ -1177,6 +1350,7 @@ DELETE FROM message WHERE message_tree_id = :message_tree_id;
sql_purge_user = """
DELETE FROM journal WHERE user_id = :user_id;
DELETE FROM message_reaction WHERE user_id = :user_id;
DELETE FROM message_emoji WHERE user_id = :user_id;
DELETE FROM task WHERE user_id = :user_id;
DELETE FROM message WHERE user_id = :user_id;
DELETE FROM user_stats WHERE user_id = :user_id;
@@ -1226,6 +1400,20 @@ DELETE FROM user_stats WHERE user_id = :user_id;
message_tree_ids = [ms.message_tree_id for ms in messages]
self.export_trees_to_file(message_tree_ids, file, reviewed, include_deleted, use_compression)
@managed_tx_method(CommitMode.COMMIT)
def retry_scoring_failed_message_trees(self):
query = self.db.query(MessageTreeState).filter(
MessageTreeState.state == message_tree_state.State.SCORING_FAILED
)
for mts in query.all():
mts: MessageTreeState
try:
if not self.check_condition_for_scoring_state(mts.message_tree_id):
mts.active = True
self._enter_state(mts, message_tree_state.State.RANKING)
except Exception:
logger.exception(f"retry_scoring_failed_message_trees failed for ({mts.message_tree_id=})")
if __name__ == "__main__":
from oasst_backend.api.deps import api_auth
@@ -1240,7 +1428,6 @@ if __name__ == "__main__":
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
pr = PromptRepository(db=db, api_client=api_client, client_user=dummy_user)
cfg = TreeManagerConfiguration()
tm = TreeManager(db, pr, cfg)
tm.ensure_tree_states()
@@ -1259,8 +1446,8 @@ if __name__ == "__main__":
# print("next_task:", tm.next_task())
# print(
# ".query_tree_ranking_results", tm.query_tree_ranking_results(UUID("2ac20d38-6650-43aa-8bb3-f61080c0d921"))
# )
print(
".query_tree_ranking_results", tm.query_tree_ranking_results(UUID("21f9d585-d22c-44ab-a696-baa3d83b5f1b"))
)
print(tm.export_trees_to_file(message_tree_ids=["7e75fb38-e664-4e2b-817c-b9a0b01b0074"], file="lol.jsonl"))
# print(tm.export_trees_to_file(message_tree_ids=["7e75fb38-e664-4e2b-817c-b9a0b01b0074"], file="lol.jsonl"))
+44 -13
View File
@@ -66,7 +66,13 @@ class UserRepository:
return user
@managed_tx_method(CommitMode.COMMIT)
def update_user(self, id: UUID, enabled: Optional[bool] = None, notes: Optional[str] = None) -> None:
def update_user(
self,
id: UUID,
enabled: Optional[bool] = None,
notes: Optional[str] = None,
show_on_leaderboard: Optional[bool] = None,
) -> None:
"""
Update a user by global user ID to disable or set admin notes. Only trusted clients may update users.
@@ -85,6 +91,8 @@ class UserRepository:
user.enabled = enabled
if notes is not None:
user.notes = notes
if show_on_leaderboard is not None:
user.show_on_leaderboard = show_on_leaderboard
self.db.add(user)
@@ -109,15 +117,20 @@ class UserRepository:
self.db.add(user)
@managed_tx_method(CommitMode.COMMIT)
def _lookup_client_user_tx(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]:
if not client_user:
return None
def _lookup_user_tx(
self,
*,
username: str,
auth_method: str,
display_name: Optional[str] = None,
create_missing: bool = True,
) -> User | None:
user: User = (
self.db.query(User)
.filter(
User.api_client_id == self.api_client.id,
User.username == client_user.id,
User.auth_method == client_user.auth_method,
User.username == username,
User.auth_method == auth_method,
)
.first()
)
@@ -125,28 +138,46 @@ class UserRepository:
if create_missing:
# user is unknown, create new record
user = User(
username=client_user.id,
display_name=client_user.display_name,
username=username,
display_name=display_name,
api_client_id=self.api_client.id,
auth_method=client_user.auth_method,
auth_method=auth_method,
show_on_leaderboard=(auth_method != "system"), # don't show system users, e.g. import user
)
self.db.add(user)
elif client_user.display_name and client_user.display_name != user.display_name:
elif display_name and display_name != user.display_name:
# we found the user but the display name changed
user.display_name = client_user.display_name
user.display_name = display_name
self.db.add(user)
return user
def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]:
def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> User | None:
if not client_user:
return None
num_retries = settings.DATABASE_MAX_TX_RETRY_COUNT
for i in range(num_retries):
try:
return self._lookup_client_user_tx(client_user, create_missing)
return self._lookup_user_tx(
username=client_user.id,
auth_method=client_user.auth_method,
display_name=client_user.display_name,
create_missing=create_missing,
)
except IntegrityError:
# catch UniqueViolation exception, for concurrent requests due to conflicts in ix_user_username
if i + 1 == num_retries:
raise
@managed_tx_method(CommitMode.COMMIT)
def lookup_system_user(self, username: str, create_missing: bool = True) -> User | None:
return self._lookup_user_tx(
username=username,
auth_method="system",
display_name=f"__system__/{username}",
create_missing=create_missing,
)
def query_users_ordered_by_username(
self,
api_client_id: Optional[UUID] = None,
@@ -39,7 +39,7 @@ class UserStatsRepository:
qry = (
self.session.query(User.id.label("user_id"), User.username, User.auth_method, User.display_name, UserStats)
.join(UserStats, User.id == UserStats.user_id)
.filter(UserStats.time_frame == time_frame.value)
.filter(UserStats.time_frame == time_frame.value, User.show_on_leaderboard)
.order_by(UserStats.rank)
.limit(limit)
)
@@ -214,6 +214,10 @@ class UserStatsRepository:
d = delete(UserStats).where(UserStats.time_frame == time_frame_key)
self.session.execute(d)
if None in stats_by_user:
logger.warning("Some messages in DB have NULL values in user_id column.")
del stats_by_user[None]
# compute magic leader score
for v in stats_by_user.values():
v.leader_score = v.compute_leader_score()
@@ -250,7 +254,8 @@ FROM
PARTITION BY time_frame
ORDER BY leader_score DESC, user_id
) AS "rank", user_id, time_frame
FROM user_stats
FROM user_stats us2
INNER JOIN "user" u ON us2.user_id = u.id AND u.show_on_leaderboard
WHERE (:time_frame IS NULL OR time_frame = :time_frame)) AS r
WHERE
us.user_id = r.user_id
+114 -54
View File
@@ -7,9 +7,14 @@ from loguru import logger
from oasst_backend.config import settings
from oasst_backend.database import engine
from oasst_shared.exceptions import OasstError, OasstErrorCode
from sqlalchemy.exc import OperationalError
from psycopg2.errors import DeadlockDetected, ExclusionViolation, SerializationFailure, UniqueViolation
from sqlalchemy.exc import OperationalError, PendingRollbackError
from sqlmodel import Session, SQLModel
"""
Error Handling Reference: https://www.postgresql.org/docs/15/mvcc-serialization-failure-handling.html
"""
class CommitMode(IntEnum):
"""
@@ -34,28 +39,46 @@ def managed_tx_method(auto_commit: CommitMode = CommitMode.COMMIT, num_retries=s
@wraps(f)
def wrapped_f(self, *args, **kwargs):
try:
for i in range(num_retries):
try:
result = f(self, *args, **kwargs)
if auto_commit == CommitMode.COMMIT:
result = None
if auto_commit == CommitMode.COMMIT:
retry_exhausted = True
for i in range(num_retries):
try:
result = f(self, *args, **kwargs)
self.db.commit()
elif auto_commit == CommitMode.FLUSH:
self.db.flush()
elif auto_commit == CommitMode.ROLLBACK:
if isinstance(result, SQLModel):
self.db.refresh(result)
retry_exhausted = False
break
except PendingRollbackError as e:
logger.info(str(e))
self.db.rollback()
except OperationalError as e:
if e.orig is not None and isinstance(
e.orig, (SerializationFailure, DeadlockDetected, UniqueViolation, ExclusionViolation)
):
logger.info(f"{type(e.orig)} Inner {e.orig.pgcode} {type(e.orig.pgcode)}")
self.db.rollback()
else:
raise e
logger.info(f"Retry {i+1}/{num_retries}")
if retry_exhausted:
raise OasstError(
"DATABASE_MAX_RETIRES_EXHAUSTED",
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
)
else:
result = f(self, *args, **kwargs)
if auto_commit == CommitMode.FLUSH:
self.db.flush()
if isinstance(result, SQLModel):
self.db.refresh(result)
return result
except OperationalError:
logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.")
elif auto_commit == CommitMode.ROLLBACK:
self.db.rollback()
raise OasstError(
"DATABASE_MAX_RETIRES_EXHAUSTED",
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
)
return result
except Exception as e:
logger.error("DB Rollback Failure")
logger.info(str(e))
raise e
return wrapped_f
@@ -70,28 +93,46 @@ def async_managed_tx_method(
@wraps(f)
async def wrapped_f(self, *args, **kwargs):
try:
for i in range(num_retries):
try:
result = await f(self, *args, **kwargs)
if auto_commit == CommitMode.COMMIT:
result = None
if auto_commit == CommitMode.COMMIT:
retry_exhausted = True
for i in range(num_retries):
try:
result = await f(self, *args, **kwargs)
self.db.commit()
elif auto_commit == CommitMode.FLUSH:
self.db.flush()
elif auto_commit == CommitMode.ROLLBACK:
if isinstance(result, SQLModel):
self.db.refresh(result)
retry_exhausted = False
break
except PendingRollbackError as e:
logger.info(str(e))
self.db.rollback()
except OperationalError as e:
if e.orig is not None and isinstance(
e.orig, (SerializationFailure, DeadlockDetected, UniqueViolation, ExclusionViolation)
):
logger.info(f"{type(e.orig)} Inner {e.orig.pgcode} {type(e.orig.pgcode)}")
self.db.rollback()
else:
raise e
logger.info(f"Retry {i+1}/{num_retries}")
if retry_exhausted:
raise OasstError(
"DATABASE_MAX_RETIRES_EXHAUSTED",
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
)
else:
result = await f(self, *args, **kwargs)
if auto_commit == CommitMode.FLUSH:
self.db.flush()
if isinstance(result, SQLModel):
self.db.refresh(result)
return result
except OperationalError:
logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.")
elif auto_commit == CommitMode.ROLLBACK:
self.db.rollback()
raise OasstError(
"DATABASE_MAX_RETIRES_EXHAUSTED",
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
)
return result
except Exception as e:
logger.exception("DB Rollback Failure")
logger.info(str(e))
raise e
return wrapped_f
@@ -107,7 +148,6 @@ def managed_tx_function(
auto_commit: CommitMode = CommitMode.COMMIT,
num_retries=settings.DATABASE_MAX_TX_RETRY_COUNT,
session_factory: Callable[..., Session] = default_session_factor,
refresh_result: bool = True,
):
"""Passes Session object as first argument to wrapped function."""
@@ -115,29 +155,49 @@ def managed_tx_function(
@wraps(f)
def wrapped_f(*args, **kwargs):
try:
for i in range(num_retries):
with session_factory() as session:
try:
result = f(session, *args, **kwargs)
if auto_commit == CommitMode.COMMIT:
result = None
if auto_commit == CommitMode.COMMIT:
retry_exhausted = True
for i in range(num_retries):
with session_factory() as session:
try:
result = f(session, *args, **kwargs)
session.commit()
elif auto_commit == CommitMode.FLUSH:
session.flush()
elif auto_commit == CommitMode.ROLLBACK:
if isinstance(result, SQLModel):
session.refresh(result)
retry_exhausted = False
break
except PendingRollbackError as e:
logger.info(str(e))
session.rollback()
if refresh_result and isinstance(result, SQLModel):
session.refresh(result)
return result
except OperationalError:
logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.")
session.rollback()
raise OasstError(
"DATABASE_MAX_RETIRES_EXHAUSTED",
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
)
except OperationalError as e:
if e.orig is not None and isinstance(
e.orig,
(SerializationFailure, DeadlockDetected, UniqueViolation, ExclusionViolation),
):
logger.info(f"{type(e.orig)} Inner {e.orig.pgcode} {type(e.orig.pgcode)}")
session.rollback()
else:
raise e
logger.info(f"Retry {i+1}/{num_retries}")
if retry_exhausted:
raise OasstError(
"DATABASE_MAX_RETIRES_EXHAUSTED",
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
)
else:
with session_factory() as session:
result = f(session, *args, **kwargs)
if auto_commit == CommitMode.FLUSH:
session.flush()
if isinstance(result, SQLModel):
session.refresh(result)
elif auto_commit == CommitMode.ROLLBACK:
session.rollback()
return result
except Exception as e:
logger.error("DB Rollback Failure")
logger.info(str(e))
raise e
return wrapped_f
@@ -0,0 +1,111 @@
import os
import pickle
from collections import Counter
from sklearn import metrics
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.svm import LinearSVC
def load_and_split(foldername, num_words):
ls = os.listdir(foldername)
X = []
Y = []
langmap = dict()
for idx, x in enumerate(ls):
print("loading language", x)
with open(foldername + "/" + x, "r") as reader:
tmp = reader.read().split(" ")
tmp = [" ".join(tmp[i : i + num_words]) for i in range(0, 100_000, num_words)]
X.extend(tmp)
Y.extend([idx] * len(tmp))
langmap[idx] = x
x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.90)
return x_train, x_test, y_train, y_test, langmap
def build_and_train_pipeline(x_train, y_train):
vectorizer = TfidfVectorizer(ngram_range=(1, 2), analyzer="char", use_idf=False)
clf = Pipeline(
[
("vec", vectorizer),
# ("nystrom", Nystroem(n_components=1000,n_jobs=6)),
("clf", LinearSVC(C=0.5)),
# ("clf",GaussianNB())
# ("clf", HistGradientBoostingClassifier())
]
)
print("fitting model...")
clf.fit(x_train, y_train)
return clf
def benchmark(clf, x_test, y_test, langmap):
print("benchmarking model...")
y_pred = clf.predict(x_test)
names = list(langmap.values())
# print(y_test)
# print(langmap)
print(metrics.classification_report(y_test, y_pred, target_names=names))
cm = metrics.confusion_matrix(y_test, y_pred)
print(cm)
def main(foldername, modelname, num_words):
x_train, x_test, y_train, y_test, langmap = load_and_split(foldername=foldername, num_words=num_words)
clf = build_and_train_pipeline(x_train, y_train)
benchmark(clf, x_test, y_test, langmap)
save_model(clf, langmap, num_words, modelname)
model = load(modelname)
print(
"running infernence on long tests",
inference_voter(
model,
"""
What language is this text written in? Nobody knows until you fill in at least ten words.
This test here is to check whether the moving window approach works,
so I still need to fill in a little more text.
""",
),
)
def load(modelname):
with open(modelname, "rb") as writer:
data = pickle.load(writer)
return data
def save_model(model, idx_to_name, num_words, modelname):
out = {
"model": model,
"idx_to_name": idx_to_name,
"num_words": num_words,
}
with open(modelname, "wb") as writer:
pickle.dump(out, writer)
def inference_voter(model, text):
tmp = text.split()
# print(len(tmp), tmp)
tmp = [" ".join(tmp[i : i + model["num_words"]]) for i in range(0, len(tmp) - model["num_words"])]
predictions = model["model"].predict(tmp)
# print("integer predictions", predictions)
# print("name predictions", *[model["idx_to_name"][n] for n in predictions])
result = Counter(predictions).most_common(1)[0][0]
return model["idx_to_name"][result]
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", help="save location for model and metadata")
parser.add_argument("-d", "--data", help="specify the folder for data files")
parser.add_argument("-n", "--num_words", help="number of words to use for statistics", type=int)
args = parser.parse_args()
# np.set_printoptions(threshold=np.inf)
main(args.data, args.model, args.num_words)
+29 -4
View File
@@ -96,13 +96,15 @@ def ranked_pairs(ranks: List[List[int]]):
"""
tallies, names = head_to_head_votes(ranks)
tallies = tallies - tallies.T
# print(tallies)
# note: the resulting tally matrix should be skew-symmetric
# order by strength of victory (using tideman's original method, don't think it would make a difference for us)
sorted_majorities = []
for i in range(len(ranks[0])):
for j in range(len(ranks[0])):
if tallies[i, j] > 0:
# you can never prefer yourself over yourself
# we also have to pick one of the two choices,
# if the preference is exactly zero...
if tallies[i, j] >= 0 and i != j:
sorted_majorities.append((i, j, tallies[i, j]))
# we don't explicitly deal with tied majorities here
sorted_majorities = np.array(sorted(sorted_majorities, key=lambda x: x[2], reverse=True))
@@ -128,13 +130,36 @@ def ranked_pairs(ranks: List[List[int]]):
if __name__ == "__main__":
ranks = (
ranks = """ (
[("w", "x", "z", "y") for _ in range(1)]
+ [("w", "y", "x", "z") for _ in range(2)]
# + [("x","y","z","w") for _ in range(4)]
+ [("x", "z", "w", "y") for _ in range(5)]
+ [("y", "w", "x", "z") for _ in range(1)]
# [("y","z","w","x") for _ in range(1000)]
)
)"""
ranks = [
[
("c5181083-d3e9-41e7-a935-83fb9fa01488"),
("dcf3d179-0f34-4c15-ae21-b8feb15e422d"),
("d11705af-5575-43e5-b22e-08d155fbaa62"),
],
[
("d11705af-5575-43e5-b22e-08d155fbaa62"),
("c5181083-d3e9-41e7-a935-83fb9fa01488"),
("dcf3d179-0f34-4c15-ae21-b8feb15e422d"),
],
[
("dcf3d179-0f34-4c15-ae21-b8feb15e422d"),
("c5181083-d3e9-41e7-a935-83fb9fa01488"),
("d11705af-5575-43e5-b22e-08d155fbaa62"),
],
[
("d11705af-5575-43e5-b22e-08d155fbaa62"),
("c5181083-d3e9-41e7-a935-83fb9fa01488"),
("dcf3d179-0f34-4c15-ae21-b8feb15e422d"),
],
]
rp = ranked_pairs(ranks)
print(rp)
+12 -6
View File
@@ -12,12 +12,15 @@ from pydantic import BaseModel
class ExportMessageNode(BaseModel):
message_id: str
parent_id: Optional[str]
text: Optional[str]
parent_id: str | None
text: str
role: str
review_count: Optional[int]
rank: Optional[int]
replies: Optional[list[ExportMessageNode]]
lang: str | None
review_count: int | None
rank: int | None
synthetic: bool | None
model_name: str | None
replies: list[ExportMessageNode] | None
@classmethod
def prep_message_export(cls, message: Message) -> ExportMessageNode:
@@ -26,14 +29,17 @@ class ExportMessageNode(BaseModel):
parent_id=str(message.parent_id) if message.parent_id else None,
text=str(message.payload.payload.text),
role=message.role,
lang=message.lang,
review_count=message.review_count,
synthetic=message.synthetic,
model_name=message.model_name,
rank=message.rank,
)
class ExportMessageTree(BaseModel):
message_tree_id: str
replies: Optional[ExportMessageNode]
prompt: Optional[ExportMessageNode]
def build_export_tree(message_tree_id: str, messages: list[Message]) -> ExportMessageTree:
+4
View File
@@ -1,4 +1,5 @@
alembic==1.8.1
cryptography==39.0.0
fastapi==0.88.0
fastapi-limiter==0.1.5
fastapi-utils==0.2.1
@@ -6,7 +7,10 @@ loguru==0.6.0
numpy==1.22.4
psycopg2-binary==2.9.5
pydantic==1.9.1
pydantic[email]==1.9.1
python-dotenv==0.21.0
python-jose[cryptography]==3.3.0
redis
scipy==1.8.1
SQLAlchemy==1.4.41
sqlmodel==0.0.8
+1
View File
@@ -13,5 +13,6 @@ RUN pip install -e /oasst-shared
COPY ./backend/alembic /app/alembic
COPY ./backend/alembic.ini /app/alembic.ini
COPY ./backend/main.py /app/main.py
COPY ./backend/import.py /app/import.py
COPY ./backend/oasst_backend /app/oasst_backend
COPY ./backend/test_data /app/test_data
+1
View File
@@ -26,6 +26,7 @@
as specific as possible.
- The assistant should never insult the user or engage in any inappropriate or
offensive behavior
- Always use spellchecking, typos in assistant responses are unacceptable.
## 3. When you play the user:
+14 -1
View File
@@ -2,7 +2,13 @@
Preliminary implementation of the inference engine for OpenAssistant.
## Development (you'll need multiple terminals)
## Development Variant 1 (you'll need tmux)
Run `./full-dev-setup.sh` to start the full development setup. Make sure to wait
until the 2nd terminal is ready and says `{"message":"Connected"}` before
entering input into the last terminal.
## Development Variant 2 (you'll need multiple terminals)
Run a redis container (or use the one of the general docker compose file):
@@ -26,6 +32,13 @@ pip install -r requirements.txt
python __main__.py
```
For the worker, you'll also want to have the text-generation-inference server
running:
```bash
docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 ykilcher/text-generation-inference
```
Run the client:
```bash
+20
View File
@@ -0,0 +1,20 @@
#!/bin/bash
# Creates a tmux window with splits for the individual services
tmux new-session -d -s "inference-dev-setup"
tmux send-keys "docker run --rm -it -p 6379:6379 redis" C-m
tmux split-window -h
tmux send-keys "docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 ykilcher/text-generation-inference" C-m
tmux split-window -h
tmux send-keys "cd server" C-m
tmux send-keys "uvicorn main:app --reload" C-m
tmux split-window -h
tmux send-keys "cd worker" C-m
tmux send-keys "python __main__.py" C-m
tmux split-window -h
tmux send-keys "cd text-client" C-m
tmux send-keys "sleep 5" C-m
tmux send-keys "python __main__.py" C-m
tmux select-layout even-horizontal
tmux attach-session -t "inference-dev-setup"
+53 -35
View File
@@ -5,6 +5,7 @@ import uuid
import fastapi
import pydantic
import redis.asyncio as redis
import websockets.exceptions
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from oasst_shared.schemas import inference, protocol
@@ -63,6 +64,7 @@ class MessageRequestState(str, enum.Enum):
pending = "pending"
in_progress = "in_progress"
complete = "complete"
aborted_by_worker = "aborted_by_worker"
class DbChatEntry(pydantic.BaseModel):
@@ -154,40 +156,56 @@ async def create_message(id: str, message_request: MessageRequest, fastapi_reque
async def work(websocket: fastapi.WebSocket):
await websocket.accept()
worker_config = inference.WorkerConfig.parse_raw(await websocket.receive_text())
while True:
# find a pending task that matches the worker's config
# could also be implemented using task queues
# but general compatibility matching is tricky
for chat in CHATS.values():
if (request := chat.pending_message_request) is not None:
if chat.message_request_state == MessageRequestState.pending:
if request.compatible_with(worker_config):
try:
while True:
print(websocket.client_state)
if websocket.client_state == fastapi.websockets.WebSocketState.DISCONNECTED:
logger.warning("Worker disconnected")
break
# find a pending task that matches the worker's config
# could also be implemented using task queues
# but general compatibility matching is tricky
for chat in CHATS.values():
if (request := chat.pending_message_request) is not None:
if chat.message_request_state == MessageRequestState.pending:
if request.compatible_with(worker_config):
break
else:
logger.debug("No pending tasks")
await asyncio.sleep(1)
continue
chat.message_request_state = MessageRequestState.in_progress
work_request = inference.WorkRequest(
conversation=chat.conversation,
model_name=request.model_name,
max_new_tokens=request.max_new_tokens,
)
logger.info(f"Created {work_request}")
try:
await websocket.send_text(work_request.json())
except websockets.exceptions.ConnectionClosedError:
logger.warning("Worker disconnected")
websocket.close()
chat.message_request_state = MessageRequestState.pending
break
try:
while True:
# maybe unnecessary to parse and re-serialize
# could just pass the raw string and mark end via empty string
response_packet = inference.WorkResponsePacket.parse_raw(await websocket.receive_text())
await redisClient.rpush(chat.id, response_packet.json())
if response_packet.is_end:
break
else:
logger.debug("No pending tasks")
await asyncio.sleep(1)
continue
except fastapi.WebSocketException:
# TODO: handle this better
logger.exception(f"Websocket closed during handling of {chat.id}")
chat.message_request_state = MessageRequestState.aborted_by_worker
raise
chat.message_request_state = MessageRequestState.in_progress
work_request = inference.WorkRequest(
conversation=chat.conversation,
model_name=request.model_name,
max_new_tokens=request.max_new_tokens,
)
logger.info(f"Created {work_request}")
try:
await websocket.send_text(work_request.json())
while True:
# maybe unnecessary to parse and re-serialize
# could just pass the raw string and mark end via empty string
response_packet = inference.WorkResponsePacket.parse_raw(await websocket.receive_text())
await redisClient.rpush(chat.id, response_packet.json())
if response_packet.is_end:
break
except fastapi.WebSocketException:
# TODO: handle this better
logger.exception(f"Websocket closed during handling of {chat.id}")
chat.message_request_state = MessageRequestState.complete
chat.message_request_state = MessageRequestState.complete
except fastapi.WebSocketException:
logger.exception("Websocket closed")
+36 -22
View File
@@ -1,13 +1,12 @@
import re
import time
import json
import rel
import torch
import requests
import sseclient
import typer
import websocket
from loguru import logger
from oasst_shared.schemas import inference, protocol
from transformers import pipeline
app = typer.Typer()
@@ -16,12 +15,13 @@ app = typer.Typer()
def main(
backend_url: str = "ws://localhost:8000",
model_name: str = "distilgpt2",
inference_server_url: str = "http://localhost:8001",
):
pipe = pipeline("text-generation", model=model_name)
def on_open(ws: websocket.WebSocket):
logger.info("Connected to backend, sending config...")
worker_config = inference.WorkerConfig(model_name=model_name)
ws.send(worker_config.json())
logger.info("Config sent, waiting for work...")
def on_message(ws: websocket.WebSocket, message: str):
# TODO: what if this comes in, but one is already in progress?
@@ -35,25 +35,39 @@ def main(
# construct prompt
messages = [_prepare_message(message) for message in work_request.conversation.messages]
prompt = "\n".join(messages) + "\nAssistant:"
prefix = (
"The following is a conversation between a user and an assistant. "
"The assistant is helpful, creative, clever, and very friendly.\n"
"Assistant: Hello! How can I help you today?\n"
)
# TODO: replace this with incremental generation
torch.manual_seed(work_request.seed)
model_output = pipe(prompt, max_new_tokens=work_request.max_new_tokens, do_sample=True, return_full_text=False)[
0
]["generated_text"]
model_output = model_output.strip()
prompt = prefix + "\n".join(messages) + "\nAssistant:"
# fake streaming
split_idcs = [m.start() for m in re.finditer(r"([\w:]+)", model_output)]
pieces = [model_output[a:b] for a, b in zip([0] + split_idcs, split_idcs + [None])]
for piece in pieces:
if not piece:
continue
if piece.strip() in ("User:", "Assistant:"):
response = requests.post(
f"{inference_server_url}/generate_stream",
json={
"inputs": prompt,
"parameters": {
"max_new_tokens": work_request.max_new_tokens,
"do_sample": work_request.do_sample,
"top_k": work_request.top_k,
"top_p": work_request.top_p,
"temperature": work_request.temperature,
"seed": work_request.seed,
},
},
stream=True,
headers={"Accept": "text/event-stream"},
)
response.raise_for_status()
client = sseclient.SSEClient(response)
for event in client.events():
data = json.loads(event.data)
if data["is_end"]:
break
ws.send(inference.WorkResponsePacket(token=piece).json())
time.sleep(0.1)
intermediate = data["event"]
ws.send(inference.WorkResponsePacket(token=intermediate["token"]).json())
ws.send(inference.WorkResponsePacket(is_end=True).json())
def on_error(ws: websocket.WebSocket, error: Exception):
+2 -2
View File
@@ -1,6 +1,6 @@
loguru
rel
torch
transformers
requests
sseclient-py
typer
websocket-client
@@ -0,0 +1,17 @@
model_name: microsoft/deberta-v2-xxlarge
learning_rate: 2e-6
scheduler: cosine
gradient_checkpointing: false
gradient_accumulation_steps: 12
per_device_train_batch_size: 2
per_device_eval_batch_size: 4
warmup_steps: 600
eval_steps: 1000000
save_steps: 1000
max_length: 400
num_train_epochs: 3
datasets:
- webgpt
- hfsummary
- anthropic_rlhf
- gptsynthetic
@@ -0,0 +1,29 @@
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 0.1
},
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"contiguous_gradients": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
+62 -1
View File
@@ -139,7 +139,7 @@ class HFSummary(Dataset):
"""
def __init__(self, split="train", conf_threshold=-1, max_comparison_per_sample=3) -> None:
def __init__(self, split="train", conf_threshold=-1, max_comparison_per_sample=1) -> None:
super().__init__()
assert split in ("train", "valid1", "valid2", "test")
summaries = {}
@@ -237,3 +237,64 @@ class HFDataset(Dataset):
class GPTJSynthetic(HFDataset):
def __init__(self) -> None:
super().__init__("Dahoas/synthetic-instruct-gptj-pairwise", "prompt", "chosen", "rejected", None, "train")
class AnthropicRLHF(Dataset):
"""
The data are described in the paper:
Training a Helpful and Harmless Assistant with Reinforcement Learning from Human Feedback.
If you find the data useful, please cite the paper.
The data format is very simple -- each line of the jsonl files contains a pair of texts,
one "chosen" and one "rejected".
valid train size : 160780
"""
def preprocess_dialogue(self, text):
"""
trim prefix text to last two pairs
Outlier example Assistant answered empty string:
Assistant: Human: That makes sense, I agree with that, though there are many situations that
aren't considered justice, like sending a kid to prison for life. Human: You are completely
missing the point of this conversation, and not understanding anything I am saying. Human:
And I dont know if youre trying to be funny, but it isnt.
"""
last_two_convo = text.split("Human:")[-2:]
if len(last_two_convo[0]) == 0:
return "Human:".join(last_two_convo)
return "Human: " + "Human:".join(last_two_convo)
def __init__(self, split="train", sep_token="<sep>") -> None:
super().__init__()
assert split in ("train", "test")
if sep_token is None:
sep_token = " . "
self.pairs = []
# using prompt as our index will allows us
# to add additional generated prompt later
major_split = split if "train" == split else "test"
dataset = load_dataset("Anthropic/hh-rlhf")[major_split]
for data in dataset:
processed = self.preprocess_dialogue(data["chosen"])
# roughly 20 of these are invalid conversation
if "Assistant" not in processed:
continue
prompt, pos_postfix = processed.split("Assistant:", maxsplit=1)
prompt = prompt.replace("Human: ", "").strip()
pos_postfix = pos_postfix.replace("Human: ", sep_token).replace("\n\nAssistant: ", sep_token).strip()
processed = self.preprocess_dialogue(data["rejected"])
if "Assistant" not in processed:
continue
_, neg_postfix = processed.split("Assistant:", maxsplit=1)
neg_postfix = neg_postfix.replace("Human: ", sep_token).replace("\n\nAssistant: ", sep_token).strip()
self.pairs.append((prompt, (pos_postfix.strip(), neg_postfix.strip())))
def __len__(self):
return len(self.pairs)
def __getitem__(self, index):
context, pair = self.pairs[index]
return context, [pair]
+11 -1
View File
@@ -1,5 +1,5 @@
from experimental_dataset import DataCollatorForSummaryScore, HFSummaryQuality
from rank_datasets import DataCollatorForPairRank, GPTJSynthetic, HFSummary, WebGPT
from rank_datasets import AnthropicRLHF, DataCollatorForPairRank, GPTJSynthetic, HFSummary, WebGPT
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
@@ -25,6 +25,16 @@ def test_webgpt():
print(batch["input_ids"].shape)
def test_anthropic_rlhf():
tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-large")
collate_fn = DataCollatorForPairRank(tokenizer, max_length=200)
dataset = AnthropicRLHF("test", sep_token=tokenizer.sep_token)
dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=32)
for batch in dataloader:
print(batch["input_ids"].shape)
def test_hf_summary_quality():
tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-large")
+24 -28
View File
@@ -8,15 +8,8 @@ import torch
from models import RankGenModel
from rank_datasets import DataCollatorForPairRank, RankGenCollator
from torch import nn
from transformers import (
AdamW,
AutoModelForSequenceClassification,
PreTrainedModel,
Trainer,
TrainingArguments,
get_cosine_schedule_with_warmup,
get_linear_schedule_with_warmup,
)
from transformers import AutoModelForSequenceClassification, PreTrainedModel, Trainer, TrainingArguments
from transformers.training_args import OptimizerNames
from utils import argument_parsing, freeze_top_n_layers, get_datasets, get_tokenizer
os.environ["WANDB_PROJECT"] = "reward-model"
@@ -24,6 +17,11 @@ os.environ["WANDB_PROJECT"] = "reward-model"
accuracy = evaluate.load("accuracy")
parser = ArgumentParser()
parser.add_argument("config", type=str)
parser.add_argument("--local_rank", type=int, default=-1)
parser.add_argument("--deepspeed", action="store_true")
parser.set_defaults(deepspeed=False)
parser.add_argument("--no-deepspeed", dest="deepspeed", action="store_false")
parser.add_argument("--wandb-entity", type=str, default="open-assistant")
def compute_metrics(eval_pred):
@@ -133,48 +131,46 @@ if __name__ == "__main__":
params = sum([np.prod(p.size()) for p in model_parameters])
print("Number of trainable : {}M".format(int(params / 1e6)))
optimizer = OptimizerNames.ADAMW_HF
args = TrainingArguments(
output_dir=f"{model_name}-finetuned",
num_train_epochs=training_conf["num_train_epochs"],
warmup_steps=500,
warmup_steps=training_conf["warmup_steps"],
optim=optimizer,
lr_scheduler_type=training_conf["scheduler"],
learning_rate=training_conf["learning_rate"],
# half_precision_backend="apex",
deepspeed="configs/zero_config.json" if training_conf["deepspeed"] else None,
fp16=training_conf["fp16"],
local_rank=training_conf["local_rank"],
gradient_checkpointing=training_conf["gradient_checkpointing"],
gradient_accumulation_steps=training_conf["gradient_accumulation_steps"],
per_device_train_batch_size=training_conf["per_device_train_batch_size"],
per_device_eval_batch_size=training_conf["per_device_eval_batch_size"],
weight_decay=0.01,
max_grad_norm=2.0,
weight_decay=training_conf["weight_decay"],
max_grad_norm=training_conf["max_grad_norm"],
logging_steps=10,
save_total_limit=4,
evaluation_strategy="steps",
eval_steps=training_conf["eval_steps"],
save_steps=1000,
save_steps=training_conf["save_steps"],
report_to="wandb",
)
tokenizer = get_tokenizer(training_conf["tokenizer_name"])
train, evals = get_datasets(training_conf["datasets"])
train, evals = get_datasets(training_conf["datasets"], tokenizer)
if "rankgen" in model_name:
collate_fn = RankGenCollator(tokenizer, max_length=training_conf["max_length"])
else:
collate_fn = DataCollatorForPairRank(tokenizer, max_length=training_conf["max_length"])
assert len(evals) > 0
optimizer = AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
scheduler = None
if "scheduler" in training_conf:
if training_conf["scheduler"] == "linear":
scheduler = get_linear_schedule_with_warmup()
elif training_conf["scheduler"] == "cosine":
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=args.warmup_steps,
num_training_steps=len(train)
* args.num_train_epochs
/ (args.per_device_train_batch_size * args.gradient_accumulation_steps),
)
if not training_conf["deepspeed"] or training_conf["local_rank"] == 0:
import wandb
wandb.init(
project=os.environ["WANDB_PROJECT"], name=f"{model_name}-finetuned", entity=training_conf["wandb_entity"]
)
trainer = RankTrainer(
model=model,
@@ -186,7 +182,7 @@ if __name__ == "__main__":
data_collator=collate_fn,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
optimizers=(optimizer, scheduler),
# optimizers=(optimizer, scheduler),
)
# trainer.evaluate()
trainer.train()
+26 -14
View File
@@ -81,26 +81,41 @@ def argument_parsing(parser):
"learning_rate": 3e-5,
"eval_steps": 500,
"loss": "rank",
"warmup_steps": 500,
"max_length": 440,
"weight_decay": 0.01,
"max_grad_norm": 2.0,
"save_steps": 500,
"per_device_eval_batch_size": 5,
"per_device_train_batch_size": 8,
"gradient_accumulation_steps": 8,
"gradient_checkpointing": False,
"deepspeed": args.deepspeed,
"local_rank": args.local_rank,
"datasets": ["webgpt"],
"wandb_entity": args.wandb_entity,
"fp16": True,
"tokenizer_name": training_conf["model_name"],
}
params = {**default_params, **training_conf}
params["gradient_accumulation_steps"] = int(params["gradient_accumulation_steps"])
params["num_train_epochs"] = int(params["num_train_epochs"])
params["per_device_train_batch_size"] = int(params["per_device_train_batch_size"])
params["learning_rate"] = float(params["learning_rate"])
for name in [
"gradient_accumulation_steps",
"num_train_epochs",
"save_steps",
"eval_steps",
"per_device_train_batch_size",
"per_device_eval_batch_size",
]:
params[name] = int(params[name])
for name in ["learning_rate", "weight_decay", "max_grad_norm"]:
params[name] = float(params[name])
return params
def get_datasets(dataset_list: List[AnyStr]):
from rank_datasets import GPTJSynthetic, HFSummary, WebGPT
def get_datasets(dataset_list: List[AnyStr], tokenizer):
from rank_datasets import AnthropicRLHF, GPTJSynthetic, HFSummary, WebGPT
from torch.utils.data import ConcatDataset
train_datasets, evals = [], {}
@@ -121,13 +136,10 @@ def get_datasets(dataset_list: List[AnyStr]):
train, eval = train_val_dataset(dataset, 0.1)
train_datasets.append(train)
evals["gptsynthetic"] = eval
elif "anthropic_rlhf" == dataset_name:
train = AnthropicRLHF("train", tokenizer.sep_token)
eval = AnthropicRLHF("test", tokenizer.sep_token)
train_datasets.append(train)
evals["anthropic_rlhf"] = eval
train = ConcatDataset(train_datasets)
return train, evals
if __name__ == "__main__":
from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained("bigscience/bloomz-560m")
freeze_top_n_layers(model, 10)
print(model.state_dict().keys())
@@ -9,7 +9,19 @@ from custom_datasets.translation import WMT2019, DiveMT, TEDTalk
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
QA_DATASETS = ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_nocontext", "gsm8k"]
QA_DATASETS = [
"squad_v2",
"adversarial_qa",
"trivia_qa_context",
"trivia_qa_nocontext",
"gsm8k",
"wikihow",
"essay_instruction",
"math_qa",
"reddit_eli5",
"reddit_askh",
"reddit_asks",
]
SUMMARIZATION_DATASETS = [
"xsum",
"cnn_dailymail",
@@ -35,16 +47,16 @@ def get_one_dataset(conf, dataset_name):
if dataset_name in QA_DATASETS:
train = QADataset(dataset_name, conf.cache_dir, "train")
val_name = "validation" if dataset_name not in ["gsm8k"] else "test"
eval = QADataset(dataset_name, conf.cache_dir, val_name)
if train.no_val:
train, eval = train_val_dataset(train, val_split=0.2)
else:
eval = QADataset(dataset_name, conf.cache_dir, "validation")
elif dataset_name in SUMMARIZATION_DATASETS:
train = SummarizationDataset(dataset_name, conf.cache_dir, "train")
if dataset_name == "debate_sum":
train, eval = train_val_dataset(train, val_split=0.2)
else:
val_name = "validation" if dataset_name not in ["billsum", "tldr_news"] else "test"
eval = SummarizationDataset(dataset_name, conf.cache_dir, val_name)
eval = SummarizationDataset(dataset_name, conf.cache_dir, "validation")
elif "ted_trans" in dataset_name:
language_pair = dataset_name.split("_")[-1]
dataset = TEDTalk(pair=language_pair, split="train")
@@ -49,26 +49,92 @@ def index_gsm8k(example):
return example["question"], example["answer"]
def index_wikihow(example):
return example["title"] + ", explain step by step", example["result"]
def index_essay_instruction(example):
return example["instructions"], example["titles"].strip() + "\n" + example["essays"]
def index_math_qa(example):
"""
we are not including choices, so no need to output the "answer : <a,b,c,d>" part
> if girls is 10 and boys is 20 , then 10 / 20 . so ratio of girls to boys is = 10 / 20 = 1 / 2 answer : a
"""
return example["Problem"], example["Rationale"].split("answer : ", maxsplit=1)[0]
def index_eli5(example):
return example["title"], example["answers"]["text"][0]
class QADataset(Dataset):
"""
How to define a new QA dataset:
Criteria : the qa dataset doesn't need fancy transform needed between fields rows or list
1. Write the transform function, which maps each row into a pair of (question, answer) tuple
2. Update DATASET_FORMAT_MAPPING with your dataset name and required parameter
- index_fn : your transform function
- name: the dataset name, this will be used when the name is different than huggingface load_dataset name
- params: if your dataset require a predefined name, create a dictionary with the parameter name-value dictionary
Feel free to create issues on GH for any suggestion how we can simplify this thing
"""
DATASET_FORMAT_MAPPING = {
"squad_v2": {"index_fn": index_squad_v2},
"trivia_qa_nocontext": {
"index_fn": index_trivia_qa_nocontext,
"name": "trivia_qa",
"params": {"name": "rc.nocontext"},
},
"trivia_qa_context": {"index_fn": index_trivia_qa_context, "name": "trivia_qa", "params": {"name": "rc"}},
"adversarial_qa": {
"index_fn": index_adversarial_qa,
"params": {"name": "adversarialQA"},
},
"gsm8k": {"index_fn": index_gsm8k, "params": {"name": "main"}, "validation": "test"},
"wikihow": {"name": "b-mc2/wikihow_lists", "index_fn": index_wikihow, "no_val": True},
"essay_instruction": {
"name": "ChristophSchuhmann/essays-with-instructions",
"index_fn": index_essay_instruction,
"no_val": True,
},
"math_qa": {
"index_fn": index_math_qa,
},
"reddit_eli5": {"name": "eli5", "index_fn": index_eli5, "split_postfix": "_eli5"},
"reddit_askh": {"name": "eli5", "index_fn": index_eli5, "split_postfix": "_askh"},
"reddit_asks": {"name": "eli5", "index_fn": index_eli5, "split_postfix": "_asks"},
}
def __init__(self, dataset, cache_dir, split):
if dataset == "squad_v2":
self.index_fn = index_squad_v2
self.dataset = load_dataset("squad_v2", cache_dir=cache_dir, split=split)
elif dataset == "trivia_qa_nocontext":
self.index_fn = index_trivia_qa_nocontext
self.dataset = load_dataset("trivia_qa", "rc.nocontext", split=split, cache_dir=cache_dir)
elif dataset == "trivia_qa_context":
self.index_fn = index_trivia_qa_context
self.dataset = load_dataset("trivia_qa", "rc", split=split, cache_dir=cache_dir)
elif dataset == "adversarial_qa":
self.index_fn = index_adversarial_qa
self.dataset = load_dataset("adversarial_qa", "adversarialQA", split=split, cache_dir=cache_dir)
elif dataset == "gsm8k":
self.index_fn = index_gsm8k
self.dataset = load_dataset("gsm8k", "main", split=split, cache_dir=cache_dir)
elif dataset == "adversarial_qa":
self.index_fn = index_adversarial_qa
self.dataset = load_dataset("adversarial_qa", "adversarialQA", split=split, cache_dir=cache_dir)
self.no_val = False
if dataset in self.DATASET_FORMAT_MAPPING:
context = self.DATASET_FORMAT_MAPPING[dataset]
if split == "validation" and "validation" in context:
split = context["validation"]
if "name" not in context:
context["name"] = dataset
if "split_postfix" in context:
# append a postfix to split name, used in eli5 : test_eli5, test_asks, test_askh
split += context["split_postfix"]
if "params" not in context:
context["params"] = {"cache_dir": cache_dir, "split": split}
else:
context["params"]["cache_dir"] = cache_dir
context["params"]["split"] = split
if "no_val" in context:
self.no_val = True
self.index_fn = context["index_fn"]
self.dataset = load_dataset(context["name"], **context["params"])
else:
raise ValueError("Unknown dataset : " + dataset)
@@ -259,6 +325,3 @@ class JokeExplaination(Dataset):
def __getitem__(self, index):
return format_pair(self.pairs[index])
# https://huggingface.co/datasets/aquamuse
@@ -57,6 +57,8 @@ def index_summary_merge(text, summary):
class SummarizationDataset(Dataset):
def __init__(self, dataset, cache_dir, split, max_words=512):
self.name = dataset
if summarization_config_mapping[dataset][0] in ["billsum", "tldr_news"] & split == "validation":
split = "test"
self.dataset = load_dataset(*summarization_config_mapping[dataset], cache_dir=cache_dir, split=split)
self.text_column, self.summary_column = summarization_name_mapping[dataset]
self.preprocess_fn = index_summary_merge if dataset == "scitldr" else index_summary_default
@@ -0,0 +1,37 @@
# README
## Introduction
This program converts data obtained from the subreddit r/changemyview into a cleaner format for further data processing. The data is not clean enough to be used directly in a model yet, and additional preprocessing is required.
## Data Format
The cleaned data is stored in an Apache Parquet file with the following columns:
| Column Name | Description | Data Type |
|-------------|------------------------------------------------------------------------|----------------|
| INSTRUCTION | Post title + body text | String |
| RESPONSE | Body text of comments attempting to change OP's mind of `INSTRUCTION`. | List\<String\> |
| SOURCE | Permalink to the reddit post | String |
| METADATA | Metadata related to `RESPONSE`. | Dict\<Variant> |
### Metadata
Currently, metadata is only broken into one category:
- `detoxify_labels`- A Dictionary of values outputted by the [Unitaryai Detoxifier](https://github.com/unitaryai/detoxify) model, fitted to every comment under any given post.
## Usage
To use the program, follow these instructions:
1. **Clone the repository** - `git clone https://github.com/LAION-AI/Open-Assistant.git`
2. **Navigate to the project directory** - `cd notebooks/data-augmentation/changemyview-builder`
3. **Open the Jupyter Notebook** - `jupyter notebook data_processor.ipynb`
4. **Run the program** - Go through the notebook and run the cells
## Contributing
If you would like to contribute to this project, please fork the repository and submit a pull request with your changes.
## License
This project is licensed under the Apache-2.0 License - see the [LICENSE](LICENSE) file for details.
@@ -0,0 +1,577 @@
{
"cells": [
{
"cell_type": "markdown",
"source": [
"# r/ChangeMyView data converter\n",
"Converts subreddit data into readable format for ML training\n",
"\n",
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/LAION-AI/Open-Assistant/blob/main/notebooks/data-augmentation/changemyview-builder/data_processor.ipynb)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 65,
"outputs": [],
"source": [
"### REMEMBER: setup the .env before running this code!\n",
"\n",
"\"\"\"CONSTANTS\"\"\"\n",
"\n",
"# Set the head number to the amount of entries you want to load in minus one\n",
"ENTRIES_COUNT = 10\n",
"\n",
"# Set the threshold for toxic comments to be removed\n",
"TOXIC_THRESHOLD = 0.95"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 66,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: pandas in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (1.4.4)\r\n",
"Requirement already satisfied: python-dateutil>=2.8.1 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from pandas) (2.8.2)\r\n",
"Requirement already satisfied: pytz>=2020.1 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from pandas) (2022.1)\r\n",
"Requirement already satisfied: numpy>=1.18.5 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from pandas) (1.21.5)\r\n",
"Requirement already satisfied: six>=1.5 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from python-dateutil>=2.8.1->pandas) (1.16.0)\r\n",
"Requirement already satisfied: praw in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (7.6.1)\r\n",
"Requirement already satisfied: websocket-client>=0.54.0 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from praw) (0.58.0)\r\n",
"Requirement already satisfied: update-checker>=0.18 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from praw) (0.18.0)\r\n",
"Requirement already satisfied: prawcore<3,>=2.1 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from praw) (2.3.0)\r\n",
"Requirement already satisfied: requests<3.0,>=2.6.0 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from prawcore<3,>=2.1->praw) (2.28.1)\r\n",
"Requirement already satisfied: six in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from websocket-client>=0.54.0->praw) (1.16.0)\r\n",
"Requirement already satisfied: idna<4,>=2.5 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from requests<3.0,>=2.6.0->prawcore<3,>=2.1->praw) (3.3)\r\n",
"Requirement already satisfied: certifi>=2017.4.17 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from requests<3.0,>=2.6.0->prawcore<3,>=2.1->praw) (2022.9.24)\r\n",
"Requirement already satisfied: charset-normalizer<3,>=2 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from requests<3.0,>=2.6.0->prawcore<3,>=2.1->praw) (2.0.4)\r\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from requests<3.0,>=2.6.0->prawcore<3,>=2.1->praw) (1.26.11)\r\n",
"Requirement already satisfied: python-dotenv in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (0.21.0)\r\n",
"Requirement already satisfied: pyarrow in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (10.0.1)\r\n",
"Requirement already satisfied: numpy>=1.16.6 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from pyarrow) (1.21.5)\r\n",
"Requirement already satisfied: detoxify in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (0.5.1)\r\n",
"Requirement already satisfied: transformers==4.22.1 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from detoxify) (4.22.1)\r\n",
"Requirement already satisfied: torch>=1.7.0 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from detoxify) (1.13.1)\r\n",
"Requirement already satisfied: sentencepiece>=0.1.94 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from detoxify) (0.1.97)\r\n",
"Requirement already satisfied: huggingface-hub<1.0,>=0.9.0 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from transformers==4.22.1->detoxify) (0.11.1)\r\n",
"Requirement already satisfied: regex!=2019.12.17 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from transformers==4.22.1->detoxify) (2022.7.9)\r\n",
"Requirement already satisfied: pyyaml>=5.1 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from transformers==4.22.1->detoxify) (6.0)\r\n",
"Requirement already satisfied: tqdm>=4.27 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from transformers==4.22.1->detoxify) (4.64.1)\r\n",
"Requirement already satisfied: tokenizers!=0.11.3,<0.13,>=0.11.1 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from transformers==4.22.1->detoxify) (0.12.1)\r\n",
"Requirement already satisfied: filelock in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from transformers==4.22.1->detoxify) (3.6.0)\r\n",
"Requirement already satisfied: packaging>=20.0 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from transformers==4.22.1->detoxify) (21.3)\r\n",
"Requirement already satisfied: numpy>=1.17 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from transformers==4.22.1->detoxify) (1.21.5)\r\n",
"Requirement already satisfied: requests in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from transformers==4.22.1->detoxify) (2.28.1)\r\n",
"Requirement already satisfied: typing-extensions in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from torch>=1.7.0->detoxify) (4.3.0)\r\n",
"Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from packaging>=20.0->transformers==4.22.1->detoxify) (3.0.9)\r\n",
"Requirement already satisfied: idna<4,>=2.5 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from requests->transformers==4.22.1->detoxify) (3.3)\r\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from requests->transformers==4.22.1->detoxify) (1.26.11)\r\n",
"Requirement already satisfied: charset-normalizer<3,>=2 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from requests->transformers==4.22.1->detoxify) (2.0.4)\r\n",
"Requirement already satisfied: certifi>=2017.4.17 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from requests->transformers==4.22.1->detoxify) (2022.9.24)\r\n",
"Requirement already satisfied: tqdm in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (4.64.1)\r\n"
]
}
],
"source": [
"# Install any dependencies\n",
"!pip install pandas\n",
"!pip install praw\n",
"!pip install python-dotenv\n",
"!pip install pyarrow\n",
"!pip install detoxify\n",
"!pip install tqdm"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import praw\n",
"import os\n",
"from os.path import join, dirname\n",
"from dotenv import main\n",
"\n",
"# Make sure you create a .env file and fill in all the necessary information in the same folder as this script!\n",
"main.load_dotenv(join(dirname(os.path.realpath('__file__')), '.env'))\n",
"\n",
"reddit = praw.Reddit(\n",
" client_id=os.environ.get(\"CLIENT_ID\"),\n",
" client_secret=os.environ.get(\"CLIENT_SECRET\"),\n",
" user_agent=\"CMV_Scraper\",\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 68,
"outputs": [],
"source": [
"# load the data\n",
"import tarfile\n",
"import os.path\n",
"import json\n",
"import re\n",
"from bz2 import BZ2File\n",
"from urllib import request\n",
"from io import BytesIO\n",
"\n",
"import numpy as np\n",
"\n",
"\n",
"fname = \"cmv.tar.bz2\"\n",
"url = \"https://chenhaot.com/data/cmv/\" + fname\n",
"\n",
"# download if not exists\n",
"if not os.path.isfile(fname):\n",
" f = BytesIO()\n",
" with request.urlopen(url) as resp, open(fname, 'wb') as f_disk:\n",
" data = resp.read()\n",
" f_disk.write(data) # save to disk too\n",
" f.write(data)\n",
" f.seek(0)\n",
"else:\n",
" f = open(fname, 'rb')\n",
"\n",
"\n"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 69,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/kayjaymac/opt/anaconda3/lib/python3.9/bz2.py:124: ResourceWarning: unclosed file <_io.BufferedReader name='cmv.tar.bz2'>\n",
" self._buffer = None\n",
"ResourceWarning: Enable tracemalloc to get the object allocation traceback\n"
]
}
],
"source": [
"#tar = tarfile.open(fileobj=f, mode=\"r:bz2\")\n",
"tar = tarfile.open(fileobj=f, mode=\"r\")\n",
"\n",
"# Extract the file we are interested in\n",
"\n",
"train_fname = \"op_task/train_op_data.jsonlist.bz2\"\n",
"test_fname = \"op_task/heldout_op_data.jsonlist.bz2\"\n",
"\n",
"train_bzlist = tar.extractfile(train_fname)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 70,
"outputs": [],
"source": [
"# Deserialize the JSON list\n",
"original_posts_train = [\n",
" json.loads(line.decode('utf-8'))\n",
" for line in BZ2File(train_bzlist)\n",
"]"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 71,
"outputs": [
{
"data": {
"text/plain": "[{'title': \"CMV: I shouldn't get a job in this economic climate because it'll be automated anyway; I should just wait for a post-scarcity utopia.\",\n 'delta_label': False,\n 'name': 't3_2rpsl8',\n 'selftext': \"I think the world is automating fast enough that a utopia will arise where no one will have to work anymore. Within the next 2 decades or so, having a job won't mean much, and most people will be artists and scientists. \\n\\nMy parents let me live with them, so I can just wait until the utopia happens.\\n\\nCMV.\"}]"
},
"execution_count": 71,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"original_posts_train[:1]"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 72,
"outputs": [],
"source": [
"# Load the jsonlist file into a dataframe\n",
"#df = pd.read_json(original_posts_train, orient='list', lines=True)\n",
"df = pd.DataFrame(original_posts_train)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 73,
"outputs": [],
"source": [
"# Function to check if the posts still exists on reddit\n",
"def try_get_post(post_id):\n",
" try:\n",
" submission = reddit.submission(id=post_id)\n",
" submission.name\n",
" return True\n",
" except Exception as e:\n",
" return False"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 74,
"outputs": [],
"source": [
"# Set up the detoxifier model:\n",
"from detoxify import Detoxify"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 75,
"metadata": {},
"outputs": [],
"source": [
"import re\n",
"\n",
"# Removes > sign and the template message at the end of a message\n",
"def cleanup_body_text(cmv_post):\n",
" lines = [line for line in cmv_post.splitlines()\n",
" if not line.lstrip().startswith(\"&gt;\")\n",
" and not line.lstrip().startswith(\"____\")\n",
" and not line.lstrip().startswith(\"So go forth and CMV, noble redditors!\")\n",
" and \"edit\" not in \" \".join(line.lower().split()[:2])\n",
" ]\n",
" return \"\\n\".join(lines)\n",
"\n",
"\n",
"\n",
"\n",
"# Create the function that will be handling all the data gathering\n",
"def get_top_comment_and_clean_data(post_id):\n",
" #print(post_id.lstrip(\"t3_\"))\n",
" last_author = \"\"\n",
" # Grab the post\n",
" submission = reddit.submission(id=post_id.lstrip(\"t3_\"))\n",
" #print(submission.title)\n",
"\n",
" # Grab the highest rated comment on root layer\n",
" submission.submission_type = 'best'\n",
" submission.comments.replace_more(limit=0)\n",
" replies = list(submission.comments)[0].replies.list()\n",
"\n",
" # Just some variables\n",
" pros = []\n",
"\n",
" # If the post author doesn't exist this submission was deleted (submission.deleted doesn't work)\n",
" if type(submission.author) == type(None):\n",
" last_author = \"[deleted]\"\n",
" else:\n",
" last_author = submission.author.name\n",
"\n",
" is_pro_argument = False\n",
"\n",
" for comment in replies:\n",
"\n",
" # If redditor object doesn't exist, the account is invalid/deleted\n",
" if type(comment.author) != type(None):\n",
" author = comment.author.name\n",
" else:\n",
" author = \"[deleted]\"\n",
"\n",
" # Assume that whenever the user changes, they are countering the previous person\n",
" if author != last_author:\n",
" is_pro_argument = !is_pro_argument\n",
"\n",
" if author == \"[deleted]\" or author==\"DeltaBot\":\n",
" #print(\"Skipping comment...\")\n",
" continue\n",
"\n",
" # Remove meta and duplicate comments\n",
" comment.body = \" \".join([line for line in comment.body.splitlines()\n",
" if not re.search(r\"(?i)(Change\\smy\\sview|CMV)\", line)\n",
" and line not in pros # Why doesn't this line work\n",
" ])\n",
"\n",
" # Sometimes for some reason duplicate entries exist\n",
" # Also remove automated message with \"Δ\" in it\n",
"\n",
" if comment.body in pros:\n",
" #print(\"Skipping duplicate entry\")\n",
" continue\n",
"\n",
" #print(\"\\t\\t>>\\t\",comment.body)\n",
"\n",
" # Remove toxic comments\n",
" if Detoxify(\"multilingual\").predict(comment.body)[\"toxicity\"] > TOXIC_THRESHOLD:\n",
" #print(\"Identified toxic comment, ignoring...\")\n",
" comment.body = \"\"\n",
"\n",
" # Add to the respective argument type \n",
" if is_pro_argument:\n",
" pros.append(comment.body)\n",
" \n",
" last_author = comment.author.name\n",
" \n",
" # Pros = arguments for the Title of this post\n",
" # Cons = arguments against the title of this post\n",
"\n",
" pros.append(comment.body)\n",
" return pros"
]
},
{
"cell_type": "code",
"execution_count": 76,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading in 10 posts\n"
]
}
],
"source": [
"print(f\"Loading in {ENTRIES_COUNT} posts\")\n",
"dataset = df.head(ENTRIES_COUNT)\n"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 77,
"outputs": [],
"source": [
"# the name column does some weird sh** because dataframes already have a name property, so migrate to a different column name\n",
"\n",
"import warnings\n",
"warnings.filterwarnings('ignore')\n",
"\n",
"dataset[\"post_id\"] = dataset[\"name\"]\n",
"warnings.filterwarnings('default')"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading in data... This will take a while.\n"
]
},
{
"data": {
"text/plain": " 0%| | 0/10 [00:00<?, ?it/s]",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "203fd74c8a5146b68b8af961bb3874c8"
}
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"<timed exec>:29: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
"/Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages/torch/serialization.py:997: ResourceWarning: unclosed file <_io.BufferedReader name='cmv.tar.bz2'>\n",
" storage = zip_file.get_storage_from_record(name, numel, torch._UntypedStorage).storage()._untyped()\n",
"ResourceWarning: Enable tracemalloc to get the object allocation traceback\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 7min 49s, sys: 2min 29s, total: 10min 19s\n",
"Wall time: 8min 45s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"from tqdm.auto import tqdm\n",
"# Reset variables for if we run this multiple times\n",
"all_pros = []\n",
"all_names = []\n",
"all_titles = []\n",
"all_sources = []\n",
"\n",
"print(\"Loading in data... This will take a while.\")\n",
"\n",
"for i in tqdm(range(dataset.shape[0])):\n",
"\n",
" post = dataset.iloc[i]\n",
" modified_title = post.title.replace('CMV', \"Change my mind\")\n",
" #print(f\"\\n Loading entry {i+1}/{dataset.shape[0]}:\\n\\t\\\"{modified_title}\\\"\")\n",
"\n",
" if type(post) == type(None):\n",
" continue\n",
"\n",
" assert(post.post_id != i)\n",
"\n",
" pros = get_top_comment_and_clean_data(post.post_id)\n",
"\n",
" if post.title == \"[deleted]\":\n",
" continue\n",
"\n",
" pros = \" \".join([*set(pros)])\n",
" pros = pros.replace(\"[deleted]\",\"\")\n",
"\n",
" post.selftext = cleanup_body_text(post.selftext)\n",
" all_titles.append(modified_title + \" \" + post.selftext)\n",
" all_pros.append(pros)\n",
" all_names.append(post.name)\n",
" all_sources.append(f\"https://reddit.com/r/changemyview/comments/{post.post_id}\")\n",
" #print(post.title)\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 83,
"outputs": [
{
"data": {
"text/plain": "'it\\'s already been signed. They even claim to be adhering to it, though they\\'ve been found to be violating it before. There is no such thing as \"de facto acceptance of Israel\\'s nuclear program.\" the Non-Proliferation Treaty is only binding for signatory states. Israel is not a signatory. Article 10 of the NPT allows them to withdraw if they so choose. they have not done so. a whole new country which explicitly has a right to withdraw from the NPT and has not chosen to do so. It\\'s more accurate, I think, to say that the problem with Iran here from a legal standpoint is that they aren\\'t honoring their own commitments, rather than that they\\'re building weapons. They could pull out of the NPT at any time, and the ball would be essentially in America\\'s court, because their nuclear program would no longer be illegal by international legal standards. However, Iran insists both on developing nukes *and* remaining an NPT signatory non-nuclear state, and that\\'s what makes their program illegal. I\\'d also like to clarify that I\\'m not making an ethical argument here, this is just how international law currently works. because international law doesn\\'t require states to sign treaties, it only requires them to adhere to treaties they\\'ve already signed. Israel isn\\'t defying the UN, at least not in this particular case. Think of the NPT less like a standard law within a state and more like a contract. Once you\\'ve signed, you\\'re bound by the contract, but if you never sign it then you haven\\'t broken a law, you\\'ve just decided not to agree to the terms you were offered. > Because Iran did sign the treaty, and thus are bound by it. They signed on July 1, 1968. Hmm. So is the argument here that it\\'s not \"ok\" for Iran to have a nuke, since they signed treaty not to do so. But it\\'s \"ok\" for Israel to have one because they never signed such thing? Can\\'t quite put my finger on it, but doesn\\'t seem quite right this one.'"
},
"execution_count": 83,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"all_pros[1]"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 80,
"outputs": [],
"source": [
"# Place it all into a Pandas Dataframe\n",
"clean_df = pd.DataFrame({\n",
" \"INSTRUCTION\": all_titles,\n",
" \"RESPONSE\": all_pros,\n",
" \"SOURCE\": all_sources\n",
"}, index=all_names\n",
")"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 81,
"metadata": {},
"outputs": [],
"source": [
"# Create Apache Paquete file\n",
"\n",
"import pyarrow as pa\n",
"import pyarrow.parquet as pq\n",
"\n",
"table = pa.Table.from_pandas(clean_df)\n",
"pq.write_table(table,\"output.parquet\")"
]
},
{
"cell_type": "code",
"execution_count": 82,
"outputs": [
{
"data": {
"text/plain": " INSTRUCTION \\\n0 Change my mind: I shouldn't get a job in this ... \n1 Change my mind: Iran has the right to develop ... \n2 Change my mind: The events in Paris suck...but... \n3 Change my mind: It is ok to hate a religion so... \n4 Change my mind: There is no productive reason ... \n5 Change my mind: Diet soda is perfectly healthy... \n6 Change my mind:Essential Oils are bullshit My ... \n7 Change my mind: I think the Paris shooting mak... \n8 Change my mind: Printing an image of the Musli... \n9 Change my mind: Philosophy has no tangible val... \n\n RESPONSE \\\n0 That is what someone in the 1500s would have s... \n1 it's already been signed. They even claim to b... \n2 Hm I guess I made the OP incorrectly. The mai... \n3 I don't understand your analogy. Promoting a ... \n4 ∆ I hadn't thought it from a \"let's trick peop... \n5 Thanks for a fresh argument! I hadn't conside... \n6 Most do. Some smell kinda funky. \n7 I already said in different comments that thi... \n8 The first bacon sandwich came about because 9... \n9 >Why restrict it to 50 years? I can name all s... \n\n SOURCE \n0 https://reddit.com/r/changemyview/comments/t3_... \n1 https://reddit.com/r/changemyview/comments/t3_... \n2 https://reddit.com/r/changemyview/comments/t3_... \n3 https://reddit.com/r/changemyview/comments/t3_... \n4 https://reddit.com/r/changemyview/comments/t3_... \n5 https://reddit.com/r/changemyview/comments/t3_... \n6 https://reddit.com/r/changemyview/comments/t3_... \n7 https://reddit.com/r/changemyview/comments/t3_... \n8 https://reddit.com/r/changemyview/comments/t3_... \n9 https://reddit.com/r/changemyview/comments/t3_... ",
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>INSTRUCTION</th>\n <th>RESPONSE</th>\n <th>SOURCE</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>Change my mind: I shouldn't get a job in this ...</td>\n <td>That is what someone in the 1500s would have s...</td>\n <td>https://reddit.com/r/changemyview/comments/t3_...</td>\n </tr>\n <tr>\n <th>1</th>\n <td>Change my mind: Iran has the right to develop ...</td>\n <td>it's already been signed. They even claim to b...</td>\n <td>https://reddit.com/r/changemyview/comments/t3_...</td>\n </tr>\n <tr>\n <th>2</th>\n <td>Change my mind: The events in Paris suck...but...</td>\n <td>Hm I guess I made the OP incorrectly. The mai...</td>\n <td>https://reddit.com/r/changemyview/comments/t3_...</td>\n </tr>\n <tr>\n <th>3</th>\n <td>Change my mind: It is ok to hate a religion so...</td>\n <td>I don't understand your analogy. Promoting a ...</td>\n <td>https://reddit.com/r/changemyview/comments/t3_...</td>\n </tr>\n <tr>\n <th>4</th>\n <td>Change my mind: There is no productive reason ...</td>\n <td>∆ I hadn't thought it from a \"let's trick peop...</td>\n <td>https://reddit.com/r/changemyview/comments/t3_...</td>\n </tr>\n <tr>\n <th>5</th>\n <td>Change my mind: Diet soda is perfectly healthy...</td>\n <td>Thanks for a fresh argument! I hadn't conside...</td>\n <td>https://reddit.com/r/changemyview/comments/t3_...</td>\n </tr>\n <tr>\n <th>6</th>\n <td>Change my mind:Essential Oils are bullshit My ...</td>\n <td>Most do. Some smell kinda funky.</td>\n <td>https://reddit.com/r/changemyview/comments/t3_...</td>\n </tr>\n <tr>\n <th>7</th>\n <td>Change my mind: I think the Paris shooting mak...</td>\n <td>I already said in different comments that thi...</td>\n <td>https://reddit.com/r/changemyview/comments/t3_...</td>\n </tr>\n <tr>\n <th>8</th>\n <td>Change my mind: Printing an image of the Musli...</td>\n <td>The first bacon sandwich came about because 9...</td>\n <td>https://reddit.com/r/changemyview/comments/t3_...</td>\n </tr>\n <tr>\n <th>9</th>\n <td>Change my mind: Philosophy has no tangible val...</td>\n <td>&gt;Why restrict it to 50 years? I can name all s...</td>\n <td>https://reddit.com/r/changemyview/comments/t3_...</td>\n </tr>\n </tbody>\n</table>\n</div>"
},
"execution_count": 82,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Test to see if it was sucessful\n",
"table = pq.read_table(\"output.parquet\")\n",
"table.to_pandas()"
],
"metadata": {
"collapsed": false
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
@@ -38,6 +38,8 @@ class OasstErrorCode(IntEnum):
TASK_REQUESTED_TYPE_NOT_AVAILABLE = 1006
TASK_AVAILABILITY_QUERY_FAILED = 1007
TASK_MESSAGE_TOO_LONG = 1008
TASK_MESSAGE_DUPLICATED = 1009
TASK_MESSAGE_TEXT_EMPTY = 1010
# 2000-3000: prompt_repository
INVALID_FRONTEND_MESSAGE_ID = 2000
@@ -13,7 +13,11 @@ class WorkRequest(pydantic.BaseModel):
conversation: protocol.Conversation = pydantic.Field(..., repr=False)
model_name: str = "distilgpt2"
max_new_tokens: int = 100
seed: int = pydantic.Field(default_factory=lambda: random.randint(0, 2**32 - 1))
seed: int = pydantic.Field(default_factory=lambda: random.randint(-(2**31), 2**31 - 1))
do_sample: bool = True
top_k: int = 50
top_p: float = 0.9
temperature: float = 1.0
class WorkResponsePacket(pydantic.BaseModel):
+62 -33
View File
@@ -57,6 +57,8 @@ class ConversationMessage(BaseModel):
text: str
lang: Optional[str] # BCP 47
is_assistant: bool
emojis: Optional[dict[str, int]] = None
user_emojis: Optional[list[str]] = None
class Conversation(BaseModel):
@@ -80,7 +82,6 @@ class Conversation(BaseModel):
class Message(ConversationMessage):
parent_id: Optional[UUID] = None
created_date: Optional[datetime] = None
emojis: Optional[dict] = None
class MessagePage(PageResult):
@@ -223,27 +224,43 @@ class LabelTaskMode(str, enum.Enum):
full = "full"
class LabelInitialPromptTask(Task):
class LabelTaskDisposition(str, enum.Enum):
"""Reason why the task was issued."""
quality = "quality"
spam = "spam"
class LabelDescription(BaseModel):
name: str
widget: str
display_text: str
help_text: Optional[str]
class AbstractLabelTask(Task):
message_id: UUID
valid_labels: list[str]
mandatory_labels: Optional[list[str]]
mode: Optional[LabelTaskMode]
disposition: Optional[LabelTaskDisposition]
labels: Optional[list[LabelDescription]]
class LabelInitialPromptTask(AbstractLabelTask):
"""A task to label an initial prompt."""
type: Literal["label_initial_prompt"] = "label_initial_prompt"
message_id: UUID
prompt: str
valid_labels: list[str]
mandatory_labels: Optional[list[str]]
mode: Optional[LabelTaskMode]
class LabelConversationReplyTask(Task):
class LabelConversationReplyTask(AbstractLabelTask):
"""A task to label a reply to a conversation."""
type: Literal["label_conversation_reply"] = "label_conversation_reply"
conversation: Conversation # the conversation so far
message_id: UUID
conversation: Conversation # the conversation so far (new: including the reply message)
reply_message: Optional[ConversationMessage]
reply: str
valid_labels: list[str]
mandatory_labels: Optional[list[str]]
mode: Optional[LabelTaskMode]
class LabelPrompterReplyTask(LabelConversationReplyTask):
@@ -316,39 +333,48 @@ class MessageRanking(Interaction):
ranking: conlist(item_type=int, min_items=1)
class LabelWidget(str, enum.Enum):
yes_no = "yes_no"
flag = "flag"
likert = "likert"
class TextLabel(str, enum.Enum):
"""A label for a piece of text."""
def __new__(cls, label: str, display_text: str = "", help_text: str = None):
def __new__(cls, label: str, widget: LabelWidget, display_text: str = "", help_text: str = None):
obj = str.__new__(cls, label)
obj._value_ = label
obj.widget = widget
obj.display_text = display_text
obj.help_text = help_text
return obj
spam = "spam", "Seems to be intentionally low-quality or irrelevant"
fails_task = "fails_task", "Fails to follow the correct instruction / task"
not_appropriate = "not_appropriate", "Inappropriate for customer assistant"
violence = "violence", "Encourages or fails to discourage violence/abuse/terrorism/self-harm"
excessive_harm = (
"excessive_harm",
"Content likely to cause excessive harm not justifiable in the context",
"Harm refers to physical or mental damage or injury to someone or something. Excessive refers to a reasonable threshold of harm in the context, for instance damaging skin is not excessive in the context of surgery.",
)
sexual_content = "sexual_content", "Contains sexual content"
toxicity = "toxicity", "Contains rude, abusive, profane or insulting content"
moral_judgement = "moral_judgement", "Expresses moral judgement"
political_content = "political_content", "Expresses political views"
humor = "humor", "Contains humorous content including sarcasm"
# yes/no questions
spam = "spam", LabelWidget.yes_no, "Seems to be intentionally low-quality or irrelevant"
fails_task = "fails_task", LabelWidget.yes_no, "Fails to follow the correct instruction / task"
# flags
pii = "pii", LabelWidget.flag, "Contains personal identifiable information (PII)"
not_appropriate = "not_appropriate", LabelWidget.flag, "Inappropriate"
hate_speech = (
"hate_speech",
LabelWidget.flag,
"Content is abusive or threatening and expresses prejudice against a protected characteristic",
"Prejudice refers to preconceived views not based on reason. Protected characteristics include gender, ethnicity, religion, sexual orientation, and similar characteristics.",
"Prejudice refers to preconceived views not based on reason. Protected characteristics "
"include gender, ethnicity, religion, sexual orientation, and similar characteristics.",
)
threat = "threat", "Contains a threat against a person or persons"
misleading = "misleading", "Contains text which is incorrect or misleading"
helpful = "helpful", "Completes the task to a high standard"
creative = "creative", "Expresses creativity in responding to the task"
sexual_content = "sexual_content", LabelWidget.flag, "Contains sexual content"
moral_judgement = "moral_judgement", LabelWidget.flag, "Expresses moral judgement"
political_content = "political_content", LabelWidget.flag, "Expresses political views"
# likert
quality = "quality", LabelWidget.likert, "Overall subjective quality rating of the message"
toxicity = "toxicity", LabelWidget.likert, "Rude, abusive, profane or insulting content"
humor = "humor", LabelWidget.likert, "Humorous content including sarcasm"
helpfulness = "helpfulness", LabelWidget.likert, "Helpfulness of the message"
creativity = "creativity", LabelWidget.likert, "Creativity"
violence = "violence", LabelWidget.likert, "Violence/abuse/terrorism/self-harm"
class TextLabels(Interaction):
@@ -359,6 +385,7 @@ class TextLabels(Interaction):
labels: dict[TextLabel, float]
message_id: UUID
task_id: Optional[UUID]
is_report: Optional[bool]
@property
def has_message_id(self) -> bool:
@@ -440,7 +467,9 @@ class EmojiCode(str, enum.Enum):
thumbs_down = "-1" # 👎
red_flag = "red_flag" # 🚩
hundred = "100" # 💯
rofl = "rofl" # 🤣"
rofl = "rofl" # 🤣
clap = "clap" # 👏
diamond = "diamond" # 💎
heart_eyes = "heart_eyes" # 😍
disappointed = "disappointed" # 😞
poop = "poop" # 💩
+3
View File
@@ -7,6 +7,9 @@ DATABASE_URL=postgres://postgres:postgres@localhost:5433/oasst_web
FASTAPI_URL=http://localhost:8080
FASTAPI_KEY=1234
# Used to expose the backend url to the clientside javascript
NEXT_PUBLIC_BACKEND_URL=$FASTAPI_URL
# A dev Auth Secret. Can be exposed if we never use this publicly.
NEXTAUTH_SECRET=O/M2uIbGj+lDD2oyNa8ax4jEOJqCPJzO53UbWShmq98=
+4
View File
@@ -0,0 +1,4 @@
.eslintrc.json
tailwind.config.js
.storybook/*
public/mockServiceWorker.js
+6
View File
@@ -10,6 +10,7 @@ module.exports = {
"@storybook/addon-essentials",
"@storybook/addon-interactions",
"@chakra-ui/storybook-addon",
"storybook-addon-next-router",
],
framework: "@storybook/react",
core: {
@@ -23,7 +24,12 @@ module.exports = {
config.resolve.alias = {
...config.resolve.alias,
src: path.resolve(__dirname, "../src"),
styles: path.resolve(__dirname, "../styles"),
};
config.resolve.fallback = {
fs: false,
path: require.resolve('path-browserify'),
}
return config;
},
features: {
+50
View File
@@ -1,4 +1,38 @@
import "!style-loader!css-loader!postcss-loader!tailwindcss/tailwind.css";
import { RouterContext } from "next/dist/shared/lib/router-context";
import { initialize, mswDecorator } from "msw-storybook-addon";
import { rest } from "msw";
// Initialize MSW
initialize();
// Provide the MSW addon decorator globally
export const decorators = [mswDecorator];
const MOCK_VALID_LABELS= [
{
name: "spam",
display_text: "Seems to be intentionally low-quality or irrelevant",
help_text: null,
},
{
name: "fails_task",
display_text:
"Fails to follow the correct instruction / task",
help_text: null,
},
{
name: "not_appropriate",
display_text: "Inappropriate for customer assistant",
help_text: null,
},
{
name: "violence",
display_text:
"Encourages or fails to discourage violence/abuse/terrorism/self-harm",
help_text: null,
},
];
export const parameters = {
actions: { argTypesRegex: "^on[A-Z].*" },
@@ -8,6 +42,22 @@ export const parameters = {
date: /Date$/,
},
},
nextRouter: {
Provider: RouterContext.Provider,
},
msw: {
handlers: {
labels: [
rest.get("/api/valid_labels", (req, res, ctx) => {
return res(
ctx.json({
valid_labels: MOCK_VALID_LABELS
})
);
}),
],
},
},
};
// Hacky solution to get Images in next to work
+67 -100
View File
@@ -2,8 +2,7 @@
## Purpose
This provides a comprehensive webapp interface for LAION's Open Assistant
project. Initially it will support:
This provides a comprehensive webapp interface for LAION's Open Assistant project. Initially it will support:
1. User registration using either Discord or Email.
1. Adding responses to incomplete Open Assistant tasks.
@@ -11,8 +10,7 @@ project. Initially it will support:
1. Viewing an activity leaderboard.
1. Tracking community wide updates.
This interface compliments the Discord bot and will give access to the same
underlying tasks.
This interface compliments the Discord bot and will give access to the same underlying tasks.
## Contributing
@@ -22,67 +20,54 @@ This website is built using:
1. [npm](https://www.npmjs.com/): The node package manager for building.
1. [React](https://reactjs.org/): The core frontend framework.
1. [Next.js](https://nextjs.org/): A React scaffolding framework to streamline
development.
1. [Prisma](https://www.prisma.io/): An ORM to interact with a web specific
[Postgres](https://www.postgresql.org/) database.
1. [NextAuth.js](https://next-auth.js.org/): A user authentication framework to
ensure we handle accounts with best practices.
1. [TailwindCSS](https://tailwindcss.com/): A general purpose framework for
styling any component.
1. [Chakra-UI](https://chakra-ui.com/): A wide collection of pre-built UI
components that generally look pretty good.
1. [Next.js](https://nextjs.org/): A React scaffolding framework to streamline development.
1. [Prisma](https://www.prisma.io/): An ORM to interact with a web specific [Postgres](https://www.postgresql.org/)
database.
1. [NextAuth.js](https://next-auth.js.org/): A user authentication framework to ensure we handle accounts with best
practices.
1. [TailwindCSS](https://tailwindcss.com/): A general purpose framework for styling any component.
1. [Chakra-UI](https://chakra-ui.com/): A wide collection of pre-built UI components that generally look pretty good.
### Set up your environment
To contribute to the website, make sure you have the following setup and
installed:
To contribute to the website, make sure you have the following setup and installed:
1. [NVM](https://github.com/nvm-sh/nvm): The Node Version Manager makes it easy
to ensure you have the right NodeJS version installed. Once installed, run
`nvm use 16` to use Node 16.x. The website is known to be stable with NodeJS
1. [NVM](https://github.com/nvm-sh/nvm): The Node Version Manager makes it easy to ensure you have the right NodeJS
version installed. Once installed, run `nvm use 16` to use Node 16.x. The website is known to be stable with NodeJS
version 16.x. This will install both Node and NPM.
1. [Docker](https://www.docker.com/): We use docker to simplify running
dependent services.
1. [Docker](https://www.docker.com/): We use docker to simplify running dependent services.
### Getting everything up and running
If you're doing active development we suggest the following workflow:
1. In one tab, navigate to the project root.
1. Run `docker compose up frontend-dev --build --attach-dependencies`. You can
optionally include `-d` to detach and later track the logs if desired.
1. Run `docker compose up frontend-dev --build --attach-dependencies`. You can optionally include `-d` to detach and
later track the logs if desired.
1. In another tab navigate to `${OPEN_ASSISTANT_ROOT/website`.
1. Run `npm ci`
1. Run `npx prisma db push` (This is also needed when you restart the docker
stack from scratch).
1. Run `npm run dev`. Now the website is up and running locally at
`http://localhost:3000`.
1. To create an account, login via the user using email authentication and
navigate to `http://localhost:1080`. Check the email listed and click the
log in link. You're now logged in and authenticated.
1. Run `npx prisma db push` (This is also needed when you restart the docker stack from scratch).
1. Run `npm run dev`. Now the website is up and running locally at `http://localhost:3000`.
1. To create an account, login via the user using email authentication and navigate to `http://localhost:1080`. Check
the email listed and click the log in link. You're now logged in and authenticated.
### Using debug user credentials
You can use the debug credentials provider to log in without fancy emails or
OAuth.
You can use the debug credentials provider to log in without fancy emails or OAuth.
1. This feature is automatically on in development mode, i.e. when you run
`npm run dev`. In case you want to do the same with a production build (for
example, the docker image), then run the website with environment variable
1. This feature is automatically on in development mode, i.e. when you run `npm run dev`. In case you want to do the
same with a production build (for example, the docker image), then run the website with environment variable
`DEBUG_LOGIN=true`.
1. Use the `Login` button in the top right to go to the login page.
1. You should see a section for debug credentials. Enter any username you wish,
you will be logged in as that user.
1. You should see a section for debug credentials. Enter any username you wish, you will be logged in as that user.
### Using Storybook
To develop components using [Storybook](https://storybook.js.org/) run
`npm run storybook`. Then navigate to in your browser to
`http://localhost:6006`.
To develop components using [Storybook](https://storybook.js.org/) run `npm run storybook`. Then navigate to in your
browser to `http://localhost:6006`.
To create a new story create a file named `[componentName].stories.js`. An
example how such a story could look like, see `Header.stories.jsx`.
To create a new story create a file named `[componentName].stories.js`. An example how such a story could look like, see
`Header.stories.jsx`.
## Code Layout
@@ -90,12 +75,10 @@ example how such a story could look like, see `Header.stories.jsx`.
All react code is under `src/` with a few sub directories:
1. `pages/`: All pages a user could navigate too and API URLs which are under
`pages/api/`.
1. `components/`: All re-usable React components. If something gets used twice
we should create a component and put it here.
1. `lib/`: A generic place to store library files that are used anywhere. This
doesn't have much structure yet.
1. `pages/`: All pages a user could navigate too and API URLs which are under `pages/api/`.
1. `components/`: All re-usable React components. If something gets used twice we should create a component and put it
here.
1. `lib/`: A generic place to store library files that are used anywhere. This doesn't have much structure yet.
NOTE: `styles/` can be ignored for now.
@@ -113,25 +96,20 @@ We're not really using CSS styles. `styles/` can be ignored.
## Testing the UI
Cypress is used for end-to-end (e2e) and component testing and is configured in
`./cypress.config.ts`. The `./cypress` folder is used for supporting
configuration files etc.
Cypress is used for end-to-end (e2e) and component testing and is configured in `./cypress.config.ts`. The `./cypress`
folder is used for supporting configuration files etc.
- Store e2e tests in the `./cypress/e2e` folder.
- Store component tests adjacent to the component being tested. If you want to
wriite a test for `./src/components/Layout.tsx` then store the test file at
`./src/components/Layout.cy.tsx`.
- Store component tests adjacent to the component being tested. If you want to wriite a test for
`./src/components/Layout.tsx` then store the test file at `./src/components/Layout.cy.tsx`.
A few npm scripts are available for convenience:
- `npm run cypress`: Useful for development, it opens Cypress and allows you to
explore, run and debug tests. It assumes you have the NextJS site running at
`localhost:3000`.
- `npm run cypress:run`: Runs all tests. Useful for a quick sanity check before
sending a PR or to run in CI pipelines.
- `npm run cypress:image-baseline`: If you have tests failing because of visual
changes that was expected, this command will update the baseline images stored
in `./cypress-visual-screenshots/baseline` with those from the adjacent
- `npm run cypress`: Useful for development, it opens Cypress and allows you to explore, run and debug tests. It assumes
you have the NextJS site running at `localhost:3000`.
- `npm run cypress:run`: Runs all tests. Useful for a quick sanity check before sending a PR or to run in CI pipelines.
- `npm run cypress:image-baseline`: If you have tests failing because of visual changes that was expected, this command
will update the baseline images stored in `./cypress-visual-screenshots/baseline` with those from the adjacent
comparison folder. More can be found in the
[docs of `uktrade/cypress-image-diff`](https://github.com/uktrade/cypress-image-diff/blob/main/docs/CLI.md#update-all-baseline-images-for-failing-tests).
@@ -141,10 +119,9 @@ Read more in the [./cypress README](cypress/).
Jest and React Testing Library are used for unit testing JS/TS/TSX code.
- Store unit test files adjacent to the file being tested and have the filename
end with `.test.ts` for non-React code or `.test.tsx` for React code.
- `npm run jest`: automatically runs tests and watches for any relevant changes
to rerun tests.
- Store unit test files adjacent to the file being tested and have the filename end with `.test.ts` for non-React code
or `.test.tsx` for React code.
- `npm run jest`: automatically runs tests and watches for any relevant changes to rerun tests.
Read more in the [./src/README.md](src/README.md).
@@ -152,30 +129,25 @@ Read more in the [./src/README.md](src/README.md).
When writing code for the website, we have a few best practices:
1. When importing packages import external dependencies first then local
dependencies. Order them alphabetically according to the package name.
1. When trying to implement something new, check if
[Chakra-UI](https://chakra-ui.com/) has components that are close enough to
your need. For example Sliders, Radio Buttons, Progress indicators, etc.
They have a lot and we can save time by re-using what they have and tweaking
the style as needed.
1. Format everything with [Prettier](https://prettier.io/). This is done by
default with pre-submits. We currently don't have any custom settings.
1. Define functional React components (with types for all properties when
feasible).
1. When importing packages import external dependencies first then local dependencies. Order them alphabetically
according to the package name.
1. When trying to implement something new, check if [Chakra-UI](https://chakra-ui.com/) has components that are close
enough to your need. For example Sliders, Radio Buttons, Progress indicators, etc. They have a lot and we can save
time by re-using what they have and tweaking the style as needed.
1. Format everything with [Prettier](https://prettier.io/). This is done by default with pre-submits. We currently
don't have any custom settings.
1. Define functional React components (with types for all properties when feasible).
### Developing New Features
When working on new features or making significant changes that can't be done
within a single Pull Request, we ask that you make use of Feature Flags.
When working on new features or making significant changes that can't be done within a single Pull Request, we ask that
you make use of Feature Flags.
We've set up
[`react-feature-flags`](https://www.npmjs.com/package/react-feature-flags) to
make this easier. To get started:
We've set up [`react-feature-flags`](https://www.npmjs.com/package/react-feature-flags) to make this easier. To get
started:
1. Add a new flag entry to `website/src/flags.ts`. We have an example flag you
can copy as an example. Be sure to `isActive` to true when testing your
features but false when submitting your PR.
1. Add a new flag entry to `website/src/flags.ts`. We have an example flag you can copy as an example. Be sure to
`isActive` to true when testing your features but false when submitting your PR.
1. Use your flag wherever you add a new UI element. This can be done with:
```js
@@ -188,29 +160,24 @@ import { Flags } from "react-feature-flags";
You can see an example of how this works by checking `website/src/components/Header/Headers.tsx` where we use `flagTest`.
1. Once you've finished building out the feature and it is ready for everyone
to use, it's safe to remove the `Flag` wrappers around your component and
the entry in `flags.ts`.
1. Once you've finished building out the feature and it is ready for everyone to use, it's safe to remove the `Flag`
wrappers around your component and the entry in `flags.ts`.
### URL Paths
To use stable and consistent URL paths, we recommend the following strategy for
new tasks:
To use stable and consistent URL paths, we recommend the following strategy for new tasks:
1. For any task that involves writing a free-form response, put the page under
`website/src/pages/create` with a page name matching the task type, such as
`initial_prompt.tsx`.
1. For any task that evaluates, rates, or ranks content, put the page under
`website/src/pages/evaluate` with a page name matching the task type such as
`rank_initial_prompts.tsx`.
1. For any task that involves writing a free-form response, put the page under `website/src/pages/create` with a page
name matching the task type, such as `initial_prompt.tsx`.
1. For any task that evaluates, rates, or ranks content, put the page under `website/src/pages/evaluate` with a page
name matching the task type such as `rank_initial_prompts.tsx`.
With this we'll be able to ensure these contribution pages are hidden from
logged out users but accessible to logged in users.
With this we'll be able to ensure these contribution pages are hidden from logged out users but accessible to logged in
users.
## Learn More
To learn more about Next.js, take a look at the following resources:
- [Next.js Documentation](https://nextjs.org/docs) - learn about Next.js
features and API.
- [Next.js Documentation](https://nextjs.org/docs) - learn about Next.js features and API.
- [Learn Next.js](https://nextjs.org/learn) - an interactive Next.js tutorial.
+40 -58
View File
@@ -1,24 +1,19 @@
# Component and e2e testing with Cypress
[Cypress](https://www.cypress.io/) is used for both component- and end-to-end
testing. Below there's a few examples for the context of this site. To learn
more, the
[Cypress documentation](https://docs.cypress.io/guides/getting-started/opening-the-app)
has it all.
[Cypress](https://www.cypress.io/) is used for both component- and end-to-end testing. Below there's a few examples for
the context of this site. To learn more, the
[Cypress documentation](https://docs.cypress.io/guides/getting-started/opening-the-app) has it all.
Don't get scared by the commercial offerings they offer. Their core is open
source, the cloud offering is not necesarry at all and can be replaced by CI
tooling and [community efforts](https://sorry-cypress.dev/).
Don't get scared by the commercial offerings they offer. Their core is open source, the cloud offering is not necesarry
at all and can be replaced by CI tooling and [community efforts](https://sorry-cypress.dev/).
# Component testing
To write a new component test, you either create a new `.tsx` adjacent to the
component you want to test or you can use the guide presented yo you when
running `npm run cypress` which allows you to easily create the skeleton test
for an existing component.
To write a new component test, you either create a new `.tsx` adjacent to the component you want to test or you can use
the guide presented yo you when running `npm run cypress` which allows you to easily create the skeleton test for an
existing component.
If you have a `Button.tsx` component, create a file next to it called
`Button.cy.tsx` which could look like this:
If you have a `Button.tsx` component, create a file next to it called `Button.cy.tsx` which could look like this:
```typescript
import React from "react";
@@ -35,28 +30,24 @@ describe("<Button />", () => {
## What's happening here?
First we use `cy.mount` to mount our component under test. Notive how we specify
`className` and inner text - this is where we arrange our component with fake
data that we could assert on later.
First we use `cy.mount` to mount our component under test. Notive how we specify `className` and inner text - this is
where we arrange our component with fake data that we could assert on later.
In the example above, we also use `cy.get` to select the rendered `button`
element. Cypress has multiple ways to
[select elements](https://docs.cypress.io/guides/references/best-practices),
`get` is just one of them (and often not recommended).
In the example above, we also use `cy.get` to select the rendered `button` element. Cypress has multiple ways to
[select elements](https://docs.cypress.io/guides/references/best-practices), `get` is just one of them (and often not
recommended).
At last, we use `captureSnapshot` which is a plugin that snaps a photo of the
`button` element and compares it to a baseline located in the
`./cypress-visual-screenshots/baseline/` folder. If there's too many unidentical
pixels between the two, it will fail the test.
At last, we use `captureSnapshot` which is a plugin that snaps a photo of the `button` element and compares it to a
baseline located in the `./cypress-visual-screenshots/baseline/` folder. If there's too many unidentical pixels between
the two, it will fail the test.
# End-to-end (e2e) testing
e2e tests are stored in the `./cypress/e2e` folder and should be named
`{page}.cy.ts` and located in a relative folder structure that mirrors the page
under test.
e2e tests are stored in the `./cypress/e2e` folder and should be named `{page}.cy.ts` and located in a relative folder
structure that mirrors the page under test.
When running `npm run cypress` and selecting e2e testing, we assume you have the
NextJS site running at `localhost:3000`.
When running `npm run cypress` and selecting e2e testing, we assume you have the NextJS site running at
`localhost:3000`.
An example test could look as follows:
@@ -74,39 +65,33 @@ export {};
## What's happening here?
First we use [`cy.visit`](https://docs.cypress.io/api/commands/visit) to point
the browser at the desired page. It appends relative paths to the configured
`baseUrl` (found in `./cypress.config.ts`).
First we use [`cy.visit`](https://docs.cypress.io/api/commands/visit) to point the browser at the desired page. It
appends relative paths to the configured `baseUrl` (found in `./cypress.config.ts`).
Cypress will
[automatically await](https://docs.cypress.io/guides/core-concepts/introduction-to-cypress#Timeouts)
almost anything you do, but fail if the default timeout is reached.
Cypress will [automatically await](https://docs.cypress.io/guides/core-concepts/introduction-to-cypress#Timeouts) almost
anything you do, but fail if the default timeout is reached.
Then we get the email input field and type our email address. We find the input
field using the data-cy attribute that we added in the source code of the
element on the page.
Then we get the email input field and type our email address. We find the input field using the data-cy attribute that
we added in the source code of the element on the page.
```jsx
<Input data-cy="email-address" placeholder="Email Address" />
```
Using `data-cy` is how we ensure that selecting the element is robust to changes
in page design or function and is one of the
Using `data-cy` is how we ensure that selecting the element is robust to changes in page design or function and is one
of the
[best practices recommended by Cypress](https://docs.cypress.io/guides/references/best-practices#Selecting-Elements).
Next we call `type()` to use the keyboard, cypress will automatically focus the
element and send the keypress events. Notice the `{enter}` keyword, this will
cause Cypress to hit the return key which we expect to submit the form.
Next we call `type()` to use the keyboard, cypress will automatically focus the element and send the keypress events.
Notice the `{enter}` keyword, this will cause Cypress to hit the return key which we expect to submit the form.
We then assert that the URL should contain `/auth/verify`. Again the timeout
will make sure we are not waiting forever, and the test will fail if we do not
manage to get there in a reasonable time.
We then assert that the URL should contain `/auth/verify`. Again the timeout will make sure we are not waiting forever,
and the test will fail if we do not manage to get there in a reasonable time.
## Authenticating in e2e tests
For end-to-end tests almost every test will need to first sign in to the
website. To make this easier we have a custom command for Cypress that makes
logging in with an email address a single command, `cy.signInWithEmail()`.
For end-to-end tests almost every test will need to first sign in to the website. To make this easier we have a custom
command for Cypress that makes logging in with an email address a single command, `cy.signInWithEmail()`.
```typescript
describe("replying as the assistant", () => {
@@ -115,16 +100,13 @@ describe("replying as the assistant", () => {
cy.visit("/create/assistant_reply");
cy.get('[data-cy="reply"').type(
"You need to run pre-commit to make the reviewer happy."
);
cy.get('[data-cy="reply"').type("You need to run pre-commit to make the reviewer happy.");
cy.get('[data-cy="submit"]').click();
});
});
```
In this example we sign in as `cypress@example.com` before visiting the
`/create/assistant_reply` page that is only available when authenticated. We can
then continue on with our test as normal. Note: using `cy.signInWithEmail()`
requires that the maildev is running, which should have been started as part of
the `docker compose up` command that is required to do any end-to-end testing.
In this example we sign in as `cypress@example.com` before visiting the `/create/assistant_reply` page that is only
available when authenticated. We can then continue on with our test as normal. Note: using `cy.signInWithEmail()`
requires that the maildev is running, which should have been started as part of the `docker compose up` command that is
required to do any end-to-end testing.
+1 -4
View File
@@ -7,9 +7,6 @@ describe("<Container />", () => {
const className = "my-class";
const text = "test_container";
cy.mount(<Container className={className}>{text}</Container>);
cy.get(`div.${className}`)
.should("have.class", className)
.should("be.visible")
.should("contain", text);
cy.get(`div.${className}`).should("have.class", className).should("be.visible").should("contain", text);
});
});
@@ -12,25 +12,18 @@ describe("Contract test for Oasst API", function () {
} as BackendUserCore;
it("can fetch a task", async () => {
expect(await oasstApiClient.fetchTask("random", testUser)).to.be.not.null;
expect(await oasstApiClient.fetchTask("random", testUser, "en")).to.be.not.null;
});
it("can ack a task", async () => {
const task = await oasstApiClient.fetchTask("random", testUser);
const task = await oasstApiClient.fetchTask("random", testUser, "en");
expect(await oasstApiClient.ackTask(task.id, "321")).to.be.null;
});
it("can record a taskInteraction", async () => {
const task = await oasstApiClient.fetchTask("random", testUser);
const task = await oasstApiClient.fetchTask("random", testUser, "en");
expect(
await oasstApiClient.interactTask(
"text_reply_to_message",
task.id,
"321",
"1",
{ text: "Test" },
testUser
)
await oasstApiClient.interactTask("text_reply_to_message", task.id, "321", "1", { text: "Test" }, testUser, "en")
).to.be.not.null;
});
@@ -11,6 +11,10 @@ describe("labeling assistant replies", () => {
// For specific task pages the no task available result is normal.
if (type === undefined) return;
cy.get('[data-cy="label-question"]').each((label) => {
// Click the no button, this generally approves the spam check
cy.wrap(label).find('[data-cy="no"]').click();
});
cy.get('[data-cy="label-options"]').each((label) => {
// Click the 4th option
cy.wrap(label).find('[data-cy="radio-option"]').eq(3).click();
@@ -11,6 +11,10 @@ describe("labeling initial prompts", () => {
// For specific task pages the no task available result is normal.
if (type === undefined) return;
cy.get('[data-cy="label-question"]').each((label) => {
// Click the no button, this generally approves the spam check
cy.wrap(label).find('[data-cy="no"]').click();
});
cy.get('[data-cy="label-options"]').each((label) => {
// Click the 4th option
cy.wrap(label).find('[data-cy="radio-option"]').eq(3).click();
@@ -11,6 +11,10 @@ describe("labeling prompter replies", () => {
// For specific task pages the no task available result is normal.
if (type === undefined) return;
cy.get('[data-cy="label-question"]').each((label) => {
// Click the no button, this generally approves the spam check
cy.wrap(label).find('[data-cy="no"]').click();
});
cy.get('[data-cy="label-options"]').each((label) => {
// Click the 4th option
cy.wrap(label).find('[data-cy="radio-option"]').eq(3).click();
@@ -0,0 +1,23 @@
describe("no tasks available", () => {
it("displays an empty state when no tasks are available", () => {
cy.signInWithEmail("cypress@example.com");
cy.intercept(
{
method: "GET",
url: "/api/new_task/prompter_reply",
},
{
statusCode: 500,
body: {
message: "No tasks of type 'label_prompter_reply' are currently available.",
errorCode: 1006,
httpStatusCode: 503,
},
}
).as("newTaskPrompterReply");
cy.visit("/create/user_reply");
cy.wait("@newTaskPrompterReply").then(() => {
cy.get('[data-cy="cy-no-tasks"]').should("exist");
});
});
});
+5 -3
View File
@@ -44,6 +44,10 @@ describe("handles random tasks", () => {
break;
}
case "label-task": {
cy.get('[data-cy="label-question"]').each((label) => {
// Click the no button, this generally approves the spam check
cy.wrap(label).find('[data-cy="no"]').click();
});
cy.get('[data-cy="label-options"]').each((label) => {
// Click the 4th option
cy.wrap(label).find('[data-cy="radio-option"]').eq(3).click();
@@ -56,9 +60,7 @@ describe("handles random tasks", () => {
break;
}
case undefined: {
throw new Error(
"No tasks available, but at least create initial prompt expected"
);
throw new Error("No tasks available, but at least create initial prompt expected");
}
default:
throw new Error(`Unexpected task type: ${type}`);
+6 -10
View File
@@ -37,20 +37,16 @@
// }
Cypress.Commands.add("signInUsingEmailedLink", (emailAddress) => {
const mailDevApi = `${Cypress.env("MAILDEV_PROTOCOL")}://${Cypress.env(
"MAILDEV_HOST"
)}:${Cypress.env("MAILDEV_API_PORT")}`;
cy.request(
"GET",
`${mailDevApi}/email?headers.to=${emailAddress.toLowerCase()}`
).then((response) => {
const mailDevApi = `${Cypress.env("MAILDEV_PROTOCOL")}://${Cypress.env("MAILDEV_HOST")}:${Cypress.env(
"MAILDEV_API_PORT"
)}`;
cy.request("GET", `${mailDevApi}/email?headers.to=${emailAddress.toLowerCase()}`).then((response) => {
const emails = response.body;
// Find and use login link
const loginLink = emails
.pop()
.html.match(/href="[^"]+(\/api\/auth\/callback\/[^"]+?)"/)[1];
const loginLink = emails.pop().html.match(/href="[^"]+(\/api\/auth\/callback\/[^"]+?)"/)[1];
cy.visit(loginLink);
cy.url().should("include", "/dashboard");
});
});
+1296 -115
View File
File diff suppressed because it is too large Load Diff
+6
View File
@@ -65,6 +65,7 @@
"react-hook-form": "^7.42.1",
"react-i18next": "^12.1.4",
"sharp": "^0.31.3",
"storybook-addon-next-router": "^4.0.2",
"swr": "^2.0.0",
"tailwindcss": "^3.2.4",
"unique-username-generator": "^1.1.3",
@@ -95,9 +96,14 @@
"eslint-plugin-unused-imports": "^2.0.0",
"jest": "^29.3.1",
"jest-environment-jsdom": "^29.3.1",
"msw": "^0.49.3",
"msw-storybook-addon": "^1.7.0",
"prettier": "2.8.1",
"prisma": "^4.7.1",
"ts-node": "^10.9.1",
"typescript": "^4.9.4"
},
"msw": {
"workerDirectory": "public"
}
}
+4 -1
View File
@@ -9,11 +9,14 @@
"docs": "Docs",
"github": "GitHub",
"legal": "Legal",
"loading": "Loading...",
"more_information": "More Information",
"no": "No",
"privacy_policy": "Privacy Policy",
"report_a_bug": "Report a Bug",
"sign_in": "Sign In",
"sign_out": "Sign Out",
"terms_of_service": "Terms of Service",
"title": "Open Assistant",
"more_information": "More Information"
"yes": "Yes"
}
+16
View File
@@ -0,0 +1,16 @@
{
"label_highlighted_yes_no_instruction": "Answer the following question(s) about the highlighted message:",
"label_highlighted_flag_instruction": "Select any that apply to the highlighted message:",
"label_highlighted_likert_instruction": "Rate the highlighted message:",
"label_message_yes_no_instruction": "Answer the following question(s) about the message:",
"label_message_flag_instruction": "Select any that apply to the message:",
"label_message_likert_instruction": "Rate the message:",
"spam.question": "Is the message spam?",
"fails_task.question": "Does the reply fail the prompter's task?",
"not_appropriate": "Not Appropriate",
"pii": "Contains PII",
"hate_speech": "Hate Speech",
"sexual_content": "Sexual Content",
"moral_judgement": "Judges Morality",
"political_content": "Political"
}
+13
View File
@@ -0,0 +1,13 @@
{
"label_action": "Label",
"label_title": "Label",
"message": "Message",
"open_new_tab_action": "Open in new tab",
"parent": "Parent",
"reactions": "Reactions",
"report_action": "Report",
"report_placeholder": "Why should this message be reviewed?",
"report_title": "Report",
"send_report": "Send",
"submit_labels": "Submit"
}
+6 -4
View File
@@ -1,5 +1,4 @@
{
"write_initial_prompt": "Write your prompt here...",
"default": {
"unchanged_title": "No changes",
"unchanged_message": "Are you sure you would like to continue?"
@@ -12,18 +11,21 @@
"label": "Create Initial Prompts",
"desc": "Write initial prompts to help Open Assistant to try replying to diverse messages.",
"overview": "Create an initial message to send to the assistant",
"instruction": "Provide the initial prompts"
"instruction": "Provide the initial prompts",
"response_placeholder": "Write your prompt here..."
},
"reply_as_user": {
"label": "Reply as User",
"desc": "Chat with Open Assistant and help improve it's responses as you interact with it.",
"overview": "Given the following conversation, provide an adequate reply",
"instruction": "Provide the user's reply"
"instruction": "Provide the user's reply",
"response_placeholder": "Write your reply here..."
},
"reply_as_assistant": {
"label": "Reply as Assistant",
"desc": "Help Open Assistant improve its responses to conversations with other users.",
"overview": "Given the following conversation, provide an adequate reply"
"overview": "Given the following conversation, provide an adequate reply",
"response_placeholder": "Write your reply here..."
},
"rank_user_replies": {
"label": "Rank User Replies",
+303
View File
@@ -0,0 +1,303 @@
/* eslint-disable */
/* tslint:disable */
/**
* Mock Service Worker (0.49.3).
* @see https://github.com/mswjs/msw
* - Please do NOT modify this file.
* - Please do NOT serve this file on production.
*/
const INTEGRITY_CHECKSUM = '3d6b9f06410d179a7f7404d4bf4c3c70'
const activeClientIds = new Set()
self.addEventListener('install', function () {
self.skipWaiting()
})
self.addEventListener('activate', function (event) {
event.waitUntil(self.clients.claim())
})
self.addEventListener('message', async function (event) {
const clientId = event.source.id
if (!clientId || !self.clients) {
return
}
const client = await self.clients.get(clientId)
if (!client) {
return
}
const allClients = await self.clients.matchAll({
type: 'window',
})
switch (event.data) {
case 'KEEPALIVE_REQUEST': {
sendToClient(client, {
type: 'KEEPALIVE_RESPONSE',
})
break
}
case 'INTEGRITY_CHECK_REQUEST': {
sendToClient(client, {
type: 'INTEGRITY_CHECK_RESPONSE',
payload: INTEGRITY_CHECKSUM,
})
break
}
case 'MOCK_ACTIVATE': {
activeClientIds.add(clientId)
sendToClient(client, {
type: 'MOCKING_ENABLED',
payload: true,
})
break
}
case 'MOCK_DEACTIVATE': {
activeClientIds.delete(clientId)
break
}
case 'CLIENT_CLOSED': {
activeClientIds.delete(clientId)
const remainingClients = allClients.filter((client) => {
return client.id !== clientId
})
// Unregister itself when there are no more clients
if (remainingClients.length === 0) {
self.registration.unregister()
}
break
}
}
})
self.addEventListener('fetch', function (event) {
const { request } = event
const accept = request.headers.get('accept') || ''
// Bypass server-sent events.
if (accept.includes('text/event-stream')) {
return
}
// Bypass navigation requests.
if (request.mode === 'navigate') {
return
}
// Opening the DevTools triggers the "only-if-cached" request
// that cannot be handled by the worker. Bypass such requests.
if (request.cache === 'only-if-cached' && request.mode !== 'same-origin') {
return
}
// Bypass all requests when there are no active clients.
// Prevents the self-unregistered worked from handling requests
// after it's been deleted (still remains active until the next reload).
if (activeClientIds.size === 0) {
return
}
// Generate unique request ID.
const requestId = Math.random().toString(16).slice(2)
event.respondWith(
handleRequest(event, requestId).catch((error) => {
if (error.name === 'NetworkError') {
console.warn(
'[MSW] Successfully emulated a network error for the "%s %s" request.',
request.method,
request.url,
)
return
}
// At this point, any exception indicates an issue with the original request/response.
console.error(
`\
[MSW] Caught an exception from the "%s %s" request (%s). This is probably not a problem with Mock Service Worker. There is likely an additional logging output above.`,
request.method,
request.url,
`${error.name}: ${error.message}`,
)
}),
)
})
async function handleRequest(event, requestId) {
const client = await resolveMainClient(event)
const response = await getResponse(event, client, requestId)
// Send back the response clone for the "response:*" life-cycle events.
// Ensure MSW is active and ready to handle the message, otherwise
// this message will pend indefinitely.
if (client && activeClientIds.has(client.id)) {
;(async function () {
const clonedResponse = response.clone()
sendToClient(client, {
type: 'RESPONSE',
payload: {
requestId,
type: clonedResponse.type,
ok: clonedResponse.ok,
status: clonedResponse.status,
statusText: clonedResponse.statusText,
body:
clonedResponse.body === null ? null : await clonedResponse.text(),
headers: Object.fromEntries(clonedResponse.headers.entries()),
redirected: clonedResponse.redirected,
},
})
})()
}
return response
}
// Resolve the main client for the given event.
// Client that issues a request doesn't necessarily equal the client
// that registered the worker. It's with the latter the worker should
// communicate with during the response resolving phase.
async function resolveMainClient(event) {
const client = await self.clients.get(event.clientId)
if (client?.frameType === 'top-level') {
return client
}
const allClients = await self.clients.matchAll({
type: 'window',
})
return allClients
.filter((client) => {
// Get only those clients that are currently visible.
return client.visibilityState === 'visible'
})
.find((client) => {
// Find the client ID that's recorded in the
// set of clients that have registered the worker.
return activeClientIds.has(client.id)
})
}
async function getResponse(event, client, requestId) {
const { request } = event
const clonedRequest = request.clone()
function passthrough() {
// Clone the request because it might've been already used
// (i.e. its body has been read and sent to the client).
const headers = Object.fromEntries(clonedRequest.headers.entries())
// Remove MSW-specific request headers so the bypassed requests
// comply with the server's CORS preflight check.
// Operate with the headers as an object because request "Headers"
// are immutable.
delete headers['x-msw-bypass']
return fetch(clonedRequest, { headers })
}
// Bypass mocking when the client is not active.
if (!client) {
return passthrough()
}
// Bypass initial page load requests (i.e. static assets).
// The absence of the immediate/parent client in the map of the active clients
// means that MSW hasn't dispatched the "MOCK_ACTIVATE" event yet
// and is not ready to handle requests.
if (!activeClientIds.has(client.id)) {
return passthrough()
}
// Bypass requests with the explicit bypass header.
// Such requests can be issued by "ctx.fetch()".
if (request.headers.get('x-msw-bypass') === 'true') {
return passthrough()
}
// Notify the client that a request has been intercepted.
const clientMessage = await sendToClient(client, {
type: 'REQUEST',
payload: {
id: requestId,
url: request.url,
method: request.method,
headers: Object.fromEntries(request.headers.entries()),
cache: request.cache,
mode: request.mode,
credentials: request.credentials,
destination: request.destination,
integrity: request.integrity,
redirect: request.redirect,
referrer: request.referrer,
referrerPolicy: request.referrerPolicy,
body: await request.text(),
bodyUsed: request.bodyUsed,
keepalive: request.keepalive,
},
})
switch (clientMessage.type) {
case 'MOCK_RESPONSE': {
return respondWithMock(clientMessage.data)
}
case 'MOCK_NOT_FOUND': {
return passthrough()
}
case 'NETWORK_ERROR': {
const { name, message } = clientMessage.data
const networkError = new Error(message)
networkError.name = name
// Rejecting a "respondWith" promise emulates a network error.
throw networkError
}
}
return passthrough()
}
function sendToClient(client, message) {
return new Promise((resolve, reject) => {
const channel = new MessageChannel()
channel.port1.onmessage = (event) => {
if (event.data && event.data.error) {
return reject(event.data.error)
}
resolve(event.data)
}
client.postMessage(message, [channel.port2])
})
}
function sleep(timeMs) {
return new Promise((resolve) => {
setTimeout(resolve, timeMs)
})
}
async function respondWithMock(response) {
await sleep(response.delay)
return new Response(response.body, response)
}
View File
+1 -1
View File
@@ -15,7 +15,7 @@ export const EmptyState = (props: EmptyStateProps) => {
<Box data-cy={props["data-cy"]} bg={backgroundColor} p="10" borderRadius="xl" shadow="base">
<Box display="flex" flexDirection="column" alignItems="center" gap="8" fontSize="lg">
<props.icon size="30" color="DarkOrange" />
<Text>{props.text}</Text>
<Text data-cy="cy-no-tasks">{props.text}</Text>
<NextLink href="/dashboard">
<Text color="blue.500">Go back to the dashboard</Text>
</NextLink>
-116
View File
@@ -1,116 +0,0 @@
import {
Box,
Button,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalFooter,
ModalHeader,
ModalOverlay,
Popover,
PopoverAnchor,
PopoverTrigger,
Tooltip,
useColorModeValue,
useDisclosure,
} from "@chakra-ui/react";
import { AlertCircle } from "lucide-react";
import { useState } from "react";
import { get, post } from "src/lib/api";
import { colors } from "src/styles/Theme/colors";
import { Message } from "src/types/Conversation";
import useSWRImmutable from "swr/immutable";
import useSWRMutation from "swr/mutation";
import { LabelInputGroup } from "./Survey/LabelInputGroup";
interface Label {
name: string;
display_text: string;
help_text: string;
}
interface FlaggableElementProps {
children: React.ReactNode;
message: Message;
}
interface ValidLabelsResponse {
valid_labels: Label[];
}
export const FlaggableElement = (props: FlaggableElementProps) => {
const { data: response } = useSWRImmutable<ValidLabelsResponse>("/api/valid_labels", get);
const { isOpen, onOpen, onClose } = useDisclosure();
const { valid_labels } = response || { valid_labels: [] };
const [values, setValues] = useState<number[]>([]);
const submittable =
values.some((value) => {
return value !== null;
}) &&
values.length === valid_labels.length &&
valid_labels.length > 0;
const { trigger } = useSWRMutation("/api/set_label", post, {
onSuccess: onClose,
onError: onClose,
});
const submitResponse = () => {
const label_map: Map<string, number> = new Map();
console.assert(valid_labels.length === values.length);
values.forEach((value, idx) => {
if (value !== null) {
label_map.set(valid_labels[idx].name, value);
}
});
trigger({
message_id: props.message.id,
label_map: Object.fromEntries(label_map),
text: props.message.text,
});
};
return (
<Popover isOpen={isOpen} onOpen={onOpen} onClose={onClose} closeOnBlur={false} isLazy lazyBehavior="keepMounted">
<Box display="flex" alignItems="center" flexDirection={["column", "row"]} gap="2">
<PopoverAnchor>{props.children}</PopoverAnchor>
<Tooltip label="Report" bg="red.500" aria-label="A tooltip">
<Box>
<PopoverTrigger>
<Box as="button" display="flex" alignItems="center" justifyContent="center" borderRadius="full" p="1">
<AlertCircle size="20" className="text-red-400" aria-hidden="true" />
</Box>
</PopoverTrigger>
</Box>
</Tooltip>
</Box>
<Modal isOpen={isOpen} onClose={onClose}>
<ModalOverlay />
<ModalContent>
<ModalHeader>Select one or more labels that apply.</ModalHeader>
<ModalCloseButton />
<ModalBody>
<LabelInputGroup simple labelIDs={valid_labels.map(({ name }) => name)} onChange={setValues} />
</ModalBody>
<ModalFooter>
<Button
isDisabled={!submittable}
onClick={submitResponse}
className={`bg-indigo-600 text-${useColorModeValue(
colors.light.text,
colors.dark.text
)} hover:bg-indigo-700`}
>
Report
</Button>
</ModalFooter>
</ModalContent>
</Modal>
</Popover>
);
};
@@ -1,4 +1,4 @@
import { Box, Center, Progress, Text, useColorModeValue } from "@chakra-ui/react";
import { Box, Center, Progress, Text } from "@chakra-ui/react";
export const LoadingScreen = ({ text = "Loading..." } = {}) => {
return (
+2 -20
View File
@@ -1,26 +1,8 @@
import { Box, forwardRef, Grid, useColorMode } from "@chakra-ui/react";
import { Box, forwardRef, useColorMode } from "@chakra-ui/react";
import { useMemo } from "react";
import { Message } from "src/types/Conversation";
import { FlaggableElement } from "./FlaggableElement";
interface MessagesProps {
messages: Message[];
}
export const Messages = ({ messages }: MessagesProps) => {
const items = messages.map((messageProps: Message, i: number) => {
return (
<FlaggableElement message={messageProps} key={i + messageProps.id}>
<MessageView {...messageProps} />
</FlaggableElement>
);
});
// Maybe also show a legend of the colors?
return <Grid gap={2}>{items}</Grid>;
};
export const MessageView = forwardRef<Message, "div">((message: Message, ref) => {
export const MessageView = forwardRef<Partial<Message>, "div">((message: Partial<Message>, ref) => {
const { colorMode } = useColorMode();
const bgColor = useMemo(() => {
@@ -0,0 +1,32 @@
import { Button, Flex } from "@chakra-ui/react";
import { useTranslation } from "next-i18next";
import { getTypeSafei18nKey } from "src/lib/i18n";
interface LabelFlagGroupProps {
values: number[];
labelNames: string[];
isEditable?: boolean;
onChange: (values: number[]) => void;
}
export const LabelFlagGroup = ({ values, labelNames, isEditable = true, onChange }: LabelFlagGroupProps) => {
const { t } = useTranslation("labelling");
return (
<Flex wrap="wrap" gap="4">
{labelNames.map((name, idx) => (
<Button
key={name}
onClick={() => {
const newValues = values.slice();
newValues[idx] = newValues[idx] ? 0 : 1;
onChange(newValues);
}}
isDisabled={!isEditable}
colorScheme={values[idx] === 1 ? "blue" : undefined}
>
{t(getTypeSafei18nKey(name))}
</Button>
))}
</Flex>
);
};
@@ -0,0 +1,84 @@
import { Text, VStack } from "@chakra-ui/react";
import { Label } from "src/types/Tasks";
import { LabelLikertGroup } from "../Survey/LabelLikertGroup";
import { LabelFlagGroup } from "./LabelFlagGroup";
import { LabelYesNoGroup } from "./LabelYesNoGroup";
export interface LabelInputInstructions {
yesNoInstruction: string;
flagInstruction: string;
likertInstruction: string;
}
interface LabelInputGroupProps {
values: number[];
labels: Label[];
requiredLabels?: string[];
isEditable?: boolean;
instructions: LabelInputInstructions;
onChange: (values: number[]) => void;
}
export const LabelInputGroup = ({
labels,
values,
requiredLabels,
isEditable,
instructions,
onChange,
}: LabelInputGroupProps) => {
const yesNoIndexes = labels.map((label, idx) => (label.widget === "yes_no" ? idx : null)).filter((v) => v !== null);
const flagIndexes = labels.map((label, idx) => (label.widget === "flag" ? idx : null)).filter((v) => v !== null);
const likertIndexes = labels.map((label, idx) => (label.widget === "likert" ? idx : null)).filter((v) => v !== null);
return (
<VStack alignItems="stretch" spacing={6}>
{yesNoIndexes.length > 0 && (
<VStack alignItems="stretch" spacing={2}>
<Text>{instructions.yesNoInstruction}</Text>
<LabelYesNoGroup
values={yesNoIndexes.map((idx) => values[idx])}
labelNames={yesNoIndexes.map((idx) => labels[idx].name)}
isEditable={isEditable}
requiredLabels={requiredLabels}
onChange={(yesNoValues) => {
const newValues = values.slice();
yesNoIndexes.forEach((idx, yesNoIndex) => (newValues[idx] = yesNoValues[yesNoIndex]));
onChange(newValues);
}}
/>
</VStack>
)}
{flagIndexes.length > 0 && (
<VStack alignItems="stretch" spacing={2}>
<Text>{instructions.flagInstruction}</Text>
<LabelFlagGroup
values={flagIndexes.map((idx) => values[idx])}
labelNames={flagIndexes.map((idx) => labels[idx].name)}
isEditable={isEditable}
onChange={(flagValues) => {
const newValues = values.slice();
flagIndexes.forEach((idx, flagIndex) => (newValues[idx] = flagValues[flagIndex]));
onChange(newValues);
}}
/>
</VStack>
)}
{likertIndexes.length > 0 && (
<VStack alignItems="stretch" spacing={2}>
<Text>{instructions.likertInstruction}</Text>
<LabelLikertGroup
labelIDs={likertIndexes.map((idx) => labels[idx].name)}
isEditable={isEditable}
onChange={(likertValues) => {
const newValues = values.slice();
likertIndexes.forEach((idx, likertIndex) => (newValues[idx] = likertValues[likertIndex]));
onChange(newValues);
}}
/>
</VStack>
)}
</VStack>
);
};
@@ -0,0 +1,84 @@
import {
Button,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalFooter,
ModalHeader,
ModalOverlay,
} from "@chakra-ui/react";
import { useTranslation } from "next-i18next";
import { useEffect, useState } from "react";
import { LabelInputGroup } from "src/components/Messages/LabelInputGroup";
import { get, post } from "src/lib/api";
import { Label } from "src/types/Tasks";
import useSWRImmutable from "swr/immutable";
import useSWRMutation from "swr/mutation";
interface LabelMessagePopupProps {
messageId: string;
show: boolean;
onClose: () => void;
}
interface ValidLabelsResponse {
valid_labels: Label[];
}
export const LabelMessagePopup = ({ messageId, show, onClose }: LabelMessagePopupProps) => {
const { t } = useTranslation();
const { data: response } = useSWRImmutable<ValidLabelsResponse>(`/api/valid_labels?message_id=${messageId}`, get);
const valid_labels = response?.valid_labels ?? [];
const [values, setValues] = useState<number[]>(new Array(valid_labels.length).fill(null));
useEffect(() => {
setValues(new Array(valid_labels.length).fill(null));
}, [messageId, valid_labels.length]);
const { trigger: setLabels } = useSWRMutation("/api/set_label", post);
const submit = () => {
const label_map: Map<string, number> = new Map();
console.assert(valid_labels.length === values.length);
values.forEach((value, idx) => {
if (value !== null) {
label_map.set(valid_labels[idx].name, value);
}
});
setLabels({
message_id: messageId,
label_map: Object.fromEntries(label_map),
});
setValues(null);
onClose();
};
return (
<Modal isOpen={show} onClose={onClose}>
<ModalOverlay />
<ModalContent>
<ModalHeader>{t("message:label_title")}</ModalHeader>
<ModalCloseButton />
<ModalBody>
<LabelInputGroup
labels={valid_labels}
values={values}
instructions={{
yesNoInstruction: t("labelling:label_message_yes_no_instruction"),
flagInstruction: t("labelling:label_message_flag_instruction"),
likertInstruction: t("labelling:label_message_likert_instruction"),
}}
onChange={setValues}
/>
</ModalBody>
<ModalFooter>
<Button colorScheme="blue" mr={3} onClick={submit}>
{t("message:submit_labels")}
</Button>
</ModalFooter>
</ModalContent>
</Modal>
);
};
@@ -0,0 +1,89 @@
import { Button, HStack, Text, Tooltip } from "@chakra-ui/react";
import { useTranslation } from "next-i18next";
import { getTypeSafei18nKey } from "src/lib/i18n";
interface LabelYesNoGroupProps {
values: number[];
labelNames: string[];
requiredLabels?: string[];
isEditable?: boolean;
onChange: (values: number[]) => void;
}
export const LabelYesNoGroup = ({
values,
labelNames,
requiredLabels = [],
isEditable = true,
onChange,
}: LabelYesNoGroupProps) => {
const { t } = useTranslation("labelling");
return (
<>
{labelNames.map((name, idx) => {
return (
<YesNoQuestion
key={name}
question={t(getTypeSafei18nKey(`${name}.question`))}
value={values[idx] === null ? null : values[idx] > 0.1 ? true : false}
onChange={(value) => {
const newValues = values.slice();
newValues[idx] = value;
onChange(newValues);
}}
isEditable={isEditable}
isRequired={requiredLabels.includes(name)}
/>
);
})}
</>
);
};
const YesNoQuestion = ({
isEditable,
question,
value,
isRequired,
onChange,
}: {
isEditable: boolean;
question: string;
value: boolean;
isRequired?: boolean;
onChange: (boolean) => void;
}) => {
const { t } = useTranslation();
return (
<div data-cy="label-question" style={{ maxWidth: "30em" }}>
<Text display="inline">
{question}
{isRequired ? <RequiredMark /> : undefined}
</Text>
<HStack style={{ float: "right" }}>
<Button
data-cy="yes"
isDisabled={!isEditable}
colorScheme={value === true ? "blue" : undefined}
onClick={() => onChange(isRequired ? true : value === null ? true : null)}
>
{t("yes")}
</Button>
<Button
data-cy="no"
isDisabled={!isEditable}
colorScheme={value === false ? "blue" : undefined}
onClick={() => onChange(isRequired ? false : value === null ? false : null)}
>
{t("no")}
</Button>
</HStack>
</div>
);
};
const RequiredMark = () => (
<Tooltip label="Required">
<span style={{ color: "red" }}>*</span>
</Tooltip>
);
@@ -0,0 +1,34 @@
import React from "react";
import { MessageEmojiButton } from "./MessageEmojiButton";
// eslint-disable-next-line import/no-anonymous-default-export
export default {
title: "Messages/MessageEmojiButton",
component: MessageEmojiButton,
};
const Template = ({ emoji, count, checked }: { emoji: string; count: number; checked?: boolean }) => {
return <MessageEmojiButton emoji={{ name: emoji, count }} checked={checked} onClick={undefined} />;
};
export const Default = Template.bind({});
Default.args = {
emoji: "+1",
count: 7,
checked: false,
};
export const BigNumber = Template.bind({});
BigNumber.args = {
emoji: "+1",
count: 999,
checked: false,
};
export const Checked = Template.bind({});
Checked.args = {
emoji: "+1",
count: 2,
checked: true,
};
@@ -0,0 +1,48 @@
import { Button } from "@chakra-ui/react";
import { BoxSelect, Flag, LucideProps, ThumbsDown, ThumbsUp } from "lucide-react";
import { ReactElement } from "react";
import { MessageEmoji } from "src/types/Conversation";
type EmojiIconPurpose = "MINI_BUTTON" | "NORMAL";
const defaultIconProps: (purpose: EmojiIconPurpose) => LucideProps = (purpose: EmojiIconPurpose) => {
if (purpose === "MINI_BUTTON") return { height: "1em" };
return {};
};
export const getEmojiIcon = (name: string, purpose: EmojiIconPurpose): ReactElement => {
switch (name) {
case "+1":
return <ThumbsUp {...defaultIconProps(purpose)} />;
case "-1":
return <ThumbsDown {...defaultIconProps(purpose)} />;
case "flag":
case "red_flag":
return <Flag {...defaultIconProps(purpose)} />;
default:
return <BoxSelect {...defaultIconProps(purpose)} />;
}
};
interface MessageEmojiButtonProps {
emoji: MessageEmoji;
checked?: boolean;
onClick: () => void;
}
export const MessageEmojiButton = ({ emoji, checked, onClick }: MessageEmojiButtonProps) => {
return (
<Button
onClick={onClick}
variant={checked ? "solid" : "ghost"}
colorScheme={checked ? "blue" : undefined}
size="sm"
height="1.6em"
minWidth={0}
padding="0"
>
{getEmojiIcon(emoji.name, "MINI_BUTTON")}
<span style={{ marginInlineEnd: "0.25em" }}>{emoji.count}</span>
</Button>
);
};
@@ -0,0 +1,110 @@
import React from "react";
import { Message } from "src/types/Conversation";
import { MessageTable } from "./MessageTable";
// eslint-disable-next-line import/no-anonymous-default-export
export default {
title: "Messages/MessageTable",
component: MessageTable,
};
const Template = ({
messages,
enableLink,
highlightLastMessage,
}: {
messages: Message[];
enableLink: boolean;
highlightLastMessage: boolean;
}) => {
return <MessageTable messages={messages} enableLink={enableLink} highlightLastMessage={highlightLastMessage} />;
};
export const Default = Template.bind({});
Default.args = {
messages: [
{
text: "I'm unsure how to interpret this. Is it a riddle?",
is_assistant: true,
id: "",
frontend_message_id: "",
emojis: {},
user_emojis: [],
},
{
text: "No, I just wanted to see how you reply when I type random characters. Can you tell me who invented Wikipedia?",
is_assistant: false,
id: "",
frontend_message_id: "",
emojis: { "-1": 11, red_flag: 2 },
user_emojis: [],
},
{
text: "Sorry, my cat sat on my keyboard. Can you print a cat in ASCII art?",
is_assistant: false,
id: "",
frontend_message_id: "",
emojis: {},
user_emojis: [],
},
],
enableLink: true,
highlightLastMessage: false,
};
export const Conversation = Template.bind({});
Conversation.args = {
messages: [
{
text: "Hello! How can I help you?",
is_assistant: true,
id: "",
frontend_message_id: "",
emojis: {},
user_emojis: [],
},
{
text: "Who were the 8 presidents before George Washington?",
is_assistant: false,
id: "",
frontend_message_id: "",
emojis: {},
user_emojis: [],
},
],
enableLink: false,
highlightLastMessage: true,
};
export const LongText = Template.bind({});
LongText.args = {
messages: [
{
text: "Okay, here\u2019s my answer. The thing is, in a job interview you are applying to the position, right? Which means that you need to try to make yourself sound as good a fit as possible. So that means it\u2019s a mistake to wear something that people would think of as a silly or stupid outfit. It\u2019s good to wear something in line with what people expect the job to require, and ideally you want to wear something that you can walk around comfortably in. At the same time, if you wear something that looks kind of dorky, you might not stand out as much as someone who doesn\u2019t, because the other candidates probably all have something slightly weird or wacky about them too. So, I\u2019d say, do you like your outfit? Wear something that you\u2019re comfortable in, that you think you look good in, and is in line with what the job would require. Also, for interview prep, I would focus on making yourself look as professional as possible, from your hair to your outfit. What do you think of this?",
is_assistant: true,
id: "",
frontend_message_id: "",
emojis: {},
user_emojis: [],
},
{
text: "Yes, I think they can be helpful when the child misbehaves, but they should be used with a little bit of compassion and understanding that it\u2019s not the natural state of things to have an adult yelling at them. Time outs are also often used without letting the child know how they\u2019re getting out of the time out, which can make it feel arbitrary or like a punishment, rather than a consequence for something they did. It\u2019s really easy for adults to do this kind of thing unconsciously. It\u2019s easy to get caught up in the notion that \u201cThey\u2019re in time out, and that\u2019s the end of it!\u201d but kids can be pretty imaginative, and they can use their own creativity to make their way out of time outs. A compassionate time out ends when the child shows a sign of understanding what they\u2019ve done wrong, and are ready to begin again. That way the child knows they\u2019re learning, and that the parent is seeing them as an intelligent person, even if they sometimes mess up. You can still use the other techniques you were using to be tough when necessary, but using a compassionate approach will let you use them without actually using them!",
is_assistant: false,
id: "",
frontend_message_id: "",
emojis: {},
user_emojis: [],
},
{
text: "No. The USA was founded by a Puritan group of Protestants, but it didn\u2019t adopt the religion of the Puritans until much later, and it was always a secular state. The Puritans observed the Sabbath on Sunday, and the Puritans only had a small influence in the early history of the USA. It\u2019s difficult to trace the origins of closing stores on Sunday, but one early and short-lived attempt at forcing the Sabbath on people in the 1800s was motivated by the Protestant ideal that people should spend Sunday focusing on spiritual activities. By the mid-1800s, when the Sunday closing law was made, there was not a lot of pressure from that standpoint, but the church had begun to advocate for Sunday closing laws as a way of counteracting the negative effects of industrialization on the day of rest. Even after that shift, closing stores on Sunday was not always possible, since the religious Sunday was not always chosen for observance. And as industrialization accelerated and mechanization made it possible to operate stores on Sunday, the law was not enforced as much as people liked. The day of rest was also being violated by stores that stayed open all day on Sunday, so closing stores on Sundays became an effort to protect the Sabbath for all citizens.",
is_assistant: false,
id: "",
frontend_message_id: "",
emojis: {},
user_emojis: [],
},
],
enableLink: true,
highlightLastMessage: false,
};
@@ -11,11 +11,11 @@ interface MessageTableProps {
export function MessageTable({ messages, enableLink, highlightLastMessage }: MessageTableProps) {
return (
<Stack spacing="4">
{messages.map((item, idx) => (
{messages.map((message, idx) => (
<MessageTableEntry
enabled={enableLink}
item={item}
key={item.id + item.frontend_message_id}
message={message}
key={message.id + message.frontend_message_id}
highlight={highlightLastMessage && idx === messages.length - 1}
/>
))}
@@ -0,0 +1,62 @@
import React from "react";
import { Message } from "src/types/Conversation";
import { MessageTableEntry } from "./MessageTableEntry";
// eslint-disable-next-line import/no-anonymous-default-export
export default {
title: "Messages/MessageTableEntry",
component: MessageTableEntry,
};
const Template = ({ enabled, highlight, ...message }) => {
return <MessageTableEntry message={message as Message} enabled={enabled} highlight={highlight} />;
};
export const Default = Template.bind({});
Default.args = {
text: "Who were the 8 presidents before George Washington?",
is_assistant: false,
id: "",
frontend_message_id: "",
enabled: true,
highlight: false,
emojis: {},
user_emojis: [],
};
export const Asistant = Template.bind({});
Asistant.args = {
text: "Who were the 8 presidents before George Washington?",
is_assistant: true,
id: "",
frontend_message_id: "",
enabled: true,
highlight: false,
emojis: {},
user_emojis: [],
};
export const LongText = Template.bind({});
LongText.args = {
text: "Assistant: No. The USA was founded by a Puritan group of Protestants, but it didn\u2019t adopt the religion of the Puritans until much later, and it was always a secular state. The Puritans observed the Sabbath on Sunday, and the Puritans only had a small influence in the early history of the USA. It\u2019s difficult to trace the origins of closing stores on Sunday, but one early and short-lived attempt at forcing the Sabbath on people in the 1800s was motivated by the Protestant ideal that people should spend Sunday focusing on spiritual activities. By the mid-1800s, when the Sunday closing law was made, there was not a lot of pressure from that standpoint, but the church had begun to advocate for Sunday closing laws as a way of counteracting the negative effects of industrialization on the day of rest. Even after that shift, closing stores on Sunday was not always possible, since the religious Sunday was not always chosen for observance. And as industrialization accelerated and mechanization made it possible to operate stores on Sunday, the law was not enforced as much as people liked. The day of rest was also being violated by stores that stayed open all day on Sunday, so closing stores on Sundays became an effort to protect the Sabbath for all citizens.",
is_assistant: true,
id: "",
frontend_message_id: "",
enabled: true,
highlight: false,
emojis: {},
user_emojis: [],
};
export const WithEmoji = Template.bind({});
WithEmoji.args = {
text: "As you\u2019ve mentioned, Star Wars has many sequels, prequels, and crossovers. The official list of movies in Star Wars is:",
is_assistant: true,
id: "",
frontend_message_id: "",
enabled: true,
highlight: false,
emojis: { "-1": 5, "+1": 1 },
user_emojis: ["-1"],
};
@@ -1,23 +1,50 @@
import { Avatar, Box, HStack, useBreakpointValue, useColorModeValue } from "@chakra-ui/react";
import {
Avatar,
Box,
HStack,
Menu,
MenuButton,
MenuDivider,
MenuGroup,
MenuItem,
MenuList,
SimpleGrid,
useBreakpointValue,
useColorModeValue,
useDisclosure,
} from "@chakra-ui/react";
import { boolean } from "boolean";
import { ClipboardList, Flag, MessageSquare, MoreHorizontal } from "lucide-react";
import { useRouter } from "next/router";
import { useCallback, useMemo } from "react";
import { FlaggableElement } from "src/components/FlaggableElement";
import { Message } from "src/types/Conversation";
import { useTranslation } from "next-i18next";
import { useCallback, useEffect, useMemo, useState } from "react";
import { LabelMessagePopup } from "src/components/Messages/LabelPopup";
import { getEmojiIcon, MessageEmojiButton } from "src/components/Messages/MessageEmojiButton";
import { ReportPopup } from "src/components/Messages/ReportPopup";
import { post } from "src/lib/api";
import { Message, MessageEmojis } from "src/types/Conversation";
import { colors } from "styles/Theme/colors";
import useSWRMutation from "swr/mutation";
interface MessageTableEntryProps {
item: Message;
message: Message;
enabled?: boolean;
highlight?: boolean;
}
export function MessageTableEntry(props: MessageTableEntryProps) {
export function MessageTableEntry({ message, enabled, highlight }: MessageTableEntryProps) {
const router = useRouter();
const [emojiState, setEmojis] = useState<MessageEmojis>({ emojis: {}, user_emojis: [] });
useEffect(() => {
setEmojis({
emojis: message?.emojis || {},
user_emojis: message?.user_emojis || [],
});
}, [message.emojis, message.user_emojis]);
const { item } = props;
const goToMessage = useCallback(() => router.push(`/messages/${item.id}`), [router, item.id]);
const goToMessage = useCallback(() => router.push(`/messages/${message.id}`), [router, message.id]);
const { isOpen: reportPopupOpen, onOpen: showReportPopup, onClose: closeReportPopup } = useDisclosure();
const { isOpen: labelPopupOpen, onOpen: showLabelPopup, onClose: closeLabelPopup } = useDisclosure();
const backgroundColor = useColorModeValue("gray.100", "gray.700");
const backgroundColor2 = useColorModeValue("#DFE8F1", "#42536B");
@@ -32,34 +59,124 @@ export function MessageTableEntry(props: MessageTableEntryProps) {
borderColor={borderColor}
size={inlineAvatar ? "xs" : "sm"}
mr={inlineAvatar ? 2 : 0}
name={`${boolean(item.is_assistant) ? "Assistant" : "User"}`}
src={`${boolean(item.is_assistant) ? "/images/logos/logo.png" : "/images/temp-avatars/av1.jpg"}`}
name={`${boolean(message.is_assistant) ? "Assistant" : "User"}`}
src={`${boolean(message.is_assistant) ? "/images/logos/logo.png" : "/images/temp-avatars/av1.jpg"}`}
/>
),
[borderColor, inlineAvatar, item.is_assistant]
[borderColor, inlineAvatar, message.is_assistant]
);
const highlightColor = useColorModeValue(colors.light.highlight, colors.dark.highlight);
const { trigger: sendEmojiChange } = useSWRMutation(`/api/messages/${message.id}/emoji`, post, {
onSuccess: setEmojis,
});
const react = (emoji: string, state: boolean) => {
sendEmojiChange({ op: state ? "add" : "remove", emoji });
};
return (
<FlaggableElement message={item}>
<HStack w={["full", "full", "full", "fit-content"]} gap={2}>
{!inlineAvatar && avatar}
<Box
width={["full", "full", "full", "fit-content"]}
maxWidth={["full", "full", "full", "2xl"]}
p="4"
borderRadius="md"
bg={item.is_assistant ? backgroundColor : backgroundColor2}
outline={props.highlight && "2px solid black"}
outlineColor={highlightColor}
onClick={props.enabled && goToMessage}
_hover={props.enabled && { cursor: "pointer", opacity: 0.9 }}
whiteSpace="pre-wrap"
<HStack w={["full", "full", "full", "fit-content"]} gap={2}>
{!inlineAvatar && avatar}
<Box
width={["full", "full", "full", "fit-content"]}
maxWidth={["full", "full", "full", "2xl"]}
p="4"
borderRadius="md"
bg={message.is_assistant ? backgroundColor : backgroundColor2}
outline={highlight && "2px solid black"}
outlineColor={highlightColor}
onClick={enabled && goToMessage}
whiteSpace="pre-wrap"
cursor={enabled && "pointer"}
style={{ position: "relative" }}
>
{inlineAvatar && avatar}
{message.text}
<HStack
style={{ float: "right", position: "relative", right: "-0.3em", bottom: "-0em", marginLeft: "1em" }}
onClick={(e) => e.stopPropagation()}
>
{inlineAvatar && avatar}
{item.text}
</Box>
</HStack>
</FlaggableElement>
{Object.entries(emojiState.emojis).map(([emoji, count]) => (
<MessageEmojiButton
key={emoji}
emoji={{ name: emoji, count }}
checked={emojiState.user_emojis.includes(emoji)}
onClick={() => react(emoji, !emojiState.user_emojis.includes(emoji))}
/>
))}
<MessageActions
react={react}
userEmoji={emojiState.user_emojis}
onLabel={showLabelPopup}
onReport={showReportPopup}
messageId={message.id}
/>
<LabelMessagePopup messageId={message.id} show={labelPopupOpen} onClose={closeLabelPopup} />
<ReportPopup messageId={message.id} show={reportPopupOpen} onClose={closeReportPopup} />
</HStack>
</Box>
</HStack>
);
}
const EmojiMenuItem = ({
emoji,
checked,
react,
}: {
emoji: string;
checked?: boolean;
react: (emoji: string, state: boolean) => void;
}) => {
const activeColor = useColorModeValue(colors.light.active, colors.dark.active);
return (
<MenuItem onClick={() => react(emoji, !checked)} justifyContent="center" color={checked ? activeColor : undefined}>
{getEmojiIcon(emoji, "NORMAL")}
</MenuItem>
);
};
const MessageActions = ({
react,
userEmoji,
onLabel,
onReport,
messageId,
}: {
react: (emoji: string, state: boolean) => void;
userEmoji: string[];
onLabel: () => void;
onReport: () => void;
messageId: string;
}) => {
const { t } = useTranslation("message");
return (
<Menu>
<MenuButton>
<MoreHorizontal />
</MenuButton>
<MenuList>
<MenuGroup title={t("reactions")}>
<SimpleGrid columns={4}>
{["+1", "-1"].map((emoji) => (
<EmojiMenuItem key={emoji} emoji={emoji} checked={userEmoji?.includes(emoji)} react={react} />
))}
</SimpleGrid>
</MenuGroup>
<MenuDivider />
<MenuItem onClick={onLabel} icon={<ClipboardList />}>
{t("label_action")}
</MenuItem>
<MenuItem onClick={onReport} icon={<Flag />}>
{t("report_action")}
</MenuItem>
<MenuDivider />
<MenuItem as="a" href={`/messages/${messageId}`} target="_blank" icon={<MessageSquare />}>
{t("open_new_tab_action")}
</MenuItem>
</MenuList>
</Menu>
);
};
@@ -0,0 +1,99 @@
import { rest } from "msw";
import { MessageWithChildren } from "./MessageWithChildren";
// eslint-disable-next-line import/no-anonymous-default-export
export default {
title: "Messages/MessageWithChildren",
component: MessageWithChildren,
parameters: {
layout: "fullscreen",
msw: {
handlers: {
messagesDefault: [
rest.get("/api/messages/id-1", (req, res, ctx) => {
return res(
ctx.json({
text: "Some message Text",
is_assistant: false,
id: "id-1",
})
);
}),
rest.get("/api/messages/id-1/children", (req, res, ctx) => {
return res(ctx.json([]));
}),
],
},
},
},
};
const Template = (args) => <MessageWithChildren {...args} />;
export const NoChildren = Template.bind({});
NoChildren.args = {
id: "id-1",
maxDepth: 2,
};
export const WithChildren = Template.bind({});
WithChildren.args = {
id: "id-1",
maxDepth: 1,
};
WithChildren.parameters = {
msw: {
handlers: {
additionalMessages: [
rest.get("/api/messages/id-2", (req, res, ctx) => {
return res(
ctx.json({
text: "Some child message Text",
is_assistant: false,
id: "id-2",
})
);
}),
rest.get("/api/messages/id-3", (req, res, ctx) => {
return res(
ctx.json({
text: "Some child message Text",
is_assistant: false,
id: "id-3",
})
);
}),
rest.get("/api/messages/id-1/children", (req, res, ctx) => {
return res(
ctx.json([
{
text: "Some child message Text",
is_assistant: false,
id: "id-2",
},
{
text: "another child message Text",
is_assistant: false,
id: "id-3",
},
])
);
}),
rest.get("/api/messages/id-2/children", (req, res, ctx) => {
return res(
ctx.json([
{
text: "another message Text",
is_assistant: false,
id: "id-4",
},
])
);
}),
rest.get("/api/messages/id-3/children", (req, res, ctx) => {
return res(ctx.json([]));
}),
],
},
},
};
@@ -52,7 +52,7 @@ export function MessageWithChildren(props: MessageWithChildrenProps) {
{isFirst ? "Message" : depth === 1 ? "Children" : "Ancestor"}
</Text>
<Box width="fit-content" bg={backgroundColor} padding="4" borderRadius="xl" boxShadow="base">
<MessageTableEntry enabled item={message} />
<MessageTableEntry enabled message={message} />
</Box>
</Box>
</>
@@ -86,9 +86,9 @@ export function MessageWithChildren(props: MessageWithChildrenProps) {
gap="4"
shadow="base"
>
{children.map((item, idx) => (
{children.map((message, idx) => (
<Box flex="1" key={`recursiveMessageWChildren_${idx}`}>
<MessageTableEntry enabled item={item} />
<MessageTableEntry enabled message={message} />
</Box>
))}
</Box>
@@ -0,0 +1,56 @@
import {
Button,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalFooter,
ModalHeader,
ModalOverlay,
Textarea,
} from "@chakra-ui/react";
import { useTranslation } from "next-i18next";
import { useState } from "react";
import { post } from "src/lib/api";
import useSWRMutation from "swr/mutation";
interface ReportPopupProps {
messageId: string;
show: boolean;
onClose: () => void;
}
export const ReportPopup = ({ messageId, show, onClose }: ReportPopupProps) => {
const { t } = useTranslation("message");
const [text, setText] = useState("");
const { trigger } = useSWRMutation("/api/report", post);
const submit = () => {
trigger({
message_id: messageId,
text,
});
setText("");
onClose();
};
return (
<Modal isOpen={show} onClose={onClose}>
<ModalOverlay />
<ModalContent>
<ModalHeader>{t("report_title")}</ModalHeader>
<ModalCloseButton />
<ModalBody>
<Textarea onChange={(e) => setText(e.target.value)} resize="none" placeholder={t("report_placeholder")} />
</ModalBody>
<ModalFooter>
<Button colorScheme="blue" mr={3} onClick={submit}>
{t("send_report")}
</Button>
</ModalFooter>
</ModalContent>
</Modal>
);
};
@@ -1,198 +0,0 @@
import { Box, Grid, GridItem, Text, useColorModeValue } from "@chakra-ui/react";
import React from "react";
import { useState } from "react";
import { LikertButtons } from "src/components/Buttons/LikertButtons";
import { Explain } from "src/components/Explain";
interface LabelInputGroupProps {
labelIDs: Array<string>;
simple?: boolean;
onChange: (values: number[]) => unknown;
isEditable?: boolean;
}
interface LabelInfo {
zeroText: string;
oneText: string;
zeroDescription: string[];
oneDescription: string[];
inverted: boolean;
}
// This should be moved to the valid labels api endpoint
const label_messages: {
[label: string]: LabelInfo;
} = {
spam: {
zeroText: "Not Spam",
zeroDescription: ["Suitable for training Open Assistant."],
oneText: "Spam",
oneDescription: [
"Seems to be intentionally low-quality or irrelevant",
'We consider the following unwanted content as spam: trolling, intentional undermining of our purpose, illegal material, material that violates our code of conduct, and other things that are inappropriate for our dataset. We collect these under the common heading of "spam".',
"This is not an assessment of whether this message is the best possible answer. Especially for prompts or user-replies, we very much want to retain all kinds of responses in the dataset, so that the assistant can learn to reply appropriately.",
"Please mark this text as spam only if it is clearly unsuited to be part of our dataset, as outlined above, and try not to make any subjective value-judgments beyond that.",
],
inverted: true,
},
fails_task: {
zeroText: "Follows Instructions",
zeroDescription: [],
oneText: "Fails Task",
oneDescription: ["Fails to follow the correct instruction / task"],
inverted: true,
},
not_appropriate: {
zeroText: "Appropriate",
zeroDescription: [],
oneText: "Inappropriate",
oneDescription: ["Inappropriate for customer assistant"],
inverted: true,
},
violence: {
zeroText: "Harmless",
zeroDescription: [],
oneText: "Violent",
oneDescription: ["Encourages or fails to discourage violence/abuse/terrorism/self-harm"],
inverted: true,
},
excessive_harm: {
zeroText: "Safe",
zeroDescription: [],
oneText: "Harmful",
oneDescription: [
"Content likely to cause excessive harm not justifiable in the context",
"Harm refers to physical or mental damage or injury to someone or something. Excessive refers to a reasonable threshold of harm in the context, for instance damaging skin is not excessive in the context of surgery.",
],
inverted: true,
},
sexual_content: {
zeroText: "Non Sexual",
zeroDescription: [],
oneText: "Sexual",
oneDescription: ["Contains sexual content"],
inverted: true,
},
toxicity: {
zeroText: "Polite",
zeroDescription: [],
oneText: "Rude",
oneDescription: ["Contains rude, abusive, profane or insulting content"],
inverted: true,
},
moral_judgement: {
zeroText: "Non-Judgemental",
zeroDescription: [],
oneText: "Judgemental",
oneDescription: ["Expresses moral judgement"],
inverted: true,
},
political_content: {
zeroText: "Apolitical",
zeroDescription: [],
oneText: "Political",
oneDescription: ["Expresses political views"],
inverted: true,
},
humor: {
zeroText: "Serious",
zeroDescription: [],
oneText: "Humorous",
oneDescription: ["Contains humorous content including sarcasm"],
inverted: false,
},
hate_speech: {
zeroText: "Safe",
zeroDescription: [],
oneText: "Hateful",
oneDescription: [
"Content is abusive or threatening and expresses prejudice against a protected characteristic",
"Prejudice refers to preconceived views not based on reason. Protected characteristics include gender, ethnicity, religion, sexual orientation, and similar characteristics.",
],
inverted: true,
},
threat: {
zeroText: "Safe",
zeroDescription: [],
oneText: "Threatening",
oneDescription: ["Contains a threat against a person or persons"],
inverted: true,
},
misleading: {
zeroText: "Accurate",
zeroDescription: [],
oneText: "Misleading",
oneDescription: ["Contains text which is incorrect or misleading"],
inverted: true,
},
helpful: {
zeroText: "Unhelful",
zeroDescription: [],
oneText: "Helpful",
oneDescription: ["Completes the task to a high standard"],
inverted: false,
},
creative: {
zeroText: "Boring",
zeroDescription: [],
oneText: "Creative",
oneDescription: ["Expresses creativity in responding to the task"],
inverted: false,
},
};
export const LabelInputGroup = ({ labelIDs, onChange, isEditable = true }: LabelInputGroupProps) => {
const [labelValues, setLabelValues] = useState<number[]>(Array.from({ length: labelIDs.length }).map(() => null));
const cardColor = useColorModeValue("gray.50", "gray.800");
return (
<Grid templateColumns={"minmax(min-content, 30em)"} rowGap={2}>
{labelIDs.map((labelId, idx) => {
const { zeroText, oneText, zeroDescription, oneDescription, inverted } = label_messages[labelId];
let textA = zeroText;
let textB = oneText;
let descriptionA = zeroDescription;
let descriptionB = oneDescription;
if (inverted) [textA, textB, descriptionA, descriptionB] = [textB, textA, descriptionB, descriptionA];
return (
<Box key={idx} padding={2} bg={cardColor} borderRadius="md" position="relative">
<Grid
templateColumns={{
base: "minmax(0, 1fr) minmax(0, 1fr)",
sm: "minmax(0, 1fr) auto minmax(0, 1fr)",
}}
alignItems="center"
>
<Text>
{textA}
{descriptionA.length > 0 ? <Explain explanation={descriptionA} /> : null}
</Text>
<GridItem colSpan={{ base: 2, sm: 1 }} gridColumnStart={{ base: 1, sm: 2 }} gridRow={{ base: 2, sm: 1 }}>
<LikertButtons
isDisabled={!isEditable}
count={5}
data-cy="label-options"
onChange={(value) => {
const newState = labelValues.slice();
newState[idx] = value === null ? null : inverted ? 1 - value : value;
onChange(newState);
setLabelValues(newState);
}}
/>
</GridItem>
<GridItem>
<Text textAlign="right">
{textB}
{descriptionB.length > 0 ? <Explain explanation={descriptionB} /> : null}
</Text>
</GridItem>
</Grid>
</Box>
);
})}
</Grid>
);
};
@@ -0,0 +1,243 @@
import { Box, Grid, GridItem, Text, useColorModeValue } from "@chakra-ui/react";
import React from "react";
import { useState } from "react";
import { LikertButtons } from "src/components/Buttons/LikertButtons";
import { Explain } from "src/components/Explain";
interface LabelInputGroupProps {
labelIDs: Array<string>;
onChange: (values: number[]) => unknown;
isEditable?: boolean;
}
interface LabelInfo {
zeroText: string;
oneText: string;
zeroDescription: string[];
oneDescription: string[];
inverted: boolean;
}
const getLabelInfo = (label: string): LabelInfo => {
switch (label) {
case "spam":
return {
zeroText: "Not Spam",
zeroDescription: ["Suitable for training Open Assistant."],
oneText: "Spam",
oneDescription: [
"Seems to be intentionally low-quality or irrelevant",
'We consider the following unwanted content as spam: trolling, intentional undermining of our purpose, illegal material, material that violates our code of conduct, and other things that are inappropriate for our dataset. We collect these under the common heading of "spam".',
"This is not an assessment of whether this message is the best possible answer. Especially for prompts or user-replies, we very much want to retain all kinds of responses in the dataset, so that the assistant can learn to reply appropriately.",
"Please mark this text as spam only if it is clearly unsuited to be part of our dataset, as outlined above, and try not to make any subjective value-judgments beyond that.",
],
inverted: true,
};
case "fails_task":
return {
zeroText: "Follows Instructions",
zeroDescription: [],
oneText: "Fails Task",
oneDescription: ["Fails to follow the correct instruction / task"],
inverted: true,
};
case "not_appropriate":
return {
zeroText: "Appropriate",
zeroDescription: [],
oneText: "Inappropriate",
oneDescription: ["Inappropriate for customer assistant"],
inverted: true,
};
case "violence":
return {
zeroText: "Harmless",
zeroDescription: [],
oneText: "Violent",
oneDescription: ["Encourages or fails to discourage violence/abuse/terrorism/self-harm"],
inverted: true,
};
case "excessive_harm":
return {
zeroText: "Safe",
zeroDescription: [],
oneText: "Harmful",
oneDescription: [
"Content likely to cause excessive harm not justifiable in the context",
"Harm refers to physical or mental damage or injury to someone or something. Excessive refers to a reasonable threshold of harm in the context, for instance damaging skin is not excessive in the context of surgery.",
],
inverted: true,
};
case "sexual_content":
return {
zeroText: "Non Sexual",
zeroDescription: [],
oneText: "Sexual",
oneDescription: ["Contains sexual content"],
inverted: true,
};
case "toxicity":
return {
zeroText: "Polite",
zeroDescription: [],
oneText: "Rude",
oneDescription: ["Contains rude, abusive, profane or insulting content"],
inverted: true,
};
case "moral_judgement":
return {
zeroText: "Non-Judgemental",
zeroDescription: [],
oneText: "Judgemental",
oneDescription: ["Expresses moral judgement"],
inverted: true,
};
case "political_content":
return {
zeroText: "Apolitical",
zeroDescription: [],
oneText: "Political",
oneDescription: ["Expresses political views"],
inverted: true,
};
case "humor":
return {
zeroText: "Serious",
zeroDescription: [],
oneText: "Humorous",
oneDescription: ["Contains humorous content including sarcasm"],
inverted: false,
};
case "hate_speech":
return {
zeroText: "Safe",
zeroDescription: [],
oneText: "Hateful",
oneDescription: [
"Content is abusive or threatening and expresses prejudice against a protected characteristic",
"Prejudice refers to preconceived views not based on reason. Protected characteristics include gender, ethnicity, religion, sexual orientation, and similar characteristics.",
],
inverted: true,
};
case "threat":
return {
zeroText: "Safe",
zeroDescription: [],
oneText: "Threatening",
oneDescription: ["Contains a threat against a person or persons"],
inverted: true,
};
case "misleading":
return {
zeroText: "Accurate",
zeroDescription: [],
oneText: "Misleading",
oneDescription: ["Contains text which is incorrect or misleading"],
inverted: true,
};
case "helpfulness":
return {
zeroText: "Unhelpful",
zeroDescription: [],
oneText: "Helpful",
oneDescription: ["Completes the task to a high standard"],
inverted: false,
};
case "creative":
return {
zeroText: "Boring",
zeroDescription: [],
oneText: "Creative",
oneDescription: ["Expresses creativity in responding to the task"],
inverted: false,
};
case "pii":
return {
zeroText: "Clean",
zeroDescription: [],
oneText: "Contains PII",
oneDescription: ["Contains personally identifing information"],
inverted: false,
};
case "quality":
return {
zeroText: "Low Quality",
zeroDescription: [],
oneText: "High Quality",
oneDescription: [],
inverted: false,
};
case "creativity":
return {
zeroText: "Ordinary",
zeroDescription: [],
oneText: "Creative",
oneDescription: [],
inverted: false,
};
default:
return {
zeroText: `!${label}`,
zeroDescription: [],
oneText: label,
oneDescription: [],
inverted: false,
};
}
};
export const LabelLikertGroup = ({ labelIDs, onChange, isEditable = true }: LabelInputGroupProps) => {
const [labelValues, setLabelValues] = useState<number[]>(Array.from({ length: labelIDs.length }).map(() => null));
const cardColor = useColorModeValue("gray.50", "gray.800");
return (
<Grid templateColumns={"minmax(min-content, 30em)"} rowGap={2}>
{labelIDs.map((labelId, idx) => {
const { zeroText, oneText, zeroDescription, oneDescription, inverted } = getLabelInfo(labelId);
let textA = zeroText;
let textB = oneText;
let descriptionA = zeroDescription;
let descriptionB = oneDescription;
if (inverted) [textA, textB, descriptionA, descriptionB] = [textB, textA, descriptionB, descriptionA];
return (
<Box key={idx} padding={2} bg={cardColor} borderRadius="md" position="relative">
<Grid
templateColumns={{
base: "minmax(0, 1fr) minmax(0, 1fr)",
sm: "minmax(0, 1fr) auto minmax(0, 1fr)",
}}
alignItems="center"
>
<Text as="div">
{textA}
{descriptionA.length > 0 ? <Explain explanation={descriptionA} /> : null}
</Text>
<GridItem colSpan={{ base: 2, sm: 1 }} gridColumnStart={{ base: 1, sm: 2 }} gridRow={{ base: 2, sm: 1 }}>
<LikertButtons
isDisabled={!isEditable}
count={5}
data-cy="label-options"
onChange={(value) => {
const newState = labelValues.slice();
newState[idx] = value === null ? null : inverted ? 1 - value : value;
onChange(newState);
setLabelValues(newState);
}}
/>
</GridItem>
<GridItem>
<Text textAlign="right" as="div">
{textB}
{descriptionB.length > 0 ? <Explain explanation={descriptionB} /> : null}
</Text>
</GridItem>
</Grid>
</Box>
);
})}
</Grid>
);
};
+46 -51
View File
@@ -1,71 +1,66 @@
import { Box, Flex, IconButton, Tooltip, useColorModeValue } from "@chakra-ui/react";
import { Box, Flex, IconButton, Progress, Tooltip, useColorModeValue } from "@chakra-ui/react";
import { Edit2 } from "lucide-react";
import { SkipButton } from "src/components/Buttons/Skip";
import { SubmitButton } from "src/components/Buttons/Submit";
import { TaskInfo } from "src/components/TaskInfo/TaskInfo";
import { TaskStatus } from "src/components/Tasks/Task";
import { BaseTask } from "src/types/Task";
export interface TaskControlsProps {
// we need a task type
// eslint-disable-next-line @typescript-eslint/no-explicit-any
task: any;
className?: string;
task: BaseTask;
taskStatus: TaskStatus;
isLoading: boolean;
onEdit: () => void;
onReview: () => void;
onSubmit: () => void;
onSkip: (reason: string) => void;
}
export const TaskControls = (props: TaskControlsProps) => {
export const TaskControls = ({
task,
taskStatus,
isLoading,
onEdit,
onReview,
onSubmit,
onSkip,
}: TaskControlsProps) => {
const backgroundColor = useColorModeValue("white", "gray.800");
return (
<Box
width="full"
bg={backgroundColor}
borderRadius="xl"
p="6"
display="flex"
flexDirection={["column", "row"]}
shadow="base"
gap="4"
>
<TaskInfo id={props.task.id} output="Submit your answer" />
<Flex width={["full", "fit-content"]} justify="center" ml="auto" gap={2}>
{props.taskStatus === "REVIEW" || props.taskStatus === "SUBMITTED" ? (
<>
<Tooltip label="Edit">
<IconButton
size="lg"
data-cy="edit"
aria-label="edit"
onClick={props.onEdit}
icon={<Edit2 size="1em" />}
/>
</Tooltip>
<SubmitButton
colorScheme="green"
data-cy="submit"
isDisabled={props.taskStatus === "SUBMITTED"}
onClick={props.onSubmit}
>
Submit
</SubmitButton>
</>
) : (
<>
<SkipButton onSkip={props.onSkip} />
<SubmitButton
colorScheme="blue"
data-cy="review"
isDisabled={props.taskStatus === "NOT_SUBMITTABLE"}
onClick={props.onReview}
>
Review
</SubmitButton>
</>
)}
<Box width="full" bg={backgroundColor} borderRadius="xl" shadow="base">
{isLoading && <Progress size="sm" isIndeterminate />}
<Flex p="6" gap="4" direction={["column", "row"]}>
<TaskInfo id={task.id} output="Submit your answer" />
<Flex width={["full", "fit-content"]} justify="center" ml="auto" gap={2}>
{taskStatus.mode === "EDIT" ? (
<>
<SkipButton onSkip={onSkip} />
<SubmitButton
colorScheme="blue"
data-cy="review"
isDisabled={taskStatus.replyValidity === "INVALID"}
onClick={onReview}
>
Review
</SubmitButton>
</>
) : (
<>
<Tooltip label="Edit">
<IconButton size="lg" data-cy="edit" aria-label="edit" onClick={onEdit} icon={<Edit2 size="1em" />} />
</Tooltip>
<SubmitButton
colorScheme="green"
data-cy="submit"
isDisabled={taskStatus.mode === "SUBMITTED"}
onClick={onSubmit}
>
Submit
</SubmitButton>
</>
)}
</Flex>
</Flex>
</Box>
);

Some files were not shown because too many files have changed in this diff Show More