Merge branch 'LAION-AI:main' into main

This commit is contained in:
Riley Sandborg
2023-01-15 15:52:58 -06:00
committed by GitHub
56 changed files with 2765 additions and 280 deletions
+5 -6
View File
@@ -4,7 +4,6 @@ on:
push:
branches:
- main
- add-api-docs-workflow
paths:
- "oasst-shared/**"
- "backend/**"
@@ -46,8 +45,8 @@ jobs:
- run: ./scripts/backend-development/stop-mock-server.sh
- uses: stefanzweifel/git-auto-commit-action@v4
with:
file_pattern: "docs/docs/api/openapi.json"
commit_message:
update docs/docs/api/openapi.json by run ${{ github.run_id }}
#- uses: stefanzweifel/git-auto-commit-action@v4
# with:
# file_pattern: "docs/docs/api/openapi.json"
# commit_message:
# update docs/docs/api/openapi.json by run ${{ github.run_id }}
+8 -3
View File
@@ -19,17 +19,20 @@
ansible.builtin.file:
path: "./{{ stack_name }}"
state: directory
mode: 0755
- name: Copy redis.conf to managed node
ansible.builtin.copy:
src: ./redis.conf
dest: "./{{ stack_name }}/redis.conf"
mode: 0644
- name: Set up Redis
community.docker.docker_container:
name: "oasst-{{ stack_name }}-redis"
image: redis
state: started
recreate: "{{ (stack_name == 'dev') | bool }}"
restart_policy: always
network_mode: "oasst-{{ stack_name }}"
healthcheck:
@@ -46,6 +49,7 @@
name: "oasst-{{ stack_name }}-postgres-{{ item.name }}"
image: postgres:15
state: started
recreate: "{{ (stack_name == 'dev') | bool }}"
restart_policy: always
network_mode: "oasst-{{ stack_name }}"
env:
@@ -73,11 +77,12 @@
env:
POSTGRES_HOST: "oasst-{{ stack_name }}-postgres-backend"
REDIS_HOST: "oasst-{{ stack_name }}-redis"
DEBUG_ALLOW_ANY_API_KEY: "true"
DEBUG_ALLOW_DEBUG_API_KEY: "true"
DEBUG_USE_SEED_DATA: "true"
DEBUG_ALLOW_SELF_LABELING: "true"
DEBUG_ALLOW_SELF_LABELING:
"{{ 'true' if stack_name == 'dev' else 'false' }}"
MAX_WORKERS: "1"
RATE_LIMIT: "false"
RATE_LIMIT: "{{ 'false' if stack_name == 'dev' else 'true' }}"
DEBUG_SKIP_EMBEDDING_COMPUTATION: "true"
DEBUG_SKIP_TOXICITY_CALCULATION: "true"
ports:
@@ -0,0 +1,83 @@
"""change user_stats ranking counts
Revision ID: 7c98102efbca
Revises: 619255ae9076
Create Date: 2023-01-15 00:02:45.622986
"""
import sqlalchemy as sa
import sqlmodel
from alembic import op
from sqlalchemy.dialects.postgresql import UUID
# revision identifiers, used by Alembic.
revision = "7c98102efbca"
down_revision = "619255ae9076"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("user_stats")
op.create_table(
"user_stats",
sa.Column("user_id", UUID(as_uuid=True), nullable=False),
sa.Column("modified_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("base_date", sa.DateTime(), nullable=True),
sa.Column("time_frame", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("leader_score", sa.Integer(), nullable=False),
sa.Column("prompts", sa.Integer(), nullable=False),
sa.Column("replies_assistant", sa.Integer(), nullable=False),
sa.Column("replies_prompter", sa.Integer(), nullable=False),
sa.Column("labels_simple", sa.Integer(), nullable=False),
sa.Column("labels_full", sa.Integer(), nullable=False),
sa.Column("rankings_total", sa.Integer(), nullable=False),
sa.Column("rankings_good", sa.Integer(), nullable=False),
sa.Column("accepted_prompts", sa.Integer(), nullable=False),
sa.Column("accepted_replies_assistant", sa.Integer(), nullable=False),
sa.Column("accepted_replies_prompter", sa.Integer(), nullable=False),
sa.Column("reply_ranked_1", sa.Integer(), nullable=False),
sa.Column("reply_ranked_2", sa.Integer(), nullable=False),
sa.Column("reply_ranked_3", sa.Integer(), nullable=False),
sa.Column("streak_last_day_date", sa.DateTime(), nullable=True),
sa.Column("streak_days", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("user_id", "time_frame"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"user_stats",
sa.Column("reply_prompter_ranked_3", sa.INTEGER(), server_default="0", autoincrement=False, nullable=False),
)
op.add_column(
"user_stats",
sa.Column("reply_assistant_ranked_1", sa.INTEGER(), server_default="0", autoincrement=False, nullable=False),
)
op.add_column(
"user_stats",
sa.Column("reply_assistant_ranked_2", sa.INTEGER(), server_default="0", autoincrement=False, nullable=False),
)
op.add_column(
"user_stats",
sa.Column("reply_prompter_ranked_2", sa.INTEGER(), server_default="0", autoincrement=False, nullable=False),
)
op.add_column(
"user_stats",
sa.Column("reply_prompter_ranked_1", sa.INTEGER(), server_default="0", autoincrement=False, nullable=False),
)
op.add_column(
"user_stats",
sa.Column("reply_assistant_ranked_3", sa.INTEGER(), server_default="0", autoincrement=False, nullable=False),
)
op.drop_column("user_stats", "reply_ranked_3")
op.drop_column("user_stats", "reply_ranked_2")
op.drop_column("user_stats", "reply_ranked_1")
# ### end Alembic commands ###
@@ -0,0 +1,30 @@
"""add indices for created_date
Revision ID: 423557e869e4
Revises: 7c98102efbca
Create Date: 2023-01-15 11:39:10.407859
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "423557e869e4"
down_revision = "7c98102efbca"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_index(op.f("ix_message_created_date"), "message", ["created_date"], unique=False)
op.create_index(op.f("ix_message_reaction_created_date"), "message_reaction", ["created_date"], unique=False)
op.create_index(op.f("ix_text_labels_created_date"), "text_labels", ["created_date"], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_text_labels_created_date"), table_name="text_labels")
op.drop_index(op.f("ix_message_reaction_created_date"), table_name="message_reaction")
op.drop_index(op.f("ix_message_created_date"), table_name="message")
# ### end Alembic commands ###
@@ -0,0 +1,33 @@
"""add rank and indices to user_stats
Revision ID: 0964ac95170d
Revises: 423557e869e4
Create Date: 2023-01-15 16:54:09.510018
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "0964ac95170d"
down_revision = "423557e869e4"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("user_stats", sa.Column("rank", sa.Integer(), nullable=True))
op.create_index(
"ix_user_stats__timeframe__rank__user_id", "user_stats", ["time_frame", "rank", "user_id"], unique=True
)
op.create_index("ix_user_stats__timeframe__user_id", "user_stats", ["time_frame", "user_id"], unique=True)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index("ix_user_stats__timeframe__user_id", table_name="user_stats")
op.drop_index("ix_user_stats__timeframe__rank__user_id", table_name="user_stats")
op.drop_column("user_stats", "rank")
# ### end Alembic commands ###
+50
View File
@@ -9,6 +9,7 @@ import alembic.config
import fastapi
import redis.asyncio as redis
from fastapi_limiter import FastAPILimiter
from fastapi_utils.tasks import repeat_every
from loguru import logger
from oasst_backend.api.deps import get_dummy_api_client
from oasst_backend.api.v1.api import api_router
@@ -18,6 +19,7 @@ from oasst_backend.database import engine
from oasst_backend.models import message_tree_state
from oasst_backend.prompt_repository import PromptRepository, TaskRepository, UserRepository
from oasst_backend.tree_manager import TreeManager
from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from pydantic import BaseModel
@@ -195,6 +197,54 @@ def ensure_tree_states():
logger.exception("TreeManager.ensure_tree_states() failed.")
@app.on_event("startup")
@repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_DAY, wait_first=False)
def update_leader_board_day() -> None:
try:
with Session(engine) as session:
usr = UserStatsRepository(session)
usr.update_stats(time_frame=UserStatsTimeFrame.day)
session.commit()
except Exception:
logger.exception("Error during leaderboard update (daily)")
@app.on_event("startup")
@repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_WEEK, wait_first=False)
def update_leader_board_week() -> None:
try:
with Session(engine) as session:
usr = UserStatsRepository(session)
usr.update_stats(time_frame=UserStatsTimeFrame.week)
session.commit()
except Exception:
logger.exception("Error during user states update (weekly)")
@app.on_event("startup")
@repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_MONTH, wait_first=False)
def update_leader_board_month() -> None:
try:
with Session(engine) as session:
usr = UserStatsRepository(session)
usr.update_stats(time_frame=UserStatsTimeFrame.month)
session.commit()
except Exception:
logger.exception("Error during user states update (monthly)")
@app.on_event("startup")
@repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_TOTAL, wait_first=False)
def update_leader_board_total() -> None:
try:
with Session(engine) as session:
usr = UserStatsRepository(session)
usr.update_stats(time_frame=UserStatsTimeFrame.total)
session.commit()
except Exception:
logger.exception("Error during user states update (total)")
app.include_router(api_router, prefix=settings.API_V1_STR)
+52 -13
View File
@@ -1,9 +1,9 @@
from http import HTTPStatus
from secrets import token_hex
from typing import Generator
from uuid import UUID
from fastapi import Depends, Request, Response, Security
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi.security.api_key import APIKey, APIKeyHeader, APIKeyQuery
from fastapi_limiter.depends import RateLimiter
from loguru import logger
@@ -22,6 +22,8 @@ def get_db() -> Generator:
api_key_query = APIKeyQuery(name="api_key", auto_error=False)
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
bearer_token = HTTPBearer(auto_error=False)
async def get_api_key(
api_key_query: str = Security(api_key_query),
@@ -33,22 +35,47 @@ async def get_api_key(
return api_key_header
def get_dummy_api_client(db: Session) -> ApiClient:
def create_api_client(
*,
session: Session,
description: str,
frontend_type: str,
trusted: bool | None = False,
admin_email: str | None = None,
api_key: str | None = None,
) -> ApiClient:
if api_key is None:
api_key = token_hex(32)
logger.info(f"Creating new api client with {api_key=}")
api_client = ApiClient(
api_key=api_key,
description=description,
frontend_type=frontend_type,
trusted=trusted,
admin_email=admin_email,
)
session.add(api_client)
session.commit()
session.refresh(api_client)
return api_client
def get_dummy_api_client(session: Session) -> ApiClient:
# make sure that a dummy api key exits in db (foreign key references)
ANY_API_KEY_ID = UUID("00000000-1111-2222-3333-444444444444")
api_client: ApiClient = db.query(ApiClient).filter(ApiClient.id == ANY_API_KEY_ID).first()
DUMMY_API_KEY = "1234"
api_client: ApiClient = session.query(ApiClient).filter(ApiClient.api_key == DUMMY_API_KEY).first()
if api_client is None:
token = token_hex(32)
logger.info(f"ANY_API_KEY missing, inserting api_key: {token}")
api_client = ApiClient(
id=ANY_API_KEY_ID,
api_key=token,
description="ANY_API_KEY, random token",
logger.info(f"ANY_API_KEY missing, inserting api_key: {DUMMY_API_KEY}")
api_client = create_api_client(
session=session,
api_key=DUMMY_API_KEY,
description="Dummy api key for debugging",
trusted=True,
frontend_type="Test frontend",
)
db.add(api_client)
db.commit()
session.add(api_client)
session.commit()
return api_client
@@ -58,7 +85,7 @@ def api_auth(
) -> ApiClient:
if api_key or settings.DEBUG_SKIP_API_KEY_CHECK:
if settings.DEBUG_SKIP_API_KEY_CHECK or settings.DEBUG_ALLOW_ANY_API_KEY:
if settings.DEBUG_SKIP_API_KEY_CHECK or settings.DEBUG_ALLOW_DEBUG_API_KEY:
return get_dummy_api_client(db)
api_client = db.query(ApiClient).filter(ApiClient.api_key == api_key).first()
@@ -93,6 +120,18 @@ def get_trusted_api_client(
return client
def get_root_token(bearer_token: HTTPAuthorizationCredentials = Security(bearer_token)) -> str:
if bearer_token:
token = bearer_token.credentials
if token and token in settings.ROOT_TOKENS:
return token
raise OasstError(
"Could not validate credentials",
error_code=OasstErrorCode.ROOT_TOKEN_NOT_AUTHORIZED,
http_status_code=HTTPStatus.FORBIDDEN,
)
class UserRateLimiter(RateLimiter):
def __init__(
self, times: int = 100, milliseconds: int = 0, seconds: int = 0, minutes: int = 1, hours: int = 0
+31
View File
@@ -0,0 +1,31 @@
import pydantic
from fastapi import APIRouter, Depends
from loguru import logger
from oasst_backend.api import deps
router = APIRouter()
class CreateApiClientRequest(pydantic.BaseModel):
description: str
frontend_type: str
trusted: bool | None = False
admin_email: str | None = None
@router.post("/api_client")
async def create_api_client(
request: CreateApiClientRequest,
root_token: str = Depends(deps.get_root_token),
session: deps.Session = Depends(deps.get_db),
) -> str:
logger.info(f"Creating new api client with {request=}")
api_client = deps.create_api_client(
session=session,
description=request.description,
frontend_type=request.frontend_type,
trusted=request.trusted,
admin_email=request.admin_email,
)
logger.info(f"Created api_client with key {api_client.api_key}")
return api_client.api_key
+3 -1
View File
@@ -1,5 +1,6 @@
from fastapi import APIRouter
from oasst_backend.api.v1 import (
admin,
frontend_messages,
frontend_users,
hugging_face,
@@ -19,5 +20,6 @@ api_router.include_router(frontend_messages.router, prefix="/frontend_messages",
api_router.include_router(users.router, prefix="/users", tags=["users"])
api_router.include_router(frontend_users.router, prefix="/frontend_users", tags=["frontend_users"])
api_router.include_router(stats.router, prefix="/stats", tags=["stats"])
api_router.include_router(leaderboards.router, prefix="/experimental/leaderboards", tags=["leaderboards"])
api_router.include_router(leaderboards.router, prefix="/leaderboards", tags=["leaderboards"])
api_router.include_router(hugging_face.router, prefix="/hf", tags=["hugging_face"])
api_router.include_router(admin.router, prefix="/admin", tags=["admin"])
+11 -16
View File
@@ -1,26 +1,21 @@
from fastapi import APIRouter, Depends
from typing import Optional
from fastapi import APIRouter, Depends, Query
from oasst_backend.api import deps
from oasst_backend.models import ApiClient
from oasst_backend.user_repository import UserRepository
from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame
from oasst_shared.schemas.protocol import LeaderboardStats
from sqlmodel import Session
router = APIRouter()
@router.get("/create/assistant")
def get_assistant_leaderboard(
@router.get("/{time_frame}", response_model=LeaderboardStats)
def get_leaderboard(
time_frame: UserStatsTimeFrame,
max_count: Optional[int] = Query(100, gt=0, le=10000),
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
) -> LeaderboardStats:
ur = UserRepository(db, api_client)
return ur.get_user_leaderboard(role="assistant")
@router.get("/create/prompter")
def get_prompter_leaderboard(
db: Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
) -> LeaderboardStats:
ur = UserRepository(db, api_client)
return ur.get_user_leaderboard(role="prompter")
usr = UserStatsRepository(db)
return usr.get_leaderboard(time_frame, limit=max_count)
+22
View File
@@ -8,6 +8,7 @@ from oasst_backend.api.v1 import utils
from oasst_backend.models import ApiClient, User
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.user_repository import UserRepository
from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame
from oasst_shared.schemas import protocol
from sqlmodel import Session
from starlette.status import HTTP_204_NO_CONTENT
@@ -96,3 +97,24 @@ def mark_user_messages_deleted(
pr = PromptRepository(db, api_client)
messages = pr.query_messages(user_id=user_id)
pr.mark_messages_deleted(messages)
@router.get("/{user_id}/stats", response_model=dict[str, protocol.UserScore | None])
def query_user_stats(
user_id: UUID,
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
usr = UserStatsRepository(db)
return usr.get_user_stats_all_time_frames(user_id=user_id)
@router.get("/{user_id}/stats/{time_frame}", response_model=protocol.UserScore)
def query_user_stats_timeframe(
user_id: UUID,
time_frame: UserStatsTimeFrame,
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
usr = UserStatsRepository(db)
return usr.get_user_stats_all_time_frames(user_id=user_id)[time_frame.value]
+19 -1
View File
@@ -71,7 +71,7 @@ class Settings(BaseSettings):
REDIS_HOST: str = "localhost"
REDIS_PORT: str = "6379"
DEBUG_ALLOW_ANY_API_KEY: bool = False
DEBUG_ALLOW_DEBUG_API_KEY: bool = False
DEBUG_SKIP_API_KEY_CHECK: bool = False
DEBUG_USE_SEED_DATA: bool = False
DEBUG_USE_SEED_DATA_PATH: Optional[FilePath] = (
@@ -83,6 +83,8 @@ class Settings(BaseSettings):
HUGGING_FACE_API_KEY: str = ""
ROOT_TOKENS: List[str] = ["1234"] # supply a string that can be parsed to a json list
@validator("DATABASE_URI", pre=True)
def assemble_db_connection(cls, v: Optional[str], values: Dict[str, Any]) -> Any:
if isinstance(v, str):
@@ -109,6 +111,22 @@ class Settings(BaseSettings):
tree_manager: Optional[TreeManagerConfiguration] = TreeManagerConfiguration()
USER_STATS_INTERVAL_DAY: int = 15 # minutes
USER_STATS_INTERVAL_WEEK: int = 60 # minutes
USER_STATS_INTERVAL_MONTH: int = 120 # minutes
USER_STATS_INTERVAL_TOTAL: int = 240 # minutes
@validator(
"USER_STATS_INTERVAL_DAY",
"USER_STATS_INTERVAL_WEEK",
"USER_STATS_INTERVAL_MONTH",
"USER_STATS_INTERVAL_TOTAL",
)
def validate_user_stats_intervals(cls, v: int):
if v < 1:
raise ValueError(v)
return v
class Config:
env_file = ".env"
env_file_encoding = "utf-8"
+2 -1
View File
@@ -8,12 +8,13 @@ from .message_tree_state import MessageTreeState
from .task import Task
from .text_labels import TextLabels
from .user import User
from .user_stats import UserStats
from .user_stats import UserStats, UserStatsTimeFrame
__all__ = [
"ApiClient",
"User",
"UserStats",
"UserStatsTimeFrame",
"Message",
"MessageEmbedding",
"MessageReaction",
@@ -65,12 +65,16 @@ class RankingReactionPayload(ReactionPayload):
type: Literal["message_ranking"] = "message_ranking"
ranking: list[int]
ranked_message_ids: list[UUID]
ranking_parent_id: Optional[UUID]
message_tree_id: Optional[UUID]
@payload_type
class RankConversationRepliesPayload(TaskPayload):
conversation: protocol_schema.Conversation # the conversation so far
reply_messages: list[protocol_schema.ConversationMessage]
ranking_parent_id: Optional[UUID]
message_tree_id: Optional[UUID]
@payload_type
@@ -104,6 +108,7 @@ class LabelInitialPromptPayload(TaskPayload):
prompt: str
valid_labels: list[str]
mandatory_labels: Optional[list[str]]
mode: Optional[protocol_schema.LabelTaskMode]
@payload_type
@@ -115,6 +120,7 @@ class LabelConversationReplyPayload(TaskPayload):
reply: str
valid_labels: list[str]
mandatory_labels: Optional[list[str]]
mode: Optional[protocol_schema.LabelTaskMode]
@payload_type
+1 -1
View File
@@ -30,7 +30,7 @@ class Message(SQLModel, table=True):
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
frontend_message_id: str = Field(max_length=200, nullable=False)
created_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp(), index=True)
)
payload_type: str = Field(nullable=False, max_length=200)
payload: Optional[PayloadContainer] = Field(
@@ -19,7 +19,7 @@ class MessageReaction(SQLModel, table=True):
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), nullable=False, primary_key=True)
)
created_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp(), index=True)
)
payload_type: str = Field(nullable=False, max_length=200)
payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False))
+2 -1
View File
@@ -4,6 +4,7 @@ from uuid import UUID, uuid4
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from oasst_shared.utils import utcnow
from sqlalchemy import false
from sqlmodel import Field, SQLModel
@@ -35,4 +36,4 @@ class Task(SQLModel, table=True):
@property
def expired(self) -> bool:
return self.expiry_date is not None and datetime.utcnow() > self.expiry_date
return self.expiry_date is not None and utcnow() > self.expiry_date
+1 -1
View File
@@ -17,7 +17,7 @@ class TextLabels(SQLModel, table=True):
)
user_id: UUID = Field(sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), nullable=False))
created_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp(), index=True),
)
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
text: str = Field(nullable=False, max_length=2**16)
+29 -9
View File
@@ -5,7 +5,7 @@ from uuid import UUID
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from sqlmodel import Field, SQLModel
from sqlmodel import Field, Index, SQLModel
class UserStatsTimeFrame(str, Enum):
@@ -17,17 +17,24 @@ class UserStatsTimeFrame(str, Enum):
class UserStats(SQLModel, table=True):
__tablename__ = "user_stats"
__table_args__ = (
Index("ix_user_stats__timeframe__user_id", "time_frame", "user_id", unique=True),
Index("ix_user_stats__timeframe__rank__user_id", "time_frame", "rank", "user_id", unique=True),
)
time_frame: Optional[str] = Field(nullable=False, primary_key=True)
user_id: Optional[UUID] = Field(
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), primary_key=True)
)
time_frame: Optional[str] = Field(nullable=False, primary_key=True)
base_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(), nullable=True))
leader_score: int = 0
modified_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
)
rank: int = Field(nullable=True)
prompts: int = 0
replies_assistant: int = 0
replies_prompter: int = 0
@@ -40,14 +47,27 @@ class UserStats(SQLModel, table=True):
accepted_replies_assistant: int = 0
accepted_replies_prompter: int = 0
reply_assistant_ranked_1: int = 0
reply_assistant_ranked_2: int = 0
reply_assistant_ranked_3: int = 0
reply_prompter_ranked_1: int = 0
reply_prompter_ranked_2: int = 0
reply_prompter_ranked_3: int = 0
reply_ranked_1: int = 0
reply_ranked_2: int = 0
reply_ranked_3: int = 0
# only used for time span "total"
streak_last_day_date: Optional[datetime] = Field(nullable=True)
streak_days: Optional[int] = Field(nullable=True)
def compute_leader_score(self) -> int:
return (
self.prompts
+ self.replies_assistant * 4
+ self.replies_prompter
+ self.labels_simple
+ self.labels_full * 2
+ self.rankings_total
+ self.rankings_good
+ self.accepted_prompts
+ self.accepted_replies_assistant * 4
+ self.accepted_replies_prompter
+ self.reply_ranked_1 * 9
+ self.reply_ranked_2 * 3
+ self.reply_ranked_3
)
+4 -1
View File
@@ -260,7 +260,10 @@ class PromptRepository:
self.db.add(message)
reaction_payload = db_payload.RankingReactionPayload(
ranking=ranking.ranking, ranked_message_ids=ranked_message_ids
ranking=ranking.ranking,
ranked_message_ids=ranked_message_ids,
ranking_parent_id=task_payload.ranking_parent_id,
message_tree_id=task_payload.message_tree_id,
)
reaction = self.insert_reaction(task.id, reaction_payload)
self.journal.log_ranking(task, message_id=parent_msg.id, ranking=ranking.ranking)
+18 -3
View File
@@ -67,17 +67,30 @@ class TaskRepository:
case protocol_schema.RankPrompterRepliesTask:
payload = db_payload.RankPrompterRepliesPayload(
type=task.type, conversation=task.conversation, reply_messages=task.reply_messages
type=task.type,
conversation=task.conversation,
reply_messages=task.reply_messages,
ranking_parent_id=task.ranking_parent_id,
message_tree_id=task.message_tree_id,
)
case protocol_schema.RankAssistantRepliesTask:
payload = db_payload.RankAssistantRepliesPayload(
type=task.type, conversation=task.conversation, reply_messages=task.reply_messages
type=task.type,
conversation=task.conversation,
reply_messages=task.reply_messages,
ranking_parent_id=task.ranking_parent_id,
message_tree_id=task.message_tree_id,
)
case protocol_schema.LabelInitialPromptTask:
payload = db_payload.LabelInitialPromptPayload(
type=task.type, message_id=task.message_id, prompt=task.prompt, valid_labels=task.valid_labels
type=task.type,
message_id=task.message_id,
prompt=task.prompt,
valid_labels=task.valid_labels,
mandatory_labels=task.mandatory_labels,
mode=task.mode,
)
case protocol_schema.LabelPrompterReplyTask:
@@ -88,6 +101,7 @@ class TaskRepository:
reply=task.reply,
valid_labels=task.valid_labels,
mandatory_labels=task.mandatory_labels,
mode=task.mode,
)
case protocol_schema.LabelAssistantReplyTask:
@@ -98,6 +112,7 @@ class TaskRepository:
reply=task.reply,
valid_labels=task.valid_labels,
mandatory_labels=task.mandatory_labels,
mode=task.mode,
)
case _:
+13 -2
View File
@@ -207,6 +207,9 @@ class TreeManager:
ranking_parent_id = random.choice(incomplete_rankings).parent_id
messages = self.pr.fetch_message_conversation(ranking_parent_id)
assert len(messages) > 1 and messages[-1].id == ranking_parent_id
ranking_parent = messages[-1]
assert not ranking_parent.deleted and ranking_parent.review_result
conversation = prepare_conversation(messages)
replies = self.pr.fetch_message_children(ranking_parent_id, reviewed=True, exclude_deleted=True)
@@ -218,12 +221,20 @@ class TreeManager:
if messages[-1].role == "assistant":
logger.info("Generating a RankPrompterRepliesTask.")
task = protocol_schema.RankPrompterRepliesTask(
conversation=conversation, replies=replies, reply_messages=reply_messages
conversation=conversation,
replies=replies,
reply_messages=reply_messages,
ranking_parent_id=ranking_parent.id,
message_tree_id=ranking_parent.message_tree_id,
)
else:
logger.info("Generating a RankAssistantRepliesTask.")
task = protocol_schema.RankAssistantRepliesTask(
conversation=conversation, replies=replies, reply_messages=reply_messages
conversation=conversation,
replies=replies,
reply_messages=reply_messages,
ranking_parent_id=ranking_parent.id,
message_tree_id=ranking_parent.message_tree_id,
)
parent_message_id = ranking_parent_id
+2 -24
View File
@@ -1,11 +1,10 @@
from typing import Optional
from uuid import UUID
from oasst_backend.models import ApiClient, Message, User
from oasst_backend.models import ApiClient, User
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.schemas.protocol import LeaderboardStats
from sqlmodel import Session, func
from sqlmodel import Session
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
@@ -137,27 +136,6 @@ class UserRepository:
self.db.commit()
return user
def get_user_leaderboard(self, role: str) -> LeaderboardStats:
"""
Get leaderboard stats for Messages created,
separate leaderboard for prompts & assistants
"""
query = (
self.db.query(Message.user_id, User.username, User.display_name, func.count(Message.user_id))
.join(User, User.id == Message.user_id, isouter=True)
.filter(Message.deleted is not True, Message.role == role)
.group_by(Message.user_id, User.username, User.display_name)
.order_by(func.count(Message.user_id).desc())
)
result = [
{"ranking": i, "user_id": j[0], "username": j[1], "display_name": j[2], "score": j[3]}
for i, j in enumerate(query.all(), start=1)
]
return LeaderboardStats(leaderboard=result)
def query_users(
self,
api_client_id: Optional[UUID] = None,
@@ -0,0 +1,303 @@
from datetime import datetime, timedelta
from typing import Optional
from uuid import UUID
import sqlalchemy as sa
from loguru import logger
from oasst_backend.models import Message, MessageReaction, Task, User, UserStats, UserStatsTimeFrame
from oasst_backend.models.db_payload import (
LabelAssistantReplyPayload,
LabelPrompterReplyPayload,
RankingReactionPayload,
)
from oasst_shared.schemas.protocol import LeaderboardStats, UserScore
from oasst_shared.utils import log_timing, utcnow
from sqlalchemy.dialects import postgresql
from sqlmodel import Session, delete, func, text
def _create_user_score(r):
if r["UserStats"]:
d = r["UserStats"].dict()
else:
d = {"modified_date": utcnow()}
for k in ["user_id", "username", "auth_method", "display_name"]:
d[k] = r[k]
return UserScore(**d)
class UserStatsRepository:
def __init__(self, session: Session):
self.session = session
def get_leaderboard(self, time_frame: UserStatsTimeFrame, limit: int = 100) -> LeaderboardStats:
"""
Get leaderboard stats for the specified time frame
"""
qry = (
self.session.query(User.id.label("user_id"), User.username, User.auth_method, User.display_name, UserStats)
.join(UserStats, User.id == UserStats.user_id)
.filter(UserStats.time_frame == time_frame.value)
.order_by(UserStats.leader_score.desc())
.limit(limit)
)
leaderboard = [_create_user_score(r) for r in self.session.exec(qry)]
return LeaderboardStats(time_frame=time_frame.value, leaderboard=leaderboard)
def get_user_stats_all_time_frames(self, user_id: UUID) -> dict[str, UserScore | None]:
qry = (
self.session.query(User.id.label("user_id"), User.username, User.auth_method, User.display_name, UserStats)
.outerjoin(UserStats, User.id == UserStats.user_id)
.filter(User.id == user_id)
)
stats_by_timeframe = {}
for r in self.session.exec(qry):
us = r["UserStats"]
if us is not None:
stats_by_timeframe[us.time_frame] = _create_user_score(r)
else:
stats_by_timeframe = {tf.value: _create_user_score(r) for tf in UserStatsTimeFrame}
return stats_by_timeframe
def query_total_prompts_per_user(
self, reference_time: Optional[datetime] = None, only_reviewed: Optional[bool] = True
):
qry = self.session.query(Message.user_id, func.count()).filter(
Message.deleted == sa.false(), Message.parent_id.is_(None)
)
if reference_time:
qry = qry.filter(Message.created_date >= reference_time)
if only_reviewed:
qry = qry.filter(Message.review_result == sa.true())
qry = qry.group_by(Message.user_id)
return qry
def query_replies_by_role_per_user(
self, reference_time: Optional[datetime] = None, only_reviewed: Optional[bool] = True
) -> list:
qry = self.session.query(Message.user_id, Message.role, func.count()).filter(
Message.deleted == sa.false(), Message.parent_id.is_not(None)
)
if reference_time:
qry = qry.filter(Message.created_date >= reference_time)
if only_reviewed:
qry = qry.filter(Message.review_result == sa.true())
qry = qry.group_by(Message.user_id, Message.role)
return qry
def query_labels_by_mode_per_user(
self, payload_type: str = LabelAssistantReplyPayload.__name__, reference_time: Optional[datetime] = None
):
qry = self.session.query(Task.user_id, Task.payload["payload", "mode"].astext, func.count()).filter(
Task.done == sa.true(), Task.payload_type == payload_type
)
if reference_time:
qry = qry.filter(Task.created_date >= reference_time)
qry = qry.group_by(Task.user_id, Task.payload["payload", "mode"].astext)
return qry
def query_rankings_per_user(self, reference_time: Optional[datetime] = None):
qry = self.session.query(MessageReaction.user_id, func.count()).filter(
MessageReaction.payload_type == RankingReactionPayload.__name__
)
if reference_time:
qry = qry.filter(MessageReaction.created_date >= reference_time)
qry = qry.group_by(MessageReaction.user_id)
return qry
def query_ranking_result_users(self, rank: int = 0, reference_time: Optional[datetime] = None):
ranked_message_id = MessageReaction.payload["payload", "ranked_message_ids", rank].astext.cast(
postgresql.UUID(as_uuid=True)
)
qry = (
self.session.query(Message.user_id, func.count())
.select_from(MessageReaction)
.join(Message, ranked_message_id == Message.id)
.filter(MessageReaction.payload_type == RankingReactionPayload.__name__)
)
if reference_time:
qry = qry.filter(MessageReaction.created_date >= reference_time)
qry = qry.group_by(Message.user_id)
return qry
def _update_stats_internal(self, time_frame: UserStatsTimeFrame, base_date: Optional[datetime] = None):
# gather user data
time_frame_key = time_frame.value
stats_by_user: dict[UUID, UserStats] = dict()
now = utcnow()
def get_stats(id: UUID) -> UserStats:
us = stats_by_user.get(id)
if not us:
us = UserStats(user_id=id, time_frame=time_frame_key, modified_date=now, base_date=base_date)
stats_by_user[id] = us
return us
# total prompts
qry = self.query_total_prompts_per_user(reference_time=base_date, only_reviewed=False)
for r in qry:
uid, count = r
get_stats(uid).prompts = count
# accepted prompts
qry = self.query_total_prompts_per_user(reference_time=base_date, only_reviewed=True)
for r in qry:
uid, count = r
get_stats(uid).accepted_prompts = count
# total replies
qry = self.query_replies_by_role_per_user(reference_time=base_date, only_reviewed=False)
for r in qry:
uid, role, count = r
s = get_stats(uid)
if role == "assistant":
s.replies_assistant += count
elif role == "prompter":
s.replies_prompter += count
# accepted replies
qry = self.query_replies_by_role_per_user(reference_time=base_date, only_reviewed=True)
for r in qry:
uid, role, count = r
s = get_stats(uid)
if role == "assistant":
s.accepted_replies_assistant += count
elif role == "prompter":
s.accepted_replies_prompter += count
# simple and full labels
qry = self.query_labels_by_mode_per_user(
payload_type=LabelAssistantReplyPayload.__name__, reference_time=base_date
)
for r in qry:
uid, mode, count = r
s = get_stats(uid)
if mode == "simple":
s.labels_simple = count
elif mode == "full":
s.labels_full = count
qry = self.query_labels_by_mode_per_user(
payload_type=LabelPrompterReplyPayload.__name__, reference_time=base_date
)
for r in qry:
uid, mode, count = r
s = get_stats(uid)
if mode == "simple":
s.labels_simple += count
elif mode == "full":
s.labels_full += count
qry = self.query_rankings_per_user(reference_time=base_date)
for r in qry:
uid, count = r
get_stats(uid).rankings_total = count
rank_field_names = ["reply_ranked_1", "reply_ranked_2", "reply_ranked_3"]
for i, fn in enumerate(rank_field_names):
qry = self.query_ranking_result_users(reference_time=base_date, rank=0)
for r in qry:
uid, count = r
setattr(get_stats(uid), fn, count)
# delete all existing stast for time frame
d = delete(UserStats).where(UserStats.time_frame == time_frame_key)
self.session.execute(d)
# compute magic leader score
for v in stats_by_user.values():
v.leader_score = v.compute_leader_score()
# insert user objects
self.session.add_all(stats_by_user.values())
self.session.flush()
self.update_ranks(time_frame=time_frame)
@log_timing(log_kwargs=True)
def update_ranks(self, time_frame: UserStatsTimeFrame = None):
"""
Update user_stats ranks. The persisted rank values allow to
quickly the rank of a single user and to query nearby users.
"""
# todo: convert sql to sqlalchemy query..
# ranks = self.session.query(
# func.row_number()
# .over(partition_by=UserStats.time_frame, order_by=[UserStats.leader_score.desc(), UserStats.user_id])
# .label("rank"),
# UserStats.user_id,
# UserStats.time_frame,
# )
sql_update_rank = """
-- update rank
UPDATE user_stats us
SET "rank" = r."rank"
FROM
(SELECT
ROW_NUMBER () OVER(
PARTITION BY time_frame
ORDER BY leader_score DESC, user_id
) AS "rank", user_id, time_frame
FROM user_stats
WHERE (:time_frame IS NULL OR time_frame = :time_frame)) AS r
WHERE
us.user_id = r.user_id
AND us.time_frame = r.time_frame;"""
r = self.session.execute(
text(sql_update_rank), {"time_frame": time_frame.value if time_frame is not None else None}
)
logger.debug(f"pre_compute_ranks updated({time_frame=}) {r.rowcount} rows.")
def update_stats_time_frame(self, time_frame: UserStatsTimeFrame, reference_time: Optional[datetime] = None):
self._update_stats_internal(time_frame, reference_time)
self.session.commit()
@log_timing(log_kwargs=True, level="INFO")
def update_stats(self, *, time_frame: UserStatsTimeFrame):
now = utcnow()
match time_frame:
case UserStatsTimeFrame.day:
r = now - timedelta(days=1)
self.update_stats_time_frame(time_frame, r)
case UserStatsTimeFrame.week:
r = now.date() - timedelta(days=7)
r = datetime(r.year, r.month, r.day, tzinfo=now.tzinfo)
self.update_stats_time_frame(time_frame, r)
case UserStatsTimeFrame.month:
r = now.date() - timedelta(days=30)
r = datetime(r.year, r.month, r.day, tzinfo=now.tzinfo)
self.update_stats_time_frame(time_frame, r)
case UserStatsTimeFrame.total:
self.update_stats_time_frame(time_frame, None)
@log_timing(level="INFO")
def update_multiple_time_frames(self, time_frames: list[UserStatsTimeFrame]):
for t in time_frames:
self.update_stats(time_frame=t)
@log_timing(level="INFO")
def update_all_time_frames(self):
self.update_multiple_time_frames(list(UserStatsTimeFrame))
if __name__ == "__main__":
from oasst_backend.api.deps import get_dummy_api_client
from oasst_backend.database import engine
with Session(engine) as session:
api_client = get_dummy_api_client(session)
usr = UserStatsRepository(session)
# usr.update_all_time_frames()
# session.commit()
# usr.get_leader_board(UserStatsTimeFrame.total)
usr.get_user_stats_all_time_frames(UUID("0d6ff62a-0bea-4c56-ade8-b3e0520a10ce"))
+1
View File
@@ -1,6 +1,7 @@
alembic==1.8.1
fastapi==0.88.0
fastapi-limiter==0.1.5
fastapi-utils==0.2.1
loguru==0.6.0
numpy==1.22.4
psycopg2-binary==2.9.5
+1 -1
View File
@@ -29,7 +29,7 @@ environments:
variables:
# Note: this has to be a valid JSON list for Pydantic to parse it.
BACKEND_CORS_ORIGINS: '["https://web.staging.open-assistant.surfacedata.org"]'
DEBUG_ALLOW_ANY_API_KEY: True
DEBUG_ALLOW_DEBUG_API_KEY: True
DEBUG_SKIP_API_KEY_CHECK: True
MAX_WORKERS: 1
+3
View File
@@ -0,0 +1,3 @@
# Deployment files
Copy these to the node you want to deploy to.
+19
View File
@@ -0,0 +1,19 @@
version: "3"
services:
webserver:
image: nginx:latest
network_mode: host
ports:
- 80:80
- 443:443
restart: always
volumes:
- ./nginx.conf:/etc/nginx/nginx.conf:ro
- ./certbot/www:/var/www/certbot/:ro
- ./certbot/conf/:/etc/nginx/ssl/:ro
certbot:
image: certbot/certbot:latest
volumes:
- ./certbot/www/:/var/www/certbot/:rw
- ./certbot/conf/:/etc/letsencrypt/:rw
+3
View File
@@ -0,0 +1,3 @@
#!/bin/bash
docker compose run --rm certbot certonly -m admin@open-assistant.io --agree-tos --webroot --webroot-path /var/www/certbot/ -d $1
+81
View File
@@ -0,0 +1,81 @@
events {}
http {
server {
listen 80;
listen [::]:80;
server_name *.open-assistant.io;
server_tokens off;
location /.well-known/acme-challenge/ {
root /var/www/certbot;
}
location / {
return 301 https://$host$request_uri;
}
}
server {
listen 443 ssl http2;
server_name web.dev.open-assistant.io;
ssl_certificate /etc/nginx/ssl/live/web.dev.open-assistant.io/fullchain.pem;
ssl_certificate_key /etc/nginx/ssl/live/web.dev.open-assistant.io/privkey.pem;
location / {
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_pass http://127.0.0.1:3000;
}
}
server {
listen 443 ssl http2;
server_name backend.dev.open-assistant.io;
ssl_certificate /etc/nginx/ssl/live/backend.dev.open-assistant.io/fullchain.pem;
ssl_certificate_key /etc/nginx/ssl/live/backend.dev.open-assistant.io/privkey.pem;
location / {
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_pass http://127.0.0.1:8080;
}
}
server {
listen 443 ssl http2;
server_name web.staging.open-assistant.io;
ssl_certificate /etc/nginx/ssl/live/web.staging.open-assistant.io/fullchain.pem;
ssl_certificate_key /etc/nginx/ssl/live/web.staging.open-assistant.io/privkey.pem;
location / {
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_pass http://127.0.0.1:3100;
}
}
server {
listen 443 ssl http2;
server_name backend.staging.open-assistant.io;
ssl_certificate /etc/nginx/ssl/live/backend.staging.open-assistant.io/fullchain.pem;
ssl_certificate_key /etc/nginx/ssl/live/backend.staging.open-assistant.io/privkey.pem;
location / {
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_pass http://127.0.0.1:8180;
}
}
}
+3
View File
@@ -0,0 +1,3 @@
#!/bin/bash
docker compose run --rm certbot renew
@@ -0,0 +1,19 @@
version: "3"
services:
webserver:
image: nginx:latest
network_mode: host
ports:
- 80:80
- 443:443
restart: always
volumes:
- ./nginx.conf:/etc/nginx/nginx.conf:ro
- ./certbot/www:/var/www/certbot/:ro
- ./certbot/conf/:/etc/nginx/ssl/:ro
certbot:
image: certbot/certbot:latest
volumes:
- ./certbot/www/:/var/www/certbot/:rw
- ./certbot/conf/:/etc/letsencrypt/:rw
+3
View File
@@ -0,0 +1,3 @@
#!/bin/bash
docker compose run --rm certbot certonly -m admin@open-assistant.io --agree-tos --webroot --webroot-path /var/www/certbot/ -d $1
+62
View File
@@ -0,0 +1,62 @@
events {}
http {
server {
listen 80;
listen [::]:80;
server_name *.open-assistant.io open-assistant.io;
server_tokens off;
location /.well-known/acme-challenge/ {
root /var/www/certbot;
}
location / {
return 301 https://$host$request_uri;
}
}
server {
listen 443 ssl http2;
server_name open-assistant.io;
ssl_certificate /etc/nginx/ssl/live/open-assistant.io/fullchain.pem;
ssl_certificate_key /etc/nginx/ssl/live/open-assistant.io/privkey.pem;
location / {
return 301 https://web.prod.open-assistant.io$request_uri;
}
}
server {
listen 443 ssl http2;
server_name web.prod.open-assistant.io;
ssl_certificate /etc/nginx/ssl/live/web.prod.open-assistant.io/fullchain.pem;
ssl_certificate_key /etc/nginx/ssl/live/web.prod.open-assistant.io/privkey.pem;
location / {
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_pass http://127.0.0.1:3000;
}
}
server {
listen 443 ssl http2;
server_name backend.prod.open-assistant.io;
ssl_certificate /etc/nginx/ssl/live/backend.prod.open-assistant.io/fullchain.pem;
ssl_certificate_key /etc/nginx/ssl/live/backend.prod.open-assistant.io/privkey.pem;
location / {
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_pass http://127.0.0.1:8080;
}
}
}
+3
View File
@@ -0,0 +1,3 @@
#!/bin/bash
docker compose run --rm certbot renew
+5 -6
View File
@@ -32,12 +32,15 @@ const config = {
presets: [
[
"classic",
"docusaurus-preset-openapi",
/** @type {import('@docusaurus/preset-classic').Options} */
({
docs: {
sidebarPath: require.resolve("./sidebars.js"),
},
api: {
path: "docs/api/openapi.json",
},
blog: false,
theme: {
customCss: require.resolve("./src/css/custom.css"),
@@ -62,11 +65,7 @@ const config = {
position: "left",
label: "Docs",
},
{
href: "https://editor.swagger.io/?url=https://raw.githubusercontent.com/LAION-AI/Open-Assistant/main/docs/docs/api/openapi.json",
label: "API",
position: "left",
},
{ to: "/api", label: "API", position: "left" },
{
href: "https://github.com/LAION-AI/Open-Assistant",
label: "GitHub",
+3 -1
View File
@@ -19,9 +19,11 @@
"@docusaurus/preset-classic": "2.2.0",
"@mdx-js/react": "^1.6.22",
"clsx": "^1.2.1",
"docusaurus-preset-openapi": "^0.6.3",
"prism-react-renderer": "^1.3.5",
"react": "^17.0.2",
"react-dom": "^17.0.2"
"react-dom": "^17.0.2",
"url": "^0.11.0"
},
"devDependencies": {
"@docusaurus/module-type-aliases": "2.2.0",
+972 -37
View File
File diff suppressed because it is too large Load Diff
@@ -17,6 +17,7 @@ class OasstErrorCode(IntEnum):
GENERIC_ERROR = 0
DATABASE_URI_NOT_SET = 1
API_CLIENT_NOT_AUTHORIZED = 2
ROOT_TOKEN_NOT_AUTHORIZED = 3
TOO_MANY_REQUESTS = 429
SERVER_ERROR0 = 500
+30 -2
View File
@@ -169,6 +169,8 @@ class RankConversationRepliesTask(Task):
conversation: Conversation # the conversation so far
replies: list[str] # deprecated, use reply_messages
reply_messages: list[ConversationMessage]
message_tree_id: UUID
ranking_parent_id: UUID
class RankPrompterRepliesTask(RankConversationRepliesTask):
@@ -356,14 +358,40 @@ class SystemStats(BaseModel):
class UserScore(BaseModel):
ranking: int
rank: Optional[int]
user_id: UUID
username: str
auth_method: str
display_name: str
score: int
leader_score: int = 0
base_date: Optional[datetime]
modified_date: Optional[datetime]
prompts: int = 0
replies_assistant: int = 0
replies_prompter: int = 0
labels_simple: int = 0
labels_full: int = 0
rankings_total: int = 0
rankings_good: int = 0
accepted_prompts: int = 0
accepted_replies_assistant: int = 0
accepted_replies_prompter: int = 0
reply_ranked_1: int = 0
reply_ranked_2: int = 0
reply_ranked_3: int = 0
# only used for time frame "total"
streak_last_day_date: Optional[datetime]
streak_days: Optional[int]
class LeaderboardStats(BaseModel):
time_frame: str
leaderboard: List[UserScore]
+26
View File
@@ -1,6 +1,32 @@
import time
from datetime import datetime, timezone
from functools import wraps
from loguru import logger
def utcnow() -> datetime:
"""Return the current utc date and time with tzinfo set to UTC."""
return datetime.now(timezone.utc)
def log_timing(func=None, *, log_kwargs: bool = False, level: int | str = "DEBUG") -> None:
def decorator(func):
@wraps(func)
def wrapped(*args, **kwargs):
start = time.time()
result = func(*args, **kwargs)
end = time.time()
elapsed = end - start
if log_kwargs:
kwargs = ", ".join([f"{k}={v}" for k, v in kwargs.items()])
logger.log(level, f"Function '{func.__name__}({kwargs})' executed in {elapsed:f} s")
else:
logger.log(level, f"Function '{func.__name__}' executed in {elapsed:f} s")
return result
return wrapped
if func and callable(func):
return decorator(func)
return decorator
+1 -1
View File
@@ -4,7 +4,7 @@ parent_path=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P )
# switch to backend directory
pushd "$parent_path/../../backend"
export DEBUG_SKIP_API_KEY_CHECK=True
export DEBUG_SKIP_API_KEY_CHECK=False
export DEBUG_USE_SEED_DATA=True
export DEBUG_SKIP_TOXICITY_CALCULATION=True
export DEBUG_ALLOW_SELF_LABELING=True
+50 -12
View File
@@ -25,10 +25,10 @@ from bs4 import BeautifulSoup as bs
from logic.logic_injector import LogicBug
from nltk.corpus import wordnet
from syntax.syntax_injector import SyntaxBug
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, T5ForConditionalGeneration, pipeline
class DataArgumenter:
class DataAugmenter:
def __init__(self):
raise NotImplementedError()
@@ -48,7 +48,7 @@ class DataArgumenter:
pass
class EssayInstructor(DataArgumenter):
class EssayInstructor(DataAugmenter):
def __init__(self, model_name=None):
if model_name is None:
model_name = "snrspeaks/t5-one-line-summary"
@@ -86,7 +86,7 @@ class EssayInstructor(DataArgumenter):
return prompts, essay_paragraphs
class EssayReviser(DataArgumenter):
class EssayReviser(DataAugmenter):
def __init__(self):
nltk.download("wordnet")
nltk.download("omw-1.4")
@@ -132,7 +132,7 @@ class EssayReviser(DataArgumenter):
return instructions, [essay] * len(instructions)
class StackExchangeBuilder(DataArgumenter):
class StackExchangeBuilder(DataAugmenter):
def __init__(self, base_url=None, filter_opts=None):
self.base_url = (
base_url
@@ -271,7 +271,7 @@ class StackExchangeBuilder(DataArgumenter):
return questions, answers
class HierachicalSummarizer(DataArgumenter):
class HierachicalSummarizer(DataAugmenter):
def __init__(self):
self.summarizer = pipeline(
"summarization",
@@ -342,7 +342,7 @@ class HierachicalSummarizer(DataArgumenter):
return instructions, answers
class EntityRecognizedSummarizer(DataArgumenter):
class EntityRecognizedSummarizer(DataAugmenter):
def __init__(self):
self.nlp = spacy.load("en_core_web_sm") # run !python -m spacy download en_core_web_sm in order to download
@@ -357,7 +357,7 @@ class EntityRecognizedSummarizer(DataArgumenter):
return [question], [answer]
class CodeBugger(DataArgumenter):
class CodeBugger(DataAugmenter):
"""
https://github.com/LAION-AI/Open-Assistant/blob/main/notebooks/code-bugger/openbugger_example.md
Openbugger is a Python package that allows you to inject syntax and logic errors into your code.
@@ -391,6 +391,38 @@ class CodeBugger(DataArgumenter):
return [question], [answer]
class CodeInstructor(DataAugmenter):
def __init__(self):
self.tokenizer = AutoTokenizer.from_pretrained("Graverman/t5-code-summary")
self.model = T5ForConditionalGeneration.from_pretrained("Graverman/t5-code-summary")
def parse(self, codes):
source_encoding = self.tokenizer(
codes,
max_length=300,
padding="max_length",
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt",
)
outputs = self.model.generate(
input_ids=source_encoding["input_ids"],
attention_mask=source_encoding["attention_mask"],
max_length=100,
length_penalty=0.75,
repetition_penalty=2.5,
early_stopping=True,
use_cache=True,
)
summaries = [self.tokenizer.decode(o, skip_special_tokens=True) for o in outputs]
questions = ["Write a script that does the following:\n" + s for s in summaries]
answers = codes
return questions, answers
def recognize_entities(text, model, n=4, person="ignore"):
"""Given a text and a model for entity recognition, return the most occuring entites in the text as a string"""
doc = model(text)
@@ -417,20 +449,23 @@ def parse_arguments():
args.add_argument("--output", type=str, required=True)
args = args.parse_args()
assert args.dataset.endswith(".tsv"), "Dataset file must be a tsv file, containing a list of files to be augmented"
assert args.dataset.endswith(".tsv") or args.dataset.endswith(
".csv"
), "Dataset file must be a tsv or csv file, containing a list of files to be augmented"
assert args.output.endswith(".json"), "Output file must be a json file"
return args
def read_data(args):
files = pd.read_csv(args.dataset, sep="\t", header=None)
files = files[0].tolist()
files = pd.read_csv(args.dataset, sep=",", header=None, names=["file"])
files = files["file"].tolist()
data = []
for file in files:
with open(file, "r") as f:
text = f.read()
data.append(text)
return data
@@ -453,9 +488,12 @@ def get_augmenter(args):
elif args.augmenter == "codebugger":
augmenter = CodeBugger()
elif args.augmenter == "codeinstructor":
augmenter = CodeInstructor()
else:
raise ValueError(
"Augmenter must be one of 'essayinstruction', 'essayrevision', 'stackexchange', 'hierarchicalsummarizer', 'entityrecognizedsummarizer', 'codebugger"
"Augmenter must be one of 'essayinstruction', 'essayrevision', 'stackexchange', 'hierarchicalsummarizer', 'entityrecognizedsummarizer', 'codebugger', 'codeinstructor"
)
return augmenter
+616 -29
View File
File diff suppressed because it is too large Load Diff
+1
View File
@@ -58,6 +58,7 @@
"react-dom": "18.2.0",
"react-feature-flags": "^1.0.0",
"react-icons": "^4.7.1",
"sharp": "^0.31.3",
"swr": "^2.0.0",
"tailwindcss": "^3.2.4",
"use-debounce": "^9.0.2"
+1
View File
@@ -41,6 +41,7 @@ model User {
email String? @unique
emailVerified DateTime?
image String?
isNew Boolean @default(true)
role String @default("general")
accounts Account[]
+7 -7
View File
@@ -3,7 +3,6 @@ import {
Button,
Checkbox,
Flex,
Grid,
Popover,
PopoverAnchor,
PopoverArrow,
@@ -146,24 +145,25 @@ export const FlaggableElement = (props: FlaggableElementProps) => {
isLazy
lazyBehavior="keepMounted"
>
<Grid display="flex" alignItems="center" gap="2">
<Box display="flex" alignItems="center" gap="2">
<PopoverAnchor>{props.children}</PopoverAnchor>
<Tooltip label="Report" bg="red.500" aria-label="A tooltip">
<div>
<Box>
<PopoverTrigger>
<Box as="button" display="flex" alignItems="center" justifyContent="center" borderRadius="full" p="1">
<FiAlertCircle size="20" className="text-red-400" aria-hidden="true" />
</Box>
</PopoverTrigger>
</div>
</Box>
</Tooltip>
</Grid>
</Box>
<PopoverContent width="auto" p="3" m="4" maxWidth="calc(100vw - 2rem)">
<PopoverArrow />
<div className="relative h-4">
<Box className="relative h-4">
<PopoverCloseButton />
</div>
</Box>
<PopoverBody>
{report.label_values.map(({ label, checked, value }, i) => (
<FlagCheckbox
+90 -87
View File
@@ -11,102 +11,105 @@ import {
Text,
useColorModeValue,
} from "@chakra-ui/react";
import { signOut, useSession } from "next-auth/react";
import NextLink from "next/link";
import React from "react";
import { signOut, useSession } from "next-auth/react";
import React, { ElementType, useCallback } from "react";
import { FiAlertTriangle, FiLayout, FiLogOut, FiSettings, FiShield } from "react-icons/fi";
interface MenuOption {
name: string;
href: string;
desc: string;
icon: ElementType;
isExternal: boolean;
}
export function UserMenu() {
const borderColor = useColorModeValue("gray.300", "gray.600");
const handleSignOut = useCallback(() => {
signOut({ callbackUrl: "/" });
}, []);
const { data: session, status } = useSession();
const { data: session } = useSession();
if (!session) {
return <></>;
if (!session || status !== "authenticated") {
return null;
}
if (session && session.user) {
const accountOptions = [
{
name: "Dashboard",
href: "/dashboard",
desc: "Dashboard",
icon: FiLayout,
},
{
name: "Account Settings",
href: "/account",
desc: "Account Settings",
icon: FiSettings,
},
];
const helpOptions = [
{
name: "Report a Bug",
href: "https://github.com/LAION-AI/Open-Assistant/issues/new/choose",
desc: "Report a Bug",
icon: FiAlertTriangle,
},
];
const options: MenuOption[] = [
{
name: "Dashboard",
href: "/dashboard",
desc: "Dashboard",
icon: FiLayout,
isExternal: false,
},
{
name: "Account Settings",
href: "/account",
desc: "Account Settings",
icon: FiSettings,
isExternal: false,
},
{
name: "Report a Bug",
href: "https://github.com/LAION-AI/Open-Assistant/issues/new/choose",
desc: "Report a Bug",
icon: FiAlertTriangle,
isExternal: true,
},
];
if (session.user.role === "admin") {
accountOptions.unshift({
name: "Admin Dashboard",
href: "/admin",
desc: "Admin Dashboard",
icon: FiShield,
});
}
return (
<>
<Menu>
<MenuButton border="solid" borderRadius="full" borderWidth="thin" borderColor={borderColor}>
<Box display="flex" alignItems="center" gap="3" p="1" paddingRight={[1, 1, 1, 6, 6]}>
<Avatar size="sm" bgImage={session.user.image}></Avatar>
<Text data-cy="username" className="hidden lg:flex">
{session.user.name || session.user.email}
</Text>
</Box>
</MenuButton>
<MenuList p="2" borderRadius="xl" shadow="none">
<Box display="flex" flexDirection="column" alignItems="center" borderRadius="md" p="4">
<Text>{session.user.name}</Text>
<Text color="blue.500" fontWeight="bold" fontSize="xl">
3,200
</Text>
</Box>
<MenuDivider />
<MenuGroup>
{accountOptions.map((item) => (
<Link as={NextLink} key={item.name} href={item.href} _hover={{ textDecoration: "none" }}>
<MenuItem gap="3" borderRadius="md" p="4">
<item.icon className="text-blue-500" aria-hidden="true" />
<Text>{item.name}</Text>
</MenuItem>
</Link>
))}
</MenuGroup>
<MenuDivider />
<MenuGroup>
{helpOptions.map((item) => (
<Link as={NextLink} key={item.name} href={item.href} isExternal _hover={{ textDecoration: "none" }}>
<MenuItem gap="3" borderRadius="md" p="4">
<item.icon className="text-blue-500" aria-hidden="true" />
<Text>{item.name}</Text>
</MenuItem>
</Link>
))}
</MenuGroup>
<MenuDivider />
<MenuItem gap="3" borderRadius="md" p="4" onClick={() => signOut({ callbackUrl: "/" })}>
<FiLogOut className="text-blue-500" aria-hidden="true" />
<Text>Sign Out</Text>
</MenuItem>
</MenuList>
</Menu>
</>
);
if (session.user.role === "admin") {
options.unshift({
name: "Admin Dashboard",
href: "/admin",
desc: "Admin Dashboard",
icon: FiShield,
isExternal: false,
});
}
return (
<Menu>
<MenuButton border="solid" borderRadius="full" borderWidth="thin" borderColor={borderColor}>
<Box display="flex" alignItems="center" gap="3" p="1" paddingRight={[1, 1, 1, 6, 6]}>
<Avatar size="sm" bgImage={session.user.image}></Avatar>
<Text data-cy="username" className="hidden lg:flex">
{session.user.name || session.user.email}
</Text>
</Box>
</MenuButton>
<MenuList p="2" borderRadius="xl" shadow="none">
<Box display="flex" flexDirection="column" alignItems="center" borderRadius="md" p="4">
<Text>{session.user.name}</Text>
{/* <Text color="blue.500" fontWeight="bold" fontSize="xl">
3,200
</Text> */}
</Box>
<MenuDivider />
<MenuGroup>
{options.map((item) => (
<Link
key={item.name}
as={item.isExternal ? "a" : NextLink}
isExternal={item.isExternal}
href={item.href}
_hover={{ textDecoration: "none" }}
>
<MenuItem gap="3" borderRadius="md" p="4">
<item.icon className="text-blue-500" aria-hidden="true" />
<Text>{item.name}</Text>
</MenuItem>
</Link>
))}
</MenuGroup>
<MenuDivider />
<MenuItem gap="3" borderRadius="md" p="4" onClick={handleSignOut}>
<FiLogOut className="text-blue-500" aria-hidden="true" />
<Text>Sign Out</Text>
</MenuItem>
</MenuList>
</Menu>
);
}
export default UserMenu;
+1
View File
@@ -54,6 +54,7 @@ export const getDashboardLayout = (page: React.ReactElement) => (
>
{page}
</SideMenuLayout>
<Footer />
</div>
);
@@ -19,7 +19,7 @@ export function MessageTableEntry(props: MessageTableEntryProps) {
return (
<FlaggableElement message={item}>
<HStack w="100%" gap={2}>
<HStack w={["full", "full", "full", "fit-content"]} gap={2}>
<Box borderRadius="full" border="solid" borderWidth="1px" borderColor={borderColor} bg={avatarColor}>
<Avatar
size="sm"
@@ -28,21 +28,20 @@ export function MessageTableEntry(props: MessageTableEntryProps) {
/>
</Box>
{props.enabled ? (
<Box maxWidth="xl">
<Box width={["full", "full", "full", "fit-content"]} maxWidth={["full", "full", "full", "2xl"]}>
<Link href={`/messages/${item.id}`}>
<LinkBox
bg={item.is_assistant ? backgroundColor : backgroundColor2}
className={`p-4 rounded-md whitespace-pre-wrap w-full`}
>
<LinkBox bg={item.is_assistant ? backgroundColor : backgroundColor2} p="4" borderRadius="md">
{item.text}
</LinkBox>
</Link>
</Box>
) : (
<Box
maxWidth="xl"
width={["full", "full", "full", "fit-content"]}
maxWidth={["full", "full", "full", "2xl"]}
bg={item.is_assistant ? backgroundColor : backgroundColor2}
className={`p-4 rounded-md whitespace-pre-wrap w-full`}
p="4"
borderRadius="md"
>
{item.text}
</Box>
+5 -3
View File
@@ -50,7 +50,7 @@ if (boolean(process.env.DEBUG_LOGIN) || process.env.NODE_ENV === "development")
where: {
id: user.id,
},
update: {},
update: user,
create: user,
});
return user;
@@ -86,6 +86,7 @@ export const authOptions: AuthOptions = {
*/
async session({ session, token }) {
session.user.role = token.role;
session.user.isNew = token.isNew;
return session;
},
/**
@@ -93,11 +94,12 @@ export const authOptions: AuthOptions = {
* This let's use forward the role to the session object.
*/
async jwt({ token }) {
const { role } = await prisma.user.findUnique({
const { isNew, role } = await prisma.user.findUnique({
where: { id: token.sub },
select: { role: true },
select: { role: true, isNew: true },
});
token.role = role;
token.isNew = isNew;
return token;
},
},
+3
View File
@@ -17,6 +17,9 @@ const handler = withoutRole("banned", async (req, res, token) => {
// Parse out the local task ID and the interaction contents.
const { id: frontendId, content, update_type } = req.body;
// Record that the user has done meaningful work and is no longer new.
await prisma.user.update({ where: { id: token.sub }, data: { isNew: false } });
// Accept the task so that we can complete it, this will probably go away soon.
const registeredTask = await prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } });
const task = registeredTask.task as Prisma.JsonObject;
+8 -1
View File
@@ -94,7 +94,14 @@ function Signin({ csrfToken, providers }) {
{email && (
<form onSubmit={signinWithEmail}>
<Stack>
<Input data-cy="email-address" variant="outline" size="lg" placeholder="Email Address" ref={emailEl} />
<Input
type="email"
data-cy="email-address"
variant="outline"
size="lg"
placeholder="Email Address"
ref={emailEl}
/>
<Button
data-cy="signin-email-button"
size={"lg"}
+6
View File
@@ -1,9 +1,15 @@
import Head from "next/head";
import { useSession } from "next-auth/react";
import { LeaderboardTable, TaskOption } from "src/components/Dashboard";
import { getDashboardLayout } from "src/components/Layout";
import { TaskCategory } from "src/components/Tasks/TaskTypes";
const Dashboard = () => {
const { data: session } = useSession();
// TODO(#670): Do something more meaningful when the user is new.
console.log(session?.user?.isNew);
return (
<>
<Head>
+1 -1
View File
@@ -27,6 +27,6 @@ const RandomTask = () => {
);
};
RandomTask.getLayout = getDashboardLayout;
RandomTask.getLayout = (page) => getDashboardLayout(page);
export default RandomTask;
+4
View File
@@ -6,6 +6,8 @@ declare module "next-auth" {
user: {
/** The user's role. */
role: string;
/** True when the user is new. */
isNew: boolean;
} & DefaultSession["user"];
}
}
@@ -14,5 +16,7 @@ declare module "next-auth/jwt" {
interface JWT {
/** The user's role. */
role?: string;
/** True when the user is new. */
isNew?: boolean;
}
}