[RLlib] Fix broken tune tests in master due to framework=auto errors. (#8672)

This commit is contained in:
Sven Mika
2020-05-29 11:55:47 +02:00
committed by GitHub
parent c64b694560
commit d483ed28ba
6 changed files with 18 additions and 7 deletions
@@ -77,6 +77,7 @@ class TestMemoryScheduling(unittest.TestCase):
config={
"env": "CartPole-v0",
"memory": 100 * 1024 * 1024, # too little
"framework": "tf",
},
raise_on_failed_trial=False)
self.assertEqual(result.trials[0].status, "ERROR")
@@ -99,6 +100,7 @@ class TestMemoryScheduling(unittest.TestCase):
"env": "CartPole-v0",
# too large
"object_store_memory": 10000 * 1024 * 1024,
"framework": "tf",
}))
finally:
ray.shutdown()
@@ -113,6 +115,7 @@ class TestMemoryScheduling(unittest.TestCase):
"env": "CartPole-v0",
"num_workers": 1,
"memory_per_worker": 100 * 1024 * 1024, # too little
"framework": "tf",
},
raise_on_failed_trial=False)
self.assertEqual(result.trials[0].status, "ERROR")
@@ -134,6 +137,7 @@ class TestMemoryScheduling(unittest.TestCase):
"num_workers": 1,
# too large
"object_store_memory_per_worker": 10000 * 1024 * 1024,
"framework": "tf",
}))
finally:
ray.shutdown()
+2 -2
View File
@@ -554,7 +554,7 @@ ray.init(address="{address}")
tune.run(
"PG",
name="experiment",
config=dict(env="CartPole-v1"),
config=dict(env="CartPole-v1", framework="tf"),
stop=dict(training_iteration=10),
local_dir="{checkpoint_dir}",
checkpoint_freq=1,
@@ -593,7 +593,7 @@ tune.run(
"experiment": {
"run": "PG",
"checkpoint_freq": 1,
"local_dir": dirpath
"local_dir": dirpath,
}
},
resume=True)
@@ -36,6 +36,7 @@ class TuneRestoreTest(unittest.TestCase):
local_dir=tmpdir,
config={
"env": "CartPole-v0",
"framework": "tf",
},
)
@@ -58,6 +59,7 @@ class TuneRestoreTest(unittest.TestCase):
restore=self.checkpoint_path, # Restore the checkpoint
config={
"env": "CartPole-v0",
"framework": "tf",
},
)
@@ -73,6 +75,7 @@ class TuneRestoreTest(unittest.TestCase):
restore=self.checkpoint_path,
config={
"env": "CartPole-v0",
"framework": "tf",
},
)
self.assertTrue(os.path.isfile(self.checkpoint_path))
+6 -3
View File
@@ -717,7 +717,8 @@ py_test(
py_test(
name = "test_impala_cartpole_v0_buffers_2_lstm",
main = "train.py", srcs = ["train.py"],
main = "train.py",
srcs = ["train.py"],
tags = ["quick_train"],
args = [
"--env", "CartPole-v0",
@@ -730,12 +731,14 @@ py_test(
py_test(
name = "test_impala_pong_deterministic_v4_40k_ts_1G_obj_store",
main = "train.py", srcs = ["train.py"],
main = "train.py",
srcs = ["train.py"],
tags = ["quick_train"],
size = "medium",
args = [
"--env", "PongDeterministic-v4",
"--run", "IMPALA",
"--stop", "'{\"timesteps_total\": 40000}'",
"--stop", "'{\"timesteps_total\": 30000}'",
"--ray-object-store-memory=1000000000",
"--config", "'{\"framework\": \"tf\", \"num_workers\": 1, \"num_gpus\": 0, \"num_envs_per_worker\": 32, \"rollout_fragment_length\": 50, \"train_batch_size\": 50, \"learner_queue_size\": 1}'"
]
+1
View File
@@ -18,6 +18,7 @@ class _MockTrainer(Trainer):
"user_checkpoint_freq": 0,
"object_store_memory_per_worker": 0,
"object_store_memory": 0,
"framework": "tf",
})
@classmethod
+2 -2
View File
@@ -123,7 +123,7 @@ class ModelSupportedSpaces(unittest.TestCase):
check_support(
"ARS", {
"num_workers": 1,
"noise_size": 100000,
"noise_size": 1500000,
"num_rollouts": 1,
"rollouts_used": 1
})
@@ -147,7 +147,7 @@ class ModelSupportedSpaces(unittest.TestCase):
check_support(
"ES", {
"num_workers": 1,
"noise_size": 100000,
"noise_size": 1500000,
"episodes_per_batch": 1,
"train_batch_size": 1
})