[rllib]Update bc/policy.py (#2012)

This commit is contained in:
SunYiran
2018-05-16 02:52:24 +08:00
committed by Richard Liaw
parent 8fbb88485b
commit 79b45c6cfd
+13 -2
View File
@@ -2,8 +2,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ray
import tensorflow as tf
import gym
import ray
from ray.rllib.a3c.policy import Policy
from ray.rllib.models.catalog import ModelCatalog
@@ -38,7 +40,16 @@ class BCPolicy(Policy):
tf.get_variable_scope().name)
def setup_loss(self, action_space):
self.ac = tf.placeholder(tf.int64, [None], name="ac")
if isinstance(action_space, gym.spaces.Box):
self.ac = tf.placeholder(tf.float32,
[None] + list(action_space.shape),
name="ac")
elif isinstance(action_space, gym.spaces.Discrete):
self.ac = tf.placeholder(tf.int64, [None], name="ac")
else:
raise NotImplementedError(
"action space" + str(type(action_space)) +
"currently not supported")
log_prob = self.curr_dist.logp(self.ac)
self.pi_loss = - tf.reduce_sum(log_prob)
self.loss = self.pi_loss