From 5d72818ddcc8c4043af845a8f0ea2d6e7a7dc62c Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Mon, 28 Aug 2017 22:48:07 -0700 Subject: [PATCH] Generic `shared_model` class (#880) Changing `shared_model` class back to `get_model` rather than `ConvolutionalNetwork` --- python/ray/rllib/a3c/shared_model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/ray/rllib/a3c/shared_model.py b/python/ray/rllib/a3c/shared_model.py index 2c0e934fc..ea96beef4 100644 --- a/python/ray/rllib/a3c/shared_model.py +++ b/python/ray/rllib/a3c/shared_model.py @@ -6,7 +6,6 @@ import tensorflow as tf from ray.rllib.models.misc import linear, normc_initializer from ray.rllib.a3c.policy import Policy from ray.rllib.models.catalog import ModelCatalog -from ray.rllib.models.convnet import ConvolutionalNetwork class SharedModel(Policy): @@ -16,7 +15,7 @@ class SharedModel(Policy): def setup_graph(self, ob_space, ac_space): self.x = tf.placeholder(tf.float32, [None] + list(ob_space)) dist_class, self.logit_dim = ModelCatalog.get_action_dist(ac_space) - self._model = ConvolutionalNetwork(self.x, self.logit_dim, {}) + self._model = ModelCatalog.get_model(self.x, self.logit_dim) self.logits = self._model.outputs self.curr_dist = dist_class(self.logits) # with tf.variable_scope("vf"):