mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Merge branch 'main' into pre-commit-jupyter-black
This commit is contained in:
@@ -39,7 +39,7 @@ repos:
|
||||
# and which break the standard YAML check. The alternative would be to
|
||||
# skip any unsafe errors (and thus break YAML compatibility) or use
|
||||
# some other checker that may not work in general.
|
||||
exclude: ^copilot/web/addons/.*$
|
||||
exclude: ^copilot/.*/addons/.*$
|
||||
- id: check-json
|
||||
- id: check-case-conflict
|
||||
- id: detect-private-key
|
||||
|
||||
@@ -1,2 +1,4 @@
|
||||
* @yk @andreaskoepf
|
||||
/website/ @fozziethebeat @k-nearest-neighbor @AbdBarho
|
||||
/model/ @theblackcat102 @sanagno
|
||||
/copilot/ @fozziethebeat @andreaskoepf @yk
|
||||
|
||||
+6
-1
@@ -26,8 +26,13 @@ app = fastapi.FastAPI(title=settings.PROJECT_NAME, openapi_url=f"{settings.API_V
|
||||
@app.exception_handler(OasstError)
|
||||
async def oasst_exception_handler(request: fastapi.Request, ex: OasstError):
|
||||
logger.error(f"{request.method} {request.url} failed: {repr(ex)}")
|
||||
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=int(ex.http_status_code), content={"message": ex.message, "error_code": ex.error_code}
|
||||
status_code=int(ex.http_status_code),
|
||||
content=protocol_schema.OasstErrorResponse(
|
||||
message=ex.message,
|
||||
error_code=OasstErrorCode(ex.error_code),
|
||||
).dict(),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
# The manifest for the "api" service.
|
||||
# Read the full specification for the "Load Balanced Web Service" type at:
|
||||
# https://aws.github.io/copilot-cli/docs/manifest/lb-web-service/
|
||||
|
||||
name: api
|
||||
type: Load Balanced Web Service
|
||||
|
||||
http:
|
||||
path: "/"
|
||||
healthcheck:
|
||||
path: "/docs"
|
||||
|
||||
image:
|
||||
build:
|
||||
dockerfile: docker/Dockerfile.backend
|
||||
context: ./
|
||||
port: 8080
|
||||
|
||||
cpu: 256
|
||||
memory: 512
|
||||
platform: linux/x86_64
|
||||
count: 1
|
||||
exec: true
|
||||
network:
|
||||
connect: true
|
||||
|
||||
environments:
|
||||
staging:
|
||||
variables:
|
||||
# Note: this has to be a valid JSON list for Pydantic to parse it.
|
||||
BACKEND_CORS_ORIGINS: '["https://web.staging.open-assistant.surfacedata.org"]'
|
||||
DEBUG_ALLOW_ANY_API_KEY: True
|
||||
DEBUG_SKIP_API_KEY_CHECK: True
|
||||
MAX_WORKERS: 1
|
||||
|
||||
secrets:
|
||||
# Note: URI, not URL.
|
||||
DATABASE_URI: /copilot/${COPILOT_APPLICATION_NAME}/${COPILOT_ENVIRONMENT_NAME}/secrets/API_DATABASE_URL
|
||||
@@ -1,161 +0,0 @@
|
||||
Parameters:
|
||||
App:
|
||||
Type: String
|
||||
Description: Your application's name.
|
||||
Env:
|
||||
Type: String
|
||||
Description:
|
||||
The environment name your service, job, or workflow is being deployed to.
|
||||
Name:
|
||||
Type: String
|
||||
Description: The name of the service, job, or workflow being deployed.
|
||||
# Customize your Aurora Serverless cluster by setting the default value of the following parameters.
|
||||
webclusterDBName:
|
||||
Type: String
|
||||
Description:
|
||||
The name of the initial database to be created in the Aurora Serverless v2
|
||||
cluster.
|
||||
Default: oassist_web
|
||||
# Cannot have special characters
|
||||
# Naming constraints: https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/CHAP_Limits.html#RDS_Limits.Constraints
|
||||
Mappings:
|
||||
webclusterEnvScalingConfigurationMap:
|
||||
staging:
|
||||
"DBMinCapacity": 0.5 # AllowedValues: from 0.5 through 128
|
||||
"DBMaxCapacity": 8 # AllowedValues: from 0.5 through 128
|
||||
|
||||
All:
|
||||
"DBMinCapacity": 0.5 # AllowedValues: from 0.5 through 128
|
||||
"DBMaxCapacity": 8 # AllowedValues: from 0.5 through 128
|
||||
|
||||
Resources:
|
||||
webclusterDBSubnetGroup:
|
||||
Type: "AWS::RDS::DBSubnetGroup"
|
||||
Properties:
|
||||
DBSubnetGroupDescription:
|
||||
Group of Copilot private subnets for Aurora Serverless v2 cluster.
|
||||
SubnetIds:
|
||||
!Split [",", { "Fn::ImportValue": !Sub "${App}-${Env}-PrivateSubnets" }]
|
||||
webclusterSecurityGroup:
|
||||
Metadata:
|
||||
"aws:copilot:description":
|
||||
"A security group for your workload to access the Aurora Serverless v2
|
||||
cluster webcluster"
|
||||
Type: "AWS::EC2::SecurityGroup"
|
||||
Properties:
|
||||
GroupDescription:
|
||||
!Sub "The Security Group for ${Name} to access Aurora Serverless v2
|
||||
cluster webcluster."
|
||||
VpcId:
|
||||
Fn::ImportValue: !Sub "${App}-${Env}-VpcId"
|
||||
Tags:
|
||||
- Key: Name
|
||||
Value: !Sub "copilot-${App}-${Env}-${Name}-Aurora"
|
||||
webclusterDBClusterSecurityGroup:
|
||||
Metadata:
|
||||
"aws:copilot:description":
|
||||
"A security group for your Aurora Serverless v2 cluster webcluster"
|
||||
Type: AWS::EC2::SecurityGroup
|
||||
Properties:
|
||||
GroupDescription: The Security Group for the Aurora Serverless v2 cluster.
|
||||
SecurityGroupIngress:
|
||||
- ToPort: 5432
|
||||
FromPort: 5432
|
||||
IpProtocol: tcp
|
||||
Description:
|
||||
!Sub "From the Aurora Security Group of the workload ${Name}."
|
||||
SourceSecurityGroupId: !Ref webclusterSecurityGroup
|
||||
VpcId:
|
||||
Fn::ImportValue: !Sub "${App}-${Env}-VpcId"
|
||||
webclusterAuroraSecret:
|
||||
Metadata:
|
||||
"aws:copilot:description":
|
||||
"A Secrets Manager secret to store your DB credentials"
|
||||
Type: AWS::SecretsManager::Secret
|
||||
Properties:
|
||||
Description: !Sub Aurora main user secret for ${AWS::StackName}
|
||||
GenerateSecretString:
|
||||
SecretStringTemplate: '{"username": "postgres"}'
|
||||
GenerateStringKey: "password"
|
||||
ExcludePunctuation: true
|
||||
IncludeSpace: false
|
||||
PasswordLength: 16
|
||||
webclusterDBClusterParameterGroup:
|
||||
Metadata:
|
||||
"aws:copilot:description":
|
||||
"A DB parameter group for engine configuration values"
|
||||
Type: "AWS::RDS::DBClusterParameterGroup"
|
||||
Properties:
|
||||
Description: !Ref "AWS::StackName"
|
||||
Family: "aurora-postgresql14"
|
||||
Parameters:
|
||||
client_encoding: "UTF8"
|
||||
webclusterDBCluster:
|
||||
Metadata:
|
||||
"aws:copilot:description":
|
||||
"The webcluster Aurora Serverless v2 database cluster"
|
||||
Type: "AWS::RDS::DBCluster"
|
||||
Properties:
|
||||
MasterUsername:
|
||||
!Join [
|
||||
"",
|
||||
[
|
||||
"{{resolve:secretsmanager:",
|
||||
!Ref webclusterAuroraSecret,
|
||||
":SecretString:username}}",
|
||||
],
|
||||
]
|
||||
MasterUserPassword:
|
||||
!Join [
|
||||
"",
|
||||
[
|
||||
"{{resolve:secretsmanager:",
|
||||
!Ref webclusterAuroraSecret,
|
||||
":SecretString:password}}",
|
||||
],
|
||||
]
|
||||
DatabaseName: !Ref webclusterDBName
|
||||
Engine: "aurora-postgresql"
|
||||
EngineVersion: "14.4"
|
||||
DBClusterParameterGroupName: !Ref webclusterDBClusterParameterGroup
|
||||
DBSubnetGroupName: !Ref webclusterDBSubnetGroup
|
||||
Port: 5432
|
||||
VpcSecurityGroupIds:
|
||||
- !Ref webclusterDBClusterSecurityGroup
|
||||
ServerlessV2ScalingConfiguration:
|
||||
# Replace "All" below with "!Ref Env" to set different autoscaling limits per environment.
|
||||
MinCapacity:
|
||||
!FindInMap [webclusterEnvScalingConfigurationMap, All, DBMinCapacity]
|
||||
MaxCapacity:
|
||||
!FindInMap [webclusterEnvScalingConfigurationMap, All, DBMaxCapacity]
|
||||
webclusterDBWriterInstance:
|
||||
Metadata:
|
||||
"aws:copilot:description":
|
||||
"The webcluster Aurora Serverless v2 writer instance"
|
||||
Type: "AWS::RDS::DBInstance"
|
||||
Properties:
|
||||
DBClusterIdentifier: !Ref webclusterDBCluster
|
||||
DBInstanceClass: db.serverless
|
||||
Engine: "aurora-postgresql"
|
||||
PromotionTier: 1
|
||||
AvailabilityZone: !Select
|
||||
- 0
|
||||
- !GetAZs
|
||||
Ref: AWS::Region
|
||||
|
||||
webclusterSecretAuroraClusterAttachment:
|
||||
Type: AWS::SecretsManager::SecretTargetAttachment
|
||||
Properties:
|
||||
SecretId: !Ref webclusterAuroraSecret
|
||||
TargetId: !Ref webclusterDBCluster
|
||||
TargetType: AWS::RDS::DBCluster
|
||||
Outputs:
|
||||
webclusterSecret: # injected as WEBCLUSTER_SECRET environment variable by Copilot.
|
||||
Description:
|
||||
"The JSON secret that holds the database username and password. Fields are
|
||||
'host', 'port', 'dbname', 'username', 'password', 'dbClusterIdentifier'
|
||||
and 'engine'"
|
||||
Value: !Ref webclusterAuroraSecret
|
||||
webclusterSecurityGroup:
|
||||
Description: "The security group to attach to the workload."
|
||||
Value: !Ref webclusterSecurityGroup
|
||||
@@ -26,6 +26,7 @@ environments:
|
||||
staging:
|
||||
variables:
|
||||
NEXTAUTH_URL: https://web.staging.open-assistant.surfacedata.org
|
||||
FASTAPI_URL: https://api.staging.open-assistant.surfacedata.org
|
||||
|
||||
secrets:
|
||||
DATABASE_URL: /copilot/${COPILOT_APPLICATION_NAME}/${COPILOT_ENVIRONMENT_NAME}/secrets/DATABASE_URL
|
||||
@@ -37,5 +38,4 @@ secrets:
|
||||
EMAIL_SERVER_USER: /copilot/${COPILOT_APPLICATION_NAME}/${COPILOT_ENVIRONMENT_NAME}/secrets/EMAIL_SERVER_USER
|
||||
EMAIL_FROM: /copilot/${COPILOT_APPLICATION_NAME}/${COPILOT_ENVIRONMENT_NAME}/secrets/EMAIL_FROM
|
||||
FASTAPI_KEY: /copilot/${COPILOT_APPLICATION_NAME}/${COPILOT_ENVIRONMENT_NAME}/secrets/FASTAPI_KEY
|
||||
FASTAPI_URL: /copilot/${COPILOT_APPLICATION_NAME}/${COPILOT_ENVIRONMENT_NAME}/secrets/FASTAPI_URL
|
||||
NEXTAUTH_SECRET: /copilot/${COPILOT_APPLICATION_NAME}/${COPILOT_ENVIRONMENT_NAME}/secrets/NEXTAUTH_SECRET
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
BOT_TOKEN=<discord bot token>
|
||||
DECLARE_GLOBAL_COMMANDS=<testing guild id>
|
||||
OWNER_IDS=[<your user id>, <other user ids>]
|
||||
PREFIX="./"
|
||||
PREFIX="/" # DO NOT LEAVE EMPTY, slash command prefix in DMs
|
||||
|
||||
OASST_API_URL="http://localhost:8080" # No trailing '/'
|
||||
OASST_API_KEY=""
|
||||
|
||||
@@ -48,7 +48,7 @@ Remember to save your changes.
|
||||
below to invite your bot.
|
||||
|
||||
```
|
||||
https://discord.com/oauth2/authorize?client_id=YOUR_ID_HERE&permissions=8&scope=bot%20applications.commands
|
||||
https://discord.com/oauth2/authorize?client_id=YOUR_CLIENT_ID_HERE&permissions=8&scope=bot%20applications.commands
|
||||
```
|
||||
|
||||
### Environment Setup
|
||||
@@ -66,6 +66,9 @@ pip install -e .
|
||||
cp .env.example .env
|
||||
|
||||
# edit .env and add your bot token and other values
|
||||
# BOT_TOKEN is given by the discord developer portal when you create a bot
|
||||
# DECLARE_GLOBAL_COMMANDS is the id of the server where you added the bot (right click on the server icon and copy id)
|
||||
# OWNER_ID can be leave as an empty list
|
||||
|
||||
python -V # 3.10
|
||||
|
||||
|
||||
+13
-6
@@ -6,7 +6,7 @@ import hikari
|
||||
import lightbulb
|
||||
import miru
|
||||
from bot.settings import Settings
|
||||
from bot.utils import EMPTY, mention
|
||||
from bot.utils import mention
|
||||
from oasst_shared.api_client import OasstApiClient
|
||||
|
||||
settings = Settings()
|
||||
@@ -34,8 +34,11 @@ async def on_starting(event: hikari.StartingEvent):
|
||||
|
||||
bot.d.oasst_api = OasstApiClient(settings.oasst_api_url, settings.oasst_api_key)
|
||||
|
||||
# A set of user id's that are currently doing work.
|
||||
bot.d.currently_working = set()
|
||||
# A `dict[hikari.Message | None, UUID | None]]` that maps user IDs to (task msg ID, task UUIDs).
|
||||
# Either both are `None` or both are not `None`.
|
||||
# If both are `None`, the user is not currently selecting a task.
|
||||
# TODO: Grow this on startup so we don't have to re-allocate memory every time it needs to grow
|
||||
bot.d.currently_working = {}
|
||||
|
||||
|
||||
@bot.listen()
|
||||
@@ -50,13 +53,13 @@ async def _send_error_embed(
|
||||
) -> None:
|
||||
ctx.command
|
||||
embed = hikari.Embed(
|
||||
title=f"`{exception.__class__.__name__}` Error{f' in `{ctx.command.name}`' if ctx.command else '' }",
|
||||
title=f"`{exception.__class__.__name__}` Error{f' in `/{ctx.command.name}`' if ctx.command else '' }",
|
||||
description=content,
|
||||
color=0xFF0000,
|
||||
timestamp=datetime.now().astimezone(),
|
||||
).set_author(name=ctx.author.username, url=str(ctx.author.avatar_url))
|
||||
|
||||
await ctx.respond(EMPTY, embed=embed)
|
||||
await ctx.respond(embed=embed)
|
||||
|
||||
|
||||
@bot.listen(lightbulb.CommandErrorEvent)
|
||||
@@ -65,6 +68,8 @@ async def on_error(event: lightbulb.CommandErrorEvent) -> None:
|
||||
# Unwrap the exception to get the original cause
|
||||
exc = event.exception.__cause__ or event.exception
|
||||
ctx = event.context
|
||||
if not ctx.bot.rest.is_alive:
|
||||
return
|
||||
|
||||
if isinstance(event.exception, lightbulb.CommandInvocationError):
|
||||
if not event.context.command:
|
||||
@@ -114,6 +119,8 @@ async def on_error(event: lightbulb.CommandErrorEvent) -> None:
|
||||
ctx,
|
||||
)
|
||||
elif isinstance(exc, lightbulb.errors.MissingRequiredAttachment):
|
||||
await _send_error_embed("Not enough attachemnts were supplied to this command.", exc, ctx)
|
||||
await _send_error_embed("Not enough attachments were supplied to this command.", exc, ctx)
|
||||
elif isinstance(exc, lightbulb.errors.CommandNotFound):
|
||||
await ctx.respond(f"`/{exc.invoked_with}` is not a valid command. Use `/help` to see a list of commands.")
|
||||
else:
|
||||
raise exc
|
||||
|
||||
@@ -78,7 +78,6 @@ async def log_channel(ctx: lightbulb.SlashContext) -> None:
|
||||
|
||||
# if the bot's permissions for this channel don't contain SEND_MESSAGE
|
||||
# This will also filter out categories and voice channels
|
||||
print(permissions_in(ch, own_member) & hikari.Permissions.SEND_MESSAGES)
|
||||
if not permissions_in(ch, own_member) & hikari.Permissions.SEND_MESSAGES:
|
||||
await ctx.respond(f"I don't have permission to send messages in {ch.mention}.")
|
||||
return
|
||||
|
||||
@@ -7,7 +7,6 @@ import lightbulb
|
||||
import miru
|
||||
from aiosqlite import Connection
|
||||
from bot.db.schemas import GuildSettings
|
||||
from bot.utils import EMPTY
|
||||
from loguru import logger
|
||||
|
||||
plugin = lightbulb.Plugin(
|
||||
@@ -74,7 +73,7 @@ class LabelModal(miru.Modal):
|
||||
)
|
||||
channel = await context.bot.rest.fetch_channel(guild_settings.log_channel_id)
|
||||
assert isinstance(channel, hikari.TextableChannel)
|
||||
await channel.send(EMPTY, embed=embed)
|
||||
await channel.send(embed=embed)
|
||||
|
||||
|
||||
class LabelSelect(miru.View):
|
||||
@@ -164,7 +163,7 @@ async def label_message_text(ctx: lightbulb.MessageContext):
|
||||
msg.content,
|
||||
timeout=60,
|
||||
)
|
||||
resp = await ctx.respond(EMPTY, embed=embed, components=label_select_view, flags=hikari.MessageFlag.EPHEMERAL)
|
||||
resp = await ctx.respond(embed=embed, components=label_select_view, flags=hikari.MessageFlag.EPHEMERAL)
|
||||
|
||||
await label_select_view.start(await resp.message())
|
||||
await label_select_view.wait()
|
||||
|
||||
+184
-186
@@ -1,15 +1,27 @@
|
||||
"""Work plugin for collecting user data."""
|
||||
import asyncio
|
||||
import typing as t
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
import hikari
|
||||
import lightbulb
|
||||
import lightbulb.decorators
|
||||
import miru
|
||||
from aiosqlite import Connection
|
||||
from bot.db.schemas import GuildSettings
|
||||
from bot.utils import EMPTY
|
||||
from bot.messages import (
|
||||
assistant_reply_message,
|
||||
confirm_ranking_response_message,
|
||||
confirm_text_response_message,
|
||||
initial_prompt_message,
|
||||
invalid_user_input_embed,
|
||||
plain_embed,
|
||||
prompter_reply_message,
|
||||
rank_assistant_reply_message,
|
||||
rank_initial_prompts_message,
|
||||
rank_prompter_reply_message,
|
||||
task_complete_embed,
|
||||
)
|
||||
from bot.settings import Settings
|
||||
from loguru import logger
|
||||
from oasst_shared.api_client import OasstApiClient, TaskType
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
@@ -20,6 +32,8 @@ plugin = lightbulb.Plugin("WorkPlugin")
|
||||
MAX_TASK_TIME = 60 * 60 # 1 hour
|
||||
MAX_TASK_ACCEPT_TIME = 60 # 1 minute
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
@plugin.command
|
||||
@lightbulb.option(
|
||||
@@ -31,31 +45,56 @@ MAX_TASK_ACCEPT_TIME = 60 # 1 minute
|
||||
type=str,
|
||||
)
|
||||
@lightbulb.command("work", "Complete a task.")
|
||||
@lightbulb.implements(lightbulb.SlashCommand)
|
||||
async def work(ctx: lightbulb.SlashContext):
|
||||
@lightbulb.implements(lightbulb.SlashCommand, lightbulb.PrefixCommand)
|
||||
async def work(ctx: lightbulb.Context):
|
||||
"""Create and handle a task."""
|
||||
# make sure the user isn't currently doing a task
|
||||
currently_working: set[hikari.Snowflakeish] = ctx.bot.d.currently_working
|
||||
# Only send this message if started from a server
|
||||
if ctx.guild_id is not None:
|
||||
await ctx.respond(embed=plain_embed("Sending you a task, check your DMs"), flags=hikari.MessageFlag.EPHEMERAL)
|
||||
|
||||
# make sure the user isn't currently doing a task, and if they are, ask if they want to cancel it
|
||||
currently_working: dict[
|
||||
hikari.Snowflakeish, tuple[hikari.Message | None, UUID | None]
|
||||
] = ctx.bot.d.currently_working
|
||||
|
||||
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
|
||||
if ctx.author.id in currently_working:
|
||||
await ctx.respond(
|
||||
"You are already performing a task. Please complete that one first.", flags=hikari.MessageFlag.EPHEMERAL
|
||||
yn_view = YesNoView(timeout=MAX_TASK_ACCEPT_TIME)
|
||||
msg = await ctx.author.send(
|
||||
embed=plain_embed("You are already working. Would you like to cancel your old task start a new one?"),
|
||||
flags=hikari.MessageFlag.EPHEMERAL,
|
||||
components=yn_view,
|
||||
)
|
||||
return
|
||||
await yn_view.start(msg)
|
||||
await yn_view.wait()
|
||||
|
||||
currently_working.add(ctx.author.id)
|
||||
match yn_view.choice:
|
||||
case False | None:
|
||||
return
|
||||
case True:
|
||||
old_msg, task_id = currently_working[ctx.author.id]
|
||||
if old_msg is not None:
|
||||
logger.info(f"User {ctx.author.id} cancelled task {task_id}, deleting message {old_msg.id}")
|
||||
map(lambda c: c, old_msg.components)
|
||||
await old_msg.delete()
|
||||
if task_id is not None:
|
||||
await oasst_api.nack_task(task_id, reason="user cancelled")
|
||||
|
||||
await msg.delete()
|
||||
|
||||
currently_working[ctx.author.id] = (None, None)
|
||||
|
||||
# Create a TaskRequestType from the stringified enum value
|
||||
task_type: TaskRequestType = TaskRequestType(ctx.options.type.split(".")[-1])
|
||||
|
||||
await ctx.respond("Sending you a task, check your DMs", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
logger.debug(f"Starting task_type: {task_type!r}")
|
||||
|
||||
try:
|
||||
await _handle_task(ctx, task_type)
|
||||
finally:
|
||||
currently_working.remove(ctx.author.id)
|
||||
del currently_working[ctx.author.id]
|
||||
|
||||
|
||||
async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType) -> None:
|
||||
async def _handle_task(ctx: lightbulb.Context, task_type: TaskRequestType) -> None:
|
||||
"""Handle creating and collecting user input for a task.
|
||||
|
||||
Continually present tasks to the user until they select one, cancel, or time out.
|
||||
@@ -72,38 +111,79 @@ async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType)
|
||||
task, msg_id = await _select_task(ctx, task_type)
|
||||
|
||||
if task is None:
|
||||
# User cancelled
|
||||
return
|
||||
|
||||
# Task action loop
|
||||
completed = False
|
||||
while not completed:
|
||||
await ctx.author.send("Please type your response here:")
|
||||
await ctx.author.send(embed=plain_embed("Please type your response here"))
|
||||
try:
|
||||
event = await ctx.bot.wait_for(
|
||||
hikari.DMMessageCreateEvent, timeout=MAX_TASK_TIME, predicate=lambda e: e.author.id == ctx.author.id
|
||||
hikari.DMMessageCreateEvent,
|
||||
timeout=MAX_TASK_TIME,
|
||||
predicate=lambda e: e.author.id == ctx.author.id
|
||||
and not (e.message.content or "").startswith(settings.prefix),
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
await ctx.author.send("Task timed out. Exiting")
|
||||
await ctx.author.send(embed=plain_embed("Task timed out. Exiting"))
|
||||
await oasst_api.nack_task(task.id, reason="timed out")
|
||||
logger.info(f"Task {task.id} timed out")
|
||||
return
|
||||
|
||||
# Invalid response
|
||||
if event.content is None or not _validate_user_input(event.content, task):
|
||||
await ctx.author.send("Invalid response")
|
||||
valid, err_msg = _validate_user_input(event.content, task)
|
||||
if not valid or event.content is None:
|
||||
|
||||
await ctx.author.send(embed=invalid_user_input_embed(err_msg))
|
||||
continue
|
||||
|
||||
logger.debug(f"Successful user input received: {event.content}")
|
||||
|
||||
# Confirm user input
|
||||
if isinstance(task, protocol_schema.RankConversationRepliesTask):
|
||||
content = confirm_ranking_response_message(event.content, task.replies)
|
||||
elif isinstance(task, protocol_schema.RankInitialPromptsTask):
|
||||
content = confirm_ranking_response_message(event.content, task.prompts)
|
||||
elif isinstance(task, protocol_schema.ReplyToConversationTask | protocol_schema.InitialPromptTask):
|
||||
content = confirm_text_response_message(event.content)
|
||||
else:
|
||||
logger.critical(f"Unknown task type: {task.type}")
|
||||
raise ValueError(f"Unknown task type: {task.type}")
|
||||
|
||||
confirm_resp_view = YesNoView(timeout=MAX_TASK_TIME)
|
||||
msg = await ctx.author.send(content, components=confirm_resp_view)
|
||||
await confirm_resp_view.start(msg)
|
||||
await confirm_resp_view.wait()
|
||||
|
||||
match confirm_resp_view.choice:
|
||||
case False | None:
|
||||
continue
|
||||
case True:
|
||||
await msg.delete() # buttons are already gone
|
||||
|
||||
# Send the response to the backend
|
||||
reply = protocol_schema.TextReplyToMessage(
|
||||
message_id=str(msg_id),
|
||||
user_message_id=str(event.message_id),
|
||||
user=protocol_schema.User(
|
||||
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
|
||||
),
|
||||
text=event.content,
|
||||
)
|
||||
if isinstance(task, protocol_schema.RankConversationRepliesTask | protocol_schema.RankInitialPromptsTask):
|
||||
reply = protocol_schema.MessageRanking(
|
||||
message_id=str(msg_id),
|
||||
ranking=[int(r) - 1 for r in event.content.replace(" ", "").split(",")],
|
||||
user=protocol_schema.User(
|
||||
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
|
||||
),
|
||||
)
|
||||
elif isinstance(task, protocol_schema.ReplyToConversationTask | protocol_schema.InitialPromptTask):
|
||||
reply = protocol_schema.TextReplyToMessage(
|
||||
message_id=str(msg_id),
|
||||
user_message_id=str(event.message_id),
|
||||
user=protocol_schema.User(
|
||||
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
|
||||
),
|
||||
text=event.content,
|
||||
)
|
||||
else:
|
||||
logger.critical(f"Unexpected task type received: {task.type}")
|
||||
raise ValueError(f"Unexpected task type received: {task.type}")
|
||||
|
||||
logger.debug(f"Sending reply to backend: {reply!r}")
|
||||
|
||||
# Get next task
|
||||
@@ -111,63 +191,55 @@ async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType)
|
||||
logger.info(f"New task {new_task}")
|
||||
|
||||
if new_task.type == TaskType.done:
|
||||
await ctx.author.send("Task completed")
|
||||
await ctx.author.send(embed=plain_embed("Task completed"))
|
||||
completed = True
|
||||
continue
|
||||
else:
|
||||
logger.critical(f"Unexpected task type received: {new_task.type}")
|
||||
|
||||
# Send a message in the log channel that the task is complete
|
||||
# TODO: Maybe do something with the msg ID so users can rate the "answer"
|
||||
assert ctx.guild_id is not None
|
||||
# Send a message in all the log channels that the task is complete
|
||||
conn: Connection = ctx.bot.d.db
|
||||
guild_settings = await GuildSettings.from_db(conn, ctx.guild_id)
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute("SELECT log_channel_id FROM guild_settings")
|
||||
log_channel_ids = await cursor.fetchall()
|
||||
|
||||
if guild_settings is not None and guild_settings.log_channel_id is not None:
|
||||
channels = [
|
||||
ctx.bot.cache.get_guild_channel(id[0]) or await ctx.bot.rest.fetch_channel(id[0])
|
||||
for id in log_channel_ids
|
||||
]
|
||||
|
||||
channel = await ctx.bot.rest.fetch_channel(guild_settings.log_channel_id)
|
||||
assert isinstance(channel, hikari.TextableChannel) # option converter
|
||||
|
||||
done_embed = (
|
||||
hikari.Embed(
|
||||
title="Task Completion",
|
||||
description=f"`{task.type}` completed by {ctx.author.mention}",
|
||||
color=hikari.Color(0x00FF00),
|
||||
timestamp=datetime.now().astimezone(),
|
||||
)
|
||||
.add_field("Total Tasks", "0", inline=True)
|
||||
.add_field("Server Ranking", "0/0", inline=True)
|
||||
.add_field("Global Ranking", "0/0", inline=True)
|
||||
.set_footer(f"Task ID: {task.id}")
|
||||
)
|
||||
await channel.send(EMPTY, embed=done_embed)
|
||||
done_embed = task_complete_embed(task, ctx.author.mention)
|
||||
# This will definitely get the bot rate limited, but that's a future problem
|
||||
asyncio.gather(*(ch.send(embed=done_embed) for ch in channels if isinstance(ch, hikari.TextableChannel)))
|
||||
|
||||
# ask the user if they want to do another task
|
||||
choice_view = ChoiceView(timeout=MAX_TASK_ACCEPT_TIME)
|
||||
msg = await ctx.author.send("Would you like another task?", components=choice_view)
|
||||
await choice_view.start(msg)
|
||||
await choice_view.wait()
|
||||
another_task_view = YesNoView(timeout=MAX_TASK_ACCEPT_TIME)
|
||||
msg = await ctx.author.send(embed=plain_embed("Would you like another task?"), components=another_task_view)
|
||||
await another_task_view.start(msg)
|
||||
await another_task_view.wait()
|
||||
|
||||
match choice_view.choice:
|
||||
match another_task_view.choice:
|
||||
case False | None:
|
||||
done = True
|
||||
await ctx.author.send("Exiting, goodbye!")
|
||||
await msg.edit(embed=plain_embed("Exiting, goodbye!"))
|
||||
case True:
|
||||
pass
|
||||
|
||||
|
||||
async def _select_task(
|
||||
ctx: lightbulb.SlashContext, task_type: TaskRequestType, user: protocol_schema.User | None = None
|
||||
ctx: lightbulb.Context, task_type: TaskRequestType, user: protocol_schema.User | None = None
|
||||
) -> tuple[protocol_schema.Task | None, str]:
|
||||
"""Present tasks to the user until they accept one, cancel, or time out."""
|
||||
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
|
||||
logger.debug(f"Starting task selection for {task_type}")
|
||||
|
||||
# Loop until the user accepts a task, cancels, or times out
|
||||
msg: hikari.UndefinedOr[hikari.Message] = hikari.UNDEFINED
|
||||
while True:
|
||||
logger.debug(f"Requesting task of type {task_type}")
|
||||
task = await oasst_api.fetch_task(task_type, user)
|
||||
resp, msg_id = await _send_task(ctx, task)
|
||||
resp, msg = await _send_task(ctx, task, msg)
|
||||
msg_id = str(msg.id)
|
||||
|
||||
logger.debug(f"User choice: {resp}")
|
||||
match resp:
|
||||
@@ -179,25 +251,24 @@ async def _select_task(
|
||||
case "next":
|
||||
logger.info(f"Task {task.id} rejected, sending NACK")
|
||||
await oasst_api.nack_task(task.id, "rejected")
|
||||
await ctx.author.send("Sending next task...")
|
||||
continue
|
||||
|
||||
case "cancel":
|
||||
logger.info(f"Task {task.id} canceled, sending NACK")
|
||||
await oasst_api.nack_task(task.id, "canceled")
|
||||
await ctx.author.send("Task canceled. Exiting")
|
||||
await ctx.author.send(embed=plain_embed("Task canceled. Exiting"))
|
||||
return None, msg_id
|
||||
|
||||
case None:
|
||||
logger.info(f"Task {task.id} timed out, sending NACK")
|
||||
await oasst_api.nack_task(task.id, "timed out")
|
||||
await ctx.author.send("Task timed out. Exiting")
|
||||
await ctx.author.send(embed=plain_embed("Task timed out. Exiting"))
|
||||
return None, msg_id
|
||||
|
||||
|
||||
async def _send_task(
|
||||
ctx: lightbulb.SlashContext, task: protocol_schema.Task
|
||||
) -> tuple[t.Literal["accept", "next", "cancel"] | None, str]:
|
||||
ctx: lightbulb.Context, task: protocol_schema.Task, msg: hikari.UndefinedOr[hikari.Message]
|
||||
) -> tuple[t.Literal["accept", "next", "cancel"] | None, hikari.Message]:
|
||||
"""Send a task to the user.
|
||||
|
||||
Returns the user's choice and the message ID of the task message.
|
||||
@@ -206,37 +277,38 @@ async def _send_task(
|
||||
# but the tasks aren't discord specific so that doesn't really make sense.
|
||||
|
||||
embed: hikari.UndefinedOr[hikari.Embed] = hikari.UNDEFINED
|
||||
content: hikari.UndefinedOr[str] = hikari.UNDEFINED
|
||||
|
||||
# Create an embed based on the task's type
|
||||
if task.type == TaskRequestType.initial_prompt:
|
||||
assert isinstance(task, protocol_schema.InitialPromptTask)
|
||||
logger.debug("sending initial prompt task")
|
||||
embed = _initial_prompt_embed(task)
|
||||
content = initial_prompt_message(task)
|
||||
|
||||
elif task.type == TaskRequestType.rank_initial_prompts:
|
||||
assert isinstance(task, protocol_schema.RankInitialPromptsTask)
|
||||
logger.debug("sending rank initial prompt task")
|
||||
embed = _rank_initial_prompt_embed(task)
|
||||
content = rank_initial_prompts_message(task)
|
||||
|
||||
elif task.type == TaskRequestType.rank_prompter_replies:
|
||||
assert isinstance(task, protocol_schema.RankPrompterRepliesTask)
|
||||
logger.debug("sending rank user reply task")
|
||||
embed = _rank_prompter_reply_embed(task)
|
||||
content = rank_prompter_reply_message(task)
|
||||
|
||||
elif task.type == TaskRequestType.rank_assistant_replies:
|
||||
assert isinstance(task, protocol_schema.RankAssistantRepliesTask)
|
||||
logger.debug("sending rank assistant reply task")
|
||||
embed = _rank_assistant_reply_embed(task)
|
||||
content = rank_assistant_reply_message(task)
|
||||
|
||||
elif task.type == TaskRequestType.prompter_reply:
|
||||
assert isinstance(task, protocol_schema.PrompterReplyTask)
|
||||
logger.debug("sending user reply task")
|
||||
embed = _prompter_reply_embed(task)
|
||||
content = prompter_reply_message(task)
|
||||
|
||||
elif task.type == TaskRequestType.assistant_reply:
|
||||
assert isinstance(task, protocol_schema.AssistantReplyTask)
|
||||
logger.debug("sending assistant reply task")
|
||||
embed = _assistant_reply_embed(task)
|
||||
content = assistant_reply_message(task)
|
||||
|
||||
elif task.type == TaskRequestType.summarize_story:
|
||||
raise NotImplementedError
|
||||
@@ -248,24 +320,34 @@ async def _send_task(
|
||||
raise ValueError(f"unknown task type {task.type}")
|
||||
|
||||
view = TaskAcceptView(timeout=MAX_TASK_ACCEPT_TIME)
|
||||
msg = await ctx.author.send(
|
||||
EMPTY,
|
||||
embed=embed,
|
||||
components=view,
|
||||
)
|
||||
if not msg:
|
||||
msg = await ctx.author.send(
|
||||
content,
|
||||
embed=embed,
|
||||
components=view,
|
||||
)
|
||||
else:
|
||||
await msg.edit(
|
||||
content,
|
||||
embed=embed,
|
||||
components=view,
|
||||
)
|
||||
|
||||
assert msg is not None
|
||||
|
||||
# Set the choice id as the current msg id
|
||||
ctx.bot.d.currently_working[ctx.author.id] = (msg, task.id)
|
||||
|
||||
await view.start(msg)
|
||||
await view.wait()
|
||||
|
||||
return view.choice, str(msg.id)
|
||||
return view.choice, msg
|
||||
|
||||
|
||||
def _validate_user_input(content: str | None, task: protocol_schema.Task) -> bool:
|
||||
"""Returns whether the user's input is valid for the task type."""
|
||||
def _validate_user_input(content: str | None, task: protocol_schema.Task) -> tuple[bool, str]:
|
||||
"""Returns whether the user's input is valid for the task type and an error message."""
|
||||
if content is None:
|
||||
return False
|
||||
return False, "No input provided"
|
||||
|
||||
# User message input
|
||||
if (
|
||||
@@ -277,22 +359,28 @@ def _validate_user_input(content: str | None, task: protocol_schema.Task) -> boo
|
||||
task,
|
||||
protocol_schema.InitialPromptTask | protocol_schema.PrompterReplyTask | protocol_schema.AssistantReplyTask,
|
||||
)
|
||||
return len(content) > 0
|
||||
return len(content) > 0, "Message must be at least one character long."
|
||||
|
||||
# Ranking tasks
|
||||
elif task.type == TaskRequestType.rank_prompter_replies or task.type == TaskRequestType.rank_assistant_replies:
|
||||
assert isinstance(task, protocol_schema.RankPrompterRepliesTask | protocol_schema.RankAssistantRepliesTask)
|
||||
num_replies = len(task.replies)
|
||||
|
||||
rankings = content.split(",")
|
||||
return set(rankings) == {str(i) for i in range(1, num_replies + 1)} and len(rankings) == num_replies
|
||||
rankings = content.replace(" ", "").split(",")
|
||||
return (
|
||||
set(rankings) == {str(i) for i in range(1, num_replies + 1)} and len(rankings) == num_replies,
|
||||
"Message must contain numbers for all replies.",
|
||||
)
|
||||
|
||||
elif task.type == TaskRequestType.rank_initial_prompts:
|
||||
assert isinstance(task, protocol_schema.RankInitialPromptsTask)
|
||||
num_prompts = len(task.prompts)
|
||||
|
||||
rankings = content.split(",")
|
||||
return set(rankings) == {str(i) for i in range(1, num_prompts + 1)} and len(rankings) == num_prompts
|
||||
rankings = content.replace(" ", "").split(",")
|
||||
return (
|
||||
set(rankings) == {str(i) for i in range(1, num_prompts + 1)} and len(rankings) == num_prompts,
|
||||
"Message must contain numbers for all prompts.",
|
||||
)
|
||||
|
||||
elif task.type == TaskRequestType.summarize_story:
|
||||
raise NotImplementedError
|
||||
@@ -316,22 +404,29 @@ class TaskAcceptView(miru.View):
|
||||
async def accept_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
logger.info("Accept button pressed")
|
||||
self.choice = "accept"
|
||||
await ctx.message.edit(component=None)
|
||||
self.stop()
|
||||
|
||||
@miru.button(label="Next Task", custom_id="next_task", row=0, style=hikari.ButtonStyle.SECONDARY)
|
||||
async def next_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
logger.info("Next button pressed")
|
||||
self.choice = "next"
|
||||
await ctx.message.edit(component=None)
|
||||
self.stop()
|
||||
|
||||
@miru.button(label="Cancel", custom_id="cancel", row=0, style=hikari.ButtonStyle.DANGER)
|
||||
async def cancel_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
logger.info("Cancel button pressed")
|
||||
self.choice = "cancel"
|
||||
await ctx.message.edit(component=None)
|
||||
self.stop()
|
||||
|
||||
async def on_timeout(self) -> None:
|
||||
if self.message is not None:
|
||||
await self.message.edit(component=None)
|
||||
|
||||
class ChoiceView(miru.View):
|
||||
|
||||
class YesNoView(miru.View):
|
||||
"""View with two buttons: yes and no.
|
||||
|
||||
The view stops once one of the buttons is pressed and the choice is stored in the `choice` attribute.
|
||||
@@ -342,115 +437,18 @@ class ChoiceView(miru.View):
|
||||
@miru.button(label="Yes", custom_id="yes", style=hikari.ButtonStyle.SUCCESS)
|
||||
async def yes_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
self.choice = True
|
||||
await ctx.message.edit(component=None)
|
||||
self.stop()
|
||||
|
||||
@miru.button(label="No", custom_id="no", style=hikari.ButtonStyle.DANGER)
|
||||
async def no_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
self.choice = False
|
||||
await ctx.message.edit(component=None)
|
||||
self.stop()
|
||||
|
||||
|
||||
################################################################
|
||||
# Template Embeds #
|
||||
################################################################
|
||||
|
||||
# TODO: Maybe implement a better way of creating embeds, like `from_json` or something
|
||||
|
||||
|
||||
def _initial_prompt_embed(task: protocol_schema.InitialPromptTask) -> hikari.Embed:
|
||||
return (
|
||||
hikari.Embed(title="Initial Prompt", description=f"Hint: {task.hint}", timestamp=datetime.now().astimezone())
|
||||
.set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512")
|
||||
.set_footer(text=f"OASST Assistant | {task.id}")
|
||||
)
|
||||
|
||||
|
||||
def _rank_initial_prompt_embed(task: protocol_schema.RankInitialPromptsTask) -> hikari.Embed:
|
||||
embed = (
|
||||
hikari.Embed(
|
||||
title="Rank Initial Prompt",
|
||||
description="Rank the following tasks from best to worst (1,2,3,4,5)",
|
||||
timestamp=datetime.now().astimezone(),
|
||||
)
|
||||
.set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512")
|
||||
.set_footer(text=f"OASST Assistant | {task.id}")
|
||||
)
|
||||
|
||||
for i, prompt in enumerate(task.prompts):
|
||||
embed.add_field(name=f"Prompt {i + 1}", value=prompt, inline=False)
|
||||
|
||||
return embed
|
||||
|
||||
|
||||
def _rank_prompter_reply_embed(task: protocol_schema.RankPrompterRepliesTask) -> hikari.Embed:
|
||||
embed = (
|
||||
hikari.Embed(
|
||||
title="Rank User Reply",
|
||||
description="Rank the following user replies from best to worst. e.g. 1,2,5,3,4",
|
||||
timestamp=datetime.now().astimezone(),
|
||||
)
|
||||
.set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") # TODO: update image
|
||||
.set_footer(text=f"OASST Assistant | {task.id}")
|
||||
)
|
||||
|
||||
for i, reply in enumerate(task.replies):
|
||||
embed.add_field(name=f"Reply {i + 1}", value=reply, inline=False)
|
||||
|
||||
return embed
|
||||
|
||||
|
||||
def _rank_assistant_reply_embed(task: protocol_schema.RankAssistantRepliesTask) -> hikari.Embed:
|
||||
embed = (
|
||||
hikari.Embed(
|
||||
title="Rank Assistant Reply",
|
||||
description="Rank the following assistant replies from best to worst. e.g. 1,2,5,3,4",
|
||||
timestamp=datetime.now().astimezone(),
|
||||
)
|
||||
.set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") # TODO: update image
|
||||
.set_footer(text=f"OASST Assistant | {task.id}")
|
||||
)
|
||||
|
||||
for i, reply in enumerate(task.replies):
|
||||
embed.add_field(name=f"Reply {i + 1}", value=reply, inline=False)
|
||||
|
||||
return embed
|
||||
|
||||
|
||||
def _prompter_reply_embed(task: protocol_schema.PrompterReplyTask) -> hikari.Embed:
|
||||
embed = (
|
||||
hikari.Embed(
|
||||
title="User Reply",
|
||||
description=f"""\
|
||||
Send the next message in the conversation as if you were the user.
|
||||
{'Hint: ' if task.hint else ''}
|
||||
""",
|
||||
timestamp=datetime.now().astimezone(),
|
||||
)
|
||||
# .set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") # TODO: change image
|
||||
.set_footer(text=f"OASST Assistant | {task.id}")
|
||||
)
|
||||
|
||||
for message in task.conversation.messages:
|
||||
embed.add_field(name="Assistant" if message.is_assistant else "User", value=message.text, inline=False)
|
||||
|
||||
return embed
|
||||
|
||||
|
||||
def _assistant_reply_embed(task: protocol_schema.AssistantReplyTask) -> hikari.Embed:
|
||||
embed = (
|
||||
hikari.Embed(
|
||||
title="User Reply",
|
||||
description="Send the next message in the conversation as if you were the user.",
|
||||
timestamp=datetime.now().astimezone(),
|
||||
)
|
||||
# .set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") # TODO: change image
|
||||
.set_footer(text=f"OASST Assistant | {task.id}")
|
||||
)
|
||||
|
||||
for message in task.conversation.messages:
|
||||
embed.add_field(name="Assistant" if message.is_assistant else "User", value=message.text, inline=False)
|
||||
|
||||
return embed
|
||||
async def on_timeout(self) -> None:
|
||||
if self.message is not None:
|
||||
await self.message.edit(component=None)
|
||||
|
||||
|
||||
def load(bot: lightbulb.BotApp):
|
||||
|
||||
@@ -0,0 +1,207 @@
|
||||
"""All user-facing messages and embeds."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import hikari
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
|
||||
NUMBER_EMOJIS = [":one:", ":two:", ":three:", ":four:", ":five:", ":six:", ":seven:", ":eight:", ":nine:", ":ten:"]
|
||||
NL = "\n"
|
||||
|
||||
###
|
||||
# Reusable 'components'
|
||||
###
|
||||
|
||||
|
||||
def _h1(text: str) -> str:
|
||||
return f"\n:small_blue_diamond: __**{text}**__ :small_blue_diamond:"
|
||||
|
||||
|
||||
def _h2(text: str) -> str:
|
||||
return f"__**{text}**__"
|
||||
|
||||
|
||||
def _h3(text: str) -> str:
|
||||
return f"__{text}__"
|
||||
|
||||
|
||||
def _writing_prompt(text: str) -> str:
|
||||
return f":pencil: _{text}_"
|
||||
|
||||
|
||||
def _ranking_prompt(text: str) -> str:
|
||||
return f":trophy: _{text}_"
|
||||
|
||||
|
||||
def _response_prompt(text: str) -> str:
|
||||
return f":speech_balloon: _{text}_"
|
||||
|
||||
|
||||
def _summarize_prompt(text: str) -> str:
|
||||
return f":notepad_spiral: _{text}_"
|
||||
|
||||
|
||||
def _user(text: str | None) -> str:
|
||||
return f"""\
|
||||
:person_red_hair: {_h3("User")}:{f"{NL}> **{text}**" if text is not None else ""}
|
||||
"""
|
||||
|
||||
|
||||
def _assistant(text: str | None) -> str:
|
||||
return f"""\
|
||||
:robot: {_h3("Assistant")}:{f"{NL}> {text}" if text is not None else ""}
|
||||
"""
|
||||
|
||||
|
||||
def _make_ordered_list(items: list[str]) -> list[str]:
|
||||
return [f"{num} {item}" for num, item in zip(NUMBER_EMOJIS, items)]
|
||||
|
||||
|
||||
def _ordered_list(items: list[str]) -> str:
|
||||
return "\n\n".join(_make_ordered_list(items))
|
||||
|
||||
|
||||
def _conversation(conv: protocol_schema.Conversation) -> str:
|
||||
return "\n".join([_assistant(msg.text) if msg.is_assistant else _user(msg.text) for msg in conv.messages])
|
||||
|
||||
|
||||
def _hint(hint: str | None) -> str:
|
||||
return f"{NL}Hint: {hint}" if hint else ""
|
||||
|
||||
|
||||
###
|
||||
# Messages
|
||||
###
|
||||
|
||||
|
||||
def initial_prompt_message(task: protocol_schema.InitialPromptTask) -> str:
|
||||
"""Creates the message that gets sent to users when they request an `initial_prompt` task."""
|
||||
return f"""\
|
||||
|
||||
{_h1("INITIAL PROMPT")}
|
||||
|
||||
{_writing_prompt("Please provide an initial prompt to the assistant.")}
|
||||
{_hint(task.hint)}
|
||||
"""
|
||||
|
||||
|
||||
def rank_initial_prompts_message(task: protocol_schema.RankInitialPromptsTask) -> str:
|
||||
"""Creates the message that gets sent to users when they request a `rank_initial_prompts` task."""
|
||||
return f"""\
|
||||
|
||||
{_h1("RANK INITIAL PROMPTS")}
|
||||
|
||||
{_ranking_prompt("Reply with the numbers of best to worst prompts separated by commas (example: '4,1,3,2')")}
|
||||
|
||||
|
||||
{_ordered_list(task.prompts)}
|
||||
"""
|
||||
|
||||
|
||||
def rank_prompter_reply_message(task: protocol_schema.RankPrompterRepliesTask) -> str:
|
||||
"""Creates the message that gets sent to users when they request a `rank_prompter_replies` task."""
|
||||
return f"""\
|
||||
|
||||
{_h1("RANK PROMPTER REPLIES")}
|
||||
|
||||
{_ranking_prompt("Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')")}
|
||||
|
||||
|
||||
{_conversation(task.conversation)}
|
||||
{_user(None)}
|
||||
{_ordered_list(task.replies)}
|
||||
"""
|
||||
|
||||
|
||||
def rank_assistant_reply_message(task: protocol_schema.RankAssistantRepliesTask) -> str:
|
||||
"""Creates the message that gets sent to users when they request a `rank_assistant_replies` task."""
|
||||
return f"""\
|
||||
|
||||
{_h1("RANK ASSISTANT REPLIES")}
|
||||
|
||||
{_ranking_prompt("Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')")}
|
||||
|
||||
|
||||
{_conversation(task.conversation)}
|
||||
{_assistant(None)}
|
||||
{_ordered_list(task.replies)}
|
||||
"""
|
||||
|
||||
|
||||
def prompter_reply_message(task: protocol_schema.PrompterReplyTask) -> str:
|
||||
"""Creates the message that gets sent to users when they request a `prompter_reply` task."""
|
||||
return f"""\
|
||||
|
||||
{_h1("PROMPTER REPLY")}
|
||||
|
||||
{_response_prompt("Please provide a reply to the assistant.")}
|
||||
|
||||
|
||||
{_conversation(task.conversation)}
|
||||
{_hint(task.hint)}
|
||||
"""
|
||||
|
||||
|
||||
def assistant_reply_message(task: protocol_schema.AssistantReplyTask) -> str:
|
||||
"""Creates the message that gets sent to users when they request a `assistant_reply` task."""
|
||||
return f"""\
|
||||
{_h1("ASSISTANT REPLY")}
|
||||
|
||||
{_response_prompt("Please provide a reply to the assistant.")}
|
||||
|
||||
|
||||
{_conversation(task.conversation)}
|
||||
"""
|
||||
|
||||
|
||||
def confirm_text_response_message(content: str) -> str:
|
||||
return f"""\
|
||||
{_h2("CONFIRM RESPONSE")}
|
||||
|
||||
> {content}
|
||||
"""
|
||||
|
||||
|
||||
def confirm_ranking_response_message(content: str, items: list[str]) -> str:
|
||||
user_rankings = [int(r) for r in content.replace(" ", "").split(",")]
|
||||
original_list = _make_ordered_list(items)
|
||||
user_ranked_list = "\n\n".join([original_list[r - 1] for r in user_rankings])
|
||||
|
||||
return f"""\
|
||||
{_h2("CONFIRM RESPONSE")}
|
||||
|
||||
{user_ranked_list}
|
||||
"""
|
||||
|
||||
|
||||
###
|
||||
# Embeds
|
||||
###
|
||||
|
||||
|
||||
def task_complete_embed(task: protocol_schema.Task, mention: str) -> hikari.Embed:
|
||||
return (
|
||||
hikari.Embed(
|
||||
title="Task Completion",
|
||||
description=f"`{task.type}` completed by {mention}",
|
||||
color=hikari.Color(0x00FF00),
|
||||
timestamp=datetime.now().astimezone(),
|
||||
)
|
||||
.add_field("Total Tasks", "0", inline=True)
|
||||
.add_field("Server Ranking", "0/0", inline=True)
|
||||
.add_field("Global Ranking", "0/0", inline=True)
|
||||
.set_footer(f"Task ID: {task.id}")
|
||||
)
|
||||
|
||||
|
||||
def invalid_user_input_embed(error_message: str) -> hikari.Embed:
|
||||
return hikari.Embed(
|
||||
title="Invalid User Input",
|
||||
description=error_message,
|
||||
color=hikari.Color(0xFF0000),
|
||||
timestamp=datetime.now().astimezone(),
|
||||
)
|
||||
|
||||
|
||||
def plain_embed(text: str) -> hikari.Embed:
|
||||
return hikari.Embed(color=0x36393F, description=text)
|
||||
@@ -8,7 +8,7 @@ class Settings(BaseSettings):
|
||||
bot_token: str = Field(env="BOT_TOKEN", default="")
|
||||
declare_global_commands: int = Field(env="DECLARE_GLOBAL_COMMANDS", default=0)
|
||||
owner_ids: list[int] = Field(env="OWNER_IDS", default_factory=list)
|
||||
prefix: str = Field(env="PREFIX", default="./")
|
||||
prefix: str = Field(env="PREFIX", default="/")
|
||||
oasst_api_url: str = Field(env="OASST_API_URL", default="http://localhost:8080")
|
||||
oasst_api_key: str = Field(env="OASST_API_KEY", default="")
|
||||
|
||||
|
||||
@@ -24,13 +24,6 @@ def format_time(dt: datetime, fmt: t.Literal["t", "T", "D", "f", "F", "R"]) -> s
|
||||
raise ValueError(f"`fmt` must be 't', 'T', 'D', 'f', 'F' or 'R', not {fmt}")
|
||||
|
||||
|
||||
EMPTY = "\u200d"
|
||||
"""Zero-width joiner.
|
||||
|
||||
This appears as an empty message in Discord.
|
||||
"""
|
||||
|
||||
|
||||
def mention(
|
||||
id: hikari.Snowflakeish,
|
||||
type: t.Literal["channel", "role", "user"],
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
aiohttp # http client
|
||||
aiohttp[speedups] # speedups for aiohttp
|
||||
aiosqlite # database
|
||||
hikari # discord framework
|
||||
hikari-lightbulb # command handler
|
||||
|
||||
+1
-1
@@ -9,7 +9,7 @@ services:
|
||||
# Use `docker compose up frontend-dev --build --attach-dependencies` to start all services needed to work on the frontend.
|
||||
frontend-dev:
|
||||
image: sverrirab/sleep
|
||||
depends_on: [db, webdb, adminer, maildev, backend]
|
||||
depends_on: [db, webdb, adminer, maildev, backend, redis]
|
||||
|
||||
# This DB is for the FastAPI Backend.
|
||||
db:
|
||||
|
||||
@@ -5,6 +5,7 @@ COPY ./backend/requirements.txt /app/requirements.txt
|
||||
RUN pip install --no-cache-dir --upgrade -r /app/requirements.txt
|
||||
|
||||
ENV PORT 8080
|
||||
EXPOSE 8080
|
||||
|
||||
COPY ./oasst-shared /oasst-shared
|
||||
RUN pip install -e /oasst-shared
|
||||
|
||||
@@ -99,6 +99,9 @@ The main tasks are a) generation of response text and b) ranking of responses.
|
||||
The following sections describe the data schemas for each of these tasks. Both
|
||||
should be implementable in parquet files.
|
||||
|
||||
Note: These files are meant to be consumed by ML algorithms and should ideally
|
||||
be produced from the above files.
|
||||
|
||||
## Common Data Structures
|
||||
|
||||
```python
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
# Train using supervised examples
|
||||
|
||||
Requirements
|
||||
|
||||
```
|
||||
wandb
|
||||
evaluate
|
||||
datasets
|
||||
transformers
|
||||
torch
|
||||
```
|
||||
|
||||
Start training reward model
|
||||
|
||||
```bash
|
||||
python trainer.py --configs defaults galactica-125
|
||||
```
|
||||
|
||||
## Dataset
|
||||
|
||||
For now we only support webgpt and summary dataset from OpenAI. Once
|
||||
open-asisstant dataset are available it will be added here.
|
||||
|
||||
## Model
|
||||
|
||||
TBD
|
||||
|
||||
## Results
|
||||
|
||||
Experimental results in wandb
|
||||
[here](https://wandb.ai/sanagnos/supervised-finetuning?workspace=user-sanagnos).
|
||||
|
||||
## TODOS
|
||||
|
||||
- decide on a model
|
||||
- add special token to declare prompt and reply. Do nto freeze the weights for
|
||||
these
|
||||
- Merge utils etc with reward model
|
||||
@@ -0,0 +1,37 @@
|
||||
defaults:
|
||||
learning_rate: 1e-5
|
||||
gradient_checkpointing: false
|
||||
gradient_accumulation_steps: 32
|
||||
per_device_train_batch_size: 2
|
||||
per_device_eval_batch_size: 2
|
||||
weight_decay: 0.00
|
||||
warmup_steps: 600
|
||||
eval_steps: 200
|
||||
save_steps: 500
|
||||
max_length: 512
|
||||
num_train_epochs: 3
|
||||
logging_steps: 10
|
||||
max_grad_norm: 2.0
|
||||
save_total_limit: 4
|
||||
eval_accumulation_steps:
|
||||
freeze_layer:
|
||||
datasets:
|
||||
- webgpt
|
||||
cache_dir: ~/.cache
|
||||
loss_fn: CrossEntropyLoss
|
||||
eval_size:
|
||||
log_dir: "base"
|
||||
|
||||
galactica-125:
|
||||
learning_rate: 5e-5
|
||||
model_name: facebook/galactica-125m
|
||||
weight_decay: 0.01
|
||||
warmup_steps: 600
|
||||
gradient_checkpointing: false
|
||||
gradient_accumulation_steps: 2
|
||||
per_device_train_batch_size: 4
|
||||
per_device_eval_batch_size: 4
|
||||
|
||||
debug:
|
||||
eval_steps: 20
|
||||
eval_size: 100
|
||||
@@ -0,0 +1,67 @@
|
||||
from datasets import load_dataset
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import Dataset, Subset
|
||||
|
||||
|
||||
class SquadV2Dataset(Dataset):
|
||||
def __init__(self, cache_dir, split):
|
||||
self.dataset = load_dataset("squad_v2", cache_dir=cache_dir, split=split)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
data = self.dataset[idx]
|
||||
# dummy return first answer
|
||||
return "".join([data["title"], ". ", data["context"], " " + data["question"]]), data["answers"]["text"][0]
|
||||
|
||||
|
||||
class WebGPT(Dataset):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
dataset = load_dataset("openai/webgpt_comparisons")
|
||||
questions = {}
|
||||
# using prompt as our index will allows us
|
||||
# to add additional generated prompt later
|
||||
self.index2question = {}
|
||||
for row in dataset["train"]:
|
||||
question = row["question"]["full_text"]
|
||||
if question not in self.index2question:
|
||||
self.index2question[len(self.index2question)] = question
|
||||
|
||||
# only keep the best answer
|
||||
questions[question] = row["answer_0" if row["score_0"] > row["score_1"] else "answer_1"]
|
||||
|
||||
self.questions = questions
|
||||
|
||||
def __len__(self):
|
||||
return len(self.index2question)
|
||||
|
||||
def __getitem__(self, index):
|
||||
question = self.index2question[index]
|
||||
answer = self.questions[question]
|
||||
return [question, answer]
|
||||
|
||||
|
||||
def train_val_dataset(dataset, val_split=0.2):
|
||||
train_idx, val_idx = train_test_split(
|
||||
list(range(len(dataset))), test_size=val_split, random_state=666, shuffle=True
|
||||
)
|
||||
return Subset(dataset, train_idx), Subset(dataset, val_idx)
|
||||
|
||||
|
||||
def get_one_dataset(conf, dataset_name):
|
||||
dataset_name = dataset_name.lower()
|
||||
|
||||
if dataset_name == "squadv2":
|
||||
raise ValueError("SquadV2 is not diverse enough for generation .. ")
|
||||
train = SquadV2Dataset(conf.cache_dir, "train")
|
||||
eval = SquadV2Dataset(conf.cache_dir, "validation")
|
||||
elif dataset_name == "webgpt":
|
||||
dataset = WebGPT()
|
||||
train, eval = train_val_dataset(dataset, val_split=0.2)
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset {dataset_name}")
|
||||
|
||||
return train, eval
|
||||
@@ -0,0 +1,85 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase
|
||||
|
||||
|
||||
@dataclass
|
||||
class DialogueDataCollator:
|
||||
"""
|
||||
Expects a list of texts corresponding to a sequence of [question, answer, question, answer, ...] pairs.
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
padding: Union[bool, str, PaddingStrategy] = True
|
||||
max_length: Optional[int] = None
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
|
||||
def __call__(self, features):
|
||||
# TODO add special tokens for question and answer here
|
||||
# additional_special_tokens = ['<question>', '<answer>']
|
||||
prompt_tokens = ["Question: ", "Answer: "]
|
||||
|
||||
flatten_messages = []
|
||||
label_masks = []
|
||||
|
||||
for messages in features:
|
||||
assert len(messages) % 2 == 0, "Number of messages must be even"
|
||||
messages = [
|
||||
(prompt_tokens[0] if i % 2 == 0 else "") + x + ((" " + prompt_tokens[1]) if i % 2 == 0 else "")
|
||||
for i, x in enumerate(messages)
|
||||
]
|
||||
|
||||
# Add a way for the model to terminate generation, reinitialize prompter
|
||||
messages.append(prompt_tokens[0])
|
||||
|
||||
flatten_messages.append(
|
||||
self.tokenizer(
|
||||
"".join(messages),
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_offsets_mapping=True,
|
||||
)
|
||||
)
|
||||
|
||||
message_change_indices = np.cumsum([len(x) for x in messages[:-1]])
|
||||
# for each token an integer indicating the index of the message it belongs to. Just to create the label mask.
|
||||
# TEXT: Question: Hello, how are you? Answer: I am fine. Question: What is your name? Answer: My name is John.
|
||||
# MESSAGE_INDICES: 0 0 0 0 0 0 1 1 1 2 2 2 2 2 2 3 3 3 3
|
||||
|
||||
# If no result in next, we are predicting the last termination token(s)
|
||||
message_indices = list(
|
||||
map(
|
||||
lambda x: next((i for i, val in enumerate(message_change_indices) if val >= x), -2),
|
||||
list(map(lambda x: x[1], flatten_messages[-1]["offset_mapping"])),
|
||||
)
|
||||
)
|
||||
label_mask = np.roll(list(map(lambda x: x % 2 == 1, message_indices)), -1, -1)
|
||||
try:
|
||||
label_mask[[i for i in range(len(message_indices)) if message_indices[i] == -2][0] - 1] = True
|
||||
except IndexError:
|
||||
# an aftermath of padding
|
||||
pass
|
||||
|
||||
label_masks.append(label_mask)
|
||||
flatten_messages[-1].pop("offset_mapping")
|
||||
|
||||
batch = self.tokenizer.pad(
|
||||
flatten_messages,
|
||||
padding=self.padding,
|
||||
max_length=self.max_length,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors="pt",
|
||||
)
|
||||
dim = batch["input_ids"].shape[-1]
|
||||
|
||||
batch["label_masks"] = torch.stack([F.pad(torch.tensor(x), (0, dim - len(x))) for x in label_masks])
|
||||
|
||||
for k in list(batch.keys()):
|
||||
if k not in ["input_ids", "attention_mask", "label_masks"]:
|
||||
batch.pop(k)
|
||||
|
||||
return batch
|
||||
@@ -0,0 +1,15 @@
|
||||
from torch import nn
|
||||
|
||||
|
||||
class CrossEntropyLoss(nn.CrossEntropyLoss):
|
||||
def __init__(self, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean"):
|
||||
super(CrossEntropyLoss, self).__init__(weight, size_average, ignore_index, reduce, reduction)
|
||||
|
||||
def forward(self, input, target, mask=None):
|
||||
if mask is not None:
|
||||
mask = mask.view(-1)
|
||||
input = input.view(-1, input.size(-1))
|
||||
target = target.view(-1)
|
||||
input = input[mask]
|
||||
target = target[mask]
|
||||
return super(CrossEntropyLoss, self).forward(input, target)
|
||||
@@ -0,0 +1,200 @@
|
||||
import argparse
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from distutils.util import strtobool
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import (
|
||||
DataCollator,
|
||||
EvalPrediction,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase,
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
get_cosine_schedule_with_warmup,
|
||||
)
|
||||
from utils import get_dataset, get_loss, get_model, get_tokenizer, read_yamls
|
||||
|
||||
os.environ["WANDB_PROJECT"] = "supervised-finetuning"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CustomTrainingArguments(TrainingArguments):
|
||||
loss_function: str = "CrossEntropyLoss"
|
||||
|
||||
|
||||
def compute_metrics(eval_pred):
|
||||
pred_ids = eval_pred.predictions
|
||||
labels = eval_pred.label_ids
|
||||
|
||||
return {"accuracy": (pred_ids[labels > 0] == labels[labels > 0]).mean()}
|
||||
|
||||
|
||||
def preprocess_logits_for_metrics(logits, labels):
|
||||
pred_ids = torch.argmax(logits, dim=-1)
|
||||
return pred_ids
|
||||
|
||||
|
||||
class SFTTrainer(Trainer):
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module] = None,
|
||||
args: TrainingArguments = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
eval_dataset: Optional[Dataset] = None,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
||||
model_init: Callable[[], PreTrainedModel] = None,
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
||||
callbacks: Optional[List[TrainerCallback]] = None,
|
||||
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
|
||||
):
|
||||
super().__init__(
|
||||
model,
|
||||
args,
|
||||
data_collator,
|
||||
train_dataset,
|
||||
eval_dataset,
|
||||
tokenizer,
|
||||
model_init,
|
||||
compute_metrics,
|
||||
callbacks,
|
||||
optimizers,
|
||||
preprocess_logits_for_metrics,
|
||||
)
|
||||
self.loss_fct = get_loss(args.loss_function)
|
||||
|
||||
def fetch_scheduler(self):
|
||||
return get_cosine_schedule_with_warmup(
|
||||
self.optimizer,
|
||||
num_warmup_steps=self.args.warmup_steps,
|
||||
num_training_steps=self.num_train_steps,
|
||||
num_cycles=1,
|
||||
last_epoch=-1,
|
||||
)
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
labels_mask = inputs.pop("label_masks")
|
||||
|
||||
outputs = model(**inputs)
|
||||
|
||||
loss = self.loss_fct(outputs.get("logits"), torch.roll(inputs["input_ids"], -1, -1), mask=labels_mask)
|
||||
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
|
||||
def _compute_loss(self, model, inputs):
|
||||
|
||||
labels_mask = inputs.pop("label_masks")
|
||||
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
|
||||
outputs = model(**inputs)
|
||||
|
||||
logits = outputs.get("logits")
|
||||
|
||||
targets = torch.roll(inputs["input_ids"], -1, -1)
|
||||
loss = self.loss_fct(outputs.get("logits"), targets, mask=labels_mask)
|
||||
|
||||
return loss, logits, targets, labels_mask
|
||||
|
||||
def prediction_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: Optional[List[str]] = None,
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
|
||||
with torch.no_grad():
|
||||
loss, logits, labels, labels_mask = self._compute_loss(model, inputs)
|
||||
labels[~labels_mask] = -1
|
||||
|
||||
loss = loss.mean().detach()
|
||||
|
||||
if self.args.prediction_loss_only:
|
||||
return (loss, None, None)
|
||||
|
||||
return (loss, logits, labels)
|
||||
|
||||
|
||||
def _strtobool(x):
|
||||
return bool(strtobool(x))
|
||||
|
||||
|
||||
def argument_parsing(notebook=False, notebook_args=None):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--configs", nargs="+", required=True)
|
||||
|
||||
if notebook:
|
||||
args, remaining = parser.parse_known_args(notebook_args)
|
||||
else:
|
||||
args, remaining = parser.parse_known_args()
|
||||
|
||||
# Config from YAML
|
||||
conf = {}
|
||||
configs = read_yamls("./configs")
|
||||
for name in args.configs:
|
||||
if "," in name:
|
||||
for n in name.split(","):
|
||||
conf.update(configs[n])
|
||||
else:
|
||||
conf.update(configs[name])
|
||||
|
||||
# Override config from command-line
|
||||
parser = argparse.ArgumentParser()
|
||||
for key, value in conf.items():
|
||||
type_ = type(value) if value is not None else str
|
||||
if type_ == bool:
|
||||
type_ = _strtobool
|
||||
parser.add_argument(f"--{key}", type=type_, default=value)
|
||||
|
||||
return parser.parse_args(remaining)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
training_conf = argument_parsing()
|
||||
|
||||
model = get_model(training_conf)
|
||||
tokenizer = get_tokenizer(training_conf)
|
||||
|
||||
train, evals, collate_fn = get_dataset(training_conf, tokenizer)
|
||||
|
||||
args = CustomTrainingArguments(
|
||||
output_dir=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned",
|
||||
num_train_epochs=training_conf.num_train_epochs,
|
||||
warmup_steps=training_conf.warmup_steps,
|
||||
loss_function=training_conf.loss_fn,
|
||||
learning_rate=float(training_conf.learning_rate),
|
||||
fp16=True,
|
||||
gradient_checkpointing=training_conf.gradient_checkpointing,
|
||||
gradient_accumulation_steps=training_conf.gradient_accumulation_steps,
|
||||
per_device_train_batch_size=training_conf.per_device_train_batch_size,
|
||||
per_device_eval_batch_size=training_conf.per_device_eval_batch_size,
|
||||
weight_decay=training_conf.weight_decay,
|
||||
max_grad_norm=training_conf.max_grad_norm,
|
||||
logging_steps=training_conf.logging_steps,
|
||||
save_total_limit=training_conf.save_total_limit,
|
||||
evaluation_strategy="steps",
|
||||
eval_steps=training_conf.eval_steps,
|
||||
save_steps=training_conf.save_steps,
|
||||
eval_accumulation_steps=training_conf.eval_accumulation_steps,
|
||||
report_to="wandb",
|
||||
)
|
||||
|
||||
assert len(evals) > 0
|
||||
trainer = SFTTrainer(
|
||||
model,
|
||||
args,
|
||||
train_dataset=train,
|
||||
eval_dataset=evals,
|
||||
data_collator=collate_fn,
|
||||
tokenizer=tokenizer,
|
||||
compute_metrics=compute_metrics,
|
||||
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
||||
)
|
||||
trainer.train()
|
||||
@@ -0,0 +1,111 @@
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from custom_datasets import get_one_dataset
|
||||
from custom_datasets.dialogue_collator import DialogueDataCollator
|
||||
from losses import CrossEntropyLoss
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import ConcatDataset, Subset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
SUPPORTED_MODELS = ["galactica"]
|
||||
|
||||
|
||||
def get_tokenizer(conf):
|
||||
tokenizer = AutoTokenizer.from_pretrained(conf.model_name, cache_dir=conf.cache_dir)
|
||||
|
||||
if "galactica" in conf.model_name:
|
||||
tokenizer.add_special_tokens({"pad_token": "<pad>", "eos_token": "</s>"})
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def get_model(conf):
|
||||
if not any([x in conf.model_name for x in SUPPORTED_MODELS]):
|
||||
raise ValueError(
|
||||
f"Model {conf.model_name} not supported. Supported models: {SUPPORTED_MODELS}. "
|
||||
"To include more make sure the masking is dne correctly... (decoder only supported for now)"
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(conf.model_name, cache_dir=conf.cache_dir)
|
||||
|
||||
if conf.freeze_layer:
|
||||
model = freeze_top_n_layers(model, conf.freeze_layer)
|
||||
|
||||
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
||||
params = sum([p.numel() for p in model_parameters])
|
||||
print("Number of trainable parameters: {}M".format(int(params / 1e6)))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_dataset(conf, tokenizer):
|
||||
train_datasets, evals = [], {}
|
||||
|
||||
for dataset_name in conf.datasets:
|
||||
train, val = get_one_dataset(conf, dataset_name)
|
||||
train_datasets.append(train)
|
||||
evals[dataset_name] = Subset(val, list(range(min(len(val), conf.eval_size)))) if conf.eval_size else val
|
||||
|
||||
train = ConcatDataset(train_datasets)
|
||||
|
||||
collate_fn = DialogueDataCollator(tokenizer, max_length=conf.max_length)
|
||||
|
||||
return train, evals, collate_fn
|
||||
|
||||
|
||||
def get_loss(loss):
|
||||
if loss == "CrossEntropyLoss":
|
||||
return CrossEntropyLoss()
|
||||
else:
|
||||
raise ValueError(f"Loss {loss} not supported")
|
||||
|
||||
|
||||
def read_yamls(dir):
|
||||
conf = {}
|
||||
no_conf = True
|
||||
|
||||
for config_file in Path(dir).glob("**/*.yaml"):
|
||||
no_conf = False
|
||||
with config_file.open("r") as f:
|
||||
conf.update(yaml.safe_load(f))
|
||||
|
||||
if no_conf:
|
||||
print(f"WARNING: No yaml files found in {dir}")
|
||||
|
||||
return conf
|
||||
|
||||
|
||||
def train_val_dataset(dataset, val_split=0.2):
|
||||
train_idx, val_idx = train_test_split(
|
||||
list(range(len(dataset))), test_size=val_split, random_state=666, shuffle=True
|
||||
)
|
||||
return Subset(dataset, train_idx), Subset(dataset, val_idx)
|
||||
|
||||
|
||||
def freeze_top_n_layers(model, target_layers):
|
||||
# its possible we can simply detect which module is a ModuleList
|
||||
# and simply freeze the module without doing string parsing
|
||||
for name, param in model.named_parameters():
|
||||
if "embed" in name:
|
||||
param.requires_grad = False
|
||||
elif ".layer" in name or ".h." in name:
|
||||
tokens = name.split(".")
|
||||
layer_ = None
|
||||
for token in tokens:
|
||||
if token.isdigit():
|
||||
layer_ = int(token)
|
||||
break
|
||||
|
||||
if layer_ is not None and layer_ < target_layers:
|
||||
# print('freeze ', layer_, name)
|
||||
param.requires_grad = False
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained("bigscience/bloomz-560m")
|
||||
freeze_top_n_layers(model, 10)
|
||||
print(model.state_dict().keys())
|
||||
File diff suppressed because one or more lines are too long
@@ -24,10 +24,13 @@ only described in the notebook
|
||||
|
||||
Charts showing detailed memory usages and times for different sentence lengths
|
||||
and batch sizes are inside the notebook Quick overview batch size 16, sentence
|
||||
length 4k for training, batch size 128 sentence length 4k for inference | Model
|
||||
name | Training memory| Training speed | Inference Memory| Inference Speed| |
|
||||
:---: | :---: | :---: |:---: | :---: | |original| 11.8GB | 2.40s| 4.8GB|16.48s|
|
||||
|unbiased| 12GB| 1.09s| 4.8GB | 5.59s| |multilingual|14GB| 1.00s| 5.5GB| 4.89s|
|
||||
length 4k for training, batch size 128 sentence length 4k for Inference
|
||||
|
||||
| Model name | Training memory | Training speed | Inference Memory | Inference Speed |
|
||||
| :----------: | :-------------: | :------------: | :--------------: | :-------------: |
|
||||
| original | 11.8GB | 2.40s | 4.8GB | 16.48s |
|
||||
| unbiased | 12GB | 1.09s | 4.8GB | 5.59s |
|
||||
| multilingual | 14GB | 1.00s | 5.5GB | 4.89s |
|
||||
|
||||
# Filtering quality
|
||||
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
"""API Client for interacting with the OASST backend."""
|
||||
import enum
|
||||
import typing as t
|
||||
from http import HTTPStatus
|
||||
from typing import Optional, Type
|
||||
from uuid import UUID
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
||||
# TODO: Move to `protocol`?
|
||||
@@ -27,7 +30,7 @@ class TaskType(str, enum.Enum):
|
||||
class OasstApiClient:
|
||||
"""API Client for interacting with the OASST backend."""
|
||||
|
||||
def __init__(self, backend_url: str, api_key: str):
|
||||
def __init__(self, backend_url: str, api_key: str, session: Optional[aiohttp.ClientSession] = None):
|
||||
"""Create a new OasstApiClient.
|
||||
|
||||
Args:
|
||||
@@ -35,8 +38,12 @@ class OasstApiClient:
|
||||
backend_url (str): The base backend URL.
|
||||
api_key (str): The API key to use for authentication.
|
||||
"""
|
||||
logger.debug("Opening OasstApiClient session")
|
||||
self.session = aiohttp.ClientSession()
|
||||
|
||||
if session is None:
|
||||
logger.debug("Opening OasstApiClient session")
|
||||
session = aiohttp.ClientSession()
|
||||
|
||||
self.session = session
|
||||
self.backend_url = backend_url
|
||||
self.api_key = api_key
|
||||
|
||||
@@ -56,7 +63,33 @@ class OasstApiClient:
|
||||
"""Make a POST request to the backend."""
|
||||
logger.debug(f"POST {self.backend_url}{path} DATA: {data}")
|
||||
response = await self.session.post(f"{self.backend_url}{path}", json=data, headers={"X-API-Key": self.api_key})
|
||||
response.raise_for_status()
|
||||
|
||||
# If the response is not a 2XX, check to see
|
||||
# if the json has the fields to create an
|
||||
# OasstError.
|
||||
if response.status >= 300:
|
||||
data = await response.json()
|
||||
try:
|
||||
oasst_error = protocol_schema.OasstErrorResponse(**(data or {}))
|
||||
raise OasstError(
|
||||
error_code=oasst_error.error_code,
|
||||
message=oasst_error.message,
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.debug(f"Got error from API but could not parse: {e}")
|
||||
|
||||
raw_response = await response.text()
|
||||
logger.debug(f"Raw response: {raw_response}")
|
||||
|
||||
raise OasstError(
|
||||
raw_response,
|
||||
OasstErrorCode.GENERIC_ERROR,
|
||||
HTTPStatus(response.status),
|
||||
)
|
||||
|
||||
if response.status == 204:
|
||||
# No content
|
||||
return None
|
||||
return await response.json()
|
||||
|
||||
def _parse_task(self, data: Optional[dict[str, t.Any]]) -> protocol_schema.Task:
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import List, Literal, Optional, Union
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pydantic
|
||||
from oasst_shared.exceptions import OasstErrorCode
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@@ -293,3 +294,10 @@ class UserScore(BaseModel):
|
||||
|
||||
class LeaderboardStats(BaseModel):
|
||||
leaderboard: List[UserScore]
|
||||
|
||||
|
||||
class OasstErrorResponse(BaseModel):
|
||||
"""The format of an error response from the OASST API."""
|
||||
|
||||
error_code: OasstErrorCode
|
||||
message: str
|
||||
|
||||
@@ -11,5 +11,7 @@ setup(
|
||||
author="OASST Team",
|
||||
install_requires=[
|
||||
"pydantic==1.9.1",
|
||||
"aiohttp==3.8.3",
|
||||
"aiohttp[speedups]",
|
||||
],
|
||||
)
|
||||
|
||||
@@ -1,12 +1,21 @@
|
||||
from typing import Any
|
||||
from unittest import mock
|
||||
from uuid import uuid4
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
from oasst_shared.api_client import OasstApiClient
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oasst_api_client_mocked():
|
||||
"""
|
||||
A an oasst_api_client pointed at the mocked backend.
|
||||
Relies on ./scripts/backend-development/start-mock-server.sh
|
||||
being run.
|
||||
"""
|
||||
client = OasstApiClient(backend_url="http://localhost:8080", api_key="123")
|
||||
yield client
|
||||
# TODO The fixture should close this connection, but there seems to be a bug
|
||||
@@ -15,6 +24,30 @@ def oasst_api_client_mocked():
|
||||
# await client.close()
|
||||
|
||||
|
||||
class MockClientSession(aiohttp.ClientSession):
|
||||
response: Any
|
||||
|
||||
def set_response(self, response: Any):
|
||||
self.response = response
|
||||
|
||||
async def post(self, *args, **kwargs):
|
||||
return self.response
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_http_session():
|
||||
yield MockClientSession()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oasst_api_client_fake_http(mock_http_session):
|
||||
"""
|
||||
An oasst_api_client that uses a mocked http session. No real requests are made.
|
||||
"""
|
||||
client = OasstApiClient(backend_url="http://localhost:8080", api_key="123", session=mock_http_session)
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("task_type", protocol_schema.TaskRequestType)
|
||||
async def test_can_fetch_task(task_type: protocol_schema.TaskRequestType, oasst_api_client_mocked: OasstApiClient):
|
||||
@@ -49,3 +82,47 @@ async def test_can_post_interaction(oasst_api_client_mocked: OasstApiClient):
|
||||
)
|
||||
is not None
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_can_handle_oasst_error_from_api(
|
||||
oasst_api_client_fake_http: OasstApiClient,
|
||||
mock_http_session: MockClientSession,
|
||||
):
|
||||
# Return a 400 response with an OasstErrorResponse body
|
||||
response_body = protocol_schema.OasstErrorResponse(
|
||||
error_code=OasstErrorCode.GENERIC_ERROR,
|
||||
message="Some error",
|
||||
)
|
||||
status_code = 400
|
||||
|
||||
mock_http_session.set_response(
|
||||
mock.AsyncMock(
|
||||
status=status_code,
|
||||
text=mock.AsyncMock(return_value=response_body.json()),
|
||||
json=mock.AsyncMock(return_value=response_body.dict()),
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(OasstError):
|
||||
await oasst_api_client_fake_http.post("/some-path", data={})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_can_handle_unknown_error_from_api(
|
||||
oasst_api_client_fake_http: OasstApiClient,
|
||||
mock_http_session: MockClientSession,
|
||||
):
|
||||
response_body = "Internal Server Error"
|
||||
status_code = 500
|
||||
|
||||
mock_http_session.set_response(
|
||||
mock.AsyncMock(
|
||||
status=status_code,
|
||||
text=mock.AsyncMock(return_value=response_body),
|
||||
json=mock.AsyncMock(return_value=None),
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(OasstError):
|
||||
await oasst_api_client_fake_http.post("/some-path", data={})
|
||||
|
||||
@@ -4,6 +4,6 @@ parent_path=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P )
|
||||
# switch to bot directory
|
||||
pushd "$parent_path/../../discord-bot"
|
||||
|
||||
python3 __main__.py
|
||||
python3 -m bot
|
||||
|
||||
popd
|
||||
|
||||
@@ -4,6 +4,8 @@ parent_path=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P )
|
||||
# switch to backend directory
|
||||
pushd "$parent_path/../../oasst-shared"
|
||||
|
||||
set -xe
|
||||
|
||||
pytest .
|
||||
|
||||
popd
|
||||
|
||||
@@ -3,7 +3,10 @@ const nextConfig = {
|
||||
output: "standalone",
|
||||
reactStrictMode: true,
|
||||
experimental: {
|
||||
scrollRestoration: true,
|
||||
/* Disabling this for now only because it causes a warning in the console that cannot be silenced for eslint
|
||||
If this can be resolved, we should re-enable this.
|
||||
*/
|
||||
// scrollRestoration: true,
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
Generated
+4444
-30
File diff suppressed because it is too large
Load Diff
@@ -22,6 +22,7 @@
|
||||
"@dnd-kit/core": "^6.0.6",
|
||||
"@dnd-kit/modifiers": "^6.0.1",
|
||||
"@dnd-kit/sortable": "^7.0.1",
|
||||
"@dnd-kit/utilities": "^3.2.1",
|
||||
"@emotion/react": "^11.10.5",
|
||||
"@emotion/styled": "^11.10.5",
|
||||
"@headlessui/react": "^1.7.7",
|
||||
@@ -39,9 +40,11 @@
|
||||
"eslint-plugin-simple-import-sort": "^8.0.0",
|
||||
"focus-visible": "^5.2.0",
|
||||
"framer-motion": "^6.5.1",
|
||||
"install": "^0.13.0",
|
||||
"next": "13.0.6",
|
||||
"next-auth": "^4.18.6",
|
||||
"nodemailer": "^6.8.0",
|
||||
"npm": "^9.2.0",
|
||||
"postcss-focus-visible": "^7.1.0",
|
||||
"react": "18.2.0",
|
||||
"react-dom": "18.2.0",
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
export function AuthLayout({ children }) {
|
||||
return (
|
||||
<main className="flex items-center justify-center sm:py-4 subpixel-antialiased">
|
||||
<div className="flex items-center justify-center sm:py-4 subpixel-antialiased">
|
||||
<div className="flex items-center w-full max-w-2xl flex-col px-4 sm:px-6">
|
||||
<div className="flex-auto items-center justify-center w-full py-10 px-4 sm:mx-0 sm:flex-none sm:rounded-2xl sm:p-4">
|
||||
{children}
|
||||
</div>
|
||||
</div>
|
||||
</main>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,16 +1,48 @@
|
||||
import { CircleBackground } from "./CircleBackground";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import { useId } from "react";
|
||||
|
||||
import { Container } from "./Container";
|
||||
|
||||
export function CallToAction() {
|
||||
function CircleBackground({ width = 558, height = 558, ...props }) {
|
||||
const id = useId();
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const baseRingColor = colorMode === "light" ? "#777" : "#000";
|
||||
const gradStopColor = colorMode === "light" ? "#fff" : "#000";
|
||||
|
||||
return (
|
||||
<section id="join-us" className="relative overflow-hidden bg-gray-900 py-20 sm:py-28">
|
||||
<svg viewBox="0 0 558 558" width={width} height={height} fill="none" aria-hidden="true" {...props}>
|
||||
<defs>
|
||||
<linearGradient id={id} x1="79" y1="16" x2="105" y2="237" gradientUnits="userSpaceOnUse">
|
||||
<stop stopColor={gradStopColor} />
|
||||
<stop offset="1" stopColor={baseRingColor} stopOpacity="0" />
|
||||
</linearGradient>
|
||||
</defs>
|
||||
<path
|
||||
opacity=".2"
|
||||
d="M1 279C1 125.465 125.465 1 279 1s278 124.465 278 278-124.465 278-278 278S1 432.535 1 279Z"
|
||||
stroke={baseRingColor}
|
||||
/>
|
||||
<path d="M1 279C1 125.465 125.465 1 279 1" stroke={`url(#${id})`} strokeLinecap="round" />
|
||||
</svg>
|
||||
);
|
||||
}
|
||||
|
||||
export function CallToAction() {
|
||||
const { colorMode } = useColorMode();
|
||||
const bgColorClass = colorMode === "light" ? "bg-gray-900" : "bg-gray-50";
|
||||
const headingColorClass = colorMode === "light" ? "text-white" : "text-black";
|
||||
const textColorClass = colorMode === "light" ? "text-gray-300" : "text-black";
|
||||
|
||||
return (
|
||||
<section id="join-us" className={`relative overflow-hidden py-20 sm:py-28 ${bgColorClass} ${textColorClass}`}>
|
||||
<div className="absolute top-1/2 left-20 -translate-y-1/2 sm:left-1/2 sm:-translate-x-1/2">
|
||||
<CircleBackground color="#fff" className="animate-spin-slower" />
|
||||
<CircleBackground className="animate-spin-slower" />
|
||||
</div>
|
||||
<Container className="relative">
|
||||
<div className="mx-auto max-w-md sm:text-center">
|
||||
<h2 className="text-3xl font-medium tracking-tight text-white sm:text-4xl">Join Us</h2>
|
||||
<p className="mt-4 text-lg text-gray-300">
|
||||
<h2 className={`text-3xl font-medium tracking-tight sm:text-4xl ${headingColorClass}`}>Join Us</h2>
|
||||
<p className="mt-4 text-lg">
|
||||
All open source projects begin with people like you. Open source is the belief that if we collaborate we can
|
||||
together gift our knowledge and technology to the world for the benefit of humanity. Are you in? Find us
|
||||
here:
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
import { useId } from "react";
|
||||
|
||||
export function CircleBackground({ color, width = 558, height = 558, ...props }) {
|
||||
const id = useId();
|
||||
|
||||
return (
|
||||
<svg viewBox="0 0 558 558" width={width} height={height} fill="none" aria-hidden="true" {...props}>
|
||||
<defs>
|
||||
<linearGradient id={id} x1="79" y1="16" x2="105" y2="237" gradientUnits="userSpaceOnUse">
|
||||
<stop stopColor={color} />
|
||||
<stop offset="1" stopColor={color} stopOpacity="0" />
|
||||
</linearGradient>
|
||||
</defs>
|
||||
<path
|
||||
opacity=".2"
|
||||
d="M1 279C1 125.465 125.465 1 279 1s278 124.465 278 278-124.465 278-278 278S1 432.535 1 279Z"
|
||||
stroke={color}
|
||||
/>
|
||||
<path d="M1 279C1 125.465 125.465 1 279 1" stroke={`url(#${id})`} strokeLinecap="round" />
|
||||
</svg>
|
||||
);
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
|
||||
import { Container } from "./Container";
|
||||
|
||||
const faqs = [
|
||||
@@ -25,11 +27,16 @@ const faqs = [
|
||||
];
|
||||
|
||||
export function Faq() {
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
const headingColorClass = colorMode === "light" ? "text-gray-900" : "text-white";
|
||||
const textColorClass = colorMode === "light" ? "text-gray-700" : "text-gray-100";
|
||||
|
||||
return (
|
||||
<section id="faq" aria-labelledby="faqs-title" className="border-t border-gray-200 py-20 sm:py-32">
|
||||
<Container className="">
|
||||
<div className="mx-auto max-w-2xl lg:mx-0">
|
||||
<h2 id="faqs-title" className="text-3xl font-medium tracking-tight text-gray-900">
|
||||
<h2 id="faqs-title" className={`text-3xl font-medium tracking-tight ${headingColorClass}`}>
|
||||
Frequently Asked Questions
|
||||
</h2>
|
||||
{/* <p className="mt-2 text-lg text-gray-600">
|
||||
@@ -52,8 +59,8 @@ export function Faq() {
|
||||
<ul role="list" className="space-y-10">
|
||||
{column.map((faq, faqIndex) => (
|
||||
<li key={faqIndex}>
|
||||
<h3 className="text-lg font-semibold leading-6 text-gray-900">{faq.question}</h3>
|
||||
<p className="mt-4 text-sm text-gray-700">{faq.answer}</p>
|
||||
<h3 className={`text-lg font-semibold leading-6 ${headingColorClass}`}>{faq.question}</h3>
|
||||
<p className={`mt-4 text-sm ${textColorClass}`}>{faq.answer}</p>
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
|
||||
@@ -1,71 +1,70 @@
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import Image from "next/image";
|
||||
import Link from "next/link";
|
||||
|
||||
import { Container } from "./Container";
|
||||
|
||||
export function Footer() {
|
||||
return (
|
||||
<footer className="border-t border-gray-200 bg-white">
|
||||
<main>
|
||||
<Container className="">
|
||||
<div className="flex flex-wrap justify-between gap-y-12 py-10 lg:items-center lg:py-16">
|
||||
<div className="flex items-center text-black pr-8">
|
||||
<Link href="/" aria-label="Home" className="flex items-center">
|
||||
<Image src="/images/logos/logo.svg" className="mx-auto object-fill" width="52" height="52" alt="logo" />
|
||||
</Link>
|
||||
const { colorMode } = useColorMode();
|
||||
const bgColorClass = colorMode === "light" ? "bg-transparent" : "bg-gray-800";
|
||||
const borderClass = colorMode === "light" ? "border-slate-200" : "border-transparent";
|
||||
|
||||
<div className="ml-2">
|
||||
<p className="text-base font-bold">Open Assistant</p>
|
||||
<p className="text-sm">Conversational AI for everyone.</p>
|
||||
return (
|
||||
<footer className={bgColorClass}>
|
||||
<div className={`flex mx-auto max-w-7xl justify-between py-10 px-10 border-t ${borderClass}`}>
|
||||
<div className="flex items-center pr-8">
|
||||
<Link href="/" aria-label="Home" className="flex items-center">
|
||||
<Image src="/images/logos/logo.svg" className="mx-auto object-fill" width="52" height="52" alt="logo" />
|
||||
</Link>
|
||||
|
||||
<div className="ml-2">
|
||||
<p className="text-base font-bold">Open Assistant</p>
|
||||
<p className="text-sm">Conversational AI for everyone.</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<nav className="flex justify-center gap-20">
|
||||
<nav className="flex justify-center gap-20">
|
||||
<div className="flex flex-col text-sm leading-7">
|
||||
<b>Legal</b>
|
||||
<div className="flex flex-col leading-5">
|
||||
<Link href="/privacy-policy" aria-label="Privacy Policy" className="hover:underline underline-offset-2">
|
||||
Privacy Policy
|
||||
</Link>
|
||||
<Link
|
||||
href="/terms-of-service"
|
||||
aria-label="Terms of Service"
|
||||
className="hover:underline underline-offset-2"
|
||||
>
|
||||
Terms of Service
|
||||
</Link>
|
||||
</div>
|
||||
</div>
|
||||
<nav className="flex justify-center gap-20">
|
||||
<div className="flex flex-col text-sm leading-7">
|
||||
<b>Legal</b>
|
||||
<div className="flex flex-col leading-5">
|
||||
<Link
|
||||
href="/privacy-policy"
|
||||
aria-label="Privacy Policy"
|
||||
className="hover:underline underline-offset-2"
|
||||
>
|
||||
Privacy Policy
|
||||
</Link>
|
||||
<Link
|
||||
href="/terms-of-service"
|
||||
aria-label="Terms of Service"
|
||||
className="hover:underline underline-offset-2"
|
||||
>
|
||||
Terms of Service
|
||||
</Link>
|
||||
</div>
|
||||
<div className="flex flex-col text-sm leading-7">
|
||||
<b>Connect</b>
|
||||
<div className="flex flex-col leading-5">
|
||||
<Link
|
||||
href="https://github.com/LAION-AI/Open-Assistant"
|
||||
rel="noopener noreferrer nofollow"
|
||||
target="_blank"
|
||||
aria-label="Privacy Policy"
|
||||
className="hover:underline underline-offset-2"
|
||||
>
|
||||
Github
|
||||
</Link>
|
||||
<Link
|
||||
href="https://discord.gg/pXtnYk9c"
|
||||
rel="noopener noreferrer nofollow"
|
||||
target="_blank"
|
||||
aria-label="Terms of Service"
|
||||
className="hover:underline underline-offset-2"
|
||||
>
|
||||
Discord
|
||||
</Link>
|
||||
</div>
|
||||
<div className="flex flex-col text-sm leading-7">
|
||||
<b>Connect</b>
|
||||
<div className="flex flex-col leading-5">
|
||||
<Link
|
||||
href="https://github.com/LAION-AI/Open-Assistant"
|
||||
rel="noopener noreferrer nofollow"
|
||||
target="_blank"
|
||||
aria-label="Privacy Policy"
|
||||
className="hover:underline underline-offset-2"
|
||||
>
|
||||
Github
|
||||
</Link>
|
||||
<Link
|
||||
href="https://discord.gg/pXtnYk9c"
|
||||
rel="noopener noreferrer nofollow"
|
||||
target="_blank"
|
||||
aria-label="Terms of Service"
|
||||
className="hover:underline underline-offset-2"
|
||||
>
|
||||
Discord
|
||||
</Link>
|
||||
</div>
|
||||
</div>
|
||||
</nav>
|
||||
</div>
|
||||
</Container>
|
||||
</main>
|
||||
</div>
|
||||
</nav>
|
||||
{/* </div> */}
|
||||
</nav>
|
||||
</div>
|
||||
</footer>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -22,4 +22,15 @@ const Template = (args) => {
|
||||
};
|
||||
|
||||
export const Default = Template.bind({});
|
||||
Default.args = { session: { data: { user: { name: "StoryBook user" } }, status: "authenticated" }, transparent: false };
|
||||
Default.args = {
|
||||
session: {
|
||||
data: {
|
||||
user: {
|
||||
name: "StoryBook user",
|
||||
},
|
||||
},
|
||||
status: "authenticated",
|
||||
},
|
||||
transparent: false,
|
||||
borderClass: undefined,
|
||||
};
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
import { Button } from "@chakra-ui/react";
|
||||
import { Box, Button, useColorMode } from "@chakra-ui/react";
|
||||
import { Popover } from "@headlessui/react";
|
||||
import clsx from "clsx";
|
||||
import { AnimatePresence, motion } from "framer-motion";
|
||||
import Image from "next/image";
|
||||
import Link from "next/link";
|
||||
import { useSession } from "next-auth/react";
|
||||
import { FaUser } from "react-icons/fa";
|
||||
import { Container } from "src/components/Container";
|
||||
|
||||
import { NavLinks } from "./NavLinks";
|
||||
import { ColorModeIconToggle } from "../UI/ColorModeIconToggle";
|
||||
import { UserMenu } from "./UserMenu";
|
||||
|
||||
function MenuIcon(props) {
|
||||
@@ -55,63 +53,72 @@ function AccountButton() {
|
||||
}
|
||||
|
||||
export function Header(props) {
|
||||
const transparent = props.transparent ?? false;
|
||||
const { colorMode } = useColorMode();
|
||||
const borderClass = props.transparent
|
||||
? ""
|
||||
: colorMode === "light"
|
||||
? "border-b border-gray-400"
|
||||
: "border-b border-zinc-800";
|
||||
|
||||
return (
|
||||
<header className={clsx(!transparent && "bg-white")}>
|
||||
<nav>
|
||||
<Container className="relative z-10 flex justify-between py-8">
|
||||
<div className="relative z-10 flex items-center gap-16">
|
||||
<Link href="/" aria-label="Home" className="flex items-center">
|
||||
<Image src="/images/logos/logo.svg" className="mx-auto object-fill" width="50" height="50" alt="logo" />
|
||||
<span className="text-2xl font-bold ml-3">Open Assistant</span>
|
||||
</Link>
|
||||
</div>
|
||||
<div className="flex items-center gap-4">
|
||||
<Popover className="lg:hidden">
|
||||
{({ open }) => (
|
||||
<>
|
||||
<Popover.Button
|
||||
className="relative z-10 inline-flex items-center rounded-lg stroke-gray-900 p-2 hover:bg-gray-200/50 hover:stroke-gray-600 active:stroke-gray-900 [&:not(:focus-visible)]:focus:outline-none"
|
||||
aria-label="Toggle site navigation"
|
||||
>
|
||||
{({ open }) => (open ? <ChevronUpIcon className="h-6 w-6" /> : <MenuIcon className="h-6 w-6" />)}
|
||||
</Popover.Button>
|
||||
<AnimatePresence initial={false}>
|
||||
{open && (
|
||||
<>
|
||||
<Popover.Overlay
|
||||
static
|
||||
as={motion.div}
|
||||
initial={{ opacity: 0 }}
|
||||
animate={{ opacity: 1 }}
|
||||
exit={{ opacity: 0 }}
|
||||
className="fixed inset-0 z-1 bg-gray-300/60 backdrop-blur"
|
||||
/>
|
||||
<Popover.Panel
|
||||
static
|
||||
as={motion.div}
|
||||
initial={{ opacity: 0, y: -32 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
exit={{
|
||||
opacity: 0,
|
||||
y: -32,
|
||||
transition: { duration: 0.2 },
|
||||
}}
|
||||
className="absolute inset-x-0 top-0 z-0 origin-top rounded-b-2xl bg-white px-6 pb-6 pt-32 shadow-2xl shadow-gray-900/20"
|
||||
>
|
||||
<div className="mt-8 flex flex-col gap-4"></div>
|
||||
</Popover.Panel>
|
||||
</>
|
||||
)}
|
||||
</AnimatePresence>
|
||||
</>
|
||||
)}
|
||||
</Popover>
|
||||
<AccountButton />
|
||||
<UserMenu />
|
||||
</div>
|
||||
</Container>
|
||||
</nav>
|
||||
</header>
|
||||
<nav className={`oa-basic-theme ${borderClass}`}>
|
||||
<Box className="flex mx-auto max-w-7xl justify-between py-8 px-10">
|
||||
<div className="relative z-10 flex items-center gap-16">
|
||||
<Link href="/" aria-label="Home" className="flex items-center">
|
||||
<Image src="/images/logos/logo.svg" className="mx-auto object-fill" width="50" height="50" alt="logo" />
|
||||
<span className="text-2xl font-bold ml-3">Open Assistant</span>
|
||||
</Link>
|
||||
</div>
|
||||
<div className="flex items-center gap-4">
|
||||
<Popover className="lg:hidden">
|
||||
{({ open }) => (
|
||||
<>
|
||||
<Popover.Button
|
||||
className="relative z-10 inline-flex items-center rounded-lg stroke-gray-900 p-2 hover:bg-gray-200/50 hover:stroke-gray-600 active:stroke-gray-900 [&:not(:focus-visible)]:focus:outline-none"
|
||||
aria-label="Toggle site navigation"
|
||||
>
|
||||
{({ open }) => (open ? <ChevronUpIcon className="h-6 w-6" /> : <MenuIcon className="h-6 w-6" />)}
|
||||
</Popover.Button>
|
||||
<AnimatePresence initial={false}>
|
||||
{open && (
|
||||
<>
|
||||
<Popover.Overlay
|
||||
static
|
||||
as={motion.div}
|
||||
initial={{ opacity: 0 }}
|
||||
animate={{ opacity: 1 }}
|
||||
exit={{ opacity: 0 }}
|
||||
className="fixed inset-0 z-1 bg-gray-300/60 backdrop-blur"
|
||||
/>
|
||||
<Popover.Panel
|
||||
static
|
||||
as={motion.div}
|
||||
initial={{ opacity: 0, y: -32 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
exit={{
|
||||
opacity: 0,
|
||||
y: -32,
|
||||
transition: { duration: 0.2 },
|
||||
}}
|
||||
className="absolute inset-x-0 top-0 z-0 origin-top rounded-b-2xl bg-white px-6 pb-6 pt-32 shadow-2xl shadow-gray-900/20"
|
||||
>
|
||||
<div className="space-y-4">
|
||||
<MobileNavLink href="/#join-us">Join Us</MobileNavLink>
|
||||
<MobileNavLink href="/#faqs">FAQs</MobileNavLink>
|
||||
</div>
|
||||
<div className="mt-8 flex flex-col gap-4"></div>
|
||||
</Popover.Panel>
|
||||
</>
|
||||
)}
|
||||
</AnimatePresence>
|
||||
</>
|
||||
)}
|
||||
</Popover>
|
||||
<AccountButton />
|
||||
<UserMenu />
|
||||
<ColorModeIconToggle className="ml-5" />
|
||||
</div>
|
||||
</Box>
|
||||
</nav>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,9 +1,15 @@
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import { AnimatePresence, motion } from "framer-motion";
|
||||
import Link from "next/link";
|
||||
import { useState } from "react";
|
||||
|
||||
export function NavLinks(): JSX.Element {
|
||||
const [hoveredIndex, setHoveredIndex] = useState(null);
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
const linkColor = colorMode === "light" ? "text-gray-700 hover:text-gray-900" : "text-gray-50 hover:text-white";
|
||||
|
||||
const hoverBgColor = colorMode === "light" ? "bg-gray-100" : "bg-gray-800";
|
||||
|
||||
return (
|
||||
<>
|
||||
@@ -14,14 +20,14 @@ export function NavLinks(): JSX.Element {
|
||||
<Link
|
||||
key={label}
|
||||
href={href}
|
||||
className="relative -my-2 -mx-3 rounded-lg px-3 py-2 text-sm text-gray-700 transition-colors delay-150 hover:text-gray-900 hover:delay-[0ms]"
|
||||
className={`${linkColor} relative -my-2 -mx-3 rounded-lg px-3 py-2 text-sm transition-colors delay-150 hover:delay-[0ms]`}
|
||||
onMouseEnter={() => setHoveredIndex(index)}
|
||||
onMouseLeave={() => setHoveredIndex(null)}
|
||||
>
|
||||
<AnimatePresence>
|
||||
{hoveredIndex === index && (
|
||||
<motion.span
|
||||
className="absolute inset-0 rounded-lg bg-gray-100"
|
||||
className={`${hoverBgColor} absolute inset-0 rounded-lg`}
|
||||
layoutId="hoverBackground"
|
||||
initial={{ opacity: 0 }}
|
||||
animate={{ opacity: 1, transition: { duration: 0.15 } }}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { Box, useColorModeValue } from "@chakra-ui/react";
|
||||
import { Popover } from "@headlessui/react";
|
||||
import { AnimatePresence, motion } from "framer-motion";
|
||||
import Image from "next/image";
|
||||
@@ -7,6 +8,7 @@ import { FaCog, FaSignOutAlt } from "react-icons/fa";
|
||||
|
||||
export function UserMenu() {
|
||||
const { data: session } = useSession();
|
||||
const backgroundColor = useColorModeValue("#FFFFFF", "#000000");
|
||||
|
||||
if (!session) {
|
||||
return <></>;
|
||||
@@ -26,7 +28,7 @@ export function UserMenu() {
|
||||
{({ open }) => (
|
||||
<>
|
||||
<Popover.Button aria-label="Toggle Account Options" className="flex">
|
||||
<div className="flex items-center gap-4 p-1 lg:pr-6 rounded-full bg-white border border-slate-300/70 hover:bg-gray-200/50 transition-colors duration-300">
|
||||
<div className="flex items-center gap-4 p-1 lg:pr-6 rounded-full border border-slate-300/70 hover:bg-gray-200/50 transition-colors duration-300">
|
||||
<Image
|
||||
src="/images/temp-avatars/av1.jpg"
|
||||
alt="Profile Picture"
|
||||
@@ -41,7 +43,7 @@ export function UserMenu() {
|
||||
</Popover.Button>
|
||||
<AnimatePresence initial={false}>
|
||||
{open && (
|
||||
<>
|
||||
<Box backgroundColor={backgroundColor}>
|
||||
<Popover.Panel
|
||||
static
|
||||
as={motion.div}
|
||||
@@ -52,9 +54,9 @@ export function UserMenu() {
|
||||
y: -10,
|
||||
transition: { duration: 0.2 },
|
||||
}}
|
||||
className="absolute right-0 mt-3 w-screen max-w-xs p-4 rounded-md bg-white border border-slate-300/70"
|
||||
className="absolute right-0 mt-3 w-screen bg-inherit max-w-xs p-4 rounded-md border border-slate-300/70"
|
||||
>
|
||||
<div className="flex flex-col gap-1">
|
||||
<Box className="flex flex-col gap-1">
|
||||
{accountOptions.map((item) => (
|
||||
<a
|
||||
key={item.name}
|
||||
@@ -81,9 +83,9 @@ export function UserMenu() {
|
||||
<p>Sign Out</p>
|
||||
</div>
|
||||
</a>
|
||||
</div>
|
||||
</Box>
|
||||
</Popover.Panel>
|
||||
</>
|
||||
</Box>
|
||||
)}
|
||||
</AnimatePresence>
|
||||
</>
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import Image from "next/image";
|
||||
import { useId } from "react";
|
||||
|
||||
@@ -6,6 +7,10 @@ import { Container } from "./Container";
|
||||
function BackgroundIllustration(props) {
|
||||
const id = useId();
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const baseRingColor = colorMode === "light" ? "#d4d4d4" : "#005a69";
|
||||
const gradStopColor = colorMode === "light" ? "#06b6d4" : "#00f2ff";
|
||||
|
||||
return (
|
||||
<div {...props}>
|
||||
<svg
|
||||
@@ -16,14 +21,14 @@ function BackgroundIllustration(props) {
|
||||
>
|
||||
<path
|
||||
d="M1025 513c0 282.77-229.23 512-512 512S1 795.77 1 513 230.23 1 513 1s512 229.23 512 512Z"
|
||||
stroke="#D4D4D4"
|
||||
stroke={baseRingColor}
|
||||
strokeOpacity="0.7"
|
||||
/>
|
||||
<path d="M513 1025C230.23 1025 1 795.77 1 513" stroke={`url(#${id}-gradient-1)`} strokeLinecap="round" />
|
||||
<defs>
|
||||
<linearGradient id={`${id}-gradient-1`} x1="1" y1="513" x2="1" y2="1025" gradientUnits="userSpaceOnUse">
|
||||
<stop stopColor="#06b6d4" />
|
||||
<stop offset="1" stopColor="#06b6d4" stopOpacity="0" />
|
||||
<stop stopColor={gradStopColor} />
|
||||
<stop offset="1" stopColor={gradStopColor} stopOpacity="0" />
|
||||
</linearGradient>
|
||||
</defs>
|
||||
</svg>
|
||||
@@ -35,14 +40,14 @@ function BackgroundIllustration(props) {
|
||||
>
|
||||
<path
|
||||
d="M913 513c0 220.914-179.086 400-400 400S113 733.914 113 513s179.086-400 400-400 400 179.086 400 400Z"
|
||||
stroke="#D4D4D4"
|
||||
stroke={baseRingColor}
|
||||
strokeOpacity="0.7"
|
||||
/>
|
||||
<path d="M913 513c0 220.914-179.086 400-400 400" stroke={`url(#${id}-gradient-2)`} strokeLinecap="round" />
|
||||
<defs>
|
||||
<linearGradient id={`${id}-gradient-2`} x1="913" y1="513" x2="913" y2="913" gradientUnits="userSpaceOnUse">
|
||||
<stop stopColor="#06b6d4" />
|
||||
<stop offset="1" stopColor="#06b6d4" stopOpacity="0" />
|
||||
<stop stopColor={gradStopColor} />
|
||||
<stop offset="1" stopColor={gradStopColor} stopOpacity="0" />
|
||||
</linearGradient>
|
||||
</defs>
|
||||
</svg>
|
||||
@@ -51,17 +56,24 @@ function BackgroundIllustration(props) {
|
||||
}
|
||||
|
||||
export function Hero() {
|
||||
const { colorMode } = useColorMode();
|
||||
const pTextColor = colorMode === "light" ? "text-gray-600" : "text-white";
|
||||
const fancyTextGradientClasses =
|
||||
colorMode === "light" ? "from-blue-600 via-sky-400 to-blue-700" : "from-blue-500 via-sky-300 to-blue-400";
|
||||
|
||||
return (
|
||||
<div className="overflow-hidden py-20 sm:py-32 lg:pb-32 xl:pb-36">
|
||||
<Container className="">
|
||||
<div className="lg:grid lg:grid-cols-12 lg:gap-x-8 lg:gap-y-20">
|
||||
<div className="relative z-10 mx-auto max-w-2xl lg:col-span-7 lg:max-w-none lg:pt-6 xl:col-span-6">
|
||||
<h1 className="text-5xl mb-6 font-bold tracking-tight text-gray-900">Open Assistant</h1>
|
||||
<p className="mt-8 text-3xl inline bg-gradient-to-r from-indigo-600 via-sky-400 to-indigo-700 bg-clip-text font-display tracking-tight text-transparent">
|
||||
<h1 className="text-5xl mb-6 font-bold tracking-tight">Open Assistant</h1>
|
||||
<p
|
||||
className={`bg-gradient-to-r ${fancyTextGradientClasses} mt-8 text-3xl inline bg-clip-text font-display tracking-tight text-transparent`}
|
||||
>
|
||||
<b>Conversational AI for everyone.</b>
|
||||
</p>
|
||||
<p className="mt-6 text-lg text-gray-600">We believe we can create a revolution.</p>
|
||||
<p className="mt-6 text-lg text-gray-600">
|
||||
<p className={`mt-6 text-lg ${pTextColor}`}>We believe we can create a revolution.</p>
|
||||
<p className={`mt-6 text-lg ${pTextColor}`}>
|
||||
In the same way that Stable Diffusion helped the world make art and images in new ways, we want to improve
|
||||
the world by providing amazing conversational AI.
|
||||
</p>
|
||||
|
||||
@@ -1,12 +1,18 @@
|
||||
import { Progress } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
|
||||
export const LoadingScreen = ({ text }) => (
|
||||
<div className="bg-slate-100">
|
||||
<Progress size="xs" isIndeterminate />
|
||||
{text && (
|
||||
<div className="flex h-full">
|
||||
<div className="text-xl font-bold text-gray-800 mx-auto my-auto">{text}</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
export const LoadingScreen = ({ text }) => {
|
||||
const { colorMode } = useColorMode();
|
||||
const mainClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
return (
|
||||
<div className={`h-full ${mainClasses}`}>
|
||||
<Progress size="sm" isIndeterminate />
|
||||
{text && (
|
||||
<div className="flex h-full">
|
||||
<div className="text-xl font-bold mx-auto my-auto">{text}</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import { Grid } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
|
||||
import { FlaggableElement } from "./FlaggableElement";
|
||||
|
||||
export interface Message {
|
||||
@@ -6,13 +8,24 @@ export interface Message {
|
||||
is_assistant: boolean;
|
||||
}
|
||||
|
||||
const getColor = (isAssistant: boolean) => (isAssistant ? "bg-slate-800" : "bg-sky-900");
|
||||
const getBgColor = (isAssistant: boolean, colorMode: "light" | "dark") => {
|
||||
if (colorMode === "light") {
|
||||
return isAssistant ? "bg-slate-800" : "bg-sky-900";
|
||||
} else {
|
||||
return isAssistant ? "bg-black" : "bg-sky-900";
|
||||
}
|
||||
};
|
||||
|
||||
export const Messages = ({ messages, post_id }: { messages: Message[]; post_id: string }) => {
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
const items = messages.map(({ text, is_assistant }: Message, i: number) => {
|
||||
return (
|
||||
<FlaggableElement text={text} post_id={post_id} key={i + text}>
|
||||
<div key={i + text} className={`${getColor(is_assistant)} p-4 rounded-md text-white whitespace-pre-wrap`}>
|
||||
<div
|
||||
key={i + text}
|
||||
className={`${getBgColor(is_assistant, colorMode)} p-4 rounded-md text-white whitespace-pre-wrap`}
|
||||
>
|
||||
{text}
|
||||
</div>
|
||||
</FlaggableElement>
|
||||
|
||||
@@ -2,9 +2,9 @@ import { Flex } from "@chakra-ui/react";
|
||||
import {
|
||||
closestCenter,
|
||||
DndContext,
|
||||
KeyboardSensor,
|
||||
PointerSensor,
|
||||
TouchSensor,
|
||||
KeyboardSensor,
|
||||
useSensor,
|
||||
useSensors,
|
||||
} from "@dnd-kit/core";
|
||||
@@ -23,6 +23,7 @@ import { SortableItem } from "./SortableItem";
|
||||
export interface SortableProps {
|
||||
items: ReactNode[];
|
||||
onChange: (newSortedIndices: number[]) => void;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
interface SortableItems {
|
||||
@@ -31,18 +32,18 @@ interface SortableItems {
|
||||
item: ReactNode;
|
||||
}
|
||||
|
||||
export const Sortable = ({ items, onChange }: SortableProps) => {
|
||||
export const Sortable = (props: SortableProps) => {
|
||||
const [itemsWithIds, setItemsWithIds] = useState<SortableItems[]>([]);
|
||||
|
||||
useEffect(() => {
|
||||
setItemsWithIds(
|
||||
items.map((item, idx) => ({
|
||||
props.items.map((item, idx) => ({
|
||||
item,
|
||||
id: idx + 1, // +1 because dndtoolkit has problem with "falsy" ids
|
||||
originalIndex: idx,
|
||||
}))
|
||||
);
|
||||
}, [items]);
|
||||
}, [props.items]);
|
||||
|
||||
const sensors = useSensors(
|
||||
useSensor(PointerSensor),
|
||||
@@ -50,6 +51,8 @@ export const Sortable = ({ items, onChange }: SortableProps) => {
|
||||
useSensor(KeyboardSensor, { coordinateGetter: sortableKeyboardCoordinates })
|
||||
);
|
||||
|
||||
const extraClasses = props.className || "";
|
||||
|
||||
return (
|
||||
<DndContext
|
||||
sensors={sensors}
|
||||
@@ -58,7 +61,7 @@ export const Sortable = ({ items, onChange }: SortableProps) => {
|
||||
modifiers={[restrictToVerticalAxis]}
|
||||
>
|
||||
<SortableContext items={itemsWithIds} strategy={verticalListSortingStrategy}>
|
||||
<Flex direction="column" gap={2}>
|
||||
<Flex direction="column" gap={2} className={extraClasses}>
|
||||
{itemsWithIds.map(({ id, item }) => (
|
||||
<SortableItem key={id} id={id}>
|
||||
{item}
|
||||
@@ -78,7 +81,7 @@ export const Sortable = ({ items, onChange }: SortableProps) => {
|
||||
const oldIndex = items.findIndex((x) => x.id === active.id);
|
||||
const newIndex = items.findIndex((x) => x.id === over.id);
|
||||
const newArray = arrayMove(items, oldIndex, newIndex);
|
||||
onChange(newArray.map((item) => item.originalIndex));
|
||||
props.onChange(newArray.map((item) => item.originalIndex));
|
||||
return newArray;
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import { Button } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import { useSortable } from "@dnd-kit/sortable";
|
||||
import { CSS } from "@dnd-kit/utilities";
|
||||
import { RxDragHandleDots2 } from "react-icons/rx";
|
||||
import { PropsWithChildren } from "react";
|
||||
import { RxDragHandleDots2 } from "react-icons/rx";
|
||||
|
||||
export const SortableItem = ({ children, id }: PropsWithChildren<{ id: number }>) => {
|
||||
const { attributes, listeners, setNodeRef, transform, transition } = useSortable({ id });
|
||||
@@ -13,9 +14,15 @@ export const SortableItem = ({ children, id }: PropsWithChildren<{ id: number }>
|
||||
touchAction: "none",
|
||||
};
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const themedClasses =
|
||||
colorMode === "light"
|
||||
? "bg-slate-600 hover:bg-slate-500 text-white"
|
||||
: "bg-black hover:bg-slate-900 text-white ring-1 ring-white/30 ring-inset hover:ring-slate-200/50";
|
||||
|
||||
return (
|
||||
<li
|
||||
className="grid grid-cols-[min-content_1fr] items-center rounded-lg shadow-md gap-x-2 p-2 bg-white hover:bg-slate-50"
|
||||
className={`grid grid-cols-[min-content_1fr] items-center rounded-lg shadow-md gap-x-2 p-2 ${themedClasses}`}
|
||||
ref={setNodeRef}
|
||||
style={style}
|
||||
>
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
|
||||
interface SurveyCardProps {
|
||||
className?: string;
|
||||
children: React.ReactNode;
|
||||
}
|
||||
|
||||
export const SurveyCard = (props: SurveyCardProps) => {
|
||||
const extraClases = props.className || "";
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
const baseCardClasses = "rounded-lg h-full block p-6";
|
||||
const cardClases =
|
||||
colorMode === "light"
|
||||
? `${baseCardClasses} bg-slate-50 text-gray-800 shadow-lg ${extraClases}`
|
||||
: // `${baseCardClasses} bg-slate-800 text-white shadow-xl${extraClases}`;
|
||||
`${baseCardClasses} bg-slate-800 text-slate-400 shadow-xl ring-1 ring-white/10 ring-inset ${extraClases}`;
|
||||
|
||||
return <div className={cardClases}>{props.children}</div>;
|
||||
};
|
||||
@@ -0,0 +1,40 @@
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import { Flex } from "@chakra-ui/react";
|
||||
import { SkipButton } from "src/components/Buttons/Skip";
|
||||
import { SubmitButton } from "src/components/Buttons/Submit";
|
||||
import { TaskInfo } from "src/components/TaskInfo/TaskInfo";
|
||||
|
||||
interface TaskControlsProps {
|
||||
// we need a task type
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
tasks: any[];
|
||||
className?: string;
|
||||
onSubmitResponse: (task: { id: string }) => void;
|
||||
onSkip: () => void;
|
||||
}
|
||||
|
||||
export const TaskControls = (props: TaskControlsProps) => {
|
||||
const extraClases = props.className || "";
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
const baseClasses = "flex flex-row justify-items-stretch mb-8 p-4 rounded-lg max-w-7xl mx-auto";
|
||||
const taskControlClases =
|
||||
colorMode === "light"
|
||||
? `${baseClasses} bg-white text-gray-800 shadow-lg ${extraClases}`
|
||||
: `${baseClasses} bg-slate-800 text-slate-400 shadow-xl ring-1 ring-white/10 ring-inset ${extraClases}`;
|
||||
|
||||
const endTask = props.tasks[props.tasks.length - 1];
|
||||
return (
|
||||
<section className={taskControlClases}>
|
||||
<TaskInfo id={props.tasks[0].id} output="Submit your answer" />
|
||||
<Flex justify="center" ml="auto" gap={2}>
|
||||
<SkipButton>Skip</SkipButton>
|
||||
{endTask.task.type !== "task_done" ? (
|
||||
<SubmitButton onClick={() => props.onSubmitResponse(props.tasks[0])}>Submit</SubmitButton>
|
||||
) : (
|
||||
<SubmitButton onClick={props.onSkip}>Next Task</SubmitButton>
|
||||
)}
|
||||
</Flex>
|
||||
</section>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,16 @@
|
||||
import { SurveyCard } from "src/components/Survey/SurveyCard";
|
||||
|
||||
export const TwoColumnsWithCards = ({ children }: { children: React.ReactNode[] }) => {
|
||||
if (!Array.isArray(children) || children.length !== 2) {
|
||||
throw new Error("TwoColumns expects 2 children");
|
||||
}
|
||||
|
||||
const [first, second] = children;
|
||||
|
||||
return (
|
||||
<div className="mb-8 mx-auto max-w-7xl lt-lg:mb-12 grid lg:gap-x-12 lg:grid-cols-2">
|
||||
<SurveyCard>{first}</SurveyCard>
|
||||
<SurveyCard className="lg:mt-0 lt-lg:mt-6">{second}</SurveyCard>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -1,6 +1,6 @@
|
||||
export const TaskInfo = ({ id, output }: { id: string; output: string }) => {
|
||||
return (
|
||||
<div className="grid grid-cols-[min-content_auto] gap-x-2 text-gray-700">
|
||||
<div className="grid grid-cols-[min-content_auto] gap-x-2 ">
|
||||
<b>Prompt</b>
|
||||
<span data-cy="task-id">{id}</span>
|
||||
<b>Output</b>
|
||||
|
||||
@@ -1,12 +1,24 @@
|
||||
import { Flex } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import React from "react";
|
||||
|
||||
import { TaskOption } from "./TaskOption";
|
||||
import { TaskOptions } from "./TaskOptions";
|
||||
|
||||
export const TaskSelection = () => {
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
return (
|
||||
<Flex gap={10} wrap="wrap" justifyContent="space-evenly" width="full" height="full" alignItems={"center"}>
|
||||
<Flex
|
||||
gap={10}
|
||||
wrap="wrap"
|
||||
justifyContent="space-evenly"
|
||||
width="full"
|
||||
height="full"
|
||||
alignItems={"center"}
|
||||
className={mainBgClasses}
|
||||
>
|
||||
<TaskOptions key="create" title="Create">
|
||||
{/* <TaskOption
|
||||
alt="Summarize Stories"
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
export const TwoColumns = ({ children }: { children: React.ReactNode[] }) => {
|
||||
if (!Array.isArray(children) || children.length !== 2) {
|
||||
throw new Error("TwoColumns expects 2 children");
|
||||
}
|
||||
|
||||
const [first, second] = children;
|
||||
|
||||
return (
|
||||
<section className="mb-8 lt-lg:mb-12 grid lg:gap-x-12 lg:grid-cols-2">
|
||||
<div className="rounded-lg shadow-lg h-full block bg-white p-6">{first}</div>
|
||||
<div className="rounded-lg shadow-lg h-full block bg-white p-6 mt-6 lg:mt-0">{second}</div>
|
||||
</section>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,23 @@
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import { CiDark } from "react-icons/ci";
|
||||
import { CiLight } from "react-icons/ci";
|
||||
|
||||
export function ColorModeIconToggle(props) {
|
||||
const { colorMode, toggleColorMode } = useColorMode();
|
||||
const propsClassName = props.className ?? "";
|
||||
|
||||
return (
|
||||
<button
|
||||
type="button"
|
||||
className={`flex h-6 w-6 items-center justify-center rounded-md transition hover:bg-zinc-900/5 dark:hover:bg-white/5 ${propsClassName}`}
|
||||
aria-label="Toggle dark mode"
|
||||
onClick={toggleColorMode}
|
||||
>
|
||||
{colorMode === "light" ? (
|
||||
<CiDark className="h-5 w-5 stroke-zinc-900" />
|
||||
) : (
|
||||
<CiLight className="h-5 w-5 stroke-white" />
|
||||
)}
|
||||
</button>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
import { Switch, useColorMode } from "@chakra-ui/react";
|
||||
import React from "react";
|
||||
|
||||
const ColorModeSwitch = () => {
|
||||
const { colorMode, toggleColorMode } = useColorMode();
|
||||
return (
|
||||
<Switch
|
||||
onChange={toggleColorMode}
|
||||
defaultChecked={colorMode === "light"}
|
||||
checked={colorMode === "light"}
|
||||
size="lg"
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default ColorModeSwitch;
|
||||
@@ -1,48 +1,25 @@
|
||||
import "../styles/globals.css";
|
||||
import "focus-visible";
|
||||
|
||||
import { ChakraProvider } from "@chakra-ui/react";
|
||||
import { extendTheme } from "@chakra-ui/react";
|
||||
import { Inter } from "@next/font/google";
|
||||
import type { AppProps } from "next/app";
|
||||
import { SessionProvider } from "next-auth/react";
|
||||
import { getDefaultLayout, NextPageWithLayout } from "src/components/Layout";
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
const inter = Inter({
|
||||
subsets: ["latin"],
|
||||
variable: "--font-inter",
|
||||
});
|
||||
|
||||
const theme = extendTheme({
|
||||
styles: {
|
||||
global: {
|
||||
body: {
|
||||
bg: "white",
|
||||
},
|
||||
main: {
|
||||
fontFamily: "Inter",
|
||||
},
|
||||
header: {
|
||||
fontFamily: "Inter",
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
import { Chakra, getServerSideProps } from "../styles/Chakra";
|
||||
|
||||
type AppPropsWithLayout = AppProps & {
|
||||
Component: NextPageWithLayout;
|
||||
};
|
||||
|
||||
function MyApp({ Component, pageProps: { session, ...pageProps } }: AppPropsWithLayout) {
|
||||
function MyApp({ Component, pageProps: { session, cookies, ...pageProps } }: AppPropsWithLayout) {
|
||||
const getLayout = Component.getLayout ?? getDefaultLayout;
|
||||
const page = getLayout(<Component {...pageProps} />);
|
||||
|
||||
return (
|
||||
<ChakraProvider theme={theme}>
|
||||
<Chakra cookies={cookies}>
|
||||
<SessionProvider session={session}>{page}</SessionProvider>
|
||||
</ChakraProvider>
|
||||
</Chakra>
|
||||
);
|
||||
}
|
||||
|
||||
export { getServerSideProps };
|
||||
export default MyApp;
|
||||
|
||||
@@ -53,7 +53,7 @@ const handler = async (req, res) => {
|
||||
});
|
||||
|
||||
// Update the backend with our Task ID
|
||||
const ackRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/tasks/${task.id}/ack`, {
|
||||
await fetch(`${process.env.FASTAPI_URL}/api/v1/tasks/${task.id}/ack`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"X-API-Key": process.env.FASTAPI_KEY,
|
||||
|
||||
@@ -7,8 +7,7 @@ import prisma from "src/lib/prismadb";
|
||||
* This implicity does a few things:
|
||||
* 1) Stores the answer with the Task Backend.
|
||||
* 2) Records the new task in our local database.
|
||||
* 3) (TODO) Acks the new task with our local task ID to the Task Backend.
|
||||
* 4) Returns the newly created task to the client.
|
||||
* 3) Returns the newly created task to the client.
|
||||
*/
|
||||
const handler = async (req, res) => {
|
||||
const token = await getToken({ req });
|
||||
@@ -69,9 +68,6 @@ const handler = async (req, res) => {
|
||||
},
|
||||
});
|
||||
|
||||
// TODO: Ack the task with the Task Backend using the newly created local
|
||||
// task ID.
|
||||
|
||||
// Send the next task in the sequence to the client.
|
||||
res.status(200).json(newRegisteredTask);
|
||||
};
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
import { Button, Input, Stack } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import Link from "next/link";
|
||||
import { getCsrfToken, getProviders, signIn } from "next-auth/react";
|
||||
import React, { useRef } from "react";
|
||||
import { FaBug, FaDiscord, FaEnvelope, FaGithub } from "react-icons/fa";
|
||||
import { AuthLayout } from "src/components/AuthLayout";
|
||||
import { Footer } from "src/components/Footer";
|
||||
import { Header } from "src/components/Header";
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
export default function Signin({ csrfToken, providers }) {
|
||||
function Signin({ csrfToken, providers }) {
|
||||
const { discord, email, github, credentials } = providers;
|
||||
const emailEl = useRef(null);
|
||||
const signinWithEmail = (ev: React.FormEvent) => {
|
||||
@@ -21,8 +24,14 @@ export default function Signin({ csrfToken, providers }) {
|
||||
signIn(credentials.id, { callbackUrl: "/", username: debugUsernameEl.current.value });
|
||||
}
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const bgColorClass = colorMode === "light" ? "bg-gray-50" : "bg-chakra-gray-900";
|
||||
const buttonBgColor = colorMode === "light" ? "#2563eb" : "#2563eb";
|
||||
|
||||
const buttonColorScheme = colorMode === "light" ? "blue" : "dark-blue-btn";
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className={bgColorClass}>
|
||||
<Head>
|
||||
<title>Sign Up - Open Assistant</title>
|
||||
<meta name="Sign Up" content="Sign up to access Open Assistant" />
|
||||
@@ -30,11 +39,11 @@ export default function Signin({ csrfToken, providers }) {
|
||||
<AuthLayout>
|
||||
<Stack spacing="2">
|
||||
{credentials && (
|
||||
<form onSubmit={signinWithDebugCredentials} className="border-2 border-orange-200 rounded-md p-4 relative">
|
||||
<span className="text-orange-600 absolute -top-3 left-5 bg-white px-1">For Debugging Only</span>
|
||||
<form onSubmit={signinWithDebugCredentials} className="border-2 border-orange-600 rounded-md p-4 relative">
|
||||
<span className={`text-orange-600 absolute -top-3 left-5 ${bgColorClass} px-1`}>For Debugging Only</span>
|
||||
<Stack>
|
||||
<Input variant="outline" size="lg" placeholder="Username" ref={debugUsernameEl} />
|
||||
<Button size={"lg"} leftIcon={<FaBug />} colorScheme="gray" type="submit">
|
||||
<Button size={"lg"} leftIcon={<FaBug />} colorScheme={buttonColorScheme} color="white" type="submit">
|
||||
Continue with Debug User
|
||||
</Button>
|
||||
</Stack>
|
||||
@@ -43,13 +52,13 @@ export default function Signin({ csrfToken, providers }) {
|
||||
{email && (
|
||||
<form onSubmit={signinWithEmail}>
|
||||
<Stack>
|
||||
<Input data-cy="email-address" variant="outline" size="lg" placeholder="Email Address" ref={emailEl} />
|
||||
<Input variant="outline" size="lg" placeholder="Email Address" ref={emailEl} />
|
||||
<Button
|
||||
data-cy="signin-email-button"
|
||||
size={"lg"}
|
||||
leftIcon={<FaEnvelope />}
|
||||
colorScheme="gray"
|
||||
type="submit"
|
||||
colorScheme={buttonColorScheme}
|
||||
color="white"
|
||||
>
|
||||
Continue with Email
|
||||
</Button>
|
||||
@@ -58,7 +67,7 @@ export default function Signin({ csrfToken, providers }) {
|
||||
)}
|
||||
{discord && (
|
||||
<Button
|
||||
bg="#5865F2"
|
||||
bg={buttonBgColor}
|
||||
_hover={{ bg: "#4A57E3" }}
|
||||
_active={{
|
||||
bg: "#454FBF",
|
||||
@@ -90,29 +99,31 @@ export default function Signin({ csrfToken, providers }) {
|
||||
</Stack>
|
||||
<div className="pt-10 text-center">
|
||||
By signing up you agree to our <br></br>
|
||||
<Link href="#" aria-label="Terms of Service" className="hover:underline underline-offset-4">
|
||||
<Link href="/terms-of-service" aria-label="Terms of Service" className="hover:underline underline-offset-4">
|
||||
<b>Terms of Service</b>
|
||||
</Link>{" "}
|
||||
and{" "}
|
||||
<Link href="#" aria-label="Terms of Use" className="hover:underline underline-offset-4">
|
||||
<Link href="/privacy-policy" aria-label="Privacy Policy" className="hover:underline underline-offset-4">
|
||||
<b>Privacy Policy</b>
|
||||
</Link>
|
||||
.
|
||||
</div>
|
||||
<hr className="mt-14 mb-4 h-px bg-gray-200 border-0" />
|
||||
<div className="text-center">
|
||||
Already have an account?{" "}
|
||||
<Link href="#" aria-label="Log In" className="hover:underline underline-offset-4">
|
||||
<b>Log In</b>
|
||||
</Link>
|
||||
</div>
|
||||
</AuthLayout>
|
||||
</>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
export async function getServerSideProps(context) {
|
||||
Signin.getLayout = (page) => (
|
||||
<div className="grid grid-rows-[min-content_1fr_min-content] h-full justify-items-stretch">
|
||||
<Header transparent={true} />
|
||||
{page}
|
||||
<Footer />
|
||||
</div>
|
||||
);
|
||||
|
||||
export default Signin;
|
||||
|
||||
export async function getServerSideProps() {
|
||||
const csrfToken = await getCsrfToken();
|
||||
const providers = await getProviders();
|
||||
return {
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import { Flex, Textarea } from "@chakra-ui/react";
|
||||
import { Container, Textarea } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import { useRef, useState } from "react";
|
||||
import { SkipButton } from "src/components/Buttons/Skip";
|
||||
import { SubmitButton } from "src/components/Buttons/Submit";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Messages } from "src/components/Messages";
|
||||
import { TaskInfo } from "src/components/TaskInfo/TaskInfo";
|
||||
import { TwoColumns } from "src/components/TwoColumns";
|
||||
import { TaskControls } from "src/components/Survey/TaskControls";
|
||||
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import poster from "src/lib/poster";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
@@ -45,43 +44,31 @@ const AssistantReply = () => {
|
||||
mutate();
|
||||
};
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
if (isLoading) {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
|
||||
if (tasks.length == 0) {
|
||||
return <div className="p-6 bg-slate-100 text-gray-800">No tasks found...</div>;
|
||||
return <Container className="p-6 text-center text-gray-800">No tasks found...</Container>;
|
||||
}
|
||||
|
||||
const task = tasks[0].task;
|
||||
const endTask = tasks[tasks.length - 1];
|
||||
|
||||
return (
|
||||
<div className="p-6 bg-slate-100 text-gray-800">
|
||||
<TwoColumns>
|
||||
<div className={`p-12 ${mainBgClasses}`}>
|
||||
<TwoColumnsWithCards>
|
||||
<>
|
||||
<h5 className="text-lg font-semibold">Reply as the assistant</h5>
|
||||
<p className="text-lg py-1">Given the following conversation, provide an adequate reply</p>
|
||||
<Messages messages={task.conversation.messages} post_id={task.id} />
|
||||
</>
|
||||
<Textarea name="reply" data-cy="reply" placeholder="Reply..." ref={inputRef} />
|
||||
</TwoColumns>
|
||||
<Textarea name="reply" placeholder="Reply..." ref={inputRef} />
|
||||
</TwoColumnsWithCards>
|
||||
|
||||
<section className="mb-8 p-4 rounded-lg shadow-lg bg-white flex flex-row justify-items-stretch ">
|
||||
<TaskInfo id={tasks[0].id} output="Submit your answer" />
|
||||
|
||||
<Flex justify="center" ml="auto" gap={2}>
|
||||
<SkipButton>Skip</SkipButton>
|
||||
{endTask.task.type !== "task_done" ? (
|
||||
<SubmitButton data-cy="submit" onClick={() => submitResponse(tasks[0])}>
|
||||
Submit
|
||||
</SubmitButton>
|
||||
) : (
|
||||
<SubmitButton data-cy="next-task" onClick={fetchNextTask}>
|
||||
Next Task
|
||||
</SubmitButton>
|
||||
)}
|
||||
</Flex>
|
||||
</section>
|
||||
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={fetchNextTask} />
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
import { Flex, Textarea } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { Textarea } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import { useRef, useState } from "react";
|
||||
import { SkipButton } from "src/components/Buttons/Skip";
|
||||
import { SubmitButton } from "src/components/Buttons/Submit";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { TaskInfo } from "src/components/TaskInfo/TaskInfo";
|
||||
import { TwoColumns } from "src/components/TwoColumns";
|
||||
import { TaskControls } from "src/components/Survey/TaskControls";
|
||||
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import poster from "src/lib/poster";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
@@ -20,7 +18,7 @@ const SummarizeStory = () => {
|
||||
|
||||
// Fetch the very fist task. We can ignore everything except isLoading
|
||||
// because the onSuccess handler will update `tasks` when ready.
|
||||
const { isLoading } = useSWRImmutable("/api/new_task/summarize_story", fetcher, {
|
||||
const { isLoading, mutate } = useSWRImmutable("/api/new_task/summarize_story", fetcher, {
|
||||
onSuccess: (data) => {
|
||||
setTasks([data]);
|
||||
},
|
||||
@@ -50,6 +48,14 @@ const SummarizeStory = () => {
|
||||
});
|
||||
};
|
||||
|
||||
const fetchNextTask = () => {
|
||||
inputRef.current.value = "";
|
||||
mutate();
|
||||
};
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
if (isLoading) {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
@@ -59,31 +65,20 @@ const SummarizeStory = () => {
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<title>Summarize A Story</title>
|
||||
<meta name="description" content="Summarize a story to train our model." />
|
||||
</Head>
|
||||
<div className={`p-12 ${mainBgClasses}`}>
|
||||
<main className="p-6 h-full mx-auto bg-slate-100 text-gray-800">
|
||||
<TwoColumns>
|
||||
<TwoColumnsWithCards>
|
||||
<>
|
||||
<h5 className="text-lg font-semibold">Instruction</h5>
|
||||
<p className="text-lg py-1">Summarize the following story</p>
|
||||
<div className="bg-slate-800 p-6 rounded-xl text-white whitespace-pre-wrap">{tasks[0].task.story}</div>
|
||||
</>
|
||||
<Textarea name="summary" placeholder="Summary" ref={inputRef} />
|
||||
</TwoColumns>
|
||||
</TwoColumnsWithCards>
|
||||
|
||||
<section className="mb-8 p-4 rounded-lg shadow-lg bg-white flex flex-row justify-items-stretch ">
|
||||
<TaskInfo id={tasks[0].id} output="Submit your answer" />
|
||||
|
||||
<Flex justify="center" ml="auto" gap={2}>
|
||||
<SkipButton>Skip</SkipButton>
|
||||
<SubmitButton onClick={() => submitResponse(tasks[0])}>Submit</SubmitButton>
|
||||
</Flex>
|
||||
</section>
|
||||
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={fetchNextTask} />
|
||||
</main>
|
||||
</>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import { Flex, Textarea } from "@chakra-ui/react";
|
||||
import { Textarea } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import { useRef, useState } from "react";
|
||||
import { SkipButton } from "src/components/Buttons/Skip";
|
||||
import { SubmitButton } from "src/components/Buttons/Submit";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Messages } from "src/components/Messages";
|
||||
import { TaskInfo } from "src/components/TaskInfo/TaskInfo";
|
||||
import { TwoColumns } from "src/components/TwoColumns";
|
||||
import { TaskControls } from "src/components/Survey/TaskControls";
|
||||
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import poster from "src/lib/poster";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
@@ -45,43 +44,38 @@ const UserReply = () => {
|
||||
mutate();
|
||||
};
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
if (isLoading) {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
|
||||
if (tasks.length == 0) {
|
||||
return <div className="p-6 bg-slate-100 text-gray-800">No tasks found...</div>;
|
||||
return (
|
||||
<div className={`p-12 ${mainBgClasses}`}>
|
||||
<div className="flex h-full">
|
||||
<div className="text-xl font-bold mx-auto my-auto">No tasks found...</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const task = tasks[0].task;
|
||||
const endTask = tasks[tasks.length - 1];
|
||||
|
||||
return (
|
||||
<div className="p-6 bg-slate-100 text-gray-800">
|
||||
<TwoColumns>
|
||||
<div className={`p-12 ${mainBgClasses}`}>
|
||||
<TwoColumnsWithCards>
|
||||
<>
|
||||
<h5 className="text-lg font-semibold">Reply as a user</h5>
|
||||
<p className="text-lg py-1">Given the following conversation, provide an adequate reply</p>
|
||||
<Messages messages={task.conversation.messages} post_id={task.id} />
|
||||
{task.hint && <p className="text-lg py-1">Hint: {task.hint}</p>}
|
||||
</>
|
||||
<Textarea name="reply" data-cy="reply" placeholder="Reply..." ref={inputRef} />
|
||||
</TwoColumns>
|
||||
<Textarea name="reply" placeholder="Reply..." ref={inputRef} />
|
||||
</TwoColumnsWithCards>
|
||||
|
||||
<section className="mb-8 p-4 rounded-lg shadow-lg bg-white flex flex-row justify-items-stretch ">
|
||||
<TaskInfo id={tasks[0].id} output="Submit your answer" />
|
||||
<Flex justify="center" ml="auto" gap={2}>
|
||||
<SkipButton>Skip</SkipButton>
|
||||
{endTask.task.type !== "task_done" ? (
|
||||
<SubmitButton data-cy="submit" onClick={() => submitResponse(tasks[0])}>
|
||||
Submit
|
||||
</SubmitButton>
|
||||
) : (
|
||||
<SubmitButton data-cy="next-task" onClick={fetchNextTask}>
|
||||
Next Task
|
||||
</SubmitButton>
|
||||
)}
|
||||
</Flex>
|
||||
</section>
|
||||
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={fetchNextTask} />
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import { Flex } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { useState } from "react";
|
||||
import { SkipButton } from "src/components/Buttons/Skip";
|
||||
import { SubmitButton } from "src/components/Buttons/Submit";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Sortable } from "src/components/Sortable/Sortable";
|
||||
import { TaskInfo } from "src/components/TaskInfo/TaskInfo";
|
||||
import { SurveyCard } from "src/components/Survey/SurveyCard";
|
||||
import { TaskControls } from "src/components/Survey/TaskControls";
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import poster from "src/lib/poster";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
@@ -47,48 +46,42 @@ const RankAssistantReplies = () => {
|
||||
mutate();
|
||||
};
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
if (isLoading) {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
|
||||
if (tasks.length == 0) {
|
||||
return <div className="p-6 bg-slate-100 text-gray-800">Loading...</div>;
|
||||
return (
|
||||
<div className={`p-12 ${mainBgClasses}`}>
|
||||
<div className="flex h-full">
|
||||
<div className="text-xl font-bold mx-auto my-auto">No tasks found...</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const replies = tasks[0].task.replies as string[];
|
||||
const endTask = tasks[tasks.length - 1];
|
||||
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<title>Rank Assistant Replies</title>
|
||||
<meta name="description" content="Rank Assistant Replies." />
|
||||
</Head>
|
||||
<main className="p-6 bg-slate-100 text-gray-800">
|
||||
<div className="rounded-lg shadow-lg block bg-white p-6 mb-8">
|
||||
<div className={`p-12 ${mainBgClasses}`}>
|
||||
<SurveyCard className="max-w-7xl mx-auto h-fit mb-24">
|
||||
<h5 className="text-lg font-semibold mb-4">Instructions</h5>
|
||||
<p className="text-lg py-1">
|
||||
Given the following replies, sort them from best to worst, best being first, worst being last.
|
||||
</p>
|
||||
<Sortable items={replies} onChange={setRanking} />
|
||||
</div>
|
||||
<Sortable items={replies} onChange={setRanking} className="my-8" />
|
||||
</SurveyCard>
|
||||
|
||||
<section className="mb-8 p-4 rounded-lg shadow-lg bg-white flex flex-row justify-items-stretch">
|
||||
<TaskInfo id={tasks[0].id} output="Submit your answer" />
|
||||
|
||||
<Flex justify="center" ml="auto" gap={2}>
|
||||
<SkipButton>Skip</SkipButton>
|
||||
{endTask.task.type !== "task_done" ? (
|
||||
<SubmitButton data-cy="submit" onClick={() => submitResponse(tasks[0])} disabled={ranking.length === 0}>
|
||||
Submit
|
||||
</SubmitButton>
|
||||
) : (
|
||||
<SubmitButton data-cy="next-task" onClick={fetchNextTask}>
|
||||
Next Task
|
||||
</SubmitButton>
|
||||
)}
|
||||
</Flex>
|
||||
</section>
|
||||
</main>
|
||||
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={fetchNextTask} />
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import { Flex } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { useState } from "react";
|
||||
import { SkipButton } from "src/components/Buttons/Skip";
|
||||
import { SubmitButton } from "src/components/Buttons/Submit";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Sortable } from "src/components/Sortable/Sortable";
|
||||
import { TaskInfo } from "src/components/TaskInfo/TaskInfo";
|
||||
import { SurveyCard } from "src/components/Survey/SurveyCard";
|
||||
import { TaskControls } from "src/components/Survey/TaskControls";
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import poster from "src/lib/poster";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
@@ -18,6 +17,7 @@ const RankInitialPrompts = () => {
|
||||
* The best prompt will have index 0, and the worst is the last.
|
||||
*/
|
||||
const [ranking, setRanking] = useState<number[]>([]);
|
||||
// const bg = useColorModeValue("gray.100", "gray.800");
|
||||
|
||||
const { isLoading, mutate } = useSWRImmutable("/api/new_task/rank_initial_prompts", fetcher, {
|
||||
onSuccess: (data) => {
|
||||
@@ -47,47 +47,40 @@ const RankInitialPrompts = () => {
|
||||
mutate();
|
||||
};
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
if (isLoading) {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
|
||||
if (tasks.length == 0) {
|
||||
return <div className="p-6 bg-slate-100 text-gray-800">No tasks found...</div>;
|
||||
return (
|
||||
<div className={`p-12 ${mainBgClasses}`}>
|
||||
<div className="flex h-full">
|
||||
<div className="text-xl font-bold mx-auto my-auto">No tasks found...</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const endTask = tasks[tasks.length - 1];
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<title>Rank Initial Prompts</title>
|
||||
<meta name="description" content="Rank initial prompts." />
|
||||
</Head>
|
||||
<main className="p-6 bg-slate-100 text-gray-800">
|
||||
<div className="rounded-lg shadow-lg block bg-white p-6 mb-8">
|
||||
<div className={`p-12 ${mainBgClasses}`}>
|
||||
<SurveyCard className="max-w-7xl mx-auto h-fit mb-24">
|
||||
<h5 className="text-lg font-semibold mb-4">Instructions</h5>
|
||||
<p className="text-lg py-1">
|
||||
Given the following prompts, sort them from best to worst, best being first, worst being last.
|
||||
</p>
|
||||
<Sortable items={tasks[0].task.prompts} onChange={setRanking} />
|
||||
</div>
|
||||
<Sortable items={tasks[0].task.prompts} onChange={setRanking} className="my-8" />
|
||||
</SurveyCard>
|
||||
|
||||
<section className="mb-8 p-4 rounded-lg shadow-lg bg-white flex flex-row justify-items-stretch">
|
||||
<TaskInfo id={tasks[0].id} output="Submit your answer" />
|
||||
|
||||
<Flex justify="center" ml="auto" gap={2}>
|
||||
<SkipButton>Skip</SkipButton>
|
||||
{endTask.task.type !== "task_done" ? (
|
||||
<SubmitButton data-cy="submit" onClick={() => submitResponse(tasks[0])} disabled={ranking.length === 0}>
|
||||
Submit
|
||||
</SubmitButton>
|
||||
) : (
|
||||
<SubmitButton data-cy="next-task" onClick={fetchNextTask}>
|
||||
Next Task
|
||||
</SubmitButton>
|
||||
)}
|
||||
</Flex>
|
||||
</section>
|
||||
</main>
|
||||
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={fetchNextTask} />
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import { Flex } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { useState } from "react";
|
||||
import { SkipButton } from "src/components/Buttons/Skip";
|
||||
import { SubmitButton } from "src/components/Buttons/Submit";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Sortable } from "src/components/Sortable/Sortable";
|
||||
import { TaskInfo } from "src/components/TaskInfo/TaskInfo";
|
||||
import { SurveyCard } from "src/components/Survey/SurveyCard";
|
||||
import { TaskControls } from "src/components/Survey/TaskControls";
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import poster from "src/lib/poster";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
@@ -47,48 +46,41 @@ const RankUserReplies = () => {
|
||||
mutate();
|
||||
};
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
if (isLoading) {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
|
||||
if (tasks.length == 0) {
|
||||
return <div className="p-6 bg-slate-100 text-gray-800">Loading...</div>;
|
||||
return (
|
||||
<div className={`p-12 ${mainBgClasses}`}>
|
||||
<div className="flex h-full">
|
||||
<div className="text-xl font-bold mx-auto my-auto">No tasks found...</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
const replies = tasks[0].task.replies as string[];
|
||||
|
||||
const endTask = tasks[tasks.length - 1];
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<title>Rank User Replies</title>
|
||||
<meta name="description" content="Rank User Replies." />
|
||||
</Head>
|
||||
<main className="p-6 bg-slate-100 text-gray-800">
|
||||
<div className="rounded-lg shadow-lg block bg-white p-6 mb-8">
|
||||
<div className={`p-12 ${mainBgClasses}`}>
|
||||
<SurveyCard className="max-w-7xl mx-auto h-fit mb-24">
|
||||
<h5 className="text-lg font-semibold mb-4">Instructions</h5>
|
||||
<p className="text-lg py-1">
|
||||
Given the following replies, sort them from best to worst, best being first, worst being last.
|
||||
</p>
|
||||
<Sortable items={replies} onChange={setRanking} />
|
||||
</div>
|
||||
<Sortable items={replies} onChange={setRanking} className="my-8" />
|
||||
</SurveyCard>
|
||||
|
||||
<section className="mb-8 p-4 rounded-lg shadow-lg bg-white flex flex-row justify-items-stretch ">
|
||||
<TaskInfo id={tasks[0].id} output="Submit your answer" />
|
||||
|
||||
<Flex justify="center" ml="auto" gap={2}>
|
||||
<SkipButton>Skip</SkipButton>
|
||||
{endTask.task.type !== "task_done" ? (
|
||||
<SubmitButton data-cy="submit" onClick={() => submitResponse(tasks[0])} disabled={ranking.length === 0}>
|
||||
Submit
|
||||
</SubmitButton>
|
||||
) : (
|
||||
<SubmitButton data-cy="next-task" onClick={fetchNextTask}>
|
||||
Next Task
|
||||
</SubmitButton>
|
||||
)}
|
||||
</Flex>
|
||||
</section>
|
||||
</main>
|
||||
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={fetchNextTask} />
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
import { Flex, Textarea } from "@chakra-ui/react";
|
||||
import { Textarea } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import { QuestionMarkCircleIcon } from "@heroicons/react/20/solid";
|
||||
import Head from "next/head";
|
||||
import { useState } from "react";
|
||||
import { SkipButton } from "src/components/Buttons/Skip";
|
||||
import { SubmitButton } from "src/components/Buttons/Submit";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import RatingRadioGroup from "src/components/RatingRadioGroup";
|
||||
import { TaskInfo } from "src/components/TaskInfo/TaskInfo";
|
||||
import { TwoColumns } from "src/components/TwoColumns";
|
||||
import { TaskControls } from "src/components/Survey/TaskControls";
|
||||
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import poster from "src/lib/poster";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
@@ -50,15 +49,27 @@ const RateSummary = () => {
|
||||
});
|
||||
};
|
||||
|
||||
const fetchNextTask = () => {
|
||||
mutate();
|
||||
};
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
if (isLoading) {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
|
||||
if (tasks.length == 0) {
|
||||
return <div className="p-6 bg-slate-100 text-gray-800">No tasks found...</div>;
|
||||
return (
|
||||
<div className={`p-12 ${mainBgClasses}`}>
|
||||
<div className="flex h-full">
|
||||
<div className="text-xl font-bold mx-auto my-auto">No tasks found...</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const endTask = tasks[tasks.length - 1];
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
@@ -66,7 +77,7 @@ const RateSummary = () => {
|
||||
<meta name="description" content="Rate a proposed story summary." />
|
||||
</Head>
|
||||
<main className="p-6 bg-slate-100 text-gray-800">
|
||||
<TwoColumns>
|
||||
<TwoColumnsWithCards>
|
||||
<>
|
||||
<h5 className="text-lg font-semibold mb-4">Instruction</h5>
|
||||
<div className="bg-slate-800 p-6 rounded-xl text-white whitespace-pre-wrap">{tasks[0].task.full_text}</div>
|
||||
@@ -89,20 +100,9 @@ const RateSummary = () => {
|
||||
</ul>
|
||||
<Textarea name="notes" placeholder="Optional notes" />
|
||||
</section>
|
||||
</TwoColumns>
|
||||
</TwoColumnsWithCards>
|
||||
|
||||
<section className="mb-8 p-4 rounded-lg shadow-lg bg-white flex flex-row justify-items-stretch ">
|
||||
<TaskInfo id={tasks[0].id} output="Submit your answer" />
|
||||
|
||||
<Flex justify="center" ml="auto" gap={2}>
|
||||
<SkipButton>Skip</SkipButton>
|
||||
{endTask.task.type !== "task_done" ? (
|
||||
<SubmitButton onClick={() => submitResponse(tasks[0])}>Submit</SubmitButton>
|
||||
) : (
|
||||
<SubmitButton onClick={mutate}>Next Task</SubmitButton>
|
||||
)}
|
||||
</Flex>
|
||||
</section>
|
||||
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={fetchNextTask} />
|
||||
</main>
|
||||
</>
|
||||
);
|
||||
|
||||
@@ -20,11 +20,9 @@ const Home = () => {
|
||||
/>
|
||||
</Head>
|
||||
{session ? (
|
||||
<main className="my-4">
|
||||
<TaskSelection />
|
||||
</main>
|
||||
<TaskSelection />
|
||||
) : (
|
||||
<main>
|
||||
<main className="oa-basic-theme">
|
||||
<Hero />
|
||||
<CallToAction />
|
||||
<Faq />
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
import { ChakraProvider, cookieStorageManagerSSR, localStorageManager } from "@chakra-ui/react";
|
||||
|
||||
import { theme } from "./Theme";
|
||||
|
||||
export function Chakra({ cookies, children }) {
|
||||
const colorModeManager = typeof cookies === "string" ? cookieStorageManagerSSR(cookies) : localStorageManager;
|
||||
|
||||
return (
|
||||
<ChakraProvider theme={theme} colorModeManager={colorModeManager}>
|
||||
{children}
|
||||
</ChakraProvider>
|
||||
);
|
||||
}
|
||||
|
||||
// also export a reusable function getServerSideProps
|
||||
export function getServerSideProps({ req }) {
|
||||
return {
|
||||
props: {
|
||||
// first time users will not have any cookies and you may not return
|
||||
// undefined here, hence ?? is necessary
|
||||
cookies: req.headers.cookie ?? "",
|
||||
},
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
export const colors = {
|
||||
light: {
|
||||
bg: "rgb(250,250,250)",
|
||||
text: "black",
|
||||
},
|
||||
dark: {
|
||||
bg: "gray.900",
|
||||
text: "white",
|
||||
},
|
||||
"dark-blue-btn": {
|
||||
200: "rgb(29,78,216)",
|
||||
300: "blue",
|
||||
},
|
||||
};
|
||||
@@ -0,0 +1,14 @@
|
||||
import { defineStyleConfig } from "@chakra-ui/styled-system";
|
||||
|
||||
const baseStyle = {};
|
||||
|
||||
const variants = {
|
||||
"no-padding": {
|
||||
padding: 0,
|
||||
},
|
||||
};
|
||||
|
||||
export const containerTheme = defineStyleConfig({
|
||||
baseStyle,
|
||||
variants,
|
||||
});
|
||||
@@ -0,0 +1,37 @@
|
||||
import { type ThemeConfig, extendTheme } from "@chakra-ui/react";
|
||||
import { Styles } from "@chakra-ui/theme-tools";
|
||||
|
||||
import { colors } from "./colors";
|
||||
import { containerTheme } from "./components/Container";
|
||||
|
||||
const config: ThemeConfig = {
|
||||
initialColorMode: "light",
|
||||
useSystemColorMode: true,
|
||||
disableTransitionOnChange: false,
|
||||
};
|
||||
|
||||
const components = {
|
||||
Container: containerTheme,
|
||||
};
|
||||
|
||||
const styles: Styles = {
|
||||
global: (props) => ({
|
||||
"*": {
|
||||
transition: "background-color 200ms cubic-bezier(0.4, 0, 1, 1)",
|
||||
// bg: props.colorMode === "light" ? colors.light.bg : colors.dark.bg,
|
||||
// color: props.colorMode === "light" ? colors.light.text : colors.dark.text,
|
||||
},
|
||||
".oa-basic-theme": {
|
||||
bg: props.colorMode === "light" ? colors.light.bg : colors.dark.bg,
|
||||
color: props.colorMode === "light" ? colors.light.text : colors.dark.text,
|
||||
},
|
||||
main: {
|
||||
fontFamily: "Inter",
|
||||
},
|
||||
header: {
|
||||
fontFamily: "Inter",
|
||||
},
|
||||
}),
|
||||
};
|
||||
|
||||
export const theme = extendTheme({ colors, config, styles, components });
|
||||
@@ -71,6 +71,10 @@ module.exports = {
|
||||
maxWidth: {
|
||||
"2xl": "40rem",
|
||||
},
|
||||
|
||||
colors: {
|
||||
"chakra-gray-900": "#171923",
|
||||
},
|
||||
},
|
||||
},
|
||||
plugins: [require("@tailwindcss/forms")],
|
||||
|
||||
Reference in New Issue
Block a user