mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 22:37:36 +08:00
[rllib] Move some inline defs to avoid deserialization errors (#5228)
* fix bug * move metrics too
This commit is contained in:
@@ -7,6 +7,7 @@ import numpy as np
|
||||
import collections
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.rollout_metrics import RolloutMetrics
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimate
|
||||
from ray.rllib.policy.policy import LEARNER_STATS_KEY
|
||||
@@ -15,11 +16,6 @@ from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
RolloutMetrics = collections.namedtuple("RolloutMetrics", [
|
||||
"episode_length", "episode_reward", "agent_rewards", "custom_metrics",
|
||||
"perf_stats"
|
||||
])
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def get_learner_stats(grad_info):
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
# Define this in its own file, see #5125
|
||||
RolloutMetrics = collections.namedtuple("RolloutMetrics", [
|
||||
"episode_length", "episode_reward", "agent_rewards", "custom_metrics",
|
||||
"perf_stats"
|
||||
])
|
||||
@@ -277,13 +277,13 @@ class RolloutWorker(EvaluatorInterface):
|
||||
dim=model_config.get("dim"),
|
||||
framestack=model_config.get("framestack"))
|
||||
if monitor_path:
|
||||
env = _monitor(env, monitor_path)
|
||||
env = gym.wrappers.Monitor(env, monitor_path, resume=True)
|
||||
return env
|
||||
else:
|
||||
|
||||
def wrap(env):
|
||||
if monitor_path:
|
||||
env = _monitor(env, monitor_path)
|
||||
env = gym.wrappers.Monitor(env, monitor_path, resume=True)
|
||||
return env
|
||||
|
||||
self.env = wrap(self.env)
|
||||
@@ -798,10 +798,6 @@ def _validate_env(env):
|
||||
return env
|
||||
|
||||
|
||||
def _monitor(env, path):
|
||||
return gym.wrappers.Monitor(env, path, resume=True)
|
||||
|
||||
|
||||
def _has_tensorflow_graph(policy_dict):
|
||||
for policy, _, _, _ in policy_dict.values():
|
||||
if issubclass(policy, TFPolicy):
|
||||
|
||||
@@ -10,7 +10,7 @@ import threading
|
||||
import time
|
||||
|
||||
from ray.rllib.evaluation.episode import MultiAgentEpisode, _flatten_action
|
||||
from ray.rllib.evaluation.metrics import RolloutMetrics
|
||||
from ray.rllib.evaluation.rollout_metrics import RolloutMetrics
|
||||
from ray.rllib.evaluation.sample_batch_builder import \
|
||||
MultiAgentSampleBatchBuilder
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
|
||||
Reference in New Issue
Block a user