[rllib] Enable distributed exec api for A2C, A3C, PG by default (#7580)

This commit is contained in:
Eric Liang
2020-03-13 18:48:41 -07:00
committed by GitHub
parent 094125cf03
commit c3a8ba399f
8 changed files with 56 additions and 29 deletions
+5 -5
View File
@@ -995,7 +995,7 @@ py_test(
py_test(
name = "tests/test_external_multi_agent_env",
tags = ["tests_dir", "tests_dir_E"],
size = "large",
size = "medium",
srcs = ["tests/test_external_multi_agent_env.py"]
)
@@ -1016,21 +1016,21 @@ py_test(
py_test(
name = "tests/test_io",
tags = ["tests_dir", "tests_dir_I"],
size = "large",
size = "medium",
srcs = ["tests/test_io.py"]
)
py_test(
name = "tests/test_local",
tags = ["tests_dir", "tests_dir_L"],
size = "large",
size = "medium",
srcs = ["tests/test_local.py"]
)
py_test(
name = "tests/test_lstm",
tags = ["tests_dir", "tests_dir_L"],
size = "large",
size = "medium",
srcs = ["tests/test_lstm.py"]
)
@@ -1051,7 +1051,7 @@ py_test(
py_test(
name = "tests/test_nested_spaces",
tags = ["tests_dir", "tests_dir_N"],
size = "large",
size = "small",
srcs = ["tests/test_nested_spaces.py"]
)
+2
View File
@@ -37,6 +37,8 @@ DEFAULT_CONFIG = with_common_config({
# Workers sample async. Note that this increases the effective
# sample_batch_size by up to 5x due to async buffering of batches.
"sample_async": True,
# Use the execution plan API instead of policy optimizers.
"use_exec_api": True,
})
# __sphinx_doc_end__
# yapf: enable
+2
View File
@@ -11,6 +11,8 @@ DEFAULT_CONFIG = with_common_config({
"num_workers": 0,
# Learning rate.
"lr": 0.0004,
# Use the execution plan API instead of policy optimizers.
"use_exec_api": True,
})
# __sphinx_doc_end__
# yapf: enable
+16 -4
View File
@@ -922,19 +922,25 @@ class Trainer(Trainable):
an error is raised.
"""
if not self._has_policy_optimizer():
if (not self._has_policy_optimizer()
and not hasattr(self, "execution_plan")):
raise NotImplementedError(
"Recovery is not supported for this algorithm")
if self._has_policy_optimizer():
workers = self.optimizer.workers
else:
assert hasattr(self, "execution_plan")
workers = self.workers
logger.info("Health checking all workers...")
checks = []
for ev in self.optimizer.workers.remote_workers():
for ev in workers.remote_workers():
_, obj_id = ev.sample_with_count.remote()
checks.append(obj_id)
healthy_workers = []
for i, obj_id in enumerate(checks):
w = self.optimizer.workers.remote_workers()[i]
w = workers.remote_workers()[i]
try:
ray_get_and_free(obj_id)
healthy_workers.append(w)
@@ -950,7 +956,13 @@ class Trainer(Trainable):
raise RuntimeError(
"Not enough healthy workers remain to continue.")
self.optimizer.reset(healthy_workers)
if self._has_policy_optimizer():
self.optimizer.reset(healthy_workers)
else:
assert hasattr(self, "execution_plan")
logger.warning("Recreating execution plan after failure")
workers.reset(healthy_workers)
self.train_exec_impl = self.execution_plan(workers, self.config)
def _has_policy_optimizer(self):
"""Whether this Trainer has a PolicyOptimizer as `optimizer` property.
+5 -1
View File
@@ -121,9 +121,13 @@ def build_trainer(name,
self.config["num_workers"])
self.train_exec_impl = None
self.optimizer = None
self.execution_plan = execution_plan
if use_exec_api:
logger.warning("Using experimental execution plan impl.")
logger.warning(
"The experimental distributed execution API is enabled "
"for this algorithm. Disable this by setting "
"'use_exec_api': False.")
self.train_exec_impl = execution_plan(self.workers, config)
elif make_policy_optimizer:
self.optimizer = make_policy_optimizer(self.workers, config)
+4 -3
View File
@@ -74,9 +74,10 @@ class WorkerSet:
def sync_weights(self):
"""Syncs weights of remote workers with the local worker."""
weights = ray.put(self.local_worker().get_weights())
for e in self.remote_workers():
e.set_weights.remote(weights)
if self.remote_workers():
weights = ray.put(self.local_worker().get_weights())
for e in self.remote_workers():
e.set_weights.remote(weights)
def add_workers(self, num_workers):
"""Creates and add a number of remote workers to this worker set.
+3 -1
View File
@@ -17,7 +17,9 @@ def rollout_test(algo, env="CartPole-v0"):
os.system("python {}/train.py --local-dir={} --run={} "
"--checkpoint-freq=1 ".format(rllib_dir, tmp_dir, algo) +
"--config='{\"num_workers\": 1, \"num_gpus\": 0}' "
"--stop='{\"training_iteration\": 1}'" + " --env={}".format(env))
"--stop='{\"training_iteration\": 1, "
"\"timesteps_per_iter\": 10, "
"\"min_iter_time_s\": 1}'" + " --env={}".format(env))
checkpoint_path = os.popen(
"ls {}/default/*/checkpoint_1/checkpoint-1".format(tmp_dir)).read()[:
+19 -15
View File
@@ -173,20 +173,25 @@ class TestRolloutWorker(unittest.TestCase):
agent = A2CTrainer(
env="CartPole-v0",
config={
"lr_schedule": [[0, 0.1], [400, 0.000001]],
"num_workers": 1,
"lr_schedule": [[0, 0.1], [100000, 0.000001]],
})
result = agent.train()
self.assertGreater(result["info"]["learner"]["cur_lr"], 0.01)
result2 = agent.train()
print("num_steps_sampled={}".format(
result["info"]["num_steps_sampled"]))
print("num_steps_trained={}".format(
result["info"]["num_steps_trained"]))
self.assertLess(result2["info"]["learner"]["cur_lr"], 0.09)
print("num_steps_sampled={}".format(
result["info"]["num_steps_sampled"]))
print("num_steps_trained={}".format(
result["info"]["num_steps_trained"]))
for i in range(10):
result = agent.train()
print("num_steps_sampled={}".format(
result["info"]["num_steps_sampled"]))
print("num_steps_trained={}".format(
result["info"]["num_steps_trained"]))
print("num_steps_sampled={}".format(
result["info"]["num_steps_sampled"]))
print("num_steps_trained={}".format(
result["info"]["num_steps_trained"]))
if i == 0:
self.assertGreater(result["info"]["learner"]["cur_lr"], 0.01)
if result["info"]["learner"]["cur_lr"] < 0.07:
break
self.assertLess(result["info"]["learner"]["cur_lr"], 0.07)
def test_no_step_on_init(self):
# Allow for Unittest run.
@@ -213,11 +218,10 @@ class TestRolloutWorker(unittest.TestCase):
pg.train()
pg.train()
pg.train()
self.assertEqual(counts["sample"], 4)
self.assertGreater(counts["sample"], 0)
self.assertGreater(counts["start"], 0)
self.assertGreater(counts["end"], 0)
self.assertGreater(counts["step"], 200)
self.assertLess(counts["step"], 400)
self.assertGreater(counts["step"], 0)
def test_query_evaluators(self):
# Allow for Unittest run.