mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 19:32:11 +08:00
[Ray RLlib] Fix tree import (#7662)
* Rollback. * Fix import tree error by adding meaningful error and replacing by tf.nest wherever possible. * LINT. * LINT. * Fix. * Fix log-likelihood test case failing on travis.
This commit is contained in:
@@ -5,7 +5,6 @@ It supports both traced and non-traced eager execution modes."""
|
||||
import logging
|
||||
import functools
|
||||
import numpy as np
|
||||
import tree
|
||||
|
||||
from ray.util.debug import log_once
|
||||
from ray.rllib.evaluation.episode import _flatten_action
|
||||
@@ -24,12 +23,12 @@ logger = logging.getLogger(__name__)
|
||||
def _convert_to_tf(x):
|
||||
if isinstance(x, SampleBatch):
|
||||
x = {k: v for k, v in x.items() if k != SampleBatch.INFOS}
|
||||
return tree.map_structure(_convert_to_tf, x)
|
||||
return tf.nest.map_structure(_convert_to_tf, x)
|
||||
if isinstance(x, Policy):
|
||||
return x
|
||||
|
||||
if x is not None:
|
||||
x = tree.map_structure(
|
||||
x = tf.nest.map_structure(
|
||||
lambda f: tf.convert_to_tensor(f) if f is not None else None, x)
|
||||
return x
|
||||
|
||||
@@ -38,7 +37,7 @@ def _convert_to_numpy(x):
|
||||
if x is None:
|
||||
return None
|
||||
try:
|
||||
return tree.map_structure(lambda component: component.numpy(), x)
|
||||
return tf.nest.map_structure(lambda component: component.numpy(), x)
|
||||
except AttributeError:
|
||||
raise TypeError(
|
||||
("Object of type {} has no method to convert to numpy.").format(
|
||||
@@ -66,7 +65,7 @@ def convert_eager_outputs(func):
|
||||
def _func(*args, **kwargs):
|
||||
out = func(*args, **kwargs)
|
||||
if tf.executing_eagerly():
|
||||
out = tree.map_structure(_convert_to_numpy, out)
|
||||
out = tf.nest.map_structure(_convert_to_numpy, out)
|
||||
return out
|
||||
|
||||
return _func
|
||||
@@ -551,7 +550,7 @@ def build_eager_tf_policy(name,
|
||||
SampleBatch.NEXT_OBS: np.array(
|
||||
[self.observation_space.sample()]),
|
||||
SampleBatch.DONES: np.array([False], dtype=np.bool),
|
||||
SampleBatch.ACTIONS: tree.map_structure(
|
||||
SampleBatch.ACTIONS: tf.nest.map_structure(
|
||||
lambda c: np.array([c]), self.action_space.sample()),
|
||||
SampleBatch.REWARDS: np.array([0], dtype=np.float32),
|
||||
}
|
||||
@@ -568,7 +567,8 @@ def build_eager_tf_policy(name,
|
||||
dummy_batch["seq_lens"] = np.array([1], dtype=np.int32)
|
||||
|
||||
# Convert everything to tensors.
|
||||
dummy_batch = tree.map_structure(tf.convert_to_tensor, dummy_batch)
|
||||
dummy_batch = tf.nest.map_structure(tf.convert_to_tensor,
|
||||
dummy_batch)
|
||||
|
||||
# for IMPALA which expects a certain sample batch size.
|
||||
def tile_to(tensor, n):
|
||||
@@ -576,7 +576,7 @@ def build_eager_tf_policy(name,
|
||||
[n] + [1 for _ in tensor.shape.as_list()[1:]])
|
||||
|
||||
if get_batch_divisibility_req:
|
||||
dummy_batch = tree.map_structure(
|
||||
dummy_batch = tf.nest.map_structure(
|
||||
lambda c: tile_to(c, get_batch_divisibility_req(self)),
|
||||
dummy_batch)
|
||||
|
||||
@@ -595,7 +595,7 @@ def build_eager_tf_policy(name,
|
||||
# overwrite any tensor state from that call)
|
||||
self.model.from_batch(dummy_batch)
|
||||
|
||||
postprocessed_batch = tree.map_structure(
|
||||
postprocessed_batch = tf.nest.map_structure(
|
||||
lambda c: tf.convert_to_tensor(c), postprocessed_batch.data)
|
||||
|
||||
loss_fn(self, self.model, self.dist_class, postprocessed_batch)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import numpy as np
|
||||
from scipy.stats import norm
|
||||
from tensorflow.python.eager.context import eager_mode
|
||||
import unittest
|
||||
|
||||
import ray.rllib.agents.dqn as dqn
|
||||
@@ -43,6 +44,14 @@ def do_test_log_likelihood(run,
|
||||
config["eager"] = fw == "eager"
|
||||
config["use_pytorch"] = fw == "torch"
|
||||
|
||||
eager_ctx = None
|
||||
if fw == "eager":
|
||||
eager_ctx = eager_mode()
|
||||
eager_ctx.__enter__()
|
||||
assert tf.executing_eagerly()
|
||||
elif fw == "tf":
|
||||
assert not tf.executing_eagerly()
|
||||
|
||||
trainer = run(config=config, env=env)
|
||||
policy = trainer.get_policy()
|
||||
vars = policy.get_weights()
|
||||
@@ -104,6 +113,9 @@ def do_test_log_likelihood(run,
|
||||
prev_reward_batch=np.array([prev_r]))
|
||||
check(np.exp(logp), expected_prob, atol=0.2)
|
||||
|
||||
if eager_ctx:
|
||||
eager_ctx.__exit__(None, None, None)
|
||||
|
||||
|
||||
class TestComputeLogLikelihood(unittest.TestCase):
|
||||
def test_dqn(self):
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import errno
|
||||
import logging
|
||||
import os
|
||||
import tree
|
||||
|
||||
import numpy as np
|
||||
import ray
|
||||
@@ -212,10 +211,8 @@ class TFPolicy(Policy):
|
||||
self._loss = loss
|
||||
|
||||
self._optimizer = self.optimizer()
|
||||
self._grads_and_vars = [
|
||||
(g, v) for (g, v) in self.gradients(self._optimizer, self._loss)
|
||||
if g is not None
|
||||
]
|
||||
self._grads_and_vars = [(g, v) for (g, v) in self.gradients(
|
||||
self._optimizer, self._loss) if g is not None]
|
||||
self._grads = [g for (g, v) in self._grads_and_vars]
|
||||
|
||||
# TODO(sven/ekl): Deprecate support for v1 models.
|
||||
@@ -493,7 +490,7 @@ class TFPolicy(Policy):
|
||||
|
||||
# build output signatures
|
||||
output_signature = self._extra_output_signature_def()
|
||||
for i, a in enumerate(tree.flatten(self._sampled_action)):
|
||||
for i, a in enumerate(tf.nest.flatten(self._sampled_action)):
|
||||
output_signature["actions_{}".format(i)] = \
|
||||
tf.saved_model.utils.build_tensor_info(a)
|
||||
|
||||
|
||||
@@ -1,9 +1,17 @@
|
||||
import tree
|
||||
import logging
|
||||
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
|
||||
torch, _ = try_import_torch()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import tree
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.warning("`dm-tree` is not installed! Run `pip install dm-tree`.")
|
||||
raise e
|
||||
|
||||
|
||||
def sequence_mask(lengths, maxlen, dtype=None):
|
||||
"""
|
||||
@@ -34,6 +42,7 @@ def convert_to_non_torch_type(stats):
|
||||
dict: A new dict with the same structure as stats_dict, but with all
|
||||
values converted to non-torch Tensor types.
|
||||
"""
|
||||
|
||||
# The mapping function used to numpyize torch Tensors.
|
||||
def mapping(item):
|
||||
if isinstance(item, torch.Tensor):
|
||||
|
||||
Reference in New Issue
Block a user