Merge branch 'main' into add-debug-skip-password-env

This commit is contained in:
Yannic Kilcher
2022-12-29 13:34:28 +01:00
46 changed files with 29778 additions and 3195 deletions
+6 -2
View File
@@ -18,16 +18,20 @@ We can then take the resulting model and continue with completion sampling step
We are not going to stop at replicating ChatGPT. We want to build the assistant of the future, able to not only write email and cover letters, but do meaningful work, use APIs, dynamically research information, and much more, with the ability to be personalized and extended by anyone. And we want to do this in a way that is open and accessible, which means we must not only build a great assistant, but also make it small and efficient enough to run on consumer hardware.
### Slide Decks
[Important Data Structures](https://docs.google.com/presentation/d/1iaX_nxasVWlvPiSNs0cllR9L_1neZq0RJxd6MFEalUY/edit?usp=sharing)
## How can you help?
All open source projects begins with people like you. Open source is the belief that if we collaborate we can together gift our knowledge and technology to the world for the benefit of humanity.
## Im in! Now what?
[Fill out the contributor signup form](https://docs.google.com/forms/d/e/1FAIpQLSeuggO7UdYkBvGLEJldDvxp6DwaRbW5p7dl96UzFkZgziRTrQ/viewform)
[Join the LAION Discord Server!](https://discord.com/invite/mVcgxMPD7e)
[and / or the YK Discord Server](https://ykilcher.com/discord)
[Visit the Notion](https://ykilcher.com/open-assistant)
### Taking on Tasks
@@ -0,0 +1,50 @@
# -*- coding: utf-8 -*-
"""post ref for work_package
Revision ID: d24b37426857
Revises: 3358eb6834e6
Create Date: 2022-12-28 11:42:26.773704
"""
import sqlalchemy as sa
import sqlmodel
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "d24b37426857"
down_revision = "3358eb6834e6"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("post", sa.Column("depth", sa.Integer(), server_default=sa.text("0"), nullable=False))
op.add_column("post", sa.Column("children_count", sa.Integer(), server_default=sa.text("0"), nullable=False))
op.add_column("post_reaction", sa.Column("work_package_id", postgresql.UUID(as_uuid=True), nullable=False))
op.drop_constraint("post_reaction_post_id_fkey", "post_reaction", type_="foreignkey")
op.create_foreign_key(None, "post_reaction", "work_package", ["work_package_id"], ["id"])
op.drop_column("post_reaction", "post_id")
op.add_column("work_package", sa.Column("done", sa.Boolean(), server_default=sa.text("false"), nullable=False))
op.add_column("work_package", sa.Column("ack", sa.Boolean(), nullable=True))
op.add_column("work_package", sa.Column("frontend_ref_post_id", sqlmodel.sql.sqltypes.AutoString(), nullable=True))
op.add_column("work_package", sa.Column("thread_id", sqlmodel.sql.sqltypes.GUID(), nullable=True))
op.add_column("work_package", sa.Column("parent_post_id", sqlmodel.sql.sqltypes.GUID(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("work_package", "parent_post_id")
op.drop_column("work_package", "thread_id")
op.drop_column("work_package", "frontend_ref_post_id")
op.drop_column("work_package", "ack")
op.drop_column("work_package", "done")
op.add_column("post_reaction", sa.Column("post_id", postgresql.UUID(), autoincrement=False, nullable=False))
op.drop_constraint(None, "post_reaction", type_="foreignkey")
op.create_foreign_key("post_reaction_post_id_fkey", "post_reaction", "post", ["post_id"], ["id"])
op.drop_column("post_reaction", "work_package_id")
op.drop_column("post", "children_count")
op.drop_column("post", "depth")
# ### end Alembic commands ###
+20
View File
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
from http import HTTPStatus
from pathlib import Path
import alembic.command
@@ -7,10 +8,29 @@ import fastapi
from loguru import logger
from oasst_backend.api.v1.api import api_router
from oasst_backend.config import settings
from oasst_backend.exceptions import OasstError, OasstErrorCode
from starlette.middleware.cors import CORSMiddleware
app = fastapi.FastAPI(title=settings.PROJECT_NAME, openapi_url=f"{settings.API_V1_STR}/openapi.json")
@app.exception_handler(OasstError)
async def oasst_exception_handler(request: fastapi.Request, ex: OasstError):
logger.error(f"{request.method} {request.url} failed: {repr(ex)}")
return fastapi.responses.JSONResponse(
status_code=int(ex.http_status_code), content={"message": ex.message, "error_code": ex.error_code}
)
@app.exception_handler(Exception)
async def unhandled_exception_handler(request: fastapi.Request, ex: Exception):
logger.exception(f"{request.method} {request.url} failed [UNHANDLED]: {repr(ex)}")
status = HTTPStatus.INTERNAL_SERVER_ERROR
return fastapi.responses.JSONResponse(
status_code=status.value, content={"message": status.name, "error_code": OasstErrorCode.GENERIC_ERROR}
)
# Set all CORS enabled origins
if settings.BACKEND_CORS_ORIGINS:
app.add_middleware(
+23 -18
View File
@@ -1,16 +1,17 @@
# -*- coding: utf-8 -*-
from http import HTTPStatus
from secrets import token_hex
from typing import Generator
from uuid import UUID
from fastapi import HTTPException, Security
from fastapi import Security
from fastapi.security.api_key import APIKey, APIKeyHeader, APIKeyQuery
from loguru import logger
from oasst_backend.config import settings
from oasst_backend.database import engine
from oasst_backend.exceptions import OasstError, OasstErrorCode
from oasst_backend.models import ApiClient
from sqlmodel import Session
from starlette.status import HTTP_403_FORBIDDEN
def get_db() -> Generator:
@@ -36,22 +37,26 @@ def api_auth(
api_key: APIKey,
db: Session,
) -> ApiClient:
if api_key or settings.DEBUG_SKIP_API_KEY_CHECK:
if api_key is None and not settings.DEBUG_SKIP_API_KEY_CHECK:
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Could not validate credentials")
if settings.DEBUG_SKIP_API_KEY_CHECK or settings.DEBUG_ALLOW_ANY_API_KEY:
# make sure that a dummy api key exits in db (foreign key references)
ANY_API_KEY_ID = UUID("00000000-1111-2222-3333-444444444444")
api_client: ApiClient = db.query(ApiClient).filter(ApiClient.id == ANY_API_KEY_ID).first()
if api_client is None:
token = token_hex(32)
logger.info(f"ANY_API_KEY missing, inserting api_key: {token}")
api_client = ApiClient(id=ANY_API_KEY_ID, api_key=token, description="ANY_API_KEY, random token")
db.add(api_client)
db.commit()
return api_client
if settings.DEBUG_SKIP_API_KEY_CHECK or settings.DEBUG_ALLOW_ANY_API_KEY:
# make sure that a dummy api key exits in db (foreign key references)
ANY_API_KEY_ID = UUID("00000000-1111-2222-3333-444444444444")
api_client: ApiClient = db.query(ApiClient).filter(ApiClient.id == ANY_API_KEY_ID).first()
if api_client is None:
token = token_hex(32)
logger.info(f"ANY_API_KEY missing, inserting api_key: {token}")
api_client = ApiClient(id=ANY_API_KEY_ID, api_key=token, description="ANY_API_KEY, random token")
db.add(api_client)
db.commit()
return api_client
api_client = db.query(ApiClient).filter(ApiClient.api_key == api_key).first()
if api_client is not None and api_client.enabled:
return api_client
api_client = db.query(ApiClient).filter(ApiClient.api_key == api_key).first()
if api_client is not None and api_client.enabled:
return api_client
raise OasstError(
"Could not validate credentials",
error_code=OasstErrorCode.API_CLIENT_NOT_AUTHORIZED,
http_status_code=HTTPStatus.FORBIDDEN,
)
+104 -104
View File
@@ -1,39 +1,54 @@
# -*- coding: utf-8 -*-
import random
from typing import Any
from typing import Any, Optional, Tuple
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends
from fastapi.security.api_key import APIKey
from loguru import logger
from oasst_backend.api import deps
from oasst_backend.exceptions import OasstError, OasstErrorCode
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
from starlette.status import HTTP_400_BAD_REQUEST
router = APIRouter()
def generate_task(request: protocol_schema.TaskRequest) -> protocol_schema.Task:
def generate_task(
request: protocol_schema.TaskRequest, pr: PromptRepository
) -> Tuple[protocol_schema.Task, Optional[UUID], Optional[UUID]]:
thread_id = None
parent_post_id = None
match request.type:
case protocol_schema.TaskRequestType.random:
logger.info("Frontend requested a random task.")
while request.type == protocol_schema.TaskRequestType.random:
request.type = random.choice(list(protocol_schema.TaskRequestType)).value
return generate_task(request)
case protocol_schema.TaskRequestType.summarize_story:
logger.info("Generating a SummarizeStoryTask.")
task = protocol_schema.SummarizeStoryTask(
story="This is a story. A very long story. So long, it needs to be summarized.",
)
case protocol_schema.TaskRequestType.rate_summary:
logger.info("Generating a RateSummaryTask.")
task = protocol_schema.RateSummaryTask(
full_text="This is a story. A very long story. So long, it needs to be summarized.",
summary="This is a summary.",
scale=protocol_schema.RatingScale(min=1, max=5),
)
disabled_tasks = (
protocol_schema.TaskRequestType.summarize_story,
protocol_schema.TaskRequestType.rate_summary,
)
request.type = random.choice(
tuple(set(protocol_schema.TaskRequestType).difference(disabled_tasks))
).value
return generate_task(request, pr)
# AKo: Summary tasks are currently disabled/supported, we focus on the conversation tasks.
# case protocol_schema.TaskRequestType.summarize_story:
# logger.info("Generating a SummarizeStoryTask.")
# task = protocol_schema.SummarizeStoryTask(
# story="This is a story. A very long story. So long, it needs to be summarized.",
# )
# case protocol_schema.TaskRequestType.rate_summary:
# logger.info("Generating a RateSummaryTask.")
# task = protocol_schema.RateSummaryTask(
# full_text="This is a story. A very long story. So long, it needs to be summarized.",
# summary="This is a summary.",
# scale=protocol_schema.RatingScale(min=1, max=5),
# )
case protocol_schema.TaskRequestType.initial_prompt:
logger.info("Generating an InitialPromptTask.")
task = protocol_schema.InitialPromptTask(
@@ -41,87 +56,72 @@ def generate_task(request: protocol_schema.TaskRequest) -> protocol_schema.Task:
)
case protocol_schema.TaskRequestType.user_reply:
logger.info("Generating a UserReplyTask.")
task = protocol_schema.UserReplyTask(
conversation=protocol_schema.Conversation(
messages=[
protocol_schema.ConversationMessage(
text="Hey, assistant, what's going on in the world?",
is_assistant=False,
),
protocol_schema.ConversationMessage(
text="I'm not sure I understood correctly, could you rephrase that?",
is_assistant=True,
),
],
)
)
posts = pr.fetch_random_conversation("assistant")
messages = [
protocol_schema.ConversationMessage(text=p.payload.payload.text, is_assistant=(p.role == "assistant"))
for p in posts
]
task = protocol_schema.UserReplyTask(conversation=protocol_schema.Conversation(messages=messages))
thread_id = posts[-1].thread_id
parent_post_id = posts[-1].id
case protocol_schema.TaskRequestType.assistant_reply:
logger.info("Generating a AssistantReplyTask.")
task = protocol_schema.AssistantReplyTask(
conversation=protocol_schema.Conversation(
messages=[
protocol_schema.ConversationMessage(
text="Hey, assistant, write me an English essay about water.",
is_assistant=False,
),
],
)
)
posts = pr.fetch_random_conversation("user")
messages = [
protocol_schema.ConversationMessage(text=p.payload.payload.text, is_assistant=(p.role == "assistant"))
for p in posts
]
task = protocol_schema.AssistantReplyTask(conversation=protocol_schema.Conversation(messages=messages))
thread_id = posts[-1].thread_id
parent_post_id = posts[-1].id
case protocol_schema.TaskRequestType.rank_initial_prompts:
logger.info("Generating a RankInitialPromptsTask.")
task = protocol_schema.RankInitialPromptsTask(
prompts=[
"Please write a story about a time you were happy.",
"Please write a story about a time you were sad.",
]
)
posts = pr.fetch_random_initial_prompts()
task = protocol_schema.RankInitialPromptsTask(prompts=[p.payload.payload.text for p in posts])
case protocol_schema.TaskRequestType.rank_user_replies:
logger.info("Generating a RankUserRepliesTask.")
conversation, replies = pr.fetch_multiple_random_replies(post_role="assistant")
messages = [
protocol_schema.ConversationMessage(
text=p.payload.payload.text,
is_assistant=(p.role == "assistant"),
)
for p in conversation
]
replies = [p.payload.payload.text for p in replies]
task = protocol_schema.RankUserRepliesTask(
conversation=protocol_schema.Conversation(
messages=[
protocol_schema.ConversationMessage(
text="Hey, assistant, what's going on in the world?",
is_assistant=False,
),
protocol_schema.ConversationMessage(
text="I'm not sure I understood correctly, could you rephrase that?",
is_assistant=True,
),
],
messages=messages,
),
replies=[
"Oh come oooooon!",
"What are the news?",
],
replies=replies,
)
case protocol_schema.TaskRequestType.rank_assistant_replies:
logger.info("Generating a RankAssistantRepliesTask.")
conversation, replies = pr.fetch_multiple_random_replies(post_role="user")
messages = [
protocol_schema.ConversationMessage(
text=p.payload.payload.text,
is_assistant=(p.role == "assistant"),
)
for p in conversation
]
replies = [p.payload.payload.text for p in replies]
task = protocol_schema.RankAssistantRepliesTask(
conversation=protocol_schema.Conversation(
messages=[
protocol_schema.ConversationMessage(
text="Hey, assistant, what's going on in the world?",
is_assistant=False,
),
],
),
replies=[
"I'm not sure I understood correctly, could you rephrase that?",
"The world is fine. All good.",
"Crap is hitting the fan. Start farming.",
],
conversation=protocol_schema.Conversation(messages=messages),
replies=replies,
)
case _:
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
detail="Invalid request type.",
)
raise OasstError("Invalid request type", OasstErrorCode.TASK_INVALID_REQUEST_TYPE)
logger.info(f"Generated {task=}.")
return task
return task, thread_id, parent_post_id
@router.post("/", response_model=protocol_schema.AnyTask) # work with Union once more types are added
@@ -137,16 +137,15 @@ def request_task(
api_client = deps.api_auth(api_key, db)
try:
task = generate_task(request)
pr = PromptRepository(db, api_client, request.user)
pr.store_task(task)
task, thread_id, parent_post_id = generate_task(request, pr)
pr.store_task(task, thread_id, parent_post_id)
except OasstError:
raise
except Exception:
logger.exception("Failed to generate task.")
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
)
logger.exception("Failed to generate task..")
raise OasstError("Failed to generate task.", OasstErrorCode.TASK_GENERATION_FAILED)
return task
@@ -171,11 +170,11 @@ def acknowledge_task(
logger.info(f"Frontend acknowledges task {task_id=}, {ack_request=}.")
pr.bind_frontend_post_id(task_id=task_id, post_id=ack_request.post_id)
except OasstError:
raise
except Exception:
logger.exception("Failed to acknowledge task.")
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
)
raise OasstError("Failed to acknowledge task.", OasstErrorCode.TASK_ACK_FAILED)
return {}
@@ -190,11 +189,15 @@ def acknowledge_task_failure(
"""
The frontend reports failure to implement a task.
"""
deps.api_auth(api_key, db)
logger.info(f"Frontend reports failure to implement task {task_id=}, {nack_request=}.")
# here we would store the post id in the database for the task
return {}
try:
logger.info(f"Frontend reports failure to implement task {task_id=}, {nack_request=}.")
api_client = deps.api_auth(api_key, db)
pr = PromptRepository(db, api_client, user=None)
pr.acknowledge_task_failure(task_id)
except (KeyError, RuntimeError):
logger.exception("Failed to not acknowledge task.")
raise OasstError("Failed to not acknowledge task.", OasstErrorCode.TASK_NACK_FAILED)
@router.post("/interaction")
@@ -219,8 +222,9 @@ def post_interaction(
)
# here we store the text reply in the database
# ToDo: role user or agent?
pr.store_text_reply(interaction, role="unknown")
pr.store_text_reply(
text=interaction.text, post_id=interaction.post_id, user_post_id=interaction.user_post_id
)
return protocol_schema.TaskDone()
case protocol_schema.PostRating:
@@ -242,13 +246,9 @@ def post_interaction(
# here we would store the ranking in the database
return protocol_schema.TaskDone()
case _:
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
detail="Invalid response type.",
)
raise OasstError("Invalid response type.", OasstErrorCode.TASK_INVALID_RESPONSE_TYPE)
except OasstError:
raise
except Exception:
logger.exception("Interaction request failed.")
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
)
raise OasstError("Interaction request failed.", OasstErrorCode.TASK_INTERACTION_REQUEST_FAILED)
+2 -1
View File
@@ -1,8 +1,9 @@
# -*- coding: utf-8 -*-
from oasst_backend.config import settings
from oasst_backend.exceptions import OasstError, OasstErrorCode
from sqlmodel import create_engine
if settings.DATABASE_URI is None:
raise ValueError("DATABASE_URI is not set")
raise OasstError("DATABASE_URI is not set", error_code=OasstErrorCode.DATABASE_URI_NOT_SET)
engine = create_engine(settings.DATABASE_URI)
+60
View File
@@ -0,0 +1,60 @@
# -*- coding: utf-8 -*-
from enum import IntEnum
from http import HTTPStatus
class OasstErrorCode(IntEnum):
"""
Error codes of the Open-Assistant backend API.
Ranges:
0-1000: general errors
1000-2000: tasks endpoint
2000-3000: prompt_repository
"""
# 0-1000: general errors
GENERIC_ERROR = 0
DATABASE_URI_NOT_SET = 1
API_CLIENT_NOT_AUTHORIZED = 2
# 1000-2000: tasks endpoint
TASK_INVALID_REQUEST_TYPE = 1000
TASK_ACK_FAILED = 1001
TASK_NACK_FAILED = 1002
TASK_INVALID_RESPONSE_TYPE = 1003
TASK_INTERACTION_REQUEST_FAILED = 1004
TASK_GENERATION_FAILED = 1005
# 2000-3000: prompt_repository
INVALID_POST_ID = 2000
POST_NOT_FOUND = 2001
RATING_OUT_OF_RANGE = 2002
INVALID_RANKING_VALUE = 2003
INVALID_TASK_TYPE = 2004
USER_NOT_SPECIFIED = 2005
NO_THREADS_FOUND = 2006
WORK_PACKAGE_NOT_FOUND = 2100
WORK_PACKAGE_EXPIRED = 2101
WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH = 2102
WORK_PACKAGE_ALREADY_UPDATED = 2103
WORK_PACKAGE_NOT_ACK = 2104
WORK_PACKAGE_ALREADY_DONE = 2105
class OasstError(Exception):
"""Base class for Open-Assistant exceptions."""
message: str
error_code: int
http_status_code: HTTPStatus
def __init__(self, message: str, error_code: OasstErrorCode, http_status_code: HTTPStatus = HTTPStatus.BAD_REQUEST):
super().__init__(message, error_code, http_status_code) # make excetpion picklable (fill args member)
self.message = message
self.error_code = error_code
self.http_status_code = http_status_code
def __repr__(self) -> str:
class_name = self.__class__.__name__
return f'{class_name}(message="{self.message}", error_code={self.error_code}, http_status_code={self.http_status_code})'
+2
View File
@@ -31,3 +31,5 @@ class Post(SQLModel, table=True):
)
payload_type: str = Field(nullable=False, max_length=200)
payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=True))
depth: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
children_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
@@ -13,8 +13,8 @@ from .payload_column_type import PayloadContainer, payload_column_type
class PostReaction(SQLModel, table=True):
__tablename__ = "post_reaction"
post_id: Optional[UUID] = Field(
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("post.id"), nullable=False, primary_key=True)
work_package_id: Optional[UUID] = Field(
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("work_package.id"), nullable=False, primary_key=True)
)
person_id: UUID = Field(
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("person.id"), nullable=False, primary_key=True)
@@ -5,6 +5,7 @@ from uuid import UUID, uuid4
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from sqlalchemy import false
from sqlmodel import Field, SQLModel
from .payload_column_type import PayloadContainer, payload_column_type
@@ -26,3 +27,12 @@ class WorkPackage(SQLModel, table=True):
payload_type: str = Field(nullable=False, max_length=200)
payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False))
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
ack: Optional[bool] = None
done: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
frontend_ref_post_id: Optional[str] = None
thread_id: Optional[UUID] = None
parent_post_id: Optional[UUID] = None
@property
def expired(self) -> bool:
return self.expiry_date is not None and datetime.utcnow() < self.expiry_date
+208 -83
View File
@@ -1,15 +1,16 @@
# -*- coding: utf-8 -*-
from datetime import datetime
import random
from typing import Optional
from uuid import UUID, uuid4
import oasst_backend.models.db_payload as db_payload
from loguru import logger
from oasst_backend.exceptions import OasstError, OasstErrorCode
from oasst_backend.journal_writer import JournalWriter
from oasst_backend.models import ApiClient, Person, Post, PostReaction, TextLabels, WorkPackage
from oasst_backend.models.payload_column_type import PayloadContainer
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
from sqlmodel import Session, func
class PromptRepository:
@@ -52,9 +53,9 @@ class PromptRepository:
def validate_post_id(self, post_id: str) -> None:
if not isinstance(post_id, str):
raise TypeError(f"post_id must be string, not {type(post_id)}")
raise OasstError(f"post_id must be string, not {type(post_id)}", OasstErrorCode.INVALID_POST_ID)
if not post_id:
raise ValueError("post_id must not be empty")
raise OasstError("post_id must not be empty", OasstErrorCode.INVALID_POST_ID)
def bind_frontend_post_id(self, task_id: UUID, post_id: str):
self.validate_post_id(post_id)
@@ -66,36 +67,36 @@ class PromptRepository:
.first()
)
if work_pack is None:
raise KeyError(f"WorkPackage for task {task_id} not found")
if work_pack.expiry_date is not None and datetime.utcnow() > work_pack.expiry_date:
raise RuntimeError("WorkPackage already expired.")
raise OasstError(f"WorkPackage for task {task_id} not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND)
if work_pack.expired:
raise OasstError("WorkPackage already expired.", OasstErrorCode.WORK_PACKAGE_EXPIRED)
if work_pack.done or work_pack.ack is not None:
raise OasstError("WorkPackage already updated.", OasstErrorCode.WORK_PACKAGE_ALREADY_UPDATED)
work_pack.frontend_ref_post_id = post_id
work_pack.ack = True
# ToDo: check race-condition, transaction
self.db.add(work_pack)
self.db.commit()
# check if task thread exits
thread_root = (
self.db.query(Post)
.filter(
Post.workpackage_id == work_pack.id,
Post.frontend_post_id == post_id,
Post.parent_id is None,
Post.api_client_id == self.api_client.id,
)
.one_or_none()
def acknowledge_task_failure(self, task_id):
# find work package
work_pack: WorkPackage = (
self.db.query(WorkPackage)
.filter(WorkPackage.id == task_id, WorkPackage.api_client_id == self.api_client.id)
.first()
)
if thread_root is None:
thread_id = uuid4()
thread_root = self.insert_post(
post_id=thread_id,
thread_id=thread_id,
frontend_post_id=post_id,
parent_id=None,
role="system",
workpackage_id=work_pack.id,
payload=None,
payload_type="bind",
)
return thread_root
if work_pack is None:
raise OasstError(f"WorkPackage for task {task_id} not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND)
if work_pack.expired:
raise OasstError("WorkPackage already expired.", OasstErrorCode.WORK_PACKAGE_EXPIRED)
if work_pack.done or work_pack.ack is not None:
raise OasstError("WorkPackage already updated.", OasstErrorCode.WORK_PACKAGE_ALREADY_UPDATED)
work_pack.ack = False
# ToDo: check race-condition, transaction
self.db.add(work_pack)
self.db.commit()
def fetch_post_by_frontend_post_id(self, frontend_post_id: str, fail_if_missing: bool = True) -> Post:
self.validate_post_id(frontend_post_id)
@@ -105,49 +106,64 @@ class PromptRepository:
.one_or_none()
)
if fail_if_missing and post is None:
raise KeyError(f"Post with post_id {frontend_post_id} not found.")
raise OasstError(f"Post with post_id {frontend_post_id} not found.", OasstErrorCode.POST_NOT_FOUND)
return post
def fetch_workpackage_by_postid(self, post_id: str) -> WorkPackage:
self.validate_post_id(post_id)
post = self.fetch_post_by_frontend_post_id(post_id, fail_if_missing=True)
work_pack = self.db.query(WorkPackage).filter(WorkPackage.id == post.workpackage_id).one()
return work_pack
def store_text_reply(self, reply: protocol_schema.TextReplyToPost, role: str) -> Post:
self.validate_post_id(reply.post_id)
self.validate_post_id(reply.user_post_id)
work_package = self.fetch_workpackage_by_postid(reply.post_id)
work_payload: db_payload.TaskPayload = work_package.payload.payload
logger.info(f"found task work package in db: {work_payload}")
# find post with post-id
parent_post: Post = (
self.db.query(Post)
.filter(
Post.api_client_id == self.api_client.id,
Post.frontend_post_id == reply.post_id,
# Post.person_id == self.person_id
)
work_pack = (
self.db.query(WorkPackage)
.filter(WorkPackage.api_client_id == self.api_client.id, WorkPackage.frontend_ref_post_id == post_id)
.one_or_none()
)
return work_pack
if parent_post is None:
raise KeyError(f"Post for post_id {reply.post_id} not found.")
def store_text_reply(self, text: str, post_id: str, user_post_id: str, role: str = None) -> Post:
self.validate_post_id(post_id)
self.validate_post_id(user_post_id)
wp = self.fetch_workpackage_by_postid(post_id)
if wp is None:
raise OasstError(f"WorkPackage for {post_id=} not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND)
if wp.expired:
raise OasstError("WorkPackage already expired.", OasstErrorCode.WORK_PACKAGE_EXPIRED)
if not wp.ack:
raise OasstError("WorkPackage is not acknowledged.", OasstErrorCode.WORK_PACKAGE_NOT_ACK)
if wp.done:
raise OasstError("WorkPackage already done.", OasstErrorCode.WORK_PACKAGE_ALREADY_DONE)
# If there's no parent post assume user started new conversation
role = "user"
depth = 0
if wp.parent_post_id:
parent_post = self.fetch_post(wp.parent_post_id)
parent_post.children_count += 1
self.db.add(parent_post)
depth = parent_post.depth + 1
if parent_post.role == "assistant":
role = "user"
else:
role = "assistant"
# create reply post
user_post_id = uuid4()
new_post_id = uuid4()
user_post = self.insert_post(
post_id=user_post_id,
frontend_post_id=reply.user_post_id,
parent_id=parent_post.id,
thread_id=parent_post.thread_id,
workpackage_id=parent_post.workpackage_id,
post_id=new_post_id,
frontend_post_id=user_post_id,
parent_id=wp.parent_post_id,
thread_id=wp.thread_id or new_post_id,
workpackage_id=wp.id,
role=role,
payload=db_payload.PostPayload(text=reply.text),
payload=db_payload.PostPayload(text=text),
depth=depth,
)
self.journal.log_text_reply(work_package=work_package, post_id=user_post_id, role=role, length=len(reply.text))
wp.done = True
self.db.add(wp)
self.db.commit()
self.journal.log_text_reply(work_package=wp, post_id=new_post_id, role=role, length=len(text))
return user_post
def store_rating(self, rating: protocol_schema.PostRating) -> PostReaction:
@@ -156,12 +172,16 @@ class PromptRepository:
work_package = self.fetch_workpackage_by_postid(rating.post_id)
work_payload: db_payload.RateSummaryPayload = work_package.payload.payload
if type(work_payload) != db_payload.RateSummaryPayload:
raise ValueError(
f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RateSummaryPayload}"
raise OasstError(
f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RateSummaryPayload}",
OasstErrorCode.WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH,
)
if rating.rating < work_payload.scale.min or rating.rating > work_payload.scale.max:
raise ValueError(f"Invalid rating value: {rating.rating=} not in {work_payload.scale=}")
raise OasstError(
f"Invalid rating value: {rating.rating=} not in {work_payload.scale=}",
OasstErrorCode.RATING_OUT_OF_RANGE,
)
# store reaction to post
reaction_payload = db_payload.RatingReactionPayload(rating=rating.rating)
@@ -171,10 +191,11 @@ class PromptRepository:
return reaction
def store_ranking(self, ranking: protocol_schema.PostRanking) -> PostReaction:
post = self.fetch_post_by_frontend_post_id(ranking.post_id, fail_if_missing=True)
# fetch work_package
work_package = self.fetch_workpackage_by_postid(ranking.post_id)
work_package.done = True
self.db.add(work_package)
work_payload: db_payload.RankConversationRepliesPayload | db_payload.RankInitialPromptsPayload = (
work_package.payload.payload
)
@@ -185,14 +206,16 @@ class PromptRepository:
# validate ranking
num_replies = len(work_payload.replies)
if sorted(ranking.ranking) != list(range(num_replies)):
raise ValueError(
f"Invalid ranking submitted. Each reply index must appear exactly once ({num_replies=})."
raise OasstError(
f"Invalid ranking submitted. Each reply index must appear exactly once ({num_replies=}).",
OasstErrorCode.INVALID_RANKING_VALUE,
)
# store reaction to post
reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking)
reaction = self.insert_reaction(post.id, reaction_payload)
self.journal.log_ranking(work_package, post_id=post.id, ranking=ranking.ranking)
reaction = self.insert_reaction(work_package.id, reaction_payload)
# TODO: resolve post_id
self.journal.log_ranking(work_package, post_id=None, ranking=ranking.ranking)
logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.")
@@ -201,25 +224,33 @@ class PromptRepository:
case db_payload.RankInitialPromptsPayload:
# validate ranking
if sorted(ranking.ranking) != list(range(num_prompts := len(work_payload.prompts))):
raise ValueError(
f"Invalid ranking submitted. Each reply index must appear exactly once ({num_prompts=})."
raise OasstError(
f"Invalid ranking submitted. Each reply index must appear exactly once ({num_prompts=}).",
OasstErrorCode.INVALID_RANKING_VALUE,
)
# store reaction to post
reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking)
reaction = self.insert_reaction(post.id, reaction_payload)
self.journal.log_ranking(work_package, post_id=post.id, ranking=ranking.ranking)
reaction = self.insert_reaction(work_package.id, reaction_payload)
# TODO: resolve post_id
self.journal.log_ranking(work_package, post_id=None, ranking=ranking.ranking)
logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.")
return reaction
case _:
raise ValueError(
f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RankConversationRepliesPayload}"
raise OasstError(
f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RankConversationRepliesPayload}",
OasstErrorCode.WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH,
)
def store_task(self, task: protocol_schema.Task) -> WorkPackage:
def store_task(
self,
task: protocol_schema.Task,
thread_id: UUID = None,
parent_post_id: UUID = None,
) -> WorkPackage:
payload: db_payload.TaskPayload
match type(task):
case protocol_schema.SummarizeStoryTask:
@@ -253,13 +284,24 @@ class PromptRepository:
)
case _:
raise ValueError(f"Invalid task type: {type(task)=}")
raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE)
wp = self.insert_work_package(payload=payload, id=task.id)
wp = self.insert_work_package(
payload=payload,
id=task.id,
thread_id=thread_id,
parent_post_id=parent_post_id,
)
assert wp.id == task.id
return wp
def insert_work_package(self, payload: db_payload.TaskPayload, id: UUID = None) -> WorkPackage:
def insert_work_package(
self,
payload: db_payload.TaskPayload,
id: UUID = None,
thread_id: UUID = None,
parent_post_id: UUID = None,
) -> WorkPackage:
c = PayloadContainer(payload=payload)
wp = WorkPackage(
id=id,
@@ -267,6 +309,8 @@ class PromptRepository:
payload_type=type(payload).__name__,
payload=c,
api_client_id=self.api_client.id,
thread_id=thread_id,
parent_post_id=parent_post_id,
)
self.db.add(wp)
self.db.commit()
@@ -284,6 +328,7 @@ class PromptRepository:
role: str,
payload: db_payload.PostPayload,
payload_type: str = None,
depth: int = 0,
) -> Post:
if payload_type is None:
if payload is None:
@@ -302,19 +347,20 @@ class PromptRepository:
api_client_id=self.api_client.id,
payload_type=payload_type,
payload=PayloadContainer(payload=payload),
depth=depth,
)
self.db.add(post)
self.db.commit()
self.db.refresh(post)
return post
def insert_reaction(self, post_id: UUID, payload: db_payload.ReactionPayload) -> PostReaction:
def insert_reaction(self, work_package_id: UUID, payload: db_payload.ReactionPayload) -> PostReaction:
if self.person_id is None:
raise ValueError("User required")
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
container = PayloadContainer(payload=payload)
reaction = PostReaction(
post_id=post_id,
work_package_id=work_package_id,
person_id=self.person_id,
payload=container,
api_client_id=self.api_client.id,
@@ -338,3 +384,82 @@ class PromptRepository:
self.db.commit()
self.db.refresh(model)
return model
def fetch_random_thread(self, require_role: str = None) -> list[Post]:
"""
Loads all posts of a random thread.
:param require_role: If set loads only thread which has
at least one post with given role.
"""
distinct_threads = self.db.query(Post.thread_id).distinct(Post.thread_id)
if require_role:
distinct_threads = distinct_threads.filter(Post.role == require_role)
distinct_threads = distinct_threads.subquery()
random_thread = self.db.query(distinct_threads).order_by(func.random()).limit(1).subquery()
thread_posts = self.db.query(Post).filter(Post.thread_id.in_(random_thread)).all()
return thread_posts
def fetch_random_conversation(self, last_post_role: str = None) -> list[Post]:
"""
Picks a random linear conversation starting from any root post
and ending somewhere in the thread, possibly at the root itself.
:param last_post_role: If set will form a conversation ending with a post
created by this role. Necessary for the tasks like "user_reply" where
the user should reply as a human and hence the last message of the conversation
needs to have "assistant" role.
"""
thread_posts = self.fetch_random_thread(last_post_role)
if not thread_posts:
raise OasstError("No threads found", OasstErrorCode.NO_THREADS_FOUND)
if last_post_role:
conv_posts = [p for p in thread_posts if p.role == last_post_role]
conv_posts = [random.choice(conv_posts)]
else:
conv_posts = [random.choice(thread_posts)]
thread_posts = {p.id: p for p in thread_posts}
while True:
if not conv_posts[-1].parent_id:
# reached the start of the conversation
break
parent_post = thread_posts[conv_posts[-1].parent_id]
conv_posts.append(parent_post)
return list(reversed(conv_posts))
def fetch_random_initial_prompts(self, size: int = 5):
posts = self.db.query(Post).filter(Post.parent_id.is_(None)).order_by(func.random()).limit(size).all()
return posts
def fetch_thread(self, thread_id: UUID):
return self.db.query(Post).filter(Post.thread_id == thread_id).all()
def fetch_multiple_random_replies(self, max_size: int = 5, post_role: str = None):
parent = self.db.query(Post.id).filter(Post.children_count > 1)
if post_role:
parent = parent.filter(Post.role == post_role)
parent = parent.order_by(func.random()).limit(1).subquery()
replies = self.db.query(Post).filter(Post.parent_id.in_(parent)).order_by(func.random()).limit(max_size).all()
thread = self.fetch_thread(replies[0].thread_id)
thread = {p.id: p for p in thread}
thread_posts = [thread[replies[0].parent_id]]
while True:
if not thread_posts[-1].parent_id:
# reached start of the conversation
break
parent_post = thread[thread_posts[-1].parent_id]
thread_posts.append(parent_post)
thread_posts = reversed(thread_posts)
return thread_posts, replies
def fetch_post(self, post_id: UUID) -> Optional[Post]:
return self.db.query(Post).filter(Post.id == post_id).one()
+4 -4
View File
@@ -1,14 +1,14 @@
version: "3.7"
services:
# Use `docker compose up backend-dev` to start a database and work and the backend.
# Use `docker compose up backend-dev --build --attach-dependencies` to start a database and work and the backend.
backend-dev:
image: tianon/true
image: sverrirab/sleep
depends_on: [db, adminer]
# Use `docker compose up frontend-dev` to start all services needed to work on the frontend.
# Use `docker compose up frontend-dev --build --attach-dependencies` to start all services needed to work on the frontend.
frontend-dev:
image: tianon/true
image: sverrirab/sleep
depends_on: [db, webdb, adminer, maildev, backend]
# This DB is for the FastAPI Backend.
+9
View File
@@ -0,0 +1,9 @@
# Documentation
This directory contains the documentation for the project and other related organization documents.
## Contributing to this documentation
Please make a pull request to the `main` branch with your changes.
Consider that this folder is used for documenting the various code sub-parts, the high-level ideas, the ML aspects, experiments, contributor guides, guides for data creation, and many more things. Please try to keep the documentation as concise as possible and keep an organized folder structure that makes sense for everyone.
+1 -1
View File
@@ -1,6 +1,6 @@
# Backend Development Setup
In root directory, run `docker compose up backend-dev` to start a database. The default settings are already configured to connect to the database at `localhost:5432`.
In root directory, run `docker compose up backend-dev --build --attach-dependencies` to start a database. The default settings are already configured to connect to the database at `localhost:5432`.
Make sure you have all requirements installed. You can do this by running `pip install -r requirements.txt` inside the `backend` folder and `pip install -e .` inside the `oasst-shared` folder.
Then, run the backend using the `run-local.sh` script. This will start the backend server at `http://localhost:8080`.
+3
View File
@@ -0,0 +1,3 @@
#!/usr/bin/env bash
parent_path=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P )
docker compose -f "$parent_path/../../docker-compose.yaml" up backend-dev --build --attach-dependencies
+1 -1
View File
@@ -1,5 +1,5 @@
# Frontend Development Setup
In root directory run `docker compose up frontend-dev --build` to start a database and the backend server.
In root directory run `docker compose up frontend-dev --build --attach-dependencies` to start a database and the backend server.
Then, point your frontend at `http://localhost:8080` to start developing. During development, any API key will be accepted.
+32
View File
@@ -0,0 +1,32 @@
const path = require("path");
module.exports = {
stories: [
"../src/components/**/*.stories.mdx",
"../src/components/**/*.stories.@(js|jsx|ts|tsx)",
],
addons: [
"@storybook/addon-links",
"@storybook/addon-essentials",
"@storybook/addon-interactions",
"@chakra-ui/storybook-addon",
],
framework: "@storybook/react",
core: {
builder: "@storybook/builder-webpack5",
},
staticDirs: ["../public"],
// https://github.com/storybookjs/storybook/issues/15336#issuecomment-888528747
typescript: { reactDocgen: false },
// fix to make absolute imports working in storybook
webpackFinal: async (config, { configType }) => {
config.resolve.alias = {
...config.resolve.alias,
src: path.resolve(__dirname, "../src"),
};
return config;
},
features: {
emotionAlias: false,
},
};
+22
View File
@@ -0,0 +1,22 @@
import "!style-loader!css-loader!postcss-loader!tailwindcss/tailwind.css";
export const parameters = {
actions: { argTypesRegex: "^on[A-Z].*" },
controls: {
matchers: {
color: /(background|color)$/i,
date: /Date$/,
},
},
};
// Hacky solution to get Images in next to work
// https://dev.to/jonasmerlin/how-to-use-the-next-js-image-component-in-storybook-1415
import * as NextImage from "next/image";
const OriginalNextImage = NextImage.default;
Object.defineProperty(NextImage, "default", {
configurable: true,
value: (props) => <OriginalNextImage {...props} unoptimized />,
});
+8 -3
View File
@@ -49,9 +49,8 @@ installed:
If you're doing active development we suggest the following workflow:
1. In one tab, navigate to
`${OPEN_ASSISTANT_ROOT}/scripts/frontend-development`.
1. Run `docker compose up --build`. You can optionally include `-d` to detach and
1. In one tab, navigate to the project root.
1. Run `docker compose up frontend-dev --build --attach-dependencies`. You can optionally include `-d` to detach and
later track the logs if desired.
1. In another tab navigate to `${OPEN_ASSISTANT_ROOT/website`.
1. Run `npm install`
@@ -71,6 +70,12 @@ You can use the debug credentials provider to log in without fancy emails or OAu
1. Use the `Login` button in the top right to go to the login page.
1. You should see a section for debug credentials. Enter any username you wish, you will be logged in as that user.
### Using Storybook
To develop components using [Storybook](https://storybook.js.org/) run `npm run storybook`. Then navigate to in your browser to `http://localhost:6006`.
To create a new story create a file named `[componentName].stories.js`. An example how such a story could look like, see `Header.stories.jsx`.
## Code Layout
### React Code
+28846 -2842
View File
File diff suppressed because it is too large Load Diff
+18 -2
View File
@@ -7,7 +7,9 @@
"dev": "next dev",
"build": "next build",
"start": "next start",
"lint": "next lint"
"lint": "next lint",
"storybook": "start-storybook -p 6006",
"build-storybook": "build-storybook"
},
"dependencies": {
"@chakra-ui/react": "^2.4.4",
@@ -40,9 +42,23 @@
"use-debounce": "^9.0.2"
},
"devDependencies": {
"@babel/core": "^7.20.7",
"@chakra-ui/storybook-addon": "^4.0.16",
"@storybook/addon-actions": "^6.5.15",
"@storybook/addon-essentials": "^6.5.15",
"@storybook/addon-interactions": "^6.5.15",
"@storybook/addon-links": "^6.5.15",
"@storybook/addon-postcss": "^2.0.0",
"@storybook/builder-webpack5": "^6.5.15",
"@storybook/manager-webpack5": "^6.5.15",
"@storybook/react": "^6.5.15",
"@storybook/testing-library": "^0.0.13",
"@types/node": "18.11.17",
"@types/react": "18.0.26",
"babel-loader": "^8.3.0",
"eslint-plugin-storybook": "^0.6.8",
"prettier": "2.8.1",
"prisma": "^4.7.1"
"prisma": "^4.7.1",
"typescript": "4.9.4"
}
}
+58 -14
View File
@@ -1,31 +1,75 @@
import Image from "next/image";
import Link from "next/link";
import { FaGithub, FaDiscord } from "react-icons/fa";
import { Container } from "./Container";
import { NavLinks } from "./NavLinks";
export function Footer() {
return (
<footer className="border-t border-gray-200 bg-white">
<Container className="">
<div className="flex flex-col items-start justify-between gap-y-12 pt-16 pb-6 lg:flex-row lg:items-center lg:py-6">
<div>
<div className="flex items-center text-gray-900">
<main>
<Container className="">
<div className="flex flex-wrap justify-between gap-y-12 py-10 lg:items-center lg:py-16">
<div className="flex items-center text-black pr-8">
<Link href="/" aria-label="Home" className="flex items-center">
<Image src="/images/logos/logo.svg" className="mx-auto object-fill" width="50" height="50" alt="logo" />
<Image src="/images/logos/logo.svg" className="mx-auto object-fill" width="52" height="52" alt="logo" />
</Link>
<div className="ml-4">
<p className="text-base font-semibold">Open Assistant</p>
<p className="mt-1 text-sm">Conversational AI for everyone.</p>
<div className="ml-2">
<p className="text-base font-bold">Open Assistant</p>
<p className="text-sm">Conversational AI for everyone.</p>
</div>
</div>
{/* <nav className="mt-11 flex gap-8">
<NavLinks />
</nav> */}
<nav className="flex justify-center gap-20">
<div className="flex flex-col text-sm leading-7">
<b>Information</b>
<div className="flex flex-col leading-5">
<Link href="#" aria-label="Our Team" className="hover:underline underline-offset-2">
Our Team
</Link>
<Link href="#join-us" aria-label="Join Us" className="hover:underline underline-offset-2">
Join Us
</Link>
</div>
</div>
<div className="flex flex-col text-sm leading-7">
<b>Legal</b>
<div className="flex flex-col leading-5">
<Link href="#" aria-label="Privacy Policy" className="hover:underline underline-offset-2">
Privacy Policy
</Link>
<Link href="#" aria-label="Terms of Service" className="hover:underline underline-offset-2">
Terms of Service
</Link>
</div>
</div>
<div className="flex flex-col text-sm leading-7">
<b>Connect</b>
<div className="flex flex-col leading-5">
<Link
href="https://github.com/LAION-AI/Open-Assistant"
rel="noopener noreferrer nofollow"
target="_blank"
aria-label="Privacy Policy"
className="hover:underline underline-offset-2"
>
Github
</Link>
<Link
href="https://discord.gg/pXtnYk9c"
rel="noopener noreferrer nofollow"
target="_blank"
aria-label="Terms of Service"
className="hover:underline underline-offset-2"
>
Discord
</Link>
</div>
</div>
</nav>
</div>
</div>
</Container>
</Container>
</main>
</footer>
);
}
@@ -0,0 +1,24 @@
import { SessionContext } from "next-auth/react";
import React from "react";
import { Header } from "./Header";
export default {
title: "Header/Header",
component: Header,
parameters: {
layout: "fullscreen",
},
};
const Template = (args) => {
var { session } = args;
return (
<SessionContext.Provider value={session}>
<Header {...args} />
</SessionContext.Provider>
);
};
export const Default = Template.bind({});
Default.args = { session: { data: { user: { name: "StoryBook user" } }, status: "authenticated" } };
@@ -6,9 +6,9 @@ import Link from "next/link";
import { signOut, useSession } from "next-auth/react";
import { FaUser, FaSignOutAlt } from "react-icons/fa";
import { Avatar } from "./Avatar";
import { Container } from "./Container";
import { Container } from "src/components/Container";
import { NavLinks } from "./NavLinks";
import { UserMenu } from "./UserMenu";
function MenuIcon(props) {
return (
@@ -45,9 +45,9 @@ function AccountButton() {
return;
}
return (
<Link href="/auth/signup" aria-label="Home" className="flex items-center">
<Link href="/auth/signin" aria-label="Home" className="flex items-center">
<Button variant="outline" leftIcon={<FaUser />}>
Log in
Sign in
</Button>
</Link>
);
@@ -113,7 +113,7 @@ export function Header() {
)}
</Popover>
<AccountButton />
<Avatar />
<UserMenu />
</div>
</Container>
</nav>
@@ -0,0 +1,14 @@
import { NavLinks } from "./NavLinks";
export default {
title: "Header/NavLinks",
component: NavLinks,
};
const Template = (args) => (
<div className="hidden lg:flex lg:gap-10">
<NavLinks {...args} />
</div>
);
export const Default = Template.bind({});
@@ -0,0 +1,25 @@
import { SessionContext } from "next-auth/react";
import React from "react";
import UserMenu from "./UserMenu";
export default {
title: "Header/UserMenu",
component: UserMenu,
};
const Template = (args) => {
var { session } = args;
return (
<SessionContext.Provider value={session}>
<div className="flex flex-col">
<div className="self-end">
<UserMenu {...args} />
</div>
</div>
</SessionContext.Provider>
);
};
export const Default = Template.bind({});
Default.args = { session: { data: { user: { name: "StoryBook user" } }, status: "authenticated" } };
@@ -5,18 +5,18 @@ import { Popover } from "@headlessui/react";
import { AnimatePresence, motion } from "framer-motion";
import { FaCog, FaSignOutAlt, FaGithub } from "react-icons/fa";
export function Avatar() {
export function UserMenu() {
const { data: session } = useSession();
if (!session) {
return <></>;
}
if (session && session.user) {
const displayName = session.user.name || session.user.email;
const email = session.user.email;
const accountOptions = [
{
name: "Account Settings",
href: "#",
href: "/account",
desc: "Account Settings",
icon: FaCog,
//For future use
@@ -35,8 +35,7 @@ export function Avatar() {
height="40"
className="rounded-full"
></Image>
<p className="hidden lg:flex">{displayName}</p>
{/* Will be changed to username once it is implemented */}
<p className="hidden lg:flex">{session.user.name || session.user.email}</p>
</div>
</Popover.Button>
<AnimatePresence initial={false}>
@@ -72,7 +71,7 @@ export function Avatar() {
))}
<a
className="flex items-center rounded-md hover:bg-gray-100 cursor-pointer"
onClick={() => signOut()}
onClick={() => signOut({ callbackUrl: "/" })}
>
<div className="p-4">
<FaSignOutAlt />
@@ -93,4 +92,4 @@ export function Avatar() {
}
}
export default Avatar;
export default UserMenu;
+3
View File
@@ -0,0 +1,3 @@
export { Header } from "./Header";
export { UserMenu } from "./UserMenu";
export { NavLinks } from "./NavLinks";
+1 -1
View File
@@ -3,7 +3,7 @@
import type { NextPage } from "next";
import { Footer } from "./Footer";
import { Header } from "./Header";
import { Header } from "src/components/Header";
export type NextPageWithLayout<P = {}, IP = P> = NextPage<P, IP> & {
getLayout?: (page: React.ReactElement) => React.ReactNode;
@@ -0,0 +1,16 @@
import { LoadingScreen } from "./LoadingScreen";
export default {
title: "Example/LoadingScreen",
component: LoadingScreen,
parameters: {
layout: "fullscreen",
},
};
const Template = (args) => <LoadingScreen {...args} />;
export const Default = Template.bind({});
export const WithText = Template.bind({});
WithText.args = { text: "Loading Text ..." };
@@ -0,0 +1,12 @@
import { Progress } from "@chakra-ui/react";
export const LoadingScreen = ({ text }) => (
<div className="bg-slate-100">
<Progress size="xs" isIndeterminate />
{text && (
<div className="flex h-full">
<div className="text-xl font-bold text-gray-800 mx-auto my-auto">{text}</div>
</div>
)}
</div>
);
+1 -1
View File
@@ -4,5 +4,5 @@ export { default } from "next-auth/middleware";
* Guards all pages under `/grading` and redirects them to the sign in page.
*/
export const config = {
matcher: ["/create/:path*", "/evaluate/:path*"],
matcher: ["/create/:path*", "/evaluate/:path*", "/account/:path*"],
};
+5 -31
View File
@@ -1,45 +1,19 @@
import { useSession } from "next-auth/react";
import { Footer } from "../components/Footer";
import { Header } from "../components/Header";
import { Header } from "src/components/Header";
import Head from "next/head";
import Link from "next/link";
export default function Error() {
const { data: session } = useSession();
if (!session) {
return (
<>
<Head>
<title>Open Assistant</title>
<meta name="404" content="Sorry, this page doesn't exist." />
</Head>
<Header />
<main className="flex h-3/4 items-center justify-center overflow-hidden subpixel-antialiased text-xl">
{"Sorry, the page you're looking for does not exist."}
</main>
<Footer />
</>
);
}
return (
<>
<Head>
<title>Open Assistant</title>
<meta
name="description"
content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world."
/>
<title>404 - Open Assistant</title>
<meta name="404" content="Sorry, this page doesn't exist." />
</Head>
<Header />
<main>
<h2>Open Chat Gpt</h2>
<p>You are logged in</p>
<Link href="/grading/grade-output">~Rate a prompt and output now~</Link>
<main className="flex h-3/4 items-center justify-center overflow-hidden subpixel-antialiased text-xl">
<p>Sorry, the page you are looking for does not exist.</p>
</main>
<Footer />
</>
);
}
+56
View File
@@ -0,0 +1,56 @@
import React, { useState } from "react";
import { useSession } from "next-auth/react";
import { Button, Input, InputGroup, Stack } from "@chakra-ui/react";
import Head from "next/head";
import Router from "next/router";
export default function Account() {
const { data: session } = useSession();
const [username, setUsername] = useState("");
const updateUser = async (e: React.SyntheticEvent) => {
e.preventDefault();
try {
const body = { username };
await fetch("/api/username", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify(body),
});
await Router.push("/account");
} catch (error) {
console.error(error);
}
};
if (!session) {
return;
}
return (
<>
<Head>
<title>Open Assistant</title>
<meta
name="description"
content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world."
/>
</Head>
<main className="h-3/4 z-0 bg-white flex flex-col items-center justify-center">
<p>{session.user.name || "No username"}</p>
<form onSubmit={updateUser}>
<InputGroup>
<Input
onChange={(e) => setUsername(e.target.value)}
placeholder="Edit Username"
type="text"
value={username}
></Input>
<Button disabled={!username} type="submit" value="Change">
Submit
</Button>
</InputGroup>
</form>
<p>{session.user.email}</p>
</main>
</>
);
}
+44
View File
@@ -0,0 +1,44 @@
import Head from "next/head";
import Link from "next/link";
import React, { useState } from "react";
import { useSession } from "next-auth/react";
import { Button } from "@chakra-ui/react";
export default function Account() {
const { data: session } = useSession();
const [username, setUsername] = useState("null");
const handleUpdate = async () => {
const response = await fetch("../api/update", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ username }),
});
const { name } = await response.json();
setUsername(name);
};
if (!session) {
return;
}
return (
<>
<Head>
<title>Open Assistant</title>
<meta
name="description"
content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world."
/>
</Head>
<main className="h-3/4 z-0 bg-white flex flex-col items-center justify-center">
<p>{username}</p>
<Button>
<Link href="/account/edit">Edit Username</Link>
</Button>
<p>{session.user.email}</p>
</main>
</>
);
}
+2 -1
View File
@@ -1,5 +1,6 @@
import type { AuthOptions } from "next-auth";
import NextAuth from "next-auth";
import { NextApiHandler } from "next";
import DiscordProvider from "next-auth/providers/discord";
import EmailProvider from "next-auth/providers/email";
import CredentialsProvider from "next-auth/providers/credentials";
@@ -56,7 +57,7 @@ export const authOptions: AuthOptions = {
adapter: PrismaAdapter(prisma),
providers,
pages: {
signIn: "/auth/signup",
signIn: "/auth/signin",
verifyRequest: "/auth/verify",
// error: "/auth/error", -Will be used later
},
+22
View File
@@ -0,0 +1,22 @@
import { getSession } from "next-auth/react";
import { Prisma } from "@prisma/client";
import Email from "next-auth/providers/email";
// POST /api/post
// Required fields in body: title
// Optional fields in body: content
export default async function handle(req, res) {
const { username } = req.body;
const { email } = req.body;
const session = await getSession({ req });
const result = await prisma.user.update({
where: {
email: session.user.email,
},
data: {
name: username,
},
});
res.json({ name: result.name });
}
@@ -1,35 +1,19 @@
import { Button, Input, Stack } from "@chakra-ui/react";
import Head from "next/head";
import Link from "next/link";
import { FaDiscord, FaEnvelope, FaGithub } from "react-icons/fa";
import { getCsrfToken, getProviders, signIn } from "next-auth/react";
import { useRef } from "react";
import { FaBug, FaDiscord, FaEnvelope, FaGithub } from "react-icons/fa";
import Link from "next/link";
import { AuthLayout } from "src/components/AuthLayout";
export default function Signin({ csrfToken, providers }) {
const { discord, email, github, credentials } = providers;
const { discord, email, github } = providers;
const emailEl = useRef(null);
const debugUsernameEl = useRef(null);
const signinWithDiscord = () => {
signIn(discord.id, { callbackUrl: "/" });
};
const signinWithEmail = (ev: React.FormEvent) => {
ev.preventDefault();
const signinWithEmail = () => {
signIn(email.id, { callbackUrl: "/", email: emailEl.current.value });
};
const signinWithGithub = () => {
signIn(github.id, { callbackUrl: "/" });
};
function signinWithDebugCredentials(ev: React.FormEvent) {
ev.preventDefault();
signIn(credentials.id, { callbackUrl: "/", username: debugUsernameEl.current.value });
}
return (
<>
<Head>
@@ -37,27 +21,20 @@ export default function Signin({ csrfToken, providers }) {
<meta name="Sign Up" content="Sign up to access Open Assistant" />
</Head>
<AuthLayout>
<Stack spacing="6">
{credentials && (
<form onSubmit={signinWithDebugCredentials} className="border-2 border-orange-200 rounded-md p-4 relative">
<span className="text-orange-600 absolute -top-3 left-5 bg-white px-1">For Debugging Only</span>
<Stack>
<Input variant="outline" size="lg" placeholder="Username" ref={debugUsernameEl} />
<Button size={"lg"} leftIcon={<FaBug />} colorScheme="gray" type="submit">
Continue with Debug User
</Button>
</Stack>
</form>
)}
<Stack spacing="2">
{email && (
<form onSubmit={signinWithEmail}>
<Stack>
<Input variant="outline" size="lg" placeholder="Email Address" ref={emailEl} />
<Button size={"lg"} leftIcon={<FaEnvelope />} colorScheme="gray">
Continue with Email
</Button>
</Stack>
</form>
<Stack>
<Input variant="outline" size="lg" placeholder="Email Address" ref={emailEl} />
<Button
size={"lg"}
leftIcon={<FaEnvelope />}
colorScheme="gray"
onClick={signinWithEmail}
// isDisabled="false"
>
Continue with Email
</Button>
</Stack>
)}
{discord && (
<Button
@@ -69,7 +46,8 @@ export default function Signin({ csrfToken, providers }) {
size="lg"
leftIcon={<FaDiscord />}
color="white"
onClick={signinWithDiscord}
onClick={() => signIn(discord, { callbackUrl: "/" })}
// isDisabled="false"
>
Continue with Discord
</Button>
@@ -84,7 +62,7 @@ export default function Signin({ csrfToken, providers }) {
size={"lg"}
leftIcon={<FaGithub />}
colorScheme="blue"
onClick={signinWithGithub}
// isDisabled="false"
>
Continue with Github
</Button>
-8
View File
@@ -13,14 +13,6 @@ export default function Verify() {
</Head>
<AuthLayout>
<h1 className="text-lg">A sign-in link has been sent to your email address.</h1>
<hr className="mt-14 mb-4 h-px bg-gray-200 border-0" />
<Link
href="#"
aria-label="Log In"
className="flex justify-center font-medium text-black hover:underline underline-offset-4"
>
Already have an account? Log In
</Link>
</AuthLayout>
</>
);
+6 -4
View File
@@ -8,6 +8,7 @@ import poster from "src/lib/poster";
import { Messages } from "src/components/Messages";
import { TwoColumns } from "src/components/TwoColumns";
import { Button } from "src/components/Button";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
const AssistantReply = () => {
const [tasks, setTasks] = useState([]);
@@ -39,11 +40,12 @@ const AssistantReply = () => {
});
};
/**
* TODO: Make this a nicer loading screen.
*/
if (isLoading) {
return <LoadingScreen text="Loading..." />;
}
if (tasks.length == 0) {
return <div className="p-6 bg-slate-100 text-gray-800">Loading...</div>;
return <div className="p-6 bg-slate-100 text-gray-800">No tasks found...</div>;
}
const task = tasks[0].task;
+6 -4
View File
@@ -9,6 +9,7 @@ import poster from "src/lib/poster";
import { TwoColumns } from "src/components/TwoColumns";
import { Button } from "src/components/Button";
import { LoadingScreen } from "@/components/Loading/LoadingScreen";
const SummarizeStory = () => {
// Use an array of tasks that record the sequence of steps until a task is
@@ -49,11 +50,12 @@ const SummarizeStory = () => {
});
};
/**
* TODO: Make this a nicer loading screen.
*/
if (isLoading) {
return <LoadingScreen text="Loading..." />;
}
if (tasks.length == 0) {
return <div className=" p-6 bg-slate-100 text-gray-800">Loading...</div>;
return <div className="p-6 bg-slate-100 text-gray-800">No tasks found...</div>;
}
return (
+6 -4
View File
@@ -8,6 +8,7 @@ import poster from "src/lib/poster";
import { Messages } from "src/components/Messages";
import { TwoColumns } from "src/components/TwoColumns";
import { Button } from "src/components/Button";
import { LoadingScreen } from "@/components/Loading/LoadingScreen";
const UserReply = () => {
const [tasks, setTasks] = useState([]);
@@ -39,11 +40,12 @@ const UserReply = () => {
});
};
/**
* TODO: Make this a nicer loading screen.
*/
if (isLoading) {
return <LoadingScreen text="Loading..." />;
}
if (tasks.length == 0) {
return <div className="p-6 bg-slate-100 text-gray-800">Loading...</div>;
return <div className="p-6 bg-slate-100 text-gray-800">No tasks found...</div>;
}
const task = tasks[0].task;
@@ -9,6 +9,7 @@ import fetcher from "src/lib/fetcher";
import poster from "src/lib/poster";
import { Button } from "src/components/Button";
import { LoadingScreen } from "@/components/Loading/LoadingScreen";
const RankInitialPrompts = () => {
const [tasks, setTasks] = useState([]);
@@ -44,12 +45,14 @@ const RankInitialPrompts = () => {
});
};
/**
* TODO: Make this a nicer loading screen.
*/
if (tasks.length == 0) {
return <div className="p-6 bg-slate-100 text-gray-800">Loading...</div>;
if (isLoading) {
return <LoadingScreen text="Loading..." />;
}
if (tasks.length == 0) {
return <div className="p-6 bg-slate-100 text-gray-800">No tasks found...</div>;
}
const prompts = tasks[0].task.prompts as string[];
const items = ranking.map((i) => ({
text: prompts[i],
+6 -4
View File
@@ -11,6 +11,7 @@ import poster from "src/lib/poster";
import { TwoColumns } from "src/components/TwoColumns";
import { Button } from "src/components/Button";
import { LoadingScreen } from "@/components/Loading/LoadingScreen";
const RateSummary = () => {
// Use an array of tasks that record the sequence of steps until a task is
@@ -49,11 +50,12 @@ const RateSummary = () => {
});
};
/**
* TODO: Make this a nicer loading screen.
*/
if (isLoading) {
return <LoadingScreen text="Loading..." />;
}
if (tasks.length == 0) {
return <div className="p-6 bg-slate-100 text-gray-800">Loading...</div>;
return <div className="p-6 bg-slate-100 text-gray-800">No tasks found...</div>;
}
return (
@@ -1,4 +1,4 @@
import RankItem from "@/components/RankItem";
import RankItem from "src/components/RankItem";
import { BarsArrowUpIcon, BarsArrowDownIcon } from "@heroicons/react/24/solid";
import Image from "next/image";
import { HiBarsArrowUp, HiBarsArrowDown } from "react-icons/hi2";