mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:08:13 +08:00
[rllib]Update bc/policy.py (#2012)
This commit is contained in:
@@ -2,8 +2,10 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import ray
|
||||
import tensorflow as tf
|
||||
import gym
|
||||
|
||||
import ray
|
||||
from ray.rllib.a3c.policy import Policy
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
|
||||
@@ -38,7 +40,16 @@ class BCPolicy(Policy):
|
||||
tf.get_variable_scope().name)
|
||||
|
||||
def setup_loss(self, action_space):
|
||||
self.ac = tf.placeholder(tf.int64, [None], name="ac")
|
||||
if isinstance(action_space, gym.spaces.Box):
|
||||
self.ac = tf.placeholder(tf.float32,
|
||||
[None] + list(action_space.shape),
|
||||
name="ac")
|
||||
elif isinstance(action_space, gym.spaces.Discrete):
|
||||
self.ac = tf.placeholder(tf.int64, [None], name="ac")
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"action space" + str(type(action_space)) +
|
||||
"currently not supported")
|
||||
log_prob = self.curr_dist.logp(self.ac)
|
||||
self.pi_loss = - tf.reduce_sum(log_prob)
|
||||
self.loss = self.pi_loss
|
||||
|
||||
Reference in New Issue
Block a user