added data downloading

This commit is contained in:
Joshuaclymer
2023-11-11 23:01:53 +00:00
parent 5203d96677
commit 84cd19f839
17 changed files with 757 additions and 160 deletions
+2
View File
@@ -2,3 +2,5 @@ models/
configs/credentials.json configs/credentials.json
build/ build/
env/ env/
upload_data.py
genies-datasets.tar
+2 -2
View File
@@ -8,7 +8,7 @@ As AI systems become more capable and are deployed in complex environments, it m
## Quickstart ## Quickstart
This repository contains: This repository contains:
- Our datasets (`./distributions`) along with pairing specifications (`./distribution_shifts`). - Our datasets (`./distributions`) along with pairing specifications (`./distribution_shifts`). Download our datasets [here](https://genies-data.s3.us-east-2.amazonaws.com/genies-datasets.tar) or run the setup command after cloning the repo.
- Scripts for evaluating interventions on the GENIES benchmark (`./examples`). - Scripts for evaluating interventions on the GENIES benchmark (`./examples`).
- Our results (`./results`). - Our results (`./results`).
- Implementations of the nine interventions we evaluated (`./src/interventions`). - Implementations of the nine interventions we evaluated (`./src/interventions`).
@@ -18,7 +18,7 @@ All of the models we fine-tuned with Lora can be found on [huggingface](https://
**Setup:** **Setup:**
``` ```
conda create --name env python=3.10 conda create --name env python=3.10
pip install -e . pip install .
python download_model_from_hf.py EleutherAI/pythia-410m models/pythia-410m python download_model_from_hf.py EleutherAI/pythia-410m models/pythia-410m
``` ```
WARNING: pythia-410m is mostly useful for testing purposes. Most tuning interventions perform poorly with this model. WARNING: pythia-410m is mostly useful for testing purposes. Most tuning interventions perform poorly with this model.
Binary file not shown.
+43
View File
@@ -0,0 +1,43 @@
import requests
import tarfile
import os
import fire
def download_tarfile(filename : str, download_dir : str):
# URL of the file to download
url = f"https://genies-data.s3.us-east-2.amazonaws.com/{filename}"
# Specify the local file path where you want to save the downloaded file
downloaded_file_path = filename
# Specify the directory where you want to extract the contents
extracted_dir = download_dir
# Download the file from the URL
response = requests.get(url)
# Check if the request was successful (status code 200)
if response.status_code == 200:
# Save the downloaded content to a local file
with open(downloaded_file_path, 'wb') as file:
file.write(response.content)
print(f"File '{downloaded_file_path}' has been downloaded successfully.")
# Create the directory for extraction if it doesn't exist
os.makedirs(extracted_dir, exist_ok=True)
# Extract the contents of the tar file
with tarfile.open(downloaded_file_path, "r") as tar:
tar.extractall(path=extracted_dir)
print(f"Contents of '{downloaded_file_path}' have been extracted to '{extracted_dir}'.")
else:
print(f"Failed to download the file. Status code: {response.status_code}")
def download_data():
print("Downloading distributions...")
download_tarfile("genies-datasets.tar", ".")
if __name__ == "__main__":
fire.Fire(download_data)
+120 -154
View File
@@ -1,237 +1,203 @@
absl-py==1.4.0 absl-py==2.0.0
accelerate==0.22.0 accelerate==0.24.1
aiohttp==3.8.5 aiohttp==3.8.6
aiosignal==1.3.1 aiosignal==1.3.1
altair==5.1.2 annotated-types==0.6.0
ansiwrap==0.8.4 anyio==3.7.1
anyio==4.0.0
appdirs==1.4.4 appdirs==1.4.4
argon2-cffi==23.1.0 argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0 argon2-cffi-bindings==21.2.0
arrow==1.2.3 arrow==1.3.0
astor==0.8.1 asttokens==2.4.1
asttokens==2.4.0
async-lru==2.0.4 async-lru==2.0.4
async-timeout==4.0.3 async-timeout==4.0.3
attrs==23.1.0 attrs==23.1.0
awscli==1.29.43 Babel==2.13.1
Babel==2.12.1
backcall==0.2.0
base58==2.1.1
baukit @ git+https://github.com/davidbau/baukit@5e23007c02fd58f063200c5dc9033e90f092630d
bcrypt==4.0.1 bcrypt==4.0.1
beautifulsoup4==4.12.2 beautifulsoup4==4.12.2
bitsandbytes==0.41.1 bitsandbytes==0.41.2.post2
black==21.12b0 bleach==6.1.0
bleach==6.0.0 boto3==1.28.84
blinker==1.6.3 botocore==1.31.84
boto3==1.28.62
botocore==1.31.62
Brotli==1.1.0
cachetools==5.3.2
certifi==2023.7.22 certifi==2023.7.22
cffi==1.15.1 cffi==1.16.0
charset-normalizer==3.2.0 charset-normalizer==3.3.2
click==7.1.2 click==8.1.7
cmake==3.27.4.1 comm==0.2.0
colorama==0.4.4 contourpy==1.2.0
comm==0.1.4 cryptography==41.0.5
contourpy==1.1.0 cycler==0.12.1
cryptography==41.0.3
cycler==0.11.0
Cython==0.29.36 Cython==0.29.36
datasets==2.14.5 datasets==2.14.6
debugpy==1.7.0 debugpy==1.8.0
decorator==5.1.1 decorator==5.1.1
deepspeed==0.10.2 deepspeed==0.12.2
defusedxml==0.7.1 defusedxml==0.7.1
dill==0.3.7 dill==0.3.7
distro==1.8.0
docker-pycreds==0.4.0 docker-pycreds==0.4.0
docutils==0.16 docopt==0.6.2
einops==0.6.1 docstring-parser==0.15
einops==0.7.0
entrypoints==0.4 entrypoints==0.4
exceptiongroup==1.1.3 exceptiongroup==1.1.3
executing==1.2.0 executing==2.0.1
fastjsonschema==2.18.0 fastjsonschema==2.18.1
filelock==3.12.3 filelock==3.13.1
fire==0.5.0 fire==0.5.0
flake8==6.1.0 fonttools==4.44.0
fonttools==4.42.1
fqdn==1.5.1 fqdn==1.5.1
frozenlist==1.4.0 frozenlist==1.4.0
fsspec==2023.6.0 fsspec==2023.10.0
gitdb==4.0.10 gitdb==4.0.11
GitPython==3.1.35 GitPython==3.1.40
h11==0.14.0
hjson==3.1.0 hjson==3.1.0
huggingface-hub==0.16.4 httpcore==1.0.2
httpx==0.25.1
huggingface-hub==0.17.3
idna==3.4 idna==3.4
inflate64==0.3.1 ipykernel==6.26.0
iniconfig==2.0.0 ipython==8.17.2
ipykernel==6.25.2 ipywidgets==8.1.1
ipython==8.15.0
ipython-genutils==0.2.0
ipywidgets==8.1.0
isoduration==20.11.0 isoduration==20.11.0
isort==5.8.0 jedi==0.19.1
jedi==0.19.0
Jinja2==3.1.2 Jinja2==3.1.2
jmespath==1.0.1 jmespath==1.0.1
joblib==1.3.2 joblib==1.3.2
json5==0.9.14 json5==0.9.14
jsonpointer==2.4 jsonpointer==2.4
jsonschema==4.19.0 jsonschema==4.19.2
jsonschema-specifications==2023.7.1 jsonschema-specifications==2023.7.1
jupyter==1.0.0 jupyter==1.0.0
jupyter-console==6.6.3 jupyter-console==6.6.3
jupyter-events==0.7.0 jupyter-events==0.9.0
jupyter-lsp==2.2.0 jupyter-lsp==2.2.0
jupyter_client==8.3.1 jupyter_client==8.6.0
jupyter_core==5.3.1 jupyter_core==5.5.0
jupyter_server==2.7.3 jupyter_server==2.10.0
jupyter_server_terminals==0.4.4 jupyter_server_terminals==0.4.4
jupyterlab==4.0.5 jupyterlab==4.0.8
jupyterlab-pygments==0.2.2 jupyterlab-pygments==0.2.2
jupyterlab-widgets==3.0.8 jupyterlab-widgets==3.0.9
jupyterlab_server==2.24.0 jupyterlab_server==2.25.1
kiwisolver==1.4.5 kiwisolver==1.4.5
lit==16.0.6 markdown-it-py==3.0.0
MarkupSafe==2.1.3 MarkupSafe==2.1.3
matplotlib==3.7.2 matplotlib==3.8.1
matplotlib-inline==0.1.6 matplotlib-inline==0.1.6
mccabe==0.7.0 mdurl==0.1.2
mistune==3.0.1 mistune==3.0.2
mpi4py @ file:///croot/mpi4py_1671223370575/work
mpmath==1.3.0 mpmath==1.3.0
multidict==6.0.4 multidict==6.0.4
multiprocess==0.70.15 multiprocess==0.70.15
multivolumefile==0.2.3 nbclient==0.9.0
mypy-extensions==1.0.0 nbconvert==7.11.0
nbclient==0.8.0
nbconvert==7.8.0
nbformat==5.9.2 nbformat==5.9.2
nest-asyncio==1.5.7 nest-asyncio==1.5.8
networkx==3.1 networkx==3.2.1
ninja==1.11.1 ninja==1.11.1.1
nltk==3.8.1 nltk==3.8.1
notebook==7.0.3 notebook==7.0.6
notebook_shim==0.2.3 notebook_shim==0.2.3
numpy==1.25.2 numpy==1.26.1
nvidia-cublas-cu11==11.10.3.66 nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu11==11.7.101 nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu11==11.7.99 nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu11==11.7.99 nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu11==8.5.0.96 nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu11==10.9.0.58 nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu11==10.2.10.91 nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu11==11.4.0.1 nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu11==11.7.4.91 nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu11==2.14.3 nvidia-nccl-cu12==2.18.1
nvidia-nvtx-cu11==11.7.91 nvidia-nvjitlink-cu12==12.3.52
openai==0.28.0 nvidia-nvtx-cu12==12.1.105
openai==1.2.3
overrides==7.4.0 overrides==7.4.0
packaging==23.1 packaging==23.2
pandas==2.1.0 pandas==2.1.3
pandocfilters==1.5.0 pandocfilters==1.5.0
papermill==2.4.0 papermill==2.5.0
paramiko==3.3.1 paramiko==3.3.1
parso==0.8.3 parso==0.8.3
pathspec==0.11.2 peft==0.6.1
pathtools==0.1.2
patsy==0.5.3
peft==0.5.0
pexpect==4.8.0 pexpect==4.8.0
pickleshare==0.7.5 Pillow==10.1.0
Pillow==10.0.0 pipreqs==0.4.13
platformdirs==3.10.0 platformdirs==4.0.0
plotly==5.18.0 prometheus-client==0.18.0
pluggy==1.3.0 prompt-toolkit==3.0.40
prometheus-client==0.17.1
prompt-toolkit==3.0.39
promptsource @ git+https://github.com/bigscience-workshop/promptsource@7dab96a3eeb3717cea633705135ebc488885d709
protobuf==3.20.3 protobuf==3.20.3
psutil==5.9.5 psutil==5.9.6
ptyprocess==0.7.0 ptyprocess==0.7.0
pure-eval==0.2.2 pure-eval==0.2.2
py-cpuinfo==9.0.0 py-cpuinfo==9.0.0
py7zr==0.20.6 pyarrow==14.0.1
pyarrow==13.0.0
pyasn1==0.5.0
pybcj==1.0.1
pycodestyle==2.11.1
pycparser==2.21 pycparser==2.21
pycryptodomex==3.19.0 pydantic==2.4.2
pydantic==1.10.12 pydantic_core==2.10.1
pydeck==0.8.1b0
pyflakes==3.1.0
Pygments==2.16.1 Pygments==2.16.1
PyNaCl==1.5.0 PyNaCl==1.5.0
pyparsing==3.0.9 pynvml==11.5.0
pyppmd==1.0.0 pyparsing==3.1.1
pytest==7.4.3
python-dateutil==2.8.2 python-dateutil==2.8.2
python-json-logger==2.0.7 python-json-logger==2.0.7
pytz==2023.3.post1 pytz==2023.3.post1
PyYAML==6.0.1 PyYAML==6.0.1
pyzmq==25.1.1 pyzmq==25.1.1
pyzstd==0.15.9 qtconsole==5.5.0
qtconsole==5.4.4 QtPy==2.4.1
QtPy==2.4.0
referencing==0.30.2 referencing==0.30.2
regex==2023.8.8 regex==2023.10.3
requests==2.31.0 requests==2.31.0
rfc3339-validator==0.1.4 rfc3339-validator==0.1.4
rfc3986-validator==0.1.1 rfc3986-validator==0.1.1
rich==13.6.0
rouge-score==0.1.2 rouge-score==0.1.2
rpds-py==0.10.2 rpds-py==0.12.0
rsa==4.7.2
s3transfer==0.7.0 s3transfer==0.7.0
safetensors==0.3.3 safetensors==0.4.0
scikit-learn==1.3.0 scikit-learn==1.3.2
scipy==1.11.2 scipy==1.11.3
seaborn==0.12.2 seaborn==0.13.0
Send2Trash==1.8.2 Send2Trash==1.8.2
sentencepiece==0.1.99 sentencepiece==0.1.99
sentry-sdk==1.30.0 sentry-sdk==1.34.0
setproctitle==1.3.2 setproctitle==1.3.3
shtab==1.6.4
six==1.16.0 six==1.16.0
smmap==5.0.0 smmap==5.0.1
sniffio==1.3.0 sniffio==1.3.0
soupsieve==2.5 soupsieve==2.5
stack-data==0.6.2 stack-data==0.6.3
statsmodels==0.14.0
streamlit==0.82.0
sympy==1.12 sympy==1.12
tenacity==8.2.3 tenacity==8.2.3
termcolor==2.3.0 termcolor==2.3.0
terminado==0.17.1 terminado==0.18.0
texttable==1.7.0
textwrap3==0.9.2
threadpoolctl==3.2.0 threadpoolctl==3.2.0
tinycss2==1.2.1 tinycss2==1.2.1
tokenizers==0.13.3 tokenizers==0.14.1
toml==0.10.2 tomli==2.0.1
tomli==1.2.3 torch==2.1.0
toolz==0.12.0
torch==2.0.1
torchvision==0.15.2
tornado==6.3.3 tornado==6.3.3
tqdm==4.66.1 tqdm==4.66.1
traitlets==5.9.0 traitlets==5.13.0
transformers==4.33.1 transformers==4.35.0
triton==2.0.0 triton==2.1.0
trl==0.7.1 trl==0.7.4
typing_extensions==4.7.1 types-python-dateutil==2.8.19.14
typing_extensions==4.8.0
tyro==0.5.12
tzdata==2023.3 tzdata==2023.3
tzlocal==5.2
uri-template==1.3.0 uri-template==1.3.0
urllib3==1.26.16 urllib3==2.0.7
validators==0.22.0 wandb==0.16.0
wandb==0.15.11 wcwidth==0.2.9
watchdog==3.0.0
wcwidth==0.2.6
webcolors==1.13 webcolors==1.13
webencodings==0.5.1 webencodings==0.5.1
websocket-client==1.6.2 websocket-client==1.6.4
widgetsnbextension==4.0.8 widgetsnbextension==4.0.9
xxhash==3.3.0 xxhash==3.4.1
yarg==0.1.9
yarl==1.9.2 yarl==1.9.2
+16 -3
View File
@@ -1,7 +1,9 @@
import setuptools import setuptools
import os import os
from io import open from io import open
from setuptools.command.install import install
from setuptools import find_packages from setuptools import find_packages
from download_data import download_data
src_dir = os.path.abspath(os.path.dirname(__file__)) src_dir = os.path.abspath(os.path.dirname(__file__))
@@ -15,24 +17,35 @@ if os.path.isfile(requirements_path):
with open(requirements_path) as f: with open(requirements_path) as f:
requirements = f.read().splitlines() requirements = f.read().splitlines()
class PostInstallCommand(install):
"""Post-installation for installation mode."""
def run(self):
print("RUNNING POST INSTALL 1")
install.run(self)
print("RUNNING POST INSTALL")
download_data()
setuptools.setup( setuptools.setup(
name="fig-benchmark", name="genies-benchmark",
version="0.0.1", version="0.0.1",
author="Joshua Clymer, Garrett Baker, Rohan Subramani, and Sam Wang", author="Joshua Clymer, Garrett Baker, Rohan Subramani, and Sam Wang",
author_email="joshuamclymer@gmail.com", author_email="joshuamclymer@gmail.com",
description="The fig benchmark repository contains datasets and tooling for evaluating the generalization of preferrence models.", description="The fig benchmark repository contains datasets and tooling for evaluating the generalization of preferrence models.",
long_description=long_description, long_description=long_description,
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
url="https://github.com/Joshuaclymer/FIG-benchmark", url="https://github.com/Joshuaclymer/GENIES",
classifiers=[ classifiers=[
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"License :: MIT License", "License :: MIT License",
"Operating System :: OS Independent", "Operating System :: OS Independent",
], ],
cmdclass={
'install': PostInstallCommand,
},
python_requires='>=3.10', python_requires='>=3.10',
install_requires=requirements, install_requires=requirements,
packages=find_packages(where='src'), # Specify 'src' as the root packages=find_packages(where='src'), # Specify 'src' as the root
package_dir={'': 'src'}, package_dir={'': 'src'},
package_data={'fig-benchmark': ['LICENCE', 'requirements.txt']}, package_data={'genies-benchmark': ['LICENCE', 'requirements.txt']},
) )
+52
View File
@@ -0,0 +1,52 @@
Metadata-Version: 2.1
Name: fig-benchmark
Version: 0.0.1
Summary: The fig benchmark repository contains datasets and tooling for evaluating the generalization of preferrence models.
Home-page: https://github.com/Joshuaclymer/FIG-benchmark
Author: Joshua Clymer, Garrett Baker, Rohan Subramani, and Sam Wang
Author-email: joshuamclymer@gmail.com
Classifier: Programming Language :: Python :: 3
Classifier: License :: MIT License
Classifier: Operating System :: OS Independent
Requires-Python: >=3.10
Description-Content-Type: text/markdown
# Generalization Analogies: A Testbed for Generalizing AI Oversight to Hard-To-Measure Domains
Read our paper [here](TODO). Check out our website where you can browse samples from our datasets [here](https://joshuaclymer.github.io/generalization-analogies-website/).
![Hero](assets/hero_horizontal.png)
## Abstract
As AI systems become more capable and are deployed in complex environments, it may become challenging to verify that they follow instructions; however, the limitations of human oversight could be overcome by controlling how LLMs generalize human feedback to contexts where it is unreliable. To better understand how Reward Models generalize human feedback, we craft 69 distribution shifts spanning 8 different categories. We find that Reward Models do not learn to evaluate instruction-following by default and instead favor personas that resemble internet text. Techniques for interpreting Reward Models internal representations achieve better generalization, but still frequently fail to distinguish instruction-following from conflated behaviors. We consolidate the 15 most challenging distribution shifts into the \textbf{GEN}aralization analog\textbf{IES} (\textsc{GENIES}) benchmark, which we hope will enable progress toward controlling Reward Model generalization.
## Quickstart
This repository contains:
- Our datasets (`./distributions`) along with pairing specifications (`./distribution_shifts`).
- Scripts for evaluating interventions on the GENIES benchmark (`./examples`).
- Our results (`./results`).
- Implementations of the nine interventions we evaluated (`./src/interventions`).
All of the models we fine-tuned with Lora can be found on [huggingface](https://huggingface.co/genies-models).
**Setup:**
```
conda create --name env python=3.10
pip install -e .
python download_model_from_hf.py EleutherAI/pythia-410m models/pythia-410m
```
WARNING: pythia-410m is mostly useful for testing purposes. Most tuning interventions perform poorly with this model.
## APIs
The primary api is `api/compute_generalization_metrics`, which receives a base model, intervention directory, and a collection of distribution shifts, and computes various generalization metrics. See `examples/compute_generalization_metrics.sh` for example usage.
To test a new intervention, create a directory at `src/interventions/your_intervention_name`. This directory must contain a `train.py` file and an `eval.py` file.
`src/interventions/your_intervention_name/train.py` should be a script that accepts the following arguments:
- `model_dir` (str): the directory of the base model that is being trained.
- `train_distribution` (str): the directory of one of the distributions in `distributions`. For example: `distributions/alpaca_mmlu`.
- `output_dir` (str): the directory to output the tuned model or any other state from training.
`src/interventions/your_intervention_name/eval.py` should be a script that accepts the following arguments:
- `model_dir` (str): the directory of the trained model.
- `distribution_dirs` (List\[str\]): a list of subdirectories of `distributions`.
- `output_paths` (List\[str\]): where to save the results. The results should be json files. The only required key is `eval_accuracy`. Evaluation results are stored in `results/evaluations`.# GENIES
+15
View File
@@ -0,0 +1,15 @@
README.md
setup.py
src/api/__init__.py
src/api/compute_generalization_metrics.py
src/api/data_classes.py
src/api/evaluate.py
src/api/hyperparameter_sweep.py
src/api/model.py
src/api/train.py
src/api/util.py
src/fig_benchmark.egg-info/PKG-INFO
src/fig_benchmark.egg-info/SOURCES.txt
src/fig_benchmark.egg-info/dependency_links.txt
src/fig_benchmark.egg-info/requires.txt
src/fig_benchmark.egg-info/top_level.txt
@@ -0,0 +1 @@
+203
View File
@@ -0,0 +1,203 @@
absl-py==2.0.0
accelerate==0.24.1
aiohttp==3.8.6
aiosignal==1.3.1
annotated-types==0.6.0
anyio==3.7.1
appdirs==1.4.4
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==2.4.1
async-lru==2.0.4
async-timeout==4.0.3
attrs==23.1.0
Babel==2.13.1
bcrypt==4.0.1
beautifulsoup4==4.12.2
bitsandbytes==0.41.2.post2
bleach==6.1.0
boto3==1.28.84
botocore==1.31.84
certifi==2023.7.22
cffi==1.16.0
charset-normalizer==3.3.2
click==8.1.7
comm==0.2.0
contourpy==1.2.0
cryptography==41.0.5
cycler==0.12.1
Cython==0.29.36
datasets==2.14.6
debugpy==1.8.0
decorator==5.1.1
deepspeed==0.12.2
defusedxml==0.7.1
dill==0.3.7
distro==1.8.0
docker-pycreds==0.4.0
docopt==0.6.2
docstring-parser==0.15
einops==0.7.0
entrypoints==0.4
exceptiongroup==1.1.3
executing==2.0.1
fastjsonschema==2.18.1
filelock==3.13.1
fire==0.5.0
fonttools==4.44.0
fqdn==1.5.1
frozenlist==1.4.0
fsspec==2023.10.0
gitdb==4.0.11
GitPython==3.1.40
h11==0.14.0
hjson==3.1.0
httpcore==1.0.2
httpx==0.25.1
huggingface-hub==0.17.3
idna==3.4
ipykernel==6.26.0
ipython==8.17.2
ipywidgets==8.1.1
isoduration==20.11.0
jedi==0.19.1
Jinja2==3.1.2
jmespath==1.0.1
joblib==1.3.2
json5==0.9.14
jsonpointer==2.4
jsonschema==4.19.2
jsonschema-specifications==2023.7.1
jupyter==1.0.0
jupyter-console==6.6.3
jupyter-events==0.9.0
jupyter-lsp==2.2.0
jupyter_client==8.6.0
jupyter_core==5.5.0
jupyter_server==2.10.0
jupyter_server_terminals==0.4.4
jupyterlab==4.0.8
jupyterlab-pygments==0.2.2
jupyterlab-widgets==3.0.9
jupyterlab_server==2.25.1
kiwisolver==1.4.5
markdown-it-py==3.0.0
MarkupSafe==2.1.3
matplotlib==3.8.1
matplotlib-inline==0.1.6
mdurl==0.1.2
mistune==3.0.2
mpmath==1.3.0
multidict==6.0.4
multiprocess==0.70.15
nbclient==0.9.0
nbconvert==7.11.0
nbformat==5.9.2
nest-asyncio==1.5.8
networkx==3.2.1
ninja==1.11.1.1
nltk==3.8.1
notebook==7.0.6
notebook_shim==0.2.3
numpy==1.26.1
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.18.1
nvidia-nvjitlink-cu12==12.3.52
nvidia-nvtx-cu12==12.1.105
openai==1.2.3
overrides==7.4.0
packaging==23.2
pandas==2.1.3
pandocfilters==1.5.0
papermill==2.5.0
paramiko==3.3.1
parso==0.8.3
peft==0.6.1
pexpect==4.8.0
Pillow==10.1.0
pipreqs==0.4.13
platformdirs==4.0.0
prometheus-client==0.18.0
prompt-toolkit==3.0.40
protobuf==3.20.3
psutil==5.9.6
ptyprocess==0.7.0
pure-eval==0.2.2
py-cpuinfo==9.0.0
pyarrow==14.0.1
pycparser==2.21
pydantic==2.4.2
pydantic_core==2.10.1
Pygments==2.16.1
PyNaCl==1.5.0
pynvml==11.5.0
pyparsing==3.1.1
python-dateutil==2.8.2
python-json-logger==2.0.7
pytz==2023.3.post1
PyYAML==6.0.1
pyzmq==25.1.1
qtconsole==5.5.0
QtPy==2.4.1
referencing==0.30.2
regex==2023.10.3
requests==2.31.0
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==13.6.0
rouge-score==0.1.2
rpds-py==0.12.0
s3transfer==0.7.0
safetensors==0.4.0
scikit-learn==1.3.2
scipy==1.11.3
seaborn==0.13.0
Send2Trash==1.8.2
sentencepiece==0.1.99
sentry-sdk==1.34.0
setproctitle==1.3.3
shtab==1.6.4
six==1.16.0
smmap==5.0.1
sniffio==1.3.0
soupsieve==2.5
stack-data==0.6.3
sympy==1.12
tenacity==8.2.3
termcolor==2.3.0
terminado==0.18.0
threadpoolctl==3.2.0
tinycss2==1.2.1
tokenizers==0.14.1
tomli==2.0.1
torch==2.1.0
tornado==6.3.3
tqdm==4.66.1
traitlets==5.13.0
transformers==4.35.0
triton==2.1.0
trl==0.7.4
types-python-dateutil==2.8.19.14
typing_extensions==4.8.0
tyro==0.5.12
tzdata==2023.3
uri-template==1.3.0
urllib3==2.0.7
wandb==0.16.0
wcwidth==0.2.9
webcolors==1.13
webencodings==0.5.1
websocket-client==1.6.4
widgetsnbextension==4.0.9
xxhash==3.4.1
yarg==0.1.9
yarl==1.9.2
+1
View File
@@ -0,0 +1 @@
api
+52
View File
@@ -0,0 +1,52 @@
Metadata-Version: 2.1
Name: genies-benchmark
Version: 0.0.1
Summary: The fig benchmark repository contains datasets and tooling for evaluating the generalization of preferrence models.
Home-page: https://github.com/Joshuaclymer/GENIES
Author: Joshua Clymer, Garrett Baker, Rohan Subramani, and Sam Wang
Author-email: joshuamclymer@gmail.com
Classifier: Programming Language :: Python :: 3
Classifier: License :: MIT License
Classifier: Operating System :: OS Independent
Requires-Python: >=3.10
Description-Content-Type: text/markdown
# Generalization Analogies: A Testbed for Generalizing AI Oversight to Hard-To-Measure Domains
Read our paper [here](TODO). Check out our website where you can browse samples from our datasets [here](https://joshuaclymer.github.io/generalization-analogies-website/).
![Hero](assets/hero_horizontal.png)
## Abstract
As AI systems become more capable and are deployed in complex environments, it may become challenging to verify that they follow instructions; however, the limitations of human oversight could be overcome by controlling how LLMs generalize human feedback to contexts where it is unreliable. To better understand how Reward Models generalize human feedback, we craft 69 distribution shifts spanning 8 different categories. We find that Reward Models do not learn to evaluate instruction-following by default and instead favor personas that resemble internet text. Techniques for interpreting Reward Models internal representations achieve better generalization, but still frequently fail to distinguish instruction-following from conflated behaviors. We consolidate the 15 most challenging distribution shifts into the \textbf{GEN}aralization analog\textbf{IES} (\textsc{GENIES}) benchmark, which we hope will enable progress toward controlling Reward Model generalization.
## Quickstart
This repository contains:
- Our datasets (`./distributions`) along with pairing specifications (`./distribution_shifts`). Download our datasets [here](https://genies-data.s3.us-east-2.amazonaws.com/genies-datasets.tar) or run the setup command after cloning the repo.
- Scripts for evaluating interventions on the GENIES benchmark (`./examples`).
- Our results (`./results`).
- Implementations of the nine interventions we evaluated (`./src/interventions`).
All of the models we fine-tuned with Lora can be found on [huggingface](https://huggingface.co/genies-models).
**Setup:**
```
conda create --name env python=3.10
pip install .
python download_model_from_hf.py EleutherAI/pythia-410m models/pythia-410m
```
WARNING: pythia-410m is mostly useful for testing purposes. Most tuning interventions perform poorly with this model.
## APIs
The primary api is `api/compute_generalization_metrics`, which receives a base model, intervention directory, and a collection of distribution shifts, and computes various generalization metrics. See `examples/compute_generalization_metrics.sh` for example usage.
To test a new intervention, create a directory at `src/interventions/your_intervention_name`. This directory must contain a `train.py` file and an `eval.py` file.
`src/interventions/your_intervention_name/train.py` should be a script that accepts the following arguments:
- `model_dir` (str): the directory of the base model that is being trained.
- `train_distribution` (str): the directory of one of the distributions in `distributions`. For example: `distributions/alpaca_mmlu`.
- `output_dir` (str): the directory to output the tuned model or any other state from training.
`src/interventions/your_intervention_name/eval.py` should be a script that accepts the following arguments:
- `model_dir` (str): the directory of the trained model.
- `distribution_dirs` (List\[str\]): a list of subdirectories of `distributions`.
- `output_paths` (List\[str\]): where to save the results. The results should be json files. The only required key is `eval_accuracy`. Evaluation results are stored in `results/evaluations`.# GENIES
+15
View File
@@ -0,0 +1,15 @@
README.md
setup.py
src/api/__init__.py
src/api/compute_generalization_metrics.py
src/api/data_classes.py
src/api/evaluate.py
src/api/hyperparameter_sweep.py
src/api/model.py
src/api/train.py
src/api/util.py
src/genies_benchmark.egg-info/PKG-INFO
src/genies_benchmark.egg-info/SOURCES.txt
src/genies_benchmark.egg-info/dependency_links.txt
src/genies_benchmark.egg-info/requires.txt
src/genies_benchmark.egg-info/top_level.txt
@@ -0,0 +1 @@
+203
View File
@@ -0,0 +1,203 @@
absl-py==2.0.0
accelerate==0.24.1
aiohttp==3.8.6
aiosignal==1.3.1
annotated-types==0.6.0
anyio==3.7.1
appdirs==1.4.4
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==2.4.1
async-lru==2.0.4
async-timeout==4.0.3
attrs==23.1.0
Babel==2.13.1
bcrypt==4.0.1
beautifulsoup4==4.12.2
bitsandbytes==0.41.2.post2
bleach==6.1.0
boto3==1.28.84
botocore==1.31.84
certifi==2023.7.22
cffi==1.16.0
charset-normalizer==3.3.2
click==8.1.7
comm==0.2.0
contourpy==1.2.0
cryptography==41.0.5
cycler==0.12.1
Cython==0.29.36
datasets==2.14.6
debugpy==1.8.0
decorator==5.1.1
deepspeed==0.12.2
defusedxml==0.7.1
dill==0.3.7
distro==1.8.0
docker-pycreds==0.4.0
docopt==0.6.2
docstring-parser==0.15
einops==0.7.0
entrypoints==0.4
exceptiongroup==1.1.3
executing==2.0.1
fastjsonschema==2.18.1
filelock==3.13.1
fire==0.5.0
fonttools==4.44.0
fqdn==1.5.1
frozenlist==1.4.0
fsspec==2023.10.0
gitdb==4.0.11
GitPython==3.1.40
h11==0.14.0
hjson==3.1.0
httpcore==1.0.2
httpx==0.25.1
huggingface-hub==0.17.3
idna==3.4
ipykernel==6.26.0
ipython==8.17.2
ipywidgets==8.1.1
isoduration==20.11.0
jedi==0.19.1
Jinja2==3.1.2
jmespath==1.0.1
joblib==1.3.2
json5==0.9.14
jsonpointer==2.4
jsonschema==4.19.2
jsonschema-specifications==2023.7.1
jupyter==1.0.0
jupyter-console==6.6.3
jupyter-events==0.9.0
jupyter-lsp==2.2.0
jupyter_client==8.6.0
jupyter_core==5.5.0
jupyter_server==2.10.0
jupyter_server_terminals==0.4.4
jupyterlab==4.0.8
jupyterlab-pygments==0.2.2
jupyterlab-widgets==3.0.9
jupyterlab_server==2.25.1
kiwisolver==1.4.5
markdown-it-py==3.0.0
MarkupSafe==2.1.3
matplotlib==3.8.1
matplotlib-inline==0.1.6
mdurl==0.1.2
mistune==3.0.2
mpmath==1.3.0
multidict==6.0.4
multiprocess==0.70.15
nbclient==0.9.0
nbconvert==7.11.0
nbformat==5.9.2
nest-asyncio==1.5.8
networkx==3.2.1
ninja==1.11.1.1
nltk==3.8.1
notebook==7.0.6
notebook_shim==0.2.3
numpy==1.26.1
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.18.1
nvidia-nvjitlink-cu12==12.3.52
nvidia-nvtx-cu12==12.1.105
openai==1.2.3
overrides==7.4.0
packaging==23.2
pandas==2.1.3
pandocfilters==1.5.0
papermill==2.5.0
paramiko==3.3.1
parso==0.8.3
peft==0.6.1
pexpect==4.8.0
Pillow==10.1.0
pipreqs==0.4.13
platformdirs==4.0.0
prometheus-client==0.18.0
prompt-toolkit==3.0.40
protobuf==3.20.3
psutil==5.9.6
ptyprocess==0.7.0
pure-eval==0.2.2
py-cpuinfo==9.0.0
pyarrow==14.0.1
pycparser==2.21
pydantic==2.4.2
pydantic_core==2.10.1
Pygments==2.16.1
PyNaCl==1.5.0
pynvml==11.5.0
pyparsing==3.1.1
python-dateutil==2.8.2
python-json-logger==2.0.7
pytz==2023.3.post1
PyYAML==6.0.1
pyzmq==25.1.1
qtconsole==5.5.0
QtPy==2.4.1
referencing==0.30.2
regex==2023.10.3
requests==2.31.0
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==13.6.0
rouge-score==0.1.2
rpds-py==0.12.0
s3transfer==0.7.0
safetensors==0.4.0
scikit-learn==1.3.2
scipy==1.11.3
seaborn==0.13.0
Send2Trash==1.8.2
sentencepiece==0.1.99
sentry-sdk==1.34.0
setproctitle==1.3.3
shtab==1.6.4
six==1.16.0
smmap==5.0.1
sniffio==1.3.0
soupsieve==2.5
stack-data==0.6.3
sympy==1.12
tenacity==8.2.3
termcolor==2.3.0
terminado==0.18.0
threadpoolctl==3.2.0
tinycss2==1.2.1
tokenizers==0.14.1
tomli==2.0.1
torch==2.1.0
tornado==6.3.3
tqdm==4.66.1
traitlets==5.13.0
transformers==4.35.0
triton==2.1.0
trl==0.7.4
types-python-dateutil==2.8.19.14
typing_extensions==4.8.0
tyro==0.5.12
tzdata==2023.3
uri-template==1.3.0
urllib3==2.0.7
wandb==0.16.0
wcwidth==0.2.9
webcolors==1.13
webencodings==0.5.1
websocket-client==1.6.4
widgetsnbextension==4.0.9
xxhash==3.4.1
yarg==0.1.9
yarl==1.9.2
@@ -0,0 +1 @@
api
+29
View File
@@ -0,0 +1,29 @@
import fire
import os
import api.util as util
import json
import boto3
import os
import tarfile
def make_tar(directory_path, output_filename):
"""Compresses directory into a tar file."""
with tarfile.open(output_filename, "w") as tar:
tar.add(directory_path, arcname=os.path.basename(directory_path))
def upload_directory_to_s3(directory_path, bucket_name, tar_name):
make_tar(directory_path, tar_name)
# Upload to S3
s3 = boto3.client('s3')
s3.upload_file(tar_name, bucket_name, tar_name)
print(f"Uploaded {tar_name} to {bucket_name}")
print("The link is: ", f"https://genies-data.s3.us-east-2.amazonaws.com/genies-datasets.tar")
def upload_data():
upload_directory_to_s3(f"distributions", "genies-data", f"genies-datasets.tar")
if __name__ == "__main__":
fire.Fire(upload_data)