mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 16:54:21 +08:00
[rllib] Fix APPO + continuous spaces, feed prev_rew/act to A3C properly (#4286)
This commit is contained in:
@@ -50,12 +50,13 @@ class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
tf.float32, [None] + list(observation_space.shape))
|
||||
dist_class, logit_dim = ModelCatalog.get_action_dist(
|
||||
action_space, self.config["model"])
|
||||
prev_actions = ModelCatalog.get_action_placeholder(action_space)
|
||||
prev_rewards = tf.placeholder(tf.float32, [None], name="prev_reward")
|
||||
self.prev_actions = ModelCatalog.get_action_placeholder(action_space)
|
||||
self.prev_rewards = tf.placeholder(
|
||||
tf.float32, [None], name="prev_reward")
|
||||
self.model = ModelCatalog.get_model({
|
||||
"obs": self.observations,
|
||||
"prev_actions": prev_actions,
|
||||
"prev_rewards": prev_rewards,
|
||||
"prev_actions": self.prev_actions,
|
||||
"prev_rewards": self.prev_rewards,
|
||||
"is_training": self._get_is_training_placeholder(),
|
||||
}, observation_space, logit_dim, self.config["model"])
|
||||
action_dist = dist_class(self.model.outputs)
|
||||
@@ -83,8 +84,8 @@ class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
loss_in = [
|
||||
("obs", self.observations),
|
||||
("actions", actions),
|
||||
("prev_actions", prev_actions),
|
||||
("prev_rewards", prev_rewards),
|
||||
("prev_actions", self.prev_actions),
|
||||
("prev_rewards", self.prev_rewards),
|
||||
("advantages", advantages),
|
||||
("value_targets", self.v_target),
|
||||
]
|
||||
@@ -103,8 +104,8 @@ class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
loss_inputs=loss_in,
|
||||
state_inputs=self.model.state_in,
|
||||
state_outputs=self.model.state_out,
|
||||
prev_action_input=prev_actions,
|
||||
prev_reward_input=prev_rewards,
|
||||
prev_action_input=self.prev_actions,
|
||||
prev_reward_input=self.prev_rewards,
|
||||
seq_lens=self.model.seq_lens,
|
||||
max_seq_len=self.config["model"]["max_seq_len"])
|
||||
|
||||
@@ -138,7 +139,9 @@ class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
next_state = []
|
||||
for i in range(len(self.model.state_in)):
|
||||
next_state.append([sample_batch["state_out_{}".format(i)][-1]])
|
||||
last_r = self._value(sample_batch["new_obs"][-1], *next_state)
|
||||
last_r = self._value(sample_batch["new_obs"][-1],
|
||||
sample_batch["actions"][-1],
|
||||
sample_batch["rewards"][-1], *next_state)
|
||||
return compute_advantages(sample_batch, last_r, self.config["gamma"],
|
||||
self.config["lambda"])
|
||||
|
||||
@@ -159,8 +162,13 @@ class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
TFPolicyGraph.extra_compute_action_fetches(self),
|
||||
**{"vf_preds": self.vf})
|
||||
|
||||
def _value(self, ob, *args):
|
||||
feed_dict = {self.observations: [ob], self.model.seq_lens: [1]}
|
||||
def _value(self, ob, prev_action, prev_reward, *args):
|
||||
feed_dict = {
|
||||
self.observations: [ob],
|
||||
self.prev_actions: [prev_action],
|
||||
self.prev_rewards: [prev_reward],
|
||||
self.model.seq_lens: [1]
|
||||
}
|
||||
assert len(args) == len(self.model.state_in), \
|
||||
(args, self.model.state_in)
|
||||
for k, v in zip(self.model.state_in, args):
|
||||
|
||||
@@ -171,16 +171,17 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
|
||||
if isinstance(action_space, gym.spaces.Discrete):
|
||||
is_multidiscrete = False
|
||||
actions_shape = [None]
|
||||
output_hidden_shape = [action_space.n]
|
||||
elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete):
|
||||
is_multidiscrete = True
|
||||
actions_shape = [None, len(action_space.nvec)]
|
||||
output_hidden_shape = action_space.nvec.astype(np.int32)
|
||||
else:
|
||||
elif self.config["vtrace"]:
|
||||
raise UnsupportedSpaceException(
|
||||
"Action space {} is not supported for APPO.",
|
||||
"Action space {} is not supported for APPO + VTrace.",
|
||||
format(action_space))
|
||||
else:
|
||||
is_multidiscrete = False
|
||||
output_hidden_shape = 1
|
||||
|
||||
# Policy network model
|
||||
dist_class, logit_dim = ModelCatalog.get_action_dist(
|
||||
@@ -200,7 +201,7 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
existing_state_in = existing_inputs[9:-1]
|
||||
existing_seq_lens = existing_inputs[-1]
|
||||
else:
|
||||
actions = tf.placeholder(tf.int64, actions_shape, name="ac")
|
||||
actions = ModelCatalog.get_action_placeholder(action_space)
|
||||
dones = tf.placeholder(tf.bool, [None], name="dones")
|
||||
rewards = tf.placeholder(tf.float32, [None], name="rewards")
|
||||
behaviour_logits = tf.placeholder(
|
||||
|
||||
@@ -84,9 +84,6 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
|
||||
learner_queue_size)
|
||||
self.learner.start()
|
||||
|
||||
if len(self.remote_evaluators) == 0:
|
||||
logger.warning("Config num_workers=0 means training will hang!")
|
||||
|
||||
# Stats
|
||||
self._optimizer_step_timer = TimerStat()
|
||||
self.num_weight_syncs = 0
|
||||
@@ -137,6 +134,8 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
|
||||
|
||||
@override(PolicyOptimizer)
|
||||
def step(self):
|
||||
if len(self.remote_evaluators) == 0:
|
||||
raise ValueError("Config num_workers=0 means training will hang!")
|
||||
assert self.learner.is_alive()
|
||||
with self._optimizer_step_timer:
|
||||
sample_timesteps, train_timesteps = self._step()
|
||||
|
||||
Reference in New Issue
Block a user