[rllib] Support agent.get_action in multiagent (#2543)

* support get action on policy id

* comment

* grammar fixes

* Update rllib-algorithms.rst
This commit is contained in:
Eric Liang
2018-08-02 13:35:53 -07:00
committed by GitHub
parent d2ebe4d9a3
commit f7ec292360
3 changed files with 54 additions and 7 deletions
+12 -5
View File
@@ -212,16 +212,23 @@ class Agent(Trainable):
raise NotImplementedError
def compute_action(self, observation, state=None):
"""Computes an action using the current trained policy."""
def compute_action(self, observation, state=None, policy_id="default"):
"""Computes an action for the specified policy.
Arguments:
observation (obj): observation from the environment.
state (list): RNN hidden state, if any.
policy_id (str): policy to query (only applies to multi-agent).
"""
if state is None:
state = []
obs = self.local_evaluator.filters["default"](
filtered_obs = self.local_evaluator.filters[policy_id](
observation, update=False)
return self.local_evaluator.for_policy(
lambda p: p.compute_single_action(obs, state, is_training=False)[0]
)
lambda p: p.compute_single_action(
filtered_obs, state, is_training=False)[0],
policy_id=policy_id)
def get_weights(self, policies=None):
"""Return a dictionary of policy ids to weights.
@@ -303,6 +303,46 @@ class TestMultiAgentEnv(unittest.TestCase):
return
raise Exception("failed to improve reward")
def testTrainMultiCartpoleMultiPolicy(self):
n = 10
register_env("multi_cartpole", lambda _: MultiCartpole(n))
single_env = gym.make("CartPole-v0")
def gen_policy():
config = {
"gamma": random.choice([0.5, 0.8, 0.9, 0.95, 0.99]),
"n_step": random.choice([1, 2, 3, 4, 5]),
}
obs_space = single_env.observation_space
act_space = single_env.action_space
return (PGPolicyGraph, obs_space, act_space, config)
pg = PGAgent(
env="multi_cartpole",
config={
"num_workers": 0,
"multiagent": {
"policy_graphs": {
"policy_1": gen_policy(),
"policy_2": gen_policy(),
},
"policy_mapping_fn": lambda agent_id: "policy_1",
},
})
# Just check that it runs without crashing
for i in range(10):
result = pg.train()
print("Iteration {}, reward {}, timesteps {}".format(
i, result.episode_reward_mean, result.timesteps_total))
self.assertTrue(
pg.compute_action([0, 0, 0, 0], policy_id="policy_1") in [0, 1])
self.assertTrue(
pg.compute_action([0, 0, 0, 0], policy_id="policy_2") in [0, 1])
self.assertRaises(
KeyError,
lambda: pg.compute_action([0, 0, 0, 0], policy_id="policy_3"))
def _testWithOptimizer(self, optimizer_cls):
n = 3
env = gym.make("CartPole-v0")