From e2edca45d40e78ab7b8345956868136b51604fec Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Sat, 22 Feb 2020 20:02:31 +0100 Subject: [PATCH] [RLlib] PPO torch memory leak and unnecessary torch.Tensor creation and gc'ing. (#7238) * Take out stats to analyze memory leak in torch PPO. * WIP * WIP * WIP * WIP * WIP * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * LINT. * Fix determine_tests_to_run.py. * minor change to re-test after determine_tests_to_run.py. * LINT. * update comments. * WIP * WIP * WIP * FIX. * Fix sequence_mask being dependent on torch being installed. * Fix strange ray-core tf-error in test_memory_scheduling test case. * Fix strange ray-core tf-error in test_memory_scheduling test case. * Fix strange ray-core tf-error in test_memory_scheduling test case. * Fix strange ray-core tf-error in test_memory_scheduling test case. --- .travis.yml | 2 +- rllib/BUILD | 5 +-- rllib/agents/a3c/a3c_torch_policy.py | 16 ++++---- rllib/agents/ppo/ppo.py | 5 ++- rllib/agents/ppo/ppo_torch_policy.py | 16 ++++---- rllib/agents/ppo/tests/test_ppo.py | 1 - rllib/policy/torch_policy_template.py | 40 ++++++++++++------- .../utils/exploration/stochastic_sampling.py | 36 ++++++++++------- rllib/utils/test_utils.py | 2 +- rllib/utils/torch_ops.py | 24 ++++++++++- 10 files changed, 93 insertions(+), 54 deletions(-) diff --git a/.travis.yml b/.travis.yml index 535ee11ed..5b9b0be49 100644 --- a/.travis.yml +++ b/.travis.yml @@ -199,7 +199,7 @@ matrix: - ./ci/suppress_output ./ci/travis/install-ray.sh script: - if [ $RAY_CI_RLLIB_FULL_AFFECTED != "1" ]; then exit; fi - - travis_wait 30 bazel test --build_tests_only --test_tag_filters=quick_train --spawn_strategy=local --flaky_test_attempts=3 --nocache_test_results --test_verbose_timeout_warnings --progress_report_interval=100 --show_progress_rate_limit=100 --show_timestamps --test_output=errors rllib/... + - travis_wait 60 bazel test --build_tests_only --test_tag_filters=quick_train --spawn_strategy=local --flaky_test_attempts=3 --nocache_test_results --test_verbose_timeout_warnings --progress_report_interval=100 --show_progress_rate_limit=100 --show_timestamps --test_output=errors rllib/... # Test everything that does not have any of the "main" labels: # "learning_tests|quick_train|examples|tests_dir". - ./ci/keep_alive bazel test --build_tests_only --test_tag_filters=-learning_tests,-quick_train,-examples,-tests_dir --spawn_strategy=local --flaky_test_attempts=3 --nocache_test_results --test_verbose_timeout_warnings --progress_report_interval=100 --show_progress_rate_limit=100 --show_timestamps --test_output=errors rllib/... diff --git a/rllib/BUILD b/rllib/BUILD index 99d0ae6c7..a066c3010 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -99,7 +99,6 @@ py_test( srcs = ["agents/impala/tests/test_vtrace.py"] ) - # -------------------------------------------------------------------- # contrib Agents # -------------------------------------------------------------------- @@ -1028,7 +1027,7 @@ py_test( py_test( - name = "autoregressive_action_dist", main = "examples/autoregressive_action_dist.py", + name = "examples/autoregressive_action_dist", main = "examples/autoregressive_action_dist.py", tags = ["examples", "examples_A"], size = "large", srcs = ["examples/autoregressive_action_dist.py"], @@ -1177,7 +1176,7 @@ py_test( py_test( name = "examples/custom_tf_policy", tags = ["examples", "examples_C"], - size = "small", + size = "medium", srcs = ["examples/custom_tf_policy.py"], args = ["--iters=2", "--num-cpus=4"] ) diff --git a/rllib/agents/a3c/a3c_torch_policy.py b/rllib/agents/a3c/a3c_torch_policy.py index 7ab577993..0e4ede629 100644 --- a/rllib/agents/a3c/a3c_torch_policy.py +++ b/rllib/agents/a3c/a3c_torch_policy.py @@ -38,19 +38,22 @@ def add_advantages(policy, sample_batch, other_agent_batches=None, episode=None): + completed = sample_batch[SampleBatch.DONES][-1] if completed: last_r = 0.0 else: last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1]) - return compute_advantages( - sample_batch, last_r, policy.config["gamma"], policy.config["lambda"], - policy.config["use_gae"], policy.config["use_critic"]) + + return compute_advantages(sample_batch, last_r, policy.config["gamma"], + policy.config["lambda"], + policy.config["use_gae"], + policy.config["use_critic"]) def model_value_predictions(policy, input_dict, state_batches, model, action_dist): - return {SampleBatch.VF_PREDS: model.value_function().cpu().numpy()} + return {SampleBatch.VF_PREDS: model.value_function()} def apply_grad_clipping(policy): @@ -68,9 +71,8 @@ def torch_optimizer(policy, config): class ValueNetworkMixin: def _value(self, obs): - obs = torch.from_numpy(obs).float().unsqueeze(0).to(self.device) - _ = self.model({"obs": obs}, [], [1]) - return self.model.value_function().detach().cpu().numpy().squeeze() + _ = self.model({"obs": torch.Tensor([obs]).to(self.device)}, [], [1]) + return self.model.value_function()[0] A3CTorchPolicy = build_torch_policy( diff --git a/rllib/agents/ppo/ppo.py b/rllib/agents/ppo/ppo.py index 4de31a3e2..9831771e1 100644 --- a/rllib/agents/ppo/ppo.py +++ b/rllib/agents/ppo/ppo.py @@ -96,10 +96,12 @@ def choose_policy_optimizer(workers, config): def update_kl(trainer, fetches): + # Single-agent. if "kl" in fetches: - # single-agent trainer.workers.local_worker().for_policy( lambda pi: pi.update_kl(fetches["kl"])) + + # Multi-agent. else: def update(pi, pi_id): @@ -108,7 +110,6 @@ def update_kl(trainer, fetches): else: logger.debug("No data for {}, not updating kl".format(pi_id)) - # multi-agent trainer.workers.local_worker().foreach_trainable_policy(update) diff --git a/rllib/agents/ppo/ppo_torch_policy.py b/rllib/agents/ppo/ppo_torch_policy.py index ad9ae27dc..af6cd0ec5 100644 --- a/rllib/agents/ppo/ppo_torch_policy.py +++ b/rllib/agents/ppo/ppo_torch_policy.py @@ -144,15 +144,15 @@ def kl_and_loss_stats(policy, train_batch): return { "cur_kl_coeff": policy.kl_coeff, "cur_lr": policy.cur_lr, - "total_loss": policy.loss_obj.loss.cpu().detach().numpy(), - "policy_loss": policy.loss_obj.mean_policy_loss.cpu().detach().numpy(), - "vf_loss": policy.loss_obj.mean_vf_loss.cpu().detach().numpy(), + "total_loss": policy.loss_obj.loss, + "policy_loss": policy.loss_obj.mean_policy_loss, + "vf_loss": policy.loss_obj.mean_vf_loss, "vf_explained_var": explained_variance( train_batch[Postprocessing.VALUE_TARGETS], policy.model.value_function(), - framework="torch").cpu().detach().numpy(), - "kl": policy.loss_obj.mean_kl.cpu().detach().numpy(), - "entropy": policy.loss_obj.mean_entropy.cpu().detach().numpy(), + framework="torch"), + "kl": policy.loss_obj.mean_kl, + "entropy": policy.loss_obj.mean_entropy, "entropy_coeff": policy.entropy_coeff, } @@ -161,8 +161,8 @@ def vf_preds_and_logits_fetches(policy, input_dict, state_batches, model, action_dist): """Adds value function and logits outputs to experience train_batches.""" return { - SampleBatch.VF_PREDS: policy.model.value_function().cpu().numpy(), - BEHAVIOUR_LOGITS: policy.model.last_output().cpu().numpy(), + SampleBatch.VF_PREDS: policy.model.value_function(), + BEHAVIOUR_LOGITS: policy.model.last_output(), } diff --git a/rllib/agents/ppo/tests/test_ppo.py b/rllib/agents/ppo/tests/test_ppo.py index ca75d395d..f5d7b24e3 100644 --- a/rllib/agents/ppo/tests/test_ppo.py +++ b/rllib/agents/ppo/tests/test_ppo.py @@ -39,7 +39,6 @@ class TestPPO(unittest.TestCase): # Torch. config["use_pytorch"] = True - config["simple_optimizer"] = True trainer = ppo.PPOTrainer(config=config, env="CartPole-v0") for i in range(num_iterations): trainer.train() diff --git a/rllib/policy/torch_policy_template.py b/rllib/policy/torch_policy_template.py index 4e08b43b5..01d963ea7 100644 --- a/rllib/policy/torch_policy_template.py +++ b/rllib/policy/torch_policy_template.py @@ -4,6 +4,10 @@ from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.utils import add_mixins from ray.rllib.utils.annotations import override, DeveloperAPI +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.torch_ops import convert_to_non_torch_type + +torch, _ = try_import_torch() @DeveloperAPI @@ -97,8 +101,12 @@ def build_torch_policy(name, episode=None): if not postprocess_fn: return sample_batch - return postprocess_fn(self, sample_batch, other_agent_batches, - episode) + + # Do all post-processing always with no_grad(). + # Not using this here will introduce a memory leak (issue #6962). + with torch.no_grad(): + return postprocess_fn(self, sample_batch, other_agent_batches, + episode) @override(TorchPolicy) def extra_grad_process(self): @@ -110,14 +118,16 @@ def build_torch_policy(name, @override(TorchPolicy) def extra_action_out(self, input_dict, state_batches, model, action_dist=None): - if extra_action_out_fn: - return extra_action_out_fn( - self, input_dict, state_batches, model, action_dist - ) - else: - return TorchPolicy.extra_action_out( - self, input_dict, state_batches, model, action_dist - ) + with torch.no_grad(): + if extra_action_out_fn: + stats_dict = extra_action_out_fn( + self, input_dict, state_batches, model, action_dist + ) + else: + stats_dict = TorchPolicy.extra_action_out( + self, input_dict, state_batches, model, action_dist + ) + return convert_to_non_torch_type(stats_dict) @override(TorchPolicy) def optimizer(self): @@ -128,10 +138,12 @@ def build_torch_policy(name, @override(TorchPolicy) def extra_grad_info(self, train_batch): - if stats_fn: - return stats_fn(self, train_batch) - else: - return TorchPolicy.extra_grad_info(self, train_batch) + with torch.no_grad(): + if stats_fn: + stats_dict = stats_fn(self, train_batch) + else: + stats_dict = TorchPolicy.extra_grad_info(self, train_batch) + return convert_to_non_torch_type(stats_dict) def with_updates(**overrides): return build_torch_policy(**dict(original_kwargs, **overrides)) diff --git a/rllib/utils/exploration/stochastic_sampling.py b/rllib/utils/exploration/stochastic_sampling.py index 23f00db97..880784d4d 100644 --- a/rllib/utils/exploration/stochastic_sampling.py +++ b/rllib/utils/exploration/stochastic_sampling.py @@ -1,8 +1,7 @@ from ray.rllib.models.catalog import ModelCatalog from ray.rllib.utils.annotations import override from ray.rllib.utils.exploration.exploration import Exploration -from ray.rllib.utils.framework import try_import_tf, try_import_torch, \ - tf_function +from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.tuple_actions import TupleActions tf = try_import_tf() @@ -73,22 +72,29 @@ class StochasticSampling(Exploration): else: return self._get_tf_exploration_action_op(action_dist, explore) - @staticmethod - @tf_function(tf) - def _get_tf_exploration_action_op(action_dist, explore): - if explore: - action = action_dist.sample() - # TODO(sven): Change `sample` to accept `sample(logp=True|False)` - logp = action_dist.sampled_action_logp() - else: - action = action_dist.deterministic_sample() + def _get_tf_exploration_action_op(self, action_dist, explore): + sample = action_dist.sample() + deterministic_sample = action_dist.deterministic_sample() + action = tf.cond( + tf.constant(explore) if isinstance(explore, bool) else explore, + true_fn=lambda: sample, + false_fn=lambda: deterministic_sample) + + def logp_false_fn(): # TODO(sven): Move into (deterministic_)sample(logp=True|False) - if isinstance(action, TupleActions): - batch_size = tf.shape(action[0][0])[0] + if isinstance(sample, TupleActions): + batch_size = tf.shape(action[0])[0] else: batch_size = tf.shape(action)[0] - logp = tf.zeros(shape=(batch_size, ), dtype=tf.float32) - return action, logp + return tf.zeros(shape=(batch_size, ), dtype=tf.float32) + + logp = tf.cond( + tf.constant(explore) if isinstance(explore, bool) else explore, + true_fn=lambda: action_dist.sampled_action_logp(), + false_fn=logp_false_fn) + + return TupleActions(action) if isinstance(sample, TupleActions) \ + else action, logp @staticmethod def _get_torch_exploration_action(action_dist, explore): diff --git a/rllib/utils/test_utils.py b/rllib/utils/test_utils.py index 0afaac2c9..9c180b671 100644 --- a/rllib/utils/test_utils.py +++ b/rllib/utils/test_utils.py @@ -100,7 +100,7 @@ def check(x, y, decimals=5, atol=None, rtol=None, false=False): else: with tf.Session() as sess: x = sess.run(x) - check( + return check( x, y, decimals=decimals, diff --git a/rllib/utils/torch_ops.py b/rllib/utils/torch_ops.py index 981147290..799ec09dc 100644 --- a/rllib/utils/torch_ops.py +++ b/rllib/utils/torch_ops.py @@ -3,7 +3,7 @@ from ray.rllib.utils.framework import try_import_torch torch, _ = try_import_torch() -def sequence_mask(lengths, maxlen, dtype=torch.bool): +def sequence_mask(lengths, maxlen, dtype=None): """ Exact same behavior as tf.sequence_mask. Thanks to Dimitris Papatheodorou @@ -15,6 +15,26 @@ def sequence_mask(lengths, maxlen, dtype=torch.bool): mask = ~(torch.ones((len(lengths), maxlen)).cumsum(dim=1).t() > lengths). \ t() - mask.type(dtype) + mask.type(dtype or torch.bool) return mask + + +def convert_to_non_torch_type(stats_dict): + """Converts values in stats_dict to non-Tensor numpy or python types. + + Args: + stats_dict (dict): A flat key, value dict, the values of which will be + converted and returned as a new dict. + + Returns: + dict: A new dict with the same structure as stats_dict, but with all + values converted to non-torch Tensor types. + """ + ret = {} + for k, v in stats_dict.items(): + if isinstance(v, torch.Tensor): + ret[k] = v.item() if len(v.size()) == 0 else v.numpy() + else: + ret[k] = v + return ret