mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 18:06:25 +08:00
[rllib] Fix APEX priorities returning zero all the time (#5980)
* fix * move example tests to end * level err * guard against none * no trace test * ignore thumbs * np * fix multi node * fix
This commit is contained in:
@@ -100,6 +100,7 @@ scripts/nodes.txt
|
||||
|
||||
# Generated documentation files
|
||||
/doc/_build
|
||||
/doc/source/_static/thumbs
|
||||
|
||||
# User-specific stuff:
|
||||
.idea/**/workspace.xml
|
||||
|
||||
@@ -171,3 +171,8 @@ If you encounter out-of-memory errors, consider setting ``redis_max_memory`` and
|
||||
For debugging unexpected hangs or performance problems, you can run ``ray stack`` to dump
|
||||
the stack traces of all Ray workers on the current node, and ``ray timeline`` to dump
|
||||
a timeline visualization of tasks to a file.
|
||||
|
||||
TensorFlow 2.0
|
||||
~~~~~~~~~~~~~~
|
||||
|
||||
RLlib currently runs in ``tf.compat.v1`` mode. This means eager execution is disabled by default, and RLlib imports TF with ``import tensorflow.compat.v1 as tf; tf.disable_v2_behaviour()``. Eager execution can be enabled manually by calling ``tf.enable_eager_execution()`` or setting the ``"eager": True`` trainer config.
|
||||
|
||||
@@ -140,9 +140,6 @@ class ComputeTDErrorMixin(object):
|
||||
@make_tf_callable(self.get_session(), dynamic_shape=True)
|
||||
def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask,
|
||||
importance_weights):
|
||||
if not self.loss_initialized():
|
||||
return tf.zeros_like(rew_t)
|
||||
|
||||
# Do forward pass on loss to update td error attribute
|
||||
build_q_losses(
|
||||
self, self.model, None, {
|
||||
|
||||
@@ -290,9 +290,6 @@ class ComputeTDErrorMixin(object):
|
||||
@make_tf_callable(self.get_session(), dynamic_shape=True)
|
||||
def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask,
|
||||
importance_weights):
|
||||
if not self.loss_initialized():
|
||||
return tf.zeros_like(rew_t)
|
||||
|
||||
# Do forward pass on loss to update td error attribute
|
||||
actor_critic_loss(
|
||||
self, self.model, None, {
|
||||
|
||||
@@ -31,7 +31,8 @@ def _convert_to_tf(x):
|
||||
return x
|
||||
|
||||
if x is not None:
|
||||
x = tf.nest.map_structure(tf.convert_to_tensor, x)
|
||||
x = tf.nest.map_structure(
|
||||
lambda f: tf.convert_to_tensor(f) if f is not None else None, x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
@@ -5,19 +5,21 @@ from ray import tune
|
||||
from ray.rllib.agents.registry import get_agent_class
|
||||
|
||||
|
||||
def check_support(alg, config):
|
||||
def check_support(alg, config, test_trace=True):
|
||||
config["eager"] = True
|
||||
if alg in ["APEX_DDPG", "TD3", "DDPG", "SAC"]:
|
||||
config["env"] = "Pendulum-v0"
|
||||
else:
|
||||
config["env"] = "CartPole-v0"
|
||||
a = get_agent_class(alg)
|
||||
config["log_level"] = "ERROR"
|
||||
|
||||
config["eager_tracing"] = False
|
||||
tune.run(a, config=config, stop={"training_iteration": 0})
|
||||
|
||||
config["eager_tracing"] = True
|
||||
tune.run(a, config=config, stop={"training_iteration": 0})
|
||||
if test_trace:
|
||||
config["eager_tracing"] = True
|
||||
tune.run(a, config=config, stop={"training_iteration": 0})
|
||||
|
||||
|
||||
class TestEagerSupport(unittest.TestCase):
|
||||
@@ -37,7 +39,8 @@ class TestEagerSupport(unittest.TestCase):
|
||||
check_support("A2C", {"num_workers": 0})
|
||||
|
||||
def testA3C(self):
|
||||
check_support("A3C", {"num_workers": 1})
|
||||
# TODO(ekl) trace on is flaky
|
||||
check_support("A3C", {"num_workers": 1}, test_trace=False)
|
||||
|
||||
def testPG(self):
|
||||
check_support("PG", {"num_workers": 0})
|
||||
|
||||
Reference in New Issue
Block a user