diff --git a/rllib/rollout.py b/rllib/rollout.py index 5f52f74f3..fecf65b87 100755 --- a/rllib/rollout.py +++ b/rllib/rollout.py @@ -8,10 +8,10 @@ from gym import wrappers as gym_wrappers import json import os from pathlib import Path -import pickle import shelve import ray +import ray.cloudpickle as cloudpickle from ray.rllib.env import MultiAgentEnv from ray.rllib.env.base_env import _DUMMY_AGENT_ID from ray.rllib.evaluation.worker_set import WorkerSet @@ -118,7 +118,7 @@ class RolloutSaver: self._shelf.close() elif self._outfile and not self._use_shelve: # Dump everything as one big pickle: - pickle.dump(self._rollouts, open(self._outfile, "wb")) + cloudpickle.dump(self._rollouts, open(self._outfile, "wb")) if self._update_file: # Remove the temp progress file: self._get_tmp_progress_filename().unlink() @@ -261,7 +261,7 @@ def run(args, parser): # Load the config from pickled. else: with open(config_path, "rb") as f: - config = pickle.load(f) + config = cloudpickle.load(f) # Set num_workers to be at least 2. if "num_workers" in config: