mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Merge branch 'main' into eslint
This commit is contained in:
@@ -4,7 +4,11 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- website/**
|
||||
pull_request:
|
||||
paths:
|
||||
- website/**
|
||||
workflow_call:
|
||||
|
||||
jobs:
|
||||
|
||||
@@ -14,7 +14,7 @@ repos:
|
||||
# and which break the standard YAML check. The alternative would be to
|
||||
# skip any unsafe errors (and thus break YAML compatibility) or use
|
||||
# some other checker that may not work in general.
|
||||
exclude: copilot/web/addons/*
|
||||
exclude: "^copilot/web/addons/.*$"
|
||||
- id: check-json
|
||||
- id: check-case-conflict
|
||||
- id: detect-private-key
|
||||
|
||||
@@ -107,9 +107,7 @@ To start the demo, run this, in root directory:
|
||||
docker compose up --build
|
||||
```
|
||||
|
||||
Then, navigate to `http://localhost:3000` and interact with the website. When
|
||||
logging in, navigate to `http://localhost:1080` to get the magic email login
|
||||
link.
|
||||
Then, navigate to `http://localhost:3000` and interact with the website.
|
||||
|
||||
### Website
|
||||
|
||||
|
||||
@@ -51,6 +51,7 @@
|
||||
env:
|
||||
POSTGRES_HOST: oasst-postgres
|
||||
DEBUG_ALLOW_ANY_API_KEY: "true"
|
||||
DEBUG_USE_SEED_DATA: "true"
|
||||
MAX_WORKERS: "1"
|
||||
ports:
|
||||
- 8080:8080
|
||||
|
||||
+121
@@ -1,14 +1,21 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import alembic.command
|
||||
import alembic.config
|
||||
import fastapi
|
||||
import pydantic
|
||||
from loguru import logger
|
||||
from oasst_backend.api.deps import get_dummy_api_client
|
||||
from oasst_backend.api.v1.api import api_router
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.database import engine
|
||||
from oasst_backend.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
app = fastapi.FastAPI(title=settings.PROJECT_NAME, openapi_url=f"{settings.API_V1_STR}/openapi.json")
|
||||
@@ -56,4 +63,118 @@ if settings.UPDATE_ALEMBIC:
|
||||
logger.exception("Alembic upgrade failed on startup")
|
||||
|
||||
|
||||
if settings.DEBUG_USE_SEED_DATA:
|
||||
|
||||
@app.on_event("startup")
|
||||
def seed_data():
|
||||
class DummyPost(pydantic.BaseModel):
|
||||
task_post_id: str
|
||||
user_post_id: str
|
||||
parent_post_id: Optional[str]
|
||||
text: str
|
||||
role: str
|
||||
|
||||
try:
|
||||
logger.info("Seed data check began")
|
||||
with Session(engine) as db:
|
||||
api_client = get_dummy_api_client(db)
|
||||
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
|
||||
pr = PromptRepository(db=db, api_client=api_client, user=dummy_user)
|
||||
|
||||
dummy_posts = [
|
||||
DummyPost(
|
||||
task_post_id="de111fa8",
|
||||
user_post_id="6f1d0711",
|
||||
parent_post_id=None,
|
||||
text="Hi!",
|
||||
role="user",
|
||||
),
|
||||
DummyPost(
|
||||
task_post_id="74c381d4",
|
||||
user_post_id="4a24530b",
|
||||
parent_post_id="6f1d0711",
|
||||
text="Hello! How can I help you?",
|
||||
role="assistant",
|
||||
),
|
||||
DummyPost(
|
||||
task_post_id="3d5dc440",
|
||||
user_post_id="a8c01c04",
|
||||
parent_post_id="4a24530b",
|
||||
text="Do you have a recipe for potato soup?",
|
||||
role="user",
|
||||
),
|
||||
DummyPost(
|
||||
task_post_id="643716c1",
|
||||
user_post_id="f43a93b7",
|
||||
parent_post_id="4a24530b",
|
||||
text="Who were the 8 presidents before George Washington?",
|
||||
role="user",
|
||||
),
|
||||
DummyPost(
|
||||
task_post_id="2e4e1e6",
|
||||
user_post_id="c886920",
|
||||
parent_post_id="6f1d0711",
|
||||
text="Hey buddy! How can I serve you?",
|
||||
role="assistant",
|
||||
),
|
||||
DummyPost(
|
||||
task_post_id="970c437d",
|
||||
user_post_id="cec432cf",
|
||||
parent_post_id=None,
|
||||
text="euirdteunvglfe23908230892309832098 AAAAAAAA",
|
||||
role="user",
|
||||
),
|
||||
DummyPost(
|
||||
task_post_id="6066118e",
|
||||
user_post_id="4f85f637",
|
||||
parent_post_id="cec432cf",
|
||||
text="Sorry, I did not understand your request and it is unclear to me what you want me to do. Could you describe it in a different way?",
|
||||
role="assistant",
|
||||
),
|
||||
DummyPost(
|
||||
task_post_id="ba87780d",
|
||||
user_post_id="0e276b98",
|
||||
parent_post_id="cec432cf",
|
||||
text="I'm unsure how to interpret this. Is it a riddle?",
|
||||
role="assistant",
|
||||
),
|
||||
]
|
||||
|
||||
for p in dummy_posts:
|
||||
wp = pr.fetch_workpackage_by_postid(p.task_post_id)
|
||||
if wp and not wp.ack:
|
||||
logger.warning("Deleting unacknowledged seed data work package")
|
||||
db.delete(wp)
|
||||
wp = None
|
||||
if not wp:
|
||||
if p.parent_post_id is None:
|
||||
wp = pr.store_task(
|
||||
protocol_schema.InitialPromptTask(hint=""), thread_id=None, parent_post_id=None
|
||||
)
|
||||
else:
|
||||
print("p.parent_post_id", p.parent_post_id)
|
||||
parent_post = pr.fetch_post_by_frontend_post_id(p.parent_post_id, fail_if_missing=True)
|
||||
wp = pr.store_task(
|
||||
protocol_schema.AssistantReplyTask(
|
||||
conversation=protocol_schema.Conversation(
|
||||
messages=[protocol_schema.ConversationMessage(text="dummy", is_assistant=False)]
|
||||
)
|
||||
),
|
||||
thread_id=parent_post.thread_id,
|
||||
parent_post_id=parent_post.id,
|
||||
)
|
||||
pr.bind_frontend_post_id(wp.id, p.task_post_id)
|
||||
post = pr.store_text_reply(p.text, p.task_post_id, p.user_post_id)
|
||||
|
||||
logger.info(
|
||||
f"Inserted: post_id: {post.id}, payload: {post.payload.payload}, parent_post_id: {post.parent_id}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"seed data work_package found: {wp.id}")
|
||||
logger.info("Seed data check completed")
|
||||
|
||||
except Exception:
|
||||
logger.exception("Seed data insertion failed")
|
||||
|
||||
|
||||
app.include_router(api_router, prefix=settings.API_V1_STR)
|
||||
|
||||
@@ -33,6 +33,19 @@ async def get_api_key(
|
||||
return api_key_header
|
||||
|
||||
|
||||
def get_dummy_api_client(db: Session) -> ApiClient:
|
||||
# make sure that a dummy api key exits in db (foreign key references)
|
||||
ANY_API_KEY_ID = UUID("00000000-1111-2222-3333-444444444444")
|
||||
api_client: ApiClient = db.query(ApiClient).filter(ApiClient.id == ANY_API_KEY_ID).first()
|
||||
if api_client is None:
|
||||
token = token_hex(32)
|
||||
logger.info(f"ANY_API_KEY missing, inserting api_key: {token}")
|
||||
api_client = ApiClient(id=ANY_API_KEY_ID, api_key=token, description="ANY_API_KEY, random token")
|
||||
db.add(api_client)
|
||||
db.commit()
|
||||
return api_client
|
||||
|
||||
|
||||
def api_auth(
|
||||
api_key: APIKey,
|
||||
db: Session,
|
||||
@@ -40,16 +53,7 @@ def api_auth(
|
||||
if api_key or settings.DEBUG_SKIP_API_KEY_CHECK:
|
||||
|
||||
if settings.DEBUG_SKIP_API_KEY_CHECK or settings.DEBUG_ALLOW_ANY_API_KEY:
|
||||
# make sure that a dummy api key exits in db (foreign key references)
|
||||
ANY_API_KEY_ID = UUID("00000000-1111-2222-3333-444444444444")
|
||||
api_client: ApiClient = db.query(ApiClient).filter(ApiClient.id == ANY_API_KEY_ID).first()
|
||||
if api_client is None:
|
||||
token = token_hex(32)
|
||||
logger.info(f"ANY_API_KEY missing, inserting api_key: {token}")
|
||||
api_client = ApiClient(id=ANY_API_KEY_ID, api_key=token, description="ANY_API_KEY, random token")
|
||||
db.add(api_client)
|
||||
db.commit()
|
||||
return api_client
|
||||
return get_dummy_api_client(db)
|
||||
|
||||
api_client = db.query(ApiClient).filter(ApiClient.api_key == api_key).first()
|
||||
if api_client is not None and api_client.enabled:
|
||||
|
||||
@@ -17,6 +17,7 @@ class Settings(BaseSettings):
|
||||
|
||||
DEBUG_ALLOW_ANY_API_KEY: bool = False
|
||||
DEBUG_SKIP_API_KEY_CHECK: bool = False
|
||||
DEBUG_USE_SEED_DATA: bool = False
|
||||
|
||||
@validator("DATABASE_URI", pre=True)
|
||||
def assemble_db_connection(cls, v: Optional[str], values: Dict[str, Any]) -> Any:
|
||||
|
||||
@@ -34,6 +34,7 @@ class OasstErrorCode(IntEnum):
|
||||
INVALID_TASK_TYPE = 2004
|
||||
USER_NOT_SPECIFIED = 2005
|
||||
NO_THREADS_FOUND = 2006
|
||||
NO_REPLIES_FOUND = 2007
|
||||
WORK_PACKAGE_NOT_FOUND = 2100
|
||||
WORK_PACKAGE_EXPIRED = 2101
|
||||
WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH = 2102
|
||||
|
||||
@@ -397,7 +397,7 @@ class PromptRepository:
|
||||
distinct_threads = distinct_threads.filter(Post.role == require_role)
|
||||
distinct_threads = distinct_threads.subquery()
|
||||
|
||||
random_thread = self.db.query(distinct_threads).order_by(func.random()).limit(1).subquery()
|
||||
random_thread = self.db.query(distinct_threads).order_by(func.random()).limit(1)
|
||||
thread_posts = self.db.query(Post).filter(Post.thread_id.in_(random_thread)).all()
|
||||
return thread_posts
|
||||
|
||||
@@ -443,8 +443,10 @@ class PromptRepository:
|
||||
if post_role:
|
||||
parent = parent.filter(Post.role == post_role)
|
||||
|
||||
parent = parent.order_by(func.random()).limit(1).subquery()
|
||||
parent = parent.order_by(func.random()).limit(1)
|
||||
replies = self.db.query(Post).filter(Post.parent_id.in_(parent)).order_by(func.random()).limit(max_size).all()
|
||||
if not replies:
|
||||
raise OasstError("No replies found", OasstErrorCode.NO_REPLIES_FOUND)
|
||||
|
||||
thread = self.fetch_thread(replies[0].thread_id)
|
||||
thread = {p.id: p for p in thread}
|
||||
|
||||
@@ -71,6 +71,7 @@ services:
|
||||
environment:
|
||||
- POSTGRES_HOST=db
|
||||
- DEBUG_SKIP_API_KEY_CHECK=True
|
||||
- DEBUG_USE_SEED_DATA=True
|
||||
- MAX_WORKERS=1
|
||||
depends_on:
|
||||
db:
|
||||
@@ -92,6 +93,7 @@ services:
|
||||
- EMAIL_SERVER_PORT=1025
|
||||
- EMAIL_FROM=info@example.com
|
||||
- NEXTAUTH_URL=http://localhost:3000
|
||||
- DEBUG_LOGIN=true
|
||||
depends_on:
|
||||
webdb:
|
||||
condition: service_healthy
|
||||
|
||||
@@ -5,6 +5,7 @@ parent_path=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P )
|
||||
pushd "$parent_path/../../backend"
|
||||
|
||||
export DEBUG_SKIP_API_KEY_CHECK=True
|
||||
export DEBUG_USE_SEED_DATA=True
|
||||
|
||||
uvicorn main:app --reload --port 8080 --host 0.0.0.0
|
||||
|
||||
|
||||
+2
-2
@@ -64,9 +64,9 @@ If you're doing active development we suggest the following workflow:
|
||||
|
||||
### Using debug user credentials
|
||||
|
||||
Whenever the website runs in development mode, you can use the debug credentials provider to log in without fancy emails or OAuth.
|
||||
You can use the debug credentials provider to log in without fancy emails or OAuth.
|
||||
|
||||
1. Development mode is automatically active when you start the website with `npm run dev`.
|
||||
1. This feature is automatically on in development mode, i.e. when you run `npm run dev`. In case you want to do the same with a production build (for example, the docker image), then run the website with environment variable `DEBUG_LOGIN=true`.
|
||||
1. Use the `Login` button in the top right to go to the login page.
|
||||
1. You should see a section for debug credentials. Enter any username you wish, you will be logged in as that user.
|
||||
|
||||
|
||||
Generated
+321
-748
File diff suppressed because it is too large
Load Diff
@@ -23,6 +23,7 @@
|
||||
"@tailwindcss/forms": "^0.5.3",
|
||||
"autoprefixer": "^10.4.13",
|
||||
"axios": "^1.2.1",
|
||||
"boolean": "^3.2.0",
|
||||
"clsx": "^1.2.1",
|
||||
"eslint": "8.29.0",
|
||||
"eslint-config-next": "13.0.6",
|
||||
|
||||
@@ -25,7 +25,7 @@ export function Footer() {
|
||||
<Link href="#" aria-label="Our Team" className="hover:underline underline-offset-2">
|
||||
Our Team
|
||||
</Link>
|
||||
<Link href="#join-us" aria-label="Join Us" className="hover:underline underline-offset-2">
|
||||
<Link href="/#join-us" aria-label="Join Us" className="hover:underline underline-offset-2">
|
||||
Join Us
|
||||
</Link>
|
||||
</div>
|
||||
|
||||
@@ -21,4 +21,4 @@ const Template = (args) => {
|
||||
};
|
||||
|
||||
export const Default = Template.bind({});
|
||||
Default.args = { session: { data: { user: { name: "StoryBook user" } }, status: "authenticated" } };
|
||||
Default.args = { session: { data: { user: { name: "StoryBook user" } }, status: "authenticated" }, transparent: false };
|
||||
|
||||
@@ -3,8 +3,10 @@ import { Popover } from "@headlessui/react";
|
||||
import { AnimatePresence, motion } from "framer-motion";
|
||||
import Image from "next/image";
|
||||
import Link from "next/link";
|
||||
import { useSession } from "next-auth/react";
|
||||
import { FaUser } from "react-icons/fa";
|
||||
|
||||
import { signOut, useSession } from "next-auth/react";
|
||||
import { FaUser, FaSignOutAlt } from "react-icons/fa";
|
||||
import clsx from "clsx";
|
||||
|
||||
import { Container } from "src/components/Container";
|
||||
import { NavLinks } from "./NavLinks";
|
||||
@@ -53,9 +55,10 @@ function AccountButton() {
|
||||
);
|
||||
}
|
||||
|
||||
export function Header() {
|
||||
export function Header(props) {
|
||||
const transparent = props.transparent ?? false;
|
||||
return (
|
||||
<header className="bg-white">
|
||||
<header className={clsx(!transparent && "bg-white")}>
|
||||
<nav>
|
||||
<Container className="relative z-10 flex justify-between py-8">
|
||||
<div className="relative z-10 flex items-center gap-16">
|
||||
@@ -101,8 +104,8 @@ export function Header() {
|
||||
className="absolute inset-x-0 top-0 z-0 origin-top rounded-b-2xl bg-white px-6 pb-6 pt-32 shadow-2xl shadow-gray-900/20"
|
||||
>
|
||||
<div className="space-y-4">
|
||||
<MobileNavLink href="#join-us">Join Us</MobileNavLink>
|
||||
<MobileNavLink href="#faqs">FAQs</MobileNavLink>
|
||||
<MobileNavLink href="/#join-us">Join Us</MobileNavLink>
|
||||
<MobileNavLink href="/#faqs">FAQs</MobileNavLink>
|
||||
</div>
|
||||
<div className="mt-8 flex flex-col gap-4"></div>
|
||||
</Popover.Panel>
|
||||
|
||||
@@ -8,8 +8,8 @@ export function NavLinks(): JSX.Element {
|
||||
return (
|
||||
<>
|
||||
{[
|
||||
["Join Us", "#join-us"],
|
||||
["FAQ", "#faq"],
|
||||
["Join Us", "/#join-us"],
|
||||
["FAQ", "/#faq"],
|
||||
].map(([label, href], index) => (
|
||||
<Link
|
||||
key={label}
|
||||
|
||||
@@ -7,12 +7,12 @@ export const TaskSelection = () => {
|
||||
return (
|
||||
<Flex gap={10} wrap="wrap" justifyContent="space-evenly" width="full" height="full" alignItems={"center"}>
|
||||
<TaskOptions key="create" title="Create">
|
||||
<TaskOption
|
||||
{/* <TaskOption
|
||||
alt="Summarize Stories"
|
||||
img="/images/logos/logo.svg"
|
||||
title="Summarize stories"
|
||||
link="/create/summarize_story"
|
||||
/>
|
||||
/> */}
|
||||
<TaskOption alt="Reply as User" img="/images/logos/logo.svg" title="Reply as User" link="/create/user_reply" />
|
||||
<TaskOption
|
||||
alt="Reply as Assistant"
|
||||
@@ -22,12 +22,12 @@ export const TaskSelection = () => {
|
||||
/>
|
||||
</TaskOptions>
|
||||
<TaskOptions key="evaluate" title="Evaluate">
|
||||
<TaskOption
|
||||
{/* <TaskOption
|
||||
alt="Rate Prompts"
|
||||
img="/images/logos/logo.svg"
|
||||
title="Rate Prompts"
|
||||
link="/evaluate/rate_summary"
|
||||
/>
|
||||
/> */}
|
||||
<TaskOption
|
||||
alt="Rank Initial Prompts"
|
||||
img="/images/logos/logo.svg"
|
||||
|
||||
@@ -5,6 +5,7 @@ import DiscordProvider from "next-auth/providers/discord";
|
||||
import EmailProvider from "next-auth/providers/email";
|
||||
import CredentialsProvider from "next-auth/providers/credentials";
|
||||
import { PrismaAdapter } from "@next-auth/prisma-adapter";
|
||||
import { boolean } from "boolean";
|
||||
|
||||
import prisma from "src/lib/prismadb";
|
||||
|
||||
@@ -34,7 +35,7 @@ if (process.env.DISCORD_CLIENT_ID) {
|
||||
);
|
||||
}
|
||||
|
||||
if (process.env.NODE_ENV === "development") {
|
||||
if (boolean(process.env.DEBUG_LOGIN) || process.env.NODE_ENV === "development") {
|
||||
providers.push(
|
||||
CredentialsProvider({
|
||||
name: "Debug Credentials",
|
||||
|
||||
@@ -5,6 +5,8 @@ import { CallToAction } from "src/components/CallToAction";
|
||||
import { Faq } from "src/components/Faq";
|
||||
import { Hero } from "src/components/Hero";
|
||||
import { TaskSelection } from "src/components/TaskSelection";
|
||||
import { Header } from "src/components/Header";
|
||||
import { Footer } from "src/components/Footer";
|
||||
|
||||
const Home = () => {
|
||||
const { data: session } = useSession();
|
||||
@@ -34,4 +36,12 @@ const Home = () => {
|
||||
);
|
||||
};
|
||||
|
||||
Home.getLayout = (page) => (
|
||||
<div className="grid grid-rows-[min-content_1fr_min-content] h-full justify-items-stretch">
|
||||
<Header transparent={true} />
|
||||
{page}
|
||||
<Footer />
|
||||
</div>
|
||||
);
|
||||
|
||||
export default Home;
|
||||
|
||||
Reference in New Issue
Block a user