mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 04:19:23 +08:00
[rllib] Add helper to iterate over envs in a vectorized environment (#4001)
* add foreach env func * fix * add test
This commit is contained in:
@@ -494,6 +494,16 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
self.policy_map[DEFAULT_POLICY_ID].compute_apply(samples))
|
||||
return grad_fetch
|
||||
|
||||
@DeveloperAPI
|
||||
def foreach_env(self, func):
|
||||
"""Apply the given function to each underlying env instance."""
|
||||
|
||||
envs = self.async_env.get_unwrapped()
|
||||
if not envs:
|
||||
return [func(self.async_env)]
|
||||
else:
|
||||
return [func(e) for e in envs]
|
||||
|
||||
@DeveloperAPI
|
||||
def get_policy(self, policy_id=DEFAULT_POLICY_ID):
|
||||
"""Return policy graph for the specified id, or None.
|
||||
|
||||
@@ -209,16 +209,21 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
def testQueryEvaluators(self):
|
||||
register_env("test", lambda _: gym.make("CartPole-v0"))
|
||||
pg = PGAgent(
|
||||
env="test", config={
|
||||
env="test",
|
||||
config={
|
||||
"num_workers": 2,
|
||||
"sample_batch_size": 5
|
||||
"sample_batch_size": 5,
|
||||
"num_envs_per_worker": 2,
|
||||
})
|
||||
results = pg.optimizer.foreach_evaluator(
|
||||
lambda ev: ev.sample_batch_size)
|
||||
results2 = pg.optimizer.foreach_evaluator_with_index(
|
||||
lambda ev, i: (i, ev.sample_batch_size))
|
||||
self.assertEqual(results, [5, 5, 5])
|
||||
self.assertEqual(results2, [(0, 5), (1, 5), (2, 5)])
|
||||
results3 = pg.optimizer.foreach_evaluator(
|
||||
lambda ev: ev.foreach_env(lambda env: 1))
|
||||
self.assertEqual(results, [10, 10, 10])
|
||||
self.assertEqual(results2, [(0, 10), (1, 10), (2, 10)])
|
||||
self.assertEqual(results3, [[1, 1], [1, 1], [1, 1]])
|
||||
|
||||
def testRewardClipping(self):
|
||||
# clipping on
|
||||
|
||||
Reference in New Issue
Block a user