mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 21:46:57 +08:00
[rllib] Add option to proceed even if some workers crashed (#4376)
This commit is contained in:
@@ -13,6 +13,7 @@ import tensorflow as tf
|
||||
from types import FunctionType
|
||||
|
||||
import ray
|
||||
from ray.exceptions import RayError
|
||||
from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \
|
||||
ShuffledInput
|
||||
from ray.rllib.models import MODEL_DEFAULTS
|
||||
@@ -29,6 +30,10 @@ from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Max number of times to retry a worker failure. We shouldn't try too many
|
||||
# times in a row since that would indicate a persistent cluster issue.
|
||||
MAX_WORKER_FAILURE_RETRIES = 3
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
COMMON_CONFIG = {
|
||||
@@ -48,6 +53,8 @@ COMMON_CONFIG = {
|
||||
"on_sample_end": None, # arg: {"samples": .., "evaluator": ...}
|
||||
"on_train_result": None, # arg: {"agent": ..., "result": ...}
|
||||
},
|
||||
# Whether to attempt to continue training if a worker crashes.
|
||||
"ignore_worker_failures": False,
|
||||
|
||||
# === Policy ===
|
||||
# Arguments to pass to model. See models/catalog.py for a full list of the
|
||||
@@ -99,7 +106,7 @@ COMMON_CONFIG = {
|
||||
"train_batch_size": 200,
|
||||
# Whether to rollout "complete_episodes" or "truncate_episodes"
|
||||
"batch_mode": "truncate_episodes",
|
||||
# Whether to use a background thread for sampling (slightly off-policy)
|
||||
# (Deprecated) Use a background thread for sampling (slightly off-policy)
|
||||
"sample_async": False,
|
||||
# Element-wise observation filter, either "NoFilter" or "MeanStdFilter"
|
||||
"observation_filter": "NoFilter",
|
||||
@@ -285,15 +292,32 @@ class Agent(Trainable):
|
||||
def train(self):
|
||||
"""Overrides super.train to synchronize global vars."""
|
||||
|
||||
if hasattr(self, "optimizer") and isinstance(self.optimizer,
|
||||
PolicyOptimizer):
|
||||
if self._has_policy_optimizer():
|
||||
self.global_vars["timestep"] = self.optimizer.num_steps_sampled
|
||||
self.optimizer.local_evaluator.set_global_vars(self.global_vars)
|
||||
for ev in self.optimizer.remote_evaluators:
|
||||
ev.set_global_vars.remote(self.global_vars)
|
||||
logger.debug("updated global vars: {}".format(self.global_vars))
|
||||
|
||||
result = Trainable.train(self)
|
||||
result = None
|
||||
for _ in range(1 + MAX_WORKER_FAILURE_RETRIES):
|
||||
try:
|
||||
result = Trainable.train(self)
|
||||
except RayError as e:
|
||||
if self.config["ignore_worker_failures"]:
|
||||
logger.exception(
|
||||
"Error in train call, attempting to recover")
|
||||
self._try_recover()
|
||||
else:
|
||||
logger.info(
|
||||
"Worker crashed during call to train(). To attempt to "
|
||||
"continue training without the failed worker, set "
|
||||
"`'ignore_worker_failures': True`.")
|
||||
raise e
|
||||
else:
|
||||
break
|
||||
if result is None:
|
||||
raise RuntimeError("Failed to recover from worker crash")
|
||||
|
||||
if (self.config.get("observation_filter", "NoFilter") != "NoFilter"
|
||||
and hasattr(self, "local_evaluator")):
|
||||
@@ -304,6 +328,9 @@ class Agent(Trainable):
|
||||
logger.debug("synchronized filters: {}".format(
|
||||
self.local_evaluator.filters))
|
||||
|
||||
if self._has_policy_optimizer():
|
||||
result["num_healthy_workers"] = len(
|
||||
self.optimizer.remote_evaluators)
|
||||
return result
|
||||
|
||||
@override(Trainable)
|
||||
@@ -558,6 +585,49 @@ class Agent(Trainable):
|
||||
"`input_evaluation` must be a list of strings, got {}".format(
|
||||
config["input_evaluation"]))
|
||||
|
||||
def _try_recover(self):
|
||||
"""Try to identify and blacklist any unhealthy workers.
|
||||
|
||||
This method is called after an unexpected remote error is encountered
|
||||
from a worker. It issues check requests to all current workers and
|
||||
blacklists any that respond with error. If no healthy workers remain,
|
||||
an error is raised.
|
||||
"""
|
||||
|
||||
if not self._has_policy_optimizer():
|
||||
raise NotImplementedError(
|
||||
"Recovery is not supported for this algorithm")
|
||||
|
||||
logger.info("Health checking all workers...")
|
||||
checks = []
|
||||
for ev in self.optimizer.remote_evaluators:
|
||||
_, obj_id = ev.sample_with_count.remote()
|
||||
checks.append(obj_id)
|
||||
|
||||
healthy_evaluators = []
|
||||
for i, obj_id in enumerate(checks):
|
||||
ev = self.optimizer.remote_evaluators[i]
|
||||
try:
|
||||
ray.get(obj_id)
|
||||
healthy_evaluators.append(ev)
|
||||
logger.info("Worker {} looks healthy".format(i + 1))
|
||||
except RayError:
|
||||
logger.exception("Blacklisting worker {}".format(i + 1))
|
||||
try:
|
||||
ev.__ray_terminate__.remote()
|
||||
except Exception:
|
||||
logger.exception("Error terminating unhealthy worker")
|
||||
|
||||
if len(healthy_evaluators) < 1:
|
||||
raise RuntimeError(
|
||||
"Not enough healthy workers remain to continue.")
|
||||
|
||||
self.optimizer.reset(healthy_evaluators)
|
||||
|
||||
def _has_policy_optimizer(self):
|
||||
return hasattr(self, "optimizer") and isinstance(
|
||||
self.optimizer, PolicyOptimizer)
|
||||
|
||||
def _make_evaluator(self, cls, env_creator, policy_graph, worker_index,
|
||||
config):
|
||||
def session_creator():
|
||||
|
||||
@@ -116,7 +116,8 @@ class ImpalaAgent(Agent):
|
||||
prev_steps = self.optimizer.num_steps_sampled
|
||||
start = time.time()
|
||||
self.optimizer.step()
|
||||
while time.time() - start < self.config["min_iter_time_s"]:
|
||||
while (time.time() - start < self.config["min_iter_time_s"]
|
||||
or self.optimizer.num_steps_sampled == prev_steps):
|
||||
self.optimizer.step()
|
||||
result = self.optimizer.collect_metrics(
|
||||
self.config["collect_metrics_timeout"])
|
||||
|
||||
@@ -171,6 +171,8 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
||||
queue_putter(item)
|
||||
|
||||
def get_data(self):
|
||||
if not self.is_alive():
|
||||
raise RuntimeError("Sampling thread has died")
|
||||
rollout = self.queue.get(timeout=600.0)
|
||||
|
||||
# Propagate errors
|
||||
|
||||
@@ -133,6 +133,11 @@ class AsyncReplayOptimizer(PolicyOptimizer):
|
||||
r.__ray_terminate__.remote()
|
||||
self.learner.stopped = True
|
||||
|
||||
@override(PolicyOptimizer)
|
||||
def reset(self, remote_evaluators):
|
||||
self.remote_evaluators = remote_evaluators
|
||||
self.sample_tasks.reset_evaluators(remote_evaluators)
|
||||
|
||||
@override(PolicyOptimizer)
|
||||
def stats(self):
|
||||
replay_stats = ray.get(self.replay_actors[0].stats.remote(self.debug))
|
||||
|
||||
@@ -152,6 +152,11 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
|
||||
def stop(self):
|
||||
self.learner.stopped = True
|
||||
|
||||
@override(PolicyOptimizer)
|
||||
def reset(self, remote_evaluators):
|
||||
self.remote_evaluators = remote_evaluators
|
||||
self.sample_tasks.reset_evaluators(remote_evaluators)
|
||||
|
||||
@override(PolicyOptimizer)
|
||||
def stats(self):
|
||||
def timer_to_ms(timer):
|
||||
|
||||
@@ -142,6 +142,12 @@ class PolicyOptimizer(object):
|
||||
res.update(info=self.stats())
|
||||
return res
|
||||
|
||||
@DeveloperAPI
|
||||
def reset(self, remote_evaluators):
|
||||
"""Called to change the set of remote evaluators being used."""
|
||||
|
||||
self.remote_evaluators = remote_evaluators
|
||||
|
||||
@DeveloperAPI
|
||||
def foreach_evaluator(self, func):
|
||||
"""Apply the given function to each evaluator instance."""
|
||||
|
||||
@@ -0,0 +1,109 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import gym
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib import _register_all
|
||||
from ray.rllib.agents.registry import get_agent_class
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
|
||||
class FaultInjectEnv(gym.Env):
|
||||
def __init__(self, config):
|
||||
self.env = gym.make("CartPole-v0")
|
||||
self.action_space = self.env.action_space
|
||||
self.observation_space = self.env.observation_space
|
||||
self.config = config
|
||||
|
||||
def reset(self):
|
||||
return self.env.reset()
|
||||
|
||||
def step(self, action):
|
||||
if self.config.worker_index in self.config["bad_indices"]:
|
||||
raise ValueError("This is a simulated error from {}".format(
|
||||
self.config.worker_index))
|
||||
return self.env.step(action)
|
||||
|
||||
|
||||
class IgnoresWorkerFailure(unittest.TestCase):
|
||||
def doTest(self, alg, config, fn=None):
|
||||
fn = fn or self._doTestFaultRecover
|
||||
try:
|
||||
ray.init(num_cpus=6)
|
||||
fn(alg, config)
|
||||
finally:
|
||||
ray.shutdown()
|
||||
_register_all() # re-register the evicted objects
|
||||
|
||||
def _doTestFaultRecover(self, alg, config):
|
||||
register_env("fault_env", lambda c: FaultInjectEnv(c))
|
||||
agent_cls = get_agent_class(alg)
|
||||
|
||||
# Test fault handling
|
||||
config["num_workers"] = 2
|
||||
config["ignore_worker_failures"] = True
|
||||
config["env_config"] = {"bad_indices": [1]}
|
||||
a = agent_cls(config=config, env="fault_env")
|
||||
result = a.train()
|
||||
self.assertTrue(result["num_healthy_workers"], 1)
|
||||
a.stop()
|
||||
|
||||
def _doTestFaultFatal(self, alg, config):
|
||||
register_env("fault_env", lambda c: FaultInjectEnv(c))
|
||||
agent_cls = get_agent_class(alg)
|
||||
|
||||
# Test raises real error when out of workers
|
||||
config["num_workers"] = 2
|
||||
config["ignore_worker_failures"] = True
|
||||
config["env_config"] = {"bad_indices": [1, 2]}
|
||||
a = agent_cls(config=config, env="fault_env")
|
||||
self.assertRaises(Exception, lambda: a.train())
|
||||
a.stop()
|
||||
|
||||
def testFatal(self):
|
||||
# test the case where all workers fail
|
||||
self.doTest("PG", {"optimizer": {}}, fn=self._doTestFaultFatal)
|
||||
|
||||
def testAsyncGrads(self):
|
||||
self.doTest("A3C", {"optimizer": {"grads_per_step": 1}})
|
||||
|
||||
def testAsyncReplay(self):
|
||||
self.doTest(
|
||||
"APEX", {
|
||||
"timesteps_per_iteration": 1000,
|
||||
"num_gpus": 0,
|
||||
"min_iter_time_s": 1,
|
||||
"learning_starts": 1000,
|
||||
"target_network_update_freq": 100,
|
||||
"optimizer": {
|
||||
"num_replay_buffer_shards": 1,
|
||||
},
|
||||
})
|
||||
|
||||
def testAsyncSamples(self):
|
||||
self.doTest("IMPALA", {"num_gpus": 0})
|
||||
|
||||
def testSyncReplay(self):
|
||||
self.doTest("DQN", {"timesteps_per_iteration": 1})
|
||||
|
||||
def testMultiGPU(self):
|
||||
self.doTest(
|
||||
"PPO", {
|
||||
"num_sgd_iter": 1,
|
||||
"train_batch_size": 10,
|
||||
"sample_batch_size": 10,
|
||||
"sgd_minibatch_size": 1,
|
||||
})
|
||||
|
||||
def testSyncSamples(self):
|
||||
self.doTest("PG", {"optimizer": {}})
|
||||
|
||||
def testAsyncSamplingOption(self):
|
||||
self.doTest("PG", {"optimizer": {}, "sample_async": True})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
@@ -53,6 +53,18 @@ class TaskPool(object):
|
||||
remaining.append((worker, obj_id))
|
||||
self._fetching = remaining
|
||||
|
||||
def reset_evaluators(self, evaluators):
|
||||
"""Notify that some evaluators may be removed."""
|
||||
for obj_id, ev in self._tasks.copy().items():
|
||||
if ev not in evaluators:
|
||||
del self._tasks[obj_id]
|
||||
del self._objects[obj_id]
|
||||
ok = []
|
||||
for ev, obj_id in self._fetching:
|
||||
if ev in evaluators:
|
||||
ok.append((ev, obj_id))
|
||||
self._fetching = ok
|
||||
|
||||
@property
|
||||
def count(self):
|
||||
return len(self._tasks)
|
||||
|
||||
Reference in New Issue
Block a user