mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 08:31:42 +08:00
[rllib] Support agent.get_action in multiagent (#2543)
* support get action on policy id * comment * grammar fixes * Update rllib-algorithms.rst
This commit is contained in:
@@ -212,16 +212,23 @@ class Agent(Trainable):
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def compute_action(self, observation, state=None):
|
||||
"""Computes an action using the current trained policy."""
|
||||
def compute_action(self, observation, state=None, policy_id="default"):
|
||||
"""Computes an action for the specified policy.
|
||||
|
||||
Arguments:
|
||||
observation (obj): observation from the environment.
|
||||
state (list): RNN hidden state, if any.
|
||||
policy_id (str): policy to query (only applies to multi-agent).
|
||||
"""
|
||||
|
||||
if state is None:
|
||||
state = []
|
||||
obs = self.local_evaluator.filters["default"](
|
||||
filtered_obs = self.local_evaluator.filters[policy_id](
|
||||
observation, update=False)
|
||||
return self.local_evaluator.for_policy(
|
||||
lambda p: p.compute_single_action(obs, state, is_training=False)[0]
|
||||
)
|
||||
lambda p: p.compute_single_action(
|
||||
filtered_obs, state, is_training=False)[0],
|
||||
policy_id=policy_id)
|
||||
|
||||
def get_weights(self, policies=None):
|
||||
"""Return a dictionary of policy ids to weights.
|
||||
|
||||
@@ -303,6 +303,46 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
return
|
||||
raise Exception("failed to improve reward")
|
||||
|
||||
def testTrainMultiCartpoleMultiPolicy(self):
|
||||
n = 10
|
||||
register_env("multi_cartpole", lambda _: MultiCartpole(n))
|
||||
single_env = gym.make("CartPole-v0")
|
||||
|
||||
def gen_policy():
|
||||
config = {
|
||||
"gamma": random.choice([0.5, 0.8, 0.9, 0.95, 0.99]),
|
||||
"n_step": random.choice([1, 2, 3, 4, 5]),
|
||||
}
|
||||
obs_space = single_env.observation_space
|
||||
act_space = single_env.action_space
|
||||
return (PGPolicyGraph, obs_space, act_space, config)
|
||||
|
||||
pg = PGAgent(
|
||||
env="multi_cartpole",
|
||||
config={
|
||||
"num_workers": 0,
|
||||
"multiagent": {
|
||||
"policy_graphs": {
|
||||
"policy_1": gen_policy(),
|
||||
"policy_2": gen_policy(),
|
||||
},
|
||||
"policy_mapping_fn": lambda agent_id: "policy_1",
|
||||
},
|
||||
})
|
||||
|
||||
# Just check that it runs without crashing
|
||||
for i in range(10):
|
||||
result = pg.train()
|
||||
print("Iteration {}, reward {}, timesteps {}".format(
|
||||
i, result.episode_reward_mean, result.timesteps_total))
|
||||
self.assertTrue(
|
||||
pg.compute_action([0, 0, 0, 0], policy_id="policy_1") in [0, 1])
|
||||
self.assertTrue(
|
||||
pg.compute_action([0, 0, 0, 0], policy_id="policy_2") in [0, 1])
|
||||
self.assertRaises(
|
||||
KeyError,
|
||||
lambda: pg.compute_action([0, 0, 0, 0], policy_id="policy_3"))
|
||||
|
||||
def _testWithOptimizer(self, optimizer_cls):
|
||||
n = 3
|
||||
env = gym.make("CartPole-v0")
|
||||
|
||||
Reference in New Issue
Block a user