mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 14:55:50 +08:00
[rllib] Fix testGetFilters in A3C (#1557)
This commit is contained in:
@@ -6,6 +6,7 @@ import unittest
|
||||
import gym
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
import ray
|
||||
from ray.rllib.a3c import DEFAULT_CONFIG
|
||||
@@ -51,6 +52,7 @@ class A3CEvaluatorTest(unittest.TestCase):
|
||||
|
||||
def sample_and_flush(self):
|
||||
e = self.e
|
||||
time.sleep(2)
|
||||
self.e.sample()
|
||||
filters = e.get_filters(flush_after=True)
|
||||
obs_f = filters["obs_filter"]
|
||||
@@ -62,14 +64,15 @@ class A3CEvaluatorTest(unittest.TestCase):
|
||||
return obs_f, rew_f
|
||||
|
||||
def testGetFilters(self):
|
||||
"""Show `flush_after=False` provides does not affect the buffer."""
|
||||
e = self.e
|
||||
obs_f, rew_f = self.sample_and_flush()
|
||||
COUNT = obs_f.rs.n
|
||||
self.sample_and_flush()
|
||||
filters = e.get_filters(flush_after=False)
|
||||
obs_f = filters["obs_filter"]
|
||||
NEW_COUNT = obs_f.rs.n
|
||||
self.assertGreaterEqual(NEW_COUNT, COUNT)
|
||||
self.assertLessEqual(obs_f.buffer.n, NEW_COUNT - COUNT)
|
||||
filters2 = e.get_filters(flush_after=False)
|
||||
obs_f2 = filters2["obs_filter"]
|
||||
self.assertGreaterEqual(obs_f2.rs.n, obs_f.rs.n)
|
||||
self.assertGreaterEqual(obs_f2.buffer.n, obs_f.buffer.n)
|
||||
|
||||
def testSyncFilter(self):
|
||||
"""Show that sync_filters rebases own buffer over input"""
|
||||
|
||||
Reference in New Issue
Block a user