Add run_orpo.py (#143)

* Add `ORPOConfig`

* Add `task=orpo` and support `(prompt,chosen,rejected)` datasets

* Add missing `model_init_kwargs` and `dataset_num_proc`

* Add `run_orpo.py` (WIP)

* Update `trl` dependency from source

* Add `setup_chat_format` before `apply_chat_template`

* Add `config_full.yaml` for `mistral-7b-orpo`

* Fix comment indentation

* Use `chat_template=chatml` instead

* Add `kaist-ai/mistral-orpo-capybara-7k` recipe

* Rename `DPOTrainer` to `ORPOTrainer` in `config_full.yaml` files

* Run `black --line-length 119 src`

* Add `is_openai_format` to fix `(prompt,chosen,rejected)` formatting

* Run `black --line-length 119 src`

* Fix `isort` in `run_orpo.py`

* Update `mistral-capybara/orpo/config_full.yaml`

* Check if `test` is available split

* Pin `trl` to `alvarobartt/trl` fork (debugging)

* Add `qwen-capybara` recipe

* Update `mistral-capybara` recipe

* Set `add_generation_prompt=True` if `task="orpo"`

* Reduce `logging_steps` to 10

* Unset `add_generation_prompt` when `task=orpo`

* Add filtering based on prompt length

Done similarly to the original implementation, in order to better reproduce their results

* Fix prompt length filtering

* Update `trl` pinned version

* Remove extra outdate config files

* Update `recipes/mistral-capybara/orpo/config_full.yaml`

* Run `make style`

* Activate BEAST MODE

* Pin deps

* Add readme

* Fix dep

---------

Co-authored-by: Lewis Tunstall <lewis.c.tunstall@gmail.com>
This commit is contained in:
Alvaro Bartolome
2024-04-11 16:02:20 +02:00
committed by GitHub
parent a83b1f617f
commit 70769f9e9b
8 changed files with 468 additions and 17 deletions
+4 -4
View File
@@ -41,10 +41,10 @@ if stale_egg_info.exists():
# IMPORTANT: all dependencies should be listed here with their version requirements, if any.
# * If a dependency is fast-moving (e.g. transformers), pin to the exact version
_deps = [
"accelerate==0.27.2",
"accelerate>=0.29.2",
"bitsandbytes==0.41.2.post2",
"black==23.1.0",
"datasets==2.14.6",
"datasets>=2.18.0",
"deepspeed==0.12.2",
"einops>=0.6.1",
"evaluate==0.4.0",
@@ -65,8 +65,8 @@ _deps = [
"scipy",
"tensorboard",
"torch==2.1.2",
"transformers @ git+https://github.com/huggingface/transformers.git@831bc25d8fdb85768402f772cf65cc3d7872b211", # Enable StarCoder2
"trl==0.7.10",
"transformers>=4.39.3",
"trl>=0.8.2",
"jinja2>=3.0.0",
"tqdm>=4.64.1",
]