mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 08:01:50 +08:00
[RLlib] PPO torch memory leak and unnecessary torch.Tensor creation and gc'ing. (#7238)
* Take out stats to analyze memory leak in torch PPO. * WIP * WIP * WIP * WIP * WIP * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * LINT. * Fix determine_tests_to_run.py. * minor change to re-test after determine_tests_to_run.py. * LINT. * update comments. * WIP * WIP * WIP * FIX. * Fix sequence_mask being dependent on torch being installed. * Fix strange ray-core tf-error in test_memory_scheduling test case. * Fix strange ray-core tf-error in test_memory_scheduling test case. * Fix strange ray-core tf-error in test_memory_scheduling test case. * Fix strange ray-core tf-error in test_memory_scheduling test case.
This commit is contained in:
+1
-1
@@ -199,7 +199,7 @@ matrix:
|
||||
- ./ci/suppress_output ./ci/travis/install-ray.sh
|
||||
script:
|
||||
- if [ $RAY_CI_RLLIB_FULL_AFFECTED != "1" ]; then exit; fi
|
||||
- travis_wait 30 bazel test --build_tests_only --test_tag_filters=quick_train --spawn_strategy=local --flaky_test_attempts=3 --nocache_test_results --test_verbose_timeout_warnings --progress_report_interval=100 --show_progress_rate_limit=100 --show_timestamps --test_output=errors rllib/...
|
||||
- travis_wait 60 bazel test --build_tests_only --test_tag_filters=quick_train --spawn_strategy=local --flaky_test_attempts=3 --nocache_test_results --test_verbose_timeout_warnings --progress_report_interval=100 --show_progress_rate_limit=100 --show_timestamps --test_output=errors rllib/...
|
||||
# Test everything that does not have any of the "main" labels:
|
||||
# "learning_tests|quick_train|examples|tests_dir".
|
||||
- ./ci/keep_alive bazel test --build_tests_only --test_tag_filters=-learning_tests,-quick_train,-examples,-tests_dir --spawn_strategy=local --flaky_test_attempts=3 --nocache_test_results --test_verbose_timeout_warnings --progress_report_interval=100 --show_progress_rate_limit=100 --show_timestamps --test_output=errors rllib/...
|
||||
|
||||
+2
-3
@@ -99,7 +99,6 @@ py_test(
|
||||
srcs = ["agents/impala/tests/test_vtrace.py"]
|
||||
)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# contrib Agents
|
||||
# --------------------------------------------------------------------
|
||||
@@ -1028,7 +1027,7 @@ py_test(
|
||||
|
||||
|
||||
py_test(
|
||||
name = "autoregressive_action_dist", main = "examples/autoregressive_action_dist.py",
|
||||
name = "examples/autoregressive_action_dist", main = "examples/autoregressive_action_dist.py",
|
||||
tags = ["examples", "examples_A"],
|
||||
size = "large",
|
||||
srcs = ["examples/autoregressive_action_dist.py"],
|
||||
@@ -1177,7 +1176,7 @@ py_test(
|
||||
py_test(
|
||||
name = "examples/custom_tf_policy",
|
||||
tags = ["examples", "examples_C"],
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = ["examples/custom_tf_policy.py"],
|
||||
args = ["--iters=2", "--num-cpus=4"]
|
||||
)
|
||||
|
||||
@@ -38,19 +38,22 @@ def add_advantages(policy,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
|
||||
completed = sample_batch[SampleBatch.DONES][-1]
|
||||
if completed:
|
||||
last_r = 0.0
|
||||
else:
|
||||
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1])
|
||||
return compute_advantages(
|
||||
sample_batch, last_r, policy.config["gamma"], policy.config["lambda"],
|
||||
policy.config["use_gae"], policy.config["use_critic"])
|
||||
|
||||
return compute_advantages(sample_batch, last_r, policy.config["gamma"],
|
||||
policy.config["lambda"],
|
||||
policy.config["use_gae"],
|
||||
policy.config["use_critic"])
|
||||
|
||||
|
||||
def model_value_predictions(policy, input_dict, state_batches, model,
|
||||
action_dist):
|
||||
return {SampleBatch.VF_PREDS: model.value_function().cpu().numpy()}
|
||||
return {SampleBatch.VF_PREDS: model.value_function()}
|
||||
|
||||
|
||||
def apply_grad_clipping(policy):
|
||||
@@ -68,9 +71,8 @@ def torch_optimizer(policy, config):
|
||||
|
||||
class ValueNetworkMixin:
|
||||
def _value(self, obs):
|
||||
obs = torch.from_numpy(obs).float().unsqueeze(0).to(self.device)
|
||||
_ = self.model({"obs": obs}, [], [1])
|
||||
return self.model.value_function().detach().cpu().numpy().squeeze()
|
||||
_ = self.model({"obs": torch.Tensor([obs]).to(self.device)}, [], [1])
|
||||
return self.model.value_function()[0]
|
||||
|
||||
|
||||
A3CTorchPolicy = build_torch_policy(
|
||||
|
||||
@@ -96,10 +96,12 @@ def choose_policy_optimizer(workers, config):
|
||||
|
||||
|
||||
def update_kl(trainer, fetches):
|
||||
# Single-agent.
|
||||
if "kl" in fetches:
|
||||
# single-agent
|
||||
trainer.workers.local_worker().for_policy(
|
||||
lambda pi: pi.update_kl(fetches["kl"]))
|
||||
|
||||
# Multi-agent.
|
||||
else:
|
||||
|
||||
def update(pi, pi_id):
|
||||
@@ -108,7 +110,6 @@ def update_kl(trainer, fetches):
|
||||
else:
|
||||
logger.debug("No data for {}, not updating kl".format(pi_id))
|
||||
|
||||
# multi-agent
|
||||
trainer.workers.local_worker().foreach_trainable_policy(update)
|
||||
|
||||
|
||||
|
||||
@@ -144,15 +144,15 @@ def kl_and_loss_stats(policy, train_batch):
|
||||
return {
|
||||
"cur_kl_coeff": policy.kl_coeff,
|
||||
"cur_lr": policy.cur_lr,
|
||||
"total_loss": policy.loss_obj.loss.cpu().detach().numpy(),
|
||||
"policy_loss": policy.loss_obj.mean_policy_loss.cpu().detach().numpy(),
|
||||
"vf_loss": policy.loss_obj.mean_vf_loss.cpu().detach().numpy(),
|
||||
"total_loss": policy.loss_obj.loss,
|
||||
"policy_loss": policy.loss_obj.mean_policy_loss,
|
||||
"vf_loss": policy.loss_obj.mean_vf_loss,
|
||||
"vf_explained_var": explained_variance(
|
||||
train_batch[Postprocessing.VALUE_TARGETS],
|
||||
policy.model.value_function(),
|
||||
framework="torch").cpu().detach().numpy(),
|
||||
"kl": policy.loss_obj.mean_kl.cpu().detach().numpy(),
|
||||
"entropy": policy.loss_obj.mean_entropy.cpu().detach().numpy(),
|
||||
framework="torch"),
|
||||
"kl": policy.loss_obj.mean_kl,
|
||||
"entropy": policy.loss_obj.mean_entropy,
|
||||
"entropy_coeff": policy.entropy_coeff,
|
||||
}
|
||||
|
||||
@@ -161,8 +161,8 @@ def vf_preds_and_logits_fetches(policy, input_dict, state_batches, model,
|
||||
action_dist):
|
||||
"""Adds value function and logits outputs to experience train_batches."""
|
||||
return {
|
||||
SampleBatch.VF_PREDS: policy.model.value_function().cpu().numpy(),
|
||||
BEHAVIOUR_LOGITS: policy.model.last_output().cpu().numpy(),
|
||||
SampleBatch.VF_PREDS: policy.model.value_function(),
|
||||
BEHAVIOUR_LOGITS: policy.model.last_output(),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -39,7 +39,6 @@ class TestPPO(unittest.TestCase):
|
||||
|
||||
# Torch.
|
||||
config["use_pytorch"] = True
|
||||
config["simple_optimizer"] = True
|
||||
trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
|
||||
for i in range(num_iterations):
|
||||
trainer.train()
|
||||
|
||||
@@ -4,6 +4,10 @@ from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.utils import add_mixins
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.torch_ops import convert_to_non_torch_type
|
||||
|
||||
torch, _ = try_import_torch()
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
@@ -97,8 +101,12 @@ def build_torch_policy(name,
|
||||
episode=None):
|
||||
if not postprocess_fn:
|
||||
return sample_batch
|
||||
return postprocess_fn(self, sample_batch, other_agent_batches,
|
||||
episode)
|
||||
|
||||
# Do all post-processing always with no_grad().
|
||||
# Not using this here will introduce a memory leak (issue #6962).
|
||||
with torch.no_grad():
|
||||
return postprocess_fn(self, sample_batch, other_agent_batches,
|
||||
episode)
|
||||
|
||||
@override(TorchPolicy)
|
||||
def extra_grad_process(self):
|
||||
@@ -110,14 +118,16 @@ def build_torch_policy(name,
|
||||
@override(TorchPolicy)
|
||||
def extra_action_out(self, input_dict, state_batches, model,
|
||||
action_dist=None):
|
||||
if extra_action_out_fn:
|
||||
return extra_action_out_fn(
|
||||
self, input_dict, state_batches, model, action_dist
|
||||
)
|
||||
else:
|
||||
return TorchPolicy.extra_action_out(
|
||||
self, input_dict, state_batches, model, action_dist
|
||||
)
|
||||
with torch.no_grad():
|
||||
if extra_action_out_fn:
|
||||
stats_dict = extra_action_out_fn(
|
||||
self, input_dict, state_batches, model, action_dist
|
||||
)
|
||||
else:
|
||||
stats_dict = TorchPolicy.extra_action_out(
|
||||
self, input_dict, state_batches, model, action_dist
|
||||
)
|
||||
return convert_to_non_torch_type(stats_dict)
|
||||
|
||||
@override(TorchPolicy)
|
||||
def optimizer(self):
|
||||
@@ -128,10 +138,12 @@ def build_torch_policy(name,
|
||||
|
||||
@override(TorchPolicy)
|
||||
def extra_grad_info(self, train_batch):
|
||||
if stats_fn:
|
||||
return stats_fn(self, train_batch)
|
||||
else:
|
||||
return TorchPolicy.extra_grad_info(self, train_batch)
|
||||
with torch.no_grad():
|
||||
if stats_fn:
|
||||
stats_dict = stats_fn(self, train_batch)
|
||||
else:
|
||||
stats_dict = TorchPolicy.extra_grad_info(self, train_batch)
|
||||
return convert_to_non_torch_type(stats_dict)
|
||||
|
||||
def with_updates(**overrides):
|
||||
return build_torch_policy(**dict(original_kwargs, **overrides))
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.exploration.exploration import Exploration
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch, \
|
||||
tf_function
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.tuple_actions import TupleActions
|
||||
|
||||
tf = try_import_tf()
|
||||
@@ -73,22 +72,29 @@ class StochasticSampling(Exploration):
|
||||
else:
|
||||
return self._get_tf_exploration_action_op(action_dist, explore)
|
||||
|
||||
@staticmethod
|
||||
@tf_function(tf)
|
||||
def _get_tf_exploration_action_op(action_dist, explore):
|
||||
if explore:
|
||||
action = action_dist.sample()
|
||||
# TODO(sven): Change `sample` to accept `sample(logp=True|False)`
|
||||
logp = action_dist.sampled_action_logp()
|
||||
else:
|
||||
action = action_dist.deterministic_sample()
|
||||
def _get_tf_exploration_action_op(self, action_dist, explore):
|
||||
sample = action_dist.sample()
|
||||
deterministic_sample = action_dist.deterministic_sample()
|
||||
action = tf.cond(
|
||||
tf.constant(explore) if isinstance(explore, bool) else explore,
|
||||
true_fn=lambda: sample,
|
||||
false_fn=lambda: deterministic_sample)
|
||||
|
||||
def logp_false_fn():
|
||||
# TODO(sven): Move into (deterministic_)sample(logp=True|False)
|
||||
if isinstance(action, TupleActions):
|
||||
batch_size = tf.shape(action[0][0])[0]
|
||||
if isinstance(sample, TupleActions):
|
||||
batch_size = tf.shape(action[0])[0]
|
||||
else:
|
||||
batch_size = tf.shape(action)[0]
|
||||
logp = tf.zeros(shape=(batch_size, ), dtype=tf.float32)
|
||||
return action, logp
|
||||
return tf.zeros(shape=(batch_size, ), dtype=tf.float32)
|
||||
|
||||
logp = tf.cond(
|
||||
tf.constant(explore) if isinstance(explore, bool) else explore,
|
||||
true_fn=lambda: action_dist.sampled_action_logp(),
|
||||
false_fn=logp_false_fn)
|
||||
|
||||
return TupleActions(action) if isinstance(sample, TupleActions) \
|
||||
else action, logp
|
||||
|
||||
@staticmethod
|
||||
def _get_torch_exploration_action(action_dist, explore):
|
||||
|
||||
@@ -100,7 +100,7 @@ def check(x, y, decimals=5, atol=None, rtol=None, false=False):
|
||||
else:
|
||||
with tf.Session() as sess:
|
||||
x = sess.run(x)
|
||||
check(
|
||||
return check(
|
||||
x,
|
||||
y,
|
||||
decimals=decimals,
|
||||
|
||||
@@ -3,7 +3,7 @@ from ray.rllib.utils.framework import try_import_torch
|
||||
torch, _ = try_import_torch()
|
||||
|
||||
|
||||
def sequence_mask(lengths, maxlen, dtype=torch.bool):
|
||||
def sequence_mask(lengths, maxlen, dtype=None):
|
||||
"""
|
||||
Exact same behavior as tf.sequence_mask.
|
||||
Thanks to Dimitris Papatheodorou
|
||||
@@ -15,6 +15,26 @@ def sequence_mask(lengths, maxlen, dtype=torch.bool):
|
||||
|
||||
mask = ~(torch.ones((len(lengths), maxlen)).cumsum(dim=1).t() > lengths). \
|
||||
t()
|
||||
mask.type(dtype)
|
||||
mask.type(dtype or torch.bool)
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def convert_to_non_torch_type(stats_dict):
|
||||
"""Converts values in stats_dict to non-Tensor numpy or python types.
|
||||
|
||||
Args:
|
||||
stats_dict (dict): A flat key, value dict, the values of which will be
|
||||
converted and returned as a new dict.
|
||||
|
||||
Returns:
|
||||
dict: A new dict with the same structure as stats_dict, but with all
|
||||
values converted to non-torch Tensor types.
|
||||
"""
|
||||
ret = {}
|
||||
for k, v in stats_dict.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
ret[k] = v.item() if len(v.size()) == 0 else v.numpy()
|
||||
else:
|
||||
ret[k] = v
|
||||
return ret
|
||||
|
||||
Reference in New Issue
Block a user