[RLlib] PPO torch memory leak and unnecessary torch.Tensor creation and gc'ing. (#7238)

* Take out stats to analyze memory leak in torch PPO.

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* LINT.

* Fix determine_tests_to_run.py.

* minor change to re-test after determine_tests_to_run.py.

* LINT.

* update comments.

* WIP

* WIP

* WIP

* FIX.

* Fix sequence_mask being dependent on torch being installed.

* Fix strange ray-core tf-error in test_memory_scheduling test case.

* Fix strange ray-core tf-error in test_memory_scheduling test case.

* Fix strange ray-core tf-error in test_memory_scheduling test case.

* Fix strange ray-core tf-error in test_memory_scheduling test case.
This commit is contained in:
Sven Mika
2020-02-22 20:02:31 +01:00
committed by GitHub
parent 01dd520797
commit e2edca45d4
10 changed files with 93 additions and 54 deletions
+1 -1
View File
@@ -199,7 +199,7 @@ matrix:
- ./ci/suppress_output ./ci/travis/install-ray.sh
script:
- if [ $RAY_CI_RLLIB_FULL_AFFECTED != "1" ]; then exit; fi
- travis_wait 30 bazel test --build_tests_only --test_tag_filters=quick_train --spawn_strategy=local --flaky_test_attempts=3 --nocache_test_results --test_verbose_timeout_warnings --progress_report_interval=100 --show_progress_rate_limit=100 --show_timestamps --test_output=errors rllib/...
- travis_wait 60 bazel test --build_tests_only --test_tag_filters=quick_train --spawn_strategy=local --flaky_test_attempts=3 --nocache_test_results --test_verbose_timeout_warnings --progress_report_interval=100 --show_progress_rate_limit=100 --show_timestamps --test_output=errors rllib/...
# Test everything that does not have any of the "main" labels:
# "learning_tests|quick_train|examples|tests_dir".
- ./ci/keep_alive bazel test --build_tests_only --test_tag_filters=-learning_tests,-quick_train,-examples,-tests_dir --spawn_strategy=local --flaky_test_attempts=3 --nocache_test_results --test_verbose_timeout_warnings --progress_report_interval=100 --show_progress_rate_limit=100 --show_timestamps --test_output=errors rllib/...
+2 -3
View File
@@ -99,7 +99,6 @@ py_test(
srcs = ["agents/impala/tests/test_vtrace.py"]
)
# --------------------------------------------------------------------
# contrib Agents
# --------------------------------------------------------------------
@@ -1028,7 +1027,7 @@ py_test(
py_test(
name = "autoregressive_action_dist", main = "examples/autoregressive_action_dist.py",
name = "examples/autoregressive_action_dist", main = "examples/autoregressive_action_dist.py",
tags = ["examples", "examples_A"],
size = "large",
srcs = ["examples/autoregressive_action_dist.py"],
@@ -1177,7 +1176,7 @@ py_test(
py_test(
name = "examples/custom_tf_policy",
tags = ["examples", "examples_C"],
size = "small",
size = "medium",
srcs = ["examples/custom_tf_policy.py"],
args = ["--iters=2", "--num-cpus=4"]
)
+9 -7
View File
@@ -38,19 +38,22 @@ def add_advantages(policy,
sample_batch,
other_agent_batches=None,
episode=None):
completed = sample_batch[SampleBatch.DONES][-1]
if completed:
last_r = 0.0
else:
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1])
return compute_advantages(
sample_batch, last_r, policy.config["gamma"], policy.config["lambda"],
policy.config["use_gae"], policy.config["use_critic"])
return compute_advantages(sample_batch, last_r, policy.config["gamma"],
policy.config["lambda"],
policy.config["use_gae"],
policy.config["use_critic"])
def model_value_predictions(policy, input_dict, state_batches, model,
action_dist):
return {SampleBatch.VF_PREDS: model.value_function().cpu().numpy()}
return {SampleBatch.VF_PREDS: model.value_function()}
def apply_grad_clipping(policy):
@@ -68,9 +71,8 @@ def torch_optimizer(policy, config):
class ValueNetworkMixin:
def _value(self, obs):
obs = torch.from_numpy(obs).float().unsqueeze(0).to(self.device)
_ = self.model({"obs": obs}, [], [1])
return self.model.value_function().detach().cpu().numpy().squeeze()
_ = self.model({"obs": torch.Tensor([obs]).to(self.device)}, [], [1])
return self.model.value_function()[0]
A3CTorchPolicy = build_torch_policy(
+3 -2
View File
@@ -96,10 +96,12 @@ def choose_policy_optimizer(workers, config):
def update_kl(trainer, fetches):
# Single-agent.
if "kl" in fetches:
# single-agent
trainer.workers.local_worker().for_policy(
lambda pi: pi.update_kl(fetches["kl"]))
# Multi-agent.
else:
def update(pi, pi_id):
@@ -108,7 +110,6 @@ def update_kl(trainer, fetches):
else:
logger.debug("No data for {}, not updating kl".format(pi_id))
# multi-agent
trainer.workers.local_worker().foreach_trainable_policy(update)
+8 -8
View File
@@ -144,15 +144,15 @@ def kl_and_loss_stats(policy, train_batch):
return {
"cur_kl_coeff": policy.kl_coeff,
"cur_lr": policy.cur_lr,
"total_loss": policy.loss_obj.loss.cpu().detach().numpy(),
"policy_loss": policy.loss_obj.mean_policy_loss.cpu().detach().numpy(),
"vf_loss": policy.loss_obj.mean_vf_loss.cpu().detach().numpy(),
"total_loss": policy.loss_obj.loss,
"policy_loss": policy.loss_obj.mean_policy_loss,
"vf_loss": policy.loss_obj.mean_vf_loss,
"vf_explained_var": explained_variance(
train_batch[Postprocessing.VALUE_TARGETS],
policy.model.value_function(),
framework="torch").cpu().detach().numpy(),
"kl": policy.loss_obj.mean_kl.cpu().detach().numpy(),
"entropy": policy.loss_obj.mean_entropy.cpu().detach().numpy(),
framework="torch"),
"kl": policy.loss_obj.mean_kl,
"entropy": policy.loss_obj.mean_entropy,
"entropy_coeff": policy.entropy_coeff,
}
@@ -161,8 +161,8 @@ def vf_preds_and_logits_fetches(policy, input_dict, state_batches, model,
action_dist):
"""Adds value function and logits outputs to experience train_batches."""
return {
SampleBatch.VF_PREDS: policy.model.value_function().cpu().numpy(),
BEHAVIOUR_LOGITS: policy.model.last_output().cpu().numpy(),
SampleBatch.VF_PREDS: policy.model.value_function(),
BEHAVIOUR_LOGITS: policy.model.last_output(),
}
-1
View File
@@ -39,7 +39,6 @@ class TestPPO(unittest.TestCase):
# Torch.
config["use_pytorch"] = True
config["simple_optimizer"] = True
trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
for i in range(num_iterations):
trainer.train()
+26 -14
View File
@@ -4,6 +4,10 @@ from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils import add_mixins
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import convert_to_non_torch_type
torch, _ = try_import_torch()
@DeveloperAPI
@@ -97,8 +101,12 @@ def build_torch_policy(name,
episode=None):
if not postprocess_fn:
return sample_batch
return postprocess_fn(self, sample_batch, other_agent_batches,
episode)
# Do all post-processing always with no_grad().
# Not using this here will introduce a memory leak (issue #6962).
with torch.no_grad():
return postprocess_fn(self, sample_batch, other_agent_batches,
episode)
@override(TorchPolicy)
def extra_grad_process(self):
@@ -110,14 +118,16 @@ def build_torch_policy(name,
@override(TorchPolicy)
def extra_action_out(self, input_dict, state_batches, model,
action_dist=None):
if extra_action_out_fn:
return extra_action_out_fn(
self, input_dict, state_batches, model, action_dist
)
else:
return TorchPolicy.extra_action_out(
self, input_dict, state_batches, model, action_dist
)
with torch.no_grad():
if extra_action_out_fn:
stats_dict = extra_action_out_fn(
self, input_dict, state_batches, model, action_dist
)
else:
stats_dict = TorchPolicy.extra_action_out(
self, input_dict, state_batches, model, action_dist
)
return convert_to_non_torch_type(stats_dict)
@override(TorchPolicy)
def optimizer(self):
@@ -128,10 +138,12 @@ def build_torch_policy(name,
@override(TorchPolicy)
def extra_grad_info(self, train_batch):
if stats_fn:
return stats_fn(self, train_batch)
else:
return TorchPolicy.extra_grad_info(self, train_batch)
with torch.no_grad():
if stats_fn:
stats_dict = stats_fn(self, train_batch)
else:
stats_dict = TorchPolicy.extra_grad_info(self, train_batch)
return convert_to_non_torch_type(stats_dict)
def with_updates(**overrides):
return build_torch_policy(**dict(original_kwargs, **overrides))
+21 -15
View File
@@ -1,8 +1,7 @@
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.utils.annotations import override
from ray.rllib.utils.exploration.exploration import Exploration
from ray.rllib.utils.framework import try_import_tf, try_import_torch, \
tf_function
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.tuple_actions import TupleActions
tf = try_import_tf()
@@ -73,22 +72,29 @@ class StochasticSampling(Exploration):
else:
return self._get_tf_exploration_action_op(action_dist, explore)
@staticmethod
@tf_function(tf)
def _get_tf_exploration_action_op(action_dist, explore):
if explore:
action = action_dist.sample()
# TODO(sven): Change `sample` to accept `sample(logp=True|False)`
logp = action_dist.sampled_action_logp()
else:
action = action_dist.deterministic_sample()
def _get_tf_exploration_action_op(self, action_dist, explore):
sample = action_dist.sample()
deterministic_sample = action_dist.deterministic_sample()
action = tf.cond(
tf.constant(explore) if isinstance(explore, bool) else explore,
true_fn=lambda: sample,
false_fn=lambda: deterministic_sample)
def logp_false_fn():
# TODO(sven): Move into (deterministic_)sample(logp=True|False)
if isinstance(action, TupleActions):
batch_size = tf.shape(action[0][0])[0]
if isinstance(sample, TupleActions):
batch_size = tf.shape(action[0])[0]
else:
batch_size = tf.shape(action)[0]
logp = tf.zeros(shape=(batch_size, ), dtype=tf.float32)
return action, logp
return tf.zeros(shape=(batch_size, ), dtype=tf.float32)
logp = tf.cond(
tf.constant(explore) if isinstance(explore, bool) else explore,
true_fn=lambda: action_dist.sampled_action_logp(),
false_fn=logp_false_fn)
return TupleActions(action) if isinstance(sample, TupleActions) \
else action, logp
@staticmethod
def _get_torch_exploration_action(action_dist, explore):
+1 -1
View File
@@ -100,7 +100,7 @@ def check(x, y, decimals=5, atol=None, rtol=None, false=False):
else:
with tf.Session() as sess:
x = sess.run(x)
check(
return check(
x,
y,
decimals=decimals,
+22 -2
View File
@@ -3,7 +3,7 @@ from ray.rllib.utils.framework import try_import_torch
torch, _ = try_import_torch()
def sequence_mask(lengths, maxlen, dtype=torch.bool):
def sequence_mask(lengths, maxlen, dtype=None):
"""
Exact same behavior as tf.sequence_mask.
Thanks to Dimitris Papatheodorou
@@ -15,6 +15,26 @@ def sequence_mask(lengths, maxlen, dtype=torch.bool):
mask = ~(torch.ones((len(lengths), maxlen)).cumsum(dim=1).t() > lengths). \
t()
mask.type(dtype)
mask.type(dtype or torch.bool)
return mask
def convert_to_non_torch_type(stats_dict):
"""Converts values in stats_dict to non-Tensor numpy or python types.
Args:
stats_dict (dict): A flat key, value dict, the values of which will be
converted and returned as a new dict.
Returns:
dict: A new dict with the same structure as stats_dict, but with all
values converted to non-torch Tensor types.
"""
ret = {}
for k, v in stats_dict.items():
if isinstance(v, torch.Tensor):
ret[k] = v.item() if len(v.size()) == 0 else v.numpy()
else:
ret[k] = v
return ret