[rllib] Move some inline defs to avoid deserialization errors (#5228)

* fix bug

* move metrics too
This commit is contained in:
Eric Liang
2019-07-18 21:01:16 -07:00
committed by GitHub
parent aa42328874
commit 28e5c5555d
4 changed files with 15 additions and 12 deletions
+1 -5
View File
@@ -7,6 +7,7 @@ import numpy as np
import collections
import ray
from ray.rllib.evaluation.rollout_metrics import RolloutMetrics
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimate
from ray.rllib.policy.policy import LEARNER_STATS_KEY
@@ -15,11 +16,6 @@ from ray.rllib.utils.memory import ray_get_and_free
logger = logging.getLogger(__name__)
RolloutMetrics = collections.namedtuple("RolloutMetrics", [
"episode_length", "episode_reward", "agent_rewards", "custom_metrics",
"perf_stats"
])
@DeveloperAPI
def get_learner_stats(grad_info):
@@ -0,0 +1,11 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
# Define this in its own file, see #5125
RolloutMetrics = collections.namedtuple("RolloutMetrics", [
"episode_length", "episode_reward", "agent_rewards", "custom_metrics",
"perf_stats"
])
@@ -277,13 +277,13 @@ class RolloutWorker(EvaluatorInterface):
dim=model_config.get("dim"),
framestack=model_config.get("framestack"))
if monitor_path:
env = _monitor(env, monitor_path)
env = gym.wrappers.Monitor(env, monitor_path, resume=True)
return env
else:
def wrap(env):
if monitor_path:
env = _monitor(env, monitor_path)
env = gym.wrappers.Monitor(env, monitor_path, resume=True)
return env
self.env = wrap(self.env)
@@ -798,10 +798,6 @@ def _validate_env(env):
return env
def _monitor(env, path):
return gym.wrappers.Monitor(env, path, resume=True)
def _has_tensorflow_graph(policy_dict):
for policy, _, _, _ in policy_dict.values():
if issubclass(policy, TFPolicy):
+1 -1
View File
@@ -10,7 +10,7 @@ import threading
import time
from ray.rllib.evaluation.episode import MultiAgentEpisode, _flatten_action
from ray.rllib.evaluation.metrics import RolloutMetrics
from ray.rllib.evaluation.rollout_metrics import RolloutMetrics
from ray.rllib.evaluation.sample_batch_builder import \
MultiAgentSampleBatchBuilder
from ray.rllib.policy.tf_policy import TFPolicy