Merge branch 'main' into pre-commit-jupyter-black

This commit is contained in:
Janosh Riebesell
2023-01-03 06:42:40 -08:00
committed by GitHub
76 changed files with 6275 additions and 1159 deletions
+1 -1
View File
@@ -39,7 +39,7 @@ repos:
# and which break the standard YAML check. The alternative would be to
# skip any unsafe errors (and thus break YAML compatibility) or use
# some other checker that may not work in general.
exclude: ^copilot/web/addons/.*$
exclude: ^copilot/.*/addons/.*$
- id: check-json
- id: check-case-conflict
- id: detect-private-key
+2
View File
@@ -1,2 +1,4 @@
* @yk @andreaskoepf
/website/ @fozziethebeat @k-nearest-neighbor @AbdBarho
/model/ @theblackcat102 @sanagno
/copilot/ @fozziethebeat @andreaskoepf @yk
+6 -1
View File
@@ -26,8 +26,13 @@ app = fastapi.FastAPI(title=settings.PROJECT_NAME, openapi_url=f"{settings.API_V
@app.exception_handler(OasstError)
async def oasst_exception_handler(request: fastapi.Request, ex: OasstError):
logger.error(f"{request.method} {request.url} failed: {repr(ex)}")
return fastapi.responses.JSONResponse(
status_code=int(ex.http_status_code), content={"message": ex.message, "error_code": ex.error_code}
status_code=int(ex.http_status_code),
content=protocol_schema.OasstErrorResponse(
message=ex.message,
error_code=OasstErrorCode(ex.error_code),
).dict(),
)
+38
View File
@@ -0,0 +1,38 @@
# The manifest for the "api" service.
# Read the full specification for the "Load Balanced Web Service" type at:
# https://aws.github.io/copilot-cli/docs/manifest/lb-web-service/
name: api
type: Load Balanced Web Service
http:
path: "/"
healthcheck:
path: "/docs"
image:
build:
dockerfile: docker/Dockerfile.backend
context: ./
port: 8080
cpu: 256
memory: 512
platform: linux/x86_64
count: 1
exec: true
network:
connect: true
environments:
staging:
variables:
# Note: this has to be a valid JSON list for Pydantic to parse it.
BACKEND_CORS_ORIGINS: '["https://web.staging.open-assistant.surfacedata.org"]'
DEBUG_ALLOW_ANY_API_KEY: True
DEBUG_SKIP_API_KEY_CHECK: True
MAX_WORKERS: 1
secrets:
# Note: URI, not URL.
DATABASE_URI: /copilot/${COPILOT_APPLICATION_NAME}/${COPILOT_ENVIRONMENT_NAME}/secrets/API_DATABASE_URL
-161
View File
@@ -1,161 +0,0 @@
Parameters:
App:
Type: String
Description: Your application's name.
Env:
Type: String
Description:
The environment name your service, job, or workflow is being deployed to.
Name:
Type: String
Description: The name of the service, job, or workflow being deployed.
# Customize your Aurora Serverless cluster by setting the default value of the following parameters.
webclusterDBName:
Type: String
Description:
The name of the initial database to be created in the Aurora Serverless v2
cluster.
Default: oassist_web
# Cannot have special characters
# Naming constraints: https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/CHAP_Limits.html#RDS_Limits.Constraints
Mappings:
webclusterEnvScalingConfigurationMap:
staging:
"DBMinCapacity": 0.5 # AllowedValues: from 0.5 through 128
"DBMaxCapacity": 8 # AllowedValues: from 0.5 through 128
All:
"DBMinCapacity": 0.5 # AllowedValues: from 0.5 through 128
"DBMaxCapacity": 8 # AllowedValues: from 0.5 through 128
Resources:
webclusterDBSubnetGroup:
Type: "AWS::RDS::DBSubnetGroup"
Properties:
DBSubnetGroupDescription:
Group of Copilot private subnets for Aurora Serverless v2 cluster.
SubnetIds:
!Split [",", { "Fn::ImportValue": !Sub "${App}-${Env}-PrivateSubnets" }]
webclusterSecurityGroup:
Metadata:
"aws:copilot:description":
"A security group for your workload to access the Aurora Serverless v2
cluster webcluster"
Type: "AWS::EC2::SecurityGroup"
Properties:
GroupDescription:
!Sub "The Security Group for ${Name} to access Aurora Serverless v2
cluster webcluster."
VpcId:
Fn::ImportValue: !Sub "${App}-${Env}-VpcId"
Tags:
- Key: Name
Value: !Sub "copilot-${App}-${Env}-${Name}-Aurora"
webclusterDBClusterSecurityGroup:
Metadata:
"aws:copilot:description":
"A security group for your Aurora Serverless v2 cluster webcluster"
Type: AWS::EC2::SecurityGroup
Properties:
GroupDescription: The Security Group for the Aurora Serverless v2 cluster.
SecurityGroupIngress:
- ToPort: 5432
FromPort: 5432
IpProtocol: tcp
Description:
!Sub "From the Aurora Security Group of the workload ${Name}."
SourceSecurityGroupId: !Ref webclusterSecurityGroup
VpcId:
Fn::ImportValue: !Sub "${App}-${Env}-VpcId"
webclusterAuroraSecret:
Metadata:
"aws:copilot:description":
"A Secrets Manager secret to store your DB credentials"
Type: AWS::SecretsManager::Secret
Properties:
Description: !Sub Aurora main user secret for ${AWS::StackName}
GenerateSecretString:
SecretStringTemplate: '{"username": "postgres"}'
GenerateStringKey: "password"
ExcludePunctuation: true
IncludeSpace: false
PasswordLength: 16
webclusterDBClusterParameterGroup:
Metadata:
"aws:copilot:description":
"A DB parameter group for engine configuration values"
Type: "AWS::RDS::DBClusterParameterGroup"
Properties:
Description: !Ref "AWS::StackName"
Family: "aurora-postgresql14"
Parameters:
client_encoding: "UTF8"
webclusterDBCluster:
Metadata:
"aws:copilot:description":
"The webcluster Aurora Serverless v2 database cluster"
Type: "AWS::RDS::DBCluster"
Properties:
MasterUsername:
!Join [
"",
[
"{{resolve:secretsmanager:",
!Ref webclusterAuroraSecret,
":SecretString:username}}",
],
]
MasterUserPassword:
!Join [
"",
[
"{{resolve:secretsmanager:",
!Ref webclusterAuroraSecret,
":SecretString:password}}",
],
]
DatabaseName: !Ref webclusterDBName
Engine: "aurora-postgresql"
EngineVersion: "14.4"
DBClusterParameterGroupName: !Ref webclusterDBClusterParameterGroup
DBSubnetGroupName: !Ref webclusterDBSubnetGroup
Port: 5432
VpcSecurityGroupIds:
- !Ref webclusterDBClusterSecurityGroup
ServerlessV2ScalingConfiguration:
# Replace "All" below with "!Ref Env" to set different autoscaling limits per environment.
MinCapacity:
!FindInMap [webclusterEnvScalingConfigurationMap, All, DBMinCapacity]
MaxCapacity:
!FindInMap [webclusterEnvScalingConfigurationMap, All, DBMaxCapacity]
webclusterDBWriterInstance:
Metadata:
"aws:copilot:description":
"The webcluster Aurora Serverless v2 writer instance"
Type: "AWS::RDS::DBInstance"
Properties:
DBClusterIdentifier: !Ref webclusterDBCluster
DBInstanceClass: db.serverless
Engine: "aurora-postgresql"
PromotionTier: 1
AvailabilityZone: !Select
- 0
- !GetAZs
Ref: AWS::Region
webclusterSecretAuroraClusterAttachment:
Type: AWS::SecretsManager::SecretTargetAttachment
Properties:
SecretId: !Ref webclusterAuroraSecret
TargetId: !Ref webclusterDBCluster
TargetType: AWS::RDS::DBCluster
Outputs:
webclusterSecret: # injected as WEBCLUSTER_SECRET environment variable by Copilot.
Description:
"The JSON secret that holds the database username and password. Fields are
'host', 'port', 'dbname', 'username', 'password', 'dbClusterIdentifier'
and 'engine'"
Value: !Ref webclusterAuroraSecret
webclusterSecurityGroup:
Description: "The security group to attach to the workload."
Value: !Ref webclusterSecurityGroup
+1 -1
View File
@@ -26,6 +26,7 @@ environments:
staging:
variables:
NEXTAUTH_URL: https://web.staging.open-assistant.surfacedata.org
FASTAPI_URL: https://api.staging.open-assistant.surfacedata.org
secrets:
DATABASE_URL: /copilot/${COPILOT_APPLICATION_NAME}/${COPILOT_ENVIRONMENT_NAME}/secrets/DATABASE_URL
@@ -37,5 +38,4 @@ secrets:
EMAIL_SERVER_USER: /copilot/${COPILOT_APPLICATION_NAME}/${COPILOT_ENVIRONMENT_NAME}/secrets/EMAIL_SERVER_USER
EMAIL_FROM: /copilot/${COPILOT_APPLICATION_NAME}/${COPILOT_ENVIRONMENT_NAME}/secrets/EMAIL_FROM
FASTAPI_KEY: /copilot/${COPILOT_APPLICATION_NAME}/${COPILOT_ENVIRONMENT_NAME}/secrets/FASTAPI_KEY
FASTAPI_URL: /copilot/${COPILOT_APPLICATION_NAME}/${COPILOT_ENVIRONMENT_NAME}/secrets/FASTAPI_URL
NEXTAUTH_SECRET: /copilot/${COPILOT_APPLICATION_NAME}/${COPILOT_ENVIRONMENT_NAME}/secrets/NEXTAUTH_SECRET
+1 -1
View File
@@ -1,7 +1,7 @@
BOT_TOKEN=<discord bot token>
DECLARE_GLOBAL_COMMANDS=<testing guild id>
OWNER_IDS=[<your user id>, <other user ids>]
PREFIX="./"
PREFIX="/" # DO NOT LEAVE EMPTY, slash command prefix in DMs
OASST_API_URL="http://localhost:8080" # No trailing '/'
OASST_API_KEY=""
+4 -1
View File
@@ -48,7 +48,7 @@ Remember to save your changes.
below to invite your bot.
```
https://discord.com/oauth2/authorize?client_id=YOUR_ID_HERE&permissions=8&scope=bot%20applications.commands
https://discord.com/oauth2/authorize?client_id=YOUR_CLIENT_ID_HERE&permissions=8&scope=bot%20applications.commands
```
### Environment Setup
@@ -66,6 +66,9 @@ pip install -e .
cp .env.example .env
# edit .env and add your bot token and other values
# BOT_TOKEN is given by the discord developer portal when you create a bot
# DECLARE_GLOBAL_COMMANDS is the id of the server where you added the bot (right click on the server icon and copy id)
# OWNER_ID can be leave as an empty list
python -V # 3.10
+13 -6
View File
@@ -6,7 +6,7 @@ import hikari
import lightbulb
import miru
from bot.settings import Settings
from bot.utils import EMPTY, mention
from bot.utils import mention
from oasst_shared.api_client import OasstApiClient
settings = Settings()
@@ -34,8 +34,11 @@ async def on_starting(event: hikari.StartingEvent):
bot.d.oasst_api = OasstApiClient(settings.oasst_api_url, settings.oasst_api_key)
# A set of user id's that are currently doing work.
bot.d.currently_working = set()
# A `dict[hikari.Message | None, UUID | None]]` that maps user IDs to (task msg ID, task UUIDs).
# Either both are `None` or both are not `None`.
# If both are `None`, the user is not currently selecting a task.
# TODO: Grow this on startup so we don't have to re-allocate memory every time it needs to grow
bot.d.currently_working = {}
@bot.listen()
@@ -50,13 +53,13 @@ async def _send_error_embed(
) -> None:
ctx.command
embed = hikari.Embed(
title=f"`{exception.__class__.__name__}` Error{f' in `{ctx.command.name}`' if ctx.command else '' }",
title=f"`{exception.__class__.__name__}` Error{f' in `/{ctx.command.name}`' if ctx.command else '' }",
description=content,
color=0xFF0000,
timestamp=datetime.now().astimezone(),
).set_author(name=ctx.author.username, url=str(ctx.author.avatar_url))
await ctx.respond(EMPTY, embed=embed)
await ctx.respond(embed=embed)
@bot.listen(lightbulb.CommandErrorEvent)
@@ -65,6 +68,8 @@ async def on_error(event: lightbulb.CommandErrorEvent) -> None:
# Unwrap the exception to get the original cause
exc = event.exception.__cause__ or event.exception
ctx = event.context
if not ctx.bot.rest.is_alive:
return
if isinstance(event.exception, lightbulb.CommandInvocationError):
if not event.context.command:
@@ -114,6 +119,8 @@ async def on_error(event: lightbulb.CommandErrorEvent) -> None:
ctx,
)
elif isinstance(exc, lightbulb.errors.MissingRequiredAttachment):
await _send_error_embed("Not enough attachemnts were supplied to this command.", exc, ctx)
await _send_error_embed("Not enough attachments were supplied to this command.", exc, ctx)
elif isinstance(exc, lightbulb.errors.CommandNotFound):
await ctx.respond(f"`/{exc.invoked_with}` is not a valid command. Use `/help` to see a list of commands.")
else:
raise exc
@@ -78,7 +78,6 @@ async def log_channel(ctx: lightbulb.SlashContext) -> None:
# if the bot's permissions for this channel don't contain SEND_MESSAGE
# This will also filter out categories and voice channels
print(permissions_in(ch, own_member) & hikari.Permissions.SEND_MESSAGES)
if not permissions_in(ch, own_member) & hikari.Permissions.SEND_MESSAGES:
await ctx.respond(f"I don't have permission to send messages in {ch.mention}.")
return
+2 -3
View File
@@ -7,7 +7,6 @@ import lightbulb
import miru
from aiosqlite import Connection
from bot.db.schemas import GuildSettings
from bot.utils import EMPTY
from loguru import logger
plugin = lightbulb.Plugin(
@@ -74,7 +73,7 @@ class LabelModal(miru.Modal):
)
channel = await context.bot.rest.fetch_channel(guild_settings.log_channel_id)
assert isinstance(channel, hikari.TextableChannel)
await channel.send(EMPTY, embed=embed)
await channel.send(embed=embed)
class LabelSelect(miru.View):
@@ -164,7 +163,7 @@ async def label_message_text(ctx: lightbulb.MessageContext):
msg.content,
timeout=60,
)
resp = await ctx.respond(EMPTY, embed=embed, components=label_select_view, flags=hikari.MessageFlag.EPHEMERAL)
resp = await ctx.respond(embed=embed, components=label_select_view, flags=hikari.MessageFlag.EPHEMERAL)
await label_select_view.start(await resp.message())
await label_select_view.wait()
+184 -186
View File
@@ -1,15 +1,27 @@
"""Work plugin for collecting user data."""
import asyncio
import typing as t
from datetime import datetime
from uuid import UUID
import hikari
import lightbulb
import lightbulb.decorators
import miru
from aiosqlite import Connection
from bot.db.schemas import GuildSettings
from bot.utils import EMPTY
from bot.messages import (
assistant_reply_message,
confirm_ranking_response_message,
confirm_text_response_message,
initial_prompt_message,
invalid_user_input_embed,
plain_embed,
prompter_reply_message,
rank_assistant_reply_message,
rank_initial_prompts_message,
rank_prompter_reply_message,
task_complete_embed,
)
from bot.settings import Settings
from loguru import logger
from oasst_shared.api_client import OasstApiClient, TaskType
from oasst_shared.schemas import protocol as protocol_schema
@@ -20,6 +32,8 @@ plugin = lightbulb.Plugin("WorkPlugin")
MAX_TASK_TIME = 60 * 60 # 1 hour
MAX_TASK_ACCEPT_TIME = 60 # 1 minute
settings = Settings()
@plugin.command
@lightbulb.option(
@@ -31,31 +45,56 @@ MAX_TASK_ACCEPT_TIME = 60 # 1 minute
type=str,
)
@lightbulb.command("work", "Complete a task.")
@lightbulb.implements(lightbulb.SlashCommand)
async def work(ctx: lightbulb.SlashContext):
@lightbulb.implements(lightbulb.SlashCommand, lightbulb.PrefixCommand)
async def work(ctx: lightbulb.Context):
"""Create and handle a task."""
# make sure the user isn't currently doing a task
currently_working: set[hikari.Snowflakeish] = ctx.bot.d.currently_working
# Only send this message if started from a server
if ctx.guild_id is not None:
await ctx.respond(embed=plain_embed("Sending you a task, check your DMs"), flags=hikari.MessageFlag.EPHEMERAL)
# make sure the user isn't currently doing a task, and if they are, ask if they want to cancel it
currently_working: dict[
hikari.Snowflakeish, tuple[hikari.Message | None, UUID | None]
] = ctx.bot.d.currently_working
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
if ctx.author.id in currently_working:
await ctx.respond(
"You are already performing a task. Please complete that one first.", flags=hikari.MessageFlag.EPHEMERAL
yn_view = YesNoView(timeout=MAX_TASK_ACCEPT_TIME)
msg = await ctx.author.send(
embed=plain_embed("You are already working. Would you like to cancel your old task start a new one?"),
flags=hikari.MessageFlag.EPHEMERAL,
components=yn_view,
)
return
await yn_view.start(msg)
await yn_view.wait()
currently_working.add(ctx.author.id)
match yn_view.choice:
case False | None:
return
case True:
old_msg, task_id = currently_working[ctx.author.id]
if old_msg is not None:
logger.info(f"User {ctx.author.id} cancelled task {task_id}, deleting message {old_msg.id}")
map(lambda c: c, old_msg.components)
await old_msg.delete()
if task_id is not None:
await oasst_api.nack_task(task_id, reason="user cancelled")
await msg.delete()
currently_working[ctx.author.id] = (None, None)
# Create a TaskRequestType from the stringified enum value
task_type: TaskRequestType = TaskRequestType(ctx.options.type.split(".")[-1])
await ctx.respond("Sending you a task, check your DMs", flags=hikari.MessageFlag.EPHEMERAL)
logger.debug(f"Starting task_type: {task_type!r}")
try:
await _handle_task(ctx, task_type)
finally:
currently_working.remove(ctx.author.id)
del currently_working[ctx.author.id]
async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType) -> None:
async def _handle_task(ctx: lightbulb.Context, task_type: TaskRequestType) -> None:
"""Handle creating and collecting user input for a task.
Continually present tasks to the user until they select one, cancel, or time out.
@@ -72,38 +111,79 @@ async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType)
task, msg_id = await _select_task(ctx, task_type)
if task is None:
# User cancelled
return
# Task action loop
completed = False
while not completed:
await ctx.author.send("Please type your response here:")
await ctx.author.send(embed=plain_embed("Please type your response here"))
try:
event = await ctx.bot.wait_for(
hikari.DMMessageCreateEvent, timeout=MAX_TASK_TIME, predicate=lambda e: e.author.id == ctx.author.id
hikari.DMMessageCreateEvent,
timeout=MAX_TASK_TIME,
predicate=lambda e: e.author.id == ctx.author.id
and not (e.message.content or "").startswith(settings.prefix),
)
except asyncio.TimeoutError:
await ctx.author.send("Task timed out. Exiting")
await ctx.author.send(embed=plain_embed("Task timed out. Exiting"))
await oasst_api.nack_task(task.id, reason="timed out")
logger.info(f"Task {task.id} timed out")
return
# Invalid response
if event.content is None or not _validate_user_input(event.content, task):
await ctx.author.send("Invalid response")
valid, err_msg = _validate_user_input(event.content, task)
if not valid or event.content is None:
await ctx.author.send(embed=invalid_user_input_embed(err_msg))
continue
logger.debug(f"Successful user input received: {event.content}")
# Confirm user input
if isinstance(task, protocol_schema.RankConversationRepliesTask):
content = confirm_ranking_response_message(event.content, task.replies)
elif isinstance(task, protocol_schema.RankInitialPromptsTask):
content = confirm_ranking_response_message(event.content, task.prompts)
elif isinstance(task, protocol_schema.ReplyToConversationTask | protocol_schema.InitialPromptTask):
content = confirm_text_response_message(event.content)
else:
logger.critical(f"Unknown task type: {task.type}")
raise ValueError(f"Unknown task type: {task.type}")
confirm_resp_view = YesNoView(timeout=MAX_TASK_TIME)
msg = await ctx.author.send(content, components=confirm_resp_view)
await confirm_resp_view.start(msg)
await confirm_resp_view.wait()
match confirm_resp_view.choice:
case False | None:
continue
case True:
await msg.delete() # buttons are already gone
# Send the response to the backend
reply = protocol_schema.TextReplyToMessage(
message_id=str(msg_id),
user_message_id=str(event.message_id),
user=protocol_schema.User(
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
),
text=event.content,
)
if isinstance(task, protocol_schema.RankConversationRepliesTask | protocol_schema.RankInitialPromptsTask):
reply = protocol_schema.MessageRanking(
message_id=str(msg_id),
ranking=[int(r) - 1 for r in event.content.replace(" ", "").split(",")],
user=protocol_schema.User(
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
),
)
elif isinstance(task, protocol_schema.ReplyToConversationTask | protocol_schema.InitialPromptTask):
reply = protocol_schema.TextReplyToMessage(
message_id=str(msg_id),
user_message_id=str(event.message_id),
user=protocol_schema.User(
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
),
text=event.content,
)
else:
logger.critical(f"Unexpected task type received: {task.type}")
raise ValueError(f"Unexpected task type received: {task.type}")
logger.debug(f"Sending reply to backend: {reply!r}")
# Get next task
@@ -111,63 +191,55 @@ async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType)
logger.info(f"New task {new_task}")
if new_task.type == TaskType.done:
await ctx.author.send("Task completed")
await ctx.author.send(embed=plain_embed("Task completed"))
completed = True
continue
else:
logger.critical(f"Unexpected task type received: {new_task.type}")
# Send a message in the log channel that the task is complete
# TODO: Maybe do something with the msg ID so users can rate the "answer"
assert ctx.guild_id is not None
# Send a message in all the log channels that the task is complete
conn: Connection = ctx.bot.d.db
guild_settings = await GuildSettings.from_db(conn, ctx.guild_id)
async with conn.cursor() as cursor:
await cursor.execute("SELECT log_channel_id FROM guild_settings")
log_channel_ids = await cursor.fetchall()
if guild_settings is not None and guild_settings.log_channel_id is not None:
channels = [
ctx.bot.cache.get_guild_channel(id[0]) or await ctx.bot.rest.fetch_channel(id[0])
for id in log_channel_ids
]
channel = await ctx.bot.rest.fetch_channel(guild_settings.log_channel_id)
assert isinstance(channel, hikari.TextableChannel) # option converter
done_embed = (
hikari.Embed(
title="Task Completion",
description=f"`{task.type}` completed by {ctx.author.mention}",
color=hikari.Color(0x00FF00),
timestamp=datetime.now().astimezone(),
)
.add_field("Total Tasks", "0", inline=True)
.add_field("Server Ranking", "0/0", inline=True)
.add_field("Global Ranking", "0/0", inline=True)
.set_footer(f"Task ID: {task.id}")
)
await channel.send(EMPTY, embed=done_embed)
done_embed = task_complete_embed(task, ctx.author.mention)
# This will definitely get the bot rate limited, but that's a future problem
asyncio.gather(*(ch.send(embed=done_embed) for ch in channels if isinstance(ch, hikari.TextableChannel)))
# ask the user if they want to do another task
choice_view = ChoiceView(timeout=MAX_TASK_ACCEPT_TIME)
msg = await ctx.author.send("Would you like another task?", components=choice_view)
await choice_view.start(msg)
await choice_view.wait()
another_task_view = YesNoView(timeout=MAX_TASK_ACCEPT_TIME)
msg = await ctx.author.send(embed=plain_embed("Would you like another task?"), components=another_task_view)
await another_task_view.start(msg)
await another_task_view.wait()
match choice_view.choice:
match another_task_view.choice:
case False | None:
done = True
await ctx.author.send("Exiting, goodbye!")
await msg.edit(embed=plain_embed("Exiting, goodbye!"))
case True:
pass
async def _select_task(
ctx: lightbulb.SlashContext, task_type: TaskRequestType, user: protocol_schema.User | None = None
ctx: lightbulb.Context, task_type: TaskRequestType, user: protocol_schema.User | None = None
) -> tuple[protocol_schema.Task | None, str]:
"""Present tasks to the user until they accept one, cancel, or time out."""
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
logger.debug(f"Starting task selection for {task_type}")
# Loop until the user accepts a task, cancels, or times out
msg: hikari.UndefinedOr[hikari.Message] = hikari.UNDEFINED
while True:
logger.debug(f"Requesting task of type {task_type}")
task = await oasst_api.fetch_task(task_type, user)
resp, msg_id = await _send_task(ctx, task)
resp, msg = await _send_task(ctx, task, msg)
msg_id = str(msg.id)
logger.debug(f"User choice: {resp}")
match resp:
@@ -179,25 +251,24 @@ async def _select_task(
case "next":
logger.info(f"Task {task.id} rejected, sending NACK")
await oasst_api.nack_task(task.id, "rejected")
await ctx.author.send("Sending next task...")
continue
case "cancel":
logger.info(f"Task {task.id} canceled, sending NACK")
await oasst_api.nack_task(task.id, "canceled")
await ctx.author.send("Task canceled. Exiting")
await ctx.author.send(embed=plain_embed("Task canceled. Exiting"))
return None, msg_id
case None:
logger.info(f"Task {task.id} timed out, sending NACK")
await oasst_api.nack_task(task.id, "timed out")
await ctx.author.send("Task timed out. Exiting")
await ctx.author.send(embed=plain_embed("Task timed out. Exiting"))
return None, msg_id
async def _send_task(
ctx: lightbulb.SlashContext, task: protocol_schema.Task
) -> tuple[t.Literal["accept", "next", "cancel"] | None, str]:
ctx: lightbulb.Context, task: protocol_schema.Task, msg: hikari.UndefinedOr[hikari.Message]
) -> tuple[t.Literal["accept", "next", "cancel"] | None, hikari.Message]:
"""Send a task to the user.
Returns the user's choice and the message ID of the task message.
@@ -206,37 +277,38 @@ async def _send_task(
# but the tasks aren't discord specific so that doesn't really make sense.
embed: hikari.UndefinedOr[hikari.Embed] = hikari.UNDEFINED
content: hikari.UndefinedOr[str] = hikari.UNDEFINED
# Create an embed based on the task's type
if task.type == TaskRequestType.initial_prompt:
assert isinstance(task, protocol_schema.InitialPromptTask)
logger.debug("sending initial prompt task")
embed = _initial_prompt_embed(task)
content = initial_prompt_message(task)
elif task.type == TaskRequestType.rank_initial_prompts:
assert isinstance(task, protocol_schema.RankInitialPromptsTask)
logger.debug("sending rank initial prompt task")
embed = _rank_initial_prompt_embed(task)
content = rank_initial_prompts_message(task)
elif task.type == TaskRequestType.rank_prompter_replies:
assert isinstance(task, protocol_schema.RankPrompterRepliesTask)
logger.debug("sending rank user reply task")
embed = _rank_prompter_reply_embed(task)
content = rank_prompter_reply_message(task)
elif task.type == TaskRequestType.rank_assistant_replies:
assert isinstance(task, protocol_schema.RankAssistantRepliesTask)
logger.debug("sending rank assistant reply task")
embed = _rank_assistant_reply_embed(task)
content = rank_assistant_reply_message(task)
elif task.type == TaskRequestType.prompter_reply:
assert isinstance(task, protocol_schema.PrompterReplyTask)
logger.debug("sending user reply task")
embed = _prompter_reply_embed(task)
content = prompter_reply_message(task)
elif task.type == TaskRequestType.assistant_reply:
assert isinstance(task, protocol_schema.AssistantReplyTask)
logger.debug("sending assistant reply task")
embed = _assistant_reply_embed(task)
content = assistant_reply_message(task)
elif task.type == TaskRequestType.summarize_story:
raise NotImplementedError
@@ -248,24 +320,34 @@ async def _send_task(
raise ValueError(f"unknown task type {task.type}")
view = TaskAcceptView(timeout=MAX_TASK_ACCEPT_TIME)
msg = await ctx.author.send(
EMPTY,
embed=embed,
components=view,
)
if not msg:
msg = await ctx.author.send(
content,
embed=embed,
components=view,
)
else:
await msg.edit(
content,
embed=embed,
components=view,
)
assert msg is not None
# Set the choice id as the current msg id
ctx.bot.d.currently_working[ctx.author.id] = (msg, task.id)
await view.start(msg)
await view.wait()
return view.choice, str(msg.id)
return view.choice, msg
def _validate_user_input(content: str | None, task: protocol_schema.Task) -> bool:
"""Returns whether the user's input is valid for the task type."""
def _validate_user_input(content: str | None, task: protocol_schema.Task) -> tuple[bool, str]:
"""Returns whether the user's input is valid for the task type and an error message."""
if content is None:
return False
return False, "No input provided"
# User message input
if (
@@ -277,22 +359,28 @@ def _validate_user_input(content: str | None, task: protocol_schema.Task) -> boo
task,
protocol_schema.InitialPromptTask | protocol_schema.PrompterReplyTask | protocol_schema.AssistantReplyTask,
)
return len(content) > 0
return len(content) > 0, "Message must be at least one character long."
# Ranking tasks
elif task.type == TaskRequestType.rank_prompter_replies or task.type == TaskRequestType.rank_assistant_replies:
assert isinstance(task, protocol_schema.RankPrompterRepliesTask | protocol_schema.RankAssistantRepliesTask)
num_replies = len(task.replies)
rankings = content.split(",")
return set(rankings) == {str(i) for i in range(1, num_replies + 1)} and len(rankings) == num_replies
rankings = content.replace(" ", "").split(",")
return (
set(rankings) == {str(i) for i in range(1, num_replies + 1)} and len(rankings) == num_replies,
"Message must contain numbers for all replies.",
)
elif task.type == TaskRequestType.rank_initial_prompts:
assert isinstance(task, protocol_schema.RankInitialPromptsTask)
num_prompts = len(task.prompts)
rankings = content.split(",")
return set(rankings) == {str(i) for i in range(1, num_prompts + 1)} and len(rankings) == num_prompts
rankings = content.replace(" ", "").split(",")
return (
set(rankings) == {str(i) for i in range(1, num_prompts + 1)} and len(rankings) == num_prompts,
"Message must contain numbers for all prompts.",
)
elif task.type == TaskRequestType.summarize_story:
raise NotImplementedError
@@ -316,22 +404,29 @@ class TaskAcceptView(miru.View):
async def accept_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
logger.info("Accept button pressed")
self.choice = "accept"
await ctx.message.edit(component=None)
self.stop()
@miru.button(label="Next Task", custom_id="next_task", row=0, style=hikari.ButtonStyle.SECONDARY)
async def next_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
logger.info("Next button pressed")
self.choice = "next"
await ctx.message.edit(component=None)
self.stop()
@miru.button(label="Cancel", custom_id="cancel", row=0, style=hikari.ButtonStyle.DANGER)
async def cancel_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
logger.info("Cancel button pressed")
self.choice = "cancel"
await ctx.message.edit(component=None)
self.stop()
async def on_timeout(self) -> None:
if self.message is not None:
await self.message.edit(component=None)
class ChoiceView(miru.View):
class YesNoView(miru.View):
"""View with two buttons: yes and no.
The view stops once one of the buttons is pressed and the choice is stored in the `choice` attribute.
@@ -342,115 +437,18 @@ class ChoiceView(miru.View):
@miru.button(label="Yes", custom_id="yes", style=hikari.ButtonStyle.SUCCESS)
async def yes_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
self.choice = True
await ctx.message.edit(component=None)
self.stop()
@miru.button(label="No", custom_id="no", style=hikari.ButtonStyle.DANGER)
async def no_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
self.choice = False
await ctx.message.edit(component=None)
self.stop()
################################################################
# Template Embeds #
################################################################
# TODO: Maybe implement a better way of creating embeds, like `from_json` or something
def _initial_prompt_embed(task: protocol_schema.InitialPromptTask) -> hikari.Embed:
return (
hikari.Embed(title="Initial Prompt", description=f"Hint: {task.hint}", timestamp=datetime.now().astimezone())
.set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512")
.set_footer(text=f"OASST Assistant | {task.id}")
)
def _rank_initial_prompt_embed(task: protocol_schema.RankInitialPromptsTask) -> hikari.Embed:
embed = (
hikari.Embed(
title="Rank Initial Prompt",
description="Rank the following tasks from best to worst (1,2,3,4,5)",
timestamp=datetime.now().astimezone(),
)
.set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512")
.set_footer(text=f"OASST Assistant | {task.id}")
)
for i, prompt in enumerate(task.prompts):
embed.add_field(name=f"Prompt {i + 1}", value=prompt, inline=False)
return embed
def _rank_prompter_reply_embed(task: protocol_schema.RankPrompterRepliesTask) -> hikari.Embed:
embed = (
hikari.Embed(
title="Rank User Reply",
description="Rank the following user replies from best to worst. e.g. 1,2,5,3,4",
timestamp=datetime.now().astimezone(),
)
.set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") # TODO: update image
.set_footer(text=f"OASST Assistant | {task.id}")
)
for i, reply in enumerate(task.replies):
embed.add_field(name=f"Reply {i + 1}", value=reply, inline=False)
return embed
def _rank_assistant_reply_embed(task: protocol_schema.RankAssistantRepliesTask) -> hikari.Embed:
embed = (
hikari.Embed(
title="Rank Assistant Reply",
description="Rank the following assistant replies from best to worst. e.g. 1,2,5,3,4",
timestamp=datetime.now().astimezone(),
)
.set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") # TODO: update image
.set_footer(text=f"OASST Assistant | {task.id}")
)
for i, reply in enumerate(task.replies):
embed.add_field(name=f"Reply {i + 1}", value=reply, inline=False)
return embed
def _prompter_reply_embed(task: protocol_schema.PrompterReplyTask) -> hikari.Embed:
embed = (
hikari.Embed(
title="User Reply",
description=f"""\
Send the next message in the conversation as if you were the user.
{'Hint: ' if task.hint else ''}
""",
timestamp=datetime.now().astimezone(),
)
# .set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") # TODO: change image
.set_footer(text=f"OASST Assistant | {task.id}")
)
for message in task.conversation.messages:
embed.add_field(name="Assistant" if message.is_assistant else "User", value=message.text, inline=False)
return embed
def _assistant_reply_embed(task: protocol_schema.AssistantReplyTask) -> hikari.Embed:
embed = (
hikari.Embed(
title="User Reply",
description="Send the next message in the conversation as if you were the user.",
timestamp=datetime.now().astimezone(),
)
# .set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") # TODO: change image
.set_footer(text=f"OASST Assistant | {task.id}")
)
for message in task.conversation.messages:
embed.add_field(name="Assistant" if message.is_assistant else "User", value=message.text, inline=False)
return embed
async def on_timeout(self) -> None:
if self.message is not None:
await self.message.edit(component=None)
def load(bot: lightbulb.BotApp):
+207
View File
@@ -0,0 +1,207 @@
"""All user-facing messages and embeds."""
from datetime import datetime
import hikari
from oasst_shared.schemas import protocol as protocol_schema
NUMBER_EMOJIS = [":one:", ":two:", ":three:", ":four:", ":five:", ":six:", ":seven:", ":eight:", ":nine:", ":ten:"]
NL = "\n"
###
# Reusable 'components'
###
def _h1(text: str) -> str:
return f"\n:small_blue_diamond: __**{text}**__ :small_blue_diamond:"
def _h2(text: str) -> str:
return f"__**{text}**__"
def _h3(text: str) -> str:
return f"__{text}__"
def _writing_prompt(text: str) -> str:
return f":pencil: _{text}_"
def _ranking_prompt(text: str) -> str:
return f":trophy: _{text}_"
def _response_prompt(text: str) -> str:
return f":speech_balloon: _{text}_"
def _summarize_prompt(text: str) -> str:
return f":notepad_spiral: _{text}_"
def _user(text: str | None) -> str:
return f"""\
:person_red_hair: {_h3("User")}:{f"{NL}> **{text}**" if text is not None else ""}
"""
def _assistant(text: str | None) -> str:
return f"""\
:robot: {_h3("Assistant")}:{f"{NL}> {text}" if text is not None else ""}
"""
def _make_ordered_list(items: list[str]) -> list[str]:
return [f"{num} {item}" for num, item in zip(NUMBER_EMOJIS, items)]
def _ordered_list(items: list[str]) -> str:
return "\n\n".join(_make_ordered_list(items))
def _conversation(conv: protocol_schema.Conversation) -> str:
return "\n".join([_assistant(msg.text) if msg.is_assistant else _user(msg.text) for msg in conv.messages])
def _hint(hint: str | None) -> str:
return f"{NL}Hint: {hint}" if hint else ""
###
# Messages
###
def initial_prompt_message(task: protocol_schema.InitialPromptTask) -> str:
"""Creates the message that gets sent to users when they request an `initial_prompt` task."""
return f"""\
{_h1("INITIAL PROMPT")}
{_writing_prompt("Please provide an initial prompt to the assistant.")}
{_hint(task.hint)}
"""
def rank_initial_prompts_message(task: protocol_schema.RankInitialPromptsTask) -> str:
"""Creates the message that gets sent to users when they request a `rank_initial_prompts` task."""
return f"""\
{_h1("RANK INITIAL PROMPTS")}
{_ranking_prompt("Reply with the numbers of best to worst prompts separated by commas (example: '4,1,3,2')")}
{_ordered_list(task.prompts)}
"""
def rank_prompter_reply_message(task: protocol_schema.RankPrompterRepliesTask) -> str:
"""Creates the message that gets sent to users when they request a `rank_prompter_replies` task."""
return f"""\
{_h1("RANK PROMPTER REPLIES")}
{_ranking_prompt("Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')")}
{_conversation(task.conversation)}
{_user(None)}
{_ordered_list(task.replies)}
"""
def rank_assistant_reply_message(task: protocol_schema.RankAssistantRepliesTask) -> str:
"""Creates the message that gets sent to users when they request a `rank_assistant_replies` task."""
return f"""\
{_h1("RANK ASSISTANT REPLIES")}
{_ranking_prompt("Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')")}
{_conversation(task.conversation)}
{_assistant(None)}
{_ordered_list(task.replies)}
"""
def prompter_reply_message(task: protocol_schema.PrompterReplyTask) -> str:
"""Creates the message that gets sent to users when they request a `prompter_reply` task."""
return f"""\
{_h1("PROMPTER REPLY")}
{_response_prompt("Please provide a reply to the assistant.")}
{_conversation(task.conversation)}
{_hint(task.hint)}
"""
def assistant_reply_message(task: protocol_schema.AssistantReplyTask) -> str:
"""Creates the message that gets sent to users when they request a `assistant_reply` task."""
return f"""\
{_h1("ASSISTANT REPLY")}
{_response_prompt("Please provide a reply to the assistant.")}
{_conversation(task.conversation)}
"""
def confirm_text_response_message(content: str) -> str:
return f"""\
{_h2("CONFIRM RESPONSE")}
> {content}
"""
def confirm_ranking_response_message(content: str, items: list[str]) -> str:
user_rankings = [int(r) for r in content.replace(" ", "").split(",")]
original_list = _make_ordered_list(items)
user_ranked_list = "\n\n".join([original_list[r - 1] for r in user_rankings])
return f"""\
{_h2("CONFIRM RESPONSE")}
{user_ranked_list}
"""
###
# Embeds
###
def task_complete_embed(task: protocol_schema.Task, mention: str) -> hikari.Embed:
return (
hikari.Embed(
title="Task Completion",
description=f"`{task.type}` completed by {mention}",
color=hikari.Color(0x00FF00),
timestamp=datetime.now().astimezone(),
)
.add_field("Total Tasks", "0", inline=True)
.add_field("Server Ranking", "0/0", inline=True)
.add_field("Global Ranking", "0/0", inline=True)
.set_footer(f"Task ID: {task.id}")
)
def invalid_user_input_embed(error_message: str) -> hikari.Embed:
return hikari.Embed(
title="Invalid User Input",
description=error_message,
color=hikari.Color(0xFF0000),
timestamp=datetime.now().astimezone(),
)
def plain_embed(text: str) -> hikari.Embed:
return hikari.Embed(color=0x36393F, description=text)
+1 -1
View File
@@ -8,7 +8,7 @@ class Settings(BaseSettings):
bot_token: str = Field(env="BOT_TOKEN", default="")
declare_global_commands: int = Field(env="DECLARE_GLOBAL_COMMANDS", default=0)
owner_ids: list[int] = Field(env="OWNER_IDS", default_factory=list)
prefix: str = Field(env="PREFIX", default="./")
prefix: str = Field(env="PREFIX", default="/")
oasst_api_url: str = Field(env="OASST_API_URL", default="http://localhost:8080")
oasst_api_key: str = Field(env="OASST_API_KEY", default="")
-7
View File
@@ -24,13 +24,6 @@ def format_time(dt: datetime, fmt: t.Literal["t", "T", "D", "f", "F", "R"]) -> s
raise ValueError(f"`fmt` must be 't', 'T', 'D', 'f', 'F' or 'R', not {fmt}")
EMPTY = "\u200d"
"""Zero-width joiner.
This appears as an empty message in Discord.
"""
def mention(
id: hikari.Snowflakeish,
type: t.Literal["channel", "role", "user"],
-2
View File
@@ -1,5 +1,3 @@
aiohttp # http client
aiohttp[speedups] # speedups for aiohttp
aiosqlite # database
hikari # discord framework
hikari-lightbulb # command handler
+1 -1
View File
@@ -9,7 +9,7 @@ services:
# Use `docker compose up frontend-dev --build --attach-dependencies` to start all services needed to work on the frontend.
frontend-dev:
image: sverrirab/sleep
depends_on: [db, webdb, adminer, maildev, backend]
depends_on: [db, webdb, adminer, maildev, backend, redis]
# This DB is for the FastAPI Backend.
db:
+1
View File
@@ -5,6 +5,7 @@ COPY ./backend/requirements.txt /app/requirements.txt
RUN pip install --no-cache-dir --upgrade -r /app/requirements.txt
ENV PORT 8080
EXPOSE 8080
COPY ./oasst-shared /oasst-shared
RUN pip install -e /oasst-shared
+3
View File
@@ -99,6 +99,9 @@ The main tasks are a) generation of response text and b) ranking of responses.
The following sections describe the data schemas for each of these tasks. Both
should be implementable in parquet files.
Note: These files are meant to be consumed by ML algorithms and should ideally
be produced from the above files.
## Common Data Structures
```python
+38
View File
@@ -0,0 +1,38 @@
# Train using supervised examples
Requirements
```
wandb
evaluate
datasets
transformers
torch
```
Start training reward model
```bash
python trainer.py --configs defaults galactica-125
```
## Dataset
For now we only support webgpt and summary dataset from OpenAI. Once
open-asisstant dataset are available it will be added here.
## Model
TBD
## Results
Experimental results in wandb
[here](https://wandb.ai/sanagnos/supervised-finetuning?workspace=user-sanagnos).
## TODOS
- decide on a model
- add special token to declare prompt and reply. Do nto freeze the weights for
these
- Merge utils etc with reward model
@@ -0,0 +1,37 @@
defaults:
learning_rate: 1e-5
gradient_checkpointing: false
gradient_accumulation_steps: 32
per_device_train_batch_size: 2
per_device_eval_batch_size: 2
weight_decay: 0.00
warmup_steps: 600
eval_steps: 200
save_steps: 500
max_length: 512
num_train_epochs: 3
logging_steps: 10
max_grad_norm: 2.0
save_total_limit: 4
eval_accumulation_steps:
freeze_layer:
datasets:
- webgpt
cache_dir: ~/.cache
loss_fn: CrossEntropyLoss
eval_size:
log_dir: "base"
galactica-125:
learning_rate: 5e-5
model_name: facebook/galactica-125m
weight_decay: 0.01
warmup_steps: 600
gradient_checkpointing: false
gradient_accumulation_steps: 2
per_device_train_batch_size: 4
per_device_eval_batch_size: 4
debug:
eval_steps: 20
eval_size: 100
@@ -0,0 +1,67 @@
from datasets import load_dataset
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, Subset
class SquadV2Dataset(Dataset):
def __init__(self, cache_dir, split):
self.dataset = load_dataset("squad_v2", cache_dir=cache_dir, split=split)
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
data = self.dataset[idx]
# dummy return first answer
return "".join([data["title"], ". ", data["context"], " " + data["question"]]), data["answers"]["text"][0]
class WebGPT(Dataset):
def __init__(self) -> None:
super().__init__()
dataset = load_dataset("openai/webgpt_comparisons")
questions = {}
# using prompt as our index will allows us
# to add additional generated prompt later
self.index2question = {}
for row in dataset["train"]:
question = row["question"]["full_text"]
if question not in self.index2question:
self.index2question[len(self.index2question)] = question
# only keep the best answer
questions[question] = row["answer_0" if row["score_0"] > row["score_1"] else "answer_1"]
self.questions = questions
def __len__(self):
return len(self.index2question)
def __getitem__(self, index):
question = self.index2question[index]
answer = self.questions[question]
return [question, answer]
def train_val_dataset(dataset, val_split=0.2):
train_idx, val_idx = train_test_split(
list(range(len(dataset))), test_size=val_split, random_state=666, shuffle=True
)
return Subset(dataset, train_idx), Subset(dataset, val_idx)
def get_one_dataset(conf, dataset_name):
dataset_name = dataset_name.lower()
if dataset_name == "squadv2":
raise ValueError("SquadV2 is not diverse enough for generation .. ")
train = SquadV2Dataset(conf.cache_dir, "train")
eval = SquadV2Dataset(conf.cache_dir, "validation")
elif dataset_name == "webgpt":
dataset = WebGPT()
train, eval = train_val_dataset(dataset, val_split=0.2)
else:
raise ValueError(f"Unknown dataset {dataset_name}")
return train, eval
@@ -0,0 +1,85 @@
from dataclasses import dataclass
from typing import Optional, Union
import numpy as np
import torch
from torch.nn import functional as F
from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase
@dataclass
class DialogueDataCollator:
"""
Expects a list of texts corresponding to a sequence of [question, answer, question, answer, ...] pairs.
"""
tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
def __call__(self, features):
# TODO add special tokens for question and answer here
# additional_special_tokens = ['<question>', '<answer>']
prompt_tokens = ["Question: ", "Answer: "]
flatten_messages = []
label_masks = []
for messages in features:
assert len(messages) % 2 == 0, "Number of messages must be even"
messages = [
(prompt_tokens[0] if i % 2 == 0 else "") + x + ((" " + prompt_tokens[1]) if i % 2 == 0 else "")
for i, x in enumerate(messages)
]
# Add a way for the model to terminate generation, reinitialize prompter
messages.append(prompt_tokens[0])
flatten_messages.append(
self.tokenizer(
"".join(messages),
truncation=True,
max_length=self.max_length,
return_offsets_mapping=True,
)
)
message_change_indices = np.cumsum([len(x) for x in messages[:-1]])
# for each token an integer indicating the index of the message it belongs to. Just to create the label mask.
# TEXT: Question: Hello, how are you? Answer: I am fine. Question: What is your name? Answer: My name is John.
# MESSAGE_INDICES: 0 0 0 0 0 0 1 1 1 2 2 2 2 2 2 3 3 3 3
# If no result in next, we are predicting the last termination token(s)
message_indices = list(
map(
lambda x: next((i for i, val in enumerate(message_change_indices) if val >= x), -2),
list(map(lambda x: x[1], flatten_messages[-1]["offset_mapping"])),
)
)
label_mask = np.roll(list(map(lambda x: x % 2 == 1, message_indices)), -1, -1)
try:
label_mask[[i for i in range(len(message_indices)) if message_indices[i] == -2][0] - 1] = True
except IndexError:
# an aftermath of padding
pass
label_masks.append(label_mask)
flatten_messages[-1].pop("offset_mapping")
batch = self.tokenizer.pad(
flatten_messages,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="pt",
)
dim = batch["input_ids"].shape[-1]
batch["label_masks"] = torch.stack([F.pad(torch.tensor(x), (0, dim - len(x))) for x in label_masks])
for k in list(batch.keys()):
if k not in ["input_ids", "attention_mask", "label_masks"]:
batch.pop(k)
return batch
+15
View File
@@ -0,0 +1,15 @@
from torch import nn
class CrossEntropyLoss(nn.CrossEntropyLoss):
def __init__(self, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean"):
super(CrossEntropyLoss, self).__init__(weight, size_average, ignore_index, reduce, reduction)
def forward(self, input, target, mask=None):
if mask is not None:
mask = mask.view(-1)
input = input.view(-1, input.size(-1))
target = target.view(-1)
input = input[mask]
target = target[mask]
return super(CrossEntropyLoss, self).forward(input, target)
+200
View File
@@ -0,0 +1,200 @@
import argparse
import os
from dataclasses import dataclass
from distutils.util import strtobool
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from torch.utils.data import Dataset
from transformers import (
DataCollator,
EvalPrediction,
PreTrainedModel,
PreTrainedTokenizerBase,
Trainer,
TrainerCallback,
TrainingArguments,
get_cosine_schedule_with_warmup,
)
from utils import get_dataset, get_loss, get_model, get_tokenizer, read_yamls
os.environ["WANDB_PROJECT"] = "supervised-finetuning"
@dataclass
class CustomTrainingArguments(TrainingArguments):
loss_function: str = "CrossEntropyLoss"
def compute_metrics(eval_pred):
pred_ids = eval_pred.predictions
labels = eval_pred.label_ids
return {"accuracy": (pred_ids[labels > 0] == labels[labels > 0]).mean()}
def preprocess_logits_for_metrics(logits, labels):
pred_ids = torch.argmax(logits, dim=-1)
return pred_ids
class SFTTrainer(Trainer):
def __init__(
self,
model: Union[PreTrainedModel, nn.Module] = None,
args: TrainingArguments = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Dataset] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
model_init: Callable[[], PreTrainedModel] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
):
super().__init__(
model,
args,
data_collator,
train_dataset,
eval_dataset,
tokenizer,
model_init,
compute_metrics,
callbacks,
optimizers,
preprocess_logits_for_metrics,
)
self.loss_fct = get_loss(args.loss_function)
def fetch_scheduler(self):
return get_cosine_schedule_with_warmup(
self.optimizer,
num_warmup_steps=self.args.warmup_steps,
num_training_steps=self.num_train_steps,
num_cycles=1,
last_epoch=-1,
)
def compute_loss(self, model, inputs, return_outputs=False):
labels_mask = inputs.pop("label_masks")
outputs = model(**inputs)
loss = self.loss_fct(outputs.get("logits"), torch.roll(inputs["input_ids"], -1, -1), mask=labels_mask)
return (loss, outputs) if return_outputs else loss
def _compute_loss(self, model, inputs):
labels_mask = inputs.pop("label_masks")
inputs = self._prepare_inputs(inputs)
outputs = model(**inputs)
logits = outputs.get("logits")
targets = torch.roll(inputs["input_ids"], -1, -1)
loss = self.loss_fct(outputs.get("logits"), targets, mask=labels_mask)
return loss, logits, targets, labels_mask
def prediction_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
with torch.no_grad():
loss, logits, labels, labels_mask = self._compute_loss(model, inputs)
labels[~labels_mask] = -1
loss = loss.mean().detach()
if self.args.prediction_loss_only:
return (loss, None, None)
return (loss, logits, labels)
def _strtobool(x):
return bool(strtobool(x))
def argument_parsing(notebook=False, notebook_args=None):
parser = argparse.ArgumentParser()
parser.add_argument("--configs", nargs="+", required=True)
if notebook:
args, remaining = parser.parse_known_args(notebook_args)
else:
args, remaining = parser.parse_known_args()
# Config from YAML
conf = {}
configs = read_yamls("./configs")
for name in args.configs:
if "," in name:
for n in name.split(","):
conf.update(configs[n])
else:
conf.update(configs[name])
# Override config from command-line
parser = argparse.ArgumentParser()
for key, value in conf.items():
type_ = type(value) if value is not None else str
if type_ == bool:
type_ = _strtobool
parser.add_argument(f"--{key}", type=type_, default=value)
return parser.parse_args(remaining)
if __name__ == "__main__":
training_conf = argument_parsing()
model = get_model(training_conf)
tokenizer = get_tokenizer(training_conf)
train, evals, collate_fn = get_dataset(training_conf, tokenizer)
args = CustomTrainingArguments(
output_dir=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned",
num_train_epochs=training_conf.num_train_epochs,
warmup_steps=training_conf.warmup_steps,
loss_function=training_conf.loss_fn,
learning_rate=float(training_conf.learning_rate),
fp16=True,
gradient_checkpointing=training_conf.gradient_checkpointing,
gradient_accumulation_steps=training_conf.gradient_accumulation_steps,
per_device_train_batch_size=training_conf.per_device_train_batch_size,
per_device_eval_batch_size=training_conf.per_device_eval_batch_size,
weight_decay=training_conf.weight_decay,
max_grad_norm=training_conf.max_grad_norm,
logging_steps=training_conf.logging_steps,
save_total_limit=training_conf.save_total_limit,
evaluation_strategy="steps",
eval_steps=training_conf.eval_steps,
save_steps=training_conf.save_steps,
eval_accumulation_steps=training_conf.eval_accumulation_steps,
report_to="wandb",
)
assert len(evals) > 0
trainer = SFTTrainer(
model,
args,
train_dataset=train,
eval_dataset=evals,
data_collator=collate_fn,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
trainer.train()
+111
View File
@@ -0,0 +1,111 @@
from pathlib import Path
import yaml
from custom_datasets import get_one_dataset
from custom_datasets.dialogue_collator import DialogueDataCollator
from losses import CrossEntropyLoss
from sklearn.model_selection import train_test_split
from torch.utils.data import ConcatDataset, Subset
from transformers import AutoModelForCausalLM, AutoTokenizer
SUPPORTED_MODELS = ["galactica"]
def get_tokenizer(conf):
tokenizer = AutoTokenizer.from_pretrained(conf.model_name, cache_dir=conf.cache_dir)
if "galactica" in conf.model_name:
tokenizer.add_special_tokens({"pad_token": "<pad>", "eos_token": "</s>"})
return tokenizer
def get_model(conf):
if not any([x in conf.model_name for x in SUPPORTED_MODELS]):
raise ValueError(
f"Model {conf.model_name} not supported. Supported models: {SUPPORTED_MODELS}. "
"To include more make sure the masking is dne correctly... (decoder only supported for now)"
)
model = AutoModelForCausalLM.from_pretrained(conf.model_name, cache_dir=conf.cache_dir)
if conf.freeze_layer:
model = freeze_top_n_layers(model, conf.freeze_layer)
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([p.numel() for p in model_parameters])
print("Number of trainable parameters: {}M".format(int(params / 1e6)))
return model
def get_dataset(conf, tokenizer):
train_datasets, evals = [], {}
for dataset_name in conf.datasets:
train, val = get_one_dataset(conf, dataset_name)
train_datasets.append(train)
evals[dataset_name] = Subset(val, list(range(min(len(val), conf.eval_size)))) if conf.eval_size else val
train = ConcatDataset(train_datasets)
collate_fn = DialogueDataCollator(tokenizer, max_length=conf.max_length)
return train, evals, collate_fn
def get_loss(loss):
if loss == "CrossEntropyLoss":
return CrossEntropyLoss()
else:
raise ValueError(f"Loss {loss} not supported")
def read_yamls(dir):
conf = {}
no_conf = True
for config_file in Path(dir).glob("**/*.yaml"):
no_conf = False
with config_file.open("r") as f:
conf.update(yaml.safe_load(f))
if no_conf:
print(f"WARNING: No yaml files found in {dir}")
return conf
def train_val_dataset(dataset, val_split=0.2):
train_idx, val_idx = train_test_split(
list(range(len(dataset))), test_size=val_split, random_state=666, shuffle=True
)
return Subset(dataset, train_idx), Subset(dataset, val_idx)
def freeze_top_n_layers(model, target_layers):
# its possible we can simply detect which module is a ModuleList
# and simply freeze the module without doing string parsing
for name, param in model.named_parameters():
if "embed" in name:
param.requires_grad = False
elif ".layer" in name or ".h." in name:
tokens = name.split(".")
layer_ = None
for token in tokens:
if token.isdigit():
layer_ = int(token)
break
if layer_ is not None and layer_ < target_layers:
# print('freeze ', layer_, name)
param.requires_grad = False
return model
if __name__ == "__main__":
from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained("bigscience/bloomz-560m")
freeze_top_n_layers(model, 10)
print(model.state_dict().keys())
File diff suppressed because one or more lines are too long
+7 -4
View File
@@ -24,10 +24,13 @@ only described in the notebook
Charts showing detailed memory usages and times for different sentence lengths
and batch sizes are inside the notebook Quick overview batch size 16, sentence
length 4k for training, batch size 128 sentence length 4k for inference | Model
name | Training memory| Training speed | Inference Memory| Inference Speed| |
:---: | :---: | :---: |:---: | :---: | |original| 11.8GB | 2.40s| 4.8GB|16.48s|
|unbiased| 12GB| 1.09s| 4.8GB | 5.59s| |multilingual|14GB| 1.00s| 5.5GB| 4.89s|
length 4k for training, batch size 128 sentence length 4k for Inference
| Model name | Training memory | Training speed | Inference Memory | Inference Speed |
| :----------: | :-------------: | :------------: | :--------------: | :-------------: |
| original | 11.8GB | 2.40s | 4.8GB | 16.48s |
| unbiased | 12GB | 1.09s | 4.8GB | 5.59s |
| multilingual | 14GB | 1.00s | 5.5GB | 4.89s |
# Filtering quality
+37 -4
View File
@@ -1,12 +1,15 @@
"""API Client for interacting with the OASST backend."""
import enum
import typing as t
from http import HTTPStatus
from typing import Optional, Type
from uuid import UUID
import aiohttp
from loguru import logger
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from pydantic import ValidationError
# TODO: Move to `protocol`?
@@ -27,7 +30,7 @@ class TaskType(str, enum.Enum):
class OasstApiClient:
"""API Client for interacting with the OASST backend."""
def __init__(self, backend_url: str, api_key: str):
def __init__(self, backend_url: str, api_key: str, session: Optional[aiohttp.ClientSession] = None):
"""Create a new OasstApiClient.
Args:
@@ -35,8 +38,12 @@ class OasstApiClient:
backend_url (str): The base backend URL.
api_key (str): The API key to use for authentication.
"""
logger.debug("Opening OasstApiClient session")
self.session = aiohttp.ClientSession()
if session is None:
logger.debug("Opening OasstApiClient session")
session = aiohttp.ClientSession()
self.session = session
self.backend_url = backend_url
self.api_key = api_key
@@ -56,7 +63,33 @@ class OasstApiClient:
"""Make a POST request to the backend."""
logger.debug(f"POST {self.backend_url}{path} DATA: {data}")
response = await self.session.post(f"{self.backend_url}{path}", json=data, headers={"X-API-Key": self.api_key})
response.raise_for_status()
# If the response is not a 2XX, check to see
# if the json has the fields to create an
# OasstError.
if response.status >= 300:
data = await response.json()
try:
oasst_error = protocol_schema.OasstErrorResponse(**(data or {}))
raise OasstError(
error_code=oasst_error.error_code,
message=oasst_error.message,
)
except ValidationError as e:
logger.debug(f"Got error from API but could not parse: {e}")
raw_response = await response.text()
logger.debug(f"Raw response: {raw_response}")
raise OasstError(
raw_response,
OasstErrorCode.GENERIC_ERROR,
HTTPStatus(response.status),
)
if response.status == 204:
# No content
return None
return await response.json()
def _parse_task(self, data: Optional[dict[str, t.Any]]) -> protocol_schema.Task:
@@ -4,6 +4,7 @@ from typing import List, Literal, Optional, Union
from uuid import UUID, uuid4
import pydantic
from oasst_shared.exceptions import OasstErrorCode
from pydantic import BaseModel, Field
@@ -293,3 +294,10 @@ class UserScore(BaseModel):
class LeaderboardStats(BaseModel):
leaderboard: List[UserScore]
class OasstErrorResponse(BaseModel):
"""The format of an error response from the OASST API."""
error_code: OasstErrorCode
message: str
+2
View File
@@ -11,5 +11,7 @@ setup(
author="OASST Team",
install_requires=[
"pydantic==1.9.1",
"aiohttp==3.8.3",
"aiohttp[speedups]",
],
)
@@ -1,12 +1,21 @@
from typing import Any
from unittest import mock
from uuid import uuid4
import aiohttp
import pytest
from oasst_shared.api_client import OasstApiClient
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
@pytest.fixture
def oasst_api_client_mocked():
"""
A an oasst_api_client pointed at the mocked backend.
Relies on ./scripts/backend-development/start-mock-server.sh
being run.
"""
client = OasstApiClient(backend_url="http://localhost:8080", api_key="123")
yield client
# TODO The fixture should close this connection, but there seems to be a bug
@@ -15,6 +24,30 @@ def oasst_api_client_mocked():
# await client.close()
class MockClientSession(aiohttp.ClientSession):
response: Any
def set_response(self, response: Any):
self.response = response
async def post(self, *args, **kwargs):
return self.response
@pytest.fixture
def mock_http_session():
yield MockClientSession()
@pytest.fixture
def oasst_api_client_fake_http(mock_http_session):
"""
An oasst_api_client that uses a mocked http session. No real requests are made.
"""
client = OasstApiClient(backend_url="http://localhost:8080", api_key="123", session=mock_http_session)
yield client
@pytest.mark.asyncio
@pytest.mark.parametrize("task_type", protocol_schema.TaskRequestType)
async def test_can_fetch_task(task_type: protocol_schema.TaskRequestType, oasst_api_client_mocked: OasstApiClient):
@@ -49,3 +82,47 @@ async def test_can_post_interaction(oasst_api_client_mocked: OasstApiClient):
)
is not None
)
@pytest.mark.asyncio
async def test_can_handle_oasst_error_from_api(
oasst_api_client_fake_http: OasstApiClient,
mock_http_session: MockClientSession,
):
# Return a 400 response with an OasstErrorResponse body
response_body = protocol_schema.OasstErrorResponse(
error_code=OasstErrorCode.GENERIC_ERROR,
message="Some error",
)
status_code = 400
mock_http_session.set_response(
mock.AsyncMock(
status=status_code,
text=mock.AsyncMock(return_value=response_body.json()),
json=mock.AsyncMock(return_value=response_body.dict()),
)
)
with pytest.raises(OasstError):
await oasst_api_client_fake_http.post("/some-path", data={})
@pytest.mark.asyncio
async def test_can_handle_unknown_error_from_api(
oasst_api_client_fake_http: OasstApiClient,
mock_http_session: MockClientSession,
):
response_body = "Internal Server Error"
status_code = 500
mock_http_session.set_response(
mock.AsyncMock(
status=status_code,
text=mock.AsyncMock(return_value=response_body),
json=mock.AsyncMock(return_value=None),
)
)
with pytest.raises(OasstError):
await oasst_api_client_fake_http.post("/some-path", data={})
@@ -4,6 +4,6 @@ parent_path=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P )
# switch to bot directory
pushd "$parent_path/../../discord-bot"
python3 __main__.py
python3 -m bot
popd
+2
View File
@@ -4,6 +4,8 @@ parent_path=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P )
# switch to backend directory
pushd "$parent_path/../../oasst-shared"
set -xe
pytest .
popd
+4 -1
View File
@@ -3,7 +3,10 @@ const nextConfig = {
output: "standalone",
reactStrictMode: true,
experimental: {
scrollRestoration: true,
/* Disabling this for now only because it causes a warning in the console that cannot be silenced for eslint
If this can be resolved, we should re-enable this.
*/
// scrollRestoration: true,
},
};
+4444 -30
View File
File diff suppressed because it is too large Load Diff
+3
View File
@@ -22,6 +22,7 @@
"@dnd-kit/core": "^6.0.6",
"@dnd-kit/modifiers": "^6.0.1",
"@dnd-kit/sortable": "^7.0.1",
"@dnd-kit/utilities": "^3.2.1",
"@emotion/react": "^11.10.5",
"@emotion/styled": "^11.10.5",
"@headlessui/react": "^1.7.7",
@@ -39,9 +40,11 @@
"eslint-plugin-simple-import-sort": "^8.0.0",
"focus-visible": "^5.2.0",
"framer-motion": "^6.5.1",
"install": "^0.13.0",
"next": "13.0.6",
"next-auth": "^4.18.6",
"nodemailer": "^6.8.0",
"npm": "^9.2.0",
"postcss-focus-visible": "^7.1.0",
"react": "18.2.0",
"react-dom": "18.2.0",
+2 -2
View File
@@ -1,11 +1,11 @@
export function AuthLayout({ children }) {
return (
<main className="flex items-center justify-center sm:py-4 subpixel-antialiased">
<div className="flex items-center justify-center sm:py-4 subpixel-antialiased">
<div className="flex items-center w-full max-w-2xl flex-col px-4 sm:px-6">
<div className="flex-auto items-center justify-center w-full py-10 px-4 sm:mx-0 sm:flex-none sm:rounded-2xl sm:p-4">
{children}
</div>
</div>
</main>
</div>
);
}
+38 -6
View File
@@ -1,16 +1,48 @@
import { CircleBackground } from "./CircleBackground";
import { useColorMode } from "@chakra-ui/react";
import { useId } from "react";
import { Container } from "./Container";
export function CallToAction() {
function CircleBackground({ width = 558, height = 558, ...props }) {
const id = useId();
const { colorMode } = useColorMode();
const baseRingColor = colorMode === "light" ? "#777" : "#000";
const gradStopColor = colorMode === "light" ? "#fff" : "#000";
return (
<section id="join-us" className="relative overflow-hidden bg-gray-900 py-20 sm:py-28">
<svg viewBox="0 0 558 558" width={width} height={height} fill="none" aria-hidden="true" {...props}>
<defs>
<linearGradient id={id} x1="79" y1="16" x2="105" y2="237" gradientUnits="userSpaceOnUse">
<stop stopColor={gradStopColor} />
<stop offset="1" stopColor={baseRingColor} stopOpacity="0" />
</linearGradient>
</defs>
<path
opacity=".2"
d="M1 279C1 125.465 125.465 1 279 1s278 124.465 278 278-124.465 278-278 278S1 432.535 1 279Z"
stroke={baseRingColor}
/>
<path d="M1 279C1 125.465 125.465 1 279 1" stroke={`url(#${id})`} strokeLinecap="round" />
</svg>
);
}
export function CallToAction() {
const { colorMode } = useColorMode();
const bgColorClass = colorMode === "light" ? "bg-gray-900" : "bg-gray-50";
const headingColorClass = colorMode === "light" ? "text-white" : "text-black";
const textColorClass = colorMode === "light" ? "text-gray-300" : "text-black";
return (
<section id="join-us" className={`relative overflow-hidden py-20 sm:py-28 ${bgColorClass} ${textColorClass}`}>
<div className="absolute top-1/2 left-20 -translate-y-1/2 sm:left-1/2 sm:-translate-x-1/2">
<CircleBackground color="#fff" className="animate-spin-slower" />
<CircleBackground className="animate-spin-slower" />
</div>
<Container className="relative">
<div className="mx-auto max-w-md sm:text-center">
<h2 className="text-3xl font-medium tracking-tight text-white sm:text-4xl">Join Us</h2>
<p className="mt-4 text-lg text-gray-300">
<h2 className={`text-3xl font-medium tracking-tight sm:text-4xl ${headingColorClass}`}>Join Us</h2>
<p className="mt-4 text-lg">
All open source projects begin with people like you. Open source is the belief that if we collaborate we can
together gift our knowledge and technology to the world for the benefit of humanity. Are you in? Find us
here:
@@ -1,22 +0,0 @@
import { useId } from "react";
export function CircleBackground({ color, width = 558, height = 558, ...props }) {
const id = useId();
return (
<svg viewBox="0 0 558 558" width={width} height={height} fill="none" aria-hidden="true" {...props}>
<defs>
<linearGradient id={id} x1="79" y1="16" x2="105" y2="237" gradientUnits="userSpaceOnUse">
<stop stopColor={color} />
<stop offset="1" stopColor={color} stopOpacity="0" />
</linearGradient>
</defs>
<path
opacity=".2"
d="M1 279C1 125.465 125.465 1 279 1s278 124.465 278 278-124.465 278-278 278S1 432.535 1 279Z"
stroke={color}
/>
<path d="M1 279C1 125.465 125.465 1 279 1" stroke={`url(#${id})`} strokeLinecap="round" />
</svg>
);
}
+10 -3
View File
@@ -1,3 +1,5 @@
import { useColorMode } from "@chakra-ui/react";
import { Container } from "./Container";
const faqs = [
@@ -25,11 +27,16 @@ const faqs = [
];
export function Faq() {
const { colorMode } = useColorMode();
const headingColorClass = colorMode === "light" ? "text-gray-900" : "text-white";
const textColorClass = colorMode === "light" ? "text-gray-700" : "text-gray-100";
return (
<section id="faq" aria-labelledby="faqs-title" className="border-t border-gray-200 py-20 sm:py-32">
<Container className="">
<div className="mx-auto max-w-2xl lg:mx-0">
<h2 id="faqs-title" className="text-3xl font-medium tracking-tight text-gray-900">
<h2 id="faqs-title" className={`text-3xl font-medium tracking-tight ${headingColorClass}`}>
Frequently Asked Questions
</h2>
{/* <p className="mt-2 text-lg text-gray-600">
@@ -52,8 +59,8 @@ export function Faq() {
<ul role="list" className="space-y-10">
{column.map((faq, faqIndex) => (
<li key={faqIndex}>
<h3 className="text-lg font-semibold leading-6 text-gray-900">{faq.question}</h3>
<p className="mt-4 text-sm text-gray-700">{faq.answer}</p>
<h3 className={`text-lg font-semibold leading-6 ${headingColorClass}`}>{faq.question}</h3>
<p className={`mt-4 text-sm ${textColorClass}`}>{faq.answer}</p>
</li>
))}
</ul>
+59 -60
View File
@@ -1,71 +1,70 @@
import { useColorMode } from "@chakra-ui/react";
import Image from "next/image";
import Link from "next/link";
import { Container } from "./Container";
export function Footer() {
return (
<footer className="border-t border-gray-200 bg-white">
<main>
<Container className="">
<div className="flex flex-wrap justify-between gap-y-12 py-10 lg:items-center lg:py-16">
<div className="flex items-center text-black pr-8">
<Link href="/" aria-label="Home" className="flex items-center">
<Image src="/images/logos/logo.svg" className="mx-auto object-fill" width="52" height="52" alt="logo" />
</Link>
const { colorMode } = useColorMode();
const bgColorClass = colorMode === "light" ? "bg-transparent" : "bg-gray-800";
const borderClass = colorMode === "light" ? "border-slate-200" : "border-transparent";
<div className="ml-2">
<p className="text-base font-bold">Open Assistant</p>
<p className="text-sm">Conversational AI for everyone.</p>
return (
<footer className={bgColorClass}>
<div className={`flex mx-auto max-w-7xl justify-between py-10 px-10 border-t ${borderClass}`}>
<div className="flex items-center pr-8">
<Link href="/" aria-label="Home" className="flex items-center">
<Image src="/images/logos/logo.svg" className="mx-auto object-fill" width="52" height="52" alt="logo" />
</Link>
<div className="ml-2">
<p className="text-base font-bold">Open Assistant</p>
<p className="text-sm">Conversational AI for everyone.</p>
</div>
</div>
<nav className="flex justify-center gap-20">
<nav className="flex justify-center gap-20">
<div className="flex flex-col text-sm leading-7">
<b>Legal</b>
<div className="flex flex-col leading-5">
<Link href="/privacy-policy" aria-label="Privacy Policy" className="hover:underline underline-offset-2">
Privacy Policy
</Link>
<Link
href="/terms-of-service"
aria-label="Terms of Service"
className="hover:underline underline-offset-2"
>
Terms of Service
</Link>
</div>
</div>
<nav className="flex justify-center gap-20">
<div className="flex flex-col text-sm leading-7">
<b>Legal</b>
<div className="flex flex-col leading-5">
<Link
href="/privacy-policy"
aria-label="Privacy Policy"
className="hover:underline underline-offset-2"
>
Privacy Policy
</Link>
<Link
href="/terms-of-service"
aria-label="Terms of Service"
className="hover:underline underline-offset-2"
>
Terms of Service
</Link>
</div>
<div className="flex flex-col text-sm leading-7">
<b>Connect</b>
<div className="flex flex-col leading-5">
<Link
href="https://github.com/LAION-AI/Open-Assistant"
rel="noopener noreferrer nofollow"
target="_blank"
aria-label="Privacy Policy"
className="hover:underline underline-offset-2"
>
Github
</Link>
<Link
href="https://discord.gg/pXtnYk9c"
rel="noopener noreferrer nofollow"
target="_blank"
aria-label="Terms of Service"
className="hover:underline underline-offset-2"
>
Discord
</Link>
</div>
<div className="flex flex-col text-sm leading-7">
<b>Connect</b>
<div className="flex flex-col leading-5">
<Link
href="https://github.com/LAION-AI/Open-Assistant"
rel="noopener noreferrer nofollow"
target="_blank"
aria-label="Privacy Policy"
className="hover:underline underline-offset-2"
>
Github
</Link>
<Link
href="https://discord.gg/pXtnYk9c"
rel="noopener noreferrer nofollow"
target="_blank"
aria-label="Terms of Service"
className="hover:underline underline-offset-2"
>
Discord
</Link>
</div>
</div>
</nav>
</div>
</Container>
</main>
</div>
</nav>
{/* </div> */}
</nav>
</div>
</footer>
);
}
@@ -22,4 +22,15 @@ const Template = (args) => {
};
export const Default = Template.bind({});
Default.args = { session: { data: { user: { name: "StoryBook user" } }, status: "authenticated" }, transparent: false };
Default.args = {
session: {
data: {
user: {
name: "StoryBook user",
},
},
status: "authenticated",
},
transparent: false,
borderClass: undefined,
};
+68 -61
View File
@@ -1,14 +1,12 @@
import { Button } from "@chakra-ui/react";
import { Box, Button, useColorMode } from "@chakra-ui/react";
import { Popover } from "@headlessui/react";
import clsx from "clsx";
import { AnimatePresence, motion } from "framer-motion";
import Image from "next/image";
import Link from "next/link";
import { useSession } from "next-auth/react";
import { FaUser } from "react-icons/fa";
import { Container } from "src/components/Container";
import { NavLinks } from "./NavLinks";
import { ColorModeIconToggle } from "../UI/ColorModeIconToggle";
import { UserMenu } from "./UserMenu";
function MenuIcon(props) {
@@ -55,63 +53,72 @@ function AccountButton() {
}
export function Header(props) {
const transparent = props.transparent ?? false;
const { colorMode } = useColorMode();
const borderClass = props.transparent
? ""
: colorMode === "light"
? "border-b border-gray-400"
: "border-b border-zinc-800";
return (
<header className={clsx(!transparent && "bg-white")}>
<nav>
<Container className="relative z-10 flex justify-between py-8">
<div className="relative z-10 flex items-center gap-16">
<Link href="/" aria-label="Home" className="flex items-center">
<Image src="/images/logos/logo.svg" className="mx-auto object-fill" width="50" height="50" alt="logo" />
<span className="text-2xl font-bold ml-3">Open Assistant</span>
</Link>
</div>
<div className="flex items-center gap-4">
<Popover className="lg:hidden">
{({ open }) => (
<>
<Popover.Button
className="relative z-10 inline-flex items-center rounded-lg stroke-gray-900 p-2 hover:bg-gray-200/50 hover:stroke-gray-600 active:stroke-gray-900 [&:not(:focus-visible)]:focus:outline-none"
aria-label="Toggle site navigation"
>
{({ open }) => (open ? <ChevronUpIcon className="h-6 w-6" /> : <MenuIcon className="h-6 w-6" />)}
</Popover.Button>
<AnimatePresence initial={false}>
{open && (
<>
<Popover.Overlay
static
as={motion.div}
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
exit={{ opacity: 0 }}
className="fixed inset-0 z-1 bg-gray-300/60 backdrop-blur"
/>
<Popover.Panel
static
as={motion.div}
initial={{ opacity: 0, y: -32 }}
animate={{ opacity: 1, y: 0 }}
exit={{
opacity: 0,
y: -32,
transition: { duration: 0.2 },
}}
className="absolute inset-x-0 top-0 z-0 origin-top rounded-b-2xl bg-white px-6 pb-6 pt-32 shadow-2xl shadow-gray-900/20"
>
<div className="mt-8 flex flex-col gap-4"></div>
</Popover.Panel>
</>
)}
</AnimatePresence>
</>
)}
</Popover>
<AccountButton />
<UserMenu />
</div>
</Container>
</nav>
</header>
<nav className={`oa-basic-theme ${borderClass}`}>
<Box className="flex mx-auto max-w-7xl justify-between py-8 px-10">
<div className="relative z-10 flex items-center gap-16">
<Link href="/" aria-label="Home" className="flex items-center">
<Image src="/images/logos/logo.svg" className="mx-auto object-fill" width="50" height="50" alt="logo" />
<span className="text-2xl font-bold ml-3">Open Assistant</span>
</Link>
</div>
<div className="flex items-center gap-4">
<Popover className="lg:hidden">
{({ open }) => (
<>
<Popover.Button
className="relative z-10 inline-flex items-center rounded-lg stroke-gray-900 p-2 hover:bg-gray-200/50 hover:stroke-gray-600 active:stroke-gray-900 [&:not(:focus-visible)]:focus:outline-none"
aria-label="Toggle site navigation"
>
{({ open }) => (open ? <ChevronUpIcon className="h-6 w-6" /> : <MenuIcon className="h-6 w-6" />)}
</Popover.Button>
<AnimatePresence initial={false}>
{open && (
<>
<Popover.Overlay
static
as={motion.div}
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
exit={{ opacity: 0 }}
className="fixed inset-0 z-1 bg-gray-300/60 backdrop-blur"
/>
<Popover.Panel
static
as={motion.div}
initial={{ opacity: 0, y: -32 }}
animate={{ opacity: 1, y: 0 }}
exit={{
opacity: 0,
y: -32,
transition: { duration: 0.2 },
}}
className="absolute inset-x-0 top-0 z-0 origin-top rounded-b-2xl bg-white px-6 pb-6 pt-32 shadow-2xl shadow-gray-900/20"
>
<div className="space-y-4">
<MobileNavLink href="/#join-us">Join Us</MobileNavLink>
<MobileNavLink href="/#faqs">FAQs</MobileNavLink>
</div>
<div className="mt-8 flex flex-col gap-4"></div>
</Popover.Panel>
</>
)}
</AnimatePresence>
</>
)}
</Popover>
<AccountButton />
<UserMenu />
<ColorModeIconToggle className="ml-5" />
</div>
</Box>
</nav>
);
}
+8 -2
View File
@@ -1,9 +1,15 @@
import { useColorMode } from "@chakra-ui/react";
import { AnimatePresence, motion } from "framer-motion";
import Link from "next/link";
import { useState } from "react";
export function NavLinks(): JSX.Element {
const [hoveredIndex, setHoveredIndex] = useState(null);
const { colorMode } = useColorMode();
const linkColor = colorMode === "light" ? "text-gray-700 hover:text-gray-900" : "text-gray-50 hover:text-white";
const hoverBgColor = colorMode === "light" ? "bg-gray-100" : "bg-gray-800";
return (
<>
@@ -14,14 +20,14 @@ export function NavLinks(): JSX.Element {
<Link
key={label}
href={href}
className="relative -my-2 -mx-3 rounded-lg px-3 py-2 text-sm text-gray-700 transition-colors delay-150 hover:text-gray-900 hover:delay-[0ms]"
className={`${linkColor} relative -my-2 -mx-3 rounded-lg px-3 py-2 text-sm transition-colors delay-150 hover:delay-[0ms]`}
onMouseEnter={() => setHoveredIndex(index)}
onMouseLeave={() => setHoveredIndex(null)}
>
<AnimatePresence>
{hoveredIndex === index && (
<motion.span
className="absolute inset-0 rounded-lg bg-gray-100"
className={`${hoverBgColor} absolute inset-0 rounded-lg`}
layoutId="hoverBackground"
initial={{ opacity: 0 }}
animate={{ opacity: 1, transition: { duration: 0.15 } }}
+8 -6
View File
@@ -1,3 +1,4 @@
import { Box, useColorModeValue } from "@chakra-ui/react";
import { Popover } from "@headlessui/react";
import { AnimatePresence, motion } from "framer-motion";
import Image from "next/image";
@@ -7,6 +8,7 @@ import { FaCog, FaSignOutAlt } from "react-icons/fa";
export function UserMenu() {
const { data: session } = useSession();
const backgroundColor = useColorModeValue("#FFFFFF", "#000000");
if (!session) {
return <></>;
@@ -26,7 +28,7 @@ export function UserMenu() {
{({ open }) => (
<>
<Popover.Button aria-label="Toggle Account Options" className="flex">
<div className="flex items-center gap-4 p-1 lg:pr-6 rounded-full bg-white border border-slate-300/70 hover:bg-gray-200/50 transition-colors duration-300">
<div className="flex items-center gap-4 p-1 lg:pr-6 rounded-full border border-slate-300/70 hover:bg-gray-200/50 transition-colors duration-300">
<Image
src="/images/temp-avatars/av1.jpg"
alt="Profile Picture"
@@ -41,7 +43,7 @@ export function UserMenu() {
</Popover.Button>
<AnimatePresence initial={false}>
{open && (
<>
<Box backgroundColor={backgroundColor}>
<Popover.Panel
static
as={motion.div}
@@ -52,9 +54,9 @@ export function UserMenu() {
y: -10,
transition: { duration: 0.2 },
}}
className="absolute right-0 mt-3 w-screen max-w-xs p-4 rounded-md bg-white border border-slate-300/70"
className="absolute right-0 mt-3 w-screen bg-inherit max-w-xs p-4 rounded-md border border-slate-300/70"
>
<div className="flex flex-col gap-1">
<Box className="flex flex-col gap-1">
{accountOptions.map((item) => (
<a
key={item.name}
@@ -81,9 +83,9 @@ export function UserMenu() {
<p>Sign Out</p>
</div>
</a>
</div>
</Box>
</Popover.Panel>
</>
</Box>
)}
</AnimatePresence>
</>
+22 -10
View File
@@ -1,3 +1,4 @@
import { useColorMode } from "@chakra-ui/react";
import Image from "next/image";
import { useId } from "react";
@@ -6,6 +7,10 @@ import { Container } from "./Container";
function BackgroundIllustration(props) {
const id = useId();
const { colorMode } = useColorMode();
const baseRingColor = colorMode === "light" ? "#d4d4d4" : "#005a69";
const gradStopColor = colorMode === "light" ? "#06b6d4" : "#00f2ff";
return (
<div {...props}>
<svg
@@ -16,14 +21,14 @@ function BackgroundIllustration(props) {
>
<path
d="M1025 513c0 282.77-229.23 512-512 512S1 795.77 1 513 230.23 1 513 1s512 229.23 512 512Z"
stroke="#D4D4D4"
stroke={baseRingColor}
strokeOpacity="0.7"
/>
<path d="M513 1025C230.23 1025 1 795.77 1 513" stroke={`url(#${id}-gradient-1)`} strokeLinecap="round" />
<defs>
<linearGradient id={`${id}-gradient-1`} x1="1" y1="513" x2="1" y2="1025" gradientUnits="userSpaceOnUse">
<stop stopColor="#06b6d4" />
<stop offset="1" stopColor="#06b6d4" stopOpacity="0" />
<stop stopColor={gradStopColor} />
<stop offset="1" stopColor={gradStopColor} stopOpacity="0" />
</linearGradient>
</defs>
</svg>
@@ -35,14 +40,14 @@ function BackgroundIllustration(props) {
>
<path
d="M913 513c0 220.914-179.086 400-400 400S113 733.914 113 513s179.086-400 400-400 400 179.086 400 400Z"
stroke="#D4D4D4"
stroke={baseRingColor}
strokeOpacity="0.7"
/>
<path d="M913 513c0 220.914-179.086 400-400 400" stroke={`url(#${id}-gradient-2)`} strokeLinecap="round" />
<defs>
<linearGradient id={`${id}-gradient-2`} x1="913" y1="513" x2="913" y2="913" gradientUnits="userSpaceOnUse">
<stop stopColor="#06b6d4" />
<stop offset="1" stopColor="#06b6d4" stopOpacity="0" />
<stop stopColor={gradStopColor} />
<stop offset="1" stopColor={gradStopColor} stopOpacity="0" />
</linearGradient>
</defs>
</svg>
@@ -51,17 +56,24 @@ function BackgroundIllustration(props) {
}
export function Hero() {
const { colorMode } = useColorMode();
const pTextColor = colorMode === "light" ? "text-gray-600" : "text-white";
const fancyTextGradientClasses =
colorMode === "light" ? "from-blue-600 via-sky-400 to-blue-700" : "from-blue-500 via-sky-300 to-blue-400";
return (
<div className="overflow-hidden py-20 sm:py-32 lg:pb-32 xl:pb-36">
<Container className="">
<div className="lg:grid lg:grid-cols-12 lg:gap-x-8 lg:gap-y-20">
<div className="relative z-10 mx-auto max-w-2xl lg:col-span-7 lg:max-w-none lg:pt-6 xl:col-span-6">
<h1 className="text-5xl mb-6 font-bold tracking-tight text-gray-900">Open Assistant</h1>
<p className="mt-8 text-3xl inline bg-gradient-to-r from-indigo-600 via-sky-400 to-indigo-700 bg-clip-text font-display tracking-tight text-transparent">
<h1 className="text-5xl mb-6 font-bold tracking-tight">Open Assistant</h1>
<p
className={`bg-gradient-to-r ${fancyTextGradientClasses} mt-8 text-3xl inline bg-clip-text font-display tracking-tight text-transparent`}
>
<b>Conversational AI for everyone.</b>
</p>
<p className="mt-6 text-lg text-gray-600">We believe we can create a revolution.</p>
<p className="mt-6 text-lg text-gray-600">
<p className={`mt-6 text-lg ${pTextColor}`}>We believe we can create a revolution.</p>
<p className={`mt-6 text-lg ${pTextColor}`}>
In the same way that Stable Diffusion helped the world make art and images in new ways, we want to improve
the world by providing amazing conversational AI.
</p>
@@ -1,12 +1,18 @@
import { Progress } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
export const LoadingScreen = ({ text }) => (
<div className="bg-slate-100">
<Progress size="xs" isIndeterminate />
{text && (
<div className="flex h-full">
<div className="text-xl font-bold text-gray-800 mx-auto my-auto">{text}</div>
</div>
)}
</div>
);
export const LoadingScreen = ({ text }) => {
const { colorMode } = useColorMode();
const mainClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
return (
<div className={`h-full ${mainClasses}`}>
<Progress size="sm" isIndeterminate />
{text && (
<div className="flex h-full">
<div className="text-xl font-bold mx-auto my-auto">{text}</div>
</div>
)}
</div>
);
};
+15 -2
View File
@@ -1,4 +1,6 @@
import { Grid } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import { FlaggableElement } from "./FlaggableElement";
export interface Message {
@@ -6,13 +8,24 @@ export interface Message {
is_assistant: boolean;
}
const getColor = (isAssistant: boolean) => (isAssistant ? "bg-slate-800" : "bg-sky-900");
const getBgColor = (isAssistant: boolean, colorMode: "light" | "dark") => {
if (colorMode === "light") {
return isAssistant ? "bg-slate-800" : "bg-sky-900";
} else {
return isAssistant ? "bg-black" : "bg-sky-900";
}
};
export const Messages = ({ messages, post_id }: { messages: Message[]; post_id: string }) => {
const { colorMode } = useColorMode();
const items = messages.map(({ text, is_assistant }: Message, i: number) => {
return (
<FlaggableElement text={text} post_id={post_id} key={i + text}>
<div key={i + text} className={`${getColor(is_assistant)} p-4 rounded-md text-white whitespace-pre-wrap`}>
<div
key={i + text}
className={`${getBgColor(is_assistant, colorMode)} p-4 rounded-md text-white whitespace-pre-wrap`}
>
{text}
</div>
</FlaggableElement>
+9 -6
View File
@@ -2,9 +2,9 @@ import { Flex } from "@chakra-ui/react";
import {
closestCenter,
DndContext,
KeyboardSensor,
PointerSensor,
TouchSensor,
KeyboardSensor,
useSensor,
useSensors,
} from "@dnd-kit/core";
@@ -23,6 +23,7 @@ import { SortableItem } from "./SortableItem";
export interface SortableProps {
items: ReactNode[];
onChange: (newSortedIndices: number[]) => void;
className?: string;
}
interface SortableItems {
@@ -31,18 +32,18 @@ interface SortableItems {
item: ReactNode;
}
export const Sortable = ({ items, onChange }: SortableProps) => {
export const Sortable = (props: SortableProps) => {
const [itemsWithIds, setItemsWithIds] = useState<SortableItems[]>([]);
useEffect(() => {
setItemsWithIds(
items.map((item, idx) => ({
props.items.map((item, idx) => ({
item,
id: idx + 1, // +1 because dndtoolkit has problem with "falsy" ids
originalIndex: idx,
}))
);
}, [items]);
}, [props.items]);
const sensors = useSensors(
useSensor(PointerSensor),
@@ -50,6 +51,8 @@ export const Sortable = ({ items, onChange }: SortableProps) => {
useSensor(KeyboardSensor, { coordinateGetter: sortableKeyboardCoordinates })
);
const extraClasses = props.className || "";
return (
<DndContext
sensors={sensors}
@@ -58,7 +61,7 @@ export const Sortable = ({ items, onChange }: SortableProps) => {
modifiers={[restrictToVerticalAxis]}
>
<SortableContext items={itemsWithIds} strategy={verticalListSortingStrategy}>
<Flex direction="column" gap={2}>
<Flex direction="column" gap={2} className={extraClasses}>
{itemsWithIds.map(({ id, item }) => (
<SortableItem key={id} id={id}>
{item}
@@ -78,7 +81,7 @@ export const Sortable = ({ items, onChange }: SortableProps) => {
const oldIndex = items.findIndex((x) => x.id === active.id);
const newIndex = items.findIndex((x) => x.id === over.id);
const newArray = arrayMove(items, oldIndex, newIndex);
onChange(newArray.map((item) => item.originalIndex));
props.onChange(newArray.map((item) => item.originalIndex));
return newArray;
});
}
@@ -1,8 +1,9 @@
import { Button } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import { useSortable } from "@dnd-kit/sortable";
import { CSS } from "@dnd-kit/utilities";
import { RxDragHandleDots2 } from "react-icons/rx";
import { PropsWithChildren } from "react";
import { RxDragHandleDots2 } from "react-icons/rx";
export const SortableItem = ({ children, id }: PropsWithChildren<{ id: number }>) => {
const { attributes, listeners, setNodeRef, transform, transition } = useSortable({ id });
@@ -13,9 +14,15 @@ export const SortableItem = ({ children, id }: PropsWithChildren<{ id: number }>
touchAction: "none",
};
const { colorMode } = useColorMode();
const themedClasses =
colorMode === "light"
? "bg-slate-600 hover:bg-slate-500 text-white"
: "bg-black hover:bg-slate-900 text-white ring-1 ring-white/30 ring-inset hover:ring-slate-200/50";
return (
<li
className="grid grid-cols-[min-content_1fr] items-center rounded-lg shadow-md gap-x-2 p-2 bg-white hover:bg-slate-50"
className={`grid grid-cols-[min-content_1fr] items-center rounded-lg shadow-md gap-x-2 p-2 ${themedClasses}`}
ref={setNodeRef}
style={style}
>
@@ -0,0 +1,20 @@
import { useColorMode } from "@chakra-ui/react";
interface SurveyCardProps {
className?: string;
children: React.ReactNode;
}
export const SurveyCard = (props: SurveyCardProps) => {
const extraClases = props.className || "";
const { colorMode } = useColorMode();
const baseCardClasses = "rounded-lg h-full block p-6";
const cardClases =
colorMode === "light"
? `${baseCardClasses} bg-slate-50 text-gray-800 shadow-lg ${extraClases}`
: // `${baseCardClasses} bg-slate-800 text-white shadow-xl${extraClases}`;
`${baseCardClasses} bg-slate-800 text-slate-400 shadow-xl ring-1 ring-white/10 ring-inset ${extraClases}`;
return <div className={cardClases}>{props.children}</div>;
};
@@ -0,0 +1,40 @@
import { useColorMode } from "@chakra-ui/react";
import { Flex } from "@chakra-ui/react";
import { SkipButton } from "src/components/Buttons/Skip";
import { SubmitButton } from "src/components/Buttons/Submit";
import { TaskInfo } from "src/components/TaskInfo/TaskInfo";
interface TaskControlsProps {
// we need a task type
// eslint-disable-next-line @typescript-eslint/no-explicit-any
tasks: any[];
className?: string;
onSubmitResponse: (task: { id: string }) => void;
onSkip: () => void;
}
export const TaskControls = (props: TaskControlsProps) => {
const extraClases = props.className || "";
const { colorMode } = useColorMode();
const baseClasses = "flex flex-row justify-items-stretch mb-8 p-4 rounded-lg max-w-7xl mx-auto";
const taskControlClases =
colorMode === "light"
? `${baseClasses} bg-white text-gray-800 shadow-lg ${extraClases}`
: `${baseClasses} bg-slate-800 text-slate-400 shadow-xl ring-1 ring-white/10 ring-inset ${extraClases}`;
const endTask = props.tasks[props.tasks.length - 1];
return (
<section className={taskControlClases}>
<TaskInfo id={props.tasks[0].id} output="Submit your answer" />
<Flex justify="center" ml="auto" gap={2}>
<SkipButton>Skip</SkipButton>
{endTask.task.type !== "task_done" ? (
<SubmitButton onClick={() => props.onSubmitResponse(props.tasks[0])}>Submit</SubmitButton>
) : (
<SubmitButton onClick={props.onSkip}>Next Task</SubmitButton>
)}
</Flex>
</section>
);
};
@@ -0,0 +1,16 @@
import { SurveyCard } from "src/components/Survey/SurveyCard";
export const TwoColumnsWithCards = ({ children }: { children: React.ReactNode[] }) => {
if (!Array.isArray(children) || children.length !== 2) {
throw new Error("TwoColumns expects 2 children");
}
const [first, second] = children;
return (
<div className="mb-8 mx-auto max-w-7xl lt-lg:mb-12 grid lg:gap-x-12 lg:grid-cols-2">
<SurveyCard>{first}</SurveyCard>
<SurveyCard className="lg:mt-0 lt-lg:mt-6">{second}</SurveyCard>
</div>
);
};
+1 -1
View File
@@ -1,6 +1,6 @@
export const TaskInfo = ({ id, output }: { id: string; output: string }) => {
return (
<div className="grid grid-cols-[min-content_auto] gap-x-2 text-gray-700">
<div className="grid grid-cols-[min-content_auto] gap-x-2 ">
<b>Prompt</b>
<span data-cy="task-id">{id}</span>
<b>Output</b>
@@ -1,12 +1,24 @@
import { Flex } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import React from "react";
import { TaskOption } from "./TaskOption";
import { TaskOptions } from "./TaskOptions";
export const TaskSelection = () => {
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
return (
<Flex gap={10} wrap="wrap" justifyContent="space-evenly" width="full" height="full" alignItems={"center"}>
<Flex
gap={10}
wrap="wrap"
justifyContent="space-evenly"
width="full"
height="full"
alignItems={"center"}
className={mainBgClasses}
>
<TaskOptions key="create" title="Create">
{/* <TaskOption
alt="Summarize Stories"
-14
View File
@@ -1,14 +0,0 @@
export const TwoColumns = ({ children }: { children: React.ReactNode[] }) => {
if (!Array.isArray(children) || children.length !== 2) {
throw new Error("TwoColumns expects 2 children");
}
const [first, second] = children;
return (
<section className="mb-8 lt-lg:mb-12 grid lg:gap-x-12 lg:grid-cols-2">
<div className="rounded-lg shadow-lg h-full block bg-white p-6">{first}</div>
<div className="rounded-lg shadow-lg h-full block bg-white p-6 mt-6 lg:mt-0">{second}</div>
</section>
);
};
@@ -0,0 +1,23 @@
import { useColorMode } from "@chakra-ui/react";
import { CiDark } from "react-icons/ci";
import { CiLight } from "react-icons/ci";
export function ColorModeIconToggle(props) {
const { colorMode, toggleColorMode } = useColorMode();
const propsClassName = props.className ?? "";
return (
<button
type="button"
className={`flex h-6 w-6 items-center justify-center rounded-md transition hover:bg-zinc-900/5 dark:hover:bg-white/5 ${propsClassName}`}
aria-label="Toggle dark mode"
onClick={toggleColorMode}
>
{colorMode === "light" ? (
<CiDark className="h-5 w-5 stroke-zinc-900" />
) : (
<CiLight className="h-5 w-5 stroke-white" />
)}
</button>
);
}
@@ -0,0 +1,16 @@
import { Switch, useColorMode } from "@chakra-ui/react";
import React from "react";
const ColorModeSwitch = () => {
const { colorMode, toggleColorMode } = useColorMode();
return (
<Switch
onChange={toggleColorMode}
defaultChecked={colorMode === "light"}
checked={colorMode === "light"}
size="lg"
/>
);
};
export default ColorModeSwitch;
+5 -28
View File
@@ -1,48 +1,25 @@
import "../styles/globals.css";
import "focus-visible";
import { ChakraProvider } from "@chakra-ui/react";
import { extendTheme } from "@chakra-ui/react";
import { Inter } from "@next/font/google";
import type { AppProps } from "next/app";
import { SessionProvider } from "next-auth/react";
import { getDefaultLayout, NextPageWithLayout } from "src/components/Layout";
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const inter = Inter({
subsets: ["latin"],
variable: "--font-inter",
});
const theme = extendTheme({
styles: {
global: {
body: {
bg: "white",
},
main: {
fontFamily: "Inter",
},
header: {
fontFamily: "Inter",
},
},
},
});
import { Chakra, getServerSideProps } from "../styles/Chakra";
type AppPropsWithLayout = AppProps & {
Component: NextPageWithLayout;
};
function MyApp({ Component, pageProps: { session, ...pageProps } }: AppPropsWithLayout) {
function MyApp({ Component, pageProps: { session, cookies, ...pageProps } }: AppPropsWithLayout) {
const getLayout = Component.getLayout ?? getDefaultLayout;
const page = getLayout(<Component {...pageProps} />);
return (
<ChakraProvider theme={theme}>
<Chakra cookies={cookies}>
<SessionProvider session={session}>{page}</SessionProvider>
</ChakraProvider>
</Chakra>
);
}
export { getServerSideProps };
export default MyApp;
@@ -53,7 +53,7 @@ const handler = async (req, res) => {
});
// Update the backend with our Task ID
const ackRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/tasks/${task.id}/ack`, {
await fetch(`${process.env.FASTAPI_URL}/api/v1/tasks/${task.id}/ack`, {
method: "POST",
headers: {
"X-API-Key": process.env.FASTAPI_KEY,
+1 -5
View File
@@ -7,8 +7,7 @@ import prisma from "src/lib/prismadb";
* This implicity does a few things:
* 1) Stores the answer with the Task Backend.
* 2) Records the new task in our local database.
* 3) (TODO) Acks the new task with our local task ID to the Task Backend.
* 4) Returns the newly created task to the client.
* 3) Returns the newly created task to the client.
*/
const handler = async (req, res) => {
const token = await getToken({ req });
@@ -69,9 +68,6 @@ const handler = async (req, res) => {
},
});
// TODO: Ack the task with the Task Backend using the newly created local
// task ID.
// Send the next task in the sequence to the client.
res.status(200).json(newRegisteredTask);
};
+32 -21
View File
@@ -1,13 +1,16 @@
import { Button, Input, Stack } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
import Link from "next/link";
import { getCsrfToken, getProviders, signIn } from "next-auth/react";
import React, { useRef } from "react";
import { FaBug, FaDiscord, FaEnvelope, FaGithub } from "react-icons/fa";
import { AuthLayout } from "src/components/AuthLayout";
import { Footer } from "src/components/Footer";
import { Header } from "src/components/Header";
// eslint-disable-next-line @typescript-eslint/no-unused-vars
export default function Signin({ csrfToken, providers }) {
function Signin({ csrfToken, providers }) {
const { discord, email, github, credentials } = providers;
const emailEl = useRef(null);
const signinWithEmail = (ev: React.FormEvent) => {
@@ -21,8 +24,14 @@ export default function Signin({ csrfToken, providers }) {
signIn(credentials.id, { callbackUrl: "/", username: debugUsernameEl.current.value });
}
const { colorMode } = useColorMode();
const bgColorClass = colorMode === "light" ? "bg-gray-50" : "bg-chakra-gray-900";
const buttonBgColor = colorMode === "light" ? "#2563eb" : "#2563eb";
const buttonColorScheme = colorMode === "light" ? "blue" : "dark-blue-btn";
return (
<>
<div className={bgColorClass}>
<Head>
<title>Sign Up - Open Assistant</title>
<meta name="Sign Up" content="Sign up to access Open Assistant" />
@@ -30,11 +39,11 @@ export default function Signin({ csrfToken, providers }) {
<AuthLayout>
<Stack spacing="2">
{credentials && (
<form onSubmit={signinWithDebugCredentials} className="border-2 border-orange-200 rounded-md p-4 relative">
<span className="text-orange-600 absolute -top-3 left-5 bg-white px-1">For Debugging Only</span>
<form onSubmit={signinWithDebugCredentials} className="border-2 border-orange-600 rounded-md p-4 relative">
<span className={`text-orange-600 absolute -top-3 left-5 ${bgColorClass} px-1`}>For Debugging Only</span>
<Stack>
<Input variant="outline" size="lg" placeholder="Username" ref={debugUsernameEl} />
<Button size={"lg"} leftIcon={<FaBug />} colorScheme="gray" type="submit">
<Button size={"lg"} leftIcon={<FaBug />} colorScheme={buttonColorScheme} color="white" type="submit">
Continue with Debug User
</Button>
</Stack>
@@ -43,13 +52,13 @@ export default function Signin({ csrfToken, providers }) {
{email && (
<form onSubmit={signinWithEmail}>
<Stack>
<Input data-cy="email-address" variant="outline" size="lg" placeholder="Email Address" ref={emailEl} />
<Input variant="outline" size="lg" placeholder="Email Address" ref={emailEl} />
<Button
data-cy="signin-email-button"
size={"lg"}
leftIcon={<FaEnvelope />}
colorScheme="gray"
type="submit"
colorScheme={buttonColorScheme}
color="white"
>
Continue with Email
</Button>
@@ -58,7 +67,7 @@ export default function Signin({ csrfToken, providers }) {
)}
{discord && (
<Button
bg="#5865F2"
bg={buttonBgColor}
_hover={{ bg: "#4A57E3" }}
_active={{
bg: "#454FBF",
@@ -90,29 +99,31 @@ export default function Signin({ csrfToken, providers }) {
</Stack>
<div className="pt-10 text-center">
By signing up you agree to our <br></br>
<Link href="#" aria-label="Terms of Service" className="hover:underline underline-offset-4">
<Link href="/terms-of-service" aria-label="Terms of Service" className="hover:underline underline-offset-4">
<b>Terms of Service</b>
</Link>{" "}
and{" "}
<Link href="#" aria-label="Terms of Use" className="hover:underline underline-offset-4">
<Link href="/privacy-policy" aria-label="Privacy Policy" className="hover:underline underline-offset-4">
<b>Privacy Policy</b>
</Link>
.
</div>
<hr className="mt-14 mb-4 h-px bg-gray-200 border-0" />
<div className="text-center">
Already have an account?{" "}
<Link href="#" aria-label="Log In" className="hover:underline underline-offset-4">
<b>Log In</b>
</Link>
</div>
</AuthLayout>
</>
</div>
);
}
// eslint-disable-next-line @typescript-eslint/no-unused-vars
export async function getServerSideProps(context) {
Signin.getLayout = (page) => (
<div className="grid grid-rows-[min-content_1fr_min-content] h-full justify-items-stretch">
<Header transparent={true} />
{page}
<Footer />
</div>
);
export default Signin;
export async function getServerSideProps() {
const csrfToken = await getCsrfToken();
const providers = await getProviders();
return {
+14 -27
View File
@@ -1,11 +1,10 @@
import { Flex, Textarea } from "@chakra-ui/react";
import { Container, Textarea } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import { useRef, useState } from "react";
import { SkipButton } from "src/components/Buttons/Skip";
import { SubmitButton } from "src/components/Buttons/Submit";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Messages } from "src/components/Messages";
import { TaskInfo } from "src/components/TaskInfo/TaskInfo";
import { TwoColumns } from "src/components/TwoColumns";
import { TaskControls } from "src/components/Survey/TaskControls";
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
import fetcher from "src/lib/fetcher";
import poster from "src/lib/poster";
import useSWRImmutable from "swr/immutable";
@@ -45,43 +44,31 @@ const AssistantReply = () => {
mutate();
};
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
if (isLoading) {
return <LoadingScreen text="Loading..." />;
}
if (tasks.length == 0) {
return <div className="p-6 bg-slate-100 text-gray-800">No tasks found...</div>;
return <Container className="p-6 text-center text-gray-800">No tasks found...</Container>;
}
const task = tasks[0].task;
const endTask = tasks[tasks.length - 1];
return (
<div className="p-6 bg-slate-100 text-gray-800">
<TwoColumns>
<div className={`p-12 ${mainBgClasses}`}>
<TwoColumnsWithCards>
<>
<h5 className="text-lg font-semibold">Reply as the assistant</h5>
<p className="text-lg py-1">Given the following conversation, provide an adequate reply</p>
<Messages messages={task.conversation.messages} post_id={task.id} />
</>
<Textarea name="reply" data-cy="reply" placeholder="Reply..." ref={inputRef} />
</TwoColumns>
<Textarea name="reply" placeholder="Reply..." ref={inputRef} />
</TwoColumnsWithCards>
<section className="mb-8 p-4 rounded-lg shadow-lg bg-white flex flex-row justify-items-stretch ">
<TaskInfo id={tasks[0].id} output="Submit your answer" />
<Flex justify="center" ml="auto" gap={2}>
<SkipButton>Skip</SkipButton>
{endTask.task.type !== "task_done" ? (
<SubmitButton data-cy="submit" onClick={() => submitResponse(tasks[0])}>
Submit
</SubmitButton>
) : (
<SubmitButton data-cy="next-task" onClick={fetchNextTask}>
Next Task
</SubmitButton>
)}
</Flex>
</section>
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={fetchNextTask} />
</div>
);
};
+18 -23
View File
@@ -1,11 +1,9 @@
import { Flex, Textarea } from "@chakra-ui/react";
import Head from "next/head";
import { Textarea } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import { useRef, useState } from "react";
import { SkipButton } from "src/components/Buttons/Skip";
import { SubmitButton } from "src/components/Buttons/Submit";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { TaskInfo } from "src/components/TaskInfo/TaskInfo";
import { TwoColumns } from "src/components/TwoColumns";
import { TaskControls } from "src/components/Survey/TaskControls";
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
import fetcher from "src/lib/fetcher";
import poster from "src/lib/poster";
import useSWRImmutable from "swr/immutable";
@@ -20,7 +18,7 @@ const SummarizeStory = () => {
// Fetch the very fist task. We can ignore everything except isLoading
// because the onSuccess handler will update `tasks` when ready.
const { isLoading } = useSWRImmutable("/api/new_task/summarize_story", fetcher, {
const { isLoading, mutate } = useSWRImmutable("/api/new_task/summarize_story", fetcher, {
onSuccess: (data) => {
setTasks([data]);
},
@@ -50,6 +48,14 @@ const SummarizeStory = () => {
});
};
const fetchNextTask = () => {
inputRef.current.value = "";
mutate();
};
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
if (isLoading) {
return <LoadingScreen text="Loading..." />;
}
@@ -59,31 +65,20 @@ const SummarizeStory = () => {
}
return (
<>
<Head>
<title>Summarize A Story</title>
<meta name="description" content="Summarize a story to train our model." />
</Head>
<div className={`p-12 ${mainBgClasses}`}>
<main className="p-6 h-full mx-auto bg-slate-100 text-gray-800">
<TwoColumns>
<TwoColumnsWithCards>
<>
<h5 className="text-lg font-semibold">Instruction</h5>
<p className="text-lg py-1">Summarize the following story</p>
<div className="bg-slate-800 p-6 rounded-xl text-white whitespace-pre-wrap">{tasks[0].task.story}</div>
</>
<Textarea name="summary" placeholder="Summary" ref={inputRef} />
</TwoColumns>
</TwoColumnsWithCards>
<section className="mb-8 p-4 rounded-lg shadow-lg bg-white flex flex-row justify-items-stretch ">
<TaskInfo id={tasks[0].id} output="Submit your answer" />
<Flex justify="center" ml="auto" gap={2}>
<SkipButton>Skip</SkipButton>
<SubmitButton onClick={() => submitResponse(tasks[0])}>Submit</SubmitButton>
</Flex>
</section>
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={fetchNextTask} />
</main>
</>
</div>
);
};
+20 -26
View File
@@ -1,11 +1,10 @@
import { Flex, Textarea } from "@chakra-ui/react";
import { Textarea } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import { useRef, useState } from "react";
import { SkipButton } from "src/components/Buttons/Skip";
import { SubmitButton } from "src/components/Buttons/Submit";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Messages } from "src/components/Messages";
import { TaskInfo } from "src/components/TaskInfo/TaskInfo";
import { TwoColumns } from "src/components/TwoColumns";
import { TaskControls } from "src/components/Survey/TaskControls";
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
import fetcher from "src/lib/fetcher";
import poster from "src/lib/poster";
import useSWRImmutable from "swr/immutable";
@@ -45,43 +44,38 @@ const UserReply = () => {
mutate();
};
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
if (isLoading) {
return <LoadingScreen text="Loading..." />;
}
if (tasks.length == 0) {
return <div className="p-6 bg-slate-100 text-gray-800">No tasks found...</div>;
return (
<div className={`p-12 ${mainBgClasses}`}>
<div className="flex h-full">
<div className="text-xl font-bold mx-auto my-auto">No tasks found...</div>
</div>
</div>
);
}
const task = tasks[0].task;
const endTask = tasks[tasks.length - 1];
return (
<div className="p-6 bg-slate-100 text-gray-800">
<TwoColumns>
<div className={`p-12 ${mainBgClasses}`}>
<TwoColumnsWithCards>
<>
<h5 className="text-lg font-semibold">Reply as a user</h5>
<p className="text-lg py-1">Given the following conversation, provide an adequate reply</p>
<Messages messages={task.conversation.messages} post_id={task.id} />
{task.hint && <p className="text-lg py-1">Hint: {task.hint}</p>}
</>
<Textarea name="reply" data-cy="reply" placeholder="Reply..." ref={inputRef} />
</TwoColumns>
<Textarea name="reply" placeholder="Reply..." ref={inputRef} />
</TwoColumnsWithCards>
<section className="mb-8 p-4 rounded-lg shadow-lg bg-white flex flex-row justify-items-stretch ">
<TaskInfo id={tasks[0].id} output="Submit your answer" />
<Flex justify="center" ml="auto" gap={2}>
<SkipButton>Skip</SkipButton>
{endTask.task.type !== "task_done" ? (
<SubmitButton data-cy="submit" onClick={() => submitResponse(tasks[0])}>
Submit
</SubmitButton>
) : (
<SubmitButton data-cy="next-task" onClick={fetchNextTask}>
Next Task
</SubmitButton>
)}
</Flex>
</section>
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={fetchNextTask} />
</div>
);
};
@@ -1,11 +1,10 @@
import { Flex } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
import { useState } from "react";
import { SkipButton } from "src/components/Buttons/Skip";
import { SubmitButton } from "src/components/Buttons/Submit";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Sortable } from "src/components/Sortable/Sortable";
import { TaskInfo } from "src/components/TaskInfo/TaskInfo";
import { SurveyCard } from "src/components/Survey/SurveyCard";
import { TaskControls } from "src/components/Survey/TaskControls";
import fetcher from "src/lib/fetcher";
import poster from "src/lib/poster";
import useSWRImmutable from "swr/immutable";
@@ -47,48 +46,42 @@ const RankAssistantReplies = () => {
mutate();
};
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
if (isLoading) {
return <LoadingScreen text="Loading..." />;
}
if (tasks.length == 0) {
return <div className="p-6 bg-slate-100 text-gray-800">Loading...</div>;
return (
<div className={`p-12 ${mainBgClasses}`}>
<div className="flex h-full">
<div className="text-xl font-bold mx-auto my-auto">No tasks found...</div>
</div>
</div>
);
}
const replies = tasks[0].task.replies as string[];
const endTask = tasks[tasks.length - 1];
return (
<>
<Head>
<title>Rank Assistant Replies</title>
<meta name="description" content="Rank Assistant Replies." />
</Head>
<main className="p-6 bg-slate-100 text-gray-800">
<div className="rounded-lg shadow-lg block bg-white p-6 mb-8">
<div className={`p-12 ${mainBgClasses}`}>
<SurveyCard className="max-w-7xl mx-auto h-fit mb-24">
<h5 className="text-lg font-semibold mb-4">Instructions</h5>
<p className="text-lg py-1">
Given the following replies, sort them from best to worst, best being first, worst being last.
</p>
<Sortable items={replies} onChange={setRanking} />
</div>
<Sortable items={replies} onChange={setRanking} className="my-8" />
</SurveyCard>
<section className="mb-8 p-4 rounded-lg shadow-lg bg-white flex flex-row justify-items-stretch">
<TaskInfo id={tasks[0].id} output="Submit your answer" />
<Flex justify="center" ml="auto" gap={2}>
<SkipButton>Skip</SkipButton>
{endTask.task.type !== "task_done" ? (
<SubmitButton data-cy="submit" onClick={() => submitResponse(tasks[0])} disabled={ranking.length === 0}>
Submit
</SubmitButton>
) : (
<SubmitButton data-cy="next-task" onClick={fetchNextTask}>
Next Task
</SubmitButton>
)}
</Flex>
</section>
</main>
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={fetchNextTask} />
</div>
</>
);
};
@@ -1,11 +1,10 @@
import { Flex } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
import { useState } from "react";
import { SkipButton } from "src/components/Buttons/Skip";
import { SubmitButton } from "src/components/Buttons/Submit";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Sortable } from "src/components/Sortable/Sortable";
import { TaskInfo } from "src/components/TaskInfo/TaskInfo";
import { SurveyCard } from "src/components/Survey/SurveyCard";
import { TaskControls } from "src/components/Survey/TaskControls";
import fetcher from "src/lib/fetcher";
import poster from "src/lib/poster";
import useSWRImmutable from "swr/immutable";
@@ -18,6 +17,7 @@ const RankInitialPrompts = () => {
* The best prompt will have index 0, and the worst is the last.
*/
const [ranking, setRanking] = useState<number[]>([]);
// const bg = useColorModeValue("gray.100", "gray.800");
const { isLoading, mutate } = useSWRImmutable("/api/new_task/rank_initial_prompts", fetcher, {
onSuccess: (data) => {
@@ -47,47 +47,40 @@ const RankInitialPrompts = () => {
mutate();
};
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
if (isLoading) {
return <LoadingScreen text="Loading..." />;
}
if (tasks.length == 0) {
return <div className="p-6 bg-slate-100 text-gray-800">No tasks found...</div>;
return (
<div className={`p-12 ${mainBgClasses}`}>
<div className="flex h-full">
<div className="text-xl font-bold mx-auto my-auto">No tasks found...</div>
</div>
</div>
);
}
const endTask = tasks[tasks.length - 1];
return (
<>
<Head>
<title>Rank Initial Prompts</title>
<meta name="description" content="Rank initial prompts." />
</Head>
<main className="p-6 bg-slate-100 text-gray-800">
<div className="rounded-lg shadow-lg block bg-white p-6 mb-8">
<div className={`p-12 ${mainBgClasses}`}>
<SurveyCard className="max-w-7xl mx-auto h-fit mb-24">
<h5 className="text-lg font-semibold mb-4">Instructions</h5>
<p className="text-lg py-1">
Given the following prompts, sort them from best to worst, best being first, worst being last.
</p>
<Sortable items={tasks[0].task.prompts} onChange={setRanking} />
</div>
<Sortable items={tasks[0].task.prompts} onChange={setRanking} className="my-8" />
</SurveyCard>
<section className="mb-8 p-4 rounded-lg shadow-lg bg-white flex flex-row justify-items-stretch">
<TaskInfo id={tasks[0].id} output="Submit your answer" />
<Flex justify="center" ml="auto" gap={2}>
<SkipButton>Skip</SkipButton>
{endTask.task.type !== "task_done" ? (
<SubmitButton data-cy="submit" onClick={() => submitResponse(tasks[0])} disabled={ranking.length === 0}>
Submit
</SubmitButton>
) : (
<SubmitButton data-cy="next-task" onClick={fetchNextTask}>
Next Task
</SubmitButton>
)}
</Flex>
</section>
</main>
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={fetchNextTask} />
</div>
</>
);
};
@@ -1,11 +1,10 @@
import { Flex } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
import { useState } from "react";
import { SkipButton } from "src/components/Buttons/Skip";
import { SubmitButton } from "src/components/Buttons/Submit";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Sortable } from "src/components/Sortable/Sortable";
import { TaskInfo } from "src/components/TaskInfo/TaskInfo";
import { SurveyCard } from "src/components/Survey/SurveyCard";
import { TaskControls } from "src/components/Survey/TaskControls";
import fetcher from "src/lib/fetcher";
import poster from "src/lib/poster";
import useSWRImmutable from "swr/immutable";
@@ -47,48 +46,41 @@ const RankUserReplies = () => {
mutate();
};
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
if (isLoading) {
return <LoadingScreen text="Loading..." />;
}
if (tasks.length == 0) {
return <div className="p-6 bg-slate-100 text-gray-800">Loading...</div>;
return (
<div className={`p-12 ${mainBgClasses}`}>
<div className="flex h-full">
<div className="text-xl font-bold mx-auto my-auto">No tasks found...</div>
</div>
</div>
);
}
const replies = tasks[0].task.replies as string[];
const endTask = tasks[tasks.length - 1];
return (
<>
<Head>
<title>Rank User Replies</title>
<meta name="description" content="Rank User Replies." />
</Head>
<main className="p-6 bg-slate-100 text-gray-800">
<div className="rounded-lg shadow-lg block bg-white p-6 mb-8">
<div className={`p-12 ${mainBgClasses}`}>
<SurveyCard className="max-w-7xl mx-auto h-fit mb-24">
<h5 className="text-lg font-semibold mb-4">Instructions</h5>
<p className="text-lg py-1">
Given the following replies, sort them from best to worst, best being first, worst being last.
</p>
<Sortable items={replies} onChange={setRanking} />
</div>
<Sortable items={replies} onChange={setRanking} className="my-8" />
</SurveyCard>
<section className="mb-8 p-4 rounded-lg shadow-lg bg-white flex flex-row justify-items-stretch ">
<TaskInfo id={tasks[0].id} output="Submit your answer" />
<Flex justify="center" ml="auto" gap={2}>
<SkipButton>Skip</SkipButton>
{endTask.task.type !== "task_done" ? (
<SubmitButton data-cy="submit" onClick={() => submitResponse(tasks[0])} disabled={ranking.length === 0}>
Submit
</SubmitButton>
) : (
<SubmitButton data-cy="next-task" onClick={fetchNextTask}>
Next Task
</SubmitButton>
)}
</Flex>
</section>
</main>
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={fetchNextTask} />
</div>
</>
);
};
+21 -21
View File
@@ -1,13 +1,12 @@
import { Flex, Textarea } from "@chakra-ui/react";
import { Textarea } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import { QuestionMarkCircleIcon } from "@heroicons/react/20/solid";
import Head from "next/head";
import { useState } from "react";
import { SkipButton } from "src/components/Buttons/Skip";
import { SubmitButton } from "src/components/Buttons/Submit";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import RatingRadioGroup from "src/components/RatingRadioGroup";
import { TaskInfo } from "src/components/TaskInfo/TaskInfo";
import { TwoColumns } from "src/components/TwoColumns";
import { TaskControls } from "src/components/Survey/TaskControls";
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
import fetcher from "src/lib/fetcher";
import poster from "src/lib/poster";
import useSWRImmutable from "swr/immutable";
@@ -50,15 +49,27 @@ const RateSummary = () => {
});
};
const fetchNextTask = () => {
mutate();
};
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
if (isLoading) {
return <LoadingScreen text="Loading..." />;
}
if (tasks.length == 0) {
return <div className="p-6 bg-slate-100 text-gray-800">No tasks found...</div>;
return (
<div className={`p-12 ${mainBgClasses}`}>
<div className="flex h-full">
<div className="text-xl font-bold mx-auto my-auto">No tasks found...</div>
</div>
</div>
);
}
const endTask = tasks[tasks.length - 1];
return (
<>
<Head>
@@ -66,7 +77,7 @@ const RateSummary = () => {
<meta name="description" content="Rate a proposed story summary." />
</Head>
<main className="p-6 bg-slate-100 text-gray-800">
<TwoColumns>
<TwoColumnsWithCards>
<>
<h5 className="text-lg font-semibold mb-4">Instruction</h5>
<div className="bg-slate-800 p-6 rounded-xl text-white whitespace-pre-wrap">{tasks[0].task.full_text}</div>
@@ -89,20 +100,9 @@ const RateSummary = () => {
</ul>
<Textarea name="notes" placeholder="Optional notes" />
</section>
</TwoColumns>
</TwoColumnsWithCards>
<section className="mb-8 p-4 rounded-lg shadow-lg bg-white flex flex-row justify-items-stretch ">
<TaskInfo id={tasks[0].id} output="Submit your answer" />
<Flex justify="center" ml="auto" gap={2}>
<SkipButton>Skip</SkipButton>
{endTask.task.type !== "task_done" ? (
<SubmitButton onClick={() => submitResponse(tasks[0])}>Submit</SubmitButton>
) : (
<SubmitButton onClick={mutate}>Next Task</SubmitButton>
)}
</Flex>
</section>
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={fetchNextTask} />
</main>
</>
);
+2 -4
View File
@@ -20,11 +20,9 @@ const Home = () => {
/>
</Head>
{session ? (
<main className="my-4">
<TaskSelection />
</main>
<TaskSelection />
) : (
<main>
<main className="oa-basic-theme">
<Hero />
<CallToAction />
<Faq />
+24
View File
@@ -0,0 +1,24 @@
import { ChakraProvider, cookieStorageManagerSSR, localStorageManager } from "@chakra-ui/react";
import { theme } from "./Theme";
export function Chakra({ cookies, children }) {
const colorModeManager = typeof cookies === "string" ? cookieStorageManagerSSR(cookies) : localStorageManager;
return (
<ChakraProvider theme={theme} colorModeManager={colorModeManager}>
{children}
</ChakraProvider>
);
}
// also export a reusable function getServerSideProps
export function getServerSideProps({ req }) {
return {
props: {
// first time users will not have any cookies and you may not return
// undefined here, hence ?? is necessary
cookies: req.headers.cookie ?? "",
},
};
}
+14
View File
@@ -0,0 +1,14 @@
export const colors = {
light: {
bg: "rgb(250,250,250)",
text: "black",
},
dark: {
bg: "gray.900",
text: "white",
},
"dark-blue-btn": {
200: "rgb(29,78,216)",
300: "blue",
},
};
@@ -0,0 +1,14 @@
import { defineStyleConfig } from "@chakra-ui/styled-system";
const baseStyle = {};
const variants = {
"no-padding": {
padding: 0,
},
};
export const containerTheme = defineStyleConfig({
baseStyle,
variants,
});
+37
View File
@@ -0,0 +1,37 @@
import { type ThemeConfig, extendTheme } from "@chakra-ui/react";
import { Styles } from "@chakra-ui/theme-tools";
import { colors } from "./colors";
import { containerTheme } from "./components/Container";
const config: ThemeConfig = {
initialColorMode: "light",
useSystemColorMode: true,
disableTransitionOnChange: false,
};
const components = {
Container: containerTheme,
};
const styles: Styles = {
global: (props) => ({
"*": {
transition: "background-color 200ms cubic-bezier(0.4, 0, 1, 1)",
// bg: props.colorMode === "light" ? colors.light.bg : colors.dark.bg,
// color: props.colorMode === "light" ? colors.light.text : colors.dark.text,
},
".oa-basic-theme": {
bg: props.colorMode === "light" ? colors.light.bg : colors.dark.bg,
color: props.colorMode === "light" ? colors.light.text : colors.dark.text,
},
main: {
fontFamily: "Inter",
},
header: {
fontFamily: "Inter",
},
}),
};
export const theme = extendTheme({ colors, config, styles, components });
+4
View File
@@ -71,6 +71,10 @@ module.exports = {
maxWidth: {
"2xl": "40rem",
},
colors: {
"chakra-gray-900": "#171923",
},
},
},
plugins: [require("@tailwindcss/forms")],