mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 19:25:51 +08:00
Generic shared_model class (#880)
Changing `shared_model` class back to `get_model` rather than `ConvolutionalNetwork`
This commit is contained in:
committed by
Robert Nishihara
parent
8099cdeb9d
commit
5d72818ddc
@@ -6,7 +6,6 @@ import tensorflow as tf
|
||||
from ray.rllib.models.misc import linear, normc_initializer
|
||||
from ray.rllib.a3c.policy import Policy
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.convnet import ConvolutionalNetwork
|
||||
|
||||
|
||||
class SharedModel(Policy):
|
||||
@@ -16,7 +15,7 @@ class SharedModel(Policy):
|
||||
def setup_graph(self, ob_space, ac_space):
|
||||
self.x = tf.placeholder(tf.float32, [None] + list(ob_space))
|
||||
dist_class, self.logit_dim = ModelCatalog.get_action_dist(ac_space)
|
||||
self._model = ConvolutionalNetwork(self.x, self.logit_dim, {})
|
||||
self._model = ModelCatalog.get_model(self.x, self.logit_dim)
|
||||
self.logits = self._model.outputs
|
||||
self.curr_dist = dist_class(self.logits)
|
||||
# with tf.variable_scope("vf"):
|
||||
|
||||
Reference in New Issue
Block a user