Merge branch 'main' into eslint

This commit is contained in:
Desmond Grealy
2022-12-29 10:32:21 -08:00
committed by GitHub
21 changed files with 504 additions and 781 deletions
+4
View File
@@ -4,7 +4,11 @@ on:
push:
branches:
- main
paths:
- website/**
pull_request:
paths:
- website/**
workflow_call:
jobs:
+1 -1
View File
@@ -14,7 +14,7 @@ repos:
# and which break the standard YAML check. The alternative would be to
# skip any unsafe errors (and thus break YAML compatibility) or use
# some other checker that may not work in general.
exclude: copilot/web/addons/*
exclude: "^copilot/web/addons/.*$"
- id: check-json
- id: check-case-conflict
- id: detect-private-key
+1 -3
View File
@@ -107,9 +107,7 @@ To start the demo, run this, in root directory:
docker compose up --build
```
Then, navigate to `http://localhost:3000` and interact with the website. When
logging in, navigate to `http://localhost:1080` to get the magic email login
link.
Then, navigate to `http://localhost:3000` and interact with the website.
### Website
+1
View File
@@ -51,6 +51,7 @@
env:
POSTGRES_HOST: oasst-postgres
DEBUG_ALLOW_ANY_API_KEY: "true"
DEBUG_USE_SEED_DATA: "true"
MAX_WORKERS: "1"
ports:
- 8080:8080
+121
View File
@@ -1,14 +1,21 @@
# -*- coding: utf-8 -*-
from http import HTTPStatus
from pathlib import Path
from typing import Optional
import alembic.command
import alembic.config
import fastapi
import pydantic
from loguru import logger
from oasst_backend.api.deps import get_dummy_api_client
from oasst_backend.api.v1.api import api_router
from oasst_backend.config import settings
from oasst_backend.database import engine
from oasst_backend.exceptions import OasstError, OasstErrorCode
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
from starlette.middleware.cors import CORSMiddleware
app = fastapi.FastAPI(title=settings.PROJECT_NAME, openapi_url=f"{settings.API_V1_STR}/openapi.json")
@@ -56,4 +63,118 @@ if settings.UPDATE_ALEMBIC:
logger.exception("Alembic upgrade failed on startup")
if settings.DEBUG_USE_SEED_DATA:
@app.on_event("startup")
def seed_data():
class DummyPost(pydantic.BaseModel):
task_post_id: str
user_post_id: str
parent_post_id: Optional[str]
text: str
role: str
try:
logger.info("Seed data check began")
with Session(engine) as db:
api_client = get_dummy_api_client(db)
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
pr = PromptRepository(db=db, api_client=api_client, user=dummy_user)
dummy_posts = [
DummyPost(
task_post_id="de111fa8",
user_post_id="6f1d0711",
parent_post_id=None,
text="Hi!",
role="user",
),
DummyPost(
task_post_id="74c381d4",
user_post_id="4a24530b",
parent_post_id="6f1d0711",
text="Hello! How can I help you?",
role="assistant",
),
DummyPost(
task_post_id="3d5dc440",
user_post_id="a8c01c04",
parent_post_id="4a24530b",
text="Do you have a recipe for potato soup?",
role="user",
),
DummyPost(
task_post_id="643716c1",
user_post_id="f43a93b7",
parent_post_id="4a24530b",
text="Who were the 8 presidents before George Washington?",
role="user",
),
DummyPost(
task_post_id="2e4e1e6",
user_post_id="c886920",
parent_post_id="6f1d0711",
text="Hey buddy! How can I serve you?",
role="assistant",
),
DummyPost(
task_post_id="970c437d",
user_post_id="cec432cf",
parent_post_id=None,
text="euirdteunvglfe23908230892309832098 AAAAAAAA",
role="user",
),
DummyPost(
task_post_id="6066118e",
user_post_id="4f85f637",
parent_post_id="cec432cf",
text="Sorry, I did not understand your request and it is unclear to me what you want me to do. Could you describe it in a different way?",
role="assistant",
),
DummyPost(
task_post_id="ba87780d",
user_post_id="0e276b98",
parent_post_id="cec432cf",
text="I'm unsure how to interpret this. Is it a riddle?",
role="assistant",
),
]
for p in dummy_posts:
wp = pr.fetch_workpackage_by_postid(p.task_post_id)
if wp and not wp.ack:
logger.warning("Deleting unacknowledged seed data work package")
db.delete(wp)
wp = None
if not wp:
if p.parent_post_id is None:
wp = pr.store_task(
protocol_schema.InitialPromptTask(hint=""), thread_id=None, parent_post_id=None
)
else:
print("p.parent_post_id", p.parent_post_id)
parent_post = pr.fetch_post_by_frontend_post_id(p.parent_post_id, fail_if_missing=True)
wp = pr.store_task(
protocol_schema.AssistantReplyTask(
conversation=protocol_schema.Conversation(
messages=[protocol_schema.ConversationMessage(text="dummy", is_assistant=False)]
)
),
thread_id=parent_post.thread_id,
parent_post_id=parent_post.id,
)
pr.bind_frontend_post_id(wp.id, p.task_post_id)
post = pr.store_text_reply(p.text, p.task_post_id, p.user_post_id)
logger.info(
f"Inserted: post_id: {post.id}, payload: {post.payload.payload}, parent_post_id: {post.parent_id}"
)
else:
logger.debug(f"seed data work_package found: {wp.id}")
logger.info("Seed data check completed")
except Exception:
logger.exception("Seed data insertion failed")
app.include_router(api_router, prefix=settings.API_V1_STR)
+14 -10
View File
@@ -33,6 +33,19 @@ async def get_api_key(
return api_key_header
def get_dummy_api_client(db: Session) -> ApiClient:
# make sure that a dummy api key exits in db (foreign key references)
ANY_API_KEY_ID = UUID("00000000-1111-2222-3333-444444444444")
api_client: ApiClient = db.query(ApiClient).filter(ApiClient.id == ANY_API_KEY_ID).first()
if api_client is None:
token = token_hex(32)
logger.info(f"ANY_API_KEY missing, inserting api_key: {token}")
api_client = ApiClient(id=ANY_API_KEY_ID, api_key=token, description="ANY_API_KEY, random token")
db.add(api_client)
db.commit()
return api_client
def api_auth(
api_key: APIKey,
db: Session,
@@ -40,16 +53,7 @@ def api_auth(
if api_key or settings.DEBUG_SKIP_API_KEY_CHECK:
if settings.DEBUG_SKIP_API_KEY_CHECK or settings.DEBUG_ALLOW_ANY_API_KEY:
# make sure that a dummy api key exits in db (foreign key references)
ANY_API_KEY_ID = UUID("00000000-1111-2222-3333-444444444444")
api_client: ApiClient = db.query(ApiClient).filter(ApiClient.id == ANY_API_KEY_ID).first()
if api_client is None:
token = token_hex(32)
logger.info(f"ANY_API_KEY missing, inserting api_key: {token}")
api_client = ApiClient(id=ANY_API_KEY_ID, api_key=token, description="ANY_API_KEY, random token")
db.add(api_client)
db.commit()
return api_client
return get_dummy_api_client(db)
api_client = db.query(ApiClient).filter(ApiClient.api_key == api_key).first()
if api_client is not None and api_client.enabled:
+1
View File
@@ -17,6 +17,7 @@ class Settings(BaseSettings):
DEBUG_ALLOW_ANY_API_KEY: bool = False
DEBUG_SKIP_API_KEY_CHECK: bool = False
DEBUG_USE_SEED_DATA: bool = False
@validator("DATABASE_URI", pre=True)
def assemble_db_connection(cls, v: Optional[str], values: Dict[str, Any]) -> Any:
+1
View File
@@ -34,6 +34,7 @@ class OasstErrorCode(IntEnum):
INVALID_TASK_TYPE = 2004
USER_NOT_SPECIFIED = 2005
NO_THREADS_FOUND = 2006
NO_REPLIES_FOUND = 2007
WORK_PACKAGE_NOT_FOUND = 2100
WORK_PACKAGE_EXPIRED = 2101
WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH = 2102
+4 -2
View File
@@ -397,7 +397,7 @@ class PromptRepository:
distinct_threads = distinct_threads.filter(Post.role == require_role)
distinct_threads = distinct_threads.subquery()
random_thread = self.db.query(distinct_threads).order_by(func.random()).limit(1).subquery()
random_thread = self.db.query(distinct_threads).order_by(func.random()).limit(1)
thread_posts = self.db.query(Post).filter(Post.thread_id.in_(random_thread)).all()
return thread_posts
@@ -443,8 +443,10 @@ class PromptRepository:
if post_role:
parent = parent.filter(Post.role == post_role)
parent = parent.order_by(func.random()).limit(1).subquery()
parent = parent.order_by(func.random()).limit(1)
replies = self.db.query(Post).filter(Post.parent_id.in_(parent)).order_by(func.random()).limit(max_size).all()
if not replies:
raise OasstError("No replies found", OasstErrorCode.NO_REPLIES_FOUND)
thread = self.fetch_thread(replies[0].thread_id)
thread = {p.id: p for p in thread}
+2
View File
@@ -71,6 +71,7 @@ services:
environment:
- POSTGRES_HOST=db
- DEBUG_SKIP_API_KEY_CHECK=True
- DEBUG_USE_SEED_DATA=True
- MAX_WORKERS=1
depends_on:
db:
@@ -92,6 +93,7 @@ services:
- EMAIL_SERVER_PORT=1025
- EMAIL_FROM=info@example.com
- NEXTAUTH_URL=http://localhost:3000
- DEBUG_LOGIN=true
depends_on:
webdb:
condition: service_healthy
+1
View File
@@ -5,6 +5,7 @@ parent_path=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P )
pushd "$parent_path/../../backend"
export DEBUG_SKIP_API_KEY_CHECK=True
export DEBUG_USE_SEED_DATA=True
uvicorn main:app --reload --port 8080 --host 0.0.0.0
+2 -2
View File
@@ -64,9 +64,9 @@ If you're doing active development we suggest the following workflow:
### Using debug user credentials
Whenever the website runs in development mode, you can use the debug credentials provider to log in without fancy emails or OAuth.
You can use the debug credentials provider to log in without fancy emails or OAuth.
1. Development mode is automatically active when you start the website with `npm run dev`.
1. This feature is automatically on in development mode, i.e. when you run `npm run dev`. In case you want to do the same with a production build (for example, the docker image), then run the website with environment variable `DEBUG_LOGIN=true`.
1. Use the `Login` button in the top right to go to the login page.
1. You should see a section for debug credentials. Enter any username you wish, you will be logged in as that user.
+321 -748
View File
File diff suppressed because it is too large Load Diff
+1
View File
@@ -23,6 +23,7 @@
"@tailwindcss/forms": "^0.5.3",
"autoprefixer": "^10.4.13",
"axios": "^1.2.1",
"boolean": "^3.2.0",
"clsx": "^1.2.1",
"eslint": "8.29.0",
"eslint-config-next": "13.0.6",
+1 -1
View File
@@ -25,7 +25,7 @@ export function Footer() {
<Link href="#" aria-label="Our Team" className="hover:underline underline-offset-2">
Our Team
</Link>
<Link href="#join-us" aria-label="Join Us" className="hover:underline underline-offset-2">
<Link href="/#join-us" aria-label="Join Us" className="hover:underline underline-offset-2">
Join Us
</Link>
</div>
@@ -21,4 +21,4 @@ const Template = (args) => {
};
export const Default = Template.bind({});
Default.args = { session: { data: { user: { name: "StoryBook user" } }, status: "authenticated" } };
Default.args = { session: { data: { user: { name: "StoryBook user" } }, status: "authenticated" }, transparent: false };
+9 -6
View File
@@ -3,8 +3,10 @@ import { Popover } from "@headlessui/react";
import { AnimatePresence, motion } from "framer-motion";
import Image from "next/image";
import Link from "next/link";
import { useSession } from "next-auth/react";
import { FaUser } from "react-icons/fa";
import { signOut, useSession } from "next-auth/react";
import { FaUser, FaSignOutAlt } from "react-icons/fa";
import clsx from "clsx";
import { Container } from "src/components/Container";
import { NavLinks } from "./NavLinks";
@@ -53,9 +55,10 @@ function AccountButton() {
);
}
export function Header() {
export function Header(props) {
const transparent = props.transparent ?? false;
return (
<header className="bg-white">
<header className={clsx(!transparent && "bg-white")}>
<nav>
<Container className="relative z-10 flex justify-between py-8">
<div className="relative z-10 flex items-center gap-16">
@@ -101,8 +104,8 @@ export function Header() {
className="absolute inset-x-0 top-0 z-0 origin-top rounded-b-2xl bg-white px-6 pb-6 pt-32 shadow-2xl shadow-gray-900/20"
>
<div className="space-y-4">
<MobileNavLink href="#join-us">Join Us</MobileNavLink>
<MobileNavLink href="#faqs">FAQs</MobileNavLink>
<MobileNavLink href="/#join-us">Join Us</MobileNavLink>
<MobileNavLink href="/#faqs">FAQs</MobileNavLink>
</div>
<div className="mt-8 flex flex-col gap-4"></div>
</Popover.Panel>
+2 -2
View File
@@ -8,8 +8,8 @@ export function NavLinks(): JSX.Element {
return (
<>
{[
["Join Us", "#join-us"],
["FAQ", "#faq"],
["Join Us", "/#join-us"],
["FAQ", "/#faq"],
].map(([label, href], index) => (
<Link
key={label}
@@ -7,12 +7,12 @@ export const TaskSelection = () => {
return (
<Flex gap={10} wrap="wrap" justifyContent="space-evenly" width="full" height="full" alignItems={"center"}>
<TaskOptions key="create" title="Create">
<TaskOption
{/* <TaskOption
alt="Summarize Stories"
img="/images/logos/logo.svg"
title="Summarize stories"
link="/create/summarize_story"
/>
/> */}
<TaskOption alt="Reply as User" img="/images/logos/logo.svg" title="Reply as User" link="/create/user_reply" />
<TaskOption
alt="Reply as Assistant"
@@ -22,12 +22,12 @@ export const TaskSelection = () => {
/>
</TaskOptions>
<TaskOptions key="evaluate" title="Evaluate">
<TaskOption
{/* <TaskOption
alt="Rate Prompts"
img="/images/logos/logo.svg"
title="Rate Prompts"
link="/evaluate/rate_summary"
/>
/> */}
<TaskOption
alt="Rank Initial Prompts"
img="/images/logos/logo.svg"
+2 -1
View File
@@ -5,6 +5,7 @@ import DiscordProvider from "next-auth/providers/discord";
import EmailProvider from "next-auth/providers/email";
import CredentialsProvider from "next-auth/providers/credentials";
import { PrismaAdapter } from "@next-auth/prisma-adapter";
import { boolean } from "boolean";
import prisma from "src/lib/prismadb";
@@ -34,7 +35,7 @@ if (process.env.DISCORD_CLIENT_ID) {
);
}
if (process.env.NODE_ENV === "development") {
if (boolean(process.env.DEBUG_LOGIN) || process.env.NODE_ENV === "development") {
providers.push(
CredentialsProvider({
name: "Debug Credentials",
+10
View File
@@ -5,6 +5,8 @@ import { CallToAction } from "src/components/CallToAction";
import { Faq } from "src/components/Faq";
import { Hero } from "src/components/Hero";
import { TaskSelection } from "src/components/TaskSelection";
import { Header } from "src/components/Header";
import { Footer } from "src/components/Footer";
const Home = () => {
const { data: session } = useSession();
@@ -34,4 +36,12 @@ const Home = () => {
);
};
Home.getLayout = (page) => (
<div className="grid grid-rows-[min-content_1fr_min-content] h-full justify-items-stretch">
<Header transparent={true} />
{page}
<Footer />
</div>
);
export default Home;