From 4fa2a6006c305694a682086b1b52608cc3b7b8ee Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 16 Jul 2019 10:52:56 -0700 Subject: [PATCH] [rllib] Remove nested import (#5204) * remove nested import * Update metrics.py --- python/ray/rllib/evaluation/metrics.py | 7 +++++-- python/ray/rllib/evaluation/sampler.py | 6 +----- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/python/ray/rllib/evaluation/metrics.py b/python/ray/rllib/evaluation/metrics.py index 341327608..3bba8c392 100644 --- a/python/ray/rllib/evaluation/metrics.py +++ b/python/ray/rllib/evaluation/metrics.py @@ -15,6 +15,11 @@ from ray.rllib.utils.memory import ray_get_and_free logger = logging.getLogger(__name__) +RolloutMetrics = collections.namedtuple("RolloutMetrics", [ + "episode_length", "episode_reward", "agent_rewards", "custom_metrics", + "perf_stats" +]) + @DeveloperAPI def get_learner_stats(grad_info): @@ -161,8 +166,6 @@ def summarize_episodes(episodes, new_episodes, num_dropped): def _partition(episodes): """Divides metrics data into true rollouts vs off-policy estimates.""" - from ray.rllib.evaluation.sampler import RolloutMetrics - rollouts, estimates = [], [] for e in episodes: if isinstance(e, RolloutMetrics): diff --git a/python/ray/rllib/evaluation/sampler.py b/python/ray/rllib/evaluation/sampler.py index 47964c3c5..f87639bd2 100644 --- a/python/ray/rllib/evaluation/sampler.py +++ b/python/ray/rllib/evaluation/sampler.py @@ -10,6 +10,7 @@ import threading import time from ray.rllib.evaluation.episode import MultiAgentEpisode, _flatten_action +from ray.rllib.evaluation.metrics import RolloutMetrics from ray.rllib.evaluation.sample_batch_builder import \ MultiAgentSampleBatchBuilder from ray.rllib.policy.tf_policy import TFPolicy @@ -24,11 +25,6 @@ from ray.rllib.policy.policy import clip_action logger = logging.getLogger(__name__) -RolloutMetrics = namedtuple("RolloutMetrics", [ - "episode_length", "episode_reward", "agent_rewards", "custom_metrics", - "perf_stats" -]) - PolicyEvalData = namedtuple("PolicyEvalData", [ "env_id", "agent_id", "obs", "info", "rnn_state", "prev_action", "prev_reward"