mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Merge branch 'main' into add-debug-skip-password-env
This commit is contained in:
@@ -18,16 +18,20 @@ We can then take the resulting model and continue with completion sampling step
|
||||
|
||||
We are not going to stop at replicating ChatGPT. We want to build the assistant of the future, able to not only write email and cover letters, but do meaningful work, use APIs, dynamically research information, and much more, with the ability to be personalized and extended by anyone. And we want to do this in a way that is open and accessible, which means we must not only build a great assistant, but also make it small and efficient enough to run on consumer hardware.
|
||||
|
||||
### Slide Decks
|
||||
|
||||
[Important Data Structures](https://docs.google.com/presentation/d/1iaX_nxasVWlvPiSNs0cllR9L_1neZq0RJxd6MFEalUY/edit?usp=sharing)
|
||||
|
||||
## How can you help?
|
||||
|
||||
All open source projects begins with people like you. Open source is the belief that if we collaborate we can together gift our knowledge and technology to the world for the benefit of humanity.
|
||||
|
||||
## I’m in! Now what?
|
||||
|
||||
[Fill out the contributor signup form](https://docs.google.com/forms/d/e/1FAIpQLSeuggO7UdYkBvGLEJldDvxp6DwaRbW5p7dl96UzFkZgziRTrQ/viewform)
|
||||
|
||||
[Join the LAION Discord Server!](https://discord.com/invite/mVcgxMPD7e)
|
||||
|
||||
[and / or the YK Discord Server](https://ykilcher.com/discord)
|
||||
|
||||
[Visit the Notion](https://ykilcher.com/open-assistant)
|
||||
|
||||
### Taking on Tasks
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""post ref for work_package
|
||||
|
||||
Revision ID: d24b37426857
|
||||
Revises: 3358eb6834e6
|
||||
Create Date: 2022-12-28 11:42:26.773704
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d24b37426857"
|
||||
down_revision = "3358eb6834e6"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("post", sa.Column("depth", sa.Integer(), server_default=sa.text("0"), nullable=False))
|
||||
op.add_column("post", sa.Column("children_count", sa.Integer(), server_default=sa.text("0"), nullable=False))
|
||||
op.add_column("post_reaction", sa.Column("work_package_id", postgresql.UUID(as_uuid=True), nullable=False))
|
||||
op.drop_constraint("post_reaction_post_id_fkey", "post_reaction", type_="foreignkey")
|
||||
op.create_foreign_key(None, "post_reaction", "work_package", ["work_package_id"], ["id"])
|
||||
op.drop_column("post_reaction", "post_id")
|
||||
op.add_column("work_package", sa.Column("done", sa.Boolean(), server_default=sa.text("false"), nullable=False))
|
||||
op.add_column("work_package", sa.Column("ack", sa.Boolean(), nullable=True))
|
||||
op.add_column("work_package", sa.Column("frontend_ref_post_id", sqlmodel.sql.sqltypes.AutoString(), nullable=True))
|
||||
op.add_column("work_package", sa.Column("thread_id", sqlmodel.sql.sqltypes.GUID(), nullable=True))
|
||||
op.add_column("work_package", sa.Column("parent_post_id", sqlmodel.sql.sqltypes.GUID(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("work_package", "parent_post_id")
|
||||
op.drop_column("work_package", "thread_id")
|
||||
op.drop_column("work_package", "frontend_ref_post_id")
|
||||
op.drop_column("work_package", "ack")
|
||||
op.drop_column("work_package", "done")
|
||||
op.add_column("post_reaction", sa.Column("post_id", postgresql.UUID(), autoincrement=False, nullable=False))
|
||||
op.drop_constraint(None, "post_reaction", type_="foreignkey")
|
||||
op.create_foreign_key("post_reaction_post_id_fkey", "post_reaction", "post", ["post_id"], ["id"])
|
||||
op.drop_column("post_reaction", "work_package_id")
|
||||
op.drop_column("post", "children_count")
|
||||
op.drop_column("post", "depth")
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,4 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
|
||||
import alembic.command
|
||||
@@ -7,10 +8,29 @@ import fastapi
|
||||
from loguru import logger
|
||||
from oasst_backend.api.v1.api import api_router
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.exceptions import OasstError, OasstErrorCode
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
app = fastapi.FastAPI(title=settings.PROJECT_NAME, openapi_url=f"{settings.API_V1_STR}/openapi.json")
|
||||
|
||||
|
||||
@app.exception_handler(OasstError)
|
||||
async def oasst_exception_handler(request: fastapi.Request, ex: OasstError):
|
||||
logger.error(f"{request.method} {request.url} failed: {repr(ex)}")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=int(ex.http_status_code), content={"message": ex.message, "error_code": ex.error_code}
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def unhandled_exception_handler(request: fastapi.Request, ex: Exception):
|
||||
logger.exception(f"{request.method} {request.url} failed [UNHANDLED]: {repr(ex)}")
|
||||
status = HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=status.value, content={"message": status.name, "error_code": OasstErrorCode.GENERIC_ERROR}
|
||||
)
|
||||
|
||||
|
||||
# Set all CORS enabled origins
|
||||
if settings.BACKEND_CORS_ORIGINS:
|
||||
app.add_middleware(
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from http import HTTPStatus
|
||||
from secrets import token_hex
|
||||
from typing import Generator
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException, Security
|
||||
from fastapi import Security
|
||||
from fastapi.security.api_key import APIKey, APIKeyHeader, APIKeyQuery
|
||||
from loguru import logger
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.database import engine
|
||||
from oasst_backend.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_backend.models import ApiClient
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_403_FORBIDDEN
|
||||
|
||||
|
||||
def get_db() -> Generator:
|
||||
@@ -36,22 +37,26 @@ def api_auth(
|
||||
api_key: APIKey,
|
||||
db: Session,
|
||||
) -> ApiClient:
|
||||
if api_key or settings.DEBUG_SKIP_API_KEY_CHECK:
|
||||
|
||||
if api_key is None and not settings.DEBUG_SKIP_API_KEY_CHECK:
|
||||
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Could not validate credentials")
|
||||
if settings.DEBUG_SKIP_API_KEY_CHECK or settings.DEBUG_ALLOW_ANY_API_KEY:
|
||||
# make sure that a dummy api key exits in db (foreign key references)
|
||||
ANY_API_KEY_ID = UUID("00000000-1111-2222-3333-444444444444")
|
||||
api_client: ApiClient = db.query(ApiClient).filter(ApiClient.id == ANY_API_KEY_ID).first()
|
||||
if api_client is None:
|
||||
token = token_hex(32)
|
||||
logger.info(f"ANY_API_KEY missing, inserting api_key: {token}")
|
||||
api_client = ApiClient(id=ANY_API_KEY_ID, api_key=token, description="ANY_API_KEY, random token")
|
||||
db.add(api_client)
|
||||
db.commit()
|
||||
return api_client
|
||||
|
||||
if settings.DEBUG_SKIP_API_KEY_CHECK or settings.DEBUG_ALLOW_ANY_API_KEY:
|
||||
# make sure that a dummy api key exits in db (foreign key references)
|
||||
ANY_API_KEY_ID = UUID("00000000-1111-2222-3333-444444444444")
|
||||
api_client: ApiClient = db.query(ApiClient).filter(ApiClient.id == ANY_API_KEY_ID).first()
|
||||
if api_client is None:
|
||||
token = token_hex(32)
|
||||
logger.info(f"ANY_API_KEY missing, inserting api_key: {token}")
|
||||
api_client = ApiClient(id=ANY_API_KEY_ID, api_key=token, description="ANY_API_KEY, random token")
|
||||
db.add(api_client)
|
||||
db.commit()
|
||||
return api_client
|
||||
api_client = db.query(ApiClient).filter(ApiClient.api_key == api_key).first()
|
||||
if api_client is not None and api_client.enabled:
|
||||
return api_client
|
||||
|
||||
api_client = db.query(ApiClient).filter(ApiClient.api_key == api_key).first()
|
||||
if api_client is not None and api_client.enabled:
|
||||
return api_client
|
||||
raise OasstError(
|
||||
"Could not validate credentials",
|
||||
error_code=OasstErrorCode.API_CLIENT_NOT_AUTHORIZED,
|
||||
http_status_code=HTTPStatus.FORBIDDEN,
|
||||
)
|
||||
|
||||
@@ -1,39 +1,54 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import random
|
||||
from typing import Any
|
||||
from typing import Any, Optional, Tuple
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.security.api_key import APIKey
|
||||
from loguru import logger
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_400_BAD_REQUEST
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def generate_task(request: protocol_schema.TaskRequest) -> protocol_schema.Task:
|
||||
def generate_task(
|
||||
request: protocol_schema.TaskRequest, pr: PromptRepository
|
||||
) -> Tuple[protocol_schema.Task, Optional[UUID], Optional[UUID]]:
|
||||
thread_id = None
|
||||
parent_post_id = None
|
||||
|
||||
match request.type:
|
||||
case protocol_schema.TaskRequestType.random:
|
||||
logger.info("Frontend requested a random task.")
|
||||
while request.type == protocol_schema.TaskRequestType.random:
|
||||
request.type = random.choice(list(protocol_schema.TaskRequestType)).value
|
||||
return generate_task(request)
|
||||
case protocol_schema.TaskRequestType.summarize_story:
|
||||
logger.info("Generating a SummarizeStoryTask.")
|
||||
task = protocol_schema.SummarizeStoryTask(
|
||||
story="This is a story. A very long story. So long, it needs to be summarized.",
|
||||
)
|
||||
case protocol_schema.TaskRequestType.rate_summary:
|
||||
logger.info("Generating a RateSummaryTask.")
|
||||
task = protocol_schema.RateSummaryTask(
|
||||
full_text="This is a story. A very long story. So long, it needs to be summarized.",
|
||||
summary="This is a summary.",
|
||||
scale=protocol_schema.RatingScale(min=1, max=5),
|
||||
)
|
||||
disabled_tasks = (
|
||||
protocol_schema.TaskRequestType.summarize_story,
|
||||
protocol_schema.TaskRequestType.rate_summary,
|
||||
)
|
||||
request.type = random.choice(
|
||||
tuple(set(protocol_schema.TaskRequestType).difference(disabled_tasks))
|
||||
).value
|
||||
return generate_task(request, pr)
|
||||
|
||||
# AKo: Summary tasks are currently disabled/supported, we focus on the conversation tasks.
|
||||
|
||||
# case protocol_schema.TaskRequestType.summarize_story:
|
||||
# logger.info("Generating a SummarizeStoryTask.")
|
||||
# task = protocol_schema.SummarizeStoryTask(
|
||||
# story="This is a story. A very long story. So long, it needs to be summarized.",
|
||||
# )
|
||||
# case protocol_schema.TaskRequestType.rate_summary:
|
||||
# logger.info("Generating a RateSummaryTask.")
|
||||
# task = protocol_schema.RateSummaryTask(
|
||||
# full_text="This is a story. A very long story. So long, it needs to be summarized.",
|
||||
# summary="This is a summary.",
|
||||
# scale=protocol_schema.RatingScale(min=1, max=5),
|
||||
# )
|
||||
|
||||
case protocol_schema.TaskRequestType.initial_prompt:
|
||||
logger.info("Generating an InitialPromptTask.")
|
||||
task = protocol_schema.InitialPromptTask(
|
||||
@@ -41,87 +56,72 @@ def generate_task(request: protocol_schema.TaskRequest) -> protocol_schema.Task:
|
||||
)
|
||||
case protocol_schema.TaskRequestType.user_reply:
|
||||
logger.info("Generating a UserReplyTask.")
|
||||
task = protocol_schema.UserReplyTask(
|
||||
conversation=protocol_schema.Conversation(
|
||||
messages=[
|
||||
protocol_schema.ConversationMessage(
|
||||
text="Hey, assistant, what's going on in the world?",
|
||||
is_assistant=False,
|
||||
),
|
||||
protocol_schema.ConversationMessage(
|
||||
text="I'm not sure I understood correctly, could you rephrase that?",
|
||||
is_assistant=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
)
|
||||
posts = pr.fetch_random_conversation("assistant")
|
||||
messages = [
|
||||
protocol_schema.ConversationMessage(text=p.payload.payload.text, is_assistant=(p.role == "assistant"))
|
||||
for p in posts
|
||||
]
|
||||
|
||||
task = protocol_schema.UserReplyTask(conversation=protocol_schema.Conversation(messages=messages))
|
||||
thread_id = posts[-1].thread_id
|
||||
parent_post_id = posts[-1].id
|
||||
case protocol_schema.TaskRequestType.assistant_reply:
|
||||
logger.info("Generating a AssistantReplyTask.")
|
||||
task = protocol_schema.AssistantReplyTask(
|
||||
conversation=protocol_schema.Conversation(
|
||||
messages=[
|
||||
protocol_schema.ConversationMessage(
|
||||
text="Hey, assistant, write me an English essay about water.",
|
||||
is_assistant=False,
|
||||
),
|
||||
],
|
||||
)
|
||||
)
|
||||
posts = pr.fetch_random_conversation("user")
|
||||
messages = [
|
||||
protocol_schema.ConversationMessage(text=p.payload.payload.text, is_assistant=(p.role == "assistant"))
|
||||
for p in posts
|
||||
]
|
||||
|
||||
task = protocol_schema.AssistantReplyTask(conversation=protocol_schema.Conversation(messages=messages))
|
||||
thread_id = posts[-1].thread_id
|
||||
parent_post_id = posts[-1].id
|
||||
case protocol_schema.TaskRequestType.rank_initial_prompts:
|
||||
logger.info("Generating a RankInitialPromptsTask.")
|
||||
task = protocol_schema.RankInitialPromptsTask(
|
||||
prompts=[
|
||||
"Please write a story about a time you were happy.",
|
||||
"Please write a story about a time you were sad.",
|
||||
]
|
||||
)
|
||||
|
||||
posts = pr.fetch_random_initial_prompts()
|
||||
task = protocol_schema.RankInitialPromptsTask(prompts=[p.payload.payload.text for p in posts])
|
||||
case protocol_schema.TaskRequestType.rank_user_replies:
|
||||
logger.info("Generating a RankUserRepliesTask.")
|
||||
conversation, replies = pr.fetch_multiple_random_replies(post_role="assistant")
|
||||
|
||||
messages = [
|
||||
protocol_schema.ConversationMessage(
|
||||
text=p.payload.payload.text,
|
||||
is_assistant=(p.role == "assistant"),
|
||||
)
|
||||
for p in conversation
|
||||
]
|
||||
replies = [p.payload.payload.text for p in replies]
|
||||
task = protocol_schema.RankUserRepliesTask(
|
||||
conversation=protocol_schema.Conversation(
|
||||
messages=[
|
||||
protocol_schema.ConversationMessage(
|
||||
text="Hey, assistant, what's going on in the world?",
|
||||
is_assistant=False,
|
||||
),
|
||||
protocol_schema.ConversationMessage(
|
||||
text="I'm not sure I understood correctly, could you rephrase that?",
|
||||
is_assistant=True,
|
||||
),
|
||||
],
|
||||
messages=messages,
|
||||
),
|
||||
replies=[
|
||||
"Oh come oooooon!",
|
||||
"What are the news?",
|
||||
],
|
||||
replies=replies,
|
||||
)
|
||||
|
||||
case protocol_schema.TaskRequestType.rank_assistant_replies:
|
||||
logger.info("Generating a RankAssistantRepliesTask.")
|
||||
conversation, replies = pr.fetch_multiple_random_replies(post_role="user")
|
||||
|
||||
messages = [
|
||||
protocol_schema.ConversationMessage(
|
||||
text=p.payload.payload.text,
|
||||
is_assistant=(p.role == "assistant"),
|
||||
)
|
||||
for p in conversation
|
||||
]
|
||||
replies = [p.payload.payload.text for p in replies]
|
||||
task = protocol_schema.RankAssistantRepliesTask(
|
||||
conversation=protocol_schema.Conversation(
|
||||
messages=[
|
||||
protocol_schema.ConversationMessage(
|
||||
text="Hey, assistant, what's going on in the world?",
|
||||
is_assistant=False,
|
||||
),
|
||||
],
|
||||
),
|
||||
replies=[
|
||||
"I'm not sure I understood correctly, could you rephrase that?",
|
||||
"The world is fine. All good.",
|
||||
"Crap is hitting the fan. Start farming.",
|
||||
],
|
||||
conversation=protocol_schema.Conversation(messages=messages),
|
||||
replies=replies,
|
||||
)
|
||||
case _:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid request type.",
|
||||
)
|
||||
raise OasstError("Invalid request type", OasstErrorCode.TASK_INVALID_REQUEST_TYPE)
|
||||
|
||||
logger.info(f"Generated {task=}.")
|
||||
|
||||
return task
|
||||
return task, thread_id, parent_post_id
|
||||
|
||||
|
||||
@router.post("/", response_model=protocol_schema.AnyTask) # work with Union once more types are added
|
||||
@@ -137,16 +137,15 @@ def request_task(
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
|
||||
try:
|
||||
task = generate_task(request)
|
||||
|
||||
pr = PromptRepository(db, api_client, request.user)
|
||||
pr.store_task(task)
|
||||
task, thread_id, parent_post_id = generate_task(request, pr)
|
||||
pr.store_task(task, thread_id, parent_post_id)
|
||||
|
||||
except OasstError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to generate task.")
|
||||
raise HTTPException(
|
||||
status_code=HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
logger.exception("Failed to generate task..")
|
||||
raise OasstError("Failed to generate task.", OasstErrorCode.TASK_GENERATION_FAILED)
|
||||
return task
|
||||
|
||||
|
||||
@@ -171,11 +170,11 @@ def acknowledge_task(
|
||||
logger.info(f"Frontend acknowledges task {task_id=}, {ack_request=}.")
|
||||
pr.bind_frontend_post_id(task_id=task_id, post_id=ack_request.post_id)
|
||||
|
||||
except OasstError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to acknowledge task.")
|
||||
raise HTTPException(
|
||||
status_code=HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
raise OasstError("Failed to acknowledge task.", OasstErrorCode.TASK_ACK_FAILED)
|
||||
return {}
|
||||
|
||||
|
||||
@@ -190,11 +189,15 @@ def acknowledge_task_failure(
|
||||
"""
|
||||
The frontend reports failure to implement a task.
|
||||
"""
|
||||
deps.api_auth(api_key, db)
|
||||
|
||||
logger.info(f"Frontend reports failure to implement task {task_id=}, {nack_request=}.")
|
||||
# here we would store the post id in the database for the task
|
||||
return {}
|
||||
try:
|
||||
logger.info(f"Frontend reports failure to implement task {task_id=}, {nack_request=}.")
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr.acknowledge_task_failure(task_id)
|
||||
except (KeyError, RuntimeError):
|
||||
logger.exception("Failed to not acknowledge task.")
|
||||
raise OasstError("Failed to not acknowledge task.", OasstErrorCode.TASK_NACK_FAILED)
|
||||
|
||||
|
||||
@router.post("/interaction")
|
||||
@@ -219,8 +222,9 @@ def post_interaction(
|
||||
)
|
||||
|
||||
# here we store the text reply in the database
|
||||
# ToDo: role user or agent?
|
||||
pr.store_text_reply(interaction, role="unknown")
|
||||
pr.store_text_reply(
|
||||
text=interaction.text, post_id=interaction.post_id, user_post_id=interaction.user_post_id
|
||||
)
|
||||
|
||||
return protocol_schema.TaskDone()
|
||||
case protocol_schema.PostRating:
|
||||
@@ -242,13 +246,9 @@ def post_interaction(
|
||||
# here we would store the ranking in the database
|
||||
return protocol_schema.TaskDone()
|
||||
case _:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid response type.",
|
||||
)
|
||||
|
||||
raise OasstError("Invalid response type.", OasstErrorCode.TASK_INVALID_RESPONSE_TYPE)
|
||||
except OasstError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Interaction request failed.")
|
||||
raise HTTPException(
|
||||
status_code=HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
raise OasstError("Interaction request failed.", OasstErrorCode.TASK_INTERACTION_REQUEST_FAILED)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.exceptions import OasstError, OasstErrorCode
|
||||
from sqlmodel import create_engine
|
||||
|
||||
if settings.DATABASE_URI is None:
|
||||
raise ValueError("DATABASE_URI is not set")
|
||||
raise OasstError("DATABASE_URI is not set", error_code=OasstErrorCode.DATABASE_URI_NOT_SET)
|
||||
|
||||
engine = create_engine(settings.DATABASE_URI)
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from enum import IntEnum
|
||||
from http import HTTPStatus
|
||||
|
||||
|
||||
class OasstErrorCode(IntEnum):
|
||||
"""
|
||||
Error codes of the Open-Assistant backend API.
|
||||
|
||||
Ranges:
|
||||
0-1000: general errors
|
||||
1000-2000: tasks endpoint
|
||||
2000-3000: prompt_repository
|
||||
"""
|
||||
|
||||
# 0-1000: general errors
|
||||
GENERIC_ERROR = 0
|
||||
DATABASE_URI_NOT_SET = 1
|
||||
API_CLIENT_NOT_AUTHORIZED = 2
|
||||
|
||||
# 1000-2000: tasks endpoint
|
||||
TASK_INVALID_REQUEST_TYPE = 1000
|
||||
TASK_ACK_FAILED = 1001
|
||||
TASK_NACK_FAILED = 1002
|
||||
TASK_INVALID_RESPONSE_TYPE = 1003
|
||||
TASK_INTERACTION_REQUEST_FAILED = 1004
|
||||
TASK_GENERATION_FAILED = 1005
|
||||
|
||||
# 2000-3000: prompt_repository
|
||||
INVALID_POST_ID = 2000
|
||||
POST_NOT_FOUND = 2001
|
||||
RATING_OUT_OF_RANGE = 2002
|
||||
INVALID_RANKING_VALUE = 2003
|
||||
INVALID_TASK_TYPE = 2004
|
||||
USER_NOT_SPECIFIED = 2005
|
||||
NO_THREADS_FOUND = 2006
|
||||
WORK_PACKAGE_NOT_FOUND = 2100
|
||||
WORK_PACKAGE_EXPIRED = 2101
|
||||
WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH = 2102
|
||||
WORK_PACKAGE_ALREADY_UPDATED = 2103
|
||||
WORK_PACKAGE_NOT_ACK = 2104
|
||||
WORK_PACKAGE_ALREADY_DONE = 2105
|
||||
|
||||
|
||||
class OasstError(Exception):
|
||||
"""Base class for Open-Assistant exceptions."""
|
||||
|
||||
message: str
|
||||
error_code: int
|
||||
http_status_code: HTTPStatus
|
||||
|
||||
def __init__(self, message: str, error_code: OasstErrorCode, http_status_code: HTTPStatus = HTTPStatus.BAD_REQUEST):
|
||||
super().__init__(message, error_code, http_status_code) # make excetpion picklable (fill args member)
|
||||
self.message = message
|
||||
self.error_code = error_code
|
||||
self.http_status_code = http_status_code
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
return f'{class_name}(message="{self.message}", error_code={self.error_code}, http_status_code={self.http_status_code})'
|
||||
@@ -31,3 +31,5 @@ class Post(SQLModel, table=True):
|
||||
)
|
||||
payload_type: str = Field(nullable=False, max_length=200)
|
||||
payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=True))
|
||||
depth: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
|
||||
children_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
|
||||
|
||||
@@ -13,8 +13,8 @@ from .payload_column_type import PayloadContainer, payload_column_type
|
||||
class PostReaction(SQLModel, table=True):
|
||||
__tablename__ = "post_reaction"
|
||||
|
||||
post_id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("post.id"), nullable=False, primary_key=True)
|
||||
work_package_id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("work_package.id"), nullable=False, primary_key=True)
|
||||
)
|
||||
person_id: UUID = Field(
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("person.id"), nullable=False, primary_key=True)
|
||||
|
||||
@@ -5,6 +5,7 @@ from uuid import UUID, uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from sqlalchemy import false
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
from .payload_column_type import PayloadContainer, payload_column_type
|
||||
@@ -26,3 +27,12 @@ class WorkPackage(SQLModel, table=True):
|
||||
payload_type: str = Field(nullable=False, max_length=200)
|
||||
payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False))
|
||||
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
|
||||
ack: Optional[bool] = None
|
||||
done: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
|
||||
frontend_ref_post_id: Optional[str] = None
|
||||
thread_id: Optional[UUID] = None
|
||||
parent_post_id: Optional[UUID] = None
|
||||
|
||||
@property
|
||||
def expired(self) -> bool:
|
||||
return self.expiry_date is not None and datetime.utcnow() < self.expiry_date
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime
|
||||
import random
|
||||
from typing import Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import oasst_backend.models.db_payload as db_payload
|
||||
from loguru import logger
|
||||
from oasst_backend.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_backend.journal_writer import JournalWriter
|
||||
from oasst_backend.models import ApiClient, Person, Post, PostReaction, TextLabels, WorkPackage
|
||||
from oasst_backend.models.payload_column_type import PayloadContainer
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
from sqlmodel import Session, func
|
||||
|
||||
|
||||
class PromptRepository:
|
||||
@@ -52,9 +53,9 @@ class PromptRepository:
|
||||
|
||||
def validate_post_id(self, post_id: str) -> None:
|
||||
if not isinstance(post_id, str):
|
||||
raise TypeError(f"post_id must be string, not {type(post_id)}")
|
||||
raise OasstError(f"post_id must be string, not {type(post_id)}", OasstErrorCode.INVALID_POST_ID)
|
||||
if not post_id:
|
||||
raise ValueError("post_id must not be empty")
|
||||
raise OasstError("post_id must not be empty", OasstErrorCode.INVALID_POST_ID)
|
||||
|
||||
def bind_frontend_post_id(self, task_id: UUID, post_id: str):
|
||||
self.validate_post_id(post_id)
|
||||
@@ -66,36 +67,36 @@ class PromptRepository:
|
||||
.first()
|
||||
)
|
||||
if work_pack is None:
|
||||
raise KeyError(f"WorkPackage for task {task_id} not found")
|
||||
if work_pack.expiry_date is not None and datetime.utcnow() > work_pack.expiry_date:
|
||||
raise RuntimeError("WorkPackage already expired.")
|
||||
raise OasstError(f"WorkPackage for task {task_id} not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND)
|
||||
if work_pack.expired:
|
||||
raise OasstError("WorkPackage already expired.", OasstErrorCode.WORK_PACKAGE_EXPIRED)
|
||||
if work_pack.done or work_pack.ack is not None:
|
||||
raise OasstError("WorkPackage already updated.", OasstErrorCode.WORK_PACKAGE_ALREADY_UPDATED)
|
||||
|
||||
work_pack.frontend_ref_post_id = post_id
|
||||
work_pack.ack = True
|
||||
# ToDo: check race-condition, transaction
|
||||
self.db.add(work_pack)
|
||||
self.db.commit()
|
||||
|
||||
# check if task thread exits
|
||||
thread_root = (
|
||||
self.db.query(Post)
|
||||
.filter(
|
||||
Post.workpackage_id == work_pack.id,
|
||||
Post.frontend_post_id == post_id,
|
||||
Post.parent_id is None,
|
||||
Post.api_client_id == self.api_client.id,
|
||||
)
|
||||
.one_or_none()
|
||||
def acknowledge_task_failure(self, task_id):
|
||||
# find work package
|
||||
work_pack: WorkPackage = (
|
||||
self.db.query(WorkPackage)
|
||||
.filter(WorkPackage.id == task_id, WorkPackage.api_client_id == self.api_client.id)
|
||||
.first()
|
||||
)
|
||||
if thread_root is None:
|
||||
thread_id = uuid4()
|
||||
thread_root = self.insert_post(
|
||||
post_id=thread_id,
|
||||
thread_id=thread_id,
|
||||
frontend_post_id=post_id,
|
||||
parent_id=None,
|
||||
role="system",
|
||||
workpackage_id=work_pack.id,
|
||||
payload=None,
|
||||
payload_type="bind",
|
||||
)
|
||||
return thread_root
|
||||
if work_pack is None:
|
||||
raise OasstError(f"WorkPackage for task {task_id} not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND)
|
||||
if work_pack.expired:
|
||||
raise OasstError("WorkPackage already expired.", OasstErrorCode.WORK_PACKAGE_EXPIRED)
|
||||
if work_pack.done or work_pack.ack is not None:
|
||||
raise OasstError("WorkPackage already updated.", OasstErrorCode.WORK_PACKAGE_ALREADY_UPDATED)
|
||||
|
||||
work_pack.ack = False
|
||||
# ToDo: check race-condition, transaction
|
||||
self.db.add(work_pack)
|
||||
self.db.commit()
|
||||
|
||||
def fetch_post_by_frontend_post_id(self, frontend_post_id: str, fail_if_missing: bool = True) -> Post:
|
||||
self.validate_post_id(frontend_post_id)
|
||||
@@ -105,49 +106,64 @@ class PromptRepository:
|
||||
.one_or_none()
|
||||
)
|
||||
if fail_if_missing and post is None:
|
||||
raise KeyError(f"Post with post_id {frontend_post_id} not found.")
|
||||
raise OasstError(f"Post with post_id {frontend_post_id} not found.", OasstErrorCode.POST_NOT_FOUND)
|
||||
return post
|
||||
|
||||
def fetch_workpackage_by_postid(self, post_id: str) -> WorkPackage:
|
||||
self.validate_post_id(post_id)
|
||||
post = self.fetch_post_by_frontend_post_id(post_id, fail_if_missing=True)
|
||||
work_pack = self.db.query(WorkPackage).filter(WorkPackage.id == post.workpackage_id).one()
|
||||
return work_pack
|
||||
|
||||
def store_text_reply(self, reply: protocol_schema.TextReplyToPost, role: str) -> Post:
|
||||
self.validate_post_id(reply.post_id)
|
||||
self.validate_post_id(reply.user_post_id)
|
||||
|
||||
work_package = self.fetch_workpackage_by_postid(reply.post_id)
|
||||
work_payload: db_payload.TaskPayload = work_package.payload.payload
|
||||
logger.info(f"found task work package in db: {work_payload}")
|
||||
|
||||
# find post with post-id
|
||||
parent_post: Post = (
|
||||
self.db.query(Post)
|
||||
.filter(
|
||||
Post.api_client_id == self.api_client.id,
|
||||
Post.frontend_post_id == reply.post_id,
|
||||
# Post.person_id == self.person_id
|
||||
)
|
||||
work_pack = (
|
||||
self.db.query(WorkPackage)
|
||||
.filter(WorkPackage.api_client_id == self.api_client.id, WorkPackage.frontend_ref_post_id == post_id)
|
||||
.one_or_none()
|
||||
)
|
||||
return work_pack
|
||||
|
||||
if parent_post is None:
|
||||
raise KeyError(f"Post for post_id {reply.post_id} not found.")
|
||||
def store_text_reply(self, text: str, post_id: str, user_post_id: str, role: str = None) -> Post:
|
||||
self.validate_post_id(post_id)
|
||||
self.validate_post_id(user_post_id)
|
||||
|
||||
wp = self.fetch_workpackage_by_postid(post_id)
|
||||
|
||||
if wp is None:
|
||||
raise OasstError(f"WorkPackage for {post_id=} not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND)
|
||||
if wp.expired:
|
||||
raise OasstError("WorkPackage already expired.", OasstErrorCode.WORK_PACKAGE_EXPIRED)
|
||||
if not wp.ack:
|
||||
raise OasstError("WorkPackage is not acknowledged.", OasstErrorCode.WORK_PACKAGE_NOT_ACK)
|
||||
if wp.done:
|
||||
raise OasstError("WorkPackage already done.", OasstErrorCode.WORK_PACKAGE_ALREADY_DONE)
|
||||
|
||||
# If there's no parent post assume user started new conversation
|
||||
role = "user"
|
||||
depth = 0
|
||||
|
||||
if wp.parent_post_id:
|
||||
parent_post = self.fetch_post(wp.parent_post_id)
|
||||
parent_post.children_count += 1
|
||||
self.db.add(parent_post)
|
||||
|
||||
depth = parent_post.depth + 1
|
||||
if parent_post.role == "assistant":
|
||||
role = "user"
|
||||
else:
|
||||
role = "assistant"
|
||||
|
||||
# create reply post
|
||||
user_post_id = uuid4()
|
||||
new_post_id = uuid4()
|
||||
user_post = self.insert_post(
|
||||
post_id=user_post_id,
|
||||
frontend_post_id=reply.user_post_id,
|
||||
parent_id=parent_post.id,
|
||||
thread_id=parent_post.thread_id,
|
||||
workpackage_id=parent_post.workpackage_id,
|
||||
post_id=new_post_id,
|
||||
frontend_post_id=user_post_id,
|
||||
parent_id=wp.parent_post_id,
|
||||
thread_id=wp.thread_id or new_post_id,
|
||||
workpackage_id=wp.id,
|
||||
role=role,
|
||||
payload=db_payload.PostPayload(text=reply.text),
|
||||
payload=db_payload.PostPayload(text=text),
|
||||
depth=depth,
|
||||
)
|
||||
self.journal.log_text_reply(work_package=work_package, post_id=user_post_id, role=role, length=len(reply.text))
|
||||
wp.done = True
|
||||
self.db.add(wp)
|
||||
self.db.commit()
|
||||
self.journal.log_text_reply(work_package=wp, post_id=new_post_id, role=role, length=len(text))
|
||||
return user_post
|
||||
|
||||
def store_rating(self, rating: protocol_schema.PostRating) -> PostReaction:
|
||||
@@ -156,12 +172,16 @@ class PromptRepository:
|
||||
work_package = self.fetch_workpackage_by_postid(rating.post_id)
|
||||
work_payload: db_payload.RateSummaryPayload = work_package.payload.payload
|
||||
if type(work_payload) != db_payload.RateSummaryPayload:
|
||||
raise ValueError(
|
||||
f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RateSummaryPayload}"
|
||||
raise OasstError(
|
||||
f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RateSummaryPayload}",
|
||||
OasstErrorCode.WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH,
|
||||
)
|
||||
|
||||
if rating.rating < work_payload.scale.min or rating.rating > work_payload.scale.max:
|
||||
raise ValueError(f"Invalid rating value: {rating.rating=} not in {work_payload.scale=}")
|
||||
raise OasstError(
|
||||
f"Invalid rating value: {rating.rating=} not in {work_payload.scale=}",
|
||||
OasstErrorCode.RATING_OUT_OF_RANGE,
|
||||
)
|
||||
|
||||
# store reaction to post
|
||||
reaction_payload = db_payload.RatingReactionPayload(rating=rating.rating)
|
||||
@@ -171,10 +191,11 @@ class PromptRepository:
|
||||
return reaction
|
||||
|
||||
def store_ranking(self, ranking: protocol_schema.PostRanking) -> PostReaction:
|
||||
post = self.fetch_post_by_frontend_post_id(ranking.post_id, fail_if_missing=True)
|
||||
|
||||
# fetch work_package
|
||||
work_package = self.fetch_workpackage_by_postid(ranking.post_id)
|
||||
work_package.done = True
|
||||
self.db.add(work_package)
|
||||
|
||||
work_payload: db_payload.RankConversationRepliesPayload | db_payload.RankInitialPromptsPayload = (
|
||||
work_package.payload.payload
|
||||
)
|
||||
@@ -185,14 +206,16 @@ class PromptRepository:
|
||||
# validate ranking
|
||||
num_replies = len(work_payload.replies)
|
||||
if sorted(ranking.ranking) != list(range(num_replies)):
|
||||
raise ValueError(
|
||||
f"Invalid ranking submitted. Each reply index must appear exactly once ({num_replies=})."
|
||||
raise OasstError(
|
||||
f"Invalid ranking submitted. Each reply index must appear exactly once ({num_replies=}).",
|
||||
OasstErrorCode.INVALID_RANKING_VALUE,
|
||||
)
|
||||
|
||||
# store reaction to post
|
||||
reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking)
|
||||
reaction = self.insert_reaction(post.id, reaction_payload)
|
||||
self.journal.log_ranking(work_package, post_id=post.id, ranking=ranking.ranking)
|
||||
reaction = self.insert_reaction(work_package.id, reaction_payload)
|
||||
# TODO: resolve post_id
|
||||
self.journal.log_ranking(work_package, post_id=None, ranking=ranking.ranking)
|
||||
|
||||
logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.")
|
||||
|
||||
@@ -201,25 +224,33 @@ class PromptRepository:
|
||||
case db_payload.RankInitialPromptsPayload:
|
||||
# validate ranking
|
||||
if sorted(ranking.ranking) != list(range(num_prompts := len(work_payload.prompts))):
|
||||
raise ValueError(
|
||||
f"Invalid ranking submitted. Each reply index must appear exactly once ({num_prompts=})."
|
||||
raise OasstError(
|
||||
f"Invalid ranking submitted. Each reply index must appear exactly once ({num_prompts=}).",
|
||||
OasstErrorCode.INVALID_RANKING_VALUE,
|
||||
)
|
||||
|
||||
# store reaction to post
|
||||
reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking)
|
||||
reaction = self.insert_reaction(post.id, reaction_payload)
|
||||
self.journal.log_ranking(work_package, post_id=post.id, ranking=ranking.ranking)
|
||||
reaction = self.insert_reaction(work_package.id, reaction_payload)
|
||||
# TODO: resolve post_id
|
||||
self.journal.log_ranking(work_package, post_id=None, ranking=ranking.ranking)
|
||||
|
||||
logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.")
|
||||
|
||||
return reaction
|
||||
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RankConversationRepliesPayload}"
|
||||
raise OasstError(
|
||||
f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RankConversationRepliesPayload}",
|
||||
OasstErrorCode.WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH,
|
||||
)
|
||||
|
||||
def store_task(self, task: protocol_schema.Task) -> WorkPackage:
|
||||
def store_task(
|
||||
self,
|
||||
task: protocol_schema.Task,
|
||||
thread_id: UUID = None,
|
||||
parent_post_id: UUID = None,
|
||||
) -> WorkPackage:
|
||||
payload: db_payload.TaskPayload
|
||||
match type(task):
|
||||
case protocol_schema.SummarizeStoryTask:
|
||||
@@ -253,13 +284,24 @@ class PromptRepository:
|
||||
)
|
||||
|
||||
case _:
|
||||
raise ValueError(f"Invalid task type: {type(task)=}")
|
||||
raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE)
|
||||
|
||||
wp = self.insert_work_package(payload=payload, id=task.id)
|
||||
wp = self.insert_work_package(
|
||||
payload=payload,
|
||||
id=task.id,
|
||||
thread_id=thread_id,
|
||||
parent_post_id=parent_post_id,
|
||||
)
|
||||
assert wp.id == task.id
|
||||
return wp
|
||||
|
||||
def insert_work_package(self, payload: db_payload.TaskPayload, id: UUID = None) -> WorkPackage:
|
||||
def insert_work_package(
|
||||
self,
|
||||
payload: db_payload.TaskPayload,
|
||||
id: UUID = None,
|
||||
thread_id: UUID = None,
|
||||
parent_post_id: UUID = None,
|
||||
) -> WorkPackage:
|
||||
c = PayloadContainer(payload=payload)
|
||||
wp = WorkPackage(
|
||||
id=id,
|
||||
@@ -267,6 +309,8 @@ class PromptRepository:
|
||||
payload_type=type(payload).__name__,
|
||||
payload=c,
|
||||
api_client_id=self.api_client.id,
|
||||
thread_id=thread_id,
|
||||
parent_post_id=parent_post_id,
|
||||
)
|
||||
self.db.add(wp)
|
||||
self.db.commit()
|
||||
@@ -284,6 +328,7 @@ class PromptRepository:
|
||||
role: str,
|
||||
payload: db_payload.PostPayload,
|
||||
payload_type: str = None,
|
||||
depth: int = 0,
|
||||
) -> Post:
|
||||
if payload_type is None:
|
||||
if payload is None:
|
||||
@@ -302,19 +347,20 @@ class PromptRepository:
|
||||
api_client_id=self.api_client.id,
|
||||
payload_type=payload_type,
|
||||
payload=PayloadContainer(payload=payload),
|
||||
depth=depth,
|
||||
)
|
||||
self.db.add(post)
|
||||
self.db.commit()
|
||||
self.db.refresh(post)
|
||||
return post
|
||||
|
||||
def insert_reaction(self, post_id: UUID, payload: db_payload.ReactionPayload) -> PostReaction:
|
||||
def insert_reaction(self, work_package_id: UUID, payload: db_payload.ReactionPayload) -> PostReaction:
|
||||
if self.person_id is None:
|
||||
raise ValueError("User required")
|
||||
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
|
||||
|
||||
container = PayloadContainer(payload=payload)
|
||||
reaction = PostReaction(
|
||||
post_id=post_id,
|
||||
work_package_id=work_package_id,
|
||||
person_id=self.person_id,
|
||||
payload=container,
|
||||
api_client_id=self.api_client.id,
|
||||
@@ -338,3 +384,82 @@ class PromptRepository:
|
||||
self.db.commit()
|
||||
self.db.refresh(model)
|
||||
return model
|
||||
|
||||
def fetch_random_thread(self, require_role: str = None) -> list[Post]:
|
||||
"""
|
||||
Loads all posts of a random thread.
|
||||
|
||||
:param require_role: If set loads only thread which has
|
||||
at least one post with given role.
|
||||
"""
|
||||
distinct_threads = self.db.query(Post.thread_id).distinct(Post.thread_id)
|
||||
if require_role:
|
||||
distinct_threads = distinct_threads.filter(Post.role == require_role)
|
||||
distinct_threads = distinct_threads.subquery()
|
||||
|
||||
random_thread = self.db.query(distinct_threads).order_by(func.random()).limit(1).subquery()
|
||||
thread_posts = self.db.query(Post).filter(Post.thread_id.in_(random_thread)).all()
|
||||
return thread_posts
|
||||
|
||||
def fetch_random_conversation(self, last_post_role: str = None) -> list[Post]:
|
||||
"""
|
||||
Picks a random linear conversation starting from any root post
|
||||
and ending somewhere in the thread, possibly at the root itself.
|
||||
|
||||
:param last_post_role: If set will form a conversation ending with a post
|
||||
created by this role. Necessary for the tasks like "user_reply" where
|
||||
the user should reply as a human and hence the last message of the conversation
|
||||
needs to have "assistant" role.
|
||||
"""
|
||||
thread_posts = self.fetch_random_thread(last_post_role)
|
||||
if not thread_posts:
|
||||
raise OasstError("No threads found", OasstErrorCode.NO_THREADS_FOUND)
|
||||
if last_post_role:
|
||||
conv_posts = [p for p in thread_posts if p.role == last_post_role]
|
||||
conv_posts = [random.choice(conv_posts)]
|
||||
else:
|
||||
conv_posts = [random.choice(thread_posts)]
|
||||
thread_posts = {p.id: p for p in thread_posts}
|
||||
|
||||
while True:
|
||||
if not conv_posts[-1].parent_id:
|
||||
# reached the start of the conversation
|
||||
break
|
||||
|
||||
parent_post = thread_posts[conv_posts[-1].parent_id]
|
||||
conv_posts.append(parent_post)
|
||||
|
||||
return list(reversed(conv_posts))
|
||||
|
||||
def fetch_random_initial_prompts(self, size: int = 5):
|
||||
posts = self.db.query(Post).filter(Post.parent_id.is_(None)).order_by(func.random()).limit(size).all()
|
||||
return posts
|
||||
|
||||
def fetch_thread(self, thread_id: UUID):
|
||||
return self.db.query(Post).filter(Post.thread_id == thread_id).all()
|
||||
|
||||
def fetch_multiple_random_replies(self, max_size: int = 5, post_role: str = None):
|
||||
parent = self.db.query(Post.id).filter(Post.children_count > 1)
|
||||
if post_role:
|
||||
parent = parent.filter(Post.role == post_role)
|
||||
|
||||
parent = parent.order_by(func.random()).limit(1).subquery()
|
||||
replies = self.db.query(Post).filter(Post.parent_id.in_(parent)).order_by(func.random()).limit(max_size).all()
|
||||
|
||||
thread = self.fetch_thread(replies[0].thread_id)
|
||||
thread = {p.id: p for p in thread}
|
||||
thread_posts = [thread[replies[0].parent_id]]
|
||||
while True:
|
||||
if not thread_posts[-1].parent_id:
|
||||
# reached start of the conversation
|
||||
break
|
||||
|
||||
parent_post = thread[thread_posts[-1].parent_id]
|
||||
thread_posts.append(parent_post)
|
||||
|
||||
thread_posts = reversed(thread_posts)
|
||||
|
||||
return thread_posts, replies
|
||||
|
||||
def fetch_post(self, post_id: UUID) -> Optional[Post]:
|
||||
return self.db.query(Post).filter(Post.id == post_id).one()
|
||||
|
||||
+4
-4
@@ -1,14 +1,14 @@
|
||||
version: "3.7"
|
||||
|
||||
services:
|
||||
# Use `docker compose up backend-dev` to start a database and work and the backend.
|
||||
# Use `docker compose up backend-dev --build --attach-dependencies` to start a database and work and the backend.
|
||||
backend-dev:
|
||||
image: tianon/true
|
||||
image: sverrirab/sleep
|
||||
depends_on: [db, adminer]
|
||||
|
||||
# Use `docker compose up frontend-dev` to start all services needed to work on the frontend.
|
||||
# Use `docker compose up frontend-dev --build --attach-dependencies` to start all services needed to work on the frontend.
|
||||
frontend-dev:
|
||||
image: tianon/true
|
||||
image: sverrirab/sleep
|
||||
depends_on: [db, webdb, adminer, maildev, backend]
|
||||
|
||||
# This DB is for the FastAPI Backend.
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
# Documentation
|
||||
|
||||
This directory contains the documentation for the project and other related organization documents.
|
||||
|
||||
## Contributing to this documentation
|
||||
|
||||
Please make a pull request to the `main` branch with your changes.
|
||||
|
||||
Consider that this folder is used for documenting the various code sub-parts, the high-level ideas, the ML aspects, experiments, contributor guides, guides for data creation, and many more things. Please try to keep the documentation as concise as possible and keep an organized folder structure that makes sense for everyone.
|
||||
@@ -1,6 +1,6 @@
|
||||
# Backend Development Setup
|
||||
|
||||
In root directory, run `docker compose up backend-dev` to start a database. The default settings are already configured to connect to the database at `localhost:5432`.
|
||||
In root directory, run `docker compose up backend-dev --build --attach-dependencies` to start a database. The default settings are already configured to connect to the database at `localhost:5432`.
|
||||
|
||||
Make sure you have all requirements installed. You can do this by running `pip install -r requirements.txt` inside the `backend` folder and `pip install -e .` inside the `oasst-shared` folder.
|
||||
Then, run the backend using the `run-local.sh` script. This will start the backend server at `http://localhost:8080`.
|
||||
|
||||
Executable
+3
@@ -0,0 +1,3 @@
|
||||
#!/usr/bin/env bash
|
||||
parent_path=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P )
|
||||
docker compose -f "$parent_path/../../docker-compose.yaml" up backend-dev --build --attach-dependencies
|
||||
@@ -1,5 +1,5 @@
|
||||
# Frontend Development Setup
|
||||
|
||||
In root directory run `docker compose up frontend-dev --build` to start a database and the backend server.
|
||||
In root directory run `docker compose up frontend-dev --build --attach-dependencies` to start a database and the backend server.
|
||||
|
||||
Then, point your frontend at `http://localhost:8080` to start developing. During development, any API key will be accepted.
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
const path = require("path");
|
||||
|
||||
module.exports = {
|
||||
stories: [
|
||||
"../src/components/**/*.stories.mdx",
|
||||
"../src/components/**/*.stories.@(js|jsx|ts|tsx)",
|
||||
],
|
||||
addons: [
|
||||
"@storybook/addon-links",
|
||||
"@storybook/addon-essentials",
|
||||
"@storybook/addon-interactions",
|
||||
"@chakra-ui/storybook-addon",
|
||||
],
|
||||
framework: "@storybook/react",
|
||||
core: {
|
||||
builder: "@storybook/builder-webpack5",
|
||||
},
|
||||
staticDirs: ["../public"],
|
||||
// https://github.com/storybookjs/storybook/issues/15336#issuecomment-888528747
|
||||
typescript: { reactDocgen: false },
|
||||
// fix to make absolute imports working in storybook
|
||||
webpackFinal: async (config, { configType }) => {
|
||||
config.resolve.alias = {
|
||||
...config.resolve.alias,
|
||||
src: path.resolve(__dirname, "../src"),
|
||||
};
|
||||
return config;
|
||||
},
|
||||
features: {
|
||||
emotionAlias: false,
|
||||
},
|
||||
};
|
||||
@@ -0,0 +1,22 @@
|
||||
import "!style-loader!css-loader!postcss-loader!tailwindcss/tailwind.css";
|
||||
|
||||
export const parameters = {
|
||||
actions: { argTypesRegex: "^on[A-Z].*" },
|
||||
controls: {
|
||||
matchers: {
|
||||
color: /(background|color)$/i,
|
||||
date: /Date$/,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
// Hacky solution to get Images in next to work
|
||||
// https://dev.to/jonasmerlin/how-to-use-the-next-js-image-component-in-storybook-1415
|
||||
import * as NextImage from "next/image";
|
||||
|
||||
const OriginalNextImage = NextImage.default;
|
||||
|
||||
Object.defineProperty(NextImage, "default", {
|
||||
configurable: true,
|
||||
value: (props) => <OriginalNextImage {...props} unoptimized />,
|
||||
});
|
||||
+8
-3
@@ -49,9 +49,8 @@ installed:
|
||||
|
||||
If you're doing active development we suggest the following workflow:
|
||||
|
||||
1. In one tab, navigate to
|
||||
`${OPEN_ASSISTANT_ROOT}/scripts/frontend-development`.
|
||||
1. Run `docker compose up --build`. You can optionally include `-d` to detach and
|
||||
1. In one tab, navigate to the project root.
|
||||
1. Run `docker compose up frontend-dev --build --attach-dependencies`. You can optionally include `-d` to detach and
|
||||
later track the logs if desired.
|
||||
1. In another tab navigate to `${OPEN_ASSISTANT_ROOT/website`.
|
||||
1. Run `npm install`
|
||||
@@ -71,6 +70,12 @@ You can use the debug credentials provider to log in without fancy emails or OAu
|
||||
1. Use the `Login` button in the top right to go to the login page.
|
||||
1. You should see a section for debug credentials. Enter any username you wish, you will be logged in as that user.
|
||||
|
||||
### Using Storybook
|
||||
|
||||
To develop components using [Storybook](https://storybook.js.org/) run `npm run storybook`. Then navigate to in your browser to `http://localhost:6006`.
|
||||
|
||||
To create a new story create a file named `[componentName].stories.js`. An example how such a story could look like, see `Header.stories.jsx`.
|
||||
|
||||
## Code Layout
|
||||
|
||||
### React Code
|
||||
|
||||
Generated
+28846
-2842
File diff suppressed because it is too large
Load Diff
+18
-2
@@ -7,7 +7,9 @@
|
||||
"dev": "next dev",
|
||||
"build": "next build",
|
||||
"start": "next start",
|
||||
"lint": "next lint"
|
||||
"lint": "next lint",
|
||||
"storybook": "start-storybook -p 6006",
|
||||
"build-storybook": "build-storybook"
|
||||
},
|
||||
"dependencies": {
|
||||
"@chakra-ui/react": "^2.4.4",
|
||||
@@ -40,9 +42,23 @@
|
||||
"use-debounce": "^9.0.2"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@babel/core": "^7.20.7",
|
||||
"@chakra-ui/storybook-addon": "^4.0.16",
|
||||
"@storybook/addon-actions": "^6.5.15",
|
||||
"@storybook/addon-essentials": "^6.5.15",
|
||||
"@storybook/addon-interactions": "^6.5.15",
|
||||
"@storybook/addon-links": "^6.5.15",
|
||||
"@storybook/addon-postcss": "^2.0.0",
|
||||
"@storybook/builder-webpack5": "^6.5.15",
|
||||
"@storybook/manager-webpack5": "^6.5.15",
|
||||
"@storybook/react": "^6.5.15",
|
||||
"@storybook/testing-library": "^0.0.13",
|
||||
"@types/node": "18.11.17",
|
||||
"@types/react": "18.0.26",
|
||||
"babel-loader": "^8.3.0",
|
||||
"eslint-plugin-storybook": "^0.6.8",
|
||||
"prettier": "2.8.1",
|
||||
"prisma": "^4.7.1"
|
||||
"prisma": "^4.7.1",
|
||||
"typescript": "4.9.4"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,31 +1,75 @@
|
||||
import Image from "next/image";
|
||||
import Link from "next/link";
|
||||
|
||||
import { FaGithub, FaDiscord } from "react-icons/fa";
|
||||
import { Container } from "./Container";
|
||||
import { NavLinks } from "./NavLinks";
|
||||
|
||||
export function Footer() {
|
||||
return (
|
||||
<footer className="border-t border-gray-200 bg-white">
|
||||
<Container className="">
|
||||
<div className="flex flex-col items-start justify-between gap-y-12 pt-16 pb-6 lg:flex-row lg:items-center lg:py-6">
|
||||
<div>
|
||||
<div className="flex items-center text-gray-900">
|
||||
<main>
|
||||
<Container className="">
|
||||
<div className="flex flex-wrap justify-between gap-y-12 py-10 lg:items-center lg:py-16">
|
||||
<div className="flex items-center text-black pr-8">
|
||||
<Link href="/" aria-label="Home" className="flex items-center">
|
||||
<Image src="/images/logos/logo.svg" className="mx-auto object-fill" width="50" height="50" alt="logo" />
|
||||
<Image src="/images/logos/logo.svg" className="mx-auto object-fill" width="52" height="52" alt="logo" />
|
||||
</Link>
|
||||
|
||||
<div className="ml-4">
|
||||
<p className="text-base font-semibold">Open Assistant</p>
|
||||
<p className="mt-1 text-sm">Conversational AI for everyone.</p>
|
||||
<div className="ml-2">
|
||||
<p className="text-base font-bold">Open Assistant</p>
|
||||
<p className="text-sm">Conversational AI for everyone.</p>
|
||||
</div>
|
||||
</div>
|
||||
{/* <nav className="mt-11 flex gap-8">
|
||||
<NavLinks />
|
||||
</nav> */}
|
||||
<nav className="flex justify-center gap-20">
|
||||
<div className="flex flex-col text-sm leading-7">
|
||||
<b>Information</b>
|
||||
<div className="flex flex-col leading-5">
|
||||
<Link href="#" aria-label="Our Team" className="hover:underline underline-offset-2">
|
||||
Our Team
|
||||
</Link>
|
||||
<Link href="#join-us" aria-label="Join Us" className="hover:underline underline-offset-2">
|
||||
Join Us
|
||||
</Link>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex flex-col text-sm leading-7">
|
||||
<b>Legal</b>
|
||||
<div className="flex flex-col leading-5">
|
||||
<Link href="#" aria-label="Privacy Policy" className="hover:underline underline-offset-2">
|
||||
Privacy Policy
|
||||
</Link>
|
||||
<Link href="#" aria-label="Terms of Service" className="hover:underline underline-offset-2">
|
||||
Terms of Service
|
||||
</Link>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex flex-col text-sm leading-7">
|
||||
<b>Connect</b>
|
||||
<div className="flex flex-col leading-5">
|
||||
<Link
|
||||
href="https://github.com/LAION-AI/Open-Assistant"
|
||||
rel="noopener noreferrer nofollow"
|
||||
target="_blank"
|
||||
aria-label="Privacy Policy"
|
||||
className="hover:underline underline-offset-2"
|
||||
>
|
||||
Github
|
||||
</Link>
|
||||
<Link
|
||||
href="https://discord.gg/pXtnYk9c"
|
||||
rel="noopener noreferrer nofollow"
|
||||
target="_blank"
|
||||
aria-label="Terms of Service"
|
||||
className="hover:underline underline-offset-2"
|
||||
>
|
||||
Discord
|
||||
</Link>
|
||||
</div>
|
||||
</div>
|
||||
</nav>
|
||||
</div>
|
||||
</div>
|
||||
</Container>
|
||||
</Container>
|
||||
</main>
|
||||
</footer>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
import { SessionContext } from "next-auth/react";
|
||||
import React from "react";
|
||||
|
||||
import { Header } from "./Header";
|
||||
|
||||
export default {
|
||||
title: "Header/Header",
|
||||
component: Header,
|
||||
parameters: {
|
||||
layout: "fullscreen",
|
||||
},
|
||||
};
|
||||
|
||||
const Template = (args) => {
|
||||
var { session } = args;
|
||||
return (
|
||||
<SessionContext.Provider value={session}>
|
||||
<Header {...args} />
|
||||
</SessionContext.Provider>
|
||||
);
|
||||
};
|
||||
|
||||
export const Default = Template.bind({});
|
||||
Default.args = { session: { data: { user: { name: "StoryBook user" } }, status: "authenticated" } };
|
||||
@@ -6,9 +6,9 @@ import Link from "next/link";
|
||||
import { signOut, useSession } from "next-auth/react";
|
||||
import { FaUser, FaSignOutAlt } from "react-icons/fa";
|
||||
|
||||
import { Avatar } from "./Avatar";
|
||||
import { Container } from "./Container";
|
||||
import { Container } from "src/components/Container";
|
||||
import { NavLinks } from "./NavLinks";
|
||||
import { UserMenu } from "./UserMenu";
|
||||
|
||||
function MenuIcon(props) {
|
||||
return (
|
||||
@@ -45,9 +45,9 @@ function AccountButton() {
|
||||
return;
|
||||
}
|
||||
return (
|
||||
<Link href="/auth/signup" aria-label="Home" className="flex items-center">
|
||||
<Link href="/auth/signin" aria-label="Home" className="flex items-center">
|
||||
<Button variant="outline" leftIcon={<FaUser />}>
|
||||
Log in
|
||||
Sign in
|
||||
</Button>
|
||||
</Link>
|
||||
);
|
||||
@@ -113,7 +113,7 @@ export function Header() {
|
||||
)}
|
||||
</Popover>
|
||||
<AccountButton />
|
||||
<Avatar />
|
||||
<UserMenu />
|
||||
</div>
|
||||
</Container>
|
||||
</nav>
|
||||
@@ -0,0 +1,14 @@
|
||||
import { NavLinks } from "./NavLinks";
|
||||
|
||||
export default {
|
||||
title: "Header/NavLinks",
|
||||
component: NavLinks,
|
||||
};
|
||||
|
||||
const Template = (args) => (
|
||||
<div className="hidden lg:flex lg:gap-10">
|
||||
<NavLinks {...args} />
|
||||
</div>
|
||||
);
|
||||
|
||||
export const Default = Template.bind({});
|
||||
@@ -0,0 +1,25 @@
|
||||
import { SessionContext } from "next-auth/react";
|
||||
import React from "react";
|
||||
|
||||
import UserMenu from "./UserMenu";
|
||||
|
||||
export default {
|
||||
title: "Header/UserMenu",
|
||||
component: UserMenu,
|
||||
};
|
||||
|
||||
const Template = (args) => {
|
||||
var { session } = args;
|
||||
return (
|
||||
<SessionContext.Provider value={session}>
|
||||
<div className="flex flex-col">
|
||||
<div className="self-end">
|
||||
<UserMenu {...args} />
|
||||
</div>
|
||||
</div>
|
||||
</SessionContext.Provider>
|
||||
);
|
||||
};
|
||||
|
||||
export const Default = Template.bind({});
|
||||
Default.args = { session: { data: { user: { name: "StoryBook user" } }, status: "authenticated" } };
|
||||
@@ -5,18 +5,18 @@ import { Popover } from "@headlessui/react";
|
||||
import { AnimatePresence, motion } from "framer-motion";
|
||||
import { FaCog, FaSignOutAlt, FaGithub } from "react-icons/fa";
|
||||
|
||||
export function Avatar() {
|
||||
export function UserMenu() {
|
||||
const { data: session } = useSession();
|
||||
|
||||
if (!session) {
|
||||
return <></>;
|
||||
}
|
||||
if (session && session.user) {
|
||||
const displayName = session.user.name || session.user.email;
|
||||
const email = session.user.email;
|
||||
const accountOptions = [
|
||||
{
|
||||
name: "Account Settings",
|
||||
href: "#",
|
||||
href: "/account",
|
||||
desc: "Account Settings",
|
||||
icon: FaCog,
|
||||
//For future use
|
||||
@@ -35,8 +35,7 @@ export function Avatar() {
|
||||
height="40"
|
||||
className="rounded-full"
|
||||
></Image>
|
||||
<p className="hidden lg:flex">{displayName}</p>
|
||||
{/* Will be changed to username once it is implemented */}
|
||||
<p className="hidden lg:flex">{session.user.name || session.user.email}</p>
|
||||
</div>
|
||||
</Popover.Button>
|
||||
<AnimatePresence initial={false}>
|
||||
@@ -72,7 +71,7 @@ export function Avatar() {
|
||||
))}
|
||||
<a
|
||||
className="flex items-center rounded-md hover:bg-gray-100 cursor-pointer"
|
||||
onClick={() => signOut()}
|
||||
onClick={() => signOut({ callbackUrl: "/" })}
|
||||
>
|
||||
<div className="p-4">
|
||||
<FaSignOutAlt />
|
||||
@@ -93,4 +92,4 @@ export function Avatar() {
|
||||
}
|
||||
}
|
||||
|
||||
export default Avatar;
|
||||
export default UserMenu;
|
||||
@@ -0,0 +1,3 @@
|
||||
export { Header } from "./Header";
|
||||
export { UserMenu } from "./UserMenu";
|
||||
export { NavLinks } from "./NavLinks";
|
||||
@@ -3,7 +3,7 @@
|
||||
import type { NextPage } from "next";
|
||||
|
||||
import { Footer } from "./Footer";
|
||||
import { Header } from "./Header";
|
||||
import { Header } from "src/components/Header";
|
||||
|
||||
export type NextPageWithLayout<P = {}, IP = P> = NextPage<P, IP> & {
|
||||
getLayout?: (page: React.ReactElement) => React.ReactNode;
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
import { LoadingScreen } from "./LoadingScreen";
|
||||
|
||||
export default {
|
||||
title: "Example/LoadingScreen",
|
||||
component: LoadingScreen,
|
||||
parameters: {
|
||||
layout: "fullscreen",
|
||||
},
|
||||
};
|
||||
|
||||
const Template = (args) => <LoadingScreen {...args} />;
|
||||
|
||||
export const Default = Template.bind({});
|
||||
|
||||
export const WithText = Template.bind({});
|
||||
WithText.args = { text: "Loading Text ..." };
|
||||
@@ -0,0 +1,12 @@
|
||||
import { Progress } from "@chakra-ui/react";
|
||||
|
||||
export const LoadingScreen = ({ text }) => (
|
||||
<div className="bg-slate-100">
|
||||
<Progress size="xs" isIndeterminate />
|
||||
{text && (
|
||||
<div className="flex h-full">
|
||||
<div className="text-xl font-bold text-gray-800 mx-auto my-auto">{text}</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
@@ -4,5 +4,5 @@ export { default } from "next-auth/middleware";
|
||||
* Guards all pages under `/grading` and redirects them to the sign in page.
|
||||
*/
|
||||
export const config = {
|
||||
matcher: ["/create/:path*", "/evaluate/:path*"],
|
||||
matcher: ["/create/:path*", "/evaluate/:path*", "/account/:path*"],
|
||||
};
|
||||
|
||||
@@ -1,45 +1,19 @@
|
||||
import { useSession } from "next-auth/react";
|
||||
import { Footer } from "../components/Footer";
|
||||
import { Header } from "../components/Header";
|
||||
import { Header } from "src/components/Header";
|
||||
import Head from "next/head";
|
||||
import Link from "next/link";
|
||||
|
||||
export default function Error() {
|
||||
const { data: session } = useSession();
|
||||
|
||||
if (!session) {
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<title>Open Assistant</title>
|
||||
<meta name="404" content="Sorry, this page doesn't exist." />
|
||||
</Head>
|
||||
<Header />
|
||||
<main className="flex h-3/4 items-center justify-center overflow-hidden subpixel-antialiased text-xl">
|
||||
{"Sorry, the page you're looking for does not exist."}
|
||||
</main>
|
||||
<Footer />
|
||||
</>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<title>Open Assistant</title>
|
||||
<meta
|
||||
name="description"
|
||||
content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world."
|
||||
/>
|
||||
<title>404 - Open Assistant</title>
|
||||
<meta name="404" content="Sorry, this page doesn't exist." />
|
||||
</Head>
|
||||
<Header />
|
||||
<main>
|
||||
<h2>Open Chat Gpt</h2>
|
||||
|
||||
<p>You are logged in</p>
|
||||
|
||||
<Link href="/grading/grade-output">~Rate a prompt and output now~</Link>
|
||||
<main className="flex h-3/4 items-center justify-center overflow-hidden subpixel-antialiased text-xl">
|
||||
<p>Sorry, the page you are looking for does not exist.</p>
|
||||
</main>
|
||||
<Footer />
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
import React, { useState } from "react";
|
||||
import { useSession } from "next-auth/react";
|
||||
import { Button, Input, InputGroup, Stack } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import Router from "next/router";
|
||||
|
||||
export default function Account() {
|
||||
const { data: session } = useSession();
|
||||
const [username, setUsername] = useState("");
|
||||
const updateUser = async (e: React.SyntheticEvent) => {
|
||||
e.preventDefault();
|
||||
try {
|
||||
const body = { username };
|
||||
await fetch("/api/username", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
await Router.push("/account");
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
}
|
||||
};
|
||||
|
||||
if (!session) {
|
||||
return;
|
||||
}
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<title>Open Assistant</title>
|
||||
<meta
|
||||
name="description"
|
||||
content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world."
|
||||
/>
|
||||
</Head>
|
||||
<main className="h-3/4 z-0 bg-white flex flex-col items-center justify-center">
|
||||
<p>{session.user.name || "No username"}</p>
|
||||
<form onSubmit={updateUser}>
|
||||
<InputGroup>
|
||||
<Input
|
||||
onChange={(e) => setUsername(e.target.value)}
|
||||
placeholder="Edit Username"
|
||||
type="text"
|
||||
value={username}
|
||||
></Input>
|
||||
<Button disabled={!username} type="submit" value="Change">
|
||||
Submit
|
||||
</Button>
|
||||
</InputGroup>
|
||||
</form>
|
||||
<p>{session.user.email}</p>
|
||||
</main>
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
import Head from "next/head";
|
||||
import Link from "next/link";
|
||||
import React, { useState } from "react";
|
||||
import { useSession } from "next-auth/react";
|
||||
import { Button } from "@chakra-ui/react";
|
||||
|
||||
export default function Account() {
|
||||
const { data: session } = useSession();
|
||||
const [username, setUsername] = useState("null");
|
||||
|
||||
const handleUpdate = async () => {
|
||||
const response = await fetch("../api/update", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({ username }),
|
||||
});
|
||||
const { name } = await response.json();
|
||||
setUsername(name);
|
||||
};
|
||||
|
||||
if (!session) {
|
||||
return;
|
||||
}
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<title>Open Assistant</title>
|
||||
<meta
|
||||
name="description"
|
||||
content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world."
|
||||
/>
|
||||
</Head>
|
||||
<main className="h-3/4 z-0 bg-white flex flex-col items-center justify-center">
|
||||
<p>{username}</p>
|
||||
<Button>
|
||||
<Link href="/account/edit">Edit Username</Link>
|
||||
</Button>
|
||||
<p>{session.user.email}</p>
|
||||
</main>
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
import type { AuthOptions } from "next-auth";
|
||||
import NextAuth from "next-auth";
|
||||
import { NextApiHandler } from "next";
|
||||
import DiscordProvider from "next-auth/providers/discord";
|
||||
import EmailProvider from "next-auth/providers/email";
|
||||
import CredentialsProvider from "next-auth/providers/credentials";
|
||||
@@ -56,7 +57,7 @@ export const authOptions: AuthOptions = {
|
||||
adapter: PrismaAdapter(prisma),
|
||||
providers,
|
||||
pages: {
|
||||
signIn: "/auth/signup",
|
||||
signIn: "/auth/signin",
|
||||
verifyRequest: "/auth/verify",
|
||||
// error: "/auth/error", -Will be used later
|
||||
},
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
import { getSession } from "next-auth/react";
|
||||
import { Prisma } from "@prisma/client";
|
||||
import Email from "next-auth/providers/email";
|
||||
|
||||
// POST /api/post
|
||||
// Required fields in body: title
|
||||
// Optional fields in body: content
|
||||
export default async function handle(req, res) {
|
||||
const { username } = req.body;
|
||||
const { email } = req.body;
|
||||
|
||||
const session = await getSession({ req });
|
||||
const result = await prisma.user.update({
|
||||
where: {
|
||||
email: session.user.email,
|
||||
},
|
||||
data: {
|
||||
name: username,
|
||||
},
|
||||
});
|
||||
res.json({ name: result.name });
|
||||
}
|
||||
@@ -1,35 +1,19 @@
|
||||
import { Button, Input, Stack } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import Link from "next/link";
|
||||
import { FaDiscord, FaEnvelope, FaGithub } from "react-icons/fa";
|
||||
import { getCsrfToken, getProviders, signIn } from "next-auth/react";
|
||||
import { useRef } from "react";
|
||||
import { FaBug, FaDiscord, FaEnvelope, FaGithub } from "react-icons/fa";
|
||||
import Link from "next/link";
|
||||
|
||||
import { AuthLayout } from "src/components/AuthLayout";
|
||||
|
||||
export default function Signin({ csrfToken, providers }) {
|
||||
const { discord, email, github, credentials } = providers;
|
||||
const { discord, email, github } = providers;
|
||||
const emailEl = useRef(null);
|
||||
const debugUsernameEl = useRef(null);
|
||||
|
||||
const signinWithDiscord = () => {
|
||||
signIn(discord.id, { callbackUrl: "/" });
|
||||
};
|
||||
|
||||
const signinWithEmail = (ev: React.FormEvent) => {
|
||||
ev.preventDefault();
|
||||
const signinWithEmail = () => {
|
||||
signIn(email.id, { callbackUrl: "/", email: emailEl.current.value });
|
||||
};
|
||||
|
||||
const signinWithGithub = () => {
|
||||
signIn(github.id, { callbackUrl: "/" });
|
||||
};
|
||||
|
||||
function signinWithDebugCredentials(ev: React.FormEvent) {
|
||||
ev.preventDefault();
|
||||
signIn(credentials.id, { callbackUrl: "/", username: debugUsernameEl.current.value });
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
@@ -37,27 +21,20 @@ export default function Signin({ csrfToken, providers }) {
|
||||
<meta name="Sign Up" content="Sign up to access Open Assistant" />
|
||||
</Head>
|
||||
<AuthLayout>
|
||||
<Stack spacing="6">
|
||||
{credentials && (
|
||||
<form onSubmit={signinWithDebugCredentials} className="border-2 border-orange-200 rounded-md p-4 relative">
|
||||
<span className="text-orange-600 absolute -top-3 left-5 bg-white px-1">For Debugging Only</span>
|
||||
<Stack>
|
||||
<Input variant="outline" size="lg" placeholder="Username" ref={debugUsernameEl} />
|
||||
<Button size={"lg"} leftIcon={<FaBug />} colorScheme="gray" type="submit">
|
||||
Continue with Debug User
|
||||
</Button>
|
||||
</Stack>
|
||||
</form>
|
||||
)}
|
||||
<Stack spacing="2">
|
||||
{email && (
|
||||
<form onSubmit={signinWithEmail}>
|
||||
<Stack>
|
||||
<Input variant="outline" size="lg" placeholder="Email Address" ref={emailEl} />
|
||||
<Button size={"lg"} leftIcon={<FaEnvelope />} colorScheme="gray">
|
||||
Continue with Email
|
||||
</Button>
|
||||
</Stack>
|
||||
</form>
|
||||
<Stack>
|
||||
<Input variant="outline" size="lg" placeholder="Email Address" ref={emailEl} />
|
||||
<Button
|
||||
size={"lg"}
|
||||
leftIcon={<FaEnvelope />}
|
||||
colorScheme="gray"
|
||||
onClick={signinWithEmail}
|
||||
// isDisabled="false"
|
||||
>
|
||||
Continue with Email
|
||||
</Button>
|
||||
</Stack>
|
||||
)}
|
||||
{discord && (
|
||||
<Button
|
||||
@@ -69,7 +46,8 @@ export default function Signin({ csrfToken, providers }) {
|
||||
size="lg"
|
||||
leftIcon={<FaDiscord />}
|
||||
color="white"
|
||||
onClick={signinWithDiscord}
|
||||
onClick={() => signIn(discord, { callbackUrl: "/" })}
|
||||
// isDisabled="false"
|
||||
>
|
||||
Continue with Discord
|
||||
</Button>
|
||||
@@ -84,7 +62,7 @@ export default function Signin({ csrfToken, providers }) {
|
||||
size={"lg"}
|
||||
leftIcon={<FaGithub />}
|
||||
colorScheme="blue"
|
||||
onClick={signinWithGithub}
|
||||
// isDisabled="false"
|
||||
>
|
||||
Continue with Github
|
||||
</Button>
|
||||
@@ -13,14 +13,6 @@ export default function Verify() {
|
||||
</Head>
|
||||
<AuthLayout>
|
||||
<h1 className="text-lg">A sign-in link has been sent to your email address.</h1>
|
||||
<hr className="mt-14 mb-4 h-px bg-gray-200 border-0" />
|
||||
<Link
|
||||
href="#"
|
||||
aria-label="Log In"
|
||||
className="flex justify-center font-medium text-black hover:underline underline-offset-4"
|
||||
>
|
||||
Already have an account? Log In
|
||||
</Link>
|
||||
</AuthLayout>
|
||||
</>
|
||||
);
|
||||
|
||||
@@ -8,6 +8,7 @@ import poster from "src/lib/poster";
|
||||
import { Messages } from "src/components/Messages";
|
||||
import { TwoColumns } from "src/components/TwoColumns";
|
||||
import { Button } from "src/components/Button";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
|
||||
const AssistantReply = () => {
|
||||
const [tasks, setTasks] = useState([]);
|
||||
@@ -39,11 +40,12 @@ const AssistantReply = () => {
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* TODO: Make this a nicer loading screen.
|
||||
*/
|
||||
if (isLoading) {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
|
||||
if (tasks.length == 0) {
|
||||
return <div className="p-6 bg-slate-100 text-gray-800">Loading...</div>;
|
||||
return <div className="p-6 bg-slate-100 text-gray-800">No tasks found...</div>;
|
||||
}
|
||||
|
||||
const task = tasks[0].task;
|
||||
|
||||
@@ -9,6 +9,7 @@ import poster from "src/lib/poster";
|
||||
|
||||
import { TwoColumns } from "src/components/TwoColumns";
|
||||
import { Button } from "src/components/Button";
|
||||
import { LoadingScreen } from "@/components/Loading/LoadingScreen";
|
||||
|
||||
const SummarizeStory = () => {
|
||||
// Use an array of tasks that record the sequence of steps until a task is
|
||||
@@ -49,11 +50,12 @@ const SummarizeStory = () => {
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* TODO: Make this a nicer loading screen.
|
||||
*/
|
||||
if (isLoading) {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
|
||||
if (tasks.length == 0) {
|
||||
return <div className=" p-6 bg-slate-100 text-gray-800">Loading...</div>;
|
||||
return <div className="p-6 bg-slate-100 text-gray-800">No tasks found...</div>;
|
||||
}
|
||||
|
||||
return (
|
||||
|
||||
@@ -8,6 +8,7 @@ import poster from "src/lib/poster";
|
||||
import { Messages } from "src/components/Messages";
|
||||
import { TwoColumns } from "src/components/TwoColumns";
|
||||
import { Button } from "src/components/Button";
|
||||
import { LoadingScreen } from "@/components/Loading/LoadingScreen";
|
||||
|
||||
const UserReply = () => {
|
||||
const [tasks, setTasks] = useState([]);
|
||||
@@ -39,11 +40,12 @@ const UserReply = () => {
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* TODO: Make this a nicer loading screen.
|
||||
*/
|
||||
if (isLoading) {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
|
||||
if (tasks.length == 0) {
|
||||
return <div className="p-6 bg-slate-100 text-gray-800">Loading...</div>;
|
||||
return <div className="p-6 bg-slate-100 text-gray-800">No tasks found...</div>;
|
||||
}
|
||||
|
||||
const task = tasks[0].task;
|
||||
|
||||
@@ -9,6 +9,7 @@ import fetcher from "src/lib/fetcher";
|
||||
import poster from "src/lib/poster";
|
||||
|
||||
import { Button } from "src/components/Button";
|
||||
import { LoadingScreen } from "@/components/Loading/LoadingScreen";
|
||||
|
||||
const RankInitialPrompts = () => {
|
||||
const [tasks, setTasks] = useState([]);
|
||||
@@ -44,12 +45,14 @@ const RankInitialPrompts = () => {
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* TODO: Make this a nicer loading screen.
|
||||
*/
|
||||
if (tasks.length == 0) {
|
||||
return <div className="p-6 bg-slate-100 text-gray-800">Loading...</div>;
|
||||
if (isLoading) {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
|
||||
if (tasks.length == 0) {
|
||||
return <div className="p-6 bg-slate-100 text-gray-800">No tasks found...</div>;
|
||||
}
|
||||
|
||||
const prompts = tasks[0].task.prompts as string[];
|
||||
const items = ranking.map((i) => ({
|
||||
text: prompts[i],
|
||||
|
||||
@@ -11,6 +11,7 @@ import poster from "src/lib/poster";
|
||||
|
||||
import { TwoColumns } from "src/components/TwoColumns";
|
||||
import { Button } from "src/components/Button";
|
||||
import { LoadingScreen } from "@/components/Loading/LoadingScreen";
|
||||
|
||||
const RateSummary = () => {
|
||||
// Use an array of tasks that record the sequence of steps until a task is
|
||||
@@ -49,11 +50,12 @@ const RateSummary = () => {
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* TODO: Make this a nicer loading screen.
|
||||
*/
|
||||
if (isLoading) {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
|
||||
if (tasks.length == 0) {
|
||||
return <div className="p-6 bg-slate-100 text-gray-800">Loading...</div>;
|
||||
return <div className="p-6 bg-slate-100 text-gray-800">No tasks found...</div>;
|
||||
}
|
||||
|
||||
return (
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import RankItem from "@/components/RankItem";
|
||||
import RankItem from "src/components/RankItem";
|
||||
import { BarsArrowUpIcon, BarsArrowDownIcon } from "@heroicons/react/24/solid";
|
||||
import Image from "next/image";
|
||||
import { HiBarsArrowUp, HiBarsArrowDown } from "react-icons/hi2";
|
||||
|
||||
Reference in New Issue
Block a user