mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 17:15:34 +08:00
[rllib] Enable distributed exec api for A2C, A3C, PG by default (#7580)
This commit is contained in:
+5
-5
@@ -995,7 +995,7 @@ py_test(
|
||||
py_test(
|
||||
name = "tests/test_external_multi_agent_env",
|
||||
tags = ["tests_dir", "tests_dir_E"],
|
||||
size = "large",
|
||||
size = "medium",
|
||||
srcs = ["tests/test_external_multi_agent_env.py"]
|
||||
)
|
||||
|
||||
@@ -1016,21 +1016,21 @@ py_test(
|
||||
py_test(
|
||||
name = "tests/test_io",
|
||||
tags = ["tests_dir", "tests_dir_I"],
|
||||
size = "large",
|
||||
size = "medium",
|
||||
srcs = ["tests/test_io.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tests/test_local",
|
||||
tags = ["tests_dir", "tests_dir_L"],
|
||||
size = "large",
|
||||
size = "medium",
|
||||
srcs = ["tests/test_local.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tests/test_lstm",
|
||||
tags = ["tests_dir", "tests_dir_L"],
|
||||
size = "large",
|
||||
size = "medium",
|
||||
srcs = ["tests/test_lstm.py"]
|
||||
)
|
||||
|
||||
@@ -1051,7 +1051,7 @@ py_test(
|
||||
py_test(
|
||||
name = "tests/test_nested_spaces",
|
||||
tags = ["tests_dir", "tests_dir_N"],
|
||||
size = "large",
|
||||
size = "small",
|
||||
srcs = ["tests/test_nested_spaces.py"]
|
||||
)
|
||||
|
||||
|
||||
@@ -37,6 +37,8 @@ DEFAULT_CONFIG = with_common_config({
|
||||
# Workers sample async. Note that this increases the effective
|
||||
# sample_batch_size by up to 5x due to async buffering of batches.
|
||||
"sample_async": True,
|
||||
# Use the execution plan API instead of policy optimizers.
|
||||
"use_exec_api": True,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
||||
@@ -11,6 +11,8 @@ DEFAULT_CONFIG = with_common_config({
|
||||
"num_workers": 0,
|
||||
# Learning rate.
|
||||
"lr": 0.0004,
|
||||
# Use the execution plan API instead of policy optimizers.
|
||||
"use_exec_api": True,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
||||
+16
-4
@@ -922,19 +922,25 @@ class Trainer(Trainable):
|
||||
an error is raised.
|
||||
"""
|
||||
|
||||
if not self._has_policy_optimizer():
|
||||
if (not self._has_policy_optimizer()
|
||||
and not hasattr(self, "execution_plan")):
|
||||
raise NotImplementedError(
|
||||
"Recovery is not supported for this algorithm")
|
||||
if self._has_policy_optimizer():
|
||||
workers = self.optimizer.workers
|
||||
else:
|
||||
assert hasattr(self, "execution_plan")
|
||||
workers = self.workers
|
||||
|
||||
logger.info("Health checking all workers...")
|
||||
checks = []
|
||||
for ev in self.optimizer.workers.remote_workers():
|
||||
for ev in workers.remote_workers():
|
||||
_, obj_id = ev.sample_with_count.remote()
|
||||
checks.append(obj_id)
|
||||
|
||||
healthy_workers = []
|
||||
for i, obj_id in enumerate(checks):
|
||||
w = self.optimizer.workers.remote_workers()[i]
|
||||
w = workers.remote_workers()[i]
|
||||
try:
|
||||
ray_get_and_free(obj_id)
|
||||
healthy_workers.append(w)
|
||||
@@ -950,7 +956,13 @@ class Trainer(Trainable):
|
||||
raise RuntimeError(
|
||||
"Not enough healthy workers remain to continue.")
|
||||
|
||||
self.optimizer.reset(healthy_workers)
|
||||
if self._has_policy_optimizer():
|
||||
self.optimizer.reset(healthy_workers)
|
||||
else:
|
||||
assert hasattr(self, "execution_plan")
|
||||
logger.warning("Recreating execution plan after failure")
|
||||
workers.reset(healthy_workers)
|
||||
self.train_exec_impl = self.execution_plan(workers, self.config)
|
||||
|
||||
def _has_policy_optimizer(self):
|
||||
"""Whether this Trainer has a PolicyOptimizer as `optimizer` property.
|
||||
|
||||
@@ -121,9 +121,13 @@ def build_trainer(name,
|
||||
self.config["num_workers"])
|
||||
self.train_exec_impl = None
|
||||
self.optimizer = None
|
||||
self.execution_plan = execution_plan
|
||||
|
||||
if use_exec_api:
|
||||
logger.warning("Using experimental execution plan impl.")
|
||||
logger.warning(
|
||||
"The experimental distributed execution API is enabled "
|
||||
"for this algorithm. Disable this by setting "
|
||||
"'use_exec_api': False.")
|
||||
self.train_exec_impl = execution_plan(self.workers, config)
|
||||
elif make_policy_optimizer:
|
||||
self.optimizer = make_policy_optimizer(self.workers, config)
|
||||
|
||||
@@ -74,9 +74,10 @@ class WorkerSet:
|
||||
|
||||
def sync_weights(self):
|
||||
"""Syncs weights of remote workers with the local worker."""
|
||||
weights = ray.put(self.local_worker().get_weights())
|
||||
for e in self.remote_workers():
|
||||
e.set_weights.remote(weights)
|
||||
if self.remote_workers():
|
||||
weights = ray.put(self.local_worker().get_weights())
|
||||
for e in self.remote_workers():
|
||||
e.set_weights.remote(weights)
|
||||
|
||||
def add_workers(self, num_workers):
|
||||
"""Creates and add a number of remote workers to this worker set.
|
||||
|
||||
@@ -17,7 +17,9 @@ def rollout_test(algo, env="CartPole-v0"):
|
||||
os.system("python {}/train.py --local-dir={} --run={} "
|
||||
"--checkpoint-freq=1 ".format(rllib_dir, tmp_dir, algo) +
|
||||
"--config='{\"num_workers\": 1, \"num_gpus\": 0}' "
|
||||
"--stop='{\"training_iteration\": 1}'" + " --env={}".format(env))
|
||||
"--stop='{\"training_iteration\": 1, "
|
||||
"\"timesteps_per_iter\": 10, "
|
||||
"\"min_iter_time_s\": 1}'" + " --env={}".format(env))
|
||||
|
||||
checkpoint_path = os.popen(
|
||||
"ls {}/default/*/checkpoint_1/checkpoint-1".format(tmp_dir)).read()[:
|
||||
|
||||
@@ -173,20 +173,25 @@ class TestRolloutWorker(unittest.TestCase):
|
||||
agent = A2CTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"lr_schedule": [[0, 0.1], [400, 0.000001]],
|
||||
"num_workers": 1,
|
||||
"lr_schedule": [[0, 0.1], [100000, 0.000001]],
|
||||
})
|
||||
result = agent.train()
|
||||
self.assertGreater(result["info"]["learner"]["cur_lr"], 0.01)
|
||||
result2 = agent.train()
|
||||
print("num_steps_sampled={}".format(
|
||||
result["info"]["num_steps_sampled"]))
|
||||
print("num_steps_trained={}".format(
|
||||
result["info"]["num_steps_trained"]))
|
||||
self.assertLess(result2["info"]["learner"]["cur_lr"], 0.09)
|
||||
print("num_steps_sampled={}".format(
|
||||
result["info"]["num_steps_sampled"]))
|
||||
print("num_steps_trained={}".format(
|
||||
result["info"]["num_steps_trained"]))
|
||||
for i in range(10):
|
||||
result = agent.train()
|
||||
print("num_steps_sampled={}".format(
|
||||
result["info"]["num_steps_sampled"]))
|
||||
print("num_steps_trained={}".format(
|
||||
result["info"]["num_steps_trained"]))
|
||||
print("num_steps_sampled={}".format(
|
||||
result["info"]["num_steps_sampled"]))
|
||||
print("num_steps_trained={}".format(
|
||||
result["info"]["num_steps_trained"]))
|
||||
if i == 0:
|
||||
self.assertGreater(result["info"]["learner"]["cur_lr"], 0.01)
|
||||
if result["info"]["learner"]["cur_lr"] < 0.07:
|
||||
break
|
||||
self.assertLess(result["info"]["learner"]["cur_lr"], 0.07)
|
||||
|
||||
def test_no_step_on_init(self):
|
||||
# Allow for Unittest run.
|
||||
@@ -213,11 +218,10 @@ class TestRolloutWorker(unittest.TestCase):
|
||||
pg.train()
|
||||
pg.train()
|
||||
pg.train()
|
||||
self.assertEqual(counts["sample"], 4)
|
||||
self.assertGreater(counts["sample"], 0)
|
||||
self.assertGreater(counts["start"], 0)
|
||||
self.assertGreater(counts["end"], 0)
|
||||
self.assertGreater(counts["step"], 200)
|
||||
self.assertLess(counts["step"], 400)
|
||||
self.assertGreater(counts["step"], 0)
|
||||
|
||||
def test_query_evaluators(self):
|
||||
# Allow for Unittest run.
|
||||
|
||||
Reference in New Issue
Block a user