mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 19:49:04 +08:00
Move TensorFlowVariables to ray.experimental.tf_utils. (#4145)
This commit is contained in:
committed by
Philipp Moritz
parent
615d5516d1
commit
7b04ed059e
@@ -2,7 +2,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from .tfutils import TensorFlowVariables
|
||||
from .features import (
|
||||
flush_redis_unsafe, flush_task_and_object_metadata_unsafe,
|
||||
flush_finished_tasks_unsafe, flush_evicted_objects_unsafe,
|
||||
@@ -12,6 +11,13 @@ from .gcs_flush_policy import (set_flushing_policy, GcsFlushPolicy,
|
||||
from .named_actors import get_actor, register_actor
|
||||
from .api import get, wait
|
||||
|
||||
|
||||
def TensorFlowVariables(*args, **kwargs):
|
||||
raise DeprecationWarning(
|
||||
"'ray.experimental.TensorFlowVariables' is deprecated. Instead, please"
|
||||
" do 'from ray.experimental.tf_utils import TensorFlowVariables'.")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"TensorFlowVariables", "flush_redis_unsafe",
|
||||
"flush_task_and_object_metadata_unsafe", "flush_finished_tasks_unsafe",
|
||||
|
||||
@@ -24,7 +24,7 @@ from ray.tune import run_experiments
|
||||
from ray.tune.examples.tune_mnist_ray import deepnn
|
||||
from ray.experimental.sgd.model import Model
|
||||
from ray.experimental.sgd.sgd import DistributedSGD
|
||||
from ray.experimental.tfutils import TensorFlowVariables
|
||||
import ray.experimental.tf_utils
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--redis-address", default=None, type=str)
|
||||
@@ -67,8 +67,8 @@ class MNISTModel(Model):
|
||||
tf.nn.softmax_cross_entropy_with_logits(
|
||||
labels=self.y_, logits=y_conv))
|
||||
self.optimizer = tf.train.AdamOptimizer(1e-4)
|
||||
self.variables = TensorFlowVariables(self.loss,
|
||||
tf.get_default_session())
|
||||
self.variables = ray.experimental.tfutils.TensorFlowVariables(
|
||||
self.loss, tf.get_default_session())
|
||||
|
||||
# For evaluating test accuracy
|
||||
correct_prediction = tf.equal(
|
||||
|
||||
@@ -6,7 +6,7 @@ import tensorflow as tf
|
||||
|
||||
from tfbench import model_config
|
||||
from ray.experimental.sgd.model import Model
|
||||
from ray.experimental.tfutils import TensorFlowVariables
|
||||
import ray.experimental.tf_utils
|
||||
|
||||
|
||||
class MockDataset():
|
||||
@@ -47,8 +47,8 @@ class TFBenchModel(Model):
|
||||
self.loss = tf.reduce_mean(loss, name='xentropy-loss')
|
||||
self.optimizer = tf.train.GradientDescentOptimizer(1e-6)
|
||||
|
||||
self.variables = TensorFlowVariables(self.loss,
|
||||
tf.get_default_session())
|
||||
self.variables = ray.experimental.tf_utils.TensorFlowVariables(
|
||||
self.loss, tf.get_default_session())
|
||||
|
||||
def get_loss(self):
|
||||
return self.loss
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import numpy as np
|
||||
|
||||
from collections import deque, OrderedDict
|
||||
import numpy as np
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def unflatten(vector, shapes):
|
||||
@@ -45,7 +48,6 @@ class TensorFlowVariables(object):
|
||||
input_variables (List[tf.Variables]): Variables to include in the
|
||||
list.
|
||||
"""
|
||||
import tensorflow as tf
|
||||
self.sess = sess
|
||||
if not isinstance(output, (list, tuple)):
|
||||
output = [output]
|
||||
@@ -10,6 +10,7 @@ import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
import ray
|
||||
import ray.experimental.tf_utils
|
||||
from ray.rllib.evaluation.sampler import _unbatch_tuple_actions
|
||||
from ray.rllib.utils.filter import get_filter
|
||||
from ray.rllib.models import ModelCatalog
|
||||
@@ -81,7 +82,7 @@ class GenericPolicy(object):
|
||||
dist = dist_class(model.outputs)
|
||||
self.sampler = dist.sample()
|
||||
|
||||
self.variables = ray.experimental.TensorFlowVariables(
|
||||
self.variables = ray.experimental.tf_utils.TensorFlowVariables(
|
||||
model.outputs, self.sess)
|
||||
|
||||
self.num_params = sum(
|
||||
|
||||
@@ -8,8 +8,9 @@ import tensorflow as tf
|
||||
import tensorflow.contrib.layers as layers
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.dqn.dqn_policy_graph import _huber_loss, \
|
||||
_minimize_and_clip, _scope_vars, _postprocess_dqn
|
||||
import ray.experimental.tf_utils
|
||||
from ray.rllib.agents.dqn.dqn_policy_graph import (
|
||||
_huber_loss, _minimize_and_clip, _scope_vars, _postprocess_dqn)
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
@@ -387,7 +388,7 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
|
||||
# Note that this encompasses both the policy and Q-value networks and
|
||||
# their corresponding target networks
|
||||
self.variables = ray.experimental.TensorFlowVariables(
|
||||
self.variables = ray.experimental.tf_utils.TensorFlowVariables(
|
||||
tf.group(q_tp0, q_tp1), self.sess)
|
||||
|
||||
# Hard initial update
|
||||
|
||||
@@ -10,6 +10,7 @@ import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
import ray
|
||||
import ray.experimental.tf_utils
|
||||
from ray.rllib.evaluation.sampler import _unbatch_tuple_actions
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.utils.filter import get_filter
|
||||
@@ -59,7 +60,7 @@ class GenericPolicy(object):
|
||||
dist = dist_class(model.outputs)
|
||||
self.sampler = dist.sample()
|
||||
|
||||
self.variables = ray.experimental.TensorFlowVariables(
|
||||
self.variables = ray.experimental.tf_utils.TensorFlowVariables(
|
||||
model.outputs, self.sess)
|
||||
|
||||
self.num_params = sum(
|
||||
|
||||
@@ -9,6 +9,7 @@ import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
import ray.experimental.tf_utils
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.models.lstm import chop_into_sequences
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
@@ -120,7 +121,7 @@ class TFPolicyGraph(PolicyGraph):
|
||||
for (g, v) in self.gradients(self._optimizer)
|
||||
if g is not None]
|
||||
self._grads = [g for (g, v) in self._grads_and_vars]
|
||||
self._variables = ray.experimental.TensorFlowVariables(
|
||||
self._variables = ray.experimental.tf_utils.TensorFlowVariables(
|
||||
self._loss, self._sess)
|
||||
|
||||
# gather update ops for any batch norm layers
|
||||
|
||||
@@ -7,6 +7,7 @@ import pytest
|
||||
import tensorflow as tf
|
||||
|
||||
import ray
|
||||
import ray.experimental.tf_utils
|
||||
|
||||
|
||||
def make_linear_network(w_name=None, b_name=None):
|
||||
@@ -31,7 +32,7 @@ class LossActor(object):
|
||||
loss, init, _, _ = make_linear_network()
|
||||
sess = tf.Session()
|
||||
# Additional code for setting and getting the weights.
|
||||
weights = ray.experimental.TensorFlowVariables(
|
||||
weights = ray.experimental.tf_utils.TensorFlowVariables(
|
||||
loss if use_loss else None, sess, input_variables=var)
|
||||
# Return all of the data needed to use the network.
|
||||
self.values = [weights, init, sess]
|
||||
@@ -53,7 +54,8 @@ class NetActor(object):
|
||||
loss, init, _, _ = make_linear_network()
|
||||
sess = tf.Session()
|
||||
# Additional code for setting and getting the weights.
|
||||
variables = ray.experimental.TensorFlowVariables(loss, sess)
|
||||
variables = ray.experimental.tf_utils.TensorFlowVariables(
|
||||
loss, sess)
|
||||
# Return all of the data needed to use the network.
|
||||
self.values = [variables, init, sess]
|
||||
sess.run(init)
|
||||
@@ -73,7 +75,8 @@ class TrainActor(object):
|
||||
with tf.Graph().as_default():
|
||||
loss, init, x_data, y_data = make_linear_network()
|
||||
sess = tf.Session()
|
||||
variables = ray.experimental.TensorFlowVariables(loss, sess)
|
||||
variables = ray.experimental.tf_utils.TensorFlowVariables(
|
||||
loss, sess)
|
||||
optimizer = tf.train.GradientDescentOptimizer(0.9)
|
||||
grads = optimizer.compute_gradients(loss)
|
||||
train = optimizer.apply_gradients(grads)
|
||||
@@ -107,7 +110,7 @@ def test_tensorflow_variables(ray_start_regular):
|
||||
loss, init, _, _ = make_linear_network()
|
||||
sess.run(init)
|
||||
|
||||
variables = ray.experimental.TensorFlowVariables(loss, sess)
|
||||
variables = ray.experimental.tf_utils.TensorFlowVariables(loss, sess)
|
||||
weights = variables.get_weights()
|
||||
|
||||
for (name, val) in weights.items():
|
||||
@@ -119,7 +122,7 @@ def test_tensorflow_variables(ray_start_regular):
|
||||
loss2, init2, _, _ = make_linear_network("w", "b")
|
||||
sess.run(init2)
|
||||
|
||||
variables2 = ray.experimental.TensorFlowVariables(loss2, sess)
|
||||
variables2 = ray.experimental.tf_utils.TensorFlowVariables(loss2, sess)
|
||||
weights2 = variables2.get_weights()
|
||||
|
||||
for (name, val) in weights2.items():
|
||||
@@ -131,7 +134,7 @@ def test_tensorflow_variables(ray_start_regular):
|
||||
variables2.set_flat(flat_weights)
|
||||
assert_almost_equal(flat_weights, variables2.get_flat())
|
||||
|
||||
variables3 = ray.experimental.TensorFlowVariables([loss2])
|
||||
variables3 = ray.experimental.tf_utils.TensorFlowVariables([loss2])
|
||||
assert variables3.sess is None
|
||||
sess = tf.Session()
|
||||
variables3.set_session(sess)
|
||||
@@ -205,7 +208,7 @@ def test_network_driver_worker_independent(ray_start_regular):
|
||||
# Create a network on the driver locally.
|
||||
sess1 = tf.Session()
|
||||
loss1, init1, _, _ = make_linear_network()
|
||||
ray.experimental.TensorFlowVariables(loss1, sess1)
|
||||
ray.experimental.tf_utils.TensorFlowVariables(loss1, sess1)
|
||||
sess1.run(init1)
|
||||
|
||||
net2 = ray.remote(NetActor).remote()
|
||||
@@ -221,7 +224,7 @@ def test_variables_control_dependencies(ray_start_regular):
|
||||
sess = tf.Session()
|
||||
loss, init, _, _ = make_linear_network()
|
||||
minimizer = tf.train.MomentumOptimizer(0.9, 0.9).minimize(loss)
|
||||
net_vars = ray.experimental.TensorFlowVariables(minimizer, sess)
|
||||
net_vars = ray.experimental.tf_utils.TensorFlowVariables(minimizer, sess)
|
||||
sess.run(init)
|
||||
|
||||
# Tests if all variables are properly retrieved, 2 variables and 2
|
||||
|
||||
Reference in New Issue
Block a user