mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-28 16:20:34 +08:00
Merge branch 'LAION-AI:main' into main
This commit is contained in:
@@ -4,7 +4,6 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- add-api-docs-workflow
|
||||
paths:
|
||||
- "oasst-shared/**"
|
||||
- "backend/**"
|
||||
@@ -46,8 +45,8 @@ jobs:
|
||||
|
||||
- run: ./scripts/backend-development/stop-mock-server.sh
|
||||
|
||||
- uses: stefanzweifel/git-auto-commit-action@v4
|
||||
with:
|
||||
file_pattern: "docs/docs/api/openapi.json"
|
||||
commit_message:
|
||||
update docs/docs/api/openapi.json by run ${{ github.run_id }}
|
||||
#- uses: stefanzweifel/git-auto-commit-action@v4
|
||||
# with:
|
||||
# file_pattern: "docs/docs/api/openapi.json"
|
||||
# commit_message:
|
||||
# update docs/docs/api/openapi.json by run ${{ github.run_id }}
|
||||
|
||||
@@ -19,17 +19,20 @@
|
||||
ansible.builtin.file:
|
||||
path: "./{{ stack_name }}"
|
||||
state: directory
|
||||
mode: 0755
|
||||
|
||||
- name: Copy redis.conf to managed node
|
||||
ansible.builtin.copy:
|
||||
src: ./redis.conf
|
||||
dest: "./{{ stack_name }}/redis.conf"
|
||||
mode: 0644
|
||||
|
||||
- name: Set up Redis
|
||||
community.docker.docker_container:
|
||||
name: "oasst-{{ stack_name }}-redis"
|
||||
image: redis
|
||||
state: started
|
||||
recreate: "{{ (stack_name == 'dev') | bool }}"
|
||||
restart_policy: always
|
||||
network_mode: "oasst-{{ stack_name }}"
|
||||
healthcheck:
|
||||
@@ -46,6 +49,7 @@
|
||||
name: "oasst-{{ stack_name }}-postgres-{{ item.name }}"
|
||||
image: postgres:15
|
||||
state: started
|
||||
recreate: "{{ (stack_name == 'dev') | bool }}"
|
||||
restart_policy: always
|
||||
network_mode: "oasst-{{ stack_name }}"
|
||||
env:
|
||||
@@ -73,11 +77,12 @@
|
||||
env:
|
||||
POSTGRES_HOST: "oasst-{{ stack_name }}-postgres-backend"
|
||||
REDIS_HOST: "oasst-{{ stack_name }}-redis"
|
||||
DEBUG_ALLOW_ANY_API_KEY: "true"
|
||||
DEBUG_ALLOW_DEBUG_API_KEY: "true"
|
||||
DEBUG_USE_SEED_DATA: "true"
|
||||
DEBUG_ALLOW_SELF_LABELING: "true"
|
||||
DEBUG_ALLOW_SELF_LABELING:
|
||||
"{{ 'true' if stack_name == 'dev' else 'false' }}"
|
||||
MAX_WORKERS: "1"
|
||||
RATE_LIMIT: "false"
|
||||
RATE_LIMIT: "{{ 'false' if stack_name == 'dev' else 'true' }}"
|
||||
DEBUG_SKIP_EMBEDDING_COMPUTATION: "true"
|
||||
DEBUG_SKIP_TOXICITY_CALCULATION: "true"
|
||||
ports:
|
||||
|
||||
+83
@@ -0,0 +1,83 @@
|
||||
"""change user_stats ranking counts
|
||||
|
||||
Revision ID: 7c98102efbca
|
||||
Revises: 619255ae9076
|
||||
Create Date: 2023-01-15 00:02:45.622986
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7c98102efbca"
|
||||
down_revision = "619255ae9076"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("user_stats")
|
||||
op.create_table(
|
||||
"user_stats",
|
||||
sa.Column("user_id", UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("modified_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("base_date", sa.DateTime(), nullable=True),
|
||||
sa.Column("time_frame", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("leader_score", sa.Integer(), nullable=False),
|
||||
sa.Column("prompts", sa.Integer(), nullable=False),
|
||||
sa.Column("replies_assistant", sa.Integer(), nullable=False),
|
||||
sa.Column("replies_prompter", sa.Integer(), nullable=False),
|
||||
sa.Column("labels_simple", sa.Integer(), nullable=False),
|
||||
sa.Column("labels_full", sa.Integer(), nullable=False),
|
||||
sa.Column("rankings_total", sa.Integer(), nullable=False),
|
||||
sa.Column("rankings_good", sa.Integer(), nullable=False),
|
||||
sa.Column("accepted_prompts", sa.Integer(), nullable=False),
|
||||
sa.Column("accepted_replies_assistant", sa.Integer(), nullable=False),
|
||||
sa.Column("accepted_replies_prompter", sa.Integer(), nullable=False),
|
||||
sa.Column("reply_ranked_1", sa.Integer(), nullable=False),
|
||||
sa.Column("reply_ranked_2", sa.Integer(), nullable=False),
|
||||
sa.Column("reply_ranked_3", sa.Integer(), nullable=False),
|
||||
sa.Column("streak_last_day_date", sa.DateTime(), nullable=True),
|
||||
sa.Column("streak_days", sa.Integer(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("user_id", "time_frame"),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"user_stats",
|
||||
sa.Column("reply_prompter_ranked_3", sa.INTEGER(), server_default="0", autoincrement=False, nullable=False),
|
||||
)
|
||||
op.add_column(
|
||||
"user_stats",
|
||||
sa.Column("reply_assistant_ranked_1", sa.INTEGER(), server_default="0", autoincrement=False, nullable=False),
|
||||
)
|
||||
op.add_column(
|
||||
"user_stats",
|
||||
sa.Column("reply_assistant_ranked_2", sa.INTEGER(), server_default="0", autoincrement=False, nullable=False),
|
||||
)
|
||||
op.add_column(
|
||||
"user_stats",
|
||||
sa.Column("reply_prompter_ranked_2", sa.INTEGER(), server_default="0", autoincrement=False, nullable=False),
|
||||
)
|
||||
op.add_column(
|
||||
"user_stats",
|
||||
sa.Column("reply_prompter_ranked_1", sa.INTEGER(), server_default="0", autoincrement=False, nullable=False),
|
||||
)
|
||||
op.add_column(
|
||||
"user_stats",
|
||||
sa.Column("reply_assistant_ranked_3", sa.INTEGER(), server_default="0", autoincrement=False, nullable=False),
|
||||
)
|
||||
op.drop_column("user_stats", "reply_ranked_3")
|
||||
op.drop_column("user_stats", "reply_ranked_2")
|
||||
op.drop_column("user_stats", "reply_ranked_1")
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,30 @@
|
||||
"""add indices for created_date
|
||||
|
||||
Revision ID: 423557e869e4
|
||||
Revises: 7c98102efbca
|
||||
Create Date: 2023-01-15 11:39:10.407859
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "423557e869e4"
|
||||
down_revision = "7c98102efbca"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_index(op.f("ix_message_created_date"), "message", ["created_date"], unique=False)
|
||||
op.create_index(op.f("ix_message_reaction_created_date"), "message_reaction", ["created_date"], unique=False)
|
||||
op.create_index(op.f("ix_text_labels_created_date"), "text_labels", ["created_date"], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f("ix_text_labels_created_date"), table_name="text_labels")
|
||||
op.drop_index(op.f("ix_message_reaction_created_date"), table_name="message_reaction")
|
||||
op.drop_index(op.f("ix_message_created_date"), table_name="message")
|
||||
# ### end Alembic commands ###
|
||||
+33
@@ -0,0 +1,33 @@
|
||||
"""add rank and indices to user_stats
|
||||
|
||||
Revision ID: 0964ac95170d
|
||||
Revises: 423557e869e4
|
||||
Create Date: 2023-01-15 16:54:09.510018
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0964ac95170d"
|
||||
down_revision = "423557e869e4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("user_stats", sa.Column("rank", sa.Integer(), nullable=True))
|
||||
op.create_index(
|
||||
"ix_user_stats__timeframe__rank__user_id", "user_stats", ["time_frame", "rank", "user_id"], unique=True
|
||||
)
|
||||
op.create_index("ix_user_stats__timeframe__user_id", "user_stats", ["time_frame", "user_id"], unique=True)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index("ix_user_stats__timeframe__user_id", table_name="user_stats")
|
||||
op.drop_index("ix_user_stats__timeframe__rank__user_id", table_name="user_stats")
|
||||
op.drop_column("user_stats", "rank")
|
||||
# ### end Alembic commands ###
|
||||
@@ -9,6 +9,7 @@ import alembic.config
|
||||
import fastapi
|
||||
import redis.asyncio as redis
|
||||
from fastapi_limiter import FastAPILimiter
|
||||
from fastapi_utils.tasks import repeat_every
|
||||
from loguru import logger
|
||||
from oasst_backend.api.deps import get_dummy_api_client
|
||||
from oasst_backend.api.v1.api import api_router
|
||||
@@ -18,6 +19,7 @@ from oasst_backend.database import engine
|
||||
from oasst_backend.models import message_tree_state
|
||||
from oasst_backend.prompt_repository import PromptRepository, TaskRepository, UserRepository
|
||||
from oasst_backend.tree_manager import TreeManager
|
||||
from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from pydantic import BaseModel
|
||||
@@ -195,6 +197,54 @@ def ensure_tree_states():
|
||||
logger.exception("TreeManager.ensure_tree_states() failed.")
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
@repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_DAY, wait_first=False)
|
||||
def update_leader_board_day() -> None:
|
||||
try:
|
||||
with Session(engine) as session:
|
||||
usr = UserStatsRepository(session)
|
||||
usr.update_stats(time_frame=UserStatsTimeFrame.day)
|
||||
session.commit()
|
||||
except Exception:
|
||||
logger.exception("Error during leaderboard update (daily)")
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
@repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_WEEK, wait_first=False)
|
||||
def update_leader_board_week() -> None:
|
||||
try:
|
||||
with Session(engine) as session:
|
||||
usr = UserStatsRepository(session)
|
||||
usr.update_stats(time_frame=UserStatsTimeFrame.week)
|
||||
session.commit()
|
||||
except Exception:
|
||||
logger.exception("Error during user states update (weekly)")
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
@repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_MONTH, wait_first=False)
|
||||
def update_leader_board_month() -> None:
|
||||
try:
|
||||
with Session(engine) as session:
|
||||
usr = UserStatsRepository(session)
|
||||
usr.update_stats(time_frame=UserStatsTimeFrame.month)
|
||||
session.commit()
|
||||
except Exception:
|
||||
logger.exception("Error during user states update (monthly)")
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
@repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_TOTAL, wait_first=False)
|
||||
def update_leader_board_total() -> None:
|
||||
try:
|
||||
with Session(engine) as session:
|
||||
usr = UserStatsRepository(session)
|
||||
usr.update_stats(time_frame=UserStatsTimeFrame.total)
|
||||
session.commit()
|
||||
except Exception:
|
||||
logger.exception("Error during user states update (total)")
|
||||
|
||||
|
||||
app.include_router(api_router, prefix=settings.API_V1_STR)
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from http import HTTPStatus
|
||||
from secrets import token_hex
|
||||
from typing import Generator
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends, Request, Response, Security
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from fastapi.security.api_key import APIKey, APIKeyHeader, APIKeyQuery
|
||||
from fastapi_limiter.depends import RateLimiter
|
||||
from loguru import logger
|
||||
@@ -22,6 +22,8 @@ def get_db() -> Generator:
|
||||
api_key_query = APIKeyQuery(name="api_key", auto_error=False)
|
||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
|
||||
bearer_token = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
async def get_api_key(
|
||||
api_key_query: str = Security(api_key_query),
|
||||
@@ -33,22 +35,47 @@ async def get_api_key(
|
||||
return api_key_header
|
||||
|
||||
|
||||
def get_dummy_api_client(db: Session) -> ApiClient:
|
||||
def create_api_client(
|
||||
*,
|
||||
session: Session,
|
||||
description: str,
|
||||
frontend_type: str,
|
||||
trusted: bool | None = False,
|
||||
admin_email: str | None = None,
|
||||
api_key: str | None = None,
|
||||
) -> ApiClient:
|
||||
if api_key is None:
|
||||
api_key = token_hex(32)
|
||||
|
||||
logger.info(f"Creating new api client with {api_key=}")
|
||||
api_client = ApiClient(
|
||||
api_key=api_key,
|
||||
description=description,
|
||||
frontend_type=frontend_type,
|
||||
trusted=trusted,
|
||||
admin_email=admin_email,
|
||||
)
|
||||
session.add(api_client)
|
||||
session.commit()
|
||||
session.refresh(api_client)
|
||||
return api_client
|
||||
|
||||
|
||||
def get_dummy_api_client(session: Session) -> ApiClient:
|
||||
# make sure that a dummy api key exits in db (foreign key references)
|
||||
ANY_API_KEY_ID = UUID("00000000-1111-2222-3333-444444444444")
|
||||
api_client: ApiClient = db.query(ApiClient).filter(ApiClient.id == ANY_API_KEY_ID).first()
|
||||
DUMMY_API_KEY = "1234"
|
||||
api_client: ApiClient = session.query(ApiClient).filter(ApiClient.api_key == DUMMY_API_KEY).first()
|
||||
if api_client is None:
|
||||
token = token_hex(32)
|
||||
logger.info(f"ANY_API_KEY missing, inserting api_key: {token}")
|
||||
api_client = ApiClient(
|
||||
id=ANY_API_KEY_ID,
|
||||
api_key=token,
|
||||
description="ANY_API_KEY, random token",
|
||||
logger.info(f"ANY_API_KEY missing, inserting api_key: {DUMMY_API_KEY}")
|
||||
api_client = create_api_client(
|
||||
session=session,
|
||||
api_key=DUMMY_API_KEY,
|
||||
description="Dummy api key for debugging",
|
||||
trusted=True,
|
||||
frontend_type="Test frontend",
|
||||
)
|
||||
db.add(api_client)
|
||||
db.commit()
|
||||
session.add(api_client)
|
||||
session.commit()
|
||||
return api_client
|
||||
|
||||
|
||||
@@ -58,7 +85,7 @@ def api_auth(
|
||||
) -> ApiClient:
|
||||
if api_key or settings.DEBUG_SKIP_API_KEY_CHECK:
|
||||
|
||||
if settings.DEBUG_SKIP_API_KEY_CHECK or settings.DEBUG_ALLOW_ANY_API_KEY:
|
||||
if settings.DEBUG_SKIP_API_KEY_CHECK or settings.DEBUG_ALLOW_DEBUG_API_KEY:
|
||||
return get_dummy_api_client(db)
|
||||
|
||||
api_client = db.query(ApiClient).filter(ApiClient.api_key == api_key).first()
|
||||
@@ -93,6 +120,18 @@ def get_trusted_api_client(
|
||||
return client
|
||||
|
||||
|
||||
def get_root_token(bearer_token: HTTPAuthorizationCredentials = Security(bearer_token)) -> str:
|
||||
if bearer_token:
|
||||
token = bearer_token.credentials
|
||||
if token and token in settings.ROOT_TOKENS:
|
||||
return token
|
||||
raise OasstError(
|
||||
"Could not validate credentials",
|
||||
error_code=OasstErrorCode.ROOT_TOKEN_NOT_AUTHORIZED,
|
||||
http_status_code=HTTPStatus.FORBIDDEN,
|
||||
)
|
||||
|
||||
|
||||
class UserRateLimiter(RateLimiter):
|
||||
def __init__(
|
||||
self, times: int = 100, milliseconds: int = 0, seconds: int = 0, minutes: int = 1, hours: int = 0
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
import pydantic
|
||||
from fastapi import APIRouter, Depends
|
||||
from loguru import logger
|
||||
from oasst_backend.api import deps
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class CreateApiClientRequest(pydantic.BaseModel):
|
||||
description: str
|
||||
frontend_type: str
|
||||
trusted: bool | None = False
|
||||
admin_email: str | None = None
|
||||
|
||||
|
||||
@router.post("/api_client")
|
||||
async def create_api_client(
|
||||
request: CreateApiClientRequest,
|
||||
root_token: str = Depends(deps.get_root_token),
|
||||
session: deps.Session = Depends(deps.get_db),
|
||||
) -> str:
|
||||
logger.info(f"Creating new api client with {request=}")
|
||||
api_client = deps.create_api_client(
|
||||
session=session,
|
||||
description=request.description,
|
||||
frontend_type=request.frontend_type,
|
||||
trusted=request.trusted,
|
||||
admin_email=request.admin_email,
|
||||
)
|
||||
logger.info(f"Created api_client with key {api_client.api_key}")
|
||||
return api_client.api_key
|
||||
@@ -1,5 +1,6 @@
|
||||
from fastapi import APIRouter
|
||||
from oasst_backend.api.v1 import (
|
||||
admin,
|
||||
frontend_messages,
|
||||
frontend_users,
|
||||
hugging_face,
|
||||
@@ -19,5 +20,6 @@ api_router.include_router(frontend_messages.router, prefix="/frontend_messages",
|
||||
api_router.include_router(users.router, prefix="/users", tags=["users"])
|
||||
api_router.include_router(frontend_users.router, prefix="/frontend_users", tags=["frontend_users"])
|
||||
api_router.include_router(stats.router, prefix="/stats", tags=["stats"])
|
||||
api_router.include_router(leaderboards.router, prefix="/experimental/leaderboards", tags=["leaderboards"])
|
||||
api_router.include_router(leaderboards.router, prefix="/leaderboards", tags=["leaderboards"])
|
||||
api_router.include_router(hugging_face.router, prefix="/hf", tags=["hugging_face"])
|
||||
api_router.include_router(admin.router, prefix="/admin", tags=["admin"])
|
||||
|
||||
@@ -1,26 +1,21 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.user_repository import UserRepository
|
||||
from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame
|
||||
from oasst_shared.schemas.protocol import LeaderboardStats
|
||||
from sqlmodel import Session
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/create/assistant")
|
||||
def get_assistant_leaderboard(
|
||||
@router.get("/{time_frame}", response_model=LeaderboardStats)
|
||||
def get_leaderboard(
|
||||
time_frame: UserStatsTimeFrame,
|
||||
max_count: Optional[int] = Query(100, gt=0, le=10000),
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
) -> LeaderboardStats:
|
||||
ur = UserRepository(db, api_client)
|
||||
return ur.get_user_leaderboard(role="assistant")
|
||||
|
||||
|
||||
@router.get("/create/prompter")
|
||||
def get_prompter_leaderboard(
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
) -> LeaderboardStats:
|
||||
ur = UserRepository(db, api_client)
|
||||
return ur.get_user_leaderboard(role="prompter")
|
||||
usr = UserStatsRepository(db)
|
||||
return usr.get_leaderboard(time_frame, limit=max_count)
|
||||
|
||||
@@ -8,6 +8,7 @@ from oasst_backend.api.v1 import utils
|
||||
from oasst_backend.models import ApiClient, User
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.user_repository import UserRepository
|
||||
from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame
|
||||
from oasst_shared.schemas import protocol
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_204_NO_CONTENT
|
||||
@@ -96,3 +97,24 @@ def mark_user_messages_deleted(
|
||||
pr = PromptRepository(db, api_client)
|
||||
messages = pr.query_messages(user_id=user_id)
|
||||
pr.mark_messages_deleted(messages)
|
||||
|
||||
|
||||
@router.get("/{user_id}/stats", response_model=dict[str, protocol.UserScore | None])
|
||||
def query_user_stats(
|
||||
user_id: UUID,
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
usr = UserStatsRepository(db)
|
||||
return usr.get_user_stats_all_time_frames(user_id=user_id)
|
||||
|
||||
|
||||
@router.get("/{user_id}/stats/{time_frame}", response_model=protocol.UserScore)
|
||||
def query_user_stats_timeframe(
|
||||
user_id: UUID,
|
||||
time_frame: UserStatsTimeFrame,
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
usr = UserStatsRepository(db)
|
||||
return usr.get_user_stats_all_time_frames(user_id=user_id)[time_frame.value]
|
||||
|
||||
@@ -71,7 +71,7 @@ class Settings(BaseSettings):
|
||||
REDIS_HOST: str = "localhost"
|
||||
REDIS_PORT: str = "6379"
|
||||
|
||||
DEBUG_ALLOW_ANY_API_KEY: bool = False
|
||||
DEBUG_ALLOW_DEBUG_API_KEY: bool = False
|
||||
DEBUG_SKIP_API_KEY_CHECK: bool = False
|
||||
DEBUG_USE_SEED_DATA: bool = False
|
||||
DEBUG_USE_SEED_DATA_PATH: Optional[FilePath] = (
|
||||
@@ -83,6 +83,8 @@ class Settings(BaseSettings):
|
||||
|
||||
HUGGING_FACE_API_KEY: str = ""
|
||||
|
||||
ROOT_TOKENS: List[str] = ["1234"] # supply a string that can be parsed to a json list
|
||||
|
||||
@validator("DATABASE_URI", pre=True)
|
||||
def assemble_db_connection(cls, v: Optional[str], values: Dict[str, Any]) -> Any:
|
||||
if isinstance(v, str):
|
||||
@@ -109,6 +111,22 @@ class Settings(BaseSettings):
|
||||
|
||||
tree_manager: Optional[TreeManagerConfiguration] = TreeManagerConfiguration()
|
||||
|
||||
USER_STATS_INTERVAL_DAY: int = 15 # minutes
|
||||
USER_STATS_INTERVAL_WEEK: int = 60 # minutes
|
||||
USER_STATS_INTERVAL_MONTH: int = 120 # minutes
|
||||
USER_STATS_INTERVAL_TOTAL: int = 240 # minutes
|
||||
|
||||
@validator(
|
||||
"USER_STATS_INTERVAL_DAY",
|
||||
"USER_STATS_INTERVAL_WEEK",
|
||||
"USER_STATS_INTERVAL_MONTH",
|
||||
"USER_STATS_INTERVAL_TOTAL",
|
||||
)
|
||||
def validate_user_stats_intervals(cls, v: int):
|
||||
if v < 1:
|
||||
raise ValueError(v)
|
||||
return v
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
|
||||
@@ -8,12 +8,13 @@ from .message_tree_state import MessageTreeState
|
||||
from .task import Task
|
||||
from .text_labels import TextLabels
|
||||
from .user import User
|
||||
from .user_stats import UserStats
|
||||
from .user_stats import UserStats, UserStatsTimeFrame
|
||||
|
||||
__all__ = [
|
||||
"ApiClient",
|
||||
"User",
|
||||
"UserStats",
|
||||
"UserStatsTimeFrame",
|
||||
"Message",
|
||||
"MessageEmbedding",
|
||||
"MessageReaction",
|
||||
|
||||
@@ -65,12 +65,16 @@ class RankingReactionPayload(ReactionPayload):
|
||||
type: Literal["message_ranking"] = "message_ranking"
|
||||
ranking: list[int]
|
||||
ranked_message_ids: list[UUID]
|
||||
ranking_parent_id: Optional[UUID]
|
||||
message_tree_id: Optional[UUID]
|
||||
|
||||
|
||||
@payload_type
|
||||
class RankConversationRepliesPayload(TaskPayload):
|
||||
conversation: protocol_schema.Conversation # the conversation so far
|
||||
reply_messages: list[protocol_schema.ConversationMessage]
|
||||
ranking_parent_id: Optional[UUID]
|
||||
message_tree_id: Optional[UUID]
|
||||
|
||||
|
||||
@payload_type
|
||||
@@ -104,6 +108,7 @@ class LabelInitialPromptPayload(TaskPayload):
|
||||
prompt: str
|
||||
valid_labels: list[str]
|
||||
mandatory_labels: Optional[list[str]]
|
||||
mode: Optional[protocol_schema.LabelTaskMode]
|
||||
|
||||
|
||||
@payload_type
|
||||
@@ -115,6 +120,7 @@ class LabelConversationReplyPayload(TaskPayload):
|
||||
reply: str
|
||||
valid_labels: list[str]
|
||||
mandatory_labels: Optional[list[str]]
|
||||
mode: Optional[protocol_schema.LabelTaskMode]
|
||||
|
||||
|
||||
@payload_type
|
||||
|
||||
@@ -30,7 +30,7 @@ class Message(SQLModel, table=True):
|
||||
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
|
||||
frontend_message_id: str = Field(max_length=200, nullable=False)
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp(), index=True)
|
||||
)
|
||||
payload_type: str = Field(nullable=False, max_length=200)
|
||||
payload: Optional[PayloadContainer] = Field(
|
||||
|
||||
@@ -19,7 +19,7 @@ class MessageReaction(SQLModel, table=True):
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), nullable=False, primary_key=True)
|
||||
)
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp(), index=True)
|
||||
)
|
||||
payload_type: str = Field(nullable=False, max_length=200)
|
||||
payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False))
|
||||
|
||||
@@ -4,6 +4,7 @@ from uuid import UUID, uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from oasst_shared.utils import utcnow
|
||||
from sqlalchemy import false
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
@@ -35,4 +36,4 @@ class Task(SQLModel, table=True):
|
||||
|
||||
@property
|
||||
def expired(self) -> bool:
|
||||
return self.expiry_date is not None and datetime.utcnow() > self.expiry_date
|
||||
return self.expiry_date is not None and utcnow() > self.expiry_date
|
||||
|
||||
@@ -17,7 +17,7 @@ class TextLabels(SQLModel, table=True):
|
||||
)
|
||||
user_id: UUID = Field(sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), nullable=False))
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp(), index=True),
|
||||
)
|
||||
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
|
||||
text: str = Field(nullable=False, max_length=2**16)
|
||||
|
||||
@@ -5,7 +5,7 @@ from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from sqlmodel import Field, SQLModel
|
||||
from sqlmodel import Field, Index, SQLModel
|
||||
|
||||
|
||||
class UserStatsTimeFrame(str, Enum):
|
||||
@@ -17,17 +17,24 @@ class UserStatsTimeFrame(str, Enum):
|
||||
|
||||
class UserStats(SQLModel, table=True):
|
||||
__tablename__ = "user_stats"
|
||||
__table_args__ = (
|
||||
Index("ix_user_stats__timeframe__user_id", "time_frame", "user_id", unique=True),
|
||||
Index("ix_user_stats__timeframe__rank__user_id", "time_frame", "rank", "user_id", unique=True),
|
||||
)
|
||||
|
||||
time_frame: Optional[str] = Field(nullable=False, primary_key=True)
|
||||
user_id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), primary_key=True)
|
||||
)
|
||||
time_frame: Optional[str] = Field(nullable=False, primary_key=True)
|
||||
base_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(), nullable=True))
|
||||
|
||||
leader_score: int = 0
|
||||
modified_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
|
||||
)
|
||||
|
||||
rank: int = Field(nullable=True)
|
||||
|
||||
prompts: int = 0
|
||||
replies_assistant: int = 0
|
||||
replies_prompter: int = 0
|
||||
@@ -40,14 +47,27 @@ class UserStats(SQLModel, table=True):
|
||||
accepted_replies_assistant: int = 0
|
||||
accepted_replies_prompter: int = 0
|
||||
|
||||
reply_assistant_ranked_1: int = 0
|
||||
reply_assistant_ranked_2: int = 0
|
||||
reply_assistant_ranked_3: int = 0
|
||||
|
||||
reply_prompter_ranked_1: int = 0
|
||||
reply_prompter_ranked_2: int = 0
|
||||
reply_prompter_ranked_3: int = 0
|
||||
reply_ranked_1: int = 0
|
||||
reply_ranked_2: int = 0
|
||||
reply_ranked_3: int = 0
|
||||
|
||||
# only used for time span "total"
|
||||
streak_last_day_date: Optional[datetime] = Field(nullable=True)
|
||||
streak_days: Optional[int] = Field(nullable=True)
|
||||
|
||||
def compute_leader_score(self) -> int:
|
||||
return (
|
||||
self.prompts
|
||||
+ self.replies_assistant * 4
|
||||
+ self.replies_prompter
|
||||
+ self.labels_simple
|
||||
+ self.labels_full * 2
|
||||
+ self.rankings_total
|
||||
+ self.rankings_good
|
||||
+ self.accepted_prompts
|
||||
+ self.accepted_replies_assistant * 4
|
||||
+ self.accepted_replies_prompter
|
||||
+ self.reply_ranked_1 * 9
|
||||
+ self.reply_ranked_2 * 3
|
||||
+ self.reply_ranked_3
|
||||
)
|
||||
|
||||
@@ -260,7 +260,10 @@ class PromptRepository:
|
||||
self.db.add(message)
|
||||
|
||||
reaction_payload = db_payload.RankingReactionPayload(
|
||||
ranking=ranking.ranking, ranked_message_ids=ranked_message_ids
|
||||
ranking=ranking.ranking,
|
||||
ranked_message_ids=ranked_message_ids,
|
||||
ranking_parent_id=task_payload.ranking_parent_id,
|
||||
message_tree_id=task_payload.message_tree_id,
|
||||
)
|
||||
reaction = self.insert_reaction(task.id, reaction_payload)
|
||||
self.journal.log_ranking(task, message_id=parent_msg.id, ranking=ranking.ranking)
|
||||
|
||||
@@ -67,17 +67,30 @@ class TaskRepository:
|
||||
|
||||
case protocol_schema.RankPrompterRepliesTask:
|
||||
payload = db_payload.RankPrompterRepliesPayload(
|
||||
type=task.type, conversation=task.conversation, reply_messages=task.reply_messages
|
||||
type=task.type,
|
||||
conversation=task.conversation,
|
||||
reply_messages=task.reply_messages,
|
||||
ranking_parent_id=task.ranking_parent_id,
|
||||
message_tree_id=task.message_tree_id,
|
||||
)
|
||||
|
||||
case protocol_schema.RankAssistantRepliesTask:
|
||||
payload = db_payload.RankAssistantRepliesPayload(
|
||||
type=task.type, conversation=task.conversation, reply_messages=task.reply_messages
|
||||
type=task.type,
|
||||
conversation=task.conversation,
|
||||
reply_messages=task.reply_messages,
|
||||
ranking_parent_id=task.ranking_parent_id,
|
||||
message_tree_id=task.message_tree_id,
|
||||
)
|
||||
|
||||
case protocol_schema.LabelInitialPromptTask:
|
||||
payload = db_payload.LabelInitialPromptPayload(
|
||||
type=task.type, message_id=task.message_id, prompt=task.prompt, valid_labels=task.valid_labels
|
||||
type=task.type,
|
||||
message_id=task.message_id,
|
||||
prompt=task.prompt,
|
||||
valid_labels=task.valid_labels,
|
||||
mandatory_labels=task.mandatory_labels,
|
||||
mode=task.mode,
|
||||
)
|
||||
|
||||
case protocol_schema.LabelPrompterReplyTask:
|
||||
@@ -88,6 +101,7 @@ class TaskRepository:
|
||||
reply=task.reply,
|
||||
valid_labels=task.valid_labels,
|
||||
mandatory_labels=task.mandatory_labels,
|
||||
mode=task.mode,
|
||||
)
|
||||
|
||||
case protocol_schema.LabelAssistantReplyTask:
|
||||
@@ -98,6 +112,7 @@ class TaskRepository:
|
||||
reply=task.reply,
|
||||
valid_labels=task.valid_labels,
|
||||
mandatory_labels=task.mandatory_labels,
|
||||
mode=task.mode,
|
||||
)
|
||||
|
||||
case _:
|
||||
|
||||
@@ -207,6 +207,9 @@ class TreeManager:
|
||||
ranking_parent_id = random.choice(incomplete_rankings).parent_id
|
||||
|
||||
messages = self.pr.fetch_message_conversation(ranking_parent_id)
|
||||
assert len(messages) > 1 and messages[-1].id == ranking_parent_id
|
||||
ranking_parent = messages[-1]
|
||||
assert not ranking_parent.deleted and ranking_parent.review_result
|
||||
conversation = prepare_conversation(messages)
|
||||
replies = self.pr.fetch_message_children(ranking_parent_id, reviewed=True, exclude_deleted=True)
|
||||
|
||||
@@ -218,12 +221,20 @@ class TreeManager:
|
||||
if messages[-1].role == "assistant":
|
||||
logger.info("Generating a RankPrompterRepliesTask.")
|
||||
task = protocol_schema.RankPrompterRepliesTask(
|
||||
conversation=conversation, replies=replies, reply_messages=reply_messages
|
||||
conversation=conversation,
|
||||
replies=replies,
|
||||
reply_messages=reply_messages,
|
||||
ranking_parent_id=ranking_parent.id,
|
||||
message_tree_id=ranking_parent.message_tree_id,
|
||||
)
|
||||
else:
|
||||
logger.info("Generating a RankAssistantRepliesTask.")
|
||||
task = protocol_schema.RankAssistantRepliesTask(
|
||||
conversation=conversation, replies=replies, reply_messages=reply_messages
|
||||
conversation=conversation,
|
||||
replies=replies,
|
||||
reply_messages=reply_messages,
|
||||
ranking_parent_id=ranking_parent.id,
|
||||
message_tree_id=ranking_parent.message_tree_id,
|
||||
)
|
||||
|
||||
parent_message_id = ranking_parent_id
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from oasst_backend.models import ApiClient, Message, User
|
||||
from oasst_backend.models import ApiClient, User
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from oasst_shared.schemas.protocol import LeaderboardStats
|
||||
from sqlmodel import Session, func
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
@@ -137,27 +136,6 @@ class UserRepository:
|
||||
self.db.commit()
|
||||
return user
|
||||
|
||||
def get_user_leaderboard(self, role: str) -> LeaderboardStats:
|
||||
"""
|
||||
Get leaderboard stats for Messages created,
|
||||
separate leaderboard for prompts & assistants
|
||||
|
||||
"""
|
||||
query = (
|
||||
self.db.query(Message.user_id, User.username, User.display_name, func.count(Message.user_id))
|
||||
.join(User, User.id == Message.user_id, isouter=True)
|
||||
.filter(Message.deleted is not True, Message.role == role)
|
||||
.group_by(Message.user_id, User.username, User.display_name)
|
||||
.order_by(func.count(Message.user_id).desc())
|
||||
)
|
||||
|
||||
result = [
|
||||
{"ranking": i, "user_id": j[0], "username": j[1], "display_name": j[2], "score": j[3]}
|
||||
for i, j in enumerate(query.all(), start=1)
|
||||
]
|
||||
|
||||
return LeaderboardStats(leaderboard=result)
|
||||
|
||||
def query_users(
|
||||
self,
|
||||
api_client_id: Optional[UUID] = None,
|
||||
|
||||
@@ -0,0 +1,303 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
from loguru import logger
|
||||
from oasst_backend.models import Message, MessageReaction, Task, User, UserStats, UserStatsTimeFrame
|
||||
from oasst_backend.models.db_payload import (
|
||||
LabelAssistantReplyPayload,
|
||||
LabelPrompterReplyPayload,
|
||||
RankingReactionPayload,
|
||||
)
|
||||
from oasst_shared.schemas.protocol import LeaderboardStats, UserScore
|
||||
from oasst_shared.utils import log_timing, utcnow
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlmodel import Session, delete, func, text
|
||||
|
||||
|
||||
def _create_user_score(r):
|
||||
if r["UserStats"]:
|
||||
d = r["UserStats"].dict()
|
||||
else:
|
||||
d = {"modified_date": utcnow()}
|
||||
for k in ["user_id", "username", "auth_method", "display_name"]:
|
||||
d[k] = r[k]
|
||||
return UserScore(**d)
|
||||
|
||||
|
||||
class UserStatsRepository:
|
||||
def __init__(self, session: Session):
|
||||
self.session = session
|
||||
|
||||
def get_leaderboard(self, time_frame: UserStatsTimeFrame, limit: int = 100) -> LeaderboardStats:
|
||||
"""
|
||||
Get leaderboard stats for the specified time frame
|
||||
"""
|
||||
|
||||
qry = (
|
||||
self.session.query(User.id.label("user_id"), User.username, User.auth_method, User.display_name, UserStats)
|
||||
.join(UserStats, User.id == UserStats.user_id)
|
||||
.filter(UserStats.time_frame == time_frame.value)
|
||||
.order_by(UserStats.leader_score.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
leaderboard = [_create_user_score(r) for r in self.session.exec(qry)]
|
||||
return LeaderboardStats(time_frame=time_frame.value, leaderboard=leaderboard)
|
||||
|
||||
def get_user_stats_all_time_frames(self, user_id: UUID) -> dict[str, UserScore | None]:
|
||||
qry = (
|
||||
self.session.query(User.id.label("user_id"), User.username, User.auth_method, User.display_name, UserStats)
|
||||
.outerjoin(UserStats, User.id == UserStats.user_id)
|
||||
.filter(User.id == user_id)
|
||||
)
|
||||
|
||||
stats_by_timeframe = {}
|
||||
for r in self.session.exec(qry):
|
||||
us = r["UserStats"]
|
||||
if us is not None:
|
||||
stats_by_timeframe[us.time_frame] = _create_user_score(r)
|
||||
else:
|
||||
stats_by_timeframe = {tf.value: _create_user_score(r) for tf in UserStatsTimeFrame}
|
||||
return stats_by_timeframe
|
||||
|
||||
def query_total_prompts_per_user(
|
||||
self, reference_time: Optional[datetime] = None, only_reviewed: Optional[bool] = True
|
||||
):
|
||||
qry = self.session.query(Message.user_id, func.count()).filter(
|
||||
Message.deleted == sa.false(), Message.parent_id.is_(None)
|
||||
)
|
||||
if reference_time:
|
||||
qry = qry.filter(Message.created_date >= reference_time)
|
||||
if only_reviewed:
|
||||
qry = qry.filter(Message.review_result == sa.true())
|
||||
qry = qry.group_by(Message.user_id)
|
||||
return qry
|
||||
|
||||
def query_replies_by_role_per_user(
|
||||
self, reference_time: Optional[datetime] = None, only_reviewed: Optional[bool] = True
|
||||
) -> list:
|
||||
qry = self.session.query(Message.user_id, Message.role, func.count()).filter(
|
||||
Message.deleted == sa.false(), Message.parent_id.is_not(None)
|
||||
)
|
||||
if reference_time:
|
||||
qry = qry.filter(Message.created_date >= reference_time)
|
||||
if only_reviewed:
|
||||
qry = qry.filter(Message.review_result == sa.true())
|
||||
qry = qry.group_by(Message.user_id, Message.role)
|
||||
return qry
|
||||
|
||||
def query_labels_by_mode_per_user(
|
||||
self, payload_type: str = LabelAssistantReplyPayload.__name__, reference_time: Optional[datetime] = None
|
||||
):
|
||||
qry = self.session.query(Task.user_id, Task.payload["payload", "mode"].astext, func.count()).filter(
|
||||
Task.done == sa.true(), Task.payload_type == payload_type
|
||||
)
|
||||
if reference_time:
|
||||
qry = qry.filter(Task.created_date >= reference_time)
|
||||
qry = qry.group_by(Task.user_id, Task.payload["payload", "mode"].astext)
|
||||
return qry
|
||||
|
||||
def query_rankings_per_user(self, reference_time: Optional[datetime] = None):
|
||||
qry = self.session.query(MessageReaction.user_id, func.count()).filter(
|
||||
MessageReaction.payload_type == RankingReactionPayload.__name__
|
||||
)
|
||||
if reference_time:
|
||||
qry = qry.filter(MessageReaction.created_date >= reference_time)
|
||||
qry = qry.group_by(MessageReaction.user_id)
|
||||
return qry
|
||||
|
||||
def query_ranking_result_users(self, rank: int = 0, reference_time: Optional[datetime] = None):
|
||||
ranked_message_id = MessageReaction.payload["payload", "ranked_message_ids", rank].astext.cast(
|
||||
postgresql.UUID(as_uuid=True)
|
||||
)
|
||||
qry = (
|
||||
self.session.query(Message.user_id, func.count())
|
||||
.select_from(MessageReaction)
|
||||
.join(Message, ranked_message_id == Message.id)
|
||||
.filter(MessageReaction.payload_type == RankingReactionPayload.__name__)
|
||||
)
|
||||
if reference_time:
|
||||
qry = qry.filter(MessageReaction.created_date >= reference_time)
|
||||
qry = qry.group_by(Message.user_id)
|
||||
return qry
|
||||
|
||||
def _update_stats_internal(self, time_frame: UserStatsTimeFrame, base_date: Optional[datetime] = None):
|
||||
# gather user data
|
||||
|
||||
time_frame_key = time_frame.value
|
||||
|
||||
stats_by_user: dict[UUID, UserStats] = dict()
|
||||
now = utcnow()
|
||||
|
||||
def get_stats(id: UUID) -> UserStats:
|
||||
us = stats_by_user.get(id)
|
||||
if not us:
|
||||
us = UserStats(user_id=id, time_frame=time_frame_key, modified_date=now, base_date=base_date)
|
||||
stats_by_user[id] = us
|
||||
return us
|
||||
|
||||
# total prompts
|
||||
qry = self.query_total_prompts_per_user(reference_time=base_date, only_reviewed=False)
|
||||
for r in qry:
|
||||
uid, count = r
|
||||
get_stats(uid).prompts = count
|
||||
|
||||
# accepted prompts
|
||||
qry = self.query_total_prompts_per_user(reference_time=base_date, only_reviewed=True)
|
||||
for r in qry:
|
||||
uid, count = r
|
||||
get_stats(uid).accepted_prompts = count
|
||||
|
||||
# total replies
|
||||
qry = self.query_replies_by_role_per_user(reference_time=base_date, only_reviewed=False)
|
||||
for r in qry:
|
||||
uid, role, count = r
|
||||
s = get_stats(uid)
|
||||
if role == "assistant":
|
||||
s.replies_assistant += count
|
||||
elif role == "prompter":
|
||||
s.replies_prompter += count
|
||||
|
||||
# accepted replies
|
||||
qry = self.query_replies_by_role_per_user(reference_time=base_date, only_reviewed=True)
|
||||
for r in qry:
|
||||
uid, role, count = r
|
||||
s = get_stats(uid)
|
||||
if role == "assistant":
|
||||
s.accepted_replies_assistant += count
|
||||
elif role == "prompter":
|
||||
s.accepted_replies_prompter += count
|
||||
|
||||
# simple and full labels
|
||||
qry = self.query_labels_by_mode_per_user(
|
||||
payload_type=LabelAssistantReplyPayload.__name__, reference_time=base_date
|
||||
)
|
||||
for r in qry:
|
||||
uid, mode, count = r
|
||||
s = get_stats(uid)
|
||||
if mode == "simple":
|
||||
s.labels_simple = count
|
||||
elif mode == "full":
|
||||
s.labels_full = count
|
||||
|
||||
qry = self.query_labels_by_mode_per_user(
|
||||
payload_type=LabelPrompterReplyPayload.__name__, reference_time=base_date
|
||||
)
|
||||
for r in qry:
|
||||
uid, mode, count = r
|
||||
s = get_stats(uid)
|
||||
if mode == "simple":
|
||||
s.labels_simple += count
|
||||
elif mode == "full":
|
||||
s.labels_full += count
|
||||
|
||||
qry = self.query_rankings_per_user(reference_time=base_date)
|
||||
for r in qry:
|
||||
uid, count = r
|
||||
get_stats(uid).rankings_total = count
|
||||
|
||||
rank_field_names = ["reply_ranked_1", "reply_ranked_2", "reply_ranked_3"]
|
||||
for i, fn in enumerate(rank_field_names):
|
||||
qry = self.query_ranking_result_users(reference_time=base_date, rank=0)
|
||||
for r in qry:
|
||||
uid, count = r
|
||||
setattr(get_stats(uid), fn, count)
|
||||
|
||||
# delete all existing stast for time frame
|
||||
d = delete(UserStats).where(UserStats.time_frame == time_frame_key)
|
||||
self.session.execute(d)
|
||||
|
||||
# compute magic leader score
|
||||
for v in stats_by_user.values():
|
||||
v.leader_score = v.compute_leader_score()
|
||||
|
||||
# insert user objects
|
||||
self.session.add_all(stats_by_user.values())
|
||||
self.session.flush()
|
||||
|
||||
self.update_ranks(time_frame=time_frame)
|
||||
|
||||
@log_timing(log_kwargs=True)
|
||||
def update_ranks(self, time_frame: UserStatsTimeFrame = None):
|
||||
"""
|
||||
Update user_stats ranks. The persisted rank values allow to
|
||||
quickly the rank of a single user and to query nearby users.
|
||||
"""
|
||||
|
||||
# todo: convert sql to sqlalchemy query..
|
||||
# ranks = self.session.query(
|
||||
# func.row_number()
|
||||
# .over(partition_by=UserStats.time_frame, order_by=[UserStats.leader_score.desc(), UserStats.user_id])
|
||||
# .label("rank"),
|
||||
# UserStats.user_id,
|
||||
# UserStats.time_frame,
|
||||
# )
|
||||
|
||||
sql_update_rank = """
|
||||
-- update rank
|
||||
UPDATE user_stats us
|
||||
SET "rank" = r."rank"
|
||||
FROM
|
||||
(SELECT
|
||||
ROW_NUMBER () OVER(
|
||||
PARTITION BY time_frame
|
||||
ORDER BY leader_score DESC, user_id
|
||||
) AS "rank", user_id, time_frame
|
||||
FROM user_stats
|
||||
WHERE (:time_frame IS NULL OR time_frame = :time_frame)) AS r
|
||||
WHERE
|
||||
us.user_id = r.user_id
|
||||
AND us.time_frame = r.time_frame;"""
|
||||
r = self.session.execute(
|
||||
text(sql_update_rank), {"time_frame": time_frame.value if time_frame is not None else None}
|
||||
)
|
||||
logger.debug(f"pre_compute_ranks updated({time_frame=}) {r.rowcount} rows.")
|
||||
|
||||
def update_stats_time_frame(self, time_frame: UserStatsTimeFrame, reference_time: Optional[datetime] = None):
|
||||
self._update_stats_internal(time_frame, reference_time)
|
||||
self.session.commit()
|
||||
|
||||
@log_timing(log_kwargs=True, level="INFO")
|
||||
def update_stats(self, *, time_frame: UserStatsTimeFrame):
|
||||
now = utcnow()
|
||||
match time_frame:
|
||||
case UserStatsTimeFrame.day:
|
||||
r = now - timedelta(days=1)
|
||||
self.update_stats_time_frame(time_frame, r)
|
||||
|
||||
case UserStatsTimeFrame.week:
|
||||
r = now.date() - timedelta(days=7)
|
||||
r = datetime(r.year, r.month, r.day, tzinfo=now.tzinfo)
|
||||
self.update_stats_time_frame(time_frame, r)
|
||||
|
||||
case UserStatsTimeFrame.month:
|
||||
r = now.date() - timedelta(days=30)
|
||||
r = datetime(r.year, r.month, r.day, tzinfo=now.tzinfo)
|
||||
self.update_stats_time_frame(time_frame, r)
|
||||
|
||||
case UserStatsTimeFrame.total:
|
||||
self.update_stats_time_frame(time_frame, None)
|
||||
|
||||
@log_timing(level="INFO")
|
||||
def update_multiple_time_frames(self, time_frames: list[UserStatsTimeFrame]):
|
||||
for t in time_frames:
|
||||
self.update_stats(time_frame=t)
|
||||
|
||||
@log_timing(level="INFO")
|
||||
def update_all_time_frames(self):
|
||||
self.update_multiple_time_frames(list(UserStatsTimeFrame))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from oasst_backend.api.deps import get_dummy_api_client
|
||||
from oasst_backend.database import engine
|
||||
|
||||
with Session(engine) as session:
|
||||
api_client = get_dummy_api_client(session)
|
||||
usr = UserStatsRepository(session)
|
||||
# usr.update_all_time_frames()
|
||||
# session.commit()
|
||||
# usr.get_leader_board(UserStatsTimeFrame.total)
|
||||
usr.get_user_stats_all_time_frames(UUID("0d6ff62a-0bea-4c56-ade8-b3e0520a10ce"))
|
||||
@@ -1,6 +1,7 @@
|
||||
alembic==1.8.1
|
||||
fastapi==0.88.0
|
||||
fastapi-limiter==0.1.5
|
||||
fastapi-utils==0.2.1
|
||||
loguru==0.6.0
|
||||
numpy==1.22.4
|
||||
psycopg2-binary==2.9.5
|
||||
|
||||
@@ -29,7 +29,7 @@ environments:
|
||||
variables:
|
||||
# Note: this has to be a valid JSON list for Pydantic to parse it.
|
||||
BACKEND_CORS_ORIGINS: '["https://web.staging.open-assistant.surfacedata.org"]'
|
||||
DEBUG_ALLOW_ANY_API_KEY: True
|
||||
DEBUG_ALLOW_DEBUG_API_KEY: True
|
||||
DEBUG_SKIP_API_KEY_CHECK: True
|
||||
MAX_WORKERS: 1
|
||||
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
# Deployment files
|
||||
|
||||
Copy these to the node you want to deploy to.
|
||||
@@ -0,0 +1,19 @@
|
||||
version: "3"
|
||||
|
||||
services:
|
||||
webserver:
|
||||
image: nginx:latest
|
||||
network_mode: host
|
||||
ports:
|
||||
- 80:80
|
||||
- 443:443
|
||||
restart: always
|
||||
volumes:
|
||||
- ./nginx.conf:/etc/nginx/nginx.conf:ro
|
||||
- ./certbot/www:/var/www/certbot/:ro
|
||||
- ./certbot/conf/:/etc/nginx/ssl/:ro
|
||||
certbot:
|
||||
image: certbot/certbot:latest
|
||||
volumes:
|
||||
- ./certbot/www/:/var/www/certbot/:rw
|
||||
- ./certbot/conf/:/etc/letsencrypt/:rw
|
||||
Executable
+3
@@ -0,0 +1,3 @@
|
||||
#!/bin/bash
|
||||
|
||||
docker compose run --rm certbot certonly -m admin@open-assistant.io --agree-tos --webroot --webroot-path /var/www/certbot/ -d $1
|
||||
@@ -0,0 +1,81 @@
|
||||
events {}
|
||||
http {
|
||||
server {
|
||||
listen 80;
|
||||
listen [::]:80;
|
||||
|
||||
server_name *.open-assistant.io;
|
||||
server_tokens off;
|
||||
|
||||
location /.well-known/acme-challenge/ {
|
||||
root /var/www/certbot;
|
||||
}
|
||||
|
||||
location / {
|
||||
return 301 https://$host$request_uri;
|
||||
}
|
||||
}
|
||||
|
||||
server {
|
||||
listen 443 ssl http2;
|
||||
|
||||
server_name web.dev.open-assistant.io;
|
||||
|
||||
ssl_certificate /etc/nginx/ssl/live/web.dev.open-assistant.io/fullchain.pem;
|
||||
ssl_certificate_key /etc/nginx/ssl/live/web.dev.open-assistant.io/privkey.pem;
|
||||
|
||||
location / {
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_pass http://127.0.0.1:3000;
|
||||
}
|
||||
}
|
||||
|
||||
server {
|
||||
listen 443 ssl http2;
|
||||
|
||||
server_name backend.dev.open-assistant.io;
|
||||
|
||||
ssl_certificate /etc/nginx/ssl/live/backend.dev.open-assistant.io/fullchain.pem;
|
||||
ssl_certificate_key /etc/nginx/ssl/live/backend.dev.open-assistant.io/privkey.pem;
|
||||
|
||||
location / {
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_pass http://127.0.0.1:8080;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
server {
|
||||
listen 443 ssl http2;
|
||||
|
||||
server_name web.staging.open-assistant.io;
|
||||
|
||||
ssl_certificate /etc/nginx/ssl/live/web.staging.open-assistant.io/fullchain.pem;
|
||||
ssl_certificate_key /etc/nginx/ssl/live/web.staging.open-assistant.io/privkey.pem;
|
||||
|
||||
location / {
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_pass http://127.0.0.1:3100;
|
||||
}
|
||||
}
|
||||
|
||||
server {
|
||||
listen 443 ssl http2;
|
||||
|
||||
server_name backend.staging.open-assistant.io;
|
||||
|
||||
ssl_certificate /etc/nginx/ssl/live/backend.staging.open-assistant.io/fullchain.pem;
|
||||
ssl_certificate_key /etc/nginx/ssl/live/backend.staging.open-assistant.io/privkey.pem;
|
||||
|
||||
location / {
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_pass http://127.0.0.1:8180;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
Executable
+3
@@ -0,0 +1,3 @@
|
||||
#!/bin/bash
|
||||
|
||||
docker compose run --rm certbot renew
|
||||
@@ -0,0 +1,19 @@
|
||||
version: "3"
|
||||
|
||||
services:
|
||||
webserver:
|
||||
image: nginx:latest
|
||||
network_mode: host
|
||||
ports:
|
||||
- 80:80
|
||||
- 443:443
|
||||
restart: always
|
||||
volumes:
|
||||
- ./nginx.conf:/etc/nginx/nginx.conf:ro
|
||||
- ./certbot/www:/var/www/certbot/:ro
|
||||
- ./certbot/conf/:/etc/nginx/ssl/:ro
|
||||
certbot:
|
||||
image: certbot/certbot:latest
|
||||
volumes:
|
||||
- ./certbot/www/:/var/www/certbot/:rw
|
||||
- ./certbot/conf/:/etc/letsencrypt/:rw
|
||||
Executable
+3
@@ -0,0 +1,3 @@
|
||||
#!/bin/bash
|
||||
|
||||
docker compose run --rm certbot certonly -m admin@open-assistant.io --agree-tos --webroot --webroot-path /var/www/certbot/ -d $1
|
||||
@@ -0,0 +1,62 @@
|
||||
events {}
|
||||
http {
|
||||
server {
|
||||
listen 80;
|
||||
listen [::]:80;
|
||||
|
||||
server_name *.open-assistant.io open-assistant.io;
|
||||
server_tokens off;
|
||||
|
||||
location /.well-known/acme-challenge/ {
|
||||
root /var/www/certbot;
|
||||
}
|
||||
|
||||
location / {
|
||||
return 301 https://$host$request_uri;
|
||||
}
|
||||
}
|
||||
|
||||
server {
|
||||
listen 443 ssl http2;
|
||||
|
||||
server_name open-assistant.io;
|
||||
|
||||
ssl_certificate /etc/nginx/ssl/live/open-assistant.io/fullchain.pem;
|
||||
ssl_certificate_key /etc/nginx/ssl/live/open-assistant.io/privkey.pem;
|
||||
|
||||
location / {
|
||||
return 301 https://web.prod.open-assistant.io$request_uri;
|
||||
}
|
||||
}
|
||||
|
||||
server {
|
||||
listen 443 ssl http2;
|
||||
|
||||
server_name web.prod.open-assistant.io;
|
||||
|
||||
ssl_certificate /etc/nginx/ssl/live/web.prod.open-assistant.io/fullchain.pem;
|
||||
ssl_certificate_key /etc/nginx/ssl/live/web.prod.open-assistant.io/privkey.pem;
|
||||
|
||||
location / {
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_pass http://127.0.0.1:3000;
|
||||
}
|
||||
}
|
||||
|
||||
server {
|
||||
listen 443 ssl http2;
|
||||
|
||||
server_name backend.prod.open-assistant.io;
|
||||
|
||||
ssl_certificate /etc/nginx/ssl/live/backend.prod.open-assistant.io/fullchain.pem;
|
||||
ssl_certificate_key /etc/nginx/ssl/live/backend.prod.open-assistant.io/privkey.pem;
|
||||
|
||||
location / {
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_pass http://127.0.0.1:8080;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
Executable
+3
@@ -0,0 +1,3 @@
|
||||
#!/bin/bash
|
||||
|
||||
docker compose run --rm certbot renew
|
||||
@@ -32,12 +32,15 @@ const config = {
|
||||
|
||||
presets: [
|
||||
[
|
||||
"classic",
|
||||
"docusaurus-preset-openapi",
|
||||
/** @type {import('@docusaurus/preset-classic').Options} */
|
||||
({
|
||||
docs: {
|
||||
sidebarPath: require.resolve("./sidebars.js"),
|
||||
},
|
||||
api: {
|
||||
path: "docs/api/openapi.json",
|
||||
},
|
||||
blog: false,
|
||||
theme: {
|
||||
customCss: require.resolve("./src/css/custom.css"),
|
||||
@@ -62,11 +65,7 @@ const config = {
|
||||
position: "left",
|
||||
label: "Docs",
|
||||
},
|
||||
{
|
||||
href: "https://editor.swagger.io/?url=https://raw.githubusercontent.com/LAION-AI/Open-Assistant/main/docs/docs/api/openapi.json",
|
||||
label: "API",
|
||||
position: "left",
|
||||
},
|
||||
{ to: "/api", label: "API", position: "left" },
|
||||
{
|
||||
href: "https://github.com/LAION-AI/Open-Assistant",
|
||||
label: "GitHub",
|
||||
|
||||
+3
-1
@@ -19,9 +19,11 @@
|
||||
"@docusaurus/preset-classic": "2.2.0",
|
||||
"@mdx-js/react": "^1.6.22",
|
||||
"clsx": "^1.2.1",
|
||||
"docusaurus-preset-openapi": "^0.6.3",
|
||||
"prism-react-renderer": "^1.3.5",
|
||||
"react": "^17.0.2",
|
||||
"react-dom": "^17.0.2"
|
||||
"react-dom": "^17.0.2",
|
||||
"url": "^0.11.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@docusaurus/module-type-aliases": "2.2.0",
|
||||
|
||||
+972
-37
File diff suppressed because it is too large
Load Diff
@@ -17,6 +17,7 @@ class OasstErrorCode(IntEnum):
|
||||
GENERIC_ERROR = 0
|
||||
DATABASE_URI_NOT_SET = 1
|
||||
API_CLIENT_NOT_AUTHORIZED = 2
|
||||
ROOT_TOKEN_NOT_AUTHORIZED = 3
|
||||
TOO_MANY_REQUESTS = 429
|
||||
|
||||
SERVER_ERROR0 = 500
|
||||
|
||||
@@ -169,6 +169,8 @@ class RankConversationRepliesTask(Task):
|
||||
conversation: Conversation # the conversation so far
|
||||
replies: list[str] # deprecated, use reply_messages
|
||||
reply_messages: list[ConversationMessage]
|
||||
message_tree_id: UUID
|
||||
ranking_parent_id: UUID
|
||||
|
||||
|
||||
class RankPrompterRepliesTask(RankConversationRepliesTask):
|
||||
@@ -356,14 +358,40 @@ class SystemStats(BaseModel):
|
||||
|
||||
|
||||
class UserScore(BaseModel):
|
||||
ranking: int
|
||||
rank: Optional[int]
|
||||
user_id: UUID
|
||||
username: str
|
||||
auth_method: str
|
||||
display_name: str
|
||||
score: int
|
||||
|
||||
leader_score: int = 0
|
||||
|
||||
base_date: Optional[datetime]
|
||||
modified_date: Optional[datetime]
|
||||
|
||||
prompts: int = 0
|
||||
replies_assistant: int = 0
|
||||
replies_prompter: int = 0
|
||||
labels_simple: int = 0
|
||||
labels_full: int = 0
|
||||
rankings_total: int = 0
|
||||
rankings_good: int = 0
|
||||
|
||||
accepted_prompts: int = 0
|
||||
accepted_replies_assistant: int = 0
|
||||
accepted_replies_prompter: int = 0
|
||||
|
||||
reply_ranked_1: int = 0
|
||||
reply_ranked_2: int = 0
|
||||
reply_ranked_3: int = 0
|
||||
|
||||
# only used for time frame "total"
|
||||
streak_last_day_date: Optional[datetime]
|
||||
streak_days: Optional[int]
|
||||
|
||||
|
||||
class LeaderboardStats(BaseModel):
|
||||
time_frame: str
|
||||
leaderboard: List[UserScore]
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,32 @@
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from functools import wraps
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def utcnow() -> datetime:
|
||||
"""Return the current utc date and time with tzinfo set to UTC."""
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def log_timing(func=None, *, log_kwargs: bool = False, level: int | str = "DEBUG") -> None:
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapped(*args, **kwargs):
|
||||
start = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
end = time.time()
|
||||
elapsed = end - start
|
||||
if log_kwargs:
|
||||
kwargs = ", ".join([f"{k}={v}" for k, v in kwargs.items()])
|
||||
logger.log(level, f"Function '{func.__name__}({kwargs})' executed in {elapsed:f} s")
|
||||
else:
|
||||
logger.log(level, f"Function '{func.__name__}' executed in {elapsed:f} s")
|
||||
return result
|
||||
|
||||
return wrapped
|
||||
|
||||
if func and callable(func):
|
||||
return decorator(func)
|
||||
return decorator
|
||||
|
||||
@@ -4,7 +4,7 @@ parent_path=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P )
|
||||
# switch to backend directory
|
||||
pushd "$parent_path/../../backend"
|
||||
|
||||
export DEBUG_SKIP_API_KEY_CHECK=True
|
||||
export DEBUG_SKIP_API_KEY_CHECK=False
|
||||
export DEBUG_USE_SEED_DATA=True
|
||||
export DEBUG_SKIP_TOXICITY_CALCULATION=True
|
||||
export DEBUG_ALLOW_SELF_LABELING=True
|
||||
|
||||
@@ -25,10 +25,10 @@ from bs4 import BeautifulSoup as bs
|
||||
from logic.logic_injector import LogicBug
|
||||
from nltk.corpus import wordnet
|
||||
from syntax.syntax_injector import SyntaxBug
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, T5ForConditionalGeneration, pipeline
|
||||
|
||||
|
||||
class DataArgumenter:
|
||||
class DataAugmenter:
|
||||
def __init__(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -48,7 +48,7 @@ class DataArgumenter:
|
||||
pass
|
||||
|
||||
|
||||
class EssayInstructor(DataArgumenter):
|
||||
class EssayInstructor(DataAugmenter):
|
||||
def __init__(self, model_name=None):
|
||||
if model_name is None:
|
||||
model_name = "snrspeaks/t5-one-line-summary"
|
||||
@@ -86,7 +86,7 @@ class EssayInstructor(DataArgumenter):
|
||||
return prompts, essay_paragraphs
|
||||
|
||||
|
||||
class EssayReviser(DataArgumenter):
|
||||
class EssayReviser(DataAugmenter):
|
||||
def __init__(self):
|
||||
nltk.download("wordnet")
|
||||
nltk.download("omw-1.4")
|
||||
@@ -132,7 +132,7 @@ class EssayReviser(DataArgumenter):
|
||||
return instructions, [essay] * len(instructions)
|
||||
|
||||
|
||||
class StackExchangeBuilder(DataArgumenter):
|
||||
class StackExchangeBuilder(DataAugmenter):
|
||||
def __init__(self, base_url=None, filter_opts=None):
|
||||
self.base_url = (
|
||||
base_url
|
||||
@@ -271,7 +271,7 @@ class StackExchangeBuilder(DataArgumenter):
|
||||
return questions, answers
|
||||
|
||||
|
||||
class HierachicalSummarizer(DataArgumenter):
|
||||
class HierachicalSummarizer(DataAugmenter):
|
||||
def __init__(self):
|
||||
self.summarizer = pipeline(
|
||||
"summarization",
|
||||
@@ -342,7 +342,7 @@ class HierachicalSummarizer(DataArgumenter):
|
||||
return instructions, answers
|
||||
|
||||
|
||||
class EntityRecognizedSummarizer(DataArgumenter):
|
||||
class EntityRecognizedSummarizer(DataAugmenter):
|
||||
def __init__(self):
|
||||
self.nlp = spacy.load("en_core_web_sm") # run !python -m spacy download en_core_web_sm in order to download
|
||||
|
||||
@@ -357,7 +357,7 @@ class EntityRecognizedSummarizer(DataArgumenter):
|
||||
return [question], [answer]
|
||||
|
||||
|
||||
class CodeBugger(DataArgumenter):
|
||||
class CodeBugger(DataAugmenter):
|
||||
"""
|
||||
https://github.com/LAION-AI/Open-Assistant/blob/main/notebooks/code-bugger/openbugger_example.md
|
||||
Openbugger is a Python package that allows you to inject syntax and logic errors into your code.
|
||||
@@ -391,6 +391,38 @@ class CodeBugger(DataArgumenter):
|
||||
return [question], [answer]
|
||||
|
||||
|
||||
class CodeInstructor(DataAugmenter):
|
||||
def __init__(self):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("Graverman/t5-code-summary")
|
||||
self.model = T5ForConditionalGeneration.from_pretrained("Graverman/t5-code-summary")
|
||||
|
||||
def parse(self, codes):
|
||||
source_encoding = self.tokenizer(
|
||||
codes,
|
||||
max_length=300,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
outputs = self.model.generate(
|
||||
input_ids=source_encoding["input_ids"],
|
||||
attention_mask=source_encoding["attention_mask"],
|
||||
max_length=100,
|
||||
length_penalty=0.75,
|
||||
repetition_penalty=2.5,
|
||||
early_stopping=True,
|
||||
use_cache=True,
|
||||
)
|
||||
summaries = [self.tokenizer.decode(o, skip_special_tokens=True) for o in outputs]
|
||||
|
||||
questions = ["Write a script that does the following:\n" + s for s in summaries]
|
||||
answers = codes
|
||||
|
||||
return questions, answers
|
||||
|
||||
|
||||
def recognize_entities(text, model, n=4, person="ignore"):
|
||||
"""Given a text and a model for entity recognition, return the most occuring entites in the text as a string"""
|
||||
doc = model(text)
|
||||
@@ -417,20 +449,23 @@ def parse_arguments():
|
||||
args.add_argument("--output", type=str, required=True)
|
||||
args = args.parse_args()
|
||||
|
||||
assert args.dataset.endswith(".tsv"), "Dataset file must be a tsv file, containing a list of files to be augmented"
|
||||
assert args.dataset.endswith(".tsv") or args.dataset.endswith(
|
||||
".csv"
|
||||
), "Dataset file must be a tsv or csv file, containing a list of files to be augmented"
|
||||
assert args.output.endswith(".json"), "Output file must be a json file"
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def read_data(args):
|
||||
files = pd.read_csv(args.dataset, sep="\t", header=None)
|
||||
files = files[0].tolist()
|
||||
files = pd.read_csv(args.dataset, sep=",", header=None, names=["file"])
|
||||
files = files["file"].tolist()
|
||||
data = []
|
||||
for file in files:
|
||||
with open(file, "r") as f:
|
||||
text = f.read()
|
||||
data.append(text)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@@ -453,9 +488,12 @@ def get_augmenter(args):
|
||||
elif args.augmenter == "codebugger":
|
||||
augmenter = CodeBugger()
|
||||
|
||||
elif args.augmenter == "codeinstructor":
|
||||
augmenter = CodeInstructor()
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"Augmenter must be one of 'essayinstruction', 'essayrevision', 'stackexchange', 'hierarchicalsummarizer', 'entityrecognizedsummarizer', 'codebugger"
|
||||
"Augmenter must be one of 'essayinstruction', 'essayrevision', 'stackexchange', 'hierarchicalsummarizer', 'entityrecognizedsummarizer', 'codebugger', 'codeinstructor"
|
||||
)
|
||||
|
||||
return augmenter
|
||||
|
||||
Generated
+616
-29
File diff suppressed because it is too large
Load Diff
@@ -58,6 +58,7 @@
|
||||
"react-dom": "18.2.0",
|
||||
"react-feature-flags": "^1.0.0",
|
||||
"react-icons": "^4.7.1",
|
||||
"sharp": "^0.31.3",
|
||||
"swr": "^2.0.0",
|
||||
"tailwindcss": "^3.2.4",
|
||||
"use-debounce": "^9.0.2"
|
||||
|
||||
@@ -41,6 +41,7 @@ model User {
|
||||
email String? @unique
|
||||
emailVerified DateTime?
|
||||
image String?
|
||||
isNew Boolean @default(true)
|
||||
role String @default("general")
|
||||
|
||||
accounts Account[]
|
||||
|
||||
@@ -3,7 +3,6 @@ import {
|
||||
Button,
|
||||
Checkbox,
|
||||
Flex,
|
||||
Grid,
|
||||
Popover,
|
||||
PopoverAnchor,
|
||||
PopoverArrow,
|
||||
@@ -146,24 +145,25 @@ export const FlaggableElement = (props: FlaggableElementProps) => {
|
||||
isLazy
|
||||
lazyBehavior="keepMounted"
|
||||
>
|
||||
<Grid display="flex" alignItems="center" gap="2">
|
||||
<Box display="flex" alignItems="center" gap="2">
|
||||
<PopoverAnchor>{props.children}</PopoverAnchor>
|
||||
|
||||
<Tooltip label="Report" bg="red.500" aria-label="A tooltip">
|
||||
<div>
|
||||
<Box>
|
||||
<PopoverTrigger>
|
||||
<Box as="button" display="flex" alignItems="center" justifyContent="center" borderRadius="full" p="1">
|
||||
<FiAlertCircle size="20" className="text-red-400" aria-hidden="true" />
|
||||
</Box>
|
||||
</PopoverTrigger>
|
||||
</div>
|
||||
</Box>
|
||||
</Tooltip>
|
||||
</Grid>
|
||||
</Box>
|
||||
|
||||
<PopoverContent width="auto" p="3" m="4" maxWidth="calc(100vw - 2rem)">
|
||||
<PopoverArrow />
|
||||
<div className="relative h-4">
|
||||
<Box className="relative h-4">
|
||||
<PopoverCloseButton />
|
||||
</div>
|
||||
</Box>
|
||||
<PopoverBody>
|
||||
{report.label_values.map(({ label, checked, value }, i) => (
|
||||
<FlagCheckbox
|
||||
|
||||
@@ -11,102 +11,105 @@ import {
|
||||
Text,
|
||||
useColorModeValue,
|
||||
} from "@chakra-ui/react";
|
||||
import { signOut, useSession } from "next-auth/react";
|
||||
import NextLink from "next/link";
|
||||
import React from "react";
|
||||
import { signOut, useSession } from "next-auth/react";
|
||||
import React, { ElementType, useCallback } from "react";
|
||||
import { FiAlertTriangle, FiLayout, FiLogOut, FiSettings, FiShield } from "react-icons/fi";
|
||||
|
||||
interface MenuOption {
|
||||
name: string;
|
||||
href: string;
|
||||
desc: string;
|
||||
icon: ElementType;
|
||||
isExternal: boolean;
|
||||
}
|
||||
|
||||
export function UserMenu() {
|
||||
const borderColor = useColorModeValue("gray.300", "gray.600");
|
||||
const handleSignOut = useCallback(() => {
|
||||
signOut({ callbackUrl: "/" });
|
||||
}, []);
|
||||
const { data: session, status } = useSession();
|
||||
|
||||
const { data: session } = useSession();
|
||||
|
||||
if (!session) {
|
||||
return <></>;
|
||||
if (!session || status !== "authenticated") {
|
||||
return null;
|
||||
}
|
||||
if (session && session.user) {
|
||||
const accountOptions = [
|
||||
{
|
||||
name: "Dashboard",
|
||||
href: "/dashboard",
|
||||
desc: "Dashboard",
|
||||
icon: FiLayout,
|
||||
},
|
||||
{
|
||||
name: "Account Settings",
|
||||
href: "/account",
|
||||
desc: "Account Settings",
|
||||
icon: FiSettings,
|
||||
},
|
||||
];
|
||||
const helpOptions = [
|
||||
{
|
||||
name: "Report a Bug",
|
||||
href: "https://github.com/LAION-AI/Open-Assistant/issues/new/choose",
|
||||
desc: "Report a Bug",
|
||||
icon: FiAlertTriangle,
|
||||
},
|
||||
];
|
||||
const options: MenuOption[] = [
|
||||
{
|
||||
name: "Dashboard",
|
||||
href: "/dashboard",
|
||||
desc: "Dashboard",
|
||||
icon: FiLayout,
|
||||
isExternal: false,
|
||||
},
|
||||
{
|
||||
name: "Account Settings",
|
||||
href: "/account",
|
||||
desc: "Account Settings",
|
||||
icon: FiSettings,
|
||||
isExternal: false,
|
||||
},
|
||||
{
|
||||
name: "Report a Bug",
|
||||
href: "https://github.com/LAION-AI/Open-Assistant/issues/new/choose",
|
||||
desc: "Report a Bug",
|
||||
icon: FiAlertTriangle,
|
||||
isExternal: true,
|
||||
},
|
||||
];
|
||||
|
||||
if (session.user.role === "admin") {
|
||||
accountOptions.unshift({
|
||||
name: "Admin Dashboard",
|
||||
href: "/admin",
|
||||
desc: "Admin Dashboard",
|
||||
icon: FiShield,
|
||||
});
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Menu>
|
||||
<MenuButton border="solid" borderRadius="full" borderWidth="thin" borderColor={borderColor}>
|
||||
<Box display="flex" alignItems="center" gap="3" p="1" paddingRight={[1, 1, 1, 6, 6]}>
|
||||
<Avatar size="sm" bgImage={session.user.image}></Avatar>
|
||||
<Text data-cy="username" className="hidden lg:flex">
|
||||
{session.user.name || session.user.email}
|
||||
</Text>
|
||||
</Box>
|
||||
</MenuButton>
|
||||
<MenuList p="2" borderRadius="xl" shadow="none">
|
||||
<Box display="flex" flexDirection="column" alignItems="center" borderRadius="md" p="4">
|
||||
<Text>{session.user.name}</Text>
|
||||
<Text color="blue.500" fontWeight="bold" fontSize="xl">
|
||||
3,200
|
||||
</Text>
|
||||
</Box>
|
||||
<MenuDivider />
|
||||
<MenuGroup>
|
||||
{accountOptions.map((item) => (
|
||||
<Link as={NextLink} key={item.name} href={item.href} _hover={{ textDecoration: "none" }}>
|
||||
<MenuItem gap="3" borderRadius="md" p="4">
|
||||
<item.icon className="text-blue-500" aria-hidden="true" />
|
||||
<Text>{item.name}</Text>
|
||||
</MenuItem>
|
||||
</Link>
|
||||
))}
|
||||
</MenuGroup>
|
||||
<MenuDivider />
|
||||
<MenuGroup>
|
||||
{helpOptions.map((item) => (
|
||||
<Link as={NextLink} key={item.name} href={item.href} isExternal _hover={{ textDecoration: "none" }}>
|
||||
<MenuItem gap="3" borderRadius="md" p="4">
|
||||
<item.icon className="text-blue-500" aria-hidden="true" />
|
||||
<Text>{item.name}</Text>
|
||||
</MenuItem>
|
||||
</Link>
|
||||
))}
|
||||
</MenuGroup>
|
||||
<MenuDivider />
|
||||
<MenuItem gap="3" borderRadius="md" p="4" onClick={() => signOut({ callbackUrl: "/" })}>
|
||||
<FiLogOut className="text-blue-500" aria-hidden="true" />
|
||||
<Text>Sign Out</Text>
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
</Menu>
|
||||
</>
|
||||
);
|
||||
if (session.user.role === "admin") {
|
||||
options.unshift({
|
||||
name: "Admin Dashboard",
|
||||
href: "/admin",
|
||||
desc: "Admin Dashboard",
|
||||
icon: FiShield,
|
||||
isExternal: false,
|
||||
});
|
||||
}
|
||||
|
||||
return (
|
||||
<Menu>
|
||||
<MenuButton border="solid" borderRadius="full" borderWidth="thin" borderColor={borderColor}>
|
||||
<Box display="flex" alignItems="center" gap="3" p="1" paddingRight={[1, 1, 1, 6, 6]}>
|
||||
<Avatar size="sm" bgImage={session.user.image}></Avatar>
|
||||
<Text data-cy="username" className="hidden lg:flex">
|
||||
{session.user.name || session.user.email}
|
||||
</Text>
|
||||
</Box>
|
||||
</MenuButton>
|
||||
<MenuList p="2" borderRadius="xl" shadow="none">
|
||||
<Box display="flex" flexDirection="column" alignItems="center" borderRadius="md" p="4">
|
||||
<Text>{session.user.name}</Text>
|
||||
{/* <Text color="blue.500" fontWeight="bold" fontSize="xl">
|
||||
3,200
|
||||
</Text> */}
|
||||
</Box>
|
||||
<MenuDivider />
|
||||
<MenuGroup>
|
||||
{options.map((item) => (
|
||||
<Link
|
||||
key={item.name}
|
||||
as={item.isExternal ? "a" : NextLink}
|
||||
isExternal={item.isExternal}
|
||||
href={item.href}
|
||||
_hover={{ textDecoration: "none" }}
|
||||
>
|
||||
<MenuItem gap="3" borderRadius="md" p="4">
|
||||
<item.icon className="text-blue-500" aria-hidden="true" />
|
||||
<Text>{item.name}</Text>
|
||||
</MenuItem>
|
||||
</Link>
|
||||
))}
|
||||
</MenuGroup>
|
||||
<MenuDivider />
|
||||
<MenuItem gap="3" borderRadius="md" p="4" onClick={handleSignOut}>
|
||||
<FiLogOut className="text-blue-500" aria-hidden="true" />
|
||||
<Text>Sign Out</Text>
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
</Menu>
|
||||
);
|
||||
}
|
||||
|
||||
export default UserMenu;
|
||||
|
||||
@@ -54,6 +54,7 @@ export const getDashboardLayout = (page: React.ReactElement) => (
|
||||
>
|
||||
{page}
|
||||
</SideMenuLayout>
|
||||
<Footer />
|
||||
</div>
|
||||
);
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ export function MessageTableEntry(props: MessageTableEntryProps) {
|
||||
|
||||
return (
|
||||
<FlaggableElement message={item}>
|
||||
<HStack w="100%" gap={2}>
|
||||
<HStack w={["full", "full", "full", "fit-content"]} gap={2}>
|
||||
<Box borderRadius="full" border="solid" borderWidth="1px" borderColor={borderColor} bg={avatarColor}>
|
||||
<Avatar
|
||||
size="sm"
|
||||
@@ -28,21 +28,20 @@ export function MessageTableEntry(props: MessageTableEntryProps) {
|
||||
/>
|
||||
</Box>
|
||||
{props.enabled ? (
|
||||
<Box maxWidth="xl">
|
||||
<Box width={["full", "full", "full", "fit-content"]} maxWidth={["full", "full", "full", "2xl"]}>
|
||||
<Link href={`/messages/${item.id}`}>
|
||||
<LinkBox
|
||||
bg={item.is_assistant ? backgroundColor : backgroundColor2}
|
||||
className={`p-4 rounded-md whitespace-pre-wrap w-full`}
|
||||
>
|
||||
<LinkBox bg={item.is_assistant ? backgroundColor : backgroundColor2} p="4" borderRadius="md">
|
||||
{item.text}
|
||||
</LinkBox>
|
||||
</Link>
|
||||
</Box>
|
||||
) : (
|
||||
<Box
|
||||
maxWidth="xl"
|
||||
width={["full", "full", "full", "fit-content"]}
|
||||
maxWidth={["full", "full", "full", "2xl"]}
|
||||
bg={item.is_assistant ? backgroundColor : backgroundColor2}
|
||||
className={`p-4 rounded-md whitespace-pre-wrap w-full`}
|
||||
p="4"
|
||||
borderRadius="md"
|
||||
>
|
||||
{item.text}
|
||||
</Box>
|
||||
|
||||
@@ -50,7 +50,7 @@ if (boolean(process.env.DEBUG_LOGIN) || process.env.NODE_ENV === "development")
|
||||
where: {
|
||||
id: user.id,
|
||||
},
|
||||
update: {},
|
||||
update: user,
|
||||
create: user,
|
||||
});
|
||||
return user;
|
||||
@@ -86,6 +86,7 @@ export const authOptions: AuthOptions = {
|
||||
*/
|
||||
async session({ session, token }) {
|
||||
session.user.role = token.role;
|
||||
session.user.isNew = token.isNew;
|
||||
return session;
|
||||
},
|
||||
/**
|
||||
@@ -93,11 +94,12 @@ export const authOptions: AuthOptions = {
|
||||
* This let's use forward the role to the session object.
|
||||
*/
|
||||
async jwt({ token }) {
|
||||
const { role } = await prisma.user.findUnique({
|
||||
const { isNew, role } = await prisma.user.findUnique({
|
||||
where: { id: token.sub },
|
||||
select: { role: true },
|
||||
select: { role: true, isNew: true },
|
||||
});
|
||||
token.role = role;
|
||||
token.isNew = isNew;
|
||||
return token;
|
||||
},
|
||||
},
|
||||
|
||||
@@ -17,6 +17,9 @@ const handler = withoutRole("banned", async (req, res, token) => {
|
||||
// Parse out the local task ID and the interaction contents.
|
||||
const { id: frontendId, content, update_type } = req.body;
|
||||
|
||||
// Record that the user has done meaningful work and is no longer new.
|
||||
await prisma.user.update({ where: { id: token.sub }, data: { isNew: false } });
|
||||
|
||||
// Accept the task so that we can complete it, this will probably go away soon.
|
||||
const registeredTask = await prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } });
|
||||
const task = registeredTask.task as Prisma.JsonObject;
|
||||
|
||||
@@ -94,7 +94,14 @@ function Signin({ csrfToken, providers }) {
|
||||
{email && (
|
||||
<form onSubmit={signinWithEmail}>
|
||||
<Stack>
|
||||
<Input data-cy="email-address" variant="outline" size="lg" placeholder="Email Address" ref={emailEl} />
|
||||
<Input
|
||||
type="email"
|
||||
data-cy="email-address"
|
||||
variant="outline"
|
||||
size="lg"
|
||||
placeholder="Email Address"
|
||||
ref={emailEl}
|
||||
/>
|
||||
<Button
|
||||
data-cy="signin-email-button"
|
||||
size={"lg"}
|
||||
|
||||
@@ -1,9 +1,15 @@
|
||||
import Head from "next/head";
|
||||
import { useSession } from "next-auth/react";
|
||||
import { LeaderboardTable, TaskOption } from "src/components/Dashboard";
|
||||
import { getDashboardLayout } from "src/components/Layout";
|
||||
import { TaskCategory } from "src/components/Tasks/TaskTypes";
|
||||
|
||||
const Dashboard = () => {
|
||||
const { data: session } = useSession();
|
||||
|
||||
// TODO(#670): Do something more meaningful when the user is new.
|
||||
console.log(session?.user?.isNew);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
|
||||
@@ -27,6 +27,6 @@ const RandomTask = () => {
|
||||
);
|
||||
};
|
||||
|
||||
RandomTask.getLayout = getDashboardLayout;
|
||||
RandomTask.getLayout = (page) => getDashboardLayout(page);
|
||||
|
||||
export default RandomTask;
|
||||
|
||||
Vendored
+4
@@ -6,6 +6,8 @@ declare module "next-auth" {
|
||||
user: {
|
||||
/** The user's role. */
|
||||
role: string;
|
||||
/** True when the user is new. */
|
||||
isNew: boolean;
|
||||
} & DefaultSession["user"];
|
||||
}
|
||||
}
|
||||
@@ -14,5 +16,7 @@ declare module "next-auth/jwt" {
|
||||
interface JWT {
|
||||
/** The user's role. */
|
||||
role?: string;
|
||||
/** True when the user is new. */
|
||||
isNew?: boolean;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user