mirror of
https://github.com/wassname/ray.git
synced 2026-07-06 05:16:30 +08:00
[RLlib] Fix two test cases that only fail on Travis. (#11435)
This commit is contained in:
@@ -14,7 +14,7 @@ class ObservationFunction:
|
||||
in multi-agent scenarios.
|
||||
|
||||
Observation functions can be specified in the multi-agent config by
|
||||
specifying ``{"observation_function": your_obs_func}``. Note that
|
||||
specifying ``{"observation_fn": your_obs_func}``. Note that
|
||||
``your_obs_func`` can be a plain Python function.
|
||||
|
||||
This API is **experimental**.
|
||||
|
||||
@@ -10,6 +10,14 @@ from ray.tune.registry import register_env
|
||||
|
||||
|
||||
class EvalTest(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
ray.init(object_store_memory=1000 * 1024 * 1024, num_cpus=4)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
ray.shutdown()
|
||||
|
||||
def test_dqn_n_step(self):
|
||||
obs = [1, 2, 3, 4, 5, 6, 7]
|
||||
actions = ["a", "b", "a", "a", "a", "b", "a"]
|
||||
@@ -32,7 +40,6 @@ class EvalTest(unittest.TestCase):
|
||||
|
||||
for agent_cls in agent_classes:
|
||||
for fw in framework_iterator(frameworks=("torch", "tf")):
|
||||
ray.init(object_store_memory=1000 * 1024 * 1024)
|
||||
register_env("CartPoleWrapped-v0", env_creator)
|
||||
agent = agent_cls(
|
||||
env="CartPoleWrapped-v0",
|
||||
@@ -53,6 +60,7 @@ class EvalTest(unittest.TestCase):
|
||||
r1 = agent.train()
|
||||
r2 = agent.train()
|
||||
r3 = agent.train()
|
||||
agent.stop()
|
||||
|
||||
self.assertTrue("evaluation" in r1)
|
||||
self.assertTrue("evaluation" in r3)
|
||||
@@ -60,7 +68,6 @@ class EvalTest(unittest.TestCase):
|
||||
self.assertFalse("evaluation" in r2)
|
||||
self.assertTrue("episode_reward_mean" in r1["evaluation"])
|
||||
self.assertNotEqual(r1["evaluation"], r3["evaluation"])
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -10,7 +10,7 @@ class TestDistributedExecution(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
ray.init(ignore_reinit_error=True)
|
||||
ray.init(num_cpus=4)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
|
||||
@@ -82,7 +82,7 @@ class TestExplorations(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
ray.init(ignore_reinit_error=True)
|
||||
ray.init(num_cpus=4)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
|
||||
Reference in New Issue
Block a user