mirror of
https://github.com/wassname/ray.git
synced 2026-07-05 23:48:33 +08:00
[RLlib] Fix broken tune tests in master due to framework=auto errors. (#8672)
This commit is contained in:
@@ -77,6 +77,7 @@ class TestMemoryScheduling(unittest.TestCase):
|
||||
config={
|
||||
"env": "CartPole-v0",
|
||||
"memory": 100 * 1024 * 1024, # too little
|
||||
"framework": "tf",
|
||||
},
|
||||
raise_on_failed_trial=False)
|
||||
self.assertEqual(result.trials[0].status, "ERROR")
|
||||
@@ -99,6 +100,7 @@ class TestMemoryScheduling(unittest.TestCase):
|
||||
"env": "CartPole-v0",
|
||||
# too large
|
||||
"object_store_memory": 10000 * 1024 * 1024,
|
||||
"framework": "tf",
|
||||
}))
|
||||
finally:
|
||||
ray.shutdown()
|
||||
@@ -113,6 +115,7 @@ class TestMemoryScheduling(unittest.TestCase):
|
||||
"env": "CartPole-v0",
|
||||
"num_workers": 1,
|
||||
"memory_per_worker": 100 * 1024 * 1024, # too little
|
||||
"framework": "tf",
|
||||
},
|
||||
raise_on_failed_trial=False)
|
||||
self.assertEqual(result.trials[0].status, "ERROR")
|
||||
@@ -134,6 +137,7 @@ class TestMemoryScheduling(unittest.TestCase):
|
||||
"num_workers": 1,
|
||||
# too large
|
||||
"object_store_memory_per_worker": 10000 * 1024 * 1024,
|
||||
"framework": "tf",
|
||||
}))
|
||||
finally:
|
||||
ray.shutdown()
|
||||
|
||||
@@ -554,7 +554,7 @@ ray.init(address="{address}")
|
||||
tune.run(
|
||||
"PG",
|
||||
name="experiment",
|
||||
config=dict(env="CartPole-v1"),
|
||||
config=dict(env="CartPole-v1", framework="tf"),
|
||||
stop=dict(training_iteration=10),
|
||||
local_dir="{checkpoint_dir}",
|
||||
checkpoint_freq=1,
|
||||
@@ -593,7 +593,7 @@ tune.run(
|
||||
"experiment": {
|
||||
"run": "PG",
|
||||
"checkpoint_freq": 1,
|
||||
"local_dir": dirpath
|
||||
"local_dir": dirpath,
|
||||
}
|
||||
},
|
||||
resume=True)
|
||||
|
||||
@@ -36,6 +36,7 @@ class TuneRestoreTest(unittest.TestCase):
|
||||
local_dir=tmpdir,
|
||||
config={
|
||||
"env": "CartPole-v0",
|
||||
"framework": "tf",
|
||||
},
|
||||
)
|
||||
|
||||
@@ -58,6 +59,7 @@ class TuneRestoreTest(unittest.TestCase):
|
||||
restore=self.checkpoint_path, # Restore the checkpoint
|
||||
config={
|
||||
"env": "CartPole-v0",
|
||||
"framework": "tf",
|
||||
},
|
||||
)
|
||||
|
||||
@@ -73,6 +75,7 @@ class TuneRestoreTest(unittest.TestCase):
|
||||
restore=self.checkpoint_path,
|
||||
config={
|
||||
"env": "CartPole-v0",
|
||||
"framework": "tf",
|
||||
},
|
||||
)
|
||||
self.assertTrue(os.path.isfile(self.checkpoint_path))
|
||||
|
||||
+6
-3
@@ -717,7 +717,8 @@ py_test(
|
||||
|
||||
py_test(
|
||||
name = "test_impala_cartpole_v0_buffers_2_lstm",
|
||||
main = "train.py", srcs = ["train.py"],
|
||||
main = "train.py",
|
||||
srcs = ["train.py"],
|
||||
tags = ["quick_train"],
|
||||
args = [
|
||||
"--env", "CartPole-v0",
|
||||
@@ -730,12 +731,14 @@ py_test(
|
||||
|
||||
py_test(
|
||||
name = "test_impala_pong_deterministic_v4_40k_ts_1G_obj_store",
|
||||
main = "train.py", srcs = ["train.py"],
|
||||
main = "train.py",
|
||||
srcs = ["train.py"],
|
||||
tags = ["quick_train"],
|
||||
size = "medium",
|
||||
args = [
|
||||
"--env", "PongDeterministic-v4",
|
||||
"--run", "IMPALA",
|
||||
"--stop", "'{\"timesteps_total\": 40000}'",
|
||||
"--stop", "'{\"timesteps_total\": 30000}'",
|
||||
"--ray-object-store-memory=1000000000",
|
||||
"--config", "'{\"framework\": \"tf\", \"num_workers\": 1, \"num_gpus\": 0, \"num_envs_per_worker\": 32, \"rollout_fragment_length\": 50, \"train_batch_size\": 50, \"learner_queue_size\": 1}'"
|
||||
]
|
||||
|
||||
@@ -18,6 +18,7 @@ class _MockTrainer(Trainer):
|
||||
"user_checkpoint_freq": 0,
|
||||
"object_store_memory_per_worker": 0,
|
||||
"object_store_memory": 0,
|
||||
"framework": "tf",
|
||||
})
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -123,7 +123,7 @@ class ModelSupportedSpaces(unittest.TestCase):
|
||||
check_support(
|
||||
"ARS", {
|
||||
"num_workers": 1,
|
||||
"noise_size": 100000,
|
||||
"noise_size": 1500000,
|
||||
"num_rollouts": 1,
|
||||
"rollouts_used": 1
|
||||
})
|
||||
@@ -147,7 +147,7 @@ class ModelSupportedSpaces(unittest.TestCase):
|
||||
check_support(
|
||||
"ES", {
|
||||
"num_workers": 1,
|
||||
"noise_size": 100000,
|
||||
"noise_size": 1500000,
|
||||
"episodes_per_batch": 1,
|
||||
"train_batch_size": 1
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user