[rllib] Add helper to iterate over envs in a vectorized environment (#4001)

* add foreach env func

* fix

* add test
This commit is contained in:
Eric Liang
2019-02-11 10:40:47 -08:00
committed by GitHub
parent a70ae1687b
commit c4182463f6
3 changed files with 25 additions and 6 deletions
@@ -494,6 +494,16 @@ class PolicyEvaluator(EvaluatorInterface):
self.policy_map[DEFAULT_POLICY_ID].compute_apply(samples))
return grad_fetch
@DeveloperAPI
def foreach_env(self, func):
"""Apply the given function to each underlying env instance."""
envs = self.async_env.get_unwrapped()
if not envs:
return [func(self.async_env)]
else:
return [func(e) for e in envs]
@DeveloperAPI
def get_policy(self, policy_id=DEFAULT_POLICY_ID):
"""Return policy graph for the specified id, or None.
@@ -209,16 +209,21 @@ class TestPolicyEvaluator(unittest.TestCase):
def testQueryEvaluators(self):
register_env("test", lambda _: gym.make("CartPole-v0"))
pg = PGAgent(
env="test", config={
env="test",
config={
"num_workers": 2,
"sample_batch_size": 5
"sample_batch_size": 5,
"num_envs_per_worker": 2,
})
results = pg.optimizer.foreach_evaluator(
lambda ev: ev.sample_batch_size)
results2 = pg.optimizer.foreach_evaluator_with_index(
lambda ev, i: (i, ev.sample_batch_size))
self.assertEqual(results, [5, 5, 5])
self.assertEqual(results2, [(0, 5), (1, 5), (2, 5)])
results3 = pg.optimizer.foreach_evaluator(
lambda ev: ev.foreach_env(lambda env: 1))
self.assertEqual(results, [10, 10, 10])
self.assertEqual(results2, [(0, 10), (1, 10), (2, 10)])
self.assertEqual(results3, [[1, 1], [1, 1], [1, 1]])
def testRewardClipping(self):
# clipping on