From 28e5c5555d67854b71cee2ff3b26acc1b66f5cd2 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 18 Jul 2019 21:01:16 -0700 Subject: [PATCH] [rllib] Move some inline defs to avoid deserialization errors (#5228) * fix bug * move metrics too --- python/ray/rllib/evaluation/metrics.py | 6 +----- python/ray/rllib/evaluation/rollout_metrics.py | 11 +++++++++++ python/ray/rllib/evaluation/rollout_worker.py | 8 ++------ python/ray/rllib/evaluation/sampler.py | 2 +- 4 files changed, 15 insertions(+), 12 deletions(-) create mode 100644 python/ray/rllib/evaluation/rollout_metrics.py diff --git a/python/ray/rllib/evaluation/metrics.py b/python/ray/rllib/evaluation/metrics.py index 3bba8c392..817f27b54 100644 --- a/python/ray/rllib/evaluation/metrics.py +++ b/python/ray/rllib/evaluation/metrics.py @@ -7,6 +7,7 @@ import numpy as np import collections import ray +from ray.rllib.evaluation.rollout_metrics import RolloutMetrics from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.offline.off_policy_estimator import OffPolicyEstimate from ray.rllib.policy.policy import LEARNER_STATS_KEY @@ -15,11 +16,6 @@ from ray.rllib.utils.memory import ray_get_and_free logger = logging.getLogger(__name__) -RolloutMetrics = collections.namedtuple("RolloutMetrics", [ - "episode_length", "episode_reward", "agent_rewards", "custom_metrics", - "perf_stats" -]) - @DeveloperAPI def get_learner_stats(grad_info): diff --git a/python/ray/rllib/evaluation/rollout_metrics.py b/python/ray/rllib/evaluation/rollout_metrics.py new file mode 100644 index 000000000..fbb57f953 --- /dev/null +++ b/python/ray/rllib/evaluation/rollout_metrics.py @@ -0,0 +1,11 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +# Define this in its own file, see #5125 +RolloutMetrics = collections.namedtuple("RolloutMetrics", [ + "episode_length", "episode_reward", "agent_rewards", "custom_metrics", + "perf_stats" +]) diff --git a/python/ray/rllib/evaluation/rollout_worker.py b/python/ray/rllib/evaluation/rollout_worker.py index d3e97ec26..2e1269e74 100644 --- a/python/ray/rllib/evaluation/rollout_worker.py +++ b/python/ray/rllib/evaluation/rollout_worker.py @@ -277,13 +277,13 @@ class RolloutWorker(EvaluatorInterface): dim=model_config.get("dim"), framestack=model_config.get("framestack")) if monitor_path: - env = _monitor(env, monitor_path) + env = gym.wrappers.Monitor(env, monitor_path, resume=True) return env else: def wrap(env): if monitor_path: - env = _monitor(env, monitor_path) + env = gym.wrappers.Monitor(env, monitor_path, resume=True) return env self.env = wrap(self.env) @@ -798,10 +798,6 @@ def _validate_env(env): return env -def _monitor(env, path): - return gym.wrappers.Monitor(env, path, resume=True) - - def _has_tensorflow_graph(policy_dict): for policy, _, _, _ in policy_dict.values(): if issubclass(policy, TFPolicy): diff --git a/python/ray/rllib/evaluation/sampler.py b/python/ray/rllib/evaluation/sampler.py index f87639bd2..e2058c4d6 100644 --- a/python/ray/rllib/evaluation/sampler.py +++ b/python/ray/rllib/evaluation/sampler.py @@ -10,7 +10,7 @@ import threading import time from ray.rllib.evaluation.episode import MultiAgentEpisode, _flatten_action -from ray.rllib.evaluation.metrics import RolloutMetrics +from ray.rllib.evaluation.rollout_metrics import RolloutMetrics from ray.rllib.evaluation.sample_batch_builder import \ MultiAgentSampleBatchBuilder from ray.rllib.policy.tf_policy import TFPolicy