diff --git a/on_policy_data_gen/decode.py b/on_policy_data_gen/decode.py index eb01f42..18094ed 100644 --- a/on_policy_data_gen/decode.py +++ b/on_policy_data_gen/decode.py @@ -30,7 +30,7 @@ tokenizer = llm.get_tokenizer() train_dataset= load_dataset(data_dir, split='train_prefs') -prompts = list(set(train_dataset['prompt'])) +prompts = sorted(list(set(train_dataset['prompt']))) conversations = [tokenizer.apply_chat_template([{'role': 'user', 'content': prompt}], tokenize=False, add_generation_prompt=True) for prompt in prompts]