[rllib] Fix APPO + continuous spaces, feed prev_rew/act to A3C properly (#4286)

This commit is contained in:
Eric Liang
2019-03-06 21:36:26 -08:00
committed by GitHub
parent f0465bc68c
commit b0332551dd
4 changed files with 74 additions and 59 deletions
@@ -50,12 +50,13 @@ class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph):
tf.float32, [None] + list(observation_space.shape))
dist_class, logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"])
prev_actions = ModelCatalog.get_action_placeholder(action_space)
prev_rewards = tf.placeholder(tf.float32, [None], name="prev_reward")
self.prev_actions = ModelCatalog.get_action_placeholder(action_space)
self.prev_rewards = tf.placeholder(
tf.float32, [None], name="prev_reward")
self.model = ModelCatalog.get_model({
"obs": self.observations,
"prev_actions": prev_actions,
"prev_rewards": prev_rewards,
"prev_actions": self.prev_actions,
"prev_rewards": self.prev_rewards,
"is_training": self._get_is_training_placeholder(),
}, observation_space, logit_dim, self.config["model"])
action_dist = dist_class(self.model.outputs)
@@ -83,8 +84,8 @@ class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph):
loss_in = [
("obs", self.observations),
("actions", actions),
("prev_actions", prev_actions),
("prev_rewards", prev_rewards),
("prev_actions", self.prev_actions),
("prev_rewards", self.prev_rewards),
("advantages", advantages),
("value_targets", self.v_target),
]
@@ -103,8 +104,8 @@ class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph):
loss_inputs=loss_in,
state_inputs=self.model.state_in,
state_outputs=self.model.state_out,
prev_action_input=prev_actions,
prev_reward_input=prev_rewards,
prev_action_input=self.prev_actions,
prev_reward_input=self.prev_rewards,
seq_lens=self.model.seq_lens,
max_seq_len=self.config["model"]["max_seq_len"])
@@ -138,7 +139,9 @@ class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph):
next_state = []
for i in range(len(self.model.state_in)):
next_state.append([sample_batch["state_out_{}".format(i)][-1]])
last_r = self._value(sample_batch["new_obs"][-1], *next_state)
last_r = self._value(sample_batch["new_obs"][-1],
sample_batch["actions"][-1],
sample_batch["rewards"][-1], *next_state)
return compute_advantages(sample_batch, last_r, self.config["gamma"],
self.config["lambda"])
@@ -159,8 +162,13 @@ class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph):
TFPolicyGraph.extra_compute_action_fetches(self),
**{"vf_preds": self.vf})
def _value(self, ob, *args):
feed_dict = {self.observations: [ob], self.model.seq_lens: [1]}
def _value(self, ob, prev_action, prev_reward, *args):
feed_dict = {
self.observations: [ob],
self.prev_actions: [prev_action],
self.prev_rewards: [prev_reward],
self.model.seq_lens: [1]
}
assert len(args) == len(self.model.state_in), \
(args, self.model.state_in)
for k, v in zip(self.model.state_in, args):
@@ -171,16 +171,17 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
if isinstance(action_space, gym.spaces.Discrete):
is_multidiscrete = False
actions_shape = [None]
output_hidden_shape = [action_space.n]
elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete):
is_multidiscrete = True
actions_shape = [None, len(action_space.nvec)]
output_hidden_shape = action_space.nvec.astype(np.int32)
else:
elif self.config["vtrace"]:
raise UnsupportedSpaceException(
"Action space {} is not supported for APPO.",
"Action space {} is not supported for APPO + VTrace.",
format(action_space))
else:
is_multidiscrete = False
output_hidden_shape = 1
# Policy network model
dist_class, logit_dim = ModelCatalog.get_action_dist(
@@ -200,7 +201,7 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
existing_state_in = existing_inputs[9:-1]
existing_seq_lens = existing_inputs[-1]
else:
actions = tf.placeholder(tf.int64, actions_shape, name="ac")
actions = ModelCatalog.get_action_placeholder(action_space)
dones = tf.placeholder(tf.bool, [None], name="dones")
rewards = tf.placeholder(tf.float32, [None], name="rewards")
behaviour_logits = tf.placeholder(
@@ -84,9 +84,6 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
learner_queue_size)
self.learner.start()
if len(self.remote_evaluators) == 0:
logger.warning("Config num_workers=0 means training will hang!")
# Stats
self._optimizer_step_timer = TimerStat()
self.num_weight_syncs = 0
@@ -137,6 +134,8 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
@override(PolicyOptimizer)
def step(self):
if len(self.remote_evaluators) == 0:
raise ValueError("Config num_workers=0 means training will hang!")
assert self.learner.is_alive()
with self._optimizer_step_timer:
sample_timesteps, train_timesteps = self._step()