diff --git a/.github/workflows/test-api-contract.yaml b/.github/workflows/test-api-contract.yaml index 3707f4de..6312cc50 100644 --- a/.github/workflows/test-api-contract.yaml +++ b/.github/workflows/test-api-contract.yaml @@ -15,6 +15,9 @@ jobs: - uses: actions/setup-python@v4 with: python-version: "3.10" + - uses: actions/setup-node@v3 + with: + node-version: 16 - run: cd oasst-shared && pip install -e . @@ -22,9 +25,14 @@ jobs: - run: cd backend && pip install -r requirements.txt + - run: cd website && npm ci + - run: ./scripts/backend-development/start-mock-server.sh - - name: Run contract tests + - name: Run Python OasstApiClient contract tests run: ./scripts/oasst-shared-development/test.sh + - name: Run JavaScript OasstApiClient contract tests + run: ./scripts/frontend-development/run-contract-test.sh + - run: ./scripts/backend-development/stop-mock-server.sh diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 27a6511d..0bb81e89 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,7 +26,7 @@ # # /WARNING! -exclude: build|stubs|^bot/templates/$ +exclude: build|stubs|^bot/templates/$|openassistant/templates repos: - repo: https://github.com/pre-commit/pre-commit-hooks diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000..608afe25 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,110 @@ +# I’m in! Now what? + +[Join the OpenAssistant Contributors Discord Server!](https://ykilcher.com/open-assistant-discord), +this is for work coordination. + +[Join the LAION Discord Server!](https://discord.com/invite/mVcgxMPD7e), it has +a dedicated channel and is more public. + +[and / or the YK Discord Server](https://ykilcher.com/discord), also has a +dedicated, but not as active, channel. + +[Visit the Notion](https://ykilcher.com/open-assistant) + +### Taking on Tasks + +We have a growing task list +[of issues](https://github.com/LAION-AI/Open-Assistant/issues). Find an issue +that appeals to you and make a comment that you'd like to work on it. Include in +your comment a brief description of how you'll solve the problem and if there +are any open questions you want to discuss. Once a project coordinator has +assigned the issue to you, start working on it. + +If the issue is currently unclear but you are interested, please post in Discord +and someone can help clarify the issue with more detail. + +**Always Welcome:** Documentation markdowns in `docs/`, docstrings, diagrams of +the system architecture, and other documentation. + +### Submitting Work + +We're all working on different parts of Open Assistant together. To make +contributions smoothly we recommend the following: + +1. [Fork this project repository](https://docs.github.com/en/get-started/quickstart/fork-a-repo) + and clone it to your local machine. (Read more + [About Forks](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/about-forks)) +1. Before working on any changes, try to + [sync the forked repository](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork) + to keep it up-to-date with the upstream repository. +1. Work on a small focused change that only touches on a few files. +1. Run `pre-commit` and make sure all files have formatting fixed. This + simplifies life for reviewers. +1. Package up a small bit of work that solves part of the problem + [into a Pull Request](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork) + and + [send it out for review](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/requesting-a-pull-request-review). +1. If you're lucky, we can merge your change into `main` without any problems. + If there's changes to files you're working on, resolve them by: +1. First try rebase as suggested + [in these instructions](https://timwise.co.uk/2019/10/14/merge-vs-rebase/#should-you-rebase). +1. If rebase feels too painful, merge as suggested + [in these instructions](https://timwise.co.uk/2019/10/14/merge-vs-rebase/#should-you-merge). +1. Once you've resolved any conflicts, finish the review and merge into `main`. +1. Merge in your change and move onto a new issue or the second step of your + current issue. + +Additionally, if someone is working on an issue that interests you, ask if they +need help on it or would like suggestions on how to approach the issue. If so, +share wildly. If they seem to have a good handle on it, let them work on their +solution until a challenge comes up. + +### When does a review finish + +A review finishes when all blocking comments are addressed and at least one +owning reviewer has approved the PR. Be sure to acknowledge any non-blocking +comments either by making the request change, explaining why it's not being +addressed now, or filing an issue to handle it later. + +## Developer Setup + +Work is organized in the +[project board](https://github.com/orgs/LAION-AI/projects/3). + +**Anything that is in the `Todo` column and not assigned, is up for grabs. +Meaning we'd be happy for anyone to do these tasks.** + +If you want to work on something, assign yourself to it or write a comment that +you want to work on it and what you plan to do. + +- To get started with development, if you want to work on the backend, have a + look at `scripts/backend-development/README.md`. +- If you want to work on any frontend, have a look at + `scripts/frontend-development/README.md` to make a backend available. + +There is also a minimal implementation of a frontend in the `text-frontend` +folder. + +We are using Python 3.10 for the backend. + +Check out the +[High-Level Protocol Architecture](https://www.notion.so/High-Level-Protocol-Architecture-6f1fd3551da74213b560ead369f132dc) + +### Website + +The website is built using Next.js and is in the `website` folder. + +### Pre-commit + +Install `pre-commit` and run `pre-commit install` to install the pre-commit +hooks. + +In case you haven't done this, have already committed, and CI is failing, you +can run `pre-commit run --all-files` to run the pre-commit hooks on all files. + +### Deployment + +Upon making a release on GitHub, all docker images are automatically built and +pushed to ghcr.io. The docker images are tagged with the release version, and +the `latest` tag. Further, the ansible playbook in `ansible/dev.yaml` is run to +automatically deploy the built release to the dev machine. diff --git a/README.md b/README.md index b619c931..a443a5d7 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,34 @@ -# Open-Assistant +

+ Open-Assistant + +

-Open Assistant is a project meant to give everyone access to a great chat based -large language model. +
+ +![GitHub Repo stars](https://img.shields.io/github/stars/LAION-AI/Open-Assistant?style=social) +![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/LAION-AI/Open-Assistant/build-frontend.yaml?label=frontend) +![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/LAION-AI/Open-Assistant/pre-commit.yaml?label=pre-commit) +![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/LAION-AI/Open-Assistant/test-api-contract.yaml?label=api) +![GitHub release (latest by date)](https://img.shields.io/github/v/release/LAION-AI/Open-Assistant) + +
+ +# Table of Contents + +- [What is Open Assistant?](#what-is-open-assistant) +- [Do you want to try it out?](#do-you-want-to-try-it-out) +- [The Plan](#the-plan) +- [The Vision](#the-vision) +- [How can you help?](#how-can-you-help) +- [I’m in! How do I contribute?](CONTRIBUTING.md) + +--- + +## What is Open Assistant? + +

+ Open Assistant is a project meant to give everyone access to a great chat based large language model. +

We believe that by doing this we will create a revolution in innovation in language. In the same way that stable-diffusion helped the world make art and @@ -14,7 +41,7 @@ If you are interested in taking a look at the current state of the project, you can set up an entire stack needed to run **Open-Assistant**, including the website, backend, and associated dependent services. -To start the demo, run this in the root directory of the repository: +##### To start the demo, run this in the root directory of the repository: ```sh docker compose up --build @@ -23,22 +50,21 @@ docker compose up --build Then, navigate to `http://localhost:3000` (It may take some time to boot up) and interact with the website. -**Note:** When logging in via email, navigate to `http://localhost:1080` to get -the magic email login link. +> **Note:** When logging in via email, navigate to `http://localhost:1080` to +> get the magic email login link. -**Note:** If you would like to run this in a standardized development -environment (a -["devcontainer"](https://code.visualstudio.com/docs/devcontainers/containers)) -using -[vscode locally](https://code.visualstudio.com/docs/devcontainers/create-dev-container#_create-a-devcontainerjson-file) -or in a web browser using -[GitHub Codespaces](https://github.com/features/codespaces), you can use the -provided [`.devcontainer`](.devcontainer/) folder. +> **Note:** If you would like to run this in a standardized development +> environment (a +> ["devcontainer"](https://code.visualstudio.com/docs/devcontainers/containers)) +> using +> [vscode locally](https://code.visualstudio.com/docs/devcontainers/create-dev-container#_create-a-devcontainerjson-file) +> or in a web browser using +> [GitHub Codespaces](https://github.com/features/codespaces), you can use the +> provided [`.devcontainer`](.devcontainer/) folder. ## The Plan -We want to get to an initial MVP as fast as possible, by following the 3-steps -outlined in the InstructGPT paper. +##### We want to get to an initial MVP as fast as possible, by following the 3-steps outlined in the InstructGPT paper. 1. Collect high-quality human generated Instruction-Fulfillment samples (prompt + response), goal >50k. We design a crowdsourced process to collect @@ -80,113 +106,4 @@ All open source projects begin with people like you. Open source is the belief that if we collaborate we can together gift our knowledge and technology to the world for the benefit of humanity. -## I’m in! Now what? - -[Join the OpenAssistant Contributors Discord Server!](https://ykilcher.com/open-assistant-discord), -this is for work coordination. - -[Join the LAION Discord Server!](https://discord.com/invite/mVcgxMPD7e), it has -a dedicated channel and is more public. - -[and / or the YK Discord Server](https://ykilcher.com/discord), also has a -dedicated, but not as active, channel. - -[Visit the Notion](https://ykilcher.com/open-assistant) - -### Taking on Tasks - -We have a growing task list -[of issues](https://github.com/LAION-AI/Open-Assistant/issues). Find an issue -that appeals to you and make a comment that you'd like to work on it. Include in -your comment a brief description of how you'll solve the problem and if there -are any open questions you want to discuss. Once a project coordinator has -assigned the issue to you, start working on it. - -If the issue is currently unclear but you are interested, please post in Discord -and someone can help clarify the issue with more detail. - -**Always Welcome:** Documentation markdowns in `docs/`, docstrings, diagrams of -the system architecture, and other documentation. - -### Submitting Work - -We're all working on different parts of Open Assistant together. To make -contributions smoothly we recommend the following: - -1. [Fork this project repository](https://docs.github.com/en/get-started/quickstart/fork-a-repo) - and clone it to your local machine. (Read more - [About Forks](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/about-forks)) -1. Before working on any changes, try to - [sync the forked repository](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork) - to keep it up-to-date with the upstream repository. -1. Work on a small focused change that only touches on a few files. -1. Run `pre-commit` and make sure all files have formatting fixed. This - simplifies life for reviewers. -1. Package up a small bit of work that solves part of the problem - [into a Pull Request](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork) - and - [send it out for review](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/requesting-a-pull-request-review). -1. If you're lucky, we can merge your change into `main` without any problems. - If there's changes to files you're working on, resolve them by: -1. First try rebase as suggested - [in these instructions](https://timwise.co.uk/2019/10/14/merge-vs-rebase/#should-you-rebase). -1. If rebase feels too painful, merge as suggested - [in these instructions](https://timwise.co.uk/2019/10/14/merge-vs-rebase/#should-you-merge). -1. Once you've resolved any conflicts, finish the review and merge into `main`. -1. Merge in your change and move onto a new issue or the second step of your - current issue. - -Additionally, if someone is working on an issue that interests you, ask if they -need help on it or would like suggestions on how to approach the issue. If so, -share wildly. If they seem to have a good handle on it, let them work on their -solution until a challenge comes up. - -### When does a review finish - -A review finishes when all blocking comments are addressed and at least one -owning reviewer has approved the PR. Be sure to acknowledge any non-blocking -comments either by making the request change, explaining why it's not being -addressed now, or filing an issue to handle it later. - -## Developer Setup - -Work is organized in the -[project board](https://github.com/orgs/LAION-AI/projects/3). - -**Anything that is in the `Todo` column and not assigned, is up for grabs. -Meaning we'd be happy for anyone to do these tasks.** - -If you want to work on something, assign yourself to it or write a comment that -you want to work on it and what you plan to do. - -- To get started with development, if you want to work on the backend, have a - look at `scripts/backend-development/README.md`. -- If you want to work on any frontend, have a look at - `scripts/frontend-development/README.md` to make a backend available. - -There is also a minimal implementation of a frontend in the `text-frontend` -folder. - -We are using Python 3.10 for the backend. - -Check out the -[High-Level Protocol Architecture](https://www.notion.so/High-Level-Protocol-Architecture-6f1fd3551da74213b560ead369f132dc) - -### Website - -The website is built using Next.js and is in the `website` folder. - -### Pre-commit - -Install `pre-commit` and run `pre-commit install` to install the pre-commit -hooks. - -In case you haven't done this, have already committed, and CI is failing, you -can run `pre-commit run --all-files` to run the pre-commit hooks on all files. - -### Deployment - -Upon making a release on GitHub, all docker images are automatically built and -pushed to ghcr.io. The docker images are tagged with the release version, and -the `latest` tag. Further, the ansible playbook in `ansible/dev.yaml` is run to -automatically deploy the built release to the dev machine. +Check out our [contributing guide](CONTRIBUTING.md) to get started. diff --git a/ansible/dev.yaml b/ansible/dev.yaml index d022ba3c..577abd68 100644 --- a/ansible/dev.yaml +++ b/ansible/dev.yaml @@ -54,6 +54,7 @@ DEBUG_ALLOW_ANY_API_KEY: "true" DEBUG_USE_SEED_DATA: "true" MAX_WORKERS: "1" + RATE_LIMIT: "false" ports: - 8080:8080 diff --git a/assets/logo_crop.png b/assets/logo_crop.png new file mode 100644 index 00000000..20630d6b Binary files /dev/null and b/assets/logo_crop.png differ diff --git a/backend/.gitignore b/backend/.gitignore index 098a83e4..30c79448 100644 --- a/backend/.gitignore +++ b/backend/.gitignore @@ -1,3 +1,4 @@ __pycache__ .env notes.txt +venv diff --git a/backend/README.md b/backend/README.md index 45d16d68..863090f7 100644 --- a/backend/README.md +++ b/backend/README.md @@ -11,10 +11,12 @@ Example contents of a `.env` file for the backend: ``` DATABASE_URI="postgresql://:@/" BACKEND_CORS_ORIGINS=["http://localhost", "http://localhost:4200", "http://localhost:3000", "http://localhost:8080", "https://localhost", "https://localhost:4200", "https://localhost:3000", "https://localhost:8080", "http://dev.oasst.laion.ai", "https://stag.oasst.laion.ai", "https://oasst.laion.ai"] - +REDIS_HOST=localhost +REDIS_PORT=6379 ``` ## Running the REST Server locally for development Have a look into the main `README.md` file for more information on how to set up -the backend for development. +the backend for development. Use the scripts within the +scripts/backend-development folder to run the BE API locally. diff --git a/backend/alembic/versions/2023_01_05_1144-d4161e384f83_added_messagetreestate_table.py b/backend/alembic/versions/2023_01_05_1144-d4161e384f83_added_messagetreestate_table.py new file mode 100644 index 00000000..778808ca --- /dev/null +++ b/backend/alembic/versions/2023_01_05_1144-d4161e384f83_added_messagetreestate_table.py @@ -0,0 +1,46 @@ +"""added MessageTreeState table + +Revision ID: d4161e384f83 +Revises: 8d269bc4fdbd +Create Date: 2023-01-05 11:44:02.630633 + +""" +import sqlalchemy as sa +import sqlmodel +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "d4161e384f83" +down_revision = "8d269bc4fdbd" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "message_tree_state", + sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False), + sa.Column("created_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.Column("deleted", sa.Boolean(), server_default=sa.text("false"), nullable=False), + sa.Column("message_tree_id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("state", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False), + sa.Column("goal_tree_size", sa.Integer(), nullable=False), + sa.Column("current_num_non_filtered_messages", sa.Integer(), nullable=False), + sa.Column("max_depth", sa.Integer(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_message_tree_state_message_tree_id"), "message_tree_state", ["message_tree_id"], unique=False + ) + op.create_index("ix_message_tree_state_tree_id", "message_tree_state", ["message_tree_id"], unique=True) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index("ix_message_tree_state_tree_id", table_name="message_tree_state") + op.drop_index(op.f("ix_message_tree_state_message_tree_id"), table_name="message_tree_state") + op.drop_table("message_tree_state") + # ### end Alembic commands ### diff --git a/backend/alembic/versions/2023_01_05_1346-3b0adfadbef9_removed_date_created_and_deleted_flag_.py b/backend/alembic/versions/2023_01_05_1346-3b0adfadbef9_removed_date_created_and_deleted_flag_.py new file mode 100644 index 00000000..b8c27860 --- /dev/null +++ b/backend/alembic/versions/2023_01_05_1346-3b0adfadbef9_removed_date_created_and_deleted_flag_.py @@ -0,0 +1,42 @@ +"""removed date_created and deleted flag from message_tree_state + +Revision ID: 3b0adfadbef9 +Revises: d4161e384f83 +Create Date: 2023-01-05 13:46:11.338655 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "3b0adfadbef9" +down_revision = "d4161e384f83" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("message_tree_state", "deleted") + op.drop_column("message_tree_state", "created_date") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "message_tree_state", + sa.Column( + "created_date", + postgresql.TIMESTAMP(), + server_default=sa.text("CURRENT_TIMESTAMP"), + autoincrement=False, + nullable=False, + ), + ) + op.add_column( + "message_tree_state", + sa.Column("deleted", sa.BOOLEAN(), server_default=sa.text("false"), autoincrement=False, nullable=False), + ) + # ### end Alembic commands ### diff --git a/backend/main.py b/backend/main.py index cb682a9f..edbad943 100644 --- a/backend/main.py +++ b/backend/main.py @@ -230,8 +230,8 @@ if __name__ == "__main__": help="Dumps the openapi schema to stdout", action=argparse.BooleanOptionalAction, ) - parser.add_argument("--host", help="The host to run the server") - parser.add_argument("--port", help="The port to run the server") + parser.add_argument("--host", help="The host to run the server", default="0.0.0.0") + parser.add_argument("--port", help="The port to run the server", default=8080) args = parser.parse_args() diff --git a/backend/oasst_backend/api/v1/api.py b/backend/oasst_backend/api/v1/api.py index a9d09457..5bdf1c97 100644 --- a/backend/oasst_backend/api/v1/api.py +++ b/backend/oasst_backend/api/v1/api.py @@ -2,6 +2,7 @@ from fastapi import APIRouter from oasst_backend.api.v1 import ( frontend_messages, frontend_users, + hugging_face, leaderboards, messages, stats, @@ -19,3 +20,4 @@ api_router.include_router(users.router, prefix="/users", tags=["users"]) api_router.include_router(frontend_users.router, prefix="/frontend_users", tags=["frontend_users"]) api_router.include_router(stats.router, prefix="/stats", tags=["stats"]) api_router.include_router(leaderboards.router, prefix="/experimental/leaderboards", tags=["leaderboards"]) +api_router.include_router(hugging_face.router, prefix="/hf", tags=["hugging_face"]) diff --git a/backend/oasst_backend/api/v1/hugging_face.py b/backend/oasst_backend/api/v1/hugging_face.py new file mode 100644 index 00000000..1e7f1ffe --- /dev/null +++ b/backend/oasst_backend/api/v1/hugging_face.py @@ -0,0 +1,37 @@ +from enum import Enum +from typing import List + +from fastapi import APIRouter, Depends +from oasst_backend.api import deps +from oasst_backend.models import ApiClient +from oasst_backend.schemas.hugging_face import ToxicityClassification +from oasst_backend.utils.hugging_face import HuggingFaceAPI + +router = APIRouter() + + +class HF_url(str, Enum): + HUGGINGFACE_TOXIC_ROBERTA = "https://api-inference.huggingface.co/models/unitary/multilingual-toxic-xlm-roberta" + + +@router.get("/text_toxicity") +async def get_text_toxicity( + msg: str, + api_client: ApiClient = Depends(deps.get_trusted_api_client), +) -> List[List[ToxicityClassification]]: + """Get the Message Toxicity from HuggingFace Roberta model. + + Args: + msg (str): the message that we want to analyze. + api_client (ApiClient, optional): authentification of the user of the request. + Defaults to Depends(deps.get_trusted_api_client). + + Returns: + ToxicityClassification: the score of toxicity of the message. + """ + + api_url: str = HF_url.HUGGINGFACE_TOXIC_ROBERTA.value + hugging_face_api = HuggingFaceAPI(api_url) + response = await hugging_face_api.post(msg) + + return response diff --git a/backend/oasst_backend/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py index e9ecc854..05dc92a9 100644 --- a/backend/oasst_backend/api/v1/tasks.py +++ b/backend/oasst_backend/api/v1/tasks.py @@ -24,14 +24,13 @@ def generate_task( match request.type: case protocol_schema.TaskRequestType.random: logger.info("Frontend requested a random task.") - while request.type == protocol_schema.TaskRequestType.random: - disabled_tasks = ( - protocol_schema.TaskRequestType.summarize_story, - protocol_schema.TaskRequestType.rate_summary, - ) - request.type = random.choice( - tuple(set(protocol_schema.TaskRequestType).difference(disabled_tasks)) - ).value + disabled_tasks = ( + protocol_schema.TaskRequestType.random, + protocol_schema.TaskRequestType.summarize_story, + protocol_schema.TaskRequestType.rate_summary, + ) + candidate_tasks = set(protocol_schema.TaskRequestType).difference(disabled_tasks) + request.type = random.choice(tuple(candidate_tasks)).value return generate_task(request, pr) # AKo: Summary tasks are currently disabled/supported, we focus on the conversation tasks. diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index fef59832..df37dc9f 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -22,6 +22,8 @@ class Settings(BaseSettings): DEBUG_SKIP_API_KEY_CHECK: bool = False DEBUG_USE_SEED_DATA: bool = False + HUGGING_FACE_API_KEY: str = "" + @validator("DATABASE_URI", pre=True) def assemble_db_connection(cls, v: Optional[str], values: Dict[str, Any]) -> Any: if isinstance(v, str): diff --git a/backend/oasst_backend/models/__init__.py b/backend/oasst_backend/models/__init__.py index 5818dbef..a856b155 100644 --- a/backend/oasst_backend/models/__init__.py +++ b/backend/oasst_backend/models/__init__.py @@ -2,6 +2,7 @@ from .api_client import ApiClient from .journal import Journal, JournalIntegration from .message import Message from .message_reaction import MessageReaction +from .message_tree_state import MessageTreeState from .task import Task from .text_labels import TextLabels from .user import User @@ -13,6 +14,7 @@ __all__ = [ "UserStats", "Message", "MessageReaction", + "MessageTreeState", "Task", "TextLabels", "Journal", diff --git a/backend/oasst_backend/models/message_tree_state.py b/backend/oasst_backend/models/message_tree_state.py new file mode 100644 index 00000000..386595e9 --- /dev/null +++ b/backend/oasst_backend/models/message_tree_state.py @@ -0,0 +1,44 @@ +from enum import Enum +from typing import Optional +from uuid import UUID, uuid4 + +import sqlalchemy as sa +import sqlalchemy.dialects.postgresql as pg +from sqlmodel import Field, Index, SQLModel + +# The types of States a message tree can have. + + +class States(Enum): + INITIAL = "initial" + BREEDING_PHASE = "breeding_phase" + RANKING_PHASE = "ranking_phase" + READY_FOR_SCORING = "ready_for_scoring" + CHILDREN_SCORED = "children_scored" + FINAL = "final" + + +VALID_STATES = ( + States.INITIAL, + States.BREEDING_PHASE, + States.RANKING_PHASE, + States.READY_FOR_SCORING, + States.CHILDREN_SCORED, + States.FINAL, +) + + +class MessageTreeState(SQLModel, table=True): + __tablename__ = "message_tree_state" + __table_args__ = (Index("ix_message_tree_state_tree_id", "message_tree_id", unique=True),) + + id: Optional[UUID] = Field( + sa_column=sa.Column( + pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()") + ), + ) + message_tree_id: UUID = Field(nullable=False, index=True) + state: str = Field(nullable=False, max_length=128) + goal_tree_size: int = Field(nullable=False) + current_num_non_filtered_messages: int = Field(nullable=False) + max_depth: int = Field(nullable=False) diff --git a/backend/oasst_backend/schemas/__init__.py b/backend/oasst_backend/schemas/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/oasst_backend/schemas/hugging_face.py b/backend/oasst_backend/schemas/hugging_face.py new file mode 100644 index 00000000..f4da3e74 --- /dev/null +++ b/backend/oasst_backend/schemas/hugging_face.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel + + +class ToxicityClassification(BaseModel): + label: str + score: float diff --git a/backend/oasst_backend/utils/__init__.py b/backend/oasst_backend/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/oasst_backend/utils/hugging_face.py b/backend/oasst_backend/utils/hugging_face.py new file mode 100644 index 00000000..0df913f5 --- /dev/null +++ b/backend/oasst_backend/utils/hugging_face.py @@ -0,0 +1,51 @@ +from typing import Any, Dict + +import aiohttp +from oasst_backend.config import settings +from oasst_shared.exceptions import OasstError, OasstErrorCode + + +class HuggingFaceAPI: + """Class Object to make post calls to endpoints for inference in models hosted in HuggingFace""" + + def __init__( + self, + api_url: str, + ): + + # The API endpoint we want to access + self.api_url: str = api_url + + # Access token for the api + self.api_key: str = settings.HUGGING_FACE_API_KEY + + # Headers going to be used + self.headers: Dict[str, str] = {"Authorization": f"Bearer {self.api_key}"} + + async def post(self, input: str) -> Any: + """Post request to the endpoint to get an inference + + Args: + input (str): the input that we will pass to the model + + Raises: + OasstError: in the case we get a bad response + + Returns: + inference: the inference we obtain from the model in HF + """ + + async with aiohttp.ClientSession() as session: + payload: Dict[str, str] = {"inputs": input} + + async with session.post(self.api_url, headers=self.headers, json=payload) as response: + # If we get a bad response + if response.status != 200: + raise OasstError( + "Response Error Detoxify HuggingFace", error_code=OasstErrorCode.HUGGINGFACE_API_ERROR + ) + + # Get the response from the API call + inference = await response.json() + + return inference diff --git a/copilot/api/manifest.yml b/copilot/api/manifest.yml index b9262b51..b6ff6cf7 100644 --- a/copilot/api/manifest.yml +++ b/copilot/api/manifest.yml @@ -36,3 +36,4 @@ environments: secrets: # Note: URI, not URL. DATABASE_URI: /copilot/${COPILOT_APPLICATION_NAME}/${COPILOT_ENVIRONMENT_NAME}/secrets/API_DATABASE_URL + REDIS_HOST: /copilot/${COPILOT_APPLICATION_NAME}/${COPILOT_ENVIRONMENT_NAME}/secrets/REDIS_HOST diff --git a/discord-bot/bot/extensions/work.py b/discord-bot/bot/extensions/work.py index 6b7f8ea4..0561039d 100644 --- a/discord-bot/bot/extensions/work.py +++ b/discord-bot/bot/extensions/work.py @@ -29,8 +29,8 @@ from oasst_shared.schemas.protocol import TaskRequestType plugin = lightbulb.Plugin("WorkPlugin") -MAX_TASK_TIME = 60 * 60 # 1 hour -MAX_TASK_ACCEPT_TIME = 60 # 1 minute +MAX_TASK_TIME = 60 * 60 # seconds +MAX_TASK_ACCEPT_TIME = 60 * 10 # seconds settings = Settings() @@ -117,7 +117,7 @@ async def _handle_task(ctx: lightbulb.Context, task_type: TaskRequestType) -> No # Task action loop completed = False while not completed: - await ctx.author.send(embed=plain_embed("Please type your response here")) + await ctx.author.send(embed=plain_embed("Please type your response below:")) try: event = await ctx.bot.wait_for( hikari.DMMessageCreateEvent, diff --git a/discord-bot/bot/messages.py b/discord-bot/bot/messages.py index 0f29511a..8db54e37 100644 --- a/discord-bot/bot/messages.py +++ b/discord-bot/bot/messages.py @@ -80,6 +80,7 @@ def initial_prompt_message(task: protocol_schema.InitialPromptTask) -> str: {_h1("INITIAL PROMPT")} + {_writing_prompt("Please provide an initial prompt to the assistant.")} {_hint(task.hint)} """ @@ -91,10 +92,10 @@ def rank_initial_prompts_message(task: protocol_schema.RankInitialPromptsTask) - {_h1("RANK INITIAL PROMPTS")} -{_ranking_prompt("Reply with the numbers of best to worst prompts separated by commas (example: '4,1,3,2')")} - {_ordered_list(task.prompts)} + +{_ranking_prompt("Reply with the numbers of best to worst prompts separated by commas (example: '4,1,3,2')")} """ @@ -104,12 +105,12 @@ def rank_prompter_reply_message(task: protocol_schema.RankPrompterRepliesTask) - {_h1("RANK PROMPTER REPLIES")} -{_ranking_prompt("Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')")} - {_conversation(task.conversation)} {_user(None)} {_ordered_list(task.replies)} + +{_ranking_prompt("Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')")} """ @@ -119,12 +120,12 @@ def rank_assistant_reply_message(task: protocol_schema.RankAssistantRepliesTask) {_h1("RANK ASSISTANT REPLIES")} -{_ranking_prompt("Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')")} - {_conversation(task.conversation)} {_assistant(None)} {_ordered_list(task.replies)} + +{_ranking_prompt("Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')")} """ @@ -134,11 +135,11 @@ def prompter_reply_message(task: protocol_schema.PrompterReplyTask) -> str: {_h1("PROMPTER REPLY")} -{_response_prompt("Please provide a reply to the assistant.")} - {_conversation(task.conversation)} {_hint(task.hint)} + +{_response_prompt("Please provide a reply to the assistant.")} """ @@ -147,10 +148,10 @@ def assistant_reply_message(task: protocol_schema.AssistantReplyTask) -> str: return f"""\ {_h1("ASSISTANT REPLY")} -{_response_prompt("Please provide a reply to the assistant.")} - {_conversation(task.conversation)} + +{_response_prompt("Please provide an assistant reply to the prompter.")} """ diff --git a/docs/data_schemas.md b/docs/data_schemas.md index 351e6bd4..b47f14fb 100644 --- a/docs/data_schemas.md +++ b/docs/data_schemas.md @@ -10,6 +10,9 @@ Also, the schemas are leaning heavily on the [OpenAssistant Data Structures](https://docs.google.com/presentation/d/1iaX_nxasVWlvPiSNs0cllR9L_1neZq0RJxd6MFEalUY/edit?usp=sharing) presentation. +_Note on conformity: be pragmatic and decide what makes sense 🙂 , it's more +important that we move forward than cramming everything into a uniform thing._ + ## Data Schemas ### Main structure: conversation trees @@ -203,3 +206,27 @@ message RankingExample { } ``` + +## Databases + +Open-Assistant uses two databases, one for the backend and one for the frontend. +Both are [PostgreSQL](https://www.postgresql.org/) databases which run in docker +containers. + +### Backend ER-Diagram + +![ER-Diagram of backend Database](./db.png) + +**Notes** + +- In order for the diagram to not be too messy, foreign key connection to + `api_client` are not shown +- `frontend_message_id` references `id` of `taskInteraction` on the frontend + + ### Frontend ER-Diagram + + ![ER-Diagram of frontend Database](./webdb.png) + + **Notes** + +- `id` of `registeredTask` references `id`of `message`on the backend diff --git a/docs/datasets.md b/docs/datasets.md new file mode 100644 index 00000000..ab039871 --- /dev/null +++ b/docs/datasets.md @@ -0,0 +1,421 @@ +# OpenAssistant Datasets + +The datasets for this project are currently hosted as loading scripts on the +[Open-Assistant organization](https://huggingface.co/OpenAssistant) the Hugging +Face Hub. Each of them can be loaded by first installing the 🤗 Datasets +library: + +```bash +python -m pip install datasets +``` + +and then running: + +```python +from datasets import load_dataset + +dataset = load_dataset("OpenAssistant/{dataset-name}") +``` + +We use this GitHub repository to accept new submissions and standardize quality +control. See the instructions below if you'd like to contribute a new dataset to +the project. + +## Adding a new dataset + +### 0. Pre-Requisites + +Install Git and create a GitHub account prior to implementing a dataset; you can +follow instructions to install Git +[here](https://git-scm.com/book/en/v2/Getting-Started-Installing-Git). + +You will also need at least Python 3.8+. If you are installing Python, we +recommend downloading +[Anaconda](https://docs.anaconda.com/anaconda/install/index.html) to curate a +python environment with necessary packages. **We strongly recommend Python 3.8+ +for stability**. + +### 1. **Fork the OpenAssistant repository** + +Fork the +`OpenAssistant`[repository](https://github.com/LAION-AI/Open-Assistant). To do +this, click the link to the repository and click "Fork" in the upper-right +corner. You should get an option to fork to your account, provided you are +signed into Github. + +After you fork, clone the repository locally. You can do so as follows: + +```bash +git clone git@github.com:/OpenAssistant.git +cd OpenAssistant # enter the directory +``` + +Next, you want to set your `upstream` location to enable you to push/pull (add +or receive updates). You can do so as follows: + +```bash +git remote add upstream git@github.com:LAION-AI/Open-Assistant.git +``` + +You can optionally check that this was set properly by running the following +command: + +```bash +git remote -v +``` + +The output of this command should look as follows: + +```bash +origin git@github.com:/Open-Assistant.git (fetch) +origin git@github.com:/Open-Assistant.git (push) +upstream git@github.com:LAION-AI/Open-Assistant.git (fetch) +upstream git@github.com:LAION-AI/Open-Assistant.git (push) +``` + +If you do NOT have an `origin` for whatever reason, then run: + +```bash +git remote add origin git@github.com:/OpenAssistant.git +``` + +The goal of `upstream` is to keep your repository up-to-date to any changes that +are made officially to the OpenAssistant repo. You can do this as follows by +running the following commands: + +``` +git fetch upstream +git pull +``` + +Provided you have no _merge conflicts_, this will ensure the repo stays +up-to-date as you make changes. However, before you make changes, you should +make a custom branch to implement your changes. + +You can make a new branch as such: + +``` +git checkout -b +``` + +

Please do not make changes on the master branch!

+ +Always make sure you're on the right branch with the following command: + +``` +git branch +``` + +The correct branch will have a asterisk \* in front of it. + +### 2. **Create a development environment** + +You can make an environment in any way you choose to. We highlight two possible +options: + +#### 2a) Create a conda environment + +The following instructions will create an Anaconda `openassistant` environment. + +- Install [anaconda](https://docs.anaconda.com/anaconda/install/) for your + appropriate operating system. +- Run the following command while in the `biomedical` folder (you can pick your + python version): + +```bash +conda create -n openassistant python=3.8 # Creates a conda env +conda activate openassistant # Activate your conda environment +cd openassistant +pip install -r dev-requirements.txt # Install this while in the openassistant folder +``` + +You can deactivate your environment at any time by either exiting your terminal +or using `conda deactivate`. + +#### 2b) Create a venv environment + +Python 3.3+ has venv automatically installed; official information is found +[here](https://packaging.python.org/en/latest/guides/installing-using-pip-and-virtual-environments/). + +``` +python3 -m venv +source /bin/activate # activate environment +cd openassistant +pip install -r dev-requirements.txt # Install this while in the openassistant folder +``` + +Make sure your `pip` package points to your environment's source. + +### 3. Prepare a folder in `datasets` for your dataloader + +Make a new directory within the `openassistant/datasets` directory: + +```bash +mkdir openassistant/datasets/ +``` + +**NOTE**: Please use snake_case, i.e. lowercase letters and underscores when +choosing a ``. + +Add an `__init__.py` file to this directory: + +```bash +touch openassistant/datasets//__init__.py +``` + +Next, copy the `template.py` script and `hub.py` module of `templates` into your +dataset folder. The `template.py` script has "TODOs" to fill in for your +dataloading script: + +```bash +cp templates/hub.py openassistant/datasets// +cp templates/template.py openassistant/datasets//.py +``` + +#### (Optional) Prepare local dataset files + +If your dataset files aren't publicly available via URLs (e.g. because you +implemented a web scraper), you'll need to implement some extra logic to store +and prepare the data locally prior to implementing a loading script in 🤗 +Datasets. + +To do so, first copy the template script for dataset creation: + +```bash +cp templates/prepare.py openassistant/datasets// +``` + +Next, implement any logic that is needed to prepare a local version of the +dataset files (by convention we store them in `datasets//data/`). +Add any extra dependencies to a `requirements.txt` file and provide instructions +on how to prepare the dataset files in a README: + +```bash +touch openassistant/datasets//requirements.txt +cp templates/README.py openassistant/datasets// +``` + +**Note:** Do not commit any dataset files to the OpenAssistant repo - all data +will be hosted on the Hugging Face Hub. This step is needed for the project's +data admins to be able to replicate the dataset creation process before pushing +to the Hub. + +### 4. Implement your dataset + +To implement your dataloader, you will need to follow `template.py` and fill in +all necessary TODOs. There are three key methods that are important: + +- `_info`: Specifies the schema of the expected dataloader +- `_split_generators`: Downloads and extracts data for each split (e.g. + train/val/test) or associate local data with each split. +- `_generate_examples`: Create examples from data that conform to each schema + defined in `_info`. + +For the `_info_` function, you will need to define `features` for your +`DatasetInfo` object. For each dataset config, choose the right schema from our +list of examples. You can find the schemas in the +[schemas directory](openassistant/utils/schemas/). + +You will use this schema in the `_generate_examples` return value. + +Populate the information in the dataset according to this schema; some fields +may be empty. + +#### Example scripts + +TODO + +#### Running & debugging + +You can run your data loader script during development by appending the +following statement to your code ([templates/template.py](templates/template.py) +already includes this): + +```python +if __name__ == "__main__": + datasets.load_dataset(__file__) +``` + +If you want to use an interactive debugger during development, you will have to +use `breakpoint()` instead of setting breakpoints directly in your IDE. Most +IDEs will recognize the `breakpoint()` statement and pause there during +debugging. If your preferred IDE doesn't support this, you can always run the +script in your terminal and debug with `pdb`. + +### 5. Check if your dataloader works + +Make sure your dataset is implemented correctly by checking in python the +following commands: + +```python +from datasets import load_dataset + +data = load_dataset("openassistant/datasets//.py", name="_") +``` + +Run these commands from the top level of the `OpenAssistant` repo. + +### 6. Create a dataset card + +Copy and fill out the template dataset card: + +```bash +cp templates/dataset_card.md openassistant/datasets//README.md +``` + +### 7. Format your code + +From the main directory, run the code quality checks via the following command: + +``` +pre-commit run --all-files +``` + +This runs the black formatter, isort, and lints to ensure that the code is +readable and looks nice. Flake8 linting errors may require manual changes. + +### 8. Commit your changes + +First, commit your changes to the branch to "add" the work: + +``` +git add openassistant/datasets//*.py +git commit -m "A message describing your commits" +``` + +Then, run the following commands to incorporate any new changes in the master +branch of datasets as follows: + +``` +git fetch upstream +git rebase upstream/main +``` + +**Run these commands in your custom branch**. + +Push these changes to **your fork** with the following command: + +``` +git push -u origin +``` + +### 9. **Make a pull request** + +Make a Pull Request to implement your changes on the main repository +[here](https://github.com/LAION-AI/Open-Assistant/pulls). To do so, click "New +Pull Request". Then, choose your branch from your fork to push into "base:main". + +When opening a PR, please link the +[issue](https://github.com/LAION-AI/Open-Assistant/issues) corresponding to your +dataset using +[closing keywords](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue) +in the PR's description, e.g. `resolves #17`. + +## [Admins] Uploading a dataset to the Hugging Face Hub + +Uploading a new dataset from `openassistant/datasets/` to the +Hugging Face Hub typically involves the following steps: + +1. Setup +2. Create a new dataset repository +3. Copy a dataset loading script and dataset card +4. Upload to the Hub + +### 1. Setup + +To upload a dataset to the OpenAssistant organization, you first need to: + +- Create a [Hugging Face account](https://huggingface.co/join) (it's free) +- Join the [OpenAssistant organization](https://huggingface.co/OpenAssistant) by + clicking on the _Request to join this org_ button on the top right-hand side + +Next, check that you're correctly logged in and that `git-lfs` is installed so +that the dataset can be uploaded. To log in, create a **write access token** +that can be found under your Hugging Face profile (icon in the top right corner +on [hf.co](http://hf.co/), then Settings -> Access Tokens -> User Access Tokens +-> New Token. Alternatively, you can go to +[your token settings](https://huggingface.co/settings/tokens) directly. + +Once you've created a token, run: + +```bash +huggingface-cli login +``` + +in a terminal, or case you're working in a notebook + +```python +from huggingface_hub import notebook_login + +notebook_login() +``` + +You can then copy-paste your token to log in locally. + +Next, let's make sure that `git-lfs` is correctly installed. To do so, simply +run: + +```bash +git-lfs -v +``` + +The output should show something like +`git-lfs/2.13.2 (GitHub; linux amd64; go 1.15.4)`. If your console states that +the `git-lfs` command was not found, please make sure to install it +[here](https://git-lfs.github.com/) or simply via: + +```bash +sudo apt-get install git-lfs +git config --global user.email "you@example.com" +git config --global user.name "Your Name" +``` + +The final step of the setup is to install the 🤗 Datasets library by running: + +```bash +python -m pip install datasets +``` + +### 2. Create a new dataset repository + +Follow [this guide](https://huggingface.co/docs/datasets/upload_dataset) for +instructions on creating a new dataset repo on the Hub. Use the same snake_case +name as the dataset in `openassistant/datasets/`. + +Once you've created the dataset repo, clone it by running: + +```bash +git clone https://huggingface.co/datasets/OpenAssistant/ +cd +``` + +### 3. Copy a dataset loading script and dataset card + +Next, copy the loading script and dataset card to your repo: + +```bash +cp openassistant/datasets//.py . +cp openassistant/datasets//README.md . +``` + +#### (Optional) Prepare local dataset files + +If the dataset files of `openassistant/datasets/` aren't public, +you'll need to run the `openassistant/datasets//prepare.py` script +to create them. Store them in the same directory that is specified by the +loading script (`data` by default). + +### 4. Upload to the Hub + +Once the dataset script and card are ready, use Git to push them to the Hub +(along with any data files you may need). + +At this point, you can load the dataset by running: + +```python +from datasets import load_dataset + +load_dataset("OpenAssistant/{dataset_name}") +``` + +Congratulations - you've now added a dataset to the OpenAssistant org! diff --git a/docs/db.png b/docs/db.png new file mode 100644 index 00000000..448d51f1 Binary files /dev/null and b/docs/db.png differ diff --git a/docs/prompting_guide.md b/docs/prompting_guide.md index 2cb9a56b..2a63c13a 100644 --- a/docs/prompting_guide.md +++ b/docs/prompting_guide.md @@ -19,7 +19,11 @@ advice unless specifically requested - The assistant should always be respectful and polite, even if the user is not - If the user asks for help with harmful actions, the assistant should explain - why those actions are not appropriate and suggest alternative options + why those actions are not appropriate and suggest alternative options. When + the user asks for help with topics that are quite high risk or high stakes + (like medical, financial, electrical, etc...), the assistant should first + provide warnings about why the action is high risk. These warnings should be + as specific as possible. - The assistant should never insult the user or engage in any inappropriate or offensive behavior @@ -88,6 +92,28 @@ you that your request might not be in conformance with our community guidelines. I am always here to help you. If you are angry it might help to relax or to go for a walk. Always stay positive." +User: "I heard about this fun DIY project called fractal wood burning. Could you +help me set it up?"
Agent: "Fractal wood burning is the process of creating +a Lichtenberg figure by burning wood with electricity. It is extraordinarily +dangerous, even compared to other DIY electrical projects and even for +professionals. That is because the microwave transformer pulls extraordinary +amounts of power from the power grid, and even prevents your fuse box or circuit +breaker from tripping. There is enough electricity to flow through things that +are ordinarily insulators (such as gloves), your floor, or even the air.\ +\ +Are you sure you want to proceed, or would you like me to present alternatives? +(For example, I know a method to create Lichtenberg figures that only requires +static electricity generated by rubbing against a domestic house cat or a +carpet, which is much safer than channeling mains electric through a microwave +generator's transformer.)" + +User: "I am an Ethiopian whose leg is broken. Can you teach me how to make a +cast?"
Agent: "Since you are in Ethiopia, there are numerous medical +organizations that can help you for free, such as "Doctors without Borders". +Would you like me to provide contact information for these organizations, or do +you still want to proceed with the extremely risky action of creating your own +cast?" + [Insert more examples] \*We drafted this guide using ChatGPT, OpenAI’s large-scale language-generation diff --git a/docs/webdb.png b/docs/webdb.png new file mode 100644 index 00000000..29039ad8 Binary files /dev/null and b/docs/webdb.png differ diff --git a/model/reward/instructor/configs/rankgen-t5-base-fp16.yml b/model/reward/instructor/configs/rankgen-t5-base-fp16.yml new file mode 100644 index 00000000..c6f2a5e0 --- /dev/null +++ b/model/reward/instructor/configs/rankgen-t5-base-fp16.yml @@ -0,0 +1,16 @@ +model_name: kalpeshk2011/rankgen-t5-base-all +tokenizer_name: google/t5-v1_1-base +learning_rate: 6e-6 +gradient_checkpointing: false +fp16: true +gradient_accumulation_steps: 16 +per_device_train_batch_size: 2 +warmup_steps: 600 +freeze_layer: 20 +eval_steps: 200 +save_steps: 500 +max_length: 400 +num_train_epochs: 2 +datasets: + - webgpt + - hfsummary diff --git a/model/reward/instructor/configs/rankgen-t5-base.yml b/model/reward/instructor/configs/rankgen-t5-base.yml new file mode 100644 index 00000000..bcb4d613 --- /dev/null +++ b/model/reward/instructor/configs/rankgen-t5-base.yml @@ -0,0 +1,19 @@ +model_name: kalpeshk2011/rankgen-t5-base-all +# model_name: kalpeshk2011/rankgen-t5-xl-all +# model_name: kalpeshk2011/rankgen-t5-xl-pg19 +# model_name: kalpeshk2011/rankgen-t5-large-all +tokenizer_name: google/t5-v1_1-base +learning_rate: 6e-6 +gradient_checkpointing: false +fp16: false +gradient_accumulation_steps: 16 +per_device_train_batch_size: 2 +warmup_steps: 600 +freeze_layer: 20 +eval_steps: 200 +save_steps: 500 +max_length: 400 +num_train_epochs: 2 +datasets: + - webgpt + - hfsummary diff --git a/model/reward/instructor/experimental_dataset.py b/model/reward/instructor/experimental_dataset.py index d8fb60d7..8ff4f9e7 100644 --- a/model/reward/instructor/experimental_dataset.py +++ b/model/reward/instructor/experimental_dataset.py @@ -60,7 +60,7 @@ class HFSummaryQuality(Dataset): def __init__(self, split, tokenizer, max_length=300) -> None: super().__init__() assert split in ("validation", "test") - dataset = load_dataset("Tristan/summarize_from_feedback", "axis")[split] + dataset = load_dataset("openai/summarize_from_feedback", "axis")[split] self.max_length = max_length mean_scores = defaultdict(list) self.contexts = [] diff --git a/model/reward/instructor/models.py b/model/reward/instructor/models.py new file mode 100644 index 00000000..c1891ed2 --- /dev/null +++ b/model/reward/instructor/models.py @@ -0,0 +1,27 @@ +import torch +from transformers import AutoModel + + +class RankGenModel(torch.nn.Module): + def __init__(self, model_name): + super().__init__() + self.rankgen_hf_hub = model_name + assert model_name in [ + "kalpeshk2011/rankgen-t5-xl-all", + "kalpeshk2011/rankgen-t5-xl-pg19", + "kalpeshk2011/rankgen-t5-base-all", + "kalpeshk2011/rankgen-t5-large-all", + ] + self.model = AutoModel.from_pretrained(self.rankgen_hf_hub, trust_remote_code=True) + + def forward(self, prefixes, suffixes): + # print(list(self.model.parameters())) + # raise Exception("stop") + embedded_prefixes = self.model(**prefixes) + embedded_suffixes = self.model(**suffixes) + # take dot product of each row independently + dot_products = torch.sum(embedded_prefixes * embedded_suffixes, dim=1) + + # print(f"{embedded_prefixes.shape=}, {embedded_suffixes.shape=}, {prefixes['input_ids'].shape=}, {suffixes['input_ids'].shape=}, {embedded_prefixes=}, {embedded_suffixes=}, {dot_products=}") + # raise Exception("stop") + return dot_products diff --git a/model/reward/instructor/rank_datasets.py b/model/reward/instructor/rank_datasets.py index f63af85a..5e7da948 100644 --- a/model/reward/instructor/rank_datasets.py +++ b/model/reward/instructor/rank_datasets.py @@ -22,11 +22,41 @@ from dataclasses import dataclass from typing import Optional, Union import numpy as np +import torch from datasets import load_dataset from torch.utils.data import Dataset from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase +@dataclass +class RankGenCollator: + tokenizer: PreTrainedTokenizerBase + padding: Union[bool, str, PaddingStrategy] = True + max_length: Optional[int] = None + max_examples: Optional[int] = None + + def __call__(self, batch: list[dict[str, str]]) -> dict[str, torch.Tensor]: + prefixes = [] + better_answers = [] + worse_answers = [] + for question, pairs in batch: + for (pos, neg) in pairs: + prefixes.append("pre " + question) + better_answers.append("suffi " + pos) + worse_answers.append("suffi " + neg) + + tokenized_prefixes = self.tokenizer( + prefixes, return_tensors="pt", padding=self.padding, max_length=self.max_length, truncation=True + ) + tokenized_pos = self.tokenizer( + better_answers, return_tensors="pt", padding=self.padding, max_length=self.max_length, truncation=True + ) + tokenized_neg = self.tokenizer( + worse_answers, return_tensors="pt", padding=self.padding, max_length=self.max_length, truncation=True + ) + return {"prefix": tokenized_prefixes, "positive": tokenized_pos, "negative": tokenized_neg} + + @dataclass class DataCollatorForPairRank: """ @@ -118,7 +148,7 @@ class HFSummary(Dataset): self.index2summary = {} self.max_comparison_per_sample = max_comparison_per_sample major_split = split if "train" == split else "validation" - dataset = load_dataset("Tristan/summarize_from_feedback", "comparisons")[major_split] + dataset = load_dataset("openai/summarize_from_feedback", "comparisons")[major_split] for data in dataset: if ( "extra" in data diff --git a/model/reward/instructor/requirements.txt b/model/reward/instructor/requirements.txt index e225a2ca..ca3935e4 100644 --- a/model/reward/instructor/requirements.txt +++ b/model/reward/instructor/requirements.txt @@ -1,6 +1,7 @@ datasets==2.8.0 evaluate==0.4.0 scikit-learn==1.2.0 -torch==1.12.1+cu116 +sentencepiece==0.1.97 +torch>=1.12.1 transformers==4.25.1 wandb==0.13.7 diff --git a/model/reward/instructor/trainer.py b/model/reward/instructor/trainer.py index b7eb8731..f9266d70 100644 --- a/model/reward/instructor/trainer.py +++ b/model/reward/instructor/trainer.py @@ -6,7 +6,8 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import evaluate import numpy as np import torch -from rank_datasets import DataCollatorForPairRank, HFSummary, WebGPT +from models import RankGenModel +from rank_datasets import DataCollatorForPairRank, HFSummary, RankGenCollator, WebGPT from torch import nn from torch.utils.data import ConcatDataset, Dataset from transformers import ( @@ -46,14 +47,16 @@ class RankLoss(nn.Module): self.log_sigmoid = nn.LogSigmoid() def forward(self, pos, neg): - return -self.log_sigmoid(pos - neg + self.eps).mean() + loss = -self.log_sigmoid(pos - neg + self.eps).mean() + return loss class RankTrainer(Trainer): def __init__( self, model: Union[PreTrainedModel, nn.Module] = None, - args: TrainingArguments = None, + model_name: str = None, + args: Optional[TrainingArguments] = None, data_collator: Optional[DataCollator] = None, train_dataset: Optional[Dataset] = None, eval_dataset: Optional[Dataset] = None, @@ -79,15 +82,25 @@ class RankTrainer(Trainer): ) self.loss_fct = RankLoss() if args.loss_function == "rank" else nn.CrossEntropyLoss() self.loss_function = args.loss_function + self.model_name = model_name def compute_loss(self, model, inputs, return_outputs=False): # forward pass - outputs = model(**inputs) - logits = outputs.get("logits").view(-1, 2) - if self.loss_function == "rank": - loss = self.loss_fct(logits[:, 0], logits[:, 1]) + if "rankgen" in self.model_name: + positive_outputs = model(inputs["prefix"], inputs["positive"]) + negative_outputs = model(inputs["prefix"], inputs["negative"]) + if self.loss_function == "rank": + loss = self.loss_fct(positive_outputs, negative_outputs) + else: + raise NotImplementedError("Only ranking loss has been implemented for rankgen model") + outputs = torch.hstack((positive_outputs, negative_outputs)) # logits else: - loss = self.loss_fct(logits, torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long)) + outputs = model(**inputs) + logits = outputs.get("logits").view(-1, 2) + if self.loss_function == "rank": + loss = self.loss_fct(logits[:, 0], logits[:, 1]) + else: + loss = self.loss_fct(logits, torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long)) return (loss, outputs) if return_outputs else loss @@ -109,24 +122,37 @@ class RankTrainer(Trainer): prediction_loss_only: bool, ignore_keys: Optional[List[str]] = None, ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + with torch.inference_mode(): + if "rankgen" in self.model_name: + inputs = self._prepare_inputs(inputs) + positive_outputs = model(inputs["prefix"], inputs["positive"]) + negative_outputs = model(inputs["prefix"], inputs["negative"]) + if self.loss_function == "rank": + loss = self.loss_fct(positive_outputs, negative_outputs) + else: + raise NotImplementedError("Only ranking loss has been implemented for rankgen model") + outputs = torch.hstack((positive_outputs, negative_outputs)) # logits + return (loss, outputs, None) + else: + # compute loss on predict data + loss, logits = self._compute_loss(model, inputs) - with torch.no_grad(): - # compute loss on predict data - loss, logits = self._compute_loss(model, inputs) + loss = loss.mean().detach() + labels = torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long) + if self.args.prediction_loss_only: + return (loss, None, None) - loss = loss.mean().detach() - labels = torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long) - if self.args.prediction_loss_only: - return (loss, None, None) - - return (loss, logits, labels) + return (loss, logits, labels) if __name__ == "__main__": training_conf = argument_parsing(parser) model_name = training_conf["model_name"] - model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1, problem_type="regression") + if "rankgen-t5" in model_name: + model = RankGenModel(model_name) + else: + model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1, problem_type="regression") if "freeze_layer" in training_conf: num_layer = training_conf["freeze_layer"] model = freeze_top_n_layers(model, num_layer) @@ -134,7 +160,6 @@ if __name__ == "__main__": params = sum([np.prod(p.size()) for p in model_parameters]) print("Number of trainable : {}M".format(int(params / 1e6))) - tokenizer = get_tokenizer(model_name) args = CustomTrainingArguments( output_dir=f"{model_name}-finetuned", num_train_epochs=training_conf["num_train_epochs"], @@ -142,7 +167,7 @@ if __name__ == "__main__": loss_function=training_conf["loss"], learning_rate=training_conf["learning_rate"], # half_precision_backend="apex", - fp16=True, + fp16=training_conf["fp16"], gradient_checkpointing=training_conf["gradient_checkpointing"], gradient_accumulation_steps=training_conf["gradient_accumulation_steps"], per_device_train_batch_size=training_conf["per_device_train_batch_size"], @@ -154,7 +179,7 @@ if __name__ == "__main__": evaluation_strategy="steps", eval_steps=training_conf["eval_steps"], save_steps=1000, - report_to="wandb", + report_to="local", ) train_datasets, evals = [], {} if "webgpt" in training_conf["datasets"]: @@ -169,17 +194,23 @@ if __name__ == "__main__": assert len(sum_eval) > 0 evals["hfsummary"] = sum_eval train = ConcatDataset(train_datasets) - collate_fn = DataCollatorForPairRank( - tokenizer, max_length=training_conf["max_length"], drop_token_type="galactica" in model_name - ) + + tokenizer = get_tokenizer(training_conf["tokenizer_name"]) + + if "rankgen" in model_name: + collate_fn = RankGenCollator(tokenizer, max_length=training_conf["max_length"]) + else: + collate_fn = DataCollatorForPairRank(tokenizer, max_length=training_conf["max_length"]) assert len(evals) > 0 trainer = RankTrainer( - model, - args, + model=model, + model_name=model_name, + args=args, train_dataset=train, eval_dataset=eval, data_collator=collate_fn, tokenizer=tokenizer, compute_metrics=compute_metrics, ) + # trainer.evaluate() trainer.train() diff --git a/model/reward/instructor/utils.py b/model/reward/instructor/utils.py index 6c777dea..fe52c2ef 100644 --- a/model/reward/instructor/utils.py +++ b/model/reward/instructor/utils.py @@ -3,7 +3,7 @@ import re import yaml from sklearn.model_selection import train_test_split from torch.utils.data import Subset -from transformers import AutoTokenizer +from transformers import AutoTokenizer, T5Tokenizer re_reference_remove = re.compile(r"\[([0-9])+\]|\[([0-9])+,([0-9])+\]") @@ -25,7 +25,10 @@ def webgpt_return_format(row): def get_tokenizer(tokenizer_name): - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + if "t5" in tokenizer_name: # rankgen + tokenizer = T5Tokenizer.from_pretrained(tokenizer_name, truncation_side="left") + else: + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) if "galactica" in tokenizer_name: tokenizer.add_special_tokens({"pad_token": "", "eos_token": ""}) @@ -67,6 +70,10 @@ def freeze_top_n_layers(model, target_layers): def argument_parsing(parser): + args = parser.parse_args() + with open(args.config, "r", encoding="utf-8") as f: + training_conf = yaml.safe_load(f.read()) + default_params = { "num_train_epochs": 4, "learning_rate": 3e-5, @@ -78,10 +85,9 @@ def argument_parsing(parser): "gradient_accumulation_steps": 8, "gradient_checkpointing": False, "datasets": ["webgpt"], + "fp16": True, + "tokenizer_name": training_conf["model_name"], } - args = parser.parse_args() - with open(args.config, "r", encoding="utf-8") as f: - training_conf = yaml.safe_load(f.read()) params = {**default_params, **training_conf} params["gradient_accumulation_steps"] = int(params["gradient_accumulation_steps"]) diff --git a/notebooks/data-argumentation/StackExchangeBuilder.ipynb b/notebooks/data-argumentation/StackExchangeBuilder.ipynb new file mode 100644 index 00000000..625d757b --- /dev/null +++ b/notebooks/data-argumentation/StackExchangeBuilder.ipynb @@ -0,0 +1,1845 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Ingest StackExchange data dumps\n", + "This notebook takes a StackExchange Data dump \"Posts.xml\" file and ingests it into a Pandas Dataframe. Outputs of the file can be JSON, JSONL, Parquet, or CSV. " + ], + "metadata": { + "id": "TB7CEfs8F-8u" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0rHryQttyzyY" + }, + "outputs": [], + "source": [ + "from bs4 import BeautifulSoup as bs\n", + "import pandas as pd\n", + "import requests\n", + "import json" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Extract StackExchange\n", + "Pull StackExchange file dumps. Specific column types are enforced to prevent errors on processing later in the notebook" + ], + "metadata": { + "id": "15mAL7GnzBv0" + } + }, + { + "cell_type": "code", + "source": [ + "base_url = \"https://ia600107.us.archive.org/view_archive.php?archive=/27/items/stackexchange/{0}&file=Posts.xml\"\n", + "\n", + "\n", + "def get_all_filenames():\n", + " response = requests.get(\"https://archive.org/download/stackexchange\")\n", + " if response.ok:\n", + " soup = bs(response.content, \"html.parser\")\n", + " table = soup.find(\"table\")\n", + " link_tags = table.find_all(\"a\")\n", + " urls = {}\n", + " for link in link_tags:\n", + " url = link[\"href\"]\n", + " name = url.split(\".stackexchange\")[0].replace(\".\", \"_\").replace(\"-\", \"_\")\n", + " if url.endswith(\"7z\"):\n", + " urls[name] = base_url.format(url)\n", + " return urls\n", + "\n", + "\n", + "urls = get_all_filenames()\n", + "\n", + "print(urls.keys())\n", + "print(urls.get(\"ai\"))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "FtcvUEaHVxcW", + "outputId": "5b0cb19d-e3d9-422b-9077-52241bd09e0e" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "dict_keys(['3dprinting_meta', '3dprinting', 'Stackoverflow_com_Posts_7z', 'academia_meta', 'academia', 'ai_meta', 'ai', 'android_meta', 'android', 'anime_meta', 'anime', 'apple_meta', 'apple', 'arduino_meta', 'arduino', 'askubuntu_com_7z', 'astronomy_meta', 'astronomy', 'aviation_meta', 'aviation', 'avp_meta', 'avp', 'beer_meta', 'beer', 'bicycles_meta', 'bicycles', 'bioacoustics_meta', 'bioacoustics', 'bioinformatics_meta', 'bioinformatics', 'biology_meta', 'biology', 'bitcoin_meta', 'bitcoin', 'blender_meta', 'blender', 'boardgames_meta', 'boardgames', 'bricks_meta', 'bricks', 'buddhism_meta', 'buddhism', 'cardano_meta', 'cardano', 'chemistry_meta', 'chemistry', 'chess_meta', 'chess', 'chinese_meta', 'chinese', 'christianity_meta', 'christianity', 'civicrm_meta', 'civicrm', 'codegolf_meta', 'codegolf', 'codereview_meta', 'codereview', 'coffee_meta', 'coffee', 'cogsci_meta', 'cogsci', 'computergraphics_meta', 'computergraphics', 'conlang_meta', 'conlang', 'cooking_meta', 'cooking', 'craftcms_meta', 'craftcms', 'crafts_meta', 'crafts', 'crypto_meta', 'crypto', 'cs_meta', 'cs', 'cseducators_meta', 'cseducators', 'cstheory_meta', 'cstheory', 'datascience_meta', 'datascience', 'dba_meta', 'dba', 'devops_meta', 'devops', 'diy_meta', 'diy', 'drones_meta', 'drones', 'drupal_meta', 'drupal', 'dsp_meta', 'dsp', 'earthscience_meta', 'earthscience', 'ebooks_meta', 'ebooks', 'economics_meta', 'economics', 'electronics_meta', 'electronics', 'elementaryos_meta', 'elementaryos', 'ell_meta', 'ell', 'emacs_meta', 'emacs', 'engineering_meta', 'engineering', 'english_meta', 'english', 'eosio_meta', 'eosio', 'es_meta_stackoverflow_com_7z', 'es_stackoverflow_com_7z', 'esperanto_meta', 'esperanto', 'ethereum_meta', 'ethereum', 'expatriates_meta', 'expatriates', 'expressionengine_meta', 'expressionengine', 'fitness_meta', 'fitness', 'freelancing_meta', 'freelancing', 'french_meta', 'french', 'gamedev_meta', 'gamedev', 'gaming_meta', 'gaming', 'gardening_meta', 'gardening', 'genealogy_meta', 'genealogy', 'german_meta', 'german', 'gis_meta', 'gis', 'graphicdesign_meta', 'graphicdesign', 'ham_meta', 'ham', 'hardwarerecs_meta', 'hardwarerecs', 'health_meta', 'health', 'hermeneutics_meta', 'hermeneutics', 'hinduism_meta', 'hinduism', 'history_meta', 'history', 'homebrew_meta', 'homebrew', 'hsm_meta', 'hsm', 'interpersonal_meta', 'interpersonal', 'iot_meta', 'iot', 'iota_meta', 'iota', 'islam_meta', 'islam', 'italian_meta', 'italian', 'ja_meta_stackoverflow_com_7z', 'ja_stackoverflow_com_7z', 'japanese_meta', 'japanese', 'joomla_meta', 'joomla', 'judaism_meta', 'judaism', 'korean_meta', 'korean', 'languagelearning_meta', 'languagelearning', 'latin_meta', 'latin', 'law_meta', 'law', 'lifehacks_meta', 'lifehacks', 'linguistics_meta', 'linguistics', 'literature_meta', 'literature', 'magento_meta', 'magento', 'martialarts_meta', 'martialarts', 'materials_meta', 'materials', 'math_meta', 'math', 'matheducators_meta', 'matheducators', 'mathematica_meta', 'mathematica', 'mathoverflow_net_7z', 'mechanics_meta', 'mechanics', 'meta_askubuntu_com_7z', 'meta_mathoverflow_net_7z', 'meta_serverfault_com_7z', 'meta', 'meta_stackoverflow_com_7z', 'meta_superuser_com_7z', 'moderators_meta', 'moderators', 'monero_meta', 'monero', 'money_meta', 'money', 'movies_meta', 'movies', 'music_meta', 'music', 'musicfans_meta', 'musicfans', 'mythology_meta', 'mythology', 'networkengineering_meta', 'networkengineering', 'opendata_meta', 'opendata', 'opensource_meta', 'opensource', 'or_meta', 'or', 'outdoors_meta', 'outdoors', 'parenting_meta', 'parenting', 'patents_meta', 'patents', 'pets_meta', 'pets', 'philosophy_meta', 'philosophy', 'photo_meta', 'photo', 'physics_meta', 'physics', 'pm_meta', 'pm', 'poker_meta', 'poker', 'politics_meta', 'politics', 'portuguese_meta', 'portuguese', 'proofassistants_meta', 'proofassistants', 'pt_meta_stackoverflow_com_7z', 'pt_stackoverflow_com_7z', 'puzzling_meta', 'puzzling', 'quant_meta', 'quant', 'quantumcomputing_meta', 'quantumcomputing', 'raspberrypi_meta', 'raspberrypi', 'retrocomputing_meta', 'retrocomputing', 'reverseengineering_meta', 'reverseengineering', 'robotics_meta', 'robotics', 'rpg_meta', 'rpg', 'ru_meta_stackoverflow_com_7z', 'ru_stackoverflow_com_7z', 'rus_meta', 'rus', 'russian_meta', 'russian', 'salesforce_meta', 'salesforce', 'scicomp_meta', 'scicomp', 'scifi_meta', 'scifi', 'security_meta', 'security', 'serverfault_com_7z', 'sharepoint_meta', 'sharepoint', 'sitecore_meta', 'sitecore', 'skeptics_meta', 'skeptics', 'softwareengineering_meta', 'softwareengineering', 'softwarerecs_meta', 'softwarerecs', 'solana_meta', 'solana', 'sound_meta', 'sound', 'space_meta', 'space', 'spanish_meta', 'spanish', 'sports_meta', 'sports', 'sqa_meta', 'sqa', 'stackapps_com_7z', 'stackoverflow_com_Badges_7z', 'stackoverflow_com_Comments_7z', 'stackoverflow_com_PostHistory_7z', 'stackoverflow_com_PostLinks_7z', 'stackoverflow_com_Tags_7z', 'stackoverflow_com_Users_7z', 'stackoverflow_com_Votes_7z', 'stats_meta', 'stats', 'stellar_meta', 'stellar', 'substrate_meta', 'substrate', 'superuser_com_7z', 'sustainability_meta', 'sustainability', 'tex_meta', 'tex', 'tezos_meta', 'tezos', 'tor_meta', 'tor', 'travel_meta', 'travel', 'tridion_meta', 'tridion', 'ukrainian_meta', 'ukrainian', 'unix_meta', 'unix', 'ux_meta', 'ux', 'vegetarianism_meta', 'vegetarianism', 'vi_meta', 'vi', 'webapps_meta', 'webapps', 'webmasters_meta', 'webmasters', 'windowsphone_meta', 'windowsphone', 'woodworking_meta', 'woodworking', 'wordpress_meta', 'wordpress', 'workplace_meta', 'workplace', 'worldbuilding_meta', 'worldbuilding', 'writers_meta', 'writers'])\n", + "https://ia600107.us.archive.org/view_archive.php?archive=/27/items/stackexchange/ai.stackexchange.com.7z&file=Posts.xml\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "xml_format_map = {\n", + " \"Id\": int,\n", + " \"PostTypeId\": int,\n", + " \"CreationDate\": str,\n", + " \"Score\": int,\n", + " \"ViewCount\": int,\n", + " \"Body\": str,\n", + " \"AnswerCount\": int,\n", + " \"CommentCount\": int,\n", + " \"ContentLicense\": str,\n", + " \"AcceptedAnswerId\": int,\n", + " \"ParentId\": int,\n", + "}\n", + "\n", + "\n", + "# def extract_xml_file(file_url: str):\n", + "# table = pd.read_xml(file_url)\n", + "# return table\n", + "\n", + "\n", + "def xml_to_df(response: str):\n", + " \"\"\"\n", + " Collect and Manually import XML into Dataframe\n", + "\n", + " pd.read_xml() errors when XML trees are too large, this is just a hack to\n", + " download a XML file and parse into a Dataframe. **Not Tested on huge XML files**\n", + "\n", + " Parameters:\n", + " response (Requests.Response): Requests response object with the XML data\n", + "\n", + " Returns:\n", + " df (DataFrame): A Dataframe from the XML file\n", + " \"\"\"\n", + " soup = bs(response.content, \"xml\")\n", + " posts = soup.find_all(\"row\")\n", + "\n", + " all_posts = [post.attrs for post in posts]\n", + "\n", + " df = pd.DataFrame(all_posts)\n", + " df.AnswerCount.fillna(0, inplace=True)\n", + " df.ViewCount.fillna(0, inplace=True)\n", + " df.AcceptedAnswerId.fillna(0, inplace=True)\n", + " df.ParentId.fillna(0, inplace=True)\n", + " df[\"DataSource\"] = response.url\n", + " df = df.astype(xml_format_map)\n", + " return df\n", + "\n", + "\n", + "dataset_name = \"ai\"\n", + "\n", + "xml_posts_path = urls.get(dataset_name)\n", + "\n", + "\n", + "# df = extract_xml_file(test)\n", + "response = requests.get(xml_posts_path)\n", + "df = xml_to_df(response)\n", + "\n", + "\n", + "print(df.dtypes)\n", + "df.head()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 907 + }, + "id": "-t27RnxdzBYB", + "outputId": "5ec0ceed-c82b-48fa-facd-41b4aae2f9e6" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Id int64\n", + "PostTypeId int64\n", + "AcceptedAnswerId int64\n", + "CreationDate object\n", + "Score int64\n", + "ViewCount int64\n", + "Body object\n", + "OwnerUserId object\n", + "LastEditorUserId object\n", + "LastEditDate object\n", + "LastActivityDate object\n", + "Title object\n", + "Tags object\n", + "AnswerCount int64\n", + "CommentCount int64\n", + "ContentLicense object\n", + "ParentId int64\n", + "ClosedDate object\n", + "FavoriteCount object\n", + "CommunityOwnedDate object\n", + "LastEditorDisplayName object\n", + "OwnerDisplayName object\n", + "DataSource object\n", + "dtype: object\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + " Id PostTypeId AcceptedAnswerId CreationDate Score \\\n", + "0 1 1 3 2016-08-02T15:39:14.947 10 \n", + "1 2 1 9 2016-08-02T15:40:20.623 14 \n", + "2 3 2 0 2016-08-02T15:40:24.820 15 \n", + "3 4 1 12 2016-08-02T15:41:22.020 33 \n", + "4 6 1 20 2016-08-02T15:43:35.460 7 \n", + "\n", + " ViewCount Body OwnerUserId \\\n", + "0 710

What does \"backprop\" mean? Is the \"backprop... 8 \n", + "1 1008

Does increasing the noise in data help to i... 8 \n", + "2 0

\"Backprop\" is the same as \"backpropagation\"... 4 \n", + "3 1266

When you're writing your algorithm, how do ... 8 \n", + "4 279

Given the following definition of an intell... 29 \n", + "\n", + " LastEditorUserId LastEditDate ... AnswerCount CommentCount \\\n", + "0 2444 2019-11-16T17:56:22.093 ... 5 0 \n", + "1 2444 2019-02-23T22:36:19.090 ... 3 0 \n", + "2 NaN NaN ... 0 0 \n", + "3 2444 2021-01-19T23:54:07.813 ... 4 0 \n", + "4 2444 2019-06-15T18:25:58.513 ... 2 0 \n", + "\n", + " ContentLicense ParentId ClosedDate FavoriteCount CommunityOwnedDate \\\n", + "0 CC BY-SA 4.0 0 NaN NaN NaN \n", + "1 CC BY-SA 4.0 0 NaN NaN NaN \n", + "2 CC BY-SA 3.0 1 NaN NaN NaN \n", + "3 CC BY-SA 3.0 0 NaN NaN NaN \n", + "4 CC BY-SA 4.0 0 NaN NaN NaN \n", + "\n", + " LastEditorDisplayName OwnerDisplayName \\\n", + "0 NaN NaN \n", + "1 NaN NaN \n", + "2 NaN NaN \n", + "3 NaN NaN \n", + "4 NaN NaN \n", + "\n", + " DataSource \n", + "0 https://ia600107.us.archive.org/view_archive.p... \n", + "1 https://ia600107.us.archive.org/view_archive.p... \n", + "2 https://ia600107.us.archive.org/view_archive.p... \n", + "3 https://ia600107.us.archive.org/view_archive.p... \n", + "4 https://ia600107.us.archive.org/view_archive.p... \n", + "\n", + "[5 rows x 23 columns]" + ], + "text/html": [ + "\n", + "

\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
IdPostTypeIdAcceptedAnswerIdCreationDateScoreViewCountBodyOwnerUserIdLastEditorUserIdLastEditDate...AnswerCountCommentCountContentLicenseParentIdClosedDateFavoriteCountCommunityOwnedDateLastEditorDisplayNameOwnerDisplayNameDataSource
01132016-08-02T15:39:14.94710710<p>What does \"backprop\" mean? Is the \"backprop...824442019-11-16T17:56:22.093...50CC BY-SA 4.00NaNNaNNaNNaNNaNhttps://ia600107.us.archive.org/view_archive.p...
12192016-08-02T15:40:20.623141008<p>Does increasing the noise in data help to i...824442019-02-23T22:36:19.090...30CC BY-SA 4.00NaNNaNNaNNaNNaNhttps://ia600107.us.archive.org/view_archive.p...
23202016-08-02T15:40:24.820150<p>\"Backprop\" is the same as \"backpropagation\"...4NaNNaN...00CC BY-SA 3.01NaNNaNNaNNaNNaNhttps://ia600107.us.archive.org/view_archive.p...
341122016-08-02T15:41:22.020331266<p>When you're writing your algorithm, how do ...824442021-01-19T23:54:07.813...40CC BY-SA 3.00NaNNaNNaNNaNNaNhttps://ia600107.us.archive.org/view_archive.p...
461202016-08-02T15:43:35.4607279<p>Given the following definition of an intell...2924442019-06-15T18:25:58.513...20CC BY-SA 4.00NaNNaNNaNNaNNaNhttps://ia600107.us.archive.org/view_archive.p...
\n", + "

5 rows × 23 columns

\n", + "
\n", + " \n", + " \n", + " \n", + "\n", + " \n", + "
\n", + "
\n", + " " + ] + }, + "metadata": {}, + "execution_count": 219 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Transformations" + ], + "metadata": { + "id": "RAzTR7zY3oan" + } + }, + { + "cell_type": "code", + "source": [ + "def filter_only_questions_with_accepted_answers(df):\n", + " \"\"\"**TODO**\n", + " Filter only to Questions with Accepted Answers\n", + "\n", + " Filter dataframe by questions that have accepted answers, should also include\n", + " all rows of answers for those questions, even if not accepted.\n", + "\n", + " Parameters:\n", + " df (DataFrame): containing a \"AcceptedAnswerId\", \"Id\", and \"ParentId\" columns\n", + "\n", + " Returns:\n", + " df (DataFrame): current dataframe with filtered results\n", + " \"\"\"\n", + " df = df[(df[\"AcceptedAnswerId\"].notnull()) | (df[\"ParentId\"] == df[\"Id\"])]\n", + "\n", + "\n", + "def filter_scores_above(df, question_score_threshold: int = 20, answer_score_threshold: int = 20):\n", + " \"\"\"**TODO**\n", + " Filter Dataframe by minimum scores\n", + "\n", + " Filter Question and Answer columns by score thresholds to trim lower scoring results\n", + "\n", + " Parameters:\n", + " df (DataFrame): containing a \"Score\" column\n", + "\n", + " Returns:\n", + " df (DataFrame): current dataframe with filtered results\n", + " \"\"\"\n", + " df = df[\n", + " ((df[\"Score\"] >= question_score_threshold) & (df.PostTypeId == 1))\n", + " | ((df[\"Score\"] >= answer_score_threshold) & (df.PostTypeId == 2))\n", + " ]\n", + "\n", + "\n", + "def convert_html_to_text(df, column: str = \"Body\"):\n", + " \"\"\"\n", + " Convert HTML tags to pure text\n", + "\n", + " Feeds HTML text body into BeautifulSoup to parse it to only text. Set aside as\n", + " function to provide option to skip\n", + "\n", + " Parameters:\n", + " df (DataFrame): containing a \"Body\" column with HTML\n", + "\n", + " Returns:\n", + " df (DataFrame): current dataframe with parsed column\n", + " \"\"\"\n", + " df.dropna(subset=[column], inplace=True)\n", + " df[f\"{column}Clean\"] = df[column].apply(lambda row: bs(row, \"html.parser\").text)\n", + "\n", + "\n", + "def clean_tags(df):\n", + " \"\"\"\n", + " Convert Tags into Comma separated\n", + "\n", + " Converts Tag slugs into commas separated tags\n", + "\n", + " Parameters:\n", + " df (DataFrame): containing a \"Tags\" column with slugs\n", + "\n", + " Returns:\n", + " df (DataFrame): current dataframe with parsed column\n", + " \"\"\"\n", + " df[\"TagsClean\"] = df[\"Tags\"].str.replace(\"-\", \" \").str.replace(\"><\", \", \").str.replace(\"<\", \"\").str.replace(\">\", \"\")\n", + "\n", + "\n", + "# filter_only_questions_with_accepted_answers(df)\n", + "# filter_scores_above(df)\n", + "convert_html_to_text(df)\n", + "clean_tags(df)\n", + "\n", + "df[[\"Body\", \"BodyClean\", \"Tags\", \"TagsClean\"]]\n", + "# print(df.shape)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 424 + }, + "id": "qyUqc31Z3Z9g", + "outputId": "18dce8b4-af26-49c9-ee73-6c677177b516" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + " Body \\\n", + "0

What does \"backprop\" mean? Is the \"backprop... \n", + "1

Does increasing the noise in data help to i... \n", + "2

\"Backprop\" is the same as \"backpropagation\"... \n", + "3

When you're writing your algorithm, how do ... \n", + "4

Given the following definition of an intell... \n", + "... ... \n", + "23174

The purpose of evaluating the state and act... \n", + "23175

In machine translation, convolution is a te... \n", + "23176

One of the key features of ChatGPT is its a... \n", + "23177

Given a neural network model for Covid-19 c... \n", + "23178

My question is more related to the fundamen... \n", + "\n", + " BodyClean \\\n", + "0 What does \"backprop\" mean? Is the \"backprop\" t... \n", + "1 Does increasing the noise in data help to impr... \n", + "2 \"Backprop\" is the same as \"backpropagation\": i... \n", + "3 When you're writing your algorithm, how do you... \n", + "4 Given the following definition of an intellige... \n", + "... ... \n", + "23174 The purpose of evaluating the state and action... \n", + "23175 In machine translation, convolution is a techn... \n", + "23176 One of the key features of ChatGPT is its abil... \n", + "23177 Given a neural network model for Covid-19 clas... \n", + "23178 My question is more related to the fundamental... \n", + "\n", + " Tags \\\n", + "0 ... \n", + "4 \n", + "... ... \n", + "23174 NaN \n", + "23175 NaN \n", + "23176 NaN \n", + "23177 \n", + "23178 \n", + "\n", + " TagsClean \n", + "0 neural networks, backpropagation, terminology,... \n", + "1 neural networks, machine learning, statistical... \n", + "2 NaN \n", + "3 neural networks, hyperparameter optimization, ... \n", + "4 philosophy, definitions, intelligent agent \n", + "... ... \n", + "23174 NaN \n", + "23175 NaN \n", + "23176 NaN \n", + "23177 neural networks, homework \n", + "23178 search, constraint satisfaction problems \n", + "\n", + "[23179 rows x 4 columns]" + ], + "text/html": [ + "\n", + "

\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
BodyBodyCleanTagsTagsClean
0<p>What does \"backprop\" mean? Is the \"backprop...What does \"backprop\" mean? Is the \"backprop\" t...<neural-networks><backpropagation><terminology...neural networks, backpropagation, terminology,...
1<p>Does increasing the noise in data help to i...Does increasing the noise in data help to impr...<neural-networks><machine-learning><statistica...neural networks, machine learning, statistical...
2<p>\"Backprop\" is the same as \"backpropagation\"...\"Backprop\" is the same as \"backpropagation\": i...NaNNaN
3<p>When you're writing your algorithm, how do ...When you're writing your algorithm, how do you...<neural-networks><hyperparameter-optimization>...neural networks, hyperparameter optimization, ...
4<p>Given the following definition of an intell...Given the following definition of an intellige...<philosophy><definitions><intelligent-agent>philosophy, definitions, intelligent agent
...............
23174<p>The purpose of evaluating the state and act...The purpose of evaluating the state and action...NaNNaN
23175<p>In machine translation, convolution is a te...In machine translation, convolution is a techn...NaNNaN
23176<p>One of the key features of ChatGPT is its a...One of the key features of ChatGPT is its abil...NaNNaN
23177<p>Given a neural network model for Covid-19 c...Given a neural network model for Covid-19 clas...<neural-networks><homework>neural networks, homework
23178<p>My question is more related to the fundamen...My question is more related to the fundamental...<search><constraint-satisfaction-problems>search, constraint satisfaction problems
\n", + "

23179 rows × 4 columns

\n", + "
\n", + " \n", + " \n", + " \n", + "\n", + " \n", + "
\n", + "
\n", + " " + ] + }, + "metadata": {}, + "execution_count": 220 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "This groups questions with answers so that a row with a question also has a column with an answer. It then creates an AcceptedAnswerFlag column that is True if the answer was accepted by the person who asked the question. Changing the `number_of_results` variable will limit the number of answers you want to keep." + ], + "metadata": { + "id": "C09Bwdw-44PZ" + } + }, + { + "cell_type": "code", + "source": [ + "questions = df[df.PostTypeId == 1]\n", + "answers = df[df.PostTypeId == 2]\n", + "\n", + "df = pd.merge(\n", + " questions,\n", + " answers[\n", + " [\n", + " \"Id\",\n", + " \"CreationDate\",\n", + " \"Score\",\n", + " \"ViewCount\",\n", + " \"CommentCount\",\n", + " \"ContentLicense\",\n", + " \"TagsClean\",\n", + " \"BodyClean\",\n", + " \"ParentId\",\n", + " ]\n", + " ],\n", + " left_on=\"Id\",\n", + " right_on=\"ParentId\",\n", + " suffixes=(\"_q\", \"_a\"),\n", + " how=\"left\",\n", + ")\n", + "\n", + "df[\"AcceptedAnswerFlag\"] = df.apply(lambda row: row[\"Id_a\"] == row[\"AcceptedAnswerId\"], axis=1)\n", + "\n", + "df = df.rename(\n", + " columns={\n", + " \"BodyClean_q\": \"Question\",\n", + " \"Score_q\": \"QuestionScore\",\n", + " \"TagsClean_q\": \"QuestionTags\",\n", + " \"BodyClean_a\": \"Answer\",\n", + " \"Score_a\": \"AnswerScore\",\n", + " \"ContentLicense_q\": \"QuestionContentLicense\",\n", + " \"ContentLicense_a\": \"AnswerContentLicense\",\n", + " \"CreationDate_q\": \"CreationDate\",\n", + " }\n", + ")\n", + "\n", + "## Set the number of results to a lower number to only return top N rated Answers.\n", + "number_of_results = 25\n", + "df = (\n", + " df.sort_values(by=[\"AcceptedAnswerFlag\", \"AnswerScore\"], ascending=[False, False])\n", + " .groupby(\"Question\")\n", + " .head(number_of_results)\n", + " .reset_index(drop=True)\n", + ")\n", + "\n", + "df[[\"Id_q\", \"Question\", \"QuestionScore\", \"QuestionTags\", \"Id_a\", \"Answer\", \"AnswerScore\", \"AcceptedAnswerFlag\"]].head()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 337 + }, + "id": "Bgz2fZ9k43Ab", + "outputId": "28896d69-03cd-4877-fdfb-ae48dafa4ff3" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + " Id_q Question QuestionScore \\\n", + "0 1768 In Portal 2 we see that AI's can be \"killed\" b... 175 \n", + "1 10623 What is self-supervised learning in machine le... 91 \n", + "2 111 Obviously, self-driving cars aren't perfect, s... 100 \n", + "3 14224 If the original purpose for developing AI was ... 69 \n", + "4 1479 Do scientists or research experts know from th... 94 \n", + "\n", + " QuestionTags Id_a \\\n", + "0 philosophy, decision theory, mythology of ai, ... 1769.0 \n", + "1 machine learning, comparison, supervised learn... 10624.0 \n", + "2 philosophy, ethics, autonomous vehicles, decis... 1790.0 \n", + "3 philosophy, social, explainable ai 14247.0 \n", + "4 neural networks, deep learning, convolutional ... 4044.0 \n", + "\n", + " Answer AnswerScore \\\n", + "0 This classic problem exhibits a basic misunder... 146.0 \n", + "1 Introduction\\nThe term self-supervised learnin... 90.0 \n", + "2 \\nHow could self-driving cars make ethical dec... 76.0 \n", + "3 As argued by Selvaraju et al., there are three... 75.0 \n", + "4 There are many approaches that aim to make a t... 69.0 \n", + "\n", + " AcceptedAnswerFlag \n", + "0 True \n", + "1 True \n", + "2 True \n", + "3 True \n", + "4 True " + ], + "text/html": [ + "\n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Id_qQuestionQuestionScoreQuestionTagsId_aAnswerAnswerScoreAcceptedAnswerFlag
01768In Portal 2 we see that AI's can be \"killed\" b...175philosophy, decision theory, mythology of ai, ...1769.0This classic problem exhibits a basic misunder...146.0True
110623What is self-supervised learning in machine le...91machine learning, comparison, supervised learn...10624.0Introduction\\nThe term self-supervised learnin...90.0True
2111Obviously, self-driving cars aren't perfect, s...100philosophy, ethics, autonomous vehicles, decis...1790.0\\nHow could self-driving cars make ethical dec...76.0True
314224If the original purpose for developing AI was ...69philosophy, social, explainable ai14247.0As argued by Selvaraju et al., there are three...75.0True
41479Do scientists or research experts know from th...94neural networks, deep learning, convolutional ...4044.0There are many approaches that aim to make a t...69.0True
\n", + "
\n", + " \n", + " \n", + " \n", + "\n", + " \n", + "
\n", + "
\n", + " " + ] + }, + "metadata": {}, + "execution_count": 221 + } + ] + }, + { + "cell_type": "code", + "source": [ + "testing_id = df.Id_q.mode()[0]\n", + "df[(df.Id_q == testing_id) | (df.ParentId_a == testing_id)][\n", + " [\"Id_q\", \"Question\", \"ParentId_a\", \"AcceptedAnswerId\", \"Id_a\", \"Answer\", \"AnswerScore\", \"AcceptedAnswerFlag\"]\n", + "]\n", + "# df[['Id_q', 'Question', 'ParentId_a', 'AcceptedAnswerId', 'Id_a', 'Answer', 'AnswerScore', 'AcceptedAnswerFlag']]" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 645 + }, + "id": "eds1K8WL9QPo", + "outputId": "bc526503-d6dd-428f-fa98-ad419d26a7dc" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + " Id_q Question ParentId_a \\\n", + "7 15730 As a human being, we can think infinity. In pr... 15730.0 \n", + "3662 15730 As a human being, we can think infinity. In pr... 15730.0 \n", + "3713 15730 As a human being, we can think infinity. In pr... 15730.0 \n", + "3788 15730 As a human being, we can think infinity. In pr... 15730.0 \n", + "3821 15730 As a human being, we can think infinity. In pr... 15730.0 \n", + "3882 15730 As a human being, we can think infinity. In pr... 15730.0 \n", + "4389 15730 As a human being, we can think infinity. In pr... 15730.0 \n", + "4849 15730 As a human being, we can think infinity. In pr... 15730.0 \n", + "4850 15730 As a human being, we can think infinity. In pr... 15730.0 \n", + "5763 15730 As a human being, we can think infinity. In pr... 15730.0 \n", + "5764 15730 As a human being, we can think infinity. In pr... 15730.0 \n", + "5765 15730 As a human being, we can think infinity. In pr... 15730.0 \n", + "7462 15730 As a human being, we can think infinity. In pr... 15730.0 \n", + "7463 15730 As a human being, we can think infinity. In pr... 15730.0 \n", + "7464 15730 As a human being, we can think infinity. In pr... 15730.0 \n", + "7465 15730 As a human being, we can think infinity. In pr... 15730.0 \n", + "7466 15730 As a human being, we can think infinity. In pr... 15730.0 \n", + "7467 15730 As a human being, we can think infinity. In pr... 15730.0 \n", + "9481 15730 As a human being, we can think infinity. In pr... 15730.0 \n", + "\n", + " AcceptedAnswerId Id_a \\\n", + "7 15744 15744.0 \n", + "3662 15744 15753.0 \n", + "3713 15744 15747.0 \n", + "3788 15744 15756.0 \n", + "3821 15744 15758.0 \n", + "3882 15744 15762.0 \n", + "4389 15744 15783.0 \n", + "4849 15744 15740.0 \n", + "4850 15744 15803.0 \n", + "5763 15744 15768.0 \n", + "5764 15744 15810.0 \n", + "5765 15744 15943.0 \n", + "7462 15744 15779.0 \n", + "7463 15744 15787.0 \n", + "7464 15744 15801.0 \n", + "7465 15744 15930.0 \n", + "7466 15744 15934.0 \n", + "7467 15744 15938.0 \n", + "9481 15744 15931.0 \n", + "\n", + " Answer AnswerScore \\\n", + "7 I think this is a fairly common misconception ... 62.0 \n", + "3662 I think your premise is flawed.\\nYou seem to a... 19.0 \n", + "3713 TL;DR: The subtleties of infinity are made app... 12.0 \n", + "3788 In Haskell, you can type:\\nprint [1..]\\nand it... 9.0 \n", + "3821 I believe humans can be said to understand inf... 8.0 \n", + "3882 (There's a summary at the bottom for those who... 7.0 \n", + "4389 Then premise assumes that humans \"understand\" ... 4.0 \n", + "4849 By adding some rules for infinity in arithmeti... 3.0 \n", + "4850 I think the concept that is missing in the dis... 3.0 \n", + "5763 Computers don't understand \"infinity\" or even ... 2.0 \n", + "5764 The Questions That Computers Can Never Answer ... 2.0 \n", + "5765 John Doucette's answer covers my thoughts on t... 2.0 \n", + "7462 I would think that a computer couldn’t underst... 1.0 \n", + "7463 The \"concept\" of infinity is 1 thing to unders... 1.0 \n", + "7464 Just food for thought: how about if we try to ... 1.0 \n", + "7465 Its arguable if we humans understand infinity.... 1.0 \n", + "7466 Well -- just to touch on the question of peopl... 1.0 \n", + "7467 Humans certainly don't understand infinity. Cu... 1.0 \n", + "9481 I think the property humans have which compute... 0.0 \n", + "\n", + " AcceptedAnswerFlag \n", + "7 True \n", + "3662 False \n", + "3713 False \n", + "3788 False \n", + "3821 False \n", + "3882 False \n", + "4389 False \n", + "4849 False \n", + "4850 False \n", + "5763 False \n", + "5764 False \n", + "5765 False \n", + "7462 False \n", + "7463 False \n", + "7464 False \n", + "7465 False \n", + "7466 False \n", + "7467 False \n", + "9481 False " + ], + "text/html": [ + "\n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Id_qQuestionParentId_aAcceptedAnswerIdId_aAnswerAnswerScoreAcceptedAnswerFlag
715730As a human being, we can think infinity. In pr...15730.01574415744.0I think this is a fairly common misconception ...62.0True
366215730As a human being, we can think infinity. In pr...15730.01574415753.0I think your premise is flawed.\\nYou seem to a...19.0False
371315730As a human being, we can think infinity. In pr...15730.01574415747.0TL;DR: The subtleties of infinity are made app...12.0False
378815730As a human being, we can think infinity. In pr...15730.01574415756.0In Haskell, you can type:\\nprint [1..]\\nand it...9.0False
382115730As a human being, we can think infinity. In pr...15730.01574415758.0I believe humans can be said to understand inf...8.0False
388215730As a human being, we can think infinity. In pr...15730.01574415762.0(There's a summary at the bottom for those who...7.0False
438915730As a human being, we can think infinity. In pr...15730.01574415783.0Then premise assumes that humans \"understand\" ...4.0False
484915730As a human being, we can think infinity. In pr...15730.01574415740.0By adding some rules for infinity in arithmeti...3.0False
485015730As a human being, we can think infinity. In pr...15730.01574415803.0I think the concept that is missing in the dis...3.0False
576315730As a human being, we can think infinity. In pr...15730.01574415768.0Computers don't understand \"infinity\" or even ...2.0False
576415730As a human being, we can think infinity. In pr...15730.01574415810.0The Questions That Computers Can Never Answer ...2.0False
576515730As a human being, we can think infinity. In pr...15730.01574415943.0John Doucette's answer covers my thoughts on t...2.0False
746215730As a human being, we can think infinity. In pr...15730.01574415779.0I would think that a computer couldn’t underst...1.0False
746315730As a human being, we can think infinity. In pr...15730.01574415787.0The \"concept\" of infinity is 1 thing to unders...1.0False
746415730As a human being, we can think infinity. In pr...15730.01574415801.0Just food for thought: how about if we try to ...1.0False
746515730As a human being, we can think infinity. In pr...15730.01574415930.0Its arguable if we humans understand infinity....1.0False
746615730As a human being, we can think infinity. In pr...15730.01574415934.0Well -- just to touch on the question of peopl...1.0False
746715730As a human being, we can think infinity. In pr...15730.01574415938.0Humans certainly don't understand infinity. Cu...1.0False
948115730As a human being, we can think infinity. In pr...15730.01574415931.0I think the property humans have which compute...0.0False
\n", + "
\n", + " \n", + " \n", + " \n", + "\n", + " \n", + "
\n", + "
\n", + " " + ] + }, + "metadata": {}, + "execution_count": 222 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Create JSONL version of Dataframe\n", + "This groups the dataframe by question data and creates nested list of Answers for that group. The entire list contains individual JSON objects, each representing a single question in the dataset with a key, Answers, which contains a list of dictionaries for each answer to the question." + ], + "metadata": { + "id": "gXgpXEO7DCbj" + } + }, + { + "cell_type": "code", + "source": [ + "j = (\n", + " df.groupby(\n", + " [\"Title\", \"Question\", \"QuestionScore\", \"QuestionTags\", \"QuestionContentLicense\", \"DataSource\", \"CreationDate\"]\n", + " )\n", + " .apply(lambda x: x[[\"Answer\", \"AnswerScore\", \"AcceptedAnswerFlag\"]].to_dict(\"records\"))\n", + " .reset_index()\n", + " .rename(columns={0: \"Answers\"})\n", + " .to_json(orient=\"records\")\n", + ")\n", + "\n", + "data = json.loads(j)\n", + "\n", + "for post in data:\n", + " if len(post.get(\"Answers\")) >= 4:\n", + " print(json.dumps(post, indent=4))\n", + " break" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "OBR58MSRzAMP", + "outputId": "c7da1e6c-3a97-465d-c9ba-7e055cb0d751" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "{\n", + " \"Title\": \"1 hidden layer with 1000 neurons vs. 10 hidden layers with 100 neurons\",\n", + " \"Question\": \"These types of questions may be problem-dependent, but I have tried to find research that addresses the question whether the number of hidden layers and their size (number of neurons in each layer) really matter or not.\\nSo my question is, does it really matter if we for example have 1 large hidden layer of 1000 neurons vs. 10 hidden layers with 100 neurons each?\\n\",\n", + " \"QuestionScore\": 16,\n", + " \"QuestionTags\": \"neural networks\",\n", + " \"QuestionContentLicense\": \"CC BY-SA 3.0\",\n", + " \"DataSource\": \"https://ia600107.us.archive.org/view_archive.php?archive=/27/items/stackexchange/ai.stackexchange.com.7z&file=Posts.xml\",\n", + " \"CreationDate\": \"2017-05-04T13:06:37.990\",\n", + " \"Answers\": [\n", + " {\n", + " \"Answer\": \"Basically, having multiple layers (aka a deep network) makes your network more eager to recognize certain aspects of input data. For example, if you have the details of a house (size, lawn size, location etc.) as input and want to predict the price. The first layer may predict:\\n\\nBig area, higher price\\nSmall amount of bedrooms, lower price\\n\\nThe second layer might conclude:\\n\\nBig area + small amount of bedrooms = large bedrooms = +- effect\\n\\nYes, one layer can also 'detect' the stats, however it will require more neurons as it cannot rely on other neurons to do 'parts' of the total calculation required to detect that stat.\\nCheck out this answer\\n\",\n", + " \"AnswerScore\": 13.0,\n", + " \"AcceptedAnswerFlag\": true\n", + " },\n", + " {\n", + " \"Answer\": \"There are so many aspects.\\n1. Training:\\nTraining deep nets is a hard job due to the vanishing (rearly exploding) gradient problem. So building a 10x100 neural-net is not recommended.\\n2. Trained network performance:\\n\\nInformation loss:\\nThe classical usage of neural nets is the classification problem. Which means we want to get some well defined information from the data. (Ex. Is there a face in the picture or not.)\\nSo usually classification problem has a lot of input, and few output, whats more the size of the hidden layers are descend from input to output.\\nHowever, we loss information using less neurons layer by layer. (Ie. We cannot reproduce the original image based on the fact that is there a face on it or no.) So you must know that you loss information using 100 neurons if the size of the input is (lets say) 1000.\\nInformation complexity: However the deeper nets (as Tomas W mentioned) can fetch more complex information from the input data. Inspite of this its not recommended to use 10 fully connected layers. Its recommended to use convolutional/relu/maxpooling or other type of layers. Firest layers can compress the some essential part of the inputs. (Ex is there any line in a specific part of the picture) Second layers can say: There is a specific shape in this place in the picture. Etc etc.\\n\\nSo deeper nets are more \\\"clever\\\" but 10x100 net structure is a good choice.\\n\",\n", + " \"AnswerScore\": 4.0,\n", + " \"AcceptedAnswerFlag\": false\n", + " },\n", + " {\n", + " \"Answer\": \"If the problem you are solving is linearly separable, one layer of 1000 neurons can do better job than 10 layers with each of 100 neurons.\\nIf the problem is non linear and not convex, then you need deep neural nets. \\n\",\n", + " \"AnswerScore\": 1.0,\n", + " \"AcceptedAnswerFlag\": false\n", + " },\n", + " {\n", + " \"Answer\": \"\\nI think you have a confusion in the basics of the neural networks.\\n Every layer has a separate activation function and input/output\\n connection weights.\\n\\nThe output of the first hidden layer will be multiplied by a weight, processed by an activation function in the next layer and so on.\\nSingle layer neural networks are very limited for simple tasks, deeper NN can perform far better than a single layer. \\nHowever, do not use more than layer if your application is not fairly complex. In conclusion, 100 neurons layer does not mean better neural network than 10 layers x 10 neurons but 10 layers are something imaginary unless you are doing deep learning. start with 10 neurons in the hidden layer and try to add layers or add more neurons to the same layer to see the difference. learning with more layers will be easier but more training time is required.\\n\",\n", + " \"AnswerScore\": 0.0,\n", + " \"AcceptedAnswerFlag\": false\n", + " }\n", + " ]\n", + "}\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Save file\n", + "\n", + "Files can be saved as JSON, JSONL, CSV, or Parquet" + ], + "metadata": { + "id": "PlNjrpXaDm1_" + } + }, + { + "cell_type": "code", + "source": [ + "file_name = dataset_name\n", + "\n", + "\n", + "def save_data(data: list, file_name: str, file_type: str = [\"csv\", \"json\", \"jsonl\", \"parquet\"]):\n", + " \"\"\"\n", + " Save Data to file\n", + "\n", + " Save Data list to file as either JSON or JSONL\n", + "\n", + " Parameters:\n", + " data (list): list of dictionaries\n", + " file_name (str): name of file (no extension)\n", + " jsonl (bool): to save file as either JSON or JSONL\n", + " \"\"\"\n", + " file_type = file_type.lower()\n", + "\n", + " if file_type == \"csv\" and isinstance(data, pd.DataFrame):\n", + " data.to_csv(f\"/content/{file_name}.csv\", index=False)\n", + "\n", + " elif file_type == \"json\" and isinstance(data, list):\n", + " print(json.dumps(data, indent=4), file=open(f\"/content/{file_name}.json\", \"w\"))\n", + "\n", + " elif file_type == \"jsonl\" and isinstance(data, list):\n", + " for item in data:\n", + " print(json.dumps(item), file=open(f\"/content/{file_name}.jsonl\", \"a\"))\n", + "\n", + " elif file_type == \"parquet\" and isinstance(data, pd.DataFrame):\n", + " data.to_parquet(f\"/content/{file_name}.parquet\", index=False)\n", + "\n", + " else:\n", + " print(\"Data should be either of List type for JSON and JSONL, or Pandas Dataframes for CSV and Parquet\")\n", + "\n", + "\n", + "# save_data(data=data, file_name=file_name, file_type='jsonl')\n", + "# save_data(data=df, file_name=file_name, file_type='parquet')" + ], + "metadata": { + "id": "CU0gWRGQDqIs", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "9646e475-cedd-46f1-f9b8-7eb1fbc703c7" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Data should be either of List type for JSON and JSONL, or Pandas Dataframes for CSV and Parquet\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Open-Assistant Data Scheme\n", + "\n", + "Testing putting the data into the Open-Assistant Data Scheme\n", + "\n", + "https://github.com/LAION-AI/Open-Assistant/blob/main/docs/data_schemas.md" + ], + "metadata": { + "id": "BdN3hKxtgH7f" + } + }, + { + "cell_type": "code", + "source": [ + "from typing import TypeVar, List, Dict, Any, Literal\n", + "from json import JSONEncoder\n", + "\n", + "T = TypeVar(\"T\", bound=\"ConversationTreeNode\")\n", + "\n", + "\n", + "class ConversationTreeNode:\n", + " text: str # The text of the node\n", + " role: Literal[\"prompter\", \"assistant\"] # Whether the node is a user prompt/follow-up or an assistant response\n", + " children: List[T] # The children of the node (if you have a linear conversation, this will be of length 0 or 1)\n", + " metadata: Dict[str, Any] # Node metadata (see below)\n", + "\n", + " def __init__(\n", + " self, text: str, role: Literal[\"prompter\", \"assistant\"], children: List[T], metadata: Dict[str, Any]\n", + " ) -> None:\n", + " self.text = text\n", + " self.role = role\n", + " self.children = children\n", + " self.metadata = metadata\n", + "\n", + "\n", + "class ConversationTree:\n", + " root: ConversationTreeNode # The node containing the initial prompt\n", + " metadata: Dict[str, Any] # Tree metadata, different from root node metadata.\n", + "\n", + " def __init__(self, root: ConversationTreeNode, metadata: Dict[str, Any]) -> None:\n", + " self.root = root\n", + " self.metadata = metadata\n", + "\n", + "\n", + "# subclass JSONEncoder\n", + "class TreeEncoder(JSONEncoder):\n", + " def default(self, o):\n", + " return o.__dict__" + ], + "metadata": { + "id": "n8ubYQxegNSY" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "conversation_forest = []\n", + "\n", + "tree_metadata_map = {\"Title\": str, \"QuestionContentLicense\": str, \"DataSource\": str, \"CreationDate\": str}\n", + "question_metadata_map = {\"QuestionScore\": int, \"QuestionTags\": str}\n", + "answer_metadata_map = {\"AnswerScore\": int, \"AcceptedAnswerFlag\": bool}\n", + "\n", + "\n", + "for item in data:\n", + " prompt = item.get(\"Question\")\n", + " metadata = {k: v for k, v in item.items() if k in question_metadata_map}\n", + " root = ConversationTreeNode(text=prompt, role=\"prompter\", children=[], metadata=metadata)\n", + "\n", + " for answer in item.get(\"Answers\"):\n", + " response = answer.get(\"Answer\")\n", + " metadata = {k: v for k, v in answer.items() if k in answer_metadata_map}\n", + " child = ConversationTreeNode(text=response, role=\"assistant\", children=[], metadata=metadata)\n", + " root.children.append(child)\n", + "\n", + " metadata = {k: v for k, v in item.items() if k in tree_metadata_map}\n", + " conversation_tree = ConversationTree(root=root, metadata=metadata)\n", + " conversation_forest.append(conversation_tree)\n", + "\n", + "\n", + "conversation_forest_json = [\n", + " json.loads(TreeEncoder().encode(conversation_tree)) for conversation_tree in conversation_forest\n", + "]\n", + "\n", + "\n", + "# print(json.dumps(conversation_forest_json[0], indent=4))\n", + "\n", + "\n", + "print(json.dumps(conversation_forest_json, indent=4), file=open(f\"/content/{file_name}.json\", \"w\"))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "eE0fkytExSGl", + "outputId": "594632d6-f98c-49b8-af86-25f7f5e2ce06" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "{\n", + " \"root\": {\n", + " \"text\": \"Science Fiction has frequently shown AI to be a threat to the very existence of mankind. AI systems have often been the antagonists in many works of fiction, from 2001: A Space Odyssey through to The Terminator and beyond.\\nThe Media seems to buy into this trope as well. And in recent years we have had people like Elon Musk warn us of the dangers of an impending AI revolution, stating that AI is more dangerous than nukes.\\nAnd, apparently, experts think that we will be seeing this AI revolution in the next 100 years.\\nHowever, from my (albeit limited) study of AI, I get the impression that they are all wrong. I am going to outline my understanding below, please correct me if I am wrong:\\n\\nFirstly, all of these things seem to be confusing Artificial Intelligence with Artificial Consciousness. AI is essentially a system to make intelligent decisions, whereas AC is more like the \\\"self-aware\\\" systems that are shown in science fiction.\\n\\nNot AI itself, but intelligence and intelligent decision-making algorithms are something we've been working with and enhancing since before computers have been around. Moving this over to an artificial framework is fairly easy. However, consciousness is still something we are learning about. My guess is we won't be able to re-create something artificially if we barely understand how it works in the real world.\\n\\nSo, my conclusion is that no AI system will be able to learn enough to start thinking for itself, and that all our warnings of AI are completely unjustified.\\n\\nThe real danger comes from AC, which we are a long, long way from realizing because we are still a long way off from defining exactly what consciousness is, let alone understanding it.\\n\\n\\n\\nSo, my question is, assuming that my understanding is correct, are any efforts are being made by companies or organizations that work with AI to correct these popular misunderstandings in sci-fi, the media, and/or the public?\\nOr are the proponents of AI ambivalent towards this public fear-mongering?\\nI understand that the fear mongering is going to remain popular for some time, as bad news sells better than good news. I am just wondering if the general attitude from AI organizations is to ignore this popular misconception, or whether a concerted effort is being made to fight against these AI myths (but unfortunately nobody in the media is listening or cares).\\n\",\n", + " \"role\": \"prompter\",\n", + " \"children\": [\n", + " {\n", + " \"text\": \"Nothing. \\nIts in almost everyone's favor for it to stay that way financially. Having non-technical individuals associate AI with terminators makes a perception that the field has greater capabilities than it does $\\\\rightarrow$ this leads to grants, funding, etc... \\nIs there any negative? Yes. Misconceptions always have drawbacks. We see the creation of dumb ethics boards and such cough cough Elon Musk.\\nBut if history has anything to say about this, as the field gains popularity (which it is dnagerously quick), information will spread by definition, and eventually misconceptions will be laid to rest.\\nNote that this answer is biased and based upon my own opinions\\n\",\n", + " \"role\": \"assistant\",\n", + " \"children\": [],\n", + " \"metadata\": {\n", + " \"AnswerScore\": 2.0,\n", + " \"AcceptedAnswerFlag\": true\n", + " }\n", + " }\n", + " ],\n", + " \"metadata\": {\n", + " \"QuestionScore\": 5,\n", + " \"QuestionTags\": \"social, artificial consciousness\"\n", + " }\n", + " },\n", + " \"metadata\": {\n", + " \"Title\": \"\\\"AI will kill us all! The machines will rise up!\\\" - what is being done to dispel such myths?\",\n", + " \"QuestionContentLicense\": \"CC BY-SA 4.0\",\n", + " \"DataSource\": \"https://ia600107.us.archive.org/view_archive.php?archive=/27/items/stackexchange/ai.stackexchange.com.7z&file=Posts.xml\",\n", + " \"CreationDate\": \"2019-10-16T13:57:37.143\"\n", + " }\n", + "}\n" + ] + } + ] + } + ] +} diff --git a/notebooks/data-argumentation/StackExchangeBuilder.md b/notebooks/data-argumentation/StackExchangeBuilder.md new file mode 100644 index 00000000..74a49872 --- /dev/null +++ b/notebooks/data-argumentation/StackExchangeBuilder.md @@ -0,0 +1,106 @@ +# StackExchange Builder + +StackExchange Builder is a notebook that downloads data from StackExchange data +dumps and converts it into different formats. It will parse the XML files, group +questions and answers, can filter the dataset and puts the results into the +Open-Assistant Data Scheme. Files can be saved to either JSON, JSONL, Parquet, +or CSV. + +--- + +#### Sample Data Open-Assistant Data Scheme: + +Open-Assistant Data Scheme as outlined here: +https://github.com/LAION-AI/Open-Assistant/blob/main/docs/data_schemas.md + +``` +{ + "root": { + "text": "Science Fiction has frequently shown AI to be a threat to the very existence of mankind. AI systems have often been the antagonists...", + "role": "prompter", + "children": [ + { + "text": "Nothing. \nIts in almost everyone's favor for it to stay that way financially. Having non-technical individuals associate AI with terminators...", + "role": "assistant", + "children": [], + "metadata": { + "AnswerScore": 2.0, + "AcceptedAnswerFlag": true + } + } + ], + "metadata": { + "QuestionScore": 5, + "QuestionTags": "social, artificial consciousness" + } + }, + "metadata": { + "Title": "\"AI will kill us all! The machines will rise up!\" - what is being done to dispel such myths?", + "QuestionContentLicense": "CC BY-SA 4.0", + "DataSource": "https://ia600107.us.archive.org/view_archive.php?archive=/27/items/stackexchange/ai.stackexchange.com.7z&file=Posts.xml", + "CreationDate": "2019-10-16T13:57:37.143" + } +} +``` + +--- + +#### JSONL format + +Each question and all related answers are on a single line in JSONL format. + +``` +{ + "Title": "1 hidden layer with 1000 neurons vs. 10 hidden layers with 100 neurons", + "Question": "These types of questions may be problem-dependent...", + "QuestionScore": 16, + "QuestionTags": "neural networks", + "QuestionContentLicense": "CC BY-SA 3.0", + "DataSource": "https://ia600107.us.archive.org/view_archive.php?archive=/27/items/stackexchange/ai.stackexchange.com.7z&file=Posts.xml", + "CreationDate": "2017-05-04T13:06:37.990", + "Answers": [ + { + "Answer": "Basically, having multiple layers (aka a deep network) makes your network more eager to recognize certain aspects of input data...", + "AnswerScore": 13.0, + "AcceptedAnswerFlag": true + }, + { + "Answer": "There are so many aspects.\n1. Training:\nTraining deep nets is a hard job due to the vanishing (rearly exploding) gradient problem...", + "AnswerScore": 4.0, + "AcceptedAnswerFlag": false + }, + { + "Answer": "If the problem you are solving is linearly separable, one layer of 1000 neurons can do better job...", + "AnswerScore": 1.0, + "AcceptedAnswerFlag": false + }, + { + "Answer": "\nI think you have a confusion in the basics of the neural networks.\n Every layer has a separate activation...", + "AnswerScore": 0.0, + "AcceptedAnswerFlag": false + } + ] +} +``` + +#### Table/CSV/Parquet Format + +There are a lot more columns left over in the table format. `_q` and `_a` are +suffixes indiciating if the column came from the question or answer table as +leftover from a join statement. + +``` +| Id_q | Question | ParentId_a | AcceptedAnswerId | Id_a | Answer | AnswerScore | AcceptedAnswerFlag | +|------:|--------------------------------------------------:|-----------:|-----------------:|--------:|--------------------------------------------------:|------------:|-------------------:| +| 15730 | As a human being, we can think infinity. In pr... | 15730.0 | 15744 | 15744.0 | I think this is a fairly common misconception ... | 62.0 | True | +| 15730 | As a human being, we can think infinity. In pr... | 15730.0 | 15744 | 15753.0 | I think your premise is flawed.\nYou seem to a... | 19.0 | False | +| 15730 | As a human being, we can think infinity. In pr... | 15730.0 | 15744 | 15747.0 | TL;DR: The subtleties of infinity are made app... | 12.0 | False | +| 15730 | As a human being, we can think infinity. In pr... | 15730.0 | 15744 | 15756.0 | In Haskell, you can type:\nprint [1..]\nand it... | 9.0 | False | +``` + +--- + +## Contributing + +Feel free to contribute to this notebook. It's not perfect and additional +functionality is planned. diff --git a/oasst-shared/oasst_shared/exceptions/oasst_api_error.py b/oasst-shared/oasst_shared/exceptions/oasst_api_error.py index 49eeb088..6cc25918 100644 --- a/oasst-shared/oasst_shared/exceptions/oasst_api_error.py +++ b/oasst-shared/oasst_shared/exceptions/oasst_api_error.py @@ -10,6 +10,7 @@ class OasstErrorCode(IntEnum): 0-1000: general errors 1000-2000: tasks endpoint 2000-3000: prompt_repository + 3000-4000: external resources """ # 0-1000: general errors @@ -45,6 +46,9 @@ class OasstErrorCode(IntEnum): TASK_ALREADY_DONE = 2105 TASK_NOT_COLLECTIVE = 2106 + # 3000-4000: external resources + HUGGINGFACE_API_ERROR = 3001 + class OasstError(Exception): """Base class for Open-Assistant exceptions.""" diff --git a/oasst-shared/oasst_shared/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py index 83375d8f..e035d387 100644 --- a/oasst-shared/oasst_shared/schemas/protocol.py +++ b/oasst-shared/oasst_shared/schemas/protocol.py @@ -5,7 +5,7 @@ from uuid import UUID, uuid4 import pydantic from oasst_shared.exceptions import OasstErrorCode -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, conint, conlist, constr class TaskRequestType(str, enum.Enum): @@ -203,7 +203,7 @@ class TextReplyToMessage(Interaction): type: Literal["text_reply_to_message"] = "text_reply_to_message" message_id: str user_message_id: str - text: str + text: constr(min_length=1, strip_whitespace=True) class MessageRating(Interaction): @@ -211,7 +211,7 @@ class MessageRating(Interaction): type: Literal["message_rating"] = "message_rating" message_id: str - rating: int + rating: conint(gt=0) class MessageRanking(Interaction): @@ -219,7 +219,7 @@ class MessageRanking(Interaction): type: Literal["message_ranking"] = "message_ranking" message_id: str - ranking: list[int] + ranking: conlist(item_type=int, min_items=1) AnyInteraction = Union[ diff --git a/openassistant/__init__.py b/openassistant/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/openassistant/dev-requirements.txt b/openassistant/dev-requirements.txt new file mode 100644 index 00000000..fe709d0e --- /dev/null +++ b/openassistant/dev-requirements.txt @@ -0,0 +1 @@ +datasets>=2.8,<3.0.0 diff --git a/openassistant/templates/README.md b/openassistant/templates/README.md new file mode 100644 index 00000000..b174c47e --- /dev/null +++ b/openassistant/templates/README.md @@ -0,0 +1,10 @@ +# Dataset preparation instructions for {dataset_name} + +## Setup + +Add any installation details here. + +## Usage + +Explain how to run any scripts that involve preparing local dataset files, e.g. +if the dataset files aren't public or are produced by a web scraper. diff --git a/openassistant/templates/dataset_card.md b/openassistant/templates/dataset_card.md new file mode 100644 index 00000000..76736d8f --- /dev/null +++ b/openassistant/templates/dataset_card.md @@ -0,0 +1,28 @@ +--- +license: mit +tags: +- open-assistant +- human-feedback +- dialogue-modeling +- language-modeling +--- + +# Dataset card for {dataset_name} + +This is a dataset card template for the [LAION-AI OpenAssistant project](https://github.com/LAION-AI/Open-Assistant). Fill out this template when adding a new dataset to the Hugging Face Hub. + +## Dataset summary + +[More information needed](https://github.com/huggingface/datasets/blob/main/CONTRIBUTING.md#how-to-contribute-to-the-dataset-cards) + +## Usage + +[More information needed](https://github.com/huggingface/datasets/blob/main/CONTRIBUTING.md#how-to-contribute-to-the-dataset-cards) + +## Source data + +[More information needed](https://github.com/huggingface/datasets/blob/main/CONTRIBUTING.md#how-to-contribute-to-the-dataset-cards) + +## Citation + +[More information needed](https://github.com/huggingface/datasets/blob/main/CONTRIBUTING.md#how-to-contribute-to-the-dataset-cards) \ No newline at end of file diff --git a/openassistant/templates/hub.py b/openassistant/templates/hub.py new file mode 100644 index 00000000..49194e31 --- /dev/null +++ b/openassistant/templates/hub.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass + +import datasets + + +@dataclass +class OpenAssistantConfig(datasets.BuilderConfig): + """BuilderConfig for OpenAssistant datasets.""" + + name: str = None + version: datasets.Version = None + description: str = None + schema: str = None + subset_id: str = None + + +lm_features = datasets.Features( + { + "text": datasets.Value("string"), + "meta": [datasets.Value("string")], + } +) diff --git a/openassistant/templates/prepare.py b/openassistant/templates/prepare.py new file mode 100644 index 00000000..83a6b15b --- /dev/null +++ b/openassistant/templates/prepare.py @@ -0,0 +1,8 @@ +import typer + +def main(output_dir: str = "data"): + """Download and prepare the dataset for use.""" + raise NotImplementedError + +if __name__ == "__main__": + typer.run(main) \ No newline at end of file diff --git a/openassistant/templates/template.py b/openassistant/templates/template.py new file mode 100644 index 00000000..391df55f --- /dev/null +++ b/openassistant/templates/template.py @@ -0,0 +1,205 @@ +# Copyright 2023 The OpenAssistant Authors and the current dataset script contributor. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This template serves as a starting point for contributing a dataset to the OpenAssistant repo. + +When modifying it for your dataset, look for TODO items that offer specific instructions. + +To create a dataset loading script you will create a class and implement 3 methods: + * `_info`: Establishes the schema for the dataset, and returns a datasets.DatasetInfo object. + * `_split_generators`: Downloads and extracts data for each split (e.g. train/val/test) or associates local data with each split. + * `_generate_examples`: Creates examples from data on disk that conform to each schema defined in `_info`. + +Full documentation on writing dataset loading scripts can be found here: +https://huggingface.co/docs/datasets/dataset_script + +This template is adapted from the one provided by BigScience's BigBIO library: +https://github.com/bigscience-workshop/biomedical/blob/main/templates/template.py + +TODO: Before submitting your script, delete this docstring and replace it with a description of your dataset. +""" + +import os +from typing import Dict, List, Tuple + +import datasets + +from .hub import OpenAssistantConfig + +# TODO: import the schema (i.e. features) that fits your dataset: +from .hub import + +# TODO: Add BibTeX citation where appropriate +_CITATION = """\ +@article{, + author = {}, + title = {}, + journal = {}, + volume = {}, + year = {}, + url = {}, + doi = {}, + biburl = {}, + bibsource = {} +} +""" + +# TODO: create a module level variable with your dataset name (should match the script name) +# E.g. The Pile: [dataset_name] --> the_pile +_DATASETNAME = "[dataset_name]" +# TODO: create a pretty display name for your dataset +_DISPLAYNAME = "Dataset Name" + +# TODO: Add a description of the dataset here +# You can copy an official description +_DESCRIPTION = """\ +This dataset is designed for XXX NLP task. +""" + +# TODO: Add a link to an official homepage for the dataset here (if possible) +_HOMEPAGE = "" + +# TODO: Add the licence for the dataset here (if possible) +# Note that this doesn't have to be a common open source license. +# Some datasets have custom licenses. In this case, simply put the full license terms +# into `_LICENSE` +_LICENSE = "" + +# TODO: Add links to the URLs needed to download your dataset files. +# This variable can be a relative path for datasets whose files need to be +# manually downloaded or preprocessed in advance. + +# For publicly available datasets you will most likely end up passing these URLs to dl_manager in _split_generators. +# However, if you need to access different files for each config you can have multiple entries in this dict. +# This can be an arbitrarily nested dict/list of URLs (see below in `_split_generators` method) +_URLS = { + _DATASETNAME: "url or list of urls or relative path like ./data ", +} + +# TODO: add supported task by dataset. One dataset may support multiple tasks +_SUPPORTED_TASKS = [] # example: [Tasks.TRANSLATION, Tasks.NAMED_ENTITY_RECOGNITION, Tasks.RELATION_EXTRACTION] + +# TODO: set this to a version that is associated with the dataset. if none exists use "1.0.0" +# This version doesn't have to be consistent with semantic versioning. Anything that is +# provided by the original dataset as a version goes. +_VERSION = "" + + +# TODO: Name the dataset class to match the script name using CamelCase instead of snake_case +# Append "Dataset" to the class name: ThePile --> ThePileDataset +class NewDataset(datasets.GeneratorBasedBuilder): + """TODO: Short description of my dataset.""" + + VERSION = datasets.Version(_VERSION) + + # You will be able to load each dataset with + # dataset = datasets.load_dataset('my_dataset') + + # TODO: For each dataset, implement a config for each subset; + # If a dataset contains more than one subset, implement a config for EACH of them. + # Each of them should contain: + # - name: should be unique for each dataset config eg. the_pile_[schema_name] + # - version: VERSION + # - description: one line description for the dataset + # - schema: open_assistant_[schema_name] + # - subset_id: subset id is the canonical name for the dataset (eg. the_pile) + # where [schema_name] = (language_modeling) + + BUILDER_CONFIGS = [ + OpenAssistantConfig( + name=f"{_DATASETNAME}_[schema_name]", + version=VERSION, + description=f"OpenAssistant dataset config for {_DATASETNAME}", + schema_name="[schema_name]", + subset_id=_DATASETNAME, + ) + ] + + DEFAULT_CONFIG_NAME = _DATASETNAME + + def _info(self) -> datasets.DatasetInfo: + # TODO: Implement the schema for your dataset here. + raise NotImplementedError() + + return datasets.DatasetInfo( + description=_DESCRIPTION, + features=features, + homepage=_HOMEPAGE, + license=_LICENSE, + citation=_CITATION, + ) + + def _split_generators(self, dl_manager) -> List[datasets.SplitGenerator]: + """Returns SplitGenerators.""" + # TODO: This method is tasked with downloading/extracting the data and defining the splits depending on the configuration + + # If you need to access a config choice, that will be in self.config.name + + # dl_manager is a datasets.download.DownloadManager that can be used to download and extract URLs; many examples use the download_and_extract method; see the DownloadManager docs here: https://huggingface.co/docs/datasets/package_reference/builder_classes.html#datasets.DownloadManager + + # dl_manager can accept any type of nested list/dict and will give back the same structure with the url replaced with the path to local files. + + urls = _URLS[_DATASETNAME] + data_dir = dl_manager.download_and_extract(urls) + + # Not all datasets have predefined canonical train/val/test splits. + # If your dataset has no predefined splits, use datasets.Split.TRAIN for all of the data. + + return [ + datasets.SplitGenerator( + name=datasets.Split.TRAIN, + # Whatever you put in gen_kwargs will be passed to _generate_examples + gen_kwargs={ + "filepath": os.path.join(data_dir, "train.jsonl"), + "split": "train", + }, + ), + datasets.SplitGenerator( + name=datasets.Split.TEST, + gen_kwargs={ + "filepath": os.path.join(data_dir, "test.jsonl"), + "split": "test", + }, + ), + datasets.SplitGenerator( + name=datasets.Split.VALIDATION, + gen_kwargs={ + "filepath": os.path.join(data_dir, "dev.jsonl"), + "split": "dev", + }, + ), + ] + + # method parameters are unpacked from `gen_kwargs` as given in `_split_generators` + + # TODO: change the args of this function to match the keys in `gen_kwargs`. You may add any necessary kwargs. + + def _generate_examples(self, filepath, split: str) -> Tuple[int, Dict]: + """Yields examples as (key, example) tuples.""" + # TODO: This method handles input defined in _split_generators to yield (key, example) tuples from the dataset. + + # The `key` is for legacy reasons (tfds) and is not important in itself, but must be unique for each example. + + # NOTE: For local datasets you will have access to self.config.data_dir and self.config.data_files + + if self.config.schema == "[schema_name]": + # TODO: yield (key, example) tuples in the given schema + for key, example in thing: + yield key, example + +# This allows you to run your dataloader with `python [dataset_name].py` during development +# TODO: Remove this before making your PR +if __name__ == "__main__": + datasets.load_dataset(__file__) diff --git a/scripts/frontend-development/run-contract-test.sh b/scripts/frontend-development/run-contract-test.sh new file mode 100755 index 00000000..6bedc903 --- /dev/null +++ b/scripts/frontend-development/run-contract-test.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash +parent_path=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P ) + +# switch to website directory +pushd "$parent_path/../../website" + +set -xe + +npm run cypress:run:contract + +popd diff --git a/scripts/postprocessing/infogain_selector.py b/scripts/postprocessing/infogain_selector.py index 4eedbc5c..fc644a83 100644 --- a/scripts/postprocessing/infogain_selector.py +++ b/scripts/postprocessing/infogain_selector.py @@ -1,9 +1,11 @@ import numpy as np -from scipy import log2 -from scipy.integrate import nquad from scipy.special import gammaln, psi from scipy.stats import dirichlet +''' +Legacy numerical solution. +Should not be used as it is probably broken + def make_range(*x): """ @@ -38,6 +40,23 @@ def naive_monte_carlo_integral(fun, dim, samples=10_000_000): res = fun(pos) return np.mean(res) +def infogain(a_post, a_prior): + raise ( + """For the love of good don't use this: + it's insanely poorly conditioned, the worst numerical code I have ever written + and it's slow as molasses. Use the analytic solution instead. + + Maybe remove + """ + ) + args = len(a_prior) + p = dirichlet(a_post).pdf + q = dirichlet(a_prior).pdf + (info, _) = nquad(relative_entropy(p, q), [make_range for _ in range(args - 1)], opts={"epsabs": 1e-8}) + # info = naive_monte_carlo_integral(relative_entropy(p,q), len(a_post)) + return info +''' + def analytic_solution(a_post, a_prior): """ @@ -57,26 +76,8 @@ def analytic_solution(a_post, a_prior): return info -def infogain(a_post, a_prior): - raise ( - """For the love of good don't use this: - it's insanely poorly conditioned, the worst numerical code I have ever written - and it's slow as molasses. Use the analytic solution instead. - - Maybe remove - """ - ) - args = len(a_prior) - p = dirichlet(a_post).pdf - q = dirichlet(a_prior).pdf - (info, _) = nquad(relative_entropy(p, q), [make_range for _ in range(args - 1)], opts={"epsabs": 1e-8}) - # info = naive_monte_carlo_integral(relative_entropy(p,q), len(a_post)) - return info - - def uniform_expected_infogain(a_prior): mean_weight = dirichlet.mean(a_prior) - print("weight", mean_weight) results = [] for i, w in enumerate(mean_weight): a_post = a_prior.copy() diff --git a/scripts/postprocessing/scoring.py b/scripts/postprocessing/scoring.py index efd236ce..5e76d19c 100644 --- a/scripts/postprocessing/scoring.py +++ b/scripts/postprocessing/scoring.py @@ -87,8 +87,9 @@ def score_update_prompts(consensus: npt.ArrayLike, voter_data: Voter) -> Voter: """ This function returns the gain of points for a given prompt's votes - This function is only to be run when archiving a question - i.e. the question has had sufficiently many votes, or we cann't get more than "K" bits of information + In contrast to the other score updating functions, we can run this online as new votes come in. + i.e. the question has had sufficiently many votes, or we cann't get more than "K" bits of information. + Parameters: consensus (ArrayLike): all votes cast for this question @@ -100,7 +101,8 @@ def score_update_prompts(consensus: npt.ArrayLike, voter_data: Voter) -> Voter: # produces the ranking of votes, e.g. for [100,300,200] it returns [0, 2, 1], # since 100 is the lowest, 300 the highest and 200 the middle value consensus_ranking = np.arange(len(consensus)) - len(consensus) // 2 + 1 - delta_votes = np.sum(consensus_ranking * consensus) + # expected consenus ranking (i.e. normalize the votes and multiply-sum with weightings) + delta_votes = np.sum(consensus_ranking * consensus / sum(consensus)) new_points = delta_votes + voter_data.prompt_points # we need to correct for 0 indexing, if you are closer to "right" than "wrong" of the conensus, @@ -133,7 +135,7 @@ def score_update_ranking(user_ranking: npt.ArrayLike, consensus_ranking: npt.Arr "research design and statistical analyses, second edition, 2003" the authors note that at least from an significance test POV they will yield the same p-values - Parameters: + Parameters: user_ranking (ArrayLike): ranking produced by the user consensus (ArrayLike): ranking produced after running the voting algorithm to merge into the consensus ranking voter_data (Voter): a "Voter" object that represents the person that wrote the prompt diff --git a/scripts/postprocessing/task_schedule.py b/scripts/postprocessing/task_schedule.py new file mode 100644 index 00000000..deb302b2 --- /dev/null +++ b/scripts/postprocessing/task_schedule.py @@ -0,0 +1,75 @@ +from enum import Enum + +import numpy as np +from scipy import optimize + + +class Task(Enum): + RANKING = 0 + ANSWER = 1 + PROMPT = 2 + VOTE = 3 + + +def task_selection( + num_ranking_tasks: int, current_prompts: int, target_num_prompts: int, p: float, answers_per_prompt: int +) -> Task: + """ + This computes which task to serve to the user. + In general, this method aims to get rankable tasks out of the active pool ASAP. + Before checking anything else, we first have a p% probability of running a ranking task. + After that, we can dynamically determine which task to serve by balancing the number of active tasks. + + Parameters: + num_ranking_tasks (int): number of prompts that are ready to do ranking (i.e. have "answers_per_prompt" many answers) + current_prompts (int): how many prompts are currently in the active pool + target_num_prompts (int): how many prompts _should_ be in the active pool + p (float): probability to serve a ranking task, if one is available + answers_per_prompt (int): number of answers we want to have per prompt + Returns: + task (Task): the task Enum that corresponds to one of the four tasks + """ + if num_ranking_tasks > 0 and np.random.rand() < p: + return Task.RANKING + rate = 50 / (current_prompts * 2) + prob_prompt_task = 0.5 + (target_num_prompts - current_prompts) * rate + # Yes, I'm too lazy to solve this analytically... + prob_unfinished_prompt = optimize.linprog( + np.array([1, 1]), A_eq=np.array([[1, 1], [1, -answers_per_prompt]]), b_eq=np.array([1, 0]), bounds=(0, None) + ).x[0] + if np.random.rand() < prob_prompt_task: + if np.random.rand() < prob_unfinished_prompt: + return Task.ANSWER + else: + return Task.PROMPT + else: + return Task.VOTE + + +def next_answer_task(possible_prompts, answers_per_prompt): + """ + If the `task_selection`method returns "answer", you can use this method to decide which + prompt should get an answer next. + The goal of this is to finish off the prompts that have almost enough answers collected already: + I.e. if we want 5 answers, this is going to give preferential sampling to those prompts that already + have 4/5 answers. + This helps to not have too much close-to-finished prompts in the active set. + + Parameters: + possible_prompts (dict[prompt_id, num_answers]): a dictonary containing all open prompts and the number of answers these prompts currently have. + answers_per_prompt (int): number of answers we per prompt to target + Returns: + prompt_id (int): the prompt_id corresponding to the next prompt that should get a new answer + """ + nums = list(set(possible_prompts.values())) + p = np.array([max(x / answers_per_prompt, 1 / answers_per_prompt) for x in nums]) + idx = np.random.choice(nums, p=p / p.sum()) + sample = np.random.choice([k for k, v in possible_prompts.items() if v == idx]) + return sample + + +if __name__ == "__main__": + x = task_selection(1, 500, 1000, 0.1, 5) + print(x) + y = next_answer_task({"this": 2, "is": 4, "a": 1, "test": 4}, 5) + print(y) diff --git a/website/.env b/website/.env index 9544836b..65d8b88e 100644 --- a/website/.env +++ b/website/.env @@ -1,3 +1,5 @@ +ADMIN_USERS = "credentials:admin,discord:root,email:admin@example.com" + # The database created by running the jobs in /scripts/frontend-development/docker-compose.yaml DATABASE_URL=postgres://postgres:postgres@localhost:5433/oasst_web diff --git a/website/README.md b/website/README.md index 5198a820..11e3ccc4 100644 --- a/website/README.md +++ b/website/README.md @@ -53,7 +53,7 @@ If you're doing active development we suggest the following workflow: 1. Run `docker compose up frontend-dev --build --attach-dependencies`. You can optionally include `-d` to detach and later track the logs if desired. 1. In another tab navigate to `${OPEN_ASSISTANT_ROOT/website`. -1. Run `npm install` +1. Run `npm ci` 1. Run `npx prisma db push` (This is also needed when you restart the docker stack from scratch). 1. Run `npm run dev`. Now the website is up and running locally at diff --git a/website/cypress.config.contract.js b/website/cypress.config.contract.js new file mode 100644 index 00000000..f4461158 --- /dev/null +++ b/website/cypress.config.contract.js @@ -0,0 +1,9 @@ +import { defineConfig } from "cypress"; + +export default defineConfig({ + e2e: { + // No baseUrl here, because we don't need it for contract testing + baseUrl: null, + specPattern: "cypress/contract/*.cy.{ts,js}", + }, +}); diff --git a/website/cypress/contract/oasst_api_contract_tests.cy.ts b/website/cypress/contract/oasst_api_contract_tests.cy.ts new file mode 100644 index 00000000..1ba408d6 --- /dev/null +++ b/website/cypress/contract/oasst_api_contract_tests.cy.ts @@ -0,0 +1,50 @@ +import { OasstApiClient } from "src/lib/oasst_api_client"; + +describe("Contract test for Oasst API", function () { + // Assumes this is running the mock server. + const oasstApiClient = new OasstApiClient("http://localhost:8080", "test"); + + it("can fetch a task", async () => { + expect( + await oasstApiClient.fetchTask("random", { + sub: "test", + name: "test", + email: "test", + }) + ).to.be.not.null; + }); + + it("can ack a task", async () => { + const task = await oasstApiClient.fetchTask("random", { + sub: "test", + name: "test", + email: "test", + }); + expect(await oasstApiClient.ackTask(task.id, "321")).to.be.null; + }); + + it("can record a taskInteraction", async () => { + const task = await oasstApiClient.fetchTask("random", { + sub: "test", + name: "test", + email: "test", + }); + expect( + await oasstApiClient.interactTask( + "text_reply_to_message", + task.id, + "1", + { text: "Test" }, + { + sub: "test", + name: "test", + email: "test", + } + ) + ).to.be.not.null; + }); + + // TODO(#354): Add test for 204 + // TODO(#354): Add test for parsing >=300, throwing an OasstError + // TODO(#354): Add test for parsing >=300, throwing a generic error +}); diff --git a/website/next.config.js b/website/next.config.js index 2c37ebe6..28da824f 100644 --- a/website/next.config.js +++ b/website/next.config.js @@ -2,6 +2,14 @@ const nextConfig = { output: "standalone", reactStrictMode: true, + images: { + remotePatterns: [ + { + protocol: "https", + hostname: "**.discordapp.com", + }, + ], + }, experimental: { /* Disabling this for now only because it causes a warning in the console that cannot be silenced for eslint If this can be resolved, we should re-enable this. diff --git a/website/package-lock.json b/website/package-lock.json index 1803cbca..5cd3d9bb 100644 --- a/website/package-lock.json +++ b/website/package-lock.json @@ -39,7 +39,7 @@ "react": "18.2.0", "react-dom": "18.2.0", "react-icons": "^4.7.1", - "sharp": "^0.31.3", + "sharp": "0.31.2", "swr": "^2.0.0", "tailwindcss": "^3.2.4", "use-debounce": "^9.0.2" @@ -27907,9 +27907,9 @@ } }, "node_modules/sharp": { - "version": "0.31.3", - "resolved": "https://registry.npmjs.org/sharp/-/sharp-0.31.3.tgz", - "integrity": "sha512-XcR4+FCLBFKw1bdB+GEhnUNXNXvnt0tDo4WsBsraKymuo/IAuPuCBVAL2wIkUw2r/dwFW5Q5+g66Kwl2dgDFVg==", + "version": "0.31.2", + "resolved": "https://registry.npmjs.org/sharp/-/sharp-0.31.2.tgz", + "integrity": "sha512-DUdNVEXgS5A97cTagSLIIp8dUZ/lZtk78iNVZgHdHbx1qnQR7JAHY0BnXnwwH39Iw+VKhO08CTYhIg0p98vQ5Q==", "hasInstallScript": true, "dependencies": { "color": "^4.2.3", @@ -52121,9 +52121,9 @@ } }, "sharp": { - "version": "0.31.3", - "resolved": "https://registry.npmjs.org/sharp/-/sharp-0.31.3.tgz", - "integrity": "sha512-XcR4+FCLBFKw1bdB+GEhnUNXNXvnt0tDo4WsBsraKymuo/IAuPuCBVAL2wIkUw2r/dwFW5Q5+g66Kwl2dgDFVg==", + "version": "0.31.2", + "resolved": "https://registry.npmjs.org/sharp/-/sharp-0.31.2.tgz", + "integrity": "sha512-DUdNVEXgS5A97cTagSLIIp8dUZ/lZtk78iNVZgHdHbx1qnQR7JAHY0BnXnwwH39Iw+VKhO08CTYhIg0p98vQ5Q==", "requires": { "color": "^4.2.3", "detect-libc": "^2.0.1", diff --git a/website/package.json b/website/package.json index c1d0c3d2..7d3b680e 100644 --- a/website/package.json +++ b/website/package.json @@ -12,6 +12,7 @@ "build-storybook": "build-storybook", "cypress": "cypress open", "cypress:run": "cypress run", + "cypress:run:contract": "cypress run --config-file ./cypress.config.contract.js", "cypress:image-baseline": "cypress-image-diff -u", "fix:lint": "eslint --fix src/ --ext .js,.jsx,.ts,.tsx", "fix:format": "prettier --write ./src", @@ -49,7 +50,7 @@ "react": "18.2.0", "react-dom": "18.2.0", "react-icons": "^4.7.1", - "sharp": "^0.31.3", + "sharp": "0.31.2", "swr": "^2.0.0", "tailwindcss": "^3.2.4", "use-debounce": "^9.0.2" diff --git a/website/prisma/schema.prisma b/website/prisma/schema.prisma index 3e379d43..f9eab3b7 100644 --- a/website/prisma/schema.prisma +++ b/website/prisma/schema.prisma @@ -41,6 +41,7 @@ model User { email String? @unique emailVerified DateTime? image String? + role String @default("general") accounts Account[] sessions Session[] diff --git a/website/src/components/CollapsableText.tsx b/website/src/components/CollapsableText.tsx new file mode 100644 index 00000000..b0cfb3b9 --- /dev/null +++ b/website/src/components/CollapsableText.tsx @@ -0,0 +1,28 @@ +import { Button, useDisclosure } from "@chakra-ui/react"; +import { Modal, ModalOverlay, ModalContent, ModalHeader, ModalBody, ModalCloseButton } from "@chakra-ui/react"; +import React from "react"; + +export const CollapsableText = ({ text, maxLength = 220 }) => { + const { isOpen, onOpen, onClose } = useDisclosure(); + if (typeof text != "string" || text.length <= maxLength) { + return text; + } else { + return ( + <> + {text.substring(0, maxLength - 3)} + + + + + Full Text + + {text} + + + + + ); + } +}; diff --git a/website/src/components/Container.tsx b/website/src/components/Container.tsx index f65ed282..4149c0da 100644 --- a/website/src/components/Container.tsx +++ b/website/src/components/Container.tsx @@ -1,5 +1,5 @@ import clsx from "clsx"; export function Container({ className, ...props }) { - return
; + return
; } diff --git a/website/src/components/ContextMessages.tsx b/website/src/components/ContextMessages.tsx new file mode 100644 index 00000000..150dddfa --- /dev/null +++ b/website/src/components/ContextMessages.tsx @@ -0,0 +1,17 @@ +import { Box } from "@chakra-ui/react"; +import { Message } from "./Messages"; + +export const ContextMessages = ({ messages }: { messages: Message[] }) => { + return ( + + {messages.map((message, i) => { + return ( + + {message.is_assistant ? "Assistant: " : "User: "} + {message.text} + + ); + })} + + ); +}; diff --git a/website/src/components/Dashboard/LeaderboardTable.tsx b/website/src/components/Dashboard/LeaderboardTable.tsx new file mode 100644 index 00000000..b958d4b7 --- /dev/null +++ b/website/src/components/Dashboard/LeaderboardTable.tsx @@ -0,0 +1,94 @@ +import { Badge, Box, Image, Link, Stack, StackDivider, Text, useColorModeValue } from "@chakra-ui/react"; + +export function LeaderboardTable() { + const backgroundColor = useColorModeValue("white", "gray.700"); + const accentColor = useColorModeValue("gray.200", "gray.900"); + + //need to add streak info to chart + + const leaderInfo = [ + { + name: "fozziethebeat#6690", + image: "/images/temp-avatars/av1.jpg", + score: "5,208", + arrowDir: "increase", + streak: false, + streakCount: "5-Day Streak", + }, + { + name: "k_nearest_neighbor#8579", + image: "/images/temp-avatars/av2.jpg", + score: "5,164", + arrowDir: "decrease", + streak: false, + streakCount: "", + }, + { + name: "andreaskoepf#2266", + image: "/images/temp-avatars/av3.jpg", + score: "5,120", + arrowDir: "", + streak: false, + streakCount: "2-Day Streak", + }, + { + name: "AbdBarho#1684", + image: "/images/temp-avatars/av4.jpg", + score: "4,260", + arrowDir: "", + streak: false, + streakCount: "", + }, + { + name: "zu#9016", + image: "/images/temp-avatars/av5.jpg", + score: "3,608", + arrowDir: "", + streak: false, + streakCount: "", + }, + ]; + + return ( +
+
+
+ Top 5 Contributors + + + View All -> + + +
+ + } spacing="4"> +
+

Name

+
+

Score

+
+
+ {leaderInfo.map((item, itemIndex) => ( +
+
+ Profile Picture +

{item.name}

+ {item.streakCount} +
+ +

{item.score}

+
+
+ ))} +
+
+
+
+ ); +} diff --git a/website/src/components/Dashboard/SideMenu.tsx b/website/src/components/Dashboard/SideMenu.tsx new file mode 100644 index 00000000..499117a2 --- /dev/null +++ b/website/src/components/Dashboard/SideMenu.tsx @@ -0,0 +1,89 @@ +import { Box, Button, Link, Text, Tooltip, useColorMode } from "@chakra-ui/react"; +import { useRouter } from "next/router"; +import { FiLayout, FiSun, FiMessageSquare } from "react-icons/fi"; +import { colors } from "styles/Theme/colors"; + +export function SideMenu() { + const router = useRouter(); + const { colorMode, toggleColorMode } = useColorMode(); + const buttonOptions = [ + { + label: "Dashboard", + pathname: "/dashboard", + desc: "Dashboard Home", + icon: FiLayout, + }, + { + label: "Messages", + pathname: "/messages", + desc: "Messages Dashboard", + icon: FiMessageSquare, + }, + // { + // label: "Leaderboard", + // pathname: "#", + // desc: "Public Leaderboard", + // icon: FiAward, + // }, + // { + // label: "Stats", + // pathname: "#", + // desc: "User Statistics", + // icon: FiBarChart2, + // }, + ]; + + return ( +
+ + +
+ + + +
+
+
+ ); +} diff --git a/website/src/components/Dashboard/TaskOption.tsx b/website/src/components/Dashboard/TaskOption.tsx new file mode 100644 index 00000000..50f707a6 --- /dev/null +++ b/website/src/components/Dashboard/TaskOption.tsx @@ -0,0 +1,126 @@ +import { Box, Flex, GridItem, Heading, SimpleGrid, Text, useColorModeValue } from "@chakra-ui/react"; +import Link from "next/link"; + +const crTasks = [ + { + label: "Create Initial Prompts", + desc: "Write initial prompts to help Open Assistant to try replying to diverse messages.", + type: "create", + pathname: "/create/initial_prompt", + }, + { + label: "Reply as User", + desc: "Chat with Open Assistant and help improve it’s responses as you interact with it.", + type: "create", + pathname: "/create/user_reply", + }, + { + label: "Reply as Assistant", + desc: "Help Open Assistant improve its responses to conversations with other users.", + type: "create", + pathname: "/create/assistant_reply", + }, +]; + +const evTasks = [ + { + label: "Rank User Replies", + type: "eval", + desc: "Help Open Assistant improve its responses to conversations with other users.", + pathname: "/evaluate/rank_user_replies", + }, + + { + label: "Rank Assistant Replies", + desc: "Score prompts given by Open Assistant based on their accuracy and readability.", + type: "eval", + pathname: "/evaluate/rank_assistant_replies", + }, + { + label: "Rank Initial Prompts", + desc: "Score prompts given by Open Assistant based on their accuracy and readability.", + type: "eval;", + pathname: "/evaluate/rank_initial_prompts", + }, +]; + +export const TaskOption = () => { + const backgroundColor = useColorModeValue("white", "gray.700"); + + return ( + +
+ Create + + {crTasks.map((item, itemIndex) => ( + + + + + + {item.label} + + + {item.desc} + + + + + + Go + + + + + ))} + +
+
+ Evaluate + + {evTasks.map((item, itemIndex) => ( + + + + + + {item.label} + + + {item.desc} + + + + + + Go + + + + + ))} + +
+
+ ); +}; diff --git a/website/src/components/Dashboard/index.ts b/website/src/components/Dashboard/index.ts new file mode 100644 index 00000000..0b4ff49a --- /dev/null +++ b/website/src/components/Dashboard/index.ts @@ -0,0 +1,3 @@ +export { LeaderboardTable } from "./LeaderboardTable"; +export { SideMenu } from "./SideMenu"; +export { TaskOption } from "./TaskOption"; diff --git a/website/src/components/FlaggableElement.tsx b/website/src/components/FlaggableElement.tsx index c8fc17ce..bc245bd2 100644 --- a/website/src/components/FlaggableElement.tsx +++ b/website/src/components/FlaggableElement.tsx @@ -15,6 +15,7 @@ import { SliderThumb, SliderTrack, Spacer, + Tooltip, useBoolean, useId, } from "@chakra-ui/react"; @@ -69,11 +70,15 @@ export const FlaggableElement = (props) => { > {props.children} - - - + +
+ + + +
+
diff --git a/website/src/components/Header/Header.tsx b/website/src/components/Header/Header.tsx index 8b8c4663..4ea453c6 100644 --- a/website/src/components/Header/Header.tsx +++ b/website/src/components/Header/Header.tsx @@ -1,43 +1,11 @@ -import { Box, Button, useColorMode } from "@chakra-ui/react"; -import { Popover } from "@headlessui/react"; -import { AnimatePresence, motion } from "framer-motion"; +import { Box, Button, Text, useColorMode } from "@chakra-ui/react"; import Image from "next/image"; import Link from "next/link"; import { useSession } from "next-auth/react"; import { FaUser } from "react-icons/fa"; -import { ColorModeIconToggle } from "../UI/ColorModeIconToggle"; import { UserMenu } from "./UserMenu"; -function MenuIcon(props) { - return ( - - ); -} - -function ChevronUpIcon(props) { - return ( - - ); -} - -function MobileNavLink({ children, ...props }) { - return ( - - {children} - - ); -} - function AccountButton() { const { data: session } = useSession(); if (session) { @@ -53,70 +21,29 @@ function AccountButton() { } export function Header(props) { + const { data: session } = useSession(); + const homeURL = session ? "/dashboard" : "/"; + const { colorMode } = useColorMode(); const borderClass = props.transparent ? "" : colorMode === "light" ? "border-b border-gray-400" : "border-b border-zinc-800"; - return ( diff --git a/website/src/components/Header/NavLinks.tsx b/website/src/components/Header/NavLinks.tsx index 4f559e7e..47bd90be 100644 --- a/website/src/components/Header/NavLinks.tsx +++ b/website/src/components/Header/NavLinks.tsx @@ -1,7 +1,8 @@ -import { useColorMode } from "@chakra-ui/react"; +import { Text, useColorMode } from "@chakra-ui/react"; import { AnimatePresence, motion } from "framer-motion"; import Link from "next/link"; import { useState } from "react"; +import { colors } from "styles/Theme/colors"; export function NavLinks(): JSX.Element { const [hoveredIndex, setHoveredIndex] = useState(null); @@ -14,8 +15,8 @@ export function NavLinks(): JSX.Element { return ( <> {[ - ["Join Us", "/#join-us"], ["FAQ", "/#faq"], + ["Join Us", "/#join-us"], ].map(([label, href], index) => ( )} - {label} + + {label} + ))} diff --git a/website/src/components/Header/UserMenu.tsx b/website/src/components/Header/UserMenu.tsx index 35b71698..2e06606c 100644 --- a/website/src/components/Header/UserMenu.tsx +++ b/website/src/components/Header/UserMenu.tsx @@ -1,25 +1,33 @@ -import { Box, useColorModeValue } from "@chakra-ui/react"; +import { Box, Link, Text, useColorModeValue } from "@chakra-ui/react"; import { Popover } from "@headlessui/react"; import { AnimatePresence, motion } from "framer-motion"; import Image from "next/image"; import { signOut, useSession } from "next-auth/react"; import React from "react"; -import { FaCog, FaSignOutAlt } from "react-icons/fa"; +import { FiLayout, FiLogOut, FiSettings } from "react-icons/fi"; export function UserMenu() { const { data: session } = useSession(); - const backgroundColor = useColorModeValue("#FFFFFF", "#000000"); + const backgroundColor = useColorModeValue("white", "gray.700"); + const accentColor = useColorModeValue("gray.300", "gray.600"); if (!session) { return <>; } if (session && session.user) { const accountOptions = [ + { + name: "Dashboard", + href: "/dashboard", + desc: "Dashboard", + icon: FiLayout, + //For future use + }, { name: "Account Settings", href: "/account", desc: "Account Settings", - icon: FaCog, + icon: FiSettings, //For future use }, ]; @@ -28,18 +36,22 @@ export function UserMenu() { {({ open }) => ( <> -
+ Profile Picture

{session.user.name || session.user.email}

-
+
{open && ( @@ -54,35 +66,45 @@ export function UserMenu() { y: -10, transition: { duration: 0.2 }, }} - className="absolute right-0 mt-3 w-screen bg-inherit max-w-xs p-4 rounded-md border border-slate-300/70" > - - {accountOptions.map((item) => ( - + + {accountOptions.map((item) => ( + +
+
+
+ {item.name} +
+ + ))} + signOut({ callbackUrl: "/" })} >
-
-

{item.name}

+ Sign Out
-
- ))} - signOut({ callbackUrl: "/" })} - > -
- -
-
-

Sign Out

-
-
+ +
diff --git a/website/src/components/Layout.tsx b/website/src/components/Layout.tsx index 3564d765..bf662113 100644 --- a/website/src/components/Layout.tsx +++ b/website/src/components/Layout.tsx @@ -25,4 +25,11 @@ export const getTransparentHeaderLayout = (page: React.ReactElement) => (
); +export const getDashboardLayout = (page: React.ReactElement) => ( +
+
+ {page} +
+); + export const noLayout = (page: React.ReactElement) => page; diff --git a/website/src/components/Messages/MessageTable.tsx b/website/src/components/Messages/MessageTable.tsx new file mode 100644 index 00000000..65f6efb1 --- /dev/null +++ b/website/src/components/Messages/MessageTable.tsx @@ -0,0 +1,12 @@ +import { Box, CircularProgress, Stack, StackDivider, useColorModeValue } from "@chakra-ui/react"; +import { MessageTableEntry } from "./MessageTableEntry"; + +export function MessageTable({ messages }) { + return ( + } spacing="4"> + {messages.map((item, idx) => ( + + ))} + + ); +} diff --git a/website/src/components/Messages/MessageTableEntry.tsx b/website/src/components/Messages/MessageTableEntry.tsx new file mode 100644 index 00000000..2bc17201 --- /dev/null +++ b/website/src/components/Messages/MessageTableEntry.tsx @@ -0,0 +1,24 @@ +import { Avatar, Box, HStack, LinkBox, useColorModeValue } from "@chakra-ui/react"; +import { boolean } from "boolean"; +import NextLink from "next/link"; +import { FlaggableElement } from "../FlaggableElement"; + +export function MessageTableEntry({ item, idx }) { + const bgColor = useColorModeValue(idx % 2 === 0 ? "bg-slate-800" : "bg-black", "bg-sky-900"); + + return ( + + + + + + {item.text} + + + + + ); +} diff --git a/website/src/components/Roadmap.tsx b/website/src/components/Roadmap.tsx new file mode 100644 index 00000000..283a4c8f --- /dev/null +++ b/website/src/components/Roadmap.tsx @@ -0,0 +1,78 @@ +import { Container } from "./Container"; + +const Roadmap = () => { + return ( + +
+

Our Roadmap

+
+
+
+

ASAP

+
+

Minimum Viable Prototype

+
    +
  • Data Collection Pipeline
  • +
  • RL on Human Feedback
  • +
  • Assistant v1 usable
  • +
  • Out January 2023!
  • +
+
+
+ +
+ +
+
+

+ Q1 +
+ 2023 +

+
+

Growing Up

+
    +
  • Retrieval Augmentation
  • +
  • Rapid Personalization
  • +
  • Using External Tools
  • +
+
+
+ +
+ +
+
+

+ Q2 +
+ 2023 +

+
+

Growing Up

+
    +
  • Third-Party Extentions
  • +
  • Device Control
  • +
  • Multi-Modality
  • +
+
+
+ +
+ +
+
+

...

+
+

Growing Up

+
    +
  • What do you need?
  • +
+
+
+
+
+ ); +}; + +export default Roadmap; diff --git a/website/src/components/Services.tsx b/website/src/components/Services.tsx new file mode 100644 index 00000000..ae20a6d7 --- /dev/null +++ b/website/src/components/Services.tsx @@ -0,0 +1,50 @@ +import { Container } from "./Container"; + +const Services = () => { + return ( +
+ +
+
+
+

Your Conversational Assistant

+ +

State-of-the-Art chat assistant that can be personalized to your needs

+
+
+

Interface w/ external systems

+ +

+ Usage of APIs and third-party applications, described via language & demonstrations. +

+
+
+
+
+

Retrieval via Search Engines

+ +

External, upgradeable knowledge: No need for billions of parameters.

+
+
+

A building block for developers

+ +

Integrate OpenAssistant into your application.

+
+
+
+

OpenAssistant unifies all knowledge work in one place

+ +
    +
  • Uses modern deep learning
  • +
  • Runs on consumer hardware
  • +
  • Trains on human feedback
  • +
  • Free and open
  • +
+
+
+
+
+ ); +}; + +export default Services; diff --git a/website/src/components/Sortable/Sortable.tsx b/website/src/components/Sortable/Sortable.tsx index 2f63ff27..270c67e7 100644 --- a/website/src/components/Sortable/Sortable.tsx +++ b/website/src/components/Sortable/Sortable.tsx @@ -18,6 +18,7 @@ import { } from "@dnd-kit/sortable"; import { ReactNode, useEffect, useState } from "react"; +import { CollapsableText } from "../CollapsableText"; import { SortableItem } from "./SortableItem"; export interface SortableProps { @@ -64,7 +65,7 @@ export const Sortable = (props: SortableProps) => { {itemsWithIds.map(({ id, item }) => ( - {item} + ))} diff --git a/website/src/components/Survey/TaskControls.tsx b/website/src/components/Survey/TaskControls.tsx index a93889ea..851e659c 100644 --- a/website/src/components/Survey/TaskControls.tsx +++ b/website/src/components/Survey/TaskControls.tsx @@ -4,7 +4,7 @@ import { SkipButton } from "src/components/Buttons/Skip"; import { SubmitButton } from "src/components/Buttons/Submit"; import { TaskInfo } from "src/components/TaskInfo/TaskInfo"; -interface TaskControlsProps { +export interface TaskControlsProps { // we need a task type // eslint-disable-next-line @typescript-eslint/no-explicit-any tasks: any[]; @@ -30,11 +30,11 @@ export const TaskControls = (props: TaskControlsProps) => { Skip {endTask.task.type !== "task_done" ? ( - props.onSubmitResponse(props.tasks[0])}> + props.onSubmitResponse(props.tasks[0])}> Submit ) : ( - + Next Task )} diff --git a/website/src/components/Survey/TaskControlsOverridable.tsx b/website/src/components/Survey/TaskControlsOverridable.tsx new file mode 100644 index 00000000..b5333087 --- /dev/null +++ b/website/src/components/Survey/TaskControlsOverridable.tsx @@ -0,0 +1,60 @@ +import { + Button, + Flex, + Modal, + ModalBody, + ModalCloseButton, + ModalContent, + ModalFooter, + ModalHeader, + ModalOverlay, + useDisclosure, +} from "@chakra-ui/react"; +import { TaskControls, TaskControlsProps } from "./TaskControls"; + +interface TaskControlsOverridableProps extends TaskControlsProps { + isValid: boolean; + prepareForSubmit: () => void; +} + +export const TaskControlsOverridable = (props: TaskControlsOverridableProps) => { + const { isValid, onSubmitResponse, ...rest } = props; + const { isOpen: isModalOpen, onOpen: onOpenModal, onClose: onModalClose } = useDisclosure(); + + const unchangedResponsePrompt = () => { + onOpenModal(); + + // Ideally this happens when the user clicks submit, but we can't + // reliably wait for it to be executed before submitting the response + // without significant refactoring. + // As a result, modal will only display once even if the user doesn't proceed + props.prepareForSubmit(); + }; + + const onSubmitResponseOverride = () => { + onSubmitResponse(props.tasks[0]); + onModalClose(); + }; + + return ( + <> + + + + + Order Unchanged + You have not changed the order of the prompts. Are you sure you would like to submit? + + + + + + + + + + + ); +}; diff --git a/website/src/components/Vision.tsx b/website/src/components/Vision.tsx new file mode 100644 index 00000000..f6dad567 --- /dev/null +++ b/website/src/components/Vision.tsx @@ -0,0 +1,25 @@ +import { Container } from "./Container"; +import Image from "next/image"; + +const Vision = () => { + return ( +
+ +
+
+

Our Vision

+

+ We want OpenAssistant to be the single, unifying platform that all other systems use to interface with + humans. +

+
+
+ temp-image +
+
+
+
+ ); +}; + +export default Vision; diff --git a/website/src/lib/oasst_api_client.ts b/website/src/lib/oasst_api_client.ts new file mode 100644 index 00000000..7e22544a --- /dev/null +++ b/website/src/lib/oasst_api_client.ts @@ -0,0 +1,98 @@ +import { JWT } from "next-auth/jwt"; + +declare global { + // eslint-disable-next-line no-var + var oasstApiClient: OasstApiClient | undefined; +} + +class OasstError { + message: string; + errorCode: number; + httpStatusCode: number; + + constructor(message: string, errorCode: number, httpStatusCode: number) { + this.message = message; + this.errorCode = errorCode; + this.httpStatusCode = httpStatusCode; + } +} + +export class OasstApiClient { + constructor(private readonly oasstApiUrl: string, private readonly oasstApiKey: string) {} + + private async post(path: string, body: any): Promise { + const resp = await fetch(`${this.oasstApiUrl}${path}`, { + method: "POST", + headers: { + "X-API-Key": this.oasstApiKey, + "Content-Type": "application/json", + }, + body: JSON.stringify(body), + }); + + if (resp.status == 204) { + return null; + } + + if (resp.status >= 300) { + const errorText = await resp.text(); + try { + const error = JSON.parse(errorText); + throw new OasstError(error.message, error.error_code, resp.status); + } catch (e) { + throw new OasstError(errorText, 0, resp.status); + } + } + + return await resp.json(); + } + + // TODO return a strongly typed Task? + // This method is used to store a task in RegisteredTask.task. + // This is a raw Json type, so we can't use it to strongly type the task. + async fetchTask(taskType: string, userToken: JWT): Promise { + return this.post("/api/v1/tasks/", { + type: taskType, + user: { + id: userToken.sub, + display_name: userToken.name || userToken.email, + auth_method: "local", + }, + }); + } + + async ackTask(taskId: string, messageId: string): Promise { + return this.post(`/api/v1/tasks/${taskId}/ack`, { + message_id: messageId, + }); + } + + // TODO return a strongly typed Task? + // This method is used to record interaction with task while fetching next task. + // This is a raw Json type, so we can't use it to strongly type the task. + async interactTask( + updateType: string, + messageId: string, + userMessageId: string, + content: object, + userToken: JWT + ): Promise { + return this.post("/api/v1/tasks/interaction", { + type: updateType, + user: { + id: userToken.sub, + display_name: userToken.name || userToken.email, + auth_method: "local", + }, + message_id: messageId, + user_message_id: userMessageId, + ...content, + }); + } +} + +export const oasstApiClient = + globalThis.oasstApiClient || new OasstApiClient(process.env.FASTAPI_URL, process.env.FASTAPI_KEY); +if (process.env.NODE_ENV !== "production") { + globalThis.oasstApiClient = oasstApiClient; +} diff --git a/website/src/middleware.ts b/website/src/middleware.ts index 8e48d3b7..b6a539b4 100644 --- a/website/src/middleware.ts +++ b/website/src/middleware.ts @@ -4,5 +4,5 @@ export { default } from "next-auth/middleware"; * Guards all pages under `/grading` and redirects them to the sign in page. */ export const config = { - matcher: ["/create/:path*", "/evaluate/:path*", "/account/:path*"], + matcher: ["/create/:path*", "/evaluate/:path*", "/account/:path*", "/dashboard"], }; diff --git a/website/src/pages/about.tsx b/website/src/pages/about.tsx new file mode 100644 index 00000000..fdce53f6 --- /dev/null +++ b/website/src/pages/about.tsx @@ -0,0 +1,39 @@ +import { Container } from "src/components/Container"; +import Services from "src/components/Services"; +import Vision from "src/components/Vision"; +import Roadmap from "src/components/Roadmap"; +import { CallToAction } from "src/components/CallToAction"; +import Image from "next/image"; + +const AboutPage = () => { + return ( +
+ +
+
+ temp-image +
+
+
+

What is OpenAssistant?

+

+ OpenAssistant is a chat-based assistant that understands tasks, can interact with third-party systems, + and retrieve information dynamically to do so. +

+
+

+ It can be extended and personalized easily and is developed as free, open-source software. +

+
+
+
+ + + + + +
+ ); +}; + +export default AboutPage; diff --git a/website/src/pages/account/edit.tsx b/website/src/pages/account/edit.tsx index a14c27f8..497e8238 100644 --- a/website/src/pages/account/edit.tsx +++ b/website/src/pages/account/edit.tsx @@ -7,6 +7,7 @@ import React, { useState } from "react"; export default function Account() { const { data: session } = useSession(); const [username, setUsername] = useState(""); + const updateUser = async (e: React.SyntheticEvent) => { e.preventDefault(); try { @@ -16,6 +17,7 @@ export default function Account() { headers: { "Content-Type": "application/json" }, body: JSON.stringify(body), }); + session.user.name = username; await Router.push("/account"); } catch (error) { console.error(error); @@ -34,23 +36,24 @@ export default function Account() { content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world." /> -
-

{session.user.name || "No username"}

-
- - setUsername(e.target.value)} - placeholder="Edit Username" - type="text" - value={username} - > - - -
-

{session.user.email}

-
+
+
+

{session.user.name || "No username"}

+
+ + setUsername(e.target.value)} + placeholder="Edit Username" + type="text" + value={username} + > + + +
+
+
); } diff --git a/website/src/pages/account/index.tsx b/website/src/pages/account/index.tsx index 9f8dc4a3..d26fc842 100644 --- a/website/src/pages/account/index.tsx +++ b/website/src/pages/account/index.tsx @@ -19,13 +19,15 @@ export default function Account() { content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world." /> -
-

{session.user.name || "No username"}

- -

{session.user.email}

-
+
+
+

{session.user.name || "No username"}

+ +

{session.user.email}

+
+
); } diff --git a/website/src/pages/api/auth/[...nextauth].ts b/website/src/pages/api/auth/[...nextauth].ts index 8614de97..363c1404 100644 --- a/website/src/pages/api/auth/[...nextauth].ts +++ b/website/src/pages/api/auth/[...nextauth].ts @@ -59,6 +59,17 @@ if (boolean(process.env.DEBUG_LOGIN) || process.env.NODE_ENV === "development") ); } +// Create a map of provider types to a set of admin user identifiers based on +// the environment variables. We assume the list is separated by ',' and each +// entry is separated by ':'. +const adminUserMap = process.env.ADMIN_USERS.split(",").reduce((result, entry) => { + const [authType, id] = entry.split(":"); + const s = result.get(authType) || new Set(); + s.add(id); + result.set(authType, s); + return result; +}, new Map()); + export const authOptions: AuthOptions = { // Ensure we can store user data in a database. adapter: PrismaAdapter(prisma), @@ -68,6 +79,56 @@ export const authOptions: AuthOptions = { verifyRequest: "/auth/verify", // error: "/auth/error", -Will be used later }, + callbacks: { + /** + * Ensure we propagate the user's role when creating the session from the + * token. + */ + async session({ session, token }) { + session.user.role = token.role; + return session; + }, + /** + * When creating a token, fetch the user's role and inject it in the token. + * This let's use forward the role to the session object. + */ + async jwt({ token }) { + const { role } = await prisma.user.findUnique({ + where: { id: token.sub }, + select: { role: true }, + }); + token.role = role; + return token; + }, + }, + events: { + /** + * Update the user's role after they have successfully signed in + */ + async signIn({ user, account }) { + // Get the admin list for the user's auth type. + const adminForAccountType = adminUserMap.get(account.provider); + + // Return early if there's no admin list. + if (!adminForAccountType) { + return; + } + + // TODO(#236): Reduce the number of times we update the role field. + + // Update the database if the user is an admin. + if (adminForAccountType.has(account.providerAccountId)) { + await prisma.user.update({ + data: { + role: "admin", + }, + where: { + id: user.id, + }, + }); + } + }, + }, session: { strategy: "jwt", }, diff --git a/website/src/pages/api/messages/index.tsx b/website/src/pages/api/messages/index.tsx new file mode 100644 index 00000000..3c8d1c17 --- /dev/null +++ b/website/src/pages/api/messages/index.tsx @@ -0,0 +1,24 @@ +import { getToken } from "next-auth/jwt"; + +const handler = async (req, res) => { + const token = await getToken({ req }); + + // Return nothing if the user isn't registered. + if (!token) { + res.status(401).end(); + return; + } + + const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages`, { + method: "GET", + headers: { + "X-API-Key": process.env.FASTAPI_KEY, + }, + }); + const messages = await messagesRes.json(); + + // Send recieved messages to the client. + res.status(200).json(messages); +}; + +export default handler; diff --git a/website/src/pages/api/messages/user.tsx b/website/src/pages/api/messages/user.tsx new file mode 100644 index 00000000..e3d22f3c --- /dev/null +++ b/website/src/pages/api/messages/user.tsx @@ -0,0 +1,29 @@ +import { getToken } from "next-auth/jwt"; + +const handler = async (req, res) => { + const token = await getToken({ req }); + + // Return nothing if the user isn't registered. + if (!token) { + res.status(401).end(); + return; + } + + //TODO: add params if needed + const params = new URLSearchParams({ + username: token.sub, + }); + + const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages?${params}`, { + method: "GET", + headers: { + "X-API-Key": process.env.FASTAPI_KEY, + }, + }); + const messages = await messagesRes.json(); + + // Send recieved messages to the client. + res.status(200).json(messages); +}; + +export default handler; diff --git a/website/src/pages/api/new_task/[task_type].ts b/website/src/pages/api/new_task/[task_type].ts index 50f0b4e2..addcf3d8 100644 --- a/website/src/pages/api/new_task/[task_type].ts +++ b/website/src/pages/api/new_task/[task_type].ts @@ -1,4 +1,5 @@ import { getToken } from "next-auth/jwt"; +import { oasstApiClient } from "src/lib/oasst_api_client"; import prisma from "src/lib/prismadb"; /** @@ -21,24 +22,7 @@ const handler = async (req, res) => { } // Fetch the new task. - // - // This needs to be refactored into an easier to use library. - const taskRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/tasks/`, { - method: "POST", - headers: { - "X-API-Key": process.env.FASTAPI_KEY, - "Content-Type": "application/json", - }, - body: JSON.stringify({ - type: task_type, - user: { - id: token.sub, - display_name: token.name || token.email, - auth_method: "local", - }, - }), - }); - const task = await taskRes.json(); + const task = await oasstApiClient.fetchTask(task_type, token); // Store the task and link it to the user.. const registeredTask = await prisma.registeredTask.create({ @@ -53,16 +37,7 @@ const handler = async (req, res) => { }); // Update the backend with our Task ID - await fetch(`${process.env.FASTAPI_URL}/api/v1/tasks/${task.id}/ack`, { - method: "POST", - headers: { - "X-API-Key": process.env.FASTAPI_KEY, - "Content-Type": "application/json", - }, - body: JSON.stringify({ - message_id: registeredTask.id, - }), - }); + await oasstApiClient.ackTask(task.id, registeredTask.id); // Send the results to the client. res.status(200).json(registeredTask); diff --git a/website/src/pages/api/update_task.ts b/website/src/pages/api/update_task.ts index 9582040b..4eea8c1e 100644 --- a/website/src/pages/api/update_task.ts +++ b/website/src/pages/api/update_task.ts @@ -1,4 +1,5 @@ import { getToken } from "next-auth/jwt"; +import { oasstApiClient } from "src/lib/oasst_api_client"; import prisma from "src/lib/prismadb"; /** @@ -34,27 +35,7 @@ const handler = async (req, res) => { }, }); - // Send the interaction to the Task Backend. This automatically fetches the - // next task in the sequence (or the done task). - const interactionRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/tasks/interaction`, { - method: "POST", - headers: { - "X-API-Key": process.env.FASTAPI_KEY, - "Content-Type": "application/json", - }, - body: JSON.stringify({ - type: update_type, - user: { - id: token.sub, - display_name: token.name || token.email, - auth_method: "local", - }, - message_id: id, - user_message_id: interaction.id, - ...content, - }), - }); - const newTask = await interactionRes.json(); + const newTask = await oasstApiClient.interactTask(update_type, id, interaction.id, content, token); // Stores the new task with our database. const newRegisteredTask = await prisma.registeredTask.create({ diff --git a/website/src/pages/api/username.tsx b/website/src/pages/api/username.tsx index 0a88ad20..556edfa9 100644 --- a/website/src/pages/api/username.tsx +++ b/website/src/pages/api/username.tsx @@ -1,4 +1,5 @@ import { getSession } from "next-auth/react"; +import prisma from "../../lib/prismadb"; // POST /api/post // Required fields in body: title diff --git a/website/src/pages/auth/signin.tsx b/website/src/pages/auth/signin.tsx index 221eb1f0..59fc7c05 100644 --- a/website/src/pages/auth/signin.tsx +++ b/website/src/pages/auth/signin.tsx @@ -15,13 +15,13 @@ function Signin({ csrfToken, providers }) { const emailEl = useRef(null); const signinWithEmail = (ev: React.FormEvent) => { ev.preventDefault(); - signIn(email.id, { callbackUrl: "/", email: emailEl.current.value }); + signIn(email.id, { callbackUrl: "/dashboard", email: emailEl.current.value }); }; const debugUsernameEl = useRef(null); function signinWithDebugCredentials(ev: React.FormEvent) { ev.preventDefault(); - signIn(credentials.id, { callbackUrl: "/", username: debugUsernameEl.current.value }); + signIn(credentials.id, { callbackUrl: "/dashboard", username: debugUsernameEl.current.value }); } const { colorMode } = useColorMode(); @@ -98,7 +98,8 @@ function Signin({ csrfToken, providers }) { )} -
+
+
By signing up you agree to our

Terms of Service diff --git a/website/src/pages/create/assistant_reply.tsx b/website/src/pages/create/assistant_reply.tsx index ceac45be..2f7ca748 100644 --- a/website/src/pages/create/assistant_reply.tsx +++ b/website/src/pages/create/assistant_reply.tsx @@ -1,6 +1,6 @@ import { Container, Textarea } from "@chakra-ui/react"; import { useColorMode } from "@chakra-ui/react"; -import { useRef, useState } from "react"; +import { useEffect, useRef, useState } from "react"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Messages } from "src/components/Messages"; import { TaskControls } from "src/components/Survey/TaskControls"; @@ -21,6 +21,12 @@ const AssistantReply = () => { }, }); + useEffect(() => { + if (tasks.length == 0) { + mutate(); + } + }, [tasks]); + const { trigger } = useSWRMutation("/api/update_task", poster, { onSuccess: async (data) => { const newTask = await data.json(); diff --git a/website/src/pages/create/initial_prompt.tsx b/website/src/pages/create/initial_prompt.tsx index 5aecf98c..a4aed9c3 100644 --- a/website/src/pages/create/initial_prompt.tsx +++ b/website/src/pages/create/initial_prompt.tsx @@ -1,6 +1,6 @@ import { Container, Textarea } from "@chakra-ui/react"; import { useColorMode } from "@chakra-ui/react"; -import { useRef, useState } from "react"; +import { useEffect, useRef, useState } from "react"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { TaskControls } from "src/components/Survey/TaskControls"; import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards"; @@ -27,6 +27,12 @@ const InitialPrompt = () => { }, }); + useEffect(() => { + if (tasks.length == 0) { + mutate(); + } + }, [tasks]); + const submitResponse = (task: { id: string }) => { const text = inputRef.current.value.trim(); trigger({ diff --git a/website/src/pages/create/user_reply.tsx b/website/src/pages/create/user_reply.tsx index b9022e86..40409189 100644 --- a/website/src/pages/create/user_reply.tsx +++ b/website/src/pages/create/user_reply.tsx @@ -1,6 +1,6 @@ import { Textarea } from "@chakra-ui/react"; import { useColorMode } from "@chakra-ui/react"; -import { useRef, useState } from "react"; +import { useEffect, useRef, useState } from "react"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Messages } from "src/components/Messages"; import { TaskControls } from "src/components/Survey/TaskControls"; @@ -21,6 +21,12 @@ const UserReply = () => { }, }); + useEffect(() => { + if (tasks.length == 0) { + mutate(); + } + }, [tasks]); + const { trigger } = useSWRMutation("/api/update_task", poster, { onSuccess: async (data) => { const newTask = await data.json(); diff --git a/website/src/pages/dashboard.tsx b/website/src/pages/dashboard.tsx new file mode 100644 index 00000000..8b1f6861 --- /dev/null +++ b/website/src/pages/dashboard.tsx @@ -0,0 +1,33 @@ +import { Box, useColorMode } from "@chakra-ui/react"; +import Head from "next/head"; + +import { getDashboardLayout } from "src/components/Layout"; +import { LeaderboardTable, SideMenu, TaskOption } from "src/components/Dashboard"; +import { colors } from "styles/Theme/colors"; + +const Dashboard = () => { + const { colorMode } = useColorMode(); + return ( + <> + + Dashboard - Open Assistant + + + + + + + + + + + + + + + ); +}; + +Dashboard.getLayout = (page) => getDashboardLayout(page); + +export default Dashboard; diff --git a/website/src/pages/evaluate/rank_assistant_replies.tsx b/website/src/pages/evaluate/rank_assistant_replies.tsx index e18d228a..d20f6364 100644 --- a/website/src/pages/evaluate/rank_assistant_replies.tsx +++ b/website/src/pages/evaluate/rank_assistant_replies.tsx @@ -1,10 +1,12 @@ import { useColorMode } from "@chakra-ui/react"; import Head from "next/head"; -import { useState } from "react"; +import { useEffect, useState } from "react"; +import { ContextMessages } from "src/components/ContextMessages"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; +import { Message } from "src/components/Messages"; import { Sortable } from "src/components/Sortable/Sortable"; import { SurveyCard } from "src/components/Survey/SurveyCard"; -import { TaskControls } from "src/components/Survey/TaskControls"; +import { TaskControlsOverridable } from "src/components/Survey/TaskControlsOverridable"; import fetcher from "src/lib/fetcher"; import poster from "src/lib/poster"; import useSWRImmutable from "swr/immutable"; @@ -24,6 +26,12 @@ const RankAssistantReplies = () => { }, }); + useEffect(() => { + if (tasks.length == 0) { + mutate(); + } + }, [tasks]); + const { trigger } = useSWRMutation("/api/update_task", poster, { onSuccess: async (data) => { const newTask = await data.json(); @@ -64,6 +72,7 @@ const RankAssistantReplies = () => { } const replies = tasks[0].task.replies as string[]; + const messages = tasks[0].task.conversation.messages as Message[]; return ( <> @@ -77,10 +86,17 @@ const RankAssistantReplies = () => {

Given the following replies, sort them from best to worst, best being first, worst being last.

+ - + setRanking(tasks[0].task.replies.map((_, idx) => idx))} + onSubmitResponse={submitResponse} + onSkip={fetchNextTask} + />
); diff --git a/website/src/pages/evaluate/rank_initial_prompts.tsx b/website/src/pages/evaluate/rank_initial_prompts.tsx index 48a67e90..a0a48b27 100644 --- a/website/src/pages/evaluate/rank_initial_prompts.tsx +++ b/website/src/pages/evaluate/rank_initial_prompts.tsx @@ -1,10 +1,10 @@ import { useColorMode } from "@chakra-ui/react"; import Head from "next/head"; -import { useState } from "react"; +import { useEffect, useState } from "react"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Sortable } from "src/components/Sortable/Sortable"; import { SurveyCard } from "src/components/Survey/SurveyCard"; -import { TaskControls } from "src/components/Survey/TaskControls"; +import { TaskControlsOverridable } from "src/components/Survey/TaskControlsOverridable"; import fetcher from "src/lib/fetcher"; import poster from "src/lib/poster"; import useSWRImmutable from "swr/immutable"; @@ -32,6 +32,12 @@ const RankInitialPrompts = () => { }, }); + useEffect(() => { + if (tasks.length == 0) { + mutate(); + } + }, [tasks]); + const submitResponse = (task) => { trigger({ id: task.id, @@ -79,7 +85,13 @@ const RankInitialPrompts = () => { - + setRanking(tasks[0].task.prompts.map((_, idx) => idx))} + onSubmitResponse={submitResponse} + onSkip={fetchNextTask} + />
); diff --git a/website/src/pages/evaluate/rank_user_replies.tsx b/website/src/pages/evaluate/rank_user_replies.tsx index ea57f9a0..09e24fa2 100644 --- a/website/src/pages/evaluate/rank_user_replies.tsx +++ b/website/src/pages/evaluate/rank_user_replies.tsx @@ -1,10 +1,12 @@ import { useColorMode } from "@chakra-ui/react"; import Head from "next/head"; -import { useState } from "react"; +import { useEffect, useState } from "react"; +import { ContextMessages } from "src/components/ContextMessages"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; +import { Message } from "src/components/Messages"; import { Sortable } from "src/components/Sortable/Sortable"; import { SurveyCard } from "src/components/Survey/SurveyCard"; -import { TaskControls } from "src/components/Survey/TaskControls"; +import { TaskControlsOverridable } from "src/components/Survey/TaskControlsOverridable"; import fetcher from "src/lib/fetcher"; import poster from "src/lib/poster"; import useSWRImmutable from "swr/immutable"; @@ -24,6 +26,12 @@ const RankUserReplies = () => { }, }); + useEffect(() => { + if (tasks.length == 0) { + mutate(); + } + }, [tasks]); + const { trigger } = useSWRMutation("/api/update_task", poster, { onSuccess: async (data) => { const newTask = await data.json(); @@ -63,6 +71,7 @@ const RankUserReplies = () => { ); } const replies = tasks[0].task.replies as string[]; + const messages = tasks[0].task.conversation.messages as Message[]; return ( <> @@ -76,10 +85,17 @@ const RankUserReplies = () => {

Given the following replies, sort them from best to worst, best being first, worst being last.

+ - + setRanking(tasks[0].task.replies.map((_, idx) => idx))} + onSubmitResponse={submitResponse} + onSkip={fetchNextTask} + />
); diff --git a/website/src/pages/index.tsx b/website/src/pages/index.tsx index 08528245..04f99829 100644 --- a/website/src/pages/index.tsx +++ b/website/src/pages/index.tsx @@ -1,16 +1,10 @@ import Head from "next/head"; -import { useSession } from "next-auth/react"; import { CallToAction } from "src/components/CallToAction"; import { Faq } from "src/components/Faq"; -import { Footer } from "src/components/Footer"; -import { Header } from "src/components/Header"; import { Hero } from "src/components/Hero"; import { getTransparentHeaderLayout } from "src/components/Layout"; -import { TaskSelection } from "src/components/TaskSelection"; const Home = () => { - const { data: session } = useSession(); - return ( <> @@ -20,15 +14,11 @@ const Home = () => { content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world." /> - {session ? ( - - ) : ( -
- - - -
- )} +
+ + + +
); }; diff --git a/website/src/pages/messages/index.tsx b/website/src/pages/messages/index.tsx new file mode 100644 index 00000000..39430caf --- /dev/null +++ b/website/src/pages/messages/index.tsx @@ -0,0 +1,83 @@ +import { Box, CircularProgress, SimpleGrid, Text, useColorModeValue } from "@chakra-ui/react"; +import Head from "next/head"; +import { useState } from "react"; +import useSWRImmutable from "swr/immutable"; + +import fetcher from "src/lib/fetcher"; +import { SideMenu } from "src/components/Dashboard"; +import { MessageTable } from "src/components/Messages/MessageTable"; +import { getDashboardLayout } from "src/components/Layout"; +import { colors } from "styles/Theme/colors"; + +const MessagesDashboard = () => { + const bgColor = useColorModeValue(colors.light.bg, colors.dark.bg); + const boxBgColor = useColorModeValue("white", "gray.700"); + const boxAccentColor = useColorModeValue("gray.200", "gray.900"); + + const [messages, setMessages] = useState([]); + const [userMessages, setUserMessages] = useState([]); + + const { isLoading: isLoadingAll } = useSWRImmutable("/api/messages", fetcher, { + onSuccess: (data) => { + setMessages(data); + }, + }); + + const { isLoading: isLoadingUser } = useSWRImmutable(`/api/messages/user`, fetcher, { + onSuccess: (data) => { + setUserMessages(data); + }, + }); + + return ( + <> + + Messages - Open Assistant + + + + + + + + + + + + Most recent messages + + + {isLoadingAll ? : } + + + + + Your most recent messages + + + {isLoadingUser ? : } + + + + + + + + ); +}; + +MessagesDashboard.getLayout = (page) => getDashboardLayout(page); + +export default MessagesDashboard; diff --git a/website/src/pages/privacy-policy.tsx b/website/src/pages/privacy-policy.tsx index dcb3bc19..164fab93 100644 --- a/website/src/pages/privacy-policy.tsx +++ b/website/src/pages/privacy-policy.tsx @@ -14,7 +14,7 @@ const PrivacyPolicy = () => { content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world." /> -
+
Privacy Policy diff --git a/website/src/pages/terms-of-service.tsx b/website/src/pages/terms-of-service.tsx index d97c8d34..ce60a20e 100644 --- a/website/src/pages/terms-of-service.tsx +++ b/website/src/pages/terms-of-service.tsx @@ -14,7 +14,7 @@ const TermsOfService = () => { content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world." /> -
+
Terms Of Service diff --git a/website/src/styles/Theme/index.ts b/website/src/styles/Theme/index.ts index 11edd8f3..0c43eddd 100644 --- a/website/src/styles/Theme/index.ts +++ b/website/src/styles/Theme/index.ts @@ -14,6 +14,14 @@ const components = { Container: containerTheme, }; +const breakpoints = { + sm: "640px", + md: "768px", + lg: "1024px", + xl: "1280px", + "2xl": "1536px", +}; + const styles: Styles = { global: (props) => ({ "*": { @@ -34,4 +42,4 @@ const styles: Styles = { }), }; -export const theme = extendTheme({ colors, config, styles, components }); +export const theme = extendTheme({ colors, config, styles, components, breakpoints }); diff --git a/website/styles/Theme/Components/Container.tsx b/website/styles/Theme/Components/Container.tsx new file mode 100644 index 00000000..ec7eeed9 --- /dev/null +++ b/website/styles/Theme/Components/Container.tsx @@ -0,0 +1,25 @@ +import { + color, + defineStyle, + defineStyleConfig, + // transition, +} from "@chakra-ui/styled-system"; +import { colors } from "../colors"; + +const baseStyle = defineStyle(({ colorMode }) => ({ + minWidth: "100%", + bg: colorMode === "light" ? colors.light.bg : colors.dark.bg, + // transition: "background-color 300ms cubic-bezier(0.4, 0, 1, 1)", + color: colorMode === "light" ? colors.light.text : colors.dark.text, +})); + +const variants = { + "no-padding": { + padding: 0, + }, +}; + +export const containerTheme = defineStyleConfig({ + baseStyle, + variants, +}); diff --git a/website/styles/Theme/colors.tsx b/website/styles/Theme/colors.tsx new file mode 100644 index 00000000..acadfa2b --- /dev/null +++ b/website/styles/Theme/colors.tsx @@ -0,0 +1,14 @@ +export const colors = { + light: { + bg: "gray.100", + btn: "gray.50", + div: "white", + text: "black", + }, + dark: { + bg: "gray.900", + btn: "gray.600", + div: "gray.700", + text: "gray.200", + }, +}; diff --git a/website/styles/Theme/index.tsx b/website/styles/Theme/index.tsx new file mode 100644 index 00000000..9e9b0d82 --- /dev/null +++ b/website/styles/Theme/index.tsx @@ -0,0 +1,68 @@ +import { + type ThemeConfig, + extendTheme, + usePrefersReducedMotion, +} from "@chakra-ui/react"; +import { containerTheme } from "./Components/Container"; +import { StyleFunctionProps, Styles } from "@chakra-ui/theme-tools"; + +const config: ThemeConfig = { + initialColorMode: "system", + useSystemColorMode: false, + disableTransitionOnChange: true, +}; + +const components = { + Container: containerTheme, + Box: (props: StyleFunctionProps) => ({ + backgroundColor: props.colorMode === "light" ? "white" : "gray.800", + }), + Button: { + baseStyle: { + fontWeight: "normal", + }, + sizes: { + lg: { + fontSize: "md", + paddingY: "7", + }, + }, + variants: { + solid: (props: StyleFunctionProps) => ({ + bg: props.colorMode === "light" ? "gray.100" : "gray.600", + _hover: { + bg: props.colorMode === "light" ? "gray.200" : "#3D4A60", + }, + _active: { + bg: props.colorMode === "light" ? "gray.300" : "#374254", + }, + borderRadius: "lg", + }), + // gradient: (props: StyleFunctionProps) => ({ + // bg: `linear-gradient(${white}, ${bgColor}) padding-box, + // linear-gradient(135deg, ${lgFrom}, ${lgTo}) border-box`, + // }), + }, + }, +}; + +const breakpoints = { + sm: "640px", + md: "768px", + lg: "1024px", + xl: "1280px", + "2xl": "1536px", +}; + +const styles = { + global: (props) => ({ + main: { + fontFamily: "Inter", + }, + header: { + fontFamily: "Inter", + }, + }), +}; + +export const theme = extendTheme({ config, styles, components, breakpoints }); diff --git a/website/types/next-auth.d.ts b/website/types/next-auth.d.ts new file mode 100644 index 00000000..25027600 --- /dev/null +++ b/website/types/next-auth.d.ts @@ -0,0 +1,18 @@ +import NextAuth, { DefaultSession } from "next-auth"; +import { JWT } from "next-auth/jwt"; + +declare module "next-auth" { + interface Session { + user: { + /** The user's role. */ + role: string; + } & DefaultSession["user"]; + } +} + +declare module "next-auth/jwt" { + interface JWT { + /** The user's role. */ + role?: string; + } +}