mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 22:17:21 +08:00
[rllib] Remove nested import (#5204)
* remove nested import * Update metrics.py
This commit is contained in:
@@ -15,6 +15,11 @@ from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
RolloutMetrics = collections.namedtuple("RolloutMetrics", [
|
||||
"episode_length", "episode_reward", "agent_rewards", "custom_metrics",
|
||||
"perf_stats"
|
||||
])
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def get_learner_stats(grad_info):
|
||||
@@ -161,8 +166,6 @@ def summarize_episodes(episodes, new_episodes, num_dropped):
|
||||
def _partition(episodes):
|
||||
"""Divides metrics data into true rollouts vs off-policy estimates."""
|
||||
|
||||
from ray.rllib.evaluation.sampler import RolloutMetrics
|
||||
|
||||
rollouts, estimates = [], []
|
||||
for e in episodes:
|
||||
if isinstance(e, RolloutMetrics):
|
||||
|
||||
@@ -10,6 +10,7 @@ import threading
|
||||
import time
|
||||
|
||||
from ray.rllib.evaluation.episode import MultiAgentEpisode, _flatten_action
|
||||
from ray.rllib.evaluation.metrics import RolloutMetrics
|
||||
from ray.rllib.evaluation.sample_batch_builder import \
|
||||
MultiAgentSampleBatchBuilder
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
@@ -24,11 +25,6 @@ from ray.rllib.policy.policy import clip_action
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
RolloutMetrics = namedtuple("RolloutMetrics", [
|
||||
"episode_length", "episode_reward", "agent_rewards", "custom_metrics",
|
||||
"perf_stats"
|
||||
])
|
||||
|
||||
PolicyEvalData = namedtuple("PolicyEvalData", [
|
||||
"env_id", "agent_id", "obs", "info", "rnn_state", "prev_action",
|
||||
"prev_reward"
|
||||
|
||||
Reference in New Issue
Block a user