mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 02:01:24 +08:00
[rllib] Support prev_state/prev_action in rollout and fix multiagent (#4565)
* Cleaner and more correct treatment of agent states in rollout.py * support lstm_use_prev_action_reward in rollout.py * Linter. * appease flake8 * Use _DUMMY_AGENT_ID instead of 0. * All agents have a policy_agent_mapping. Reset the mapping cache at the start of each episode. * Update rollout.py * Fix rollout.py for single-agent envs. * Use agent_id, not policy_id.
This commit is contained in:
+64
-32
@@ -5,6 +5,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
@@ -12,6 +13,8 @@ import pickle
|
||||
import gym
|
||||
import ray
|
||||
from ray.rllib.agents.registry import get_agent_class
|
||||
from ray.rllib.env import MultiAgentEnv
|
||||
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.tune.util import merge_dicts
|
||||
|
||||
@@ -102,17 +105,35 @@ def run(args, parser):
|
||||
rollout(agent, args.env, num_steps, args.out, args.no_render)
|
||||
|
||||
|
||||
class DefaultMapping(collections.defaultdict):
|
||||
"""default_factory now takes as an argument the missing key."""
|
||||
|
||||
def __missing__(self, key):
|
||||
self[key] = value = self.default_factory(key)
|
||||
return value
|
||||
|
||||
|
||||
def default_policy_agent_mapping(unused_agent_id):
|
||||
return DEFAULT_POLICY_ID
|
||||
|
||||
|
||||
def rollout(agent, env_name, num_steps, out=None, no_render=True):
|
||||
policy_agent_mapping = default_policy_agent_mapping
|
||||
|
||||
if hasattr(agent, "local_evaluator"):
|
||||
env = agent.local_evaluator.env
|
||||
multiagent = agent.local_evaluator.multiagent
|
||||
if multiagent:
|
||||
multiagent = isinstance(env, MultiAgentEnv)
|
||||
if agent.local_evaluator.multiagent:
|
||||
policy_agent_mapping = agent.config["multiagent"][
|
||||
"policy_mapping_fn"]
|
||||
mapping_cache = {}
|
||||
|
||||
policy_map = agent.local_evaluator.policy_map
|
||||
state_init = {p: m.get_initial_state() for p, m in policy_map.items()}
|
||||
use_lstm = {p: len(s) > 0 for p, s in state_init.items()}
|
||||
action_init = {
|
||||
p: m.action_space.sample()
|
||||
for p, m in policy_map.items()
|
||||
}
|
||||
else:
|
||||
env = gym.make(env_name)
|
||||
multiagent = False
|
||||
@@ -122,39 +143,50 @@ def rollout(agent, env_name, num_steps, out=None, no_render=True):
|
||||
rollouts = []
|
||||
steps = 0
|
||||
while steps < (num_steps or steps + 1):
|
||||
mapping_cache = {} # in case policy_agent_mapping is stochastic
|
||||
if out is not None:
|
||||
rollout = []
|
||||
state = env.reset()
|
||||
obs = env.reset()
|
||||
agent_states = DefaultMapping(
|
||||
lambda agent_id: state_init[mapping_cache[agent_id]])
|
||||
prev_actions = DefaultMapping(
|
||||
lambda agent_id: action_init[mapping_cache[agent_id]])
|
||||
prev_rewards = collections.defaultdict(lambda: 0.)
|
||||
done = False
|
||||
reward_total = 0.0
|
||||
while not done and steps < (num_steps or steps + 1):
|
||||
if multiagent:
|
||||
action_dict = {}
|
||||
for agent_id in state.keys():
|
||||
a_state = state[agent_id]
|
||||
if a_state is not None:
|
||||
policy_id = mapping_cache.setdefault(
|
||||
agent_id, policy_agent_mapping(agent_id))
|
||||
p_use_lstm = use_lstm[policy_id]
|
||||
if p_use_lstm:
|
||||
a_action, p_state_init, _ = agent.compute_action(
|
||||
a_state,
|
||||
state=state_init[policy_id],
|
||||
policy_id=policy_id)
|
||||
state_init[policy_id] = p_state_init
|
||||
else:
|
||||
a_action = agent.compute_action(
|
||||
a_state, policy_id=policy_id)
|
||||
action_dict[agent_id] = a_action
|
||||
action = action_dict
|
||||
else:
|
||||
if use_lstm[DEFAULT_POLICY_ID]:
|
||||
action, state_init, _ = agent.compute_action(
|
||||
state, state=state_init)
|
||||
else:
|
||||
action = agent.compute_action(state)
|
||||
multi_obs = obs if multiagent else {_DUMMY_AGENT_ID: obs}
|
||||
action_dict = {}
|
||||
for agent_id, a_obs in multi_obs.items():
|
||||
if a_obs is not None:
|
||||
policy_id = mapping_cache.setdefault(
|
||||
agent_id, policy_agent_mapping(agent_id))
|
||||
p_use_lstm = use_lstm[policy_id]
|
||||
if p_use_lstm:
|
||||
a_action, p_state, _ = agent.compute_action(
|
||||
a_obs,
|
||||
state=agent_states[agent_id],
|
||||
prev_action=prev_actions[agent_id],
|
||||
prev_reward=prev_rewards[agent_id],
|
||||
policy_id=policy_id)
|
||||
agent_states[agent_id] = p_state
|
||||
else:
|
||||
a_action = agent.compute_action(
|
||||
a_obs,
|
||||
prev_action=prev_actions[agent_id],
|
||||
prev_reward=prev_rewards[agent_id],
|
||||
policy_id=policy_id)
|
||||
action_dict[agent_id] = a_action
|
||||
prev_actions[agent_id] = a_action
|
||||
action = action_dict
|
||||
|
||||
next_state, reward, done, _ = env.step(action)
|
||||
action = action if multiagent else action[_DUMMY_AGENT_ID]
|
||||
next_obs, reward, done, _ = env.step(action)
|
||||
if multiagent:
|
||||
for agent_id, r in reward.items():
|
||||
prev_rewards[agent_id] = r
|
||||
else:
|
||||
prev_rewards[_DUMMY_AGENT_ID] = reward
|
||||
|
||||
if multiagent:
|
||||
done = done["__all__"]
|
||||
@@ -164,9 +196,9 @@ def rollout(agent, env_name, num_steps, out=None, no_render=True):
|
||||
if not no_render:
|
||||
env.render()
|
||||
if out is not None:
|
||||
rollout.append([state, action, next_state, reward, done])
|
||||
rollout.append([obs, action, next_obs, reward, done])
|
||||
steps += 1
|
||||
state = next_state
|
||||
obs = next_obs
|
||||
if out is not None:
|
||||
rollouts.append(rollout)
|
||||
print("Episode reward", reward_total)
|
||||
|
||||
Reference in New Issue
Block a user