Generic shared_model class (#880)

Changing `shared_model` class back to `get_model` rather than `ConvolutionalNetwork`
This commit is contained in:
Richard Liaw
2017-08-28 22:48:07 -07:00
committed by Robert Nishihara
parent 8099cdeb9d
commit 5d72818ddc
+1 -2
View File
@@ -6,7 +6,6 @@ import tensorflow as tf
from ray.rllib.models.misc import linear, normc_initializer
from ray.rllib.a3c.policy import Policy
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.convnet import ConvolutionalNetwork
class SharedModel(Policy):
@@ -16,7 +15,7 @@ class SharedModel(Policy):
def setup_graph(self, ob_space, ac_space):
self.x = tf.placeholder(tf.float32, [None] + list(ob_space))
dist_class, self.logit_dim = ModelCatalog.get_action_dist(ac_space)
self._model = ConvolutionalNetwork(self.x, self.logit_dim, {})
self._model = ModelCatalog.get_model(self.x, self.logit_dim)
self.logits = self._model.outputs
self.curr_dist = dist_class(self.logits)
# with tf.variable_scope("vf"):