mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 19:32:11 +08:00
[RLlib] Issue with pickle versions (breaks rollout test cases in RLlib). (#11939)
This commit is contained in:
+3
-3
@@ -8,10 +8,10 @@ from gym import wrappers as gym_wrappers
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
import shelve
|
||||
|
||||
import ray
|
||||
import ray.cloudpickle as cloudpickle
|
||||
from ray.rllib.env import MultiAgentEnv
|
||||
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
@@ -118,7 +118,7 @@ class RolloutSaver:
|
||||
self._shelf.close()
|
||||
elif self._outfile and not self._use_shelve:
|
||||
# Dump everything as one big pickle:
|
||||
pickle.dump(self._rollouts, open(self._outfile, "wb"))
|
||||
cloudpickle.dump(self._rollouts, open(self._outfile, "wb"))
|
||||
if self._update_file:
|
||||
# Remove the temp progress file:
|
||||
self._get_tmp_progress_filename().unlink()
|
||||
@@ -261,7 +261,7 @@ def run(args, parser):
|
||||
# Load the config from pickled.
|
||||
else:
|
||||
with open(config_path, "rb") as f:
|
||||
config = pickle.load(f)
|
||||
config = cloudpickle.load(f)
|
||||
|
||||
# Set num_workers to be at least 2.
|
||||
if "num_workers" in config:
|
||||
|
||||
Reference in New Issue
Block a user