[RLlib] Fix two test cases that only fail on Travis. (#11435)

This commit is contained in:
Sven Mika
2020-10-16 20:53:30 +02:00
committed by GitHub
parent f890808c14
commit 2aec77e305
4 changed files with 12 additions and 5 deletions
+1 -1
View File
@@ -14,7 +14,7 @@ class ObservationFunction:
in multi-agent scenarios.
Observation functions can be specified in the multi-agent config by
specifying ``{"observation_function": your_obs_func}``. Note that
specifying ``{"observation_fn": your_obs_func}``. Note that
``your_obs_func`` can be a plain Python function.
This API is **experimental**.
+9 -2
View File
@@ -10,6 +10,14 @@ from ray.tune.registry import register_env
class EvalTest(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init(object_store_memory=1000 * 1024 * 1024, num_cpus=4)
@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()
def test_dqn_n_step(self):
obs = [1, 2, 3, 4, 5, 6, 7]
actions = ["a", "b", "a", "a", "a", "b", "a"]
@@ -32,7 +40,6 @@ class EvalTest(unittest.TestCase):
for agent_cls in agent_classes:
for fw in framework_iterator(frameworks=("torch", "tf")):
ray.init(object_store_memory=1000 * 1024 * 1024)
register_env("CartPoleWrapped-v0", env_creator)
agent = agent_cls(
env="CartPoleWrapped-v0",
@@ -53,6 +60,7 @@ class EvalTest(unittest.TestCase):
r1 = agent.train()
r2 = agent.train()
r3 = agent.train()
agent.stop()
self.assertTrue("evaluation" in r1)
self.assertTrue("evaluation" in r3)
@@ -60,7 +68,6 @@ class EvalTest(unittest.TestCase):
self.assertFalse("evaluation" in r2)
self.assertTrue("episode_reward_mean" in r1["evaluation"])
self.assertNotEqual(r1["evaluation"], r3["evaluation"])
ray.shutdown()
if __name__ == "__main__":
+1 -1
View File
@@ -10,7 +10,7 @@ class TestDistributedExecution(unittest.TestCase):
@classmethod
def setUpClass(cls):
ray.init(ignore_reinit_error=True)
ray.init(num_cpus=4)
@classmethod
def tearDownClass(cls):
@@ -82,7 +82,7 @@ class TestExplorations(unittest.TestCase):
@classmethod
def setUpClass(cls):
ray.init(ignore_reinit_error=True)
ray.init(num_cpus=4)
@classmethod
def tearDownClass(cls):