From 1f00f834ac44d06dbe7264ca712b8d587b0a60f6 Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Mon, 18 Jan 2021 19:29:03 +0100 Subject: [PATCH] [RLlib] Solve PyTorch/TF-eager A3C async race condition between calling model and its value function. (#13467) --- rllib/agents/a3c/a3c.py | 8 +++--- rllib/agents/a3c/tests/test_a2c.py | 1 - rllib/agents/a3c/tests/test_a3c.py | 3 +-- rllib/policy/eager_tf_policy.py | 13 +++++++++ rllib/policy/torch_policy.py | 18 +++++++++++++ .../exploration/tests/test_explorations.py | 2 +- rllib/utils/threading.py | 27 +++++++++++++++++++ 7 files changed, 63 insertions(+), 9 deletions(-) create mode 100644 rllib/utils/threading.py diff --git a/rllib/agents/a3c/a3c.py b/rllib/agents/a3c/a3c.py index 88e91bf82..9cb935568 100644 --- a/rllib/agents/a3c/a3c.py +++ b/rllib/agents/a3c/a3c.py @@ -53,11 +53,9 @@ def get_policy_class(config): def validate_config(config): if config["entropy_coeff"] < 0: - raise DeprecationWarning("`entropy_coeff` must be >= 0") - if config["sample_async"] and config["framework"] == "torch": - config["sample_async"] = False - logger.warning("`sample_async=True` is not supported for PyTorch! " - "Multithreading can lead to crashes.") + raise ValueError("`entropy_coeff` must be >= 0.0!") + if config["num_workers"] <= 0 and config["sample_async"]: + raise ValueError("`num_workers` for A3C must be >= 1!") def execution_plan(workers, config): diff --git a/rllib/agents/a3c/tests/test_a2c.py b/rllib/agents/a3c/tests/test_a2c.py index 9924755eb..e1198de04 100644 --- a/rllib/agents/a3c/tests/test_a2c.py +++ b/rllib/agents/a3c/tests/test_a2c.py @@ -25,7 +25,6 @@ class TestA2C(unittest.TestCase): # Test against all frameworks. for fw in framework_iterator(config): - config["sample_async"] = fw in ["tf", "tfe", "tf2"] for env in ["PongDeterministic-v0"]: trainer = a3c.A2CTrainer(config=config, env=env) for i in range(num_iterations): diff --git a/rllib/agents/a3c/tests/test_a3c.py b/rllib/agents/a3c/tests/test_a3c.py index 8c8c621d2..37ab55f33 100644 --- a/rllib/agents/a3c/tests/test_a3c.py +++ b/rllib/agents/a3c/tests/test_a3c.py @@ -24,8 +24,7 @@ class TestA3C(unittest.TestCase): num_iterations = 1 # Test against all frameworks. - for fw in framework_iterator(config): - config["sample_async"] = fw == "tf" + for _ in framework_iterator(config): for env in ["CartPole-v0", "Pendulum-v0", "PongDeterministic-v0"]: print("env={}".format(env)) trainer = a3c.A3CTrainer(config=config, env=env) diff --git a/rllib/policy/eager_tf_policy.py b/rllib/policy/eager_tf_policy.py index c21f31755..805cacaaa 100644 --- a/rllib/policy/eager_tf_policy.py +++ b/rllib/policy/eager_tf_policy.py @@ -4,6 +4,7 @@ It supports both traced and non-traced eager execution modes.""" import functools import logging +import threading from ray.util.debug import log_once from ray.rllib.models.catalog import ModelCatalog @@ -15,6 +16,7 @@ from ray.rllib.utils import add_mixins, force_list from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.tf_ops import convert_to_non_tf_type +from ray.rllib.utils.threading import with_lock from ray.rllib.utils.tracking_dict import UsageTrackingDict tf1, tf, tfv = try_import_tf() @@ -255,6 +257,13 @@ def build_eager_tf_policy(name, config["model"], framework=self.framework, ) + # Lock used for locking some methods on the object-level. + # This prevents possible race conditions when calling the model + # first, then its value function (e.g. in a loss function), in + # between of which another model call is made (e.g. to compute an + # action). + self._lock = threading.RLock() + # Auto-update model's inference view requirements, if recurrent. self._update_model_view_requirements_from_init_state() @@ -305,6 +314,7 @@ def build_eager_tf_policy(name, episode) return sample_batch + @with_lock @override(Policy) def learn_on_batch(self, postprocessed_batch): # Callback handling. @@ -351,6 +361,7 @@ def build_eager_tf_policy(name, grads = [g for g, v in grads_and_vars] return grads, stats + @with_lock @override(Policy) @convert_eager_inputs @convert_eager_outputs @@ -448,6 +459,7 @@ def build_eager_tf_policy(name, return actions, state_out, extra_fetches + @with_lock @override(Policy) def compute_log_likelihoods(self, actions, @@ -593,6 +605,7 @@ def build_eager_tf_policy(name, self._optimizer.apply_gradients( [(g, v) for g, v in grads_and_vars if g is not None]) + @with_lock def _compute_gradients(self, samples): """Computes and returns grads as eager tensors.""" diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index f81ac03ab..19d576d37 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -3,6 +3,7 @@ import gym import logging import numpy as np import time +import threading from typing import Callable, Dict, List, Optional, Tuple, Type, Union from ray.rllib.models.modelv2 import ModelV2 @@ -15,6 +16,7 @@ from ray.rllib.utils import force_list from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule +from ray.rllib.utils.threading import with_lock from ray.rllib.utils.torch_ops import convert_to_non_torch_type, \ convert_to_torch_tensor from ray.rllib.utils.tracking_dict import UsageTrackingDict @@ -110,6 +112,14 @@ class TorchPolicy(Policy): logger.info("TorchPolicy running on CPU.") self.device = torch.device("cpu") self.model = model.to(self.device) + + # Lock used for locking some methods on the object-level. + # This prevents possible race conditions when calling the model + # first, then its value function (e.g. in a loss function), in + # between of which another model call is made (e.g. to compute an + # action). + self._lock = threading.RLock() + self._state_inputs = self.model.get_initial_state() self._is_recurrent = len(self._state_inputs) > 0 # Auto-update model's inference view requirements, if recurrent. @@ -197,6 +207,7 @@ class TorchPolicy(Policy): return self._compute_action_helper(input_dict, state_batches, seq_lens, explore, timestep) + @with_lock def _compute_action_helper(self, input_dict, state_batches, seq_lens, explore, timestep): """Shared forward pass logic (w/ and w/o trajectory view API). @@ -206,6 +217,7 @@ class TorchPolicy(Policy): - actions, state_out, extra_fetches, logp. """ self._is_recurrent = state_batches is not None and state_batches != [] + # Switch to eval mode. if self.model: self.model.eval() @@ -274,6 +286,7 @@ class TorchPolicy(Policy): return convert_to_non_torch_type((actions, state_out, extra_fetches)) + @with_lock @override(Policy) @DeveloperAPI def compute_log_likelihoods( @@ -325,12 +338,15 @@ class TorchPolicy(Policy): action_dist = dist_class(dist_inputs, self.model) log_likelihoods = action_dist.logp(input_dict[SampleBatch.ACTIONS]) + return log_likelihoods + @with_lock @override(Policy) @DeveloperAPI def learn_on_batch( self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]: + # Set Model to train mode. if self.model: self.model.train() @@ -348,8 +364,10 @@ class TorchPolicy(Policy): if self.model: fetches["model"] = self.model.metrics() + return fetches + @with_lock @override(Policy) @DeveloperAPI def compute_gradients(self, diff --git a/rllib/utils/exploration/tests/test_explorations.py b/rllib/utils/exploration/tests/test_explorations.py index 61ec1d3e9..47ab55aef 100644 --- a/rllib/utils/exploration/tests/test_explorations.py +++ b/rllib/utils/exploration/tests/test_explorations.py @@ -92,7 +92,7 @@ class TestExplorations(unittest.TestCase): do_test_explorations( a3c.A2CTrainer, "CartPole-v0", - a3c.DEFAULT_CONFIG, + a3c.a2c.A2C_DEFAULT_CONFIG, np.array([0.0, 0.1, 0.0, 0.0]), prev_a=np.array(1)) diff --git a/rllib/utils/threading.py b/rllib/utils/threading.py new file mode 100644 index 000000000..7361dad65 --- /dev/null +++ b/rllib/utils/threading.py @@ -0,0 +1,27 @@ +from typing import Callable + + +def with_lock(func: Callable): + """Use as decorator (@withlock) around object methods that need locking. + + Note: The object must have a self._lock = threading.Lock() property. + Locking thus works on the object level (no two locked methods of the same + object can be called asynchronously). + + Args: + func (Callable): The function to decorate/wrap. + + Returns: + Callable: The wrapped (object-level locked) function. + """ + + def wrapper(self, *a, **k): + try: + with self._lock: + return func(self, *a, **k) + except AttributeError: + raise AttributeError( + "Object {} must have a `self._lock` property (assigned to a " + "threading.Lock() object in its constructor)!".format(self)) + + return wrapper