mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 23:08:32 +08:00
[rllib] Learner should not see clipped actions (#3496)
This commit is contained in:
@@ -287,12 +287,12 @@ def _env_runner(async_vector_env,
|
||||
|
||||
# Do batched policy eval
|
||||
eval_results = _do_policy_eval(tf_sess, to_eval, policies,
|
||||
active_episodes, clip_actions)
|
||||
active_episodes)
|
||||
|
||||
# Process results and update episode state
|
||||
actions_to_send = _process_policy_eval_results(
|
||||
to_eval, eval_results, active_episodes, active_envs,
|
||||
off_policy_actions)
|
||||
off_policy_actions, policies, clip_actions)
|
||||
|
||||
# Return computed actions to ready envs. We also send to envs that have
|
||||
# taken off-policy actions; those envs are free to ignore the action.
|
||||
@@ -448,7 +448,7 @@ def _process_observations(async_vector_env, policies, batch_builder_pool,
|
||||
return active_envs, to_eval, outputs
|
||||
|
||||
|
||||
def _do_policy_eval(tf_sess, to_eval, policies, active_episodes, clip_actions):
|
||||
def _do_policy_eval(tf_sess, to_eval, policies, active_episodes):
|
||||
"""Call compute actions on observation batches to get next actions.
|
||||
|
||||
Returns:
|
||||
@@ -483,18 +483,12 @@ def _do_policy_eval(tf_sess, to_eval, policies, active_episodes, clip_actions):
|
||||
for k, v in pending_fetches.items():
|
||||
eval_results[k] = builder.get(v)
|
||||
|
||||
if clip_actions:
|
||||
for policy_id, results in eval_results.items():
|
||||
policy = _get_or_raise(policies, policy_id)
|
||||
actions, rnn_out_cols, pi_info_cols = results
|
||||
eval_results[policy_id] = (_clip_actions(
|
||||
actions, policy.action_space), rnn_out_cols, pi_info_cols)
|
||||
|
||||
return eval_results
|
||||
|
||||
|
||||
def _process_policy_eval_results(to_eval, eval_results, active_episodes,
|
||||
active_envs, off_policy_actions):
|
||||
active_envs, off_policy_actions, policies,
|
||||
clip_actions):
|
||||
"""Process the output of policy neural network evaluation.
|
||||
|
||||
Records policy evaluation results into the given episode objects and
|
||||
@@ -521,10 +515,15 @@ def _process_policy_eval_results(to_eval, eval_results, active_episodes,
|
||||
pi_info_cols["state_out_{}".format(f_i)] = column
|
||||
# Save output rows
|
||||
actions = _unbatch_tuple_actions(actions)
|
||||
policy = _get_or_raise(policies, policy_id)
|
||||
for i, action in enumerate(actions):
|
||||
env_id = eval_data[i].env_id
|
||||
agent_id = eval_data[i].agent_id
|
||||
actions_to_send[env_id][agent_id] = action
|
||||
if clip_actions:
|
||||
actions_to_send[env_id][agent_id] = _clip_actions(
|
||||
action, policy.action_space)
|
||||
else:
|
||||
actions_to_send[env_id][agent_id] = action
|
||||
episode = active_episodes[env_id]
|
||||
episode._set_rnn_state(agent_id, [c[i] for c in rnn_out_cols])
|
||||
episode._set_last_pi_info(
|
||||
@@ -562,7 +561,7 @@ def _clip_actions(actions, space):
|
||||
"""Called to clip actions to the specified range of this policy.
|
||||
|
||||
Arguments:
|
||||
actions: Batch of actions or TupleActions.
|
||||
actions: Single action.
|
||||
space: Action space the actions should be present in.
|
||||
|
||||
Returns:
|
||||
@@ -572,13 +571,13 @@ def _clip_actions(actions, space):
|
||||
if isinstance(space, gym.spaces.Box):
|
||||
return np.clip(actions, space.low, space.high)
|
||||
elif isinstance(space, gym.spaces.Tuple):
|
||||
if not isinstance(actions, TupleActions):
|
||||
if type(actions) not in (tuple, list):
|
||||
raise ValueError("Expected tuple space for actions {}: {}".format(
|
||||
actions, space))
|
||||
out = []
|
||||
for a, s in zip(actions.batches, space.spaces):
|
||||
for a, s in zip(actions, space.spaces):
|
||||
out.append(_clip_actions(a, s))
|
||||
return TupleActions(out)
|
||||
return out
|
||||
else:
|
||||
return actions
|
||||
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
"""Integration test: (1) pendulum works, (2) single-agent multi-agent works."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import ray
|
||||
from ray.rllib.test.test_multi_agent_env import make_multiagent
|
||||
from ray.tune import run_experiments
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init()
|
||||
MultiPendulum = make_multiagent("Pendulum-v0")
|
||||
register_env("multi_pend", lambda _: MultiPendulum(1))
|
||||
trials = run_experiments({
|
||||
"test": {
|
||||
"run": "PPO",
|
||||
"env": "multi_pend",
|
||||
"stop": {
|
||||
"timesteps_total": 500000,
|
||||
"episode_reward_mean": -200,
|
||||
},
|
||||
"config": {
|
||||
"train_batch_size": 2048,
|
||||
"vf_clip_param": 10.0,
|
||||
"num_workers": 0,
|
||||
"num_envs_per_worker": 10,
|
||||
"lambda": 0.1,
|
||||
"gamma": 0.95,
|
||||
"lr": 0.0003,
|
||||
"sgd_minibatch_size": 64,
|
||||
"num_sgd_iter": 10,
|
||||
"model": {
|
||||
"fcnet_hiddens": [64, 64],
|
||||
},
|
||||
"batch_mode": "complete_episodes",
|
||||
},
|
||||
}
|
||||
})
|
||||
if trials[0].last_result["episode_reward_mean"] < -200:
|
||||
raise ValueError("Did not get to -200 reward", trials[0].last_result)
|
||||
@@ -5,7 +5,8 @@ pendulum-ppo:
|
||||
config:
|
||||
train_batch_size: 2048
|
||||
vf_clip_param: 10.0
|
||||
num_workers: 2
|
||||
num_workers: 0
|
||||
num_envs_per_worker: 10
|
||||
lambda: 0.1
|
||||
gamma: 0.95
|
||||
lr: 0.0003
|
||||
|
||||
Reference in New Issue
Block a user