mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Merge branch 'main' of github.com:LAION-AI/Open-Chat-GPT
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
|
||||
import alembic.command
|
||||
@@ -7,10 +8,29 @@ import fastapi
|
||||
from loguru import logger
|
||||
from oasst_backend.api.v1.api import api_router
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.exceptions import OasstError, OasstErrorCode
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
app = fastapi.FastAPI(title=settings.PROJECT_NAME, openapi_url=f"{settings.API_V1_STR}/openapi.json")
|
||||
|
||||
|
||||
@app.exception_handler(OasstError)
|
||||
async def oasst_exception_handler(request: fastapi.Request, ex: OasstError):
|
||||
logger.error(f"{request.method} {request.url} failed: {repr(ex)}")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=int(ex.http_status_code), content={"message": ex.message, "error_code": ex.error_code}
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def unhandled_exception_handler(request: fastapi.Request, ex: Exception):
|
||||
logger.exception(f"{request.method} {request.url} failed [UNHANDLED]: {repr(ex)}")
|
||||
status = HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=status.value, content={"message": status.name, "error_code": OasstErrorCode.GENERIC_ERROR}
|
||||
)
|
||||
|
||||
|
||||
# Set all CORS enabled origins
|
||||
if settings.BACKEND_CORS_ORIGINS:
|
||||
app.add_middleware(
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from http import HTTPStatus
|
||||
from secrets import token_hex
|
||||
from typing import Generator
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException, Security
|
||||
from fastapi import Security
|
||||
from fastapi.security.api_key import APIKey, APIKeyHeader, APIKeyQuery
|
||||
from loguru import logger
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.database import engine
|
||||
from oasst_backend.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_backend.models import ApiClient
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_403_FORBIDDEN
|
||||
|
||||
|
||||
def get_db() -> Generator:
|
||||
@@ -36,22 +37,26 @@ def api_auth(
|
||||
api_key: APIKey,
|
||||
db: Session,
|
||||
) -> ApiClient:
|
||||
if api_key or settings.DEBUG_SKIP_API_KEY_CHECK:
|
||||
|
||||
if api_key is None and not settings.DEBUG_SKIP_API_KEY_CHECK:
|
||||
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Could not validate credentials")
|
||||
if settings.DEBUG_SKIP_API_KEY_CHECK or settings.DEBUG_ALLOW_ANY_API_KEY:
|
||||
# make sure that a dummy api key exits in db (foreign key references)
|
||||
ANY_API_KEY_ID = UUID("00000000-1111-2222-3333-444444444444")
|
||||
api_client: ApiClient = db.query(ApiClient).filter(ApiClient.id == ANY_API_KEY_ID).first()
|
||||
if api_client is None:
|
||||
token = token_hex(32)
|
||||
logger.info(f"ANY_API_KEY missing, inserting api_key: {token}")
|
||||
api_client = ApiClient(id=ANY_API_KEY_ID, api_key=token, description="ANY_API_KEY, random token")
|
||||
db.add(api_client)
|
||||
db.commit()
|
||||
return api_client
|
||||
|
||||
if settings.DEBUG_SKIP_API_KEY_CHECK or settings.DEBUG_ALLOW_ANY_API_KEY:
|
||||
# make sure that a dummy api key exits in db (foreign key references)
|
||||
ANY_API_KEY_ID = UUID("00000000-1111-2222-3333-444444444444")
|
||||
api_client: ApiClient = db.query(ApiClient).filter(ApiClient.id == ANY_API_KEY_ID).first()
|
||||
if api_client is None:
|
||||
token = token_hex(32)
|
||||
logger.info(f"ANY_API_KEY missing, inserting api_key: {token}")
|
||||
api_client = ApiClient(id=ANY_API_KEY_ID, api_key=token, description="ANY_API_KEY, random token")
|
||||
db.add(api_client)
|
||||
db.commit()
|
||||
return api_client
|
||||
api_client = db.query(ApiClient).filter(ApiClient.api_key == api_key).first()
|
||||
if api_client is not None and api_client.enabled:
|
||||
return api_client
|
||||
|
||||
api_client = db.query(ApiClient).filter(ApiClient.api_key == api_key).first()
|
||||
if api_client is not None and api_client.enabled:
|
||||
return api_client
|
||||
raise OasstError(
|
||||
"Could not validate credentials",
|
||||
error_code=OasstErrorCode.API_CLIENT_NOT_AUTHORIZED,
|
||||
http_status_code=HTTPStatus.FORBIDDEN,
|
||||
)
|
||||
|
||||
@@ -3,14 +3,14 @@ import random
|
||||
from typing import Any, Optional, Tuple
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.security.api_key import APIKey
|
||||
from loguru import logger
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_400_BAD_REQUEST
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -107,10 +107,7 @@ def generate_task(
|
||||
replies=replies,
|
||||
)
|
||||
case _:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid request type.",
|
||||
)
|
||||
raise OasstError("Invalid request type", OasstErrorCode.TASK_INVALID_REQUEST_TYPE)
|
||||
|
||||
logger.info(f"Generated {task=}.")
|
||||
|
||||
@@ -134,11 +131,11 @@ def request_task(
|
||||
task, thread_id, parent_post_id = generate_task(request, pr)
|
||||
pr.store_task(task, thread_id, parent_post_id)
|
||||
|
||||
except OasstError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to generate task.")
|
||||
raise HTTPException(
|
||||
status_code=HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
logger.exception("Failed to generate task..")
|
||||
raise OasstError("Failed to generate task.", OasstErrorCode.TASK_GENERATION_FAILED)
|
||||
return task
|
||||
|
||||
|
||||
@@ -163,11 +160,11 @@ def acknowledge_task(
|
||||
logger.info(f"Frontend acknowledges task {task_id=}, {ack_request=}.")
|
||||
pr.bind_frontend_post_id(task_id=task_id, post_id=ack_request.post_id)
|
||||
|
||||
except OasstError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to acknowledge task.")
|
||||
raise HTTPException(
|
||||
status_code=HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
raise OasstError("Failed to acknowledge task.", OasstErrorCode.TASK_ACK_FAILED)
|
||||
return {}
|
||||
|
||||
|
||||
@@ -189,10 +186,8 @@ def acknowledge_task_failure(
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr.acknowledge_task_failure(task_id)
|
||||
except (KeyError, RuntimeError):
|
||||
logger.exception("Failed to acknowledge task.")
|
||||
raise HTTPException(
|
||||
status_code=HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
logger.exception("Failed to not acknowledge task.")
|
||||
raise OasstError("Failed to not acknowledge task.", OasstErrorCode.TASK_NACK_FAILED)
|
||||
|
||||
|
||||
@router.post("/interaction")
|
||||
@@ -241,13 +236,9 @@ def post_interaction(
|
||||
# here we would store the ranking in the database
|
||||
return protocol_schema.TaskDone()
|
||||
case _:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid response type.",
|
||||
)
|
||||
|
||||
raise OasstError("Invalid response type.", OasstErrorCode.TASK_INVALID_RESPONSE_TYPE)
|
||||
except OasstError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Interaction request failed.")
|
||||
raise HTTPException(
|
||||
status_code=HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
raise OasstError("Interaction request failed.", OasstErrorCode.TASK_INTERACTION_REQUEST_FAILED)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.exceptions import OasstError, OasstErrorCode
|
||||
from sqlmodel import create_engine
|
||||
|
||||
if settings.DATABASE_URI is None:
|
||||
raise ValueError("DATABASE_URI is not set")
|
||||
raise OasstError("DATABASE_URI is not set", error_code=OasstErrorCode.DATABASE_URI_NOT_SET)
|
||||
|
||||
engine = create_engine(settings.DATABASE_URI)
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from enum import IntEnum
|
||||
from http import HTTPStatus
|
||||
|
||||
|
||||
class OasstErrorCode(IntEnum):
|
||||
"""
|
||||
Error codes of the Open-Assistant backend API.
|
||||
|
||||
Ranges:
|
||||
0-1000: general errors
|
||||
1000-2000: tasks endpoint
|
||||
2000-3000: prompt_repository
|
||||
"""
|
||||
|
||||
# 0-1000: general errors
|
||||
GENERIC_ERROR = 0
|
||||
DATABASE_URI_NOT_SET = 1
|
||||
API_CLIENT_NOT_AUTHORIZED = 2
|
||||
|
||||
# 1000-2000: tasks endpoint
|
||||
TASK_INVALID_REQUEST_TYPE = 1000
|
||||
TASK_ACK_FAILED = 1001
|
||||
TASK_NACK_FAILED = 1002
|
||||
TASK_INVALID_RESPONSE_TYPE = 1003
|
||||
TASK_INTERACTION_REQUEST_FAILED = 1004
|
||||
TASK_GENERATION_FAILED = 1005
|
||||
|
||||
# 2000-3000: prompt_repository
|
||||
INVALID_POST_ID = 2000
|
||||
POST_NOT_FOUND = 2001
|
||||
RATING_OUT_OF_RANGE = 2002
|
||||
INVALID_RANKING_VALUE = 2003
|
||||
INVALID_TASK_TYPE = 2004
|
||||
USER_NOT_SPECIFIED = 2005
|
||||
NO_THREADS_FOUND = 2006
|
||||
WORK_PACKAGE_NOT_FOUND = 2100
|
||||
WORK_PACKAGE_EXPIRED = 2101
|
||||
WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH = 2102
|
||||
WORK_PACKAGE_ALREADY_UPDATED = 2103
|
||||
WORK_PACKAGE_NOT_ACK = 2104
|
||||
WORK_PACKAGE_ALREADY_DONE = 2105
|
||||
|
||||
|
||||
class OasstError(Exception):
|
||||
"""Base class for Open-Assistant exceptions."""
|
||||
|
||||
message: str
|
||||
error_code: int
|
||||
http_status_code: HTTPStatus
|
||||
|
||||
def __init__(self, message: str, error_code: OasstErrorCode, http_status_code: HTTPStatus = HTTPStatus.BAD_REQUEST):
|
||||
super().__init__(message, error_code, http_status_code) # make excetpion picklable (fill args member)
|
||||
self.message = message
|
||||
self.error_code = error_code
|
||||
self.http_status_code = http_status_code
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
return f'{class_name}(message="{self.message}", error_code={self.error_code}, http_status_code={self.http_status_code})'
|
||||
@@ -5,6 +5,7 @@ from uuid import UUID, uuid4
|
||||
|
||||
import oasst_backend.models.db_payload as db_payload
|
||||
from loguru import logger
|
||||
from oasst_backend.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_backend.journal_writer import JournalWriter
|
||||
from oasst_backend.models import ApiClient, Person, Post, PostReaction, TextLabels, WorkPackage
|
||||
from oasst_backend.models.payload_column_type import PayloadContainer
|
||||
@@ -52,9 +53,9 @@ class PromptRepository:
|
||||
|
||||
def validate_post_id(self, post_id: str) -> None:
|
||||
if not isinstance(post_id, str):
|
||||
raise TypeError(f"post_id must be string, not {type(post_id)}")
|
||||
raise OasstError(f"post_id must be string, not {type(post_id)}", OasstErrorCode.INVALID_POST_ID)
|
||||
if not post_id:
|
||||
raise ValueError("post_id must not be empty")
|
||||
raise OasstError("post_id must not be empty", OasstErrorCode.INVALID_POST_ID)
|
||||
|
||||
def bind_frontend_post_id(self, task_id: UUID, post_id: str):
|
||||
self.validate_post_id(post_id)
|
||||
@@ -66,11 +67,11 @@ class PromptRepository:
|
||||
.first()
|
||||
)
|
||||
if work_pack is None:
|
||||
raise KeyError(f"WorkPackage for task {task_id} not found")
|
||||
raise OasstError(f"WorkPackage for task {task_id} not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND)
|
||||
if work_pack.expired:
|
||||
raise RuntimeError("WorkPackage already expired.")
|
||||
raise OasstError("WorkPackage already expired.", OasstErrorCode.WORK_PACKAGE_EXPIRED)
|
||||
if work_pack.done or work_pack.ack is not None:
|
||||
raise RuntimeError("WorkPackage already updated.")
|
||||
raise OasstError("WorkPackage already updated.", OasstErrorCode.WORK_PACKAGE_ALREADY_UPDATED)
|
||||
|
||||
work_pack.frontend_ref_post_id = post_id
|
||||
work_pack.ack = True
|
||||
@@ -86,11 +87,11 @@ class PromptRepository:
|
||||
.first()
|
||||
)
|
||||
if work_pack is None:
|
||||
raise KeyError(f"WorkPackage for task {task_id} not found")
|
||||
raise OasstError(f"WorkPackage for task {task_id} not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND)
|
||||
if work_pack.expired:
|
||||
raise RuntimeError("WorkPackage already expired.")
|
||||
raise OasstError("WorkPackage already expired.", OasstErrorCode.WORK_PACKAGE_EXPIRED)
|
||||
if work_pack.done or work_pack.ack is not None:
|
||||
raise RuntimeError("WorkPackage already updated.")
|
||||
raise OasstError("WorkPackage already updated.", OasstErrorCode.WORK_PACKAGE_ALREADY_UPDATED)
|
||||
|
||||
work_pack.ack = False
|
||||
# ToDo: check race-condition, transaction
|
||||
@@ -105,7 +106,7 @@ class PromptRepository:
|
||||
.one_or_none()
|
||||
)
|
||||
if fail_if_missing and post is None:
|
||||
raise KeyError(f"Post with post_id {frontend_post_id} not found.")
|
||||
raise OasstError(f"Post with post_id {frontend_post_id} not found.", OasstErrorCode.POST_NOT_FOUND)
|
||||
return post
|
||||
|
||||
def fetch_workpackage_by_postid(self, post_id: str) -> WorkPackage:
|
||||
@@ -124,13 +125,13 @@ class PromptRepository:
|
||||
wp = self.fetch_workpackage_by_postid(post_id)
|
||||
|
||||
if wp is None:
|
||||
raise KeyError(f"WorkPackage for {post_id=} not found")
|
||||
raise OasstError(f"WorkPackage for {post_id=} not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND)
|
||||
if wp.expired:
|
||||
raise RuntimeError("WorkPackage already expired.")
|
||||
raise OasstError("WorkPackage already expired.", OasstErrorCode.WORK_PACKAGE_EXPIRED)
|
||||
if not wp.ack:
|
||||
raise RuntimeError("WorkPackage is not acknowledged.")
|
||||
raise OasstError("WorkPackage is not acknowledged.", OasstErrorCode.WORK_PACKAGE_NOT_ACK)
|
||||
if wp.done:
|
||||
raise RuntimeError("WorkPackage already done.")
|
||||
raise OasstError("WorkPackage already done.", OasstErrorCode.WORK_PACKAGE_ALREADY_DONE)
|
||||
|
||||
# If there's no parent post assume user started new conversation
|
||||
role = "user"
|
||||
@@ -171,12 +172,16 @@ class PromptRepository:
|
||||
work_package = self.fetch_workpackage_by_postid(rating.post_id)
|
||||
work_payload: db_payload.RateSummaryPayload = work_package.payload.payload
|
||||
if type(work_payload) != db_payload.RateSummaryPayload:
|
||||
raise ValueError(
|
||||
f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RateSummaryPayload}"
|
||||
raise OasstError(
|
||||
f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RateSummaryPayload}",
|
||||
OasstErrorCode.WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH,
|
||||
)
|
||||
|
||||
if rating.rating < work_payload.scale.min or rating.rating > work_payload.scale.max:
|
||||
raise ValueError(f"Invalid rating value: {rating.rating=} not in {work_payload.scale=}")
|
||||
raise OasstError(
|
||||
f"Invalid rating value: {rating.rating=} not in {work_payload.scale=}",
|
||||
OasstErrorCode.RATING_OUT_OF_RANGE,
|
||||
)
|
||||
|
||||
# store reaction to post
|
||||
reaction_payload = db_payload.RatingReactionPayload(rating=rating.rating)
|
||||
@@ -201,8 +206,9 @@ class PromptRepository:
|
||||
# validate ranking
|
||||
num_replies = len(work_payload.replies)
|
||||
if sorted(ranking.ranking) != list(range(num_replies)):
|
||||
raise ValueError(
|
||||
f"Invalid ranking submitted. Each reply index must appear exactly once ({num_replies=})."
|
||||
raise OasstError(
|
||||
f"Invalid ranking submitted. Each reply index must appear exactly once ({num_replies=}).",
|
||||
OasstErrorCode.INVALID_RANKING_VALUE,
|
||||
)
|
||||
|
||||
# store reaction to post
|
||||
@@ -218,8 +224,9 @@ class PromptRepository:
|
||||
case db_payload.RankInitialPromptsPayload:
|
||||
# validate ranking
|
||||
if sorted(ranking.ranking) != list(range(num_prompts := len(work_payload.prompts))):
|
||||
raise ValueError(
|
||||
f"Invalid ranking submitted. Each reply index must appear exactly once ({num_prompts=})."
|
||||
raise OasstError(
|
||||
f"Invalid ranking submitted. Each reply index must appear exactly once ({num_prompts=}).",
|
||||
OasstErrorCode.INVALID_RANKING_VALUE,
|
||||
)
|
||||
|
||||
# store reaction to post
|
||||
@@ -233,8 +240,9 @@ class PromptRepository:
|
||||
return reaction
|
||||
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RankConversationRepliesPayload}"
|
||||
raise OasstError(
|
||||
f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RankConversationRepliesPayload}",
|
||||
OasstErrorCode.WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH,
|
||||
)
|
||||
|
||||
def store_task(
|
||||
@@ -276,7 +284,7 @@ class PromptRepository:
|
||||
)
|
||||
|
||||
case _:
|
||||
raise ValueError(f"Invalid task type: {type(task)=}")
|
||||
raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE)
|
||||
|
||||
wp = self.insert_work_package(
|
||||
payload=payload,
|
||||
@@ -348,7 +356,7 @@ class PromptRepository:
|
||||
|
||||
def insert_reaction(self, work_package_id: UUID, payload: db_payload.ReactionPayload) -> PostReaction:
|
||||
if self.person_id is None:
|
||||
raise ValueError("User required")
|
||||
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
|
||||
|
||||
container = PayloadContainer(payload=payload)
|
||||
reaction = PostReaction(
|
||||
@@ -405,7 +413,7 @@ class PromptRepository:
|
||||
"""
|
||||
thread_posts = self.fetch_random_thread(last_post_role)
|
||||
if not thread_posts:
|
||||
raise RuntimeError("No threads found")
|
||||
raise OasstError("No threads found", OasstErrorCode.NO_THREADS_FOUND)
|
||||
if last_post_role:
|
||||
conv_posts = [p for p in thread_posts if p.role == last_post_role]
|
||||
conv_posts = [random.choice(conv_posts)]
|
||||
|
||||
@@ -1,31 +1,75 @@
|
||||
import Image from "next/image";
|
||||
import Link from "next/link";
|
||||
|
||||
import { FaGithub, FaDiscord } from "react-icons/fa";
|
||||
import { Container } from "./Container";
|
||||
import { NavLinks } from "./NavLinks";
|
||||
|
||||
export function Footer() {
|
||||
return (
|
||||
<footer className="border-t border-gray-200 bg-white">
|
||||
<Container className="">
|
||||
<div className="flex flex-col items-start justify-between gap-y-12 pt-16 pb-6 lg:flex-row lg:items-center lg:py-6">
|
||||
<div>
|
||||
<div className="flex items-center text-gray-900">
|
||||
<main>
|
||||
<Container className="">
|
||||
<div className="flex flex-wrap justify-between gap-y-12 py-10 lg:items-center lg:py-16">
|
||||
<div className="flex items-center text-black pr-8">
|
||||
<Link href="/" aria-label="Home" className="flex items-center">
|
||||
<Image src="/images/logos/logo.svg" className="mx-auto object-fill" width="50" height="50" alt="logo" />
|
||||
<Image src="/images/logos/logo.svg" className="mx-auto object-fill" width="52" height="52" alt="logo" />
|
||||
</Link>
|
||||
|
||||
<div className="ml-4">
|
||||
<p className="text-base font-semibold">Open Assistant</p>
|
||||
<p className="mt-1 text-sm">Conversational AI for everyone.</p>
|
||||
<div className="ml-2">
|
||||
<p className="text-base font-bold">Open Assistant</p>
|
||||
<p className="text-sm">Conversational AI for everyone.</p>
|
||||
</div>
|
||||
</div>
|
||||
{/* <nav className="mt-11 flex gap-8">
|
||||
<NavLinks />
|
||||
</nav> */}
|
||||
<nav className="flex justify-center gap-20">
|
||||
<div className="flex flex-col text-sm leading-7">
|
||||
<b>Information</b>
|
||||
<div className="flex flex-col leading-5">
|
||||
<Link href="#" aria-label="Our Team" className="hover:underline underline-offset-2">
|
||||
Our Team
|
||||
</Link>
|
||||
<Link href="#join-us" aria-label="Join Us" className="hover:underline underline-offset-2">
|
||||
Join Us
|
||||
</Link>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex flex-col text-sm leading-7">
|
||||
<b>Legal</b>
|
||||
<div className="flex flex-col leading-5">
|
||||
<Link href="#" aria-label="Privacy Policy" className="hover:underline underline-offset-2">
|
||||
Privacy Policy
|
||||
</Link>
|
||||
<Link href="#" aria-label="Terms of Service" className="hover:underline underline-offset-2">
|
||||
Terms of Service
|
||||
</Link>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex flex-col text-sm leading-7">
|
||||
<b>Connect</b>
|
||||
<div className="flex flex-col leading-5">
|
||||
<Link
|
||||
href="https://github.com/LAION-AI/Open-Assistant"
|
||||
rel="noopener noreferrer nofollow"
|
||||
target="_blank"
|
||||
aria-label="Privacy Policy"
|
||||
className="hover:underline underline-offset-2"
|
||||
>
|
||||
Github
|
||||
</Link>
|
||||
<Link
|
||||
href="https://discord.gg/pXtnYk9c"
|
||||
rel="noopener noreferrer nofollow"
|
||||
target="_blank"
|
||||
aria-label="Terms of Service"
|
||||
className="hover:underline underline-offset-2"
|
||||
>
|
||||
Discord
|
||||
</Link>
|
||||
</div>
|
||||
</div>
|
||||
</nav>
|
||||
</div>
|
||||
</div>
|
||||
</Container>
|
||||
</Container>
|
||||
</main>
|
||||
</footer>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import Link from "next/link";
|
||||
import { signOut, useSession } from "next-auth/react";
|
||||
import { FaUser, FaSignOutAlt } from "react-icons/fa";
|
||||
|
||||
import { Avatar } from "./Avatar";
|
||||
import { UserMenu } from "./UserMenu";
|
||||
import { Container } from "./Container";
|
||||
import { NavLinks } from "./NavLinks";
|
||||
|
||||
@@ -45,9 +45,9 @@ function AccountButton() {
|
||||
return;
|
||||
}
|
||||
return (
|
||||
<Link href="/auth/signup" aria-label="Home" className="flex items-center">
|
||||
<Link href="/auth/signin" aria-label="Home" className="flex items-center">
|
||||
<Button variant="outline" leftIcon={<FaUser />}>
|
||||
Log in
|
||||
Sign in
|
||||
</Button>
|
||||
</Link>
|
||||
);
|
||||
@@ -113,7 +113,7 @@ export function Header() {
|
||||
)}
|
||||
</Popover>
|
||||
<AccountButton />
|
||||
<Avatar />
|
||||
<UserMenu />
|
||||
</div>
|
||||
</Container>
|
||||
</nav>
|
||||
|
||||
@@ -5,18 +5,18 @@ import { Popover } from "@headlessui/react";
|
||||
import { AnimatePresence, motion } from "framer-motion";
|
||||
import { FaCog, FaSignOutAlt, FaGithub } from "react-icons/fa";
|
||||
|
||||
export function Avatar() {
|
||||
export function UserMenu() {
|
||||
const { data: session } = useSession();
|
||||
|
||||
if (!session) {
|
||||
return <></>;
|
||||
}
|
||||
if (session && session.user) {
|
||||
const displayName = session.user.name || session.user.email;
|
||||
const email = session.user.email;
|
||||
const accountOptions = [
|
||||
{
|
||||
name: "Account Settings",
|
||||
href: "#",
|
||||
href: "/account",
|
||||
desc: "Account Settings",
|
||||
icon: FaCog,
|
||||
//For future use
|
||||
@@ -35,8 +35,7 @@ export function Avatar() {
|
||||
height="40"
|
||||
className="rounded-full"
|
||||
></Image>
|
||||
<p className="hidden lg:flex">{displayName}</p>
|
||||
{/* Will be changed to username once it is implemented */}
|
||||
<p className="hidden lg:flex">{session.user.name || session.user.email}</p>
|
||||
</div>
|
||||
</Popover.Button>
|
||||
<AnimatePresence initial={false}>
|
||||
@@ -72,7 +71,7 @@ export function Avatar() {
|
||||
))}
|
||||
<a
|
||||
className="flex items-center rounded-md hover:bg-gray-100 cursor-pointer"
|
||||
onClick={() => signOut()}
|
||||
onClick={() => signOut({ callbackUrl: "/" })}
|
||||
>
|
||||
<div className="p-4">
|
||||
<FaSignOutAlt />
|
||||
@@ -93,4 +92,4 @@ export function Avatar() {
|
||||
}
|
||||
}
|
||||
|
||||
export default Avatar;
|
||||
export default UserMenu;
|
||||
@@ -4,5 +4,5 @@ export { default } from "next-auth/middleware";
|
||||
* Guards all pages under `/grading` and redirects them to the sign in page.
|
||||
*/
|
||||
export const config = {
|
||||
matcher: ["/create/:path*", "/evaluate/:path*"],
|
||||
matcher: ["/create/:path*", "/evaluate/:path*", "/account/:path*"],
|
||||
};
|
||||
|
||||
@@ -5,41 +5,15 @@ import Head from "next/head";
|
||||
import Link from "next/link";
|
||||
|
||||
export default function Error() {
|
||||
const { data: session } = useSession();
|
||||
|
||||
if (!session) {
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<title>Open Assistant</title>
|
||||
<meta name="404" content="Sorry, this page doesn't exist." />
|
||||
</Head>
|
||||
<Header />
|
||||
<main className="flex h-3/4 items-center justify-center overflow-hidden subpixel-antialiased text-xl">
|
||||
{"Sorry, the page you're looking for does not exist."}
|
||||
</main>
|
||||
<Footer />
|
||||
</>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<title>Open Assistant</title>
|
||||
<meta
|
||||
name="description"
|
||||
content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world."
|
||||
/>
|
||||
<title>404 - Open Assistant</title>
|
||||
<meta name="404" content="Sorry, this page doesn't exist." />
|
||||
</Head>
|
||||
<Header />
|
||||
<main>
|
||||
<h2>Open Chat Gpt</h2>
|
||||
|
||||
<p>You are logged in</p>
|
||||
|
||||
<Link href="/grading/grade-output">~Rate a prompt and output now~</Link>
|
||||
<main className="flex h-3/4 items-center justify-center overflow-hidden subpixel-antialiased text-xl">
|
||||
<p>Sorry, the page you are looking for does not exist.</p>
|
||||
</main>
|
||||
<Footer />
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
import React, { useState } from "react";
|
||||
import { useSession } from "next-auth/react";
|
||||
import { Button, Input, InputGroup, Stack } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import Router from "next/router";
|
||||
|
||||
export default function Account() {
|
||||
const { data: session } = useSession();
|
||||
const [username, setUsername] = useState("");
|
||||
const updateUser = async (e: React.SyntheticEvent) => {
|
||||
e.preventDefault();
|
||||
try {
|
||||
const body = { username };
|
||||
await fetch("/api/username", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
await Router.push("/account");
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
}
|
||||
};
|
||||
|
||||
if (!session) {
|
||||
return;
|
||||
}
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<title>Open Assistant</title>
|
||||
<meta
|
||||
name="description"
|
||||
content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world."
|
||||
/>
|
||||
</Head>
|
||||
<main className="h-3/4 z-0 bg-white flex flex-col items-center justify-center">
|
||||
<p>{session.user.name || "No username"}</p>
|
||||
<form onSubmit={updateUser}>
|
||||
<InputGroup>
|
||||
<Input
|
||||
onChange={(e) => setUsername(e.target.value)}
|
||||
placeholder="Edit Username"
|
||||
type="text"
|
||||
value={username}
|
||||
></Input>
|
||||
<Button disabled={!username} type="submit" value="Change">
|
||||
Submit
|
||||
</Button>
|
||||
</InputGroup>
|
||||
</form>
|
||||
<p>{session.user.email}</p>
|
||||
</main>
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
import Head from "next/head";
|
||||
import Link from "next/link";
|
||||
import React, { useState } from "react";
|
||||
import { useSession } from "next-auth/react";
|
||||
import { Button } from "@chakra-ui/react";
|
||||
|
||||
export default function Account() {
|
||||
const { data: session } = useSession();
|
||||
const [username, setUsername] = useState("null");
|
||||
|
||||
const handleUpdate = async () => {
|
||||
const response = await fetch("../api/update", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({ username }),
|
||||
});
|
||||
const { name } = await response.json();
|
||||
setUsername(name);
|
||||
};
|
||||
|
||||
if (!session) {
|
||||
return;
|
||||
}
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<title>Open Assistant</title>
|
||||
<meta
|
||||
name="description"
|
||||
content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world."
|
||||
/>
|
||||
</Head>
|
||||
<main className="h-3/4 z-0 bg-white flex flex-col items-center justify-center">
|
||||
<p>{username}</p>
|
||||
<Button>
|
||||
<Link href="/account/edit">Edit Username</Link>
|
||||
</Button>
|
||||
<p>{session.user.email}</p>
|
||||
</main>
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
import type { AuthOptions } from "next-auth";
|
||||
import NextAuth from "next-auth";
|
||||
import { NextApiHandler } from "next";
|
||||
import DiscordProvider from "next-auth/providers/discord";
|
||||
import EmailProvider from "next-auth/providers/email";
|
||||
import CredentialsProvider from "next-auth/providers/credentials";
|
||||
@@ -55,7 +56,7 @@ export const authOptions: AuthOptions = {
|
||||
adapter: PrismaAdapter(prisma),
|
||||
providers,
|
||||
pages: {
|
||||
signIn: "/auth/signup",
|
||||
signIn: "/auth/signin",
|
||||
verifyRequest: "/auth/verify",
|
||||
// error: "/auth/error", -Will be used later
|
||||
},
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
import { getSession } from "next-auth/react";
|
||||
import { Prisma } from "@prisma/client";
|
||||
import Email from "next-auth/providers/email";
|
||||
|
||||
// POST /api/post
|
||||
// Required fields in body: title
|
||||
// Optional fields in body: content
|
||||
export default async function handle(req, res) {
|
||||
const { username } = req.body;
|
||||
const { email } = req.body;
|
||||
|
||||
const session = await getSession({ req });
|
||||
const result = await prisma.user.update({
|
||||
where: {
|
||||
email: session.user.email,
|
||||
},
|
||||
data: {
|
||||
name: username,
|
||||
},
|
||||
});
|
||||
res.json({ name: result.name });
|
||||
}
|
||||
@@ -1,35 +1,19 @@
|
||||
import { Button, Input, Stack } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import Link from "next/link";
|
||||
import { FaDiscord, FaEnvelope, FaGithub } from "react-icons/fa";
|
||||
import { getCsrfToken, getProviders, signIn } from "next-auth/react";
|
||||
import { useRef } from "react";
|
||||
import { FaBug, FaDiscord, FaEnvelope, FaGithub } from "react-icons/fa";
|
||||
import Link from "next/link";
|
||||
|
||||
import { AuthLayout } from "src/components/AuthLayout";
|
||||
|
||||
export default function Signin({ csrfToken, providers }) {
|
||||
const { discord, email, github, credentials } = providers;
|
||||
const { discord, email, github } = providers;
|
||||
const emailEl = useRef(null);
|
||||
const debugUsernameEl = useRef(null);
|
||||
|
||||
const signinWithDiscord = () => {
|
||||
signIn(discord.id, { callbackUrl: "/" });
|
||||
};
|
||||
|
||||
const signinWithEmail = (ev: React.FormEvent) => {
|
||||
ev.preventDefault();
|
||||
const signinWithEmail = () => {
|
||||
signIn(email.id, { callbackUrl: "/", email: emailEl.current.value });
|
||||
};
|
||||
|
||||
const signinWithGithub = () => {
|
||||
signIn(github.id, { callbackUrl: "/" });
|
||||
};
|
||||
|
||||
function signinWithDebugCredentials(ev: React.FormEvent) {
|
||||
ev.preventDefault();
|
||||
signIn(credentials.id, { callbackUrl: "/", username: debugUsernameEl.current.value });
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
@@ -37,27 +21,20 @@ export default function Signin({ csrfToken, providers }) {
|
||||
<meta name="Sign Up" content="Sign up to access Open Assistant" />
|
||||
</Head>
|
||||
<AuthLayout>
|
||||
<Stack spacing="6">
|
||||
{credentials && (
|
||||
<form onSubmit={signinWithDebugCredentials} className="border-2 border-orange-200 rounded-md p-4 relative">
|
||||
<span className="text-orange-600 absolute -top-3 left-5 bg-white px-1">For Debugging Only</span>
|
||||
<Stack>
|
||||
<Input variant="outline" size="lg" placeholder="Username" ref={debugUsernameEl} />
|
||||
<Button size={"lg"} leftIcon={<FaBug />} colorScheme="gray" type="submit">
|
||||
Continue with Debug User
|
||||
</Button>
|
||||
</Stack>
|
||||
</form>
|
||||
)}
|
||||
<Stack spacing="2">
|
||||
{email && (
|
||||
<form onSubmit={signinWithEmail}>
|
||||
<Stack>
|
||||
<Input variant="outline" size="lg" placeholder="Email Address" ref={emailEl} />
|
||||
<Button size={"lg"} leftIcon={<FaEnvelope />} colorScheme="gray">
|
||||
Continue with Email
|
||||
</Button>
|
||||
</Stack>
|
||||
</form>
|
||||
<Stack>
|
||||
<Input variant="outline" size="lg" placeholder="Email Address" ref={emailEl} />
|
||||
<Button
|
||||
size={"lg"}
|
||||
leftIcon={<FaEnvelope />}
|
||||
colorScheme="gray"
|
||||
onClick={signinWithEmail}
|
||||
// isDisabled="false"
|
||||
>
|
||||
Continue with Email
|
||||
</Button>
|
||||
</Stack>
|
||||
)}
|
||||
{discord && (
|
||||
<Button
|
||||
@@ -69,7 +46,8 @@ export default function Signin({ csrfToken, providers }) {
|
||||
size="lg"
|
||||
leftIcon={<FaDiscord />}
|
||||
color="white"
|
||||
onClick={signinWithDiscord}
|
||||
onClick={() => signIn(discord, { callbackUrl: "/" })}
|
||||
// isDisabled="false"
|
||||
>
|
||||
Continue with Discord
|
||||
</Button>
|
||||
@@ -84,7 +62,7 @@ export default function Signin({ csrfToken, providers }) {
|
||||
size={"lg"}
|
||||
leftIcon={<FaGithub />}
|
||||
colorScheme="blue"
|
||||
onClick={signinWithGithub}
|
||||
// isDisabled="false"
|
||||
>
|
||||
Continue with Github
|
||||
</Button>
|
||||
@@ -13,14 +13,6 @@ export default function Verify() {
|
||||
</Head>
|
||||
<AuthLayout>
|
||||
<h1 className="text-lg">A sign-in link has been sent to your email address.</h1>
|
||||
<hr className="mt-14 mb-4 h-px bg-gray-200 border-0" />
|
||||
<Link
|
||||
href="#"
|
||||
aria-label="Log In"
|
||||
className="flex justify-center font-medium text-black hover:underline underline-offset-4"
|
||||
>
|
||||
Already have an account? Log In
|
||||
</Link>
|
||||
</AuthLayout>
|
||||
</>
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user