[rllib] Fix testGetFilters in A3C (#1557)

This commit is contained in:
Richard Liaw
2018-02-19 22:44:14 -08:00
committed by GitHub
parent 73be235701
commit 0f766ae24b
+8 -5
View File
@@ -6,6 +6,7 @@ import unittest
import gym
import shutil
import tempfile
import time
import ray
from ray.rllib.a3c import DEFAULT_CONFIG
@@ -51,6 +52,7 @@ class A3CEvaluatorTest(unittest.TestCase):
def sample_and_flush(self):
e = self.e
time.sleep(2)
self.e.sample()
filters = e.get_filters(flush_after=True)
obs_f = filters["obs_filter"]
@@ -62,14 +64,15 @@ class A3CEvaluatorTest(unittest.TestCase):
return obs_f, rew_f
def testGetFilters(self):
"""Show `flush_after=False` provides does not affect the buffer."""
e = self.e
obs_f, rew_f = self.sample_and_flush()
COUNT = obs_f.rs.n
self.sample_and_flush()
filters = e.get_filters(flush_after=False)
obs_f = filters["obs_filter"]
NEW_COUNT = obs_f.rs.n
self.assertGreaterEqual(NEW_COUNT, COUNT)
self.assertLessEqual(obs_f.buffer.n, NEW_COUNT - COUNT)
filters2 = e.get_filters(flush_after=False)
obs_f2 = filters2["obs_filter"]
self.assertGreaterEqual(obs_f2.rs.n, obs_f.rs.n)
self.assertGreaterEqual(obs_f2.buffer.n, obs_f.buffer.n)
def testSyncFilter(self):
"""Show that sync_filters rebases own buffer over input"""