mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-29 16:30:24 +08:00
Merge branch 'main' into 766_admin_enhancement
This commit is contained in:
+29
@@ -0,0 +1,29 @@
|
||||
"""use 'en' instead 'en-US' as default lang
|
||||
|
||||
Revision ID: 160ac010efcc
|
||||
Revises: 4f26fec4d204
|
||||
Create Date: 2023-01-20 16:50:00
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "160ac010efcc"
|
||||
down_revision = "4f26fec4d204"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("message", "lang")
|
||||
op.add_column("message", sa.Column("lang", sa.String(length=32), server_default="en", nullable=False))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("message", "lang")
|
||||
op.add_column("message", sa.Column("lang", sa.VARCHAR(length=200), autoincrement=False, nullable=False))
|
||||
# ### end Alembic commands ###
|
||||
@@ -128,6 +128,7 @@ if settings.DEBUG_USE_SEED_DATA:
|
||||
user_message_id: str
|
||||
parent_message_id: Optional[str]
|
||||
text: str
|
||||
lang: Optional[str]
|
||||
role: str
|
||||
tree_state: Optional[message_tree_state.State]
|
||||
|
||||
@@ -184,6 +185,7 @@ if settings.DEBUG_USE_SEED_DATA:
|
||||
tr.bind_frontend_message_id(task.id, msg.task_message_id)
|
||||
message = pr.store_text_reply(
|
||||
msg.text,
|
||||
msg.lang,
|
||||
msg.task_message_id,
|
||||
msg.user_message_id,
|
||||
review_count=5,
|
||||
|
||||
@@ -5,6 +5,7 @@ from uuid import UUID
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.api.v1 import utils
|
||||
from oasst_backend.api.v1.messages import get_messages_cursor
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.user_repository import UserRepository
|
||||
@@ -76,20 +77,47 @@ def query_frontend_user_messages(
|
||||
Query frontend user messages.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client)
|
||||
messages = pr.query_messages(
|
||||
messages = pr.query_messages_ordered_by_created_date(
|
||||
auth_method=auth_method,
|
||||
username=username,
|
||||
api_client_id=api_client_id,
|
||||
desc=desc,
|
||||
limit=max_count,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
gte_created_date=start_date,
|
||||
lte_created_date=end_date,
|
||||
only_roots=only_roots,
|
||||
deleted=None if include_deleted else False,
|
||||
)
|
||||
return utils.prepare_message_list(messages)
|
||||
|
||||
|
||||
@router.get("/{auth_method}/{username}/messages/cursor", response_model=protocol.MessagePage)
|
||||
def query_frontend_user_messages_cursor(
|
||||
auth_method: str,
|
||||
username: str,
|
||||
lt: Optional[str] = None,
|
||||
gt: Optional[str] = None,
|
||||
only_roots: Optional[bool] = False,
|
||||
include_deleted: Optional[bool] = False,
|
||||
max_count: Optional[int] = Query(10, gt=0, le=1000),
|
||||
desc: Optional[bool] = False,
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
return get_messages_cursor(
|
||||
lt=lt,
|
||||
gt=gt,
|
||||
auth_method=auth_method,
|
||||
username=username,
|
||||
only_roots=only_roots,
|
||||
include_deleted=include_deleted,
|
||||
max_count=max_count,
|
||||
desc=desc,
|
||||
api_client=api_client,
|
||||
db=db,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{auth_method}/{username}/messages", status_code=HTTP_204_NO_CONTENT)
|
||||
def mark_frontend_user_messages_deleted(
|
||||
auth_method: str,
|
||||
@@ -98,5 +126,10 @@ def mark_frontend_user_messages_deleted(
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
pr = PromptRepository(db, api_client)
|
||||
messages = pr.query_messages(auth_method=auth_method, username=username, api_client_id=api_client.id)
|
||||
messages = pr.query_messages_ordered_by_created_date(
|
||||
auth_method=auth_method,
|
||||
username=username,
|
||||
api_client_id=api_client.id,
|
||||
limit=None,
|
||||
)
|
||||
pr.mark_messages_deleted(messages)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import datetime
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
@@ -6,8 +7,8 @@ from oasst_backend.api import deps
|
||||
from oasst_backend.api.v1 import utils
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol
|
||||
from oasst_shared.utils import unaware_to_utc
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_204_NO_CONTENT
|
||||
|
||||
@@ -16,31 +17,30 @@ router = APIRouter()
|
||||
|
||||
@router.get("/", response_model=list[protocol.Message])
|
||||
def query_messages(
|
||||
username: str = None,
|
||||
api_client_id: str = None,
|
||||
max_count: int = Query(10, gt=0, le=1000),
|
||||
start_date: datetime.datetime = None,
|
||||
end_date: datetime.datetime = None,
|
||||
only_roots: bool = False,
|
||||
desc: bool = True,
|
||||
allow_deleted: bool = False,
|
||||
auth_method: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
api_client_id: Optional[str] = None,
|
||||
max_count: Optional[int] = Query(10, gt=0, le=1000),
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
only_roots: Optional[bool] = False,
|
||||
desc: Optional[bool] = True,
|
||||
allow_deleted: Optional[bool] = False,
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
"""
|
||||
Query messages.
|
||||
"""
|
||||
start_date = unaware_to_utc(start_date)
|
||||
end_date = unaware_to_utc(end_date)
|
||||
|
||||
pr = PromptRepository(db, api_client)
|
||||
messages = pr.query_messages(
|
||||
messages = pr.query_messages_ordered_by_created_date(
|
||||
auth_method=auth_method,
|
||||
username=username,
|
||||
api_client_id=api_client_id,
|
||||
desc=desc,
|
||||
limit=max_count,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
gte_created_date=start_date,
|
||||
lte_created_date=end_date,
|
||||
only_roots=only_roots,
|
||||
deleted=None if allow_deleted else False,
|
||||
)
|
||||
@@ -48,6 +48,61 @@ def query_messages(
|
||||
return utils.prepare_message_list(messages)
|
||||
|
||||
|
||||
@router.get("/cursor", response_model=protocol.MessagePage)
|
||||
def get_messages_cursor(
|
||||
lt: Optional[str] = None,
|
||||
gt: Optional[str] = None,
|
||||
user_id: Optional[UUID] = None,
|
||||
auth_method: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
api_client_id: Optional[str] = None,
|
||||
only_roots: Optional[bool] = False,
|
||||
include_deleted: Optional[bool] = False,
|
||||
max_count: Optional[int] = Query(10, gt=0, le=1000),
|
||||
desc: Optional[bool] = False,
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
def split_cursor(x: str | None) -> tuple[datetime, UUID]:
|
||||
if not x:
|
||||
return None, None
|
||||
try:
|
||||
m = utils.split_uuid_pattern.match(x)
|
||||
if m:
|
||||
return datetime.fromisoformat(m[2]), UUID(m[1])
|
||||
return datetime.fromisoformat(x), None
|
||||
except ValueError:
|
||||
raise OasstError("Invalid cursor value", OasstErrorCode.INVALID_CURSOR_VALUE)
|
||||
|
||||
lte_created_date, lt_id = split_cursor(lt)
|
||||
gte_created_date, gt_id = split_cursor(gt)
|
||||
|
||||
pr = PromptRepository(db, api_client)
|
||||
messages = pr.query_messages_ordered_by_created_date(
|
||||
user_id=user_id,
|
||||
auth_method=auth_method,
|
||||
username=username,
|
||||
api_client_id=api_client_id,
|
||||
gte_created_date=gte_created_date,
|
||||
gt_id=gt_id,
|
||||
lte_created_date=lte_created_date,
|
||||
lt_id=lt_id,
|
||||
only_roots=only_roots,
|
||||
deleted=None if include_deleted else False,
|
||||
desc=desc,
|
||||
limit=max_count,
|
||||
)
|
||||
|
||||
items = utils.prepare_message_list(messages)
|
||||
n, p = None, None
|
||||
if len(items) > 0:
|
||||
p = str(items[0].id) + "$" + items[0].created_date.isoformat()
|
||||
n = str(items[-1].id) + "$" + items[-1].created_date.isoformat()
|
||||
|
||||
order = "desc" if desc else "asc"
|
||||
return protocol.MessagePage(prev=p, next=n, sort_key="created_date", order=order, items=items)
|
||||
|
||||
|
||||
@router.get("/{message_id}", response_model=protocol.Message)
|
||||
def get_message(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
|
||||
@@ -39,7 +39,7 @@ def request_task(
|
||||
pr.ensure_user_is_enabled()
|
||||
|
||||
tm = TreeManager(db, pr)
|
||||
task, message_tree_id, parent_message_id = tm.next_task(request.type)
|
||||
task, message_tree_id, parent_message_id = tm.next_task(desired_task_type=request.type, lang=request.lang)
|
||||
pr.task_repository.store_task(task, message_tree_id, parent_message_id, request.collective)
|
||||
|
||||
except OasstError:
|
||||
@@ -54,6 +54,7 @@ def request_task(
|
||||
def tasks_availability(
|
||||
*,
|
||||
user: Optional[protocol_schema.User] = None,
|
||||
lang: Optional[str] = "en",
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
):
|
||||
@@ -62,7 +63,7 @@ def tasks_availability(
|
||||
try:
|
||||
pr = PromptRepository(db, api_client, client_user=user)
|
||||
tm = TreeManager(db, pr)
|
||||
return tm.determine_task_availability()
|
||||
return tm.determine_task_availability(lang)
|
||||
|
||||
except OasstError:
|
||||
raise
|
||||
|
||||
@@ -5,10 +5,12 @@ from uuid import UUID
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.api.v1 import utils
|
||||
from oasst_backend.api.v1.messages import get_messages_cursor
|
||||
from oasst_backend.models import ApiClient, User
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.user_repository import UserRepository
|
||||
from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame
|
||||
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_204_NO_CONTENT
|
||||
@@ -70,6 +72,70 @@ def get_users_ordered_by_display_name(
|
||||
return [u.to_protocol_frontend_user() for u in users]
|
||||
|
||||
|
||||
@router.get("/cursor", response_model=protocol.FrontEndUserPage)
|
||||
def get_users_cursor(
|
||||
lt: Optional[str] = None,
|
||||
gt: Optional[str] = None,
|
||||
sort_key: Optional[str] = Query("username", max_length=32),
|
||||
max_count: Optional[int] = Query(100, gt=0, le=10000),
|
||||
api_client_id: Optional[UUID] = None,
|
||||
search_text: Optional[str] = None,
|
||||
auth_method: Optional[str] = None,
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
def split_cursor(x: str | None) -> tuple[str, UUID]:
|
||||
if not x:
|
||||
return None, None
|
||||
m = utils.split_uuid_pattern.match(x)
|
||||
if m:
|
||||
return m[2], UUID(m[1])
|
||||
return x, None
|
||||
|
||||
items: list[protocol.FrontEndUser]
|
||||
n, p = None, None
|
||||
if sort_key == "username":
|
||||
lte_username, lt_id = split_cursor(lt)
|
||||
gte_username, gt_id = split_cursor(gt)
|
||||
items = get_users_ordered_by_username(
|
||||
api_client_id=api_client_id,
|
||||
gte_username=gte_username,
|
||||
gt_id=gt_id,
|
||||
lte_username=lte_username,
|
||||
lt_id=lt_id,
|
||||
auth_method=auth_method,
|
||||
search_text=search_text,
|
||||
max_count=max_count,
|
||||
api_client=api_client,
|
||||
db=db,
|
||||
)
|
||||
if len(items) > 0:
|
||||
p = str(items[0].user_id) + "$" + items[0].id
|
||||
n = str(items[-1].user_id) + "$" + items[-1].id
|
||||
elif sort_key == "display_name":
|
||||
lte_display_name, lt_id = split_cursor(lt)
|
||||
gte_display_name, gt_id = split_cursor(gt)
|
||||
items = get_users_ordered_by_display_name(
|
||||
api_client_id=api_client_id,
|
||||
gte_display_name=gte_display_name,
|
||||
gt_id=gt_id,
|
||||
lte_display_name=lte_display_name,
|
||||
lt_id=lt_id,
|
||||
auth_method=auth_method,
|
||||
search_text=search_text,
|
||||
max_count=max_count,
|
||||
api_client=api_client,
|
||||
db=db,
|
||||
)
|
||||
if len(items) > 0:
|
||||
p = str(items[0].user_id) + "$" + items[0].display_name
|
||||
n = str(items[-1].user_id) + "$" + items[-1].display_name
|
||||
else:
|
||||
raise OasstError(f"Unsupported sort key: '{sort_key}'", OasstErrorCode.SORT_KEY_UNSUPPORTED)
|
||||
|
||||
return protocol.FrontEndUserPage(prev=p, next=n, sort_key=sort_key, order="asc", items=items)
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=protocol.FrontEndUser)
|
||||
def get_user(
|
||||
user_id: UUID,
|
||||
@@ -130,13 +196,13 @@ def query_user_messages(
|
||||
Query user messages.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client)
|
||||
messages = pr.query_messages(
|
||||
messages = pr.query_messages_ordered_by_created_date(
|
||||
user_id=user_id,
|
||||
api_client_id=api_client_id,
|
||||
desc=desc,
|
||||
limit=max_count,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
gte_created_date=start_date,
|
||||
lte_created_date=end_date,
|
||||
only_roots=only_roots,
|
||||
deleted=None if include_deleted else False,
|
||||
)
|
||||
@@ -144,12 +210,37 @@ def query_user_messages(
|
||||
return utils.prepare_message_list(messages)
|
||||
|
||||
|
||||
@router.get("/{user_id}/messages/cursor", response_model=protocol.MessagePage)
|
||||
def query_user_messages_cursor(
|
||||
user_id: Optional[UUID],
|
||||
lt: Optional[str] = None,
|
||||
gt: Optional[str] = None,
|
||||
only_roots: Optional[bool] = False,
|
||||
include_deleted: Optional[bool] = False,
|
||||
max_count: Optional[int] = Query(10, gt=0, le=1000),
|
||||
desc: Optional[bool] = False,
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
return get_messages_cursor(
|
||||
lt=lt,
|
||||
gt=gt,
|
||||
user_id=user_id,
|
||||
only_roots=only_roots,
|
||||
include_deleted=include_deleted,
|
||||
max_count=max_count,
|
||||
desc=desc,
|
||||
api_client=api_client,
|
||||
db=db,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{user_id}/messages", status_code=HTTP_204_NO_CONTENT)
|
||||
def mark_user_messages_deleted(
|
||||
user_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
pr = PromptRepository(db, api_client)
|
||||
messages = pr.query_messages(user_id=user_id)
|
||||
messages = pr.query_messages_ordered_by_created_date(user_id=user_id, limit=None)
|
||||
pr.mark_messages_deleted(messages)
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import re
|
||||
from uuid import UUID
|
||||
|
||||
from oasst_backend.models import Message
|
||||
@@ -10,6 +11,7 @@ def prepare_message(m: Message) -> protocol.Message:
|
||||
frontend_message_id=m.frontend_message_id,
|
||||
parent_id=m.parent_id,
|
||||
text=m.text,
|
||||
lang=m.lang,
|
||||
is_assistant=(m.role == "assistant"),
|
||||
created_date=m.created_date,
|
||||
)
|
||||
@@ -22,10 +24,11 @@ def prepare_message_list(messages: list[Message]) -> list[protocol.Message]:
|
||||
def prepare_conversation_message_list(messages: list[Message]) -> list[protocol.ConversationMessage]:
|
||||
return [
|
||||
protocol.ConversationMessage(
|
||||
text=message.text,
|
||||
is_assistant=(message.role == "assistant"),
|
||||
id=message.id,
|
||||
frontend_message_id=message.frontend_message_id,
|
||||
text=message.text,
|
||||
lang=message.lang,
|
||||
is_assistant=(message.role == "assistant"),
|
||||
)
|
||||
for message in messages
|
||||
]
|
||||
@@ -41,3 +44,8 @@ def prepare_tree(tree: list[Message], tree_id: UUID) -> protocol.MessageTree:
|
||||
tree_messages.append(prepare_message(message))
|
||||
|
||||
return protocol.MessageTree(id=tree_id, messages=tree_messages)
|
||||
|
||||
|
||||
split_uuid_pattern = re.compile(
|
||||
r"^([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12})\$(.*)$"
|
||||
)
|
||||
|
||||
@@ -38,7 +38,7 @@ class Message(SQLModel, table=True):
|
||||
payload: Optional[PayloadContainer] = Field(
|
||||
sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=True)
|
||||
)
|
||||
lang: str = Field(nullable=False, max_length=200, default="en-US")
|
||||
lang: str = Field(sa_column=sa.Column(sa.String(32), server_default="en", nullable=False))
|
||||
depth: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
|
||||
children_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
|
||||
deleted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import datetime
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from http import HTTPStatus
|
||||
from typing import List, Optional, Tuple
|
||||
from uuid import UUID, uuid4
|
||||
@@ -28,7 +28,8 @@ from oasst_backend.utils.database_utils import CommitMode, managed_tx_method
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from oasst_shared.schemas.protocol import SystemStats
|
||||
from sqlmodel import Session, func, not_, text, update
|
||||
from oasst_shared.utils import unaware_to_utc
|
||||
from sqlmodel import Session, and_, func, not_, or_, text, update
|
||||
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
@@ -85,6 +86,7 @@ class PromptRepository:
|
||||
task_id: UUID,
|
||||
role: str,
|
||||
payload: db_payload.MessagePayload,
|
||||
lang: str,
|
||||
payload_type: str = None,
|
||||
depth: int = 0,
|
||||
review_count: int = 0,
|
||||
@@ -107,6 +109,7 @@ class PromptRepository:
|
||||
api_client_id=self.api_client.id,
|
||||
payload_type=payload_type,
|
||||
payload=PayloadContainer(payload=payload),
|
||||
lang=lang,
|
||||
depth=depth,
|
||||
review_count=review_count,
|
||||
review_result=review_result,
|
||||
@@ -146,6 +149,7 @@ class PromptRepository:
|
||||
def store_text_reply(
|
||||
self,
|
||||
text: str,
|
||||
lang: str,
|
||||
frontend_message_id: str,
|
||||
user_frontend_message_id: str,
|
||||
review_count: int = 0,
|
||||
@@ -209,6 +213,7 @@ class PromptRepository:
|
||||
task_id=task.id,
|
||||
role=role,
|
||||
payload=db_payload.MessagePayload(text=text),
|
||||
lang=lang or "en",
|
||||
depth=depth,
|
||||
review_count=review_count,
|
||||
review_result=review_result,
|
||||
@@ -660,58 +665,85 @@ class PromptRepository:
|
||||
max_message = max(tree, key=lambda m: m.children_count)
|
||||
return max_message, [m for m in tree if m.parent_id == max_message.id]
|
||||
|
||||
def query_messages(
|
||||
def query_messages_ordered_by_created_date(
|
||||
self,
|
||||
user_id: Optional[UUID] = None,
|
||||
auth_method: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
api_client_id: Optional[UUID] = None,
|
||||
desc: bool = True,
|
||||
limit: Optional[int] = 10,
|
||||
start_date: Optional[datetime.datetime] = None,
|
||||
end_date: Optional[datetime.datetime] = None,
|
||||
gte_created_date: Optional[datetime] = None,
|
||||
gt_id: Optional[UUID] = None,
|
||||
lte_created_date: Optional[datetime] = None,
|
||||
lt_id: Optional[UUID] = None,
|
||||
only_roots: bool = False,
|
||||
deleted: Optional[bool] = None,
|
||||
desc: bool = False,
|
||||
limit: Optional[int] = 100,
|
||||
) -> list[Message]:
|
||||
if not self.api_client.trusted and not api_client_id:
|
||||
# Let unprivileged api clients query their own messages without api_client_id being set
|
||||
api_client_id = self.api_client.id
|
||||
if not self.api_client.trusted:
|
||||
if not api_client_id:
|
||||
# Let unprivileged api clients query their own messages without api_client_id being set
|
||||
api_client_id = self.api_client.id
|
||||
|
||||
if not self.api_client.trusted and api_client_id != self.api_client.id:
|
||||
# Unprivileged api client asks for foreign messages
|
||||
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
|
||||
if api_client_id != self.api_client.id:
|
||||
# Unprivileged api client asks for foreign messages
|
||||
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
|
||||
|
||||
messages = self.db.query(Message)
|
||||
qry = self.db.query(Message)
|
||||
if user_id:
|
||||
messages = messages.filter(Message.user_id == user_id)
|
||||
qry = qry.filter(Message.user_id == user_id)
|
||||
if username or auth_method:
|
||||
if not username and auth_method:
|
||||
raise OasstError("Auth method or username missing.", OasstErrorCode.AUTH_AND_USERNAME_REQUIRED)
|
||||
messages = messages.join(User)
|
||||
messages = messages.filter(User.username == username, User.auth_method == auth_method)
|
||||
qry = qry.join(User)
|
||||
qry = qry.filter(User.username == username, User.auth_method == auth_method)
|
||||
if api_client_id:
|
||||
messages = messages.filter(Message.api_client_id == api_client_id)
|
||||
qry = qry.filter(Message.api_client_id == api_client_id)
|
||||
|
||||
if start_date:
|
||||
messages = messages.filter(Message.created_date >= start_date)
|
||||
if end_date:
|
||||
messages = messages.filter(Message.created_date < end_date)
|
||||
gte_created_date = unaware_to_utc(gte_created_date)
|
||||
lte_created_date = unaware_to_utc(lte_created_date)
|
||||
|
||||
if gte_created_date is not None:
|
||||
if gt_id:
|
||||
qry = qry.filter(
|
||||
or_(
|
||||
Message.created_date > gte_created_date,
|
||||
and_(Message.created_date == gte_created_date, Message.id > gt_id),
|
||||
)
|
||||
)
|
||||
else:
|
||||
qry = qry.filter(Message.created_date >= gte_created_date)
|
||||
elif gt_id:
|
||||
raise OasstError("Need id and date for keyset pagination", OasstErrorCode.GENERIC_ERROR)
|
||||
|
||||
if lte_created_date is not None:
|
||||
if lt_id:
|
||||
qry = qry.filter(
|
||||
or_(
|
||||
Message.created_date < lte_created_date,
|
||||
and_(Message.created_date == lte_created_date, Message.id < lt_id),
|
||||
)
|
||||
)
|
||||
else:
|
||||
qry = qry.filter(Message.created_date <= lte_created_date)
|
||||
elif lt_id:
|
||||
raise OasstError("Need id and date for keyset pagination", OasstErrorCode.GENERIC_ERROR)
|
||||
|
||||
if only_roots:
|
||||
messages = messages.filter(Message.parent_id.is_(None))
|
||||
qry = qry.filter(Message.parent_id.is_(None))
|
||||
|
||||
if deleted is not None:
|
||||
messages = messages.filter(Message.deleted == deleted)
|
||||
qry = qry.filter(Message.deleted == deleted)
|
||||
|
||||
if desc:
|
||||
messages = messages.order_by(Message.created_date.desc())
|
||||
qry = qry.order_by(Message.created_date.desc(), Message.id.desc())
|
||||
else:
|
||||
messages = messages.order_by(Message.created_date.asc())
|
||||
qry = qry.order_by(Message.created_date.asc(), Message.id.asc())
|
||||
|
||||
if limit is not None:
|
||||
messages = messages.limit(limit)
|
||||
qry = qry.limit(limit)
|
||||
|
||||
return messages.all()
|
||||
return qry.all()
|
||||
|
||||
def update_children_counts(self, message_tree_id: UUID):
|
||||
sql_update_children_count = """
|
||||
|
||||
@@ -190,14 +190,18 @@ class TreeManager:
|
||||
|
||||
return task_count_by_type
|
||||
|
||||
def determine_task_availability(self) -> dict[protocol_schema.TaskRequestType, int]:
|
||||
def determine_task_availability(self, lang: str) -> dict[protocol_schema.TaskRequestType, int]:
|
||||
self.pr.ensure_user_is_enabled()
|
||||
|
||||
num_active_trees = self.query_num_active_trees()
|
||||
extendible_parents = self.query_extendible_parents()
|
||||
prompts_need_review = self.query_prompts_need_review()
|
||||
replies_need_review = self.query_replies_need_review()
|
||||
incomplete_rankings = self.query_incomplete_rankings()
|
||||
if not lang:
|
||||
lang = "en"
|
||||
logger.warning("Task availability request without lang tag received, assuming lang='en'.")
|
||||
|
||||
num_active_trees = self.query_num_active_trees(lang=lang)
|
||||
extendible_parents = self.query_extendible_parents(lang=lang)
|
||||
prompts_need_review = self.query_prompts_need_review(lang=lang)
|
||||
replies_need_review = self.query_replies_need_review(lang=lang)
|
||||
incomplete_rankings = self.query_incomplete_rankings(lang=lang)
|
||||
|
||||
return self._determine_task_availability_internal(
|
||||
num_active_trees=num_active_trees,
|
||||
@@ -208,23 +212,29 @@ class TreeManager:
|
||||
)
|
||||
|
||||
def next_task(
|
||||
self, desired_task_type: protocol_schema.TaskRequestType = protocol_schema.TaskRequestType.random
|
||||
self,
|
||||
desired_task_type: protocol_schema.TaskRequestType = protocol_schema.TaskRequestType.random,
|
||||
lang: str = "en",
|
||||
) -> Tuple[protocol_schema.Task, Optional[UUID], Optional[UUID]]:
|
||||
|
||||
logger.debug("TreeManager.next_task()")
|
||||
logger.debug(f"TreeManager.next_task({desired_task_type=}, {lang=})")
|
||||
|
||||
self.pr.ensure_user_is_enabled()
|
||||
|
||||
num_active_trees = self.query_num_active_trees()
|
||||
prompts_need_review = self.query_prompts_need_review()
|
||||
replies_need_review = self.query_replies_need_review()
|
||||
extendible_parents = self.query_extendible_parents()
|
||||
if not lang:
|
||||
lang = "en"
|
||||
logger.warning("Task request without lang tag received, assuming 'en'.")
|
||||
|
||||
incomplete_rankings = self.query_incomplete_rankings()
|
||||
num_active_trees = self.query_num_active_trees(lang=lang)
|
||||
prompts_need_review = self.query_prompts_need_review(lang=lang)
|
||||
replies_need_review = self.query_replies_need_review(lang=lang)
|
||||
extendible_parents = self.query_extendible_parents(lang=lang)
|
||||
|
||||
incomplete_rankings = self.query_incomplete_rankings(lang=lang)
|
||||
if not self.cfg.rank_prompter_replies:
|
||||
incomplete_rankings = list(filter(lambda r: r.role == "assistant", incomplete_rankings))
|
||||
|
||||
active_tree_sizes = self.query_extendible_trees()
|
||||
active_tree_sizes = self.query_extendible_trees(lang=lang)
|
||||
|
||||
# determine type of task to generate
|
||||
num_missing_replies = sum(x.remaining_messages for x in active_tree_sizes)
|
||||
@@ -458,6 +468,7 @@ class TreeManager:
|
||||
# here we store the text reply in the database
|
||||
message = pr.store_text_reply(
|
||||
text=interaction.text,
|
||||
lang=interaction.lang,
|
||||
frontend_message_id=interaction.message_id,
|
||||
user_frontend_message_id=interaction.user_message_id,
|
||||
)
|
||||
@@ -665,7 +676,7 @@ class TreeManager:
|
||||
# calculate acceptance based on spam label
|
||||
return np.mean([1 - l.labels[protocol_schema.TextLabel.spam] for l in labels])
|
||||
|
||||
def query_prompts_need_review(self) -> list[Message]:
|
||||
def query_prompts_need_review(self, lang: str) -> list[Message]:
|
||||
"""
|
||||
Select initial prompt messages with less then required rankings in active message tree
|
||||
(active == True in message_tree_state)
|
||||
@@ -682,6 +693,7 @@ class TreeManager:
|
||||
not_(Message.deleted),
|
||||
Message.review_count < self.cfg.num_reviews_initial_prompt,
|
||||
Message.parent_id.is_(None),
|
||||
Message.lang == lang,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -690,7 +702,7 @@ class TreeManager:
|
||||
|
||||
return qry.all()
|
||||
|
||||
def query_replies_need_review(self) -> list[Message]:
|
||||
def query_replies_need_review(self, lang: str) -> list[Message]:
|
||||
"""
|
||||
Select child messages (parent_id IS NOT NULL) with less then required rankings
|
||||
in active message tree (active == True in message_tree_state)
|
||||
@@ -707,6 +719,7 @@ class TreeManager:
|
||||
not_(Message.deleted),
|
||||
Message.review_count < self.cfg.num_reviews_reply,
|
||||
Message.parent_id.is_not(None),
|
||||
Message.lang == lang,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -724,13 +737,14 @@ FROM message_tree_state mts
|
||||
WHERE mts.active -- only consider active trees
|
||||
AND mts.state = :ranking_state -- message tree must be in ranking state
|
||||
AND m.review_result -- must be reviewed
|
||||
AND m.lang = :lang -- matches lang
|
||||
AND NOT m.deleted -- not deleted
|
||||
AND m.parent_id IS NOT NULL -- ignore initial prompts
|
||||
GROUP BY m.parent_id, m.role
|
||||
HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings
|
||||
"""
|
||||
|
||||
def query_incomplete_rankings(self) -> list[IncompleteRankingsRow]:
|
||||
def query_incomplete_rankings(self, lang: str) -> list[IncompleteRankingsRow]:
|
||||
"""Query parents which have childern that need further rankings"""
|
||||
|
||||
r = self.db.execute(
|
||||
@@ -738,6 +752,7 @@ HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings
|
||||
{
|
||||
"num_required_rankings": self.cfg.num_required_rankings,
|
||||
"ranking_state": message_tree_state.State.RANKING,
|
||||
"lang": lang,
|
||||
},
|
||||
)
|
||||
return [IncompleteRankingsRow.from_orm(x) for x in r.all()]
|
||||
@@ -753,13 +768,14 @@ WHERE mts.active -- only consider active trees
|
||||
AND NOT m.deleted -- ignore deleted messages as parents
|
||||
AND m.depth < mts.max_depth -- ignore leaf nodes as parents
|
||||
AND m.review_result -- parent node must have positive review
|
||||
AND m.lang = :lang -- parent matches lang
|
||||
AND NOT coalesce(c.deleted, FALSE) -- don't count deleted children
|
||||
AND (c.review_result OR coalesce(c.review_count, 0) < :num_reviews_reply) -- don't count children with negative review but count elements under review
|
||||
GROUP BY m.id, m.role, m.depth, m.message_tree_id, mts.max_children_count
|
||||
HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children
|
||||
"""
|
||||
|
||||
def query_extendible_parents(self) -> list[ExtendibleParentRow]:
|
||||
def query_extendible_parents(self, lang: str) -> list[ExtendibleParentRow]:
|
||||
"""Query parent messages that have not reached the maximum number of replies."""
|
||||
|
||||
r = self.db.execute(
|
||||
@@ -767,6 +783,7 @@ HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children
|
||||
{
|
||||
"growing_state": message_tree_state.State.GROWING,
|
||||
"num_reviews_reply": self.cfg.num_reviews_reply,
|
||||
"lang": lang,
|
||||
},
|
||||
)
|
||||
return [ExtendibleParentRow.from_orm(x) for x in r.all()]
|
||||
@@ -787,7 +804,7 @@ GROUP BY m.message_tree_id, mts.goal_tree_size
|
||||
HAVING COUNT(m.id) < mts.goal_tree_size
|
||||
"""
|
||||
|
||||
def query_extendible_trees(self) -> list[ActiveTreeSizeRow]:
|
||||
def query_extendible_trees(self, lang: str) -> list[ActiveTreeSizeRow]:
|
||||
"""Query size of active message trees in growing state."""
|
||||
|
||||
r = self.db.execute(
|
||||
@@ -795,6 +812,7 @@ HAVING COUNT(m.id) < mts.goal_tree_size
|
||||
{
|
||||
"growing_state": message_tree_state.State.GROWING,
|
||||
"num_reviews_reply": self.cfg.num_reviews_reply,
|
||||
"lang": lang,
|
||||
},
|
||||
)
|
||||
return [ActiveTreeSizeRow.from_orm(x) for x in r.all()]
|
||||
@@ -894,8 +912,12 @@ INNER JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Ranki
|
||||
logger.info(f"Inserting missing message tree state for message: {id} ({tree_size=}, {state=:s})")
|
||||
self._insert_default_state(id, state=state)
|
||||
|
||||
def query_num_active_trees(self) -> int:
|
||||
query = self.db.query(func.count(MessageTreeState.message_tree_id)).filter(MessageTreeState.active)
|
||||
def query_num_active_trees(self, lang: str) -> int:
|
||||
query = (
|
||||
self.db.query(func.count(MessageTreeState.message_tree_id))
|
||||
.join(Message, MessageTreeState.message_tree_id == Message.id)
|
||||
.filter(MessageTreeState.active, Message.lang == lang)
|
||||
)
|
||||
return query.scalar()
|
||||
|
||||
def query_reviews_for_message(self, message_id: UUID) -> list[TextLabels]:
|
||||
|
||||
@@ -202,9 +202,11 @@ class UserRepository:
|
||||
) -> list[User]:
|
||||
if not self.api_client.trusted:
|
||||
if not api_client_id:
|
||||
# Let unprivileged api clients query their own users without api_client_id being set
|
||||
api_client_id = self.api_client.id
|
||||
|
||||
if api_client_id != self.api_client.id:
|
||||
# Unprivileged api client asks for foreign users
|
||||
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
|
||||
|
||||
qry = self.db.query(User).order_by(User.display_name, User.id)
|
||||
|
||||
@@ -19,6 +19,10 @@ class OasstErrorCode(IntEnum):
|
||||
API_CLIENT_NOT_AUTHORIZED = 2
|
||||
ROOT_TOKEN_NOT_AUTHORIZED = 3
|
||||
DATABASE_MAX_RETRIES_EXHAUSTED = 4
|
||||
|
||||
SORT_KEY_UNSUPPORTED = 100
|
||||
INVALID_CURSOR_VALUE = 101
|
||||
|
||||
TOO_MANY_REQUESTS = 429
|
||||
|
||||
SERVER_ERROR0 = 500
|
||||
|
||||
@@ -37,12 +37,25 @@ class FrontEndUser(User):
|
||||
created_date: Optional[datetime] = None
|
||||
|
||||
|
||||
class PageResult(BaseModel):
|
||||
prev: str | None
|
||||
next: str | None
|
||||
sort_key: str
|
||||
items: list
|
||||
order: Literal["asc", "desc"]
|
||||
|
||||
|
||||
class FrontEndUserPage(PageResult):
|
||||
items: list[FrontEndUser]
|
||||
|
||||
|
||||
class ConversationMessage(BaseModel):
|
||||
"""Represents a message in a conversation between the user and the assistant."""
|
||||
|
||||
id: Optional[UUID] = None
|
||||
frontend_message_id: Optional[str] = None
|
||||
text: str
|
||||
lang: Optional[str] # BCP 47
|
||||
is_assistant: bool
|
||||
|
||||
|
||||
@@ -57,6 +70,10 @@ class Message(ConversationMessage):
|
||||
created_date: Optional[datetime] = None
|
||||
|
||||
|
||||
class MessagePage(PageResult):
|
||||
items: list[Message]
|
||||
|
||||
|
||||
class MessageTree(BaseModel):
|
||||
"""All messages belonging to the same message tree."""
|
||||
|
||||
@@ -72,6 +89,7 @@ class TaskRequest(BaseModel):
|
||||
# this is optional. https://github.com/pydantic/pydantic/issues/1270
|
||||
user: Optional[User] = Field(None, nullable=True)
|
||||
collective: bool = False
|
||||
lang: Optional[str] = Field(None, nullable=True) # BCP 47
|
||||
|
||||
|
||||
class TaskAck(BaseModel):
|
||||
@@ -266,6 +284,7 @@ class TextReplyToMessage(Interaction):
|
||||
message_id: str
|
||||
user_message_id: str
|
||||
text: constr(min_length=1, strip_whitespace=True)
|
||||
lang: Optional[str] # BCP 47
|
||||
|
||||
|
||||
class MessageRating(Interaction):
|
||||
|
||||
@@ -73,6 +73,7 @@ async def test_can_post_interaction(oasst_api_client_mocked: OasstApiClient):
|
||||
message_id="123",
|
||||
user_message_id="321",
|
||||
text="This is my reply",
|
||||
lang="en",
|
||||
user=protocol_schema.User(
|
||||
id="123",
|
||||
display_name="lomz",
|
||||
|
||||
@@ -13,5 +13,6 @@
|
||||
"sign_in": "Sign In",
|
||||
"sign_out": "Sign Out",
|
||||
"terms_of_service": "Terms of Service",
|
||||
"title": "Open Assistant"
|
||||
"title": "Open Assistant",
|
||||
"last_updated_at": "Last updated at: {{val, datetime}}"
|
||||
}
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import { Table, TableContainer, Tbody, Td, Th, Thead, Tr, useColorModeValue } from "@chakra-ui/react";
|
||||
import React from "react";
|
||||
import { Table, TableContainer, Tbody, Td, Text, Th, Thead, Tr, useColorModeValue } from "@chakra-ui/react";
|
||||
import { useTranslation } from "next-i18next";
|
||||
import React, { useMemo } from "react";
|
||||
import { useTable } from "react-table";
|
||||
import { get } from "src/lib/api";
|
||||
import { LeaderboardEntity, LeaderboardTimeFrame } from "src/types/Leaderboard";
|
||||
import { LeaderboardReply, LeaderboardTimeFrame } from "src/types/Leaderboard";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
|
||||
const columns = [
|
||||
@@ -26,13 +27,26 @@ const columns = [
|
||||
* Presents a grid of leaderboard entries with more detailed information.
|
||||
*/
|
||||
const LeaderboardGridCell = ({ timeFrame }: { timeFrame: LeaderboardTimeFrame }) => {
|
||||
const { data } = useSWRImmutable<LeaderboardEntity[]>(`/api/leaderboard?time_frame=${timeFrame}`, get, {
|
||||
fallbackData: [],
|
||||
const { t } = useTranslation();
|
||||
const { data: reply } = useSWRImmutable<LeaderboardReply>(`/api/leaderboard?time_frame=${timeFrame}`, get, {
|
||||
revalidateOnMount: true,
|
||||
});
|
||||
|
||||
const { getTableProps, getTableBodyProps, headerGroups, rows, prepareRow } = useTable({
|
||||
columns,
|
||||
data: reply?.leaderboard ?? [],
|
||||
});
|
||||
|
||||
const backgroundColor = useColorModeValue("white", "gray.800");
|
||||
|
||||
const { getTableProps, getTableBodyProps, headerGroups, rows, prepareRow } = useTable({ columns, data });
|
||||
const lastUpdated = useMemo(() => {
|
||||
const val = new Date(reply?.last_updated);
|
||||
return t("last_updated_at", { val, formatParams: { val: { dateStyle: "full", timeStyle: "short" } } });
|
||||
}, [t, reply?.last_updated]);
|
||||
|
||||
if (!reply) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<TableContainer>
|
||||
@@ -66,6 +80,7 @@ const LeaderboardGridCell = ({ timeFrame }: { timeFrame: LeaderboardTimeFrame })
|
||||
})}
|
||||
</Tbody>
|
||||
</Table>
|
||||
<Text p="2">{lastUpdated}</Text>
|
||||
</TableContainer>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -50,6 +50,7 @@ export function MessageTableEntry(props: MessageTableEntryProps) {
|
||||
bg={item.is_assistant ? backgroundColor : backgroundColor2}
|
||||
onClick={props.enabled && goToMessage}
|
||||
_hover={props.enabled && { cursor: "pointer", opacity: 0.9 }}
|
||||
whiteSpace="pre-wrap"
|
||||
>
|
||||
{inlineAvatar && avatar}
|
||||
{item.text}
|
||||
|
||||
@@ -6,9 +6,9 @@ import { LeaderboardTimeFrame } from "src/types/Leaderboard";
|
||||
* Returns the set of valid labels that can be applied to messages.
|
||||
*/
|
||||
const handler = withoutRole("banned", async (req, res) => {
|
||||
const time_frame = (req.query.time_frame as LeaderboardTimeFrame) || LeaderboardTimeFrame.day;
|
||||
const { leaderboard } = await oasstApiClient.fetch_leaderboard(time_frame);
|
||||
res.status(200).json(leaderboard);
|
||||
const time_frame = (req.query.time_frame as LeaderboardTimeFrame) ?? LeaderboardTimeFrame.day;
|
||||
const info = await oasstApiClient.fetch_leaderboard(time_frame);
|
||||
res.status(200).json(info);
|
||||
});
|
||||
|
||||
export default handler;
|
||||
|
||||
@@ -12,6 +12,7 @@ export const enum LeaderboardTimeFrame {
|
||||
}
|
||||
export interface LeaderboardReply {
|
||||
time_frame: LeaderboardTimeFrame;
|
||||
last_updated: string; // date time iso string
|
||||
leaderboard: LeaderboardEntity[];
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user