diff --git a/doc/source/rllib-training.rst b/doc/source/rllib-training.rst index 7af13bb7c..1b7f13c41 100644 --- a/doc/source/rllib-training.rst +++ b/doc/source/rllib-training.rst @@ -99,7 +99,7 @@ Here are some rules of thumb for scaling training with RLlib. 3. If the model is compute intensive (e.g., a large deep residual network) and inference is the bottleneck, consider allocating GPUs to workers by setting ``num_gpus_per_worker: 1``. If you only have a single GPU, consider ``num_workers: 0`` to use the learner GPU for inference. For efficient use of GPU time, use a small number of GPU workers and a large number of `envs per worker `__. - + 4. Finally, if both model and environment are compute intensive, then enable `remote worker envs `__ with `async batching `__ by setting ``remote_worker_envs: True`` and optionally ``remote_env_batch_wait_ms``. This batches inference on GPUs in the rollout workers while letting envs run asynchronously in separate actors, similar to the `SEED `__ architecture. The number of workers and number of envs per worker should be tuned to maximize GPU utilization. If your env requires GPUs to function, or if multi-node SGD is needed, then also consider :ref:`DD-PPO `. Common Parameters diff --git a/python/ray/util/sgd/torch/examples/mnist_cnn.pt b/python/ray/util/sgd/torch/examples/mnist_cnn.pt new file mode 100644 index 000000000..1c4364e16 Binary files /dev/null and b/python/ray/util/sgd/torch/examples/mnist_cnn.pt differ diff --git a/rllib/models/tests/test_distributions.py b/rllib/models/tests/test_distributions.py index ebd3525ac..ae7a1bdf4 100644 --- a/rllib/models/tests/test_distributions.py +++ b/rllib/models/tests/test_distributions.py @@ -9,7 +9,7 @@ from ray.rllib.models.torch.torch_action_dist import TorchMultiCategorical, \ TorchSquashedGaussian, TorchBeta from ray.rllib.utils import try_import_tf, try_import_torch from ray.rllib.utils.numpy import MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT, \ - softmax, SMALL_NUMBER + softmax, SMALL_NUMBER, LARGE_INTEGER from ray.rllib.utils.test_utils import check, framework_iterator tf = try_import_tf() @@ -19,6 +19,47 @@ torch, _ = try_import_torch() class TestDistributions(unittest.TestCase): """Tests ActionDistribution classes.""" + def _stability_test(self, + distribution_cls, + network_output_shape, + fw, + sess=None, + bounds=None): + extreme_values = [ + 0.0, + float(LARGE_INTEGER), + -float(LARGE_INTEGER), + 1.1e-34, + 1.1e34, + -1.1e-34, + -1.1e34, + SMALL_NUMBER, + -SMALL_NUMBER, + ] + inputs = np.zeros(shape=network_output_shape, dtype=np.float32) + for batch_item in range(network_output_shape[0]): + for num in range(len(inputs[batch_item])): + inputs[batch_item][num] = np.random.choice(extreme_values) + dist = distribution_cls(inputs, {}) + for _ in range(100): + sample = dist.sample() + if fw != "tf": + sample_check = sample.numpy() + else: + sample_check = sess.run(sample) + assert not np.any(np.isnan(sample_check)) + assert np.all(np.isfinite(sample_check)) + if bounds: + assert np.min(sample_check) >= bounds[0] + assert np.max(sample_check) <= bounds[1] + logp = dist.logp(sample) + if fw != "tf": + logp_check = logp.numpy() + else: + logp_check = sess.run(logp) + assert not np.any(np.isnan(logp_check)) + assert np.all(np.isfinite(logp_check)) + def test_categorical(self): """Tests the Categorical ActionDistribution (tf only).""" num_samples = 100000 @@ -103,9 +144,15 @@ class TestDistributions(unittest.TestCase): input_space = Box(-2.0, 2.0, shape=(200, 10)) low, high = -2.0, 1.0 - for fw, sess in framework_iterator(session=True): + for fw, sess in framework_iterator( + frameworks=("torch", "tf", "eager"), session=True): cls = SquashedGaussian if fw != "torch" else TorchSquashedGaussian + # Do a stability test using extreme NN outputs to see whether + # sampling and logp'ing result in NaN or +/-inf values. + self._stability_test( + cls, input_space.shape, fw=fw, sess=sess, bounds=(low, high)) + # Batch of size=n and deterministic. inputs = input_space.sample() means, _ = np.split(inputs, 2, axis=-1) @@ -125,8 +172,8 @@ class TestDistributions(unittest.TestCase): values = sess.run(values) else: values = values.numpy() - self.assertTrue(np.max(values) < high) - self.assertTrue(np.min(values) > low) + self.assertTrue(np.max(values) <= high) + self.assertTrue(np.min(values) >= low) check(np.mean(values), expected.mean(), decimals=1) @@ -143,11 +190,13 @@ class TestDistributions(unittest.TestCase): # Unsquash values, then get log-llh from regular gaussian. # atanh_in = np.clip((values - low) / (high - low) * 2.0 - 1.0, # -1.0 + SMALL_NUMBER, 1.0 - SMALL_NUMBER) - atanh_in = (values - low) / (high - low) * 2.0 - 1.0 - unsquashed_values = np.arctanh(atanh_in) + normed_values = (values - low) / (high - low) * 2.0 - 1.0 + save_normed_values = np.clip(normed_values, -1.0 + SMALL_NUMBER, + 1.0 - SMALL_NUMBER) + unsquashed_values = np.arctanh(save_normed_values) log_prob_unsquashed = np.sum( - np.log( - norm.pdf(unsquashed_values, means, stds) + SMALL_NUMBER), + np.log(norm.pdf(unsquashed_values, means, + stds)), -1) log_prob = log_prob_unsquashed - \ np.sum(np.log(1 - np.tanh(unsquashed_values) ** 2), diff --git a/rllib/models/tf/tf_action_dist.py b/rllib/models/tf/tf_action_dist.py index 16288aecc..818de8234 100644 --- a/rllib/models/tf/tf_action_dist.py +++ b/rllib/models/tf/tf_action_dist.py @@ -294,27 +294,34 @@ class SquashedGaussian(TFActionDistribution): @override(ActionDistribution) def logp(self, x): + # Unsquash values (from [low,high] to ]-inf,inf[) unsquashed_values = self._unsquash(x) - log_prob = tf.reduce_sum( - self.distr.log_prob(value=unsquashed_values), axis=-1) + # Get log prob of unsquashed values from our Normal. + log_prob_gaussian = self.distr.log_prob(unsquashed_values) + # For safety reasons, clamp somehow, only then sum up. + log_prob_gaussian = tf.clip_by_value(log_prob_gaussian, -100, 100) + log_prob_gaussian = tf.reduce_sum(log_prob_gaussian, axis=-1) + # Get log-prob for squashed Gaussian. unsquashed_values_tanhd = tf.math.tanh(unsquashed_values) - log_prob -= tf.math.reduce_sum( + log_prob = log_prob_gaussian - tf.reduce_sum( tf.math.log(1 - unsquashed_values_tanhd**2 + SMALL_NUMBER), axis=-1) return log_prob def _squash(self, raw_values): - # Make sure raw_values are not too high/low (such that tanh would - # return exactly 1.0/-1.0, which would lead to +/-inf log-probs). - return (tf.clip_by_value( - tf.math.tanh(raw_values), - -1.0 + SMALL_NUMBER, - 1.0 - SMALL_NUMBER) + 1.0) / 2.0 * (self.high - self.low) + \ - self.low + # Returned values are within [low, high] (including `low` and `high`). + squashed = ((tf.math.tanh(raw_values) + 1.0) / 2.0) * \ + (self.high - self.low) + self.low + return tf.clip_by_value(squashed, self.low, self.high) def _unsquash(self, values): - return tf.math.atanh((values - self.low) / - (self.high - self.low) * 2.0 - 1.0) + normed_values = (values - self.low) / (self.high - self.low) * 2.0 - \ + 1.0 + # Stabilize input to atanh. + save_normed_values = tf.clip_by_value( + normed_values, -1.0 + SMALL_NUMBER, 1.0 - SMALL_NUMBER) + unsquashed = tf.math.atanh(save_normed_values) + return unsquashed class Deterministic(TFActionDistribution): diff --git a/rllib/models/torch/torch_action_dist.py b/rllib/models/torch/torch_action_dist.py index a4210d695..3e1810e09 100644 --- a/rllib/models/torch/torch_action_dist.py +++ b/rllib/models/torch/torch_action_dist.py @@ -205,24 +205,33 @@ class TorchSquashedGaussian(TorchDistributionWrapper): @override(ActionDistribution) def logp(self, x): + # Unsquash values (from [low,high] to ]-inf,inf[) unsquashed_values = self._unsquash(x) - log_prob = torch.sum(self.dist.log_prob(unsquashed_values), dim=-1) + # Get log prob of unsquashed values from our Normal. + log_prob_gaussian = self.dist.log_prob(unsquashed_values) + # For safety reasons, clamp somehow, only then sum up. + log_prob_gaussian = torch.clamp(log_prob_gaussian, -100, 100) + log_prob_gaussian = torch.sum(log_prob_gaussian, dim=-1) + # Get log-prob for squashed Gaussian. unsquashed_values_tanhd = torch.tanh(unsquashed_values) - log_prob -= torch.sum( + log_prob = log_prob_gaussian - torch.sum( torch.log(1 - unsquashed_values_tanhd**2 + SMALL_NUMBER), dim=-1) return log_prob def _squash(self, raw_values): - # Make sure raw_values are not too high/low (such that tanh would - # return exactly 1.0/-1.0, which would lead to +/-inf log-probs). - return (torch.clamp( - torch.tanh(raw_values), - -1.0 + SMALL_NUMBER, - 1.0 - SMALL_NUMBER) + 1.0) / 2.0 * (self.high - self.low) + \ - self.low + # Returned values are within [low, high] (including `low` and `high`). + squashed = ((torch.tanh(raw_values) + 1.0) / 2.0) * \ + (self.high - self.low) + self.low + return torch.clamp(squashed, self.low, self.high) def _unsquash(self, values): - return atanh((values - self.low) / (self.high - self.low) * 2.0 - 1.0) + normed_values = (values - self.low) / (self.high - self.low) * 2.0 - \ + 1.0 + # Stabilize input to atanh. + save_normed_values = torch.clamp(normed_values, -1.0 + SMALL_NUMBER, + 1.0 - SMALL_NUMBER) + unsquashed = atanh(save_normed_values) + return unsquashed class TorchBeta(TorchDistributionWrapper):