From 0f766ae24b2639db3dfa76780a32c2e794df8997 Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Mon, 19 Feb 2018 22:44:14 -0800 Subject: [PATCH] [rllib] Fix testGetFilters in A3C (#1557) --- python/ray/rllib/test/test_evaluators.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/python/ray/rllib/test/test_evaluators.py b/python/ray/rllib/test/test_evaluators.py index 482b13326..29c054a0d 100644 --- a/python/ray/rllib/test/test_evaluators.py +++ b/python/ray/rllib/test/test_evaluators.py @@ -6,6 +6,7 @@ import unittest import gym import shutil import tempfile +import time import ray from ray.rllib.a3c import DEFAULT_CONFIG @@ -51,6 +52,7 @@ class A3CEvaluatorTest(unittest.TestCase): def sample_and_flush(self): e = self.e + time.sleep(2) self.e.sample() filters = e.get_filters(flush_after=True) obs_f = filters["obs_filter"] @@ -62,14 +64,15 @@ class A3CEvaluatorTest(unittest.TestCase): return obs_f, rew_f def testGetFilters(self): + """Show `flush_after=False` provides does not affect the buffer.""" e = self.e - obs_f, rew_f = self.sample_and_flush() - COUNT = obs_f.rs.n + self.sample_and_flush() filters = e.get_filters(flush_after=False) obs_f = filters["obs_filter"] - NEW_COUNT = obs_f.rs.n - self.assertGreaterEqual(NEW_COUNT, COUNT) - self.assertLessEqual(obs_f.buffer.n, NEW_COUNT - COUNT) + filters2 = e.get_filters(flush_after=False) + obs_f2 = filters2["obs_filter"] + self.assertGreaterEqual(obs_f2.rs.n, obs_f.rs.n) + self.assertGreaterEqual(obs_f2.buffer.n, obs_f.buffer.n) def testSyncFilter(self): """Show that sync_filters rebases own buffer over input"""