diff --git a/python/ray/rllib/agents/ppo/ppo.py b/python/ray/rllib/agents/ppo/ppo.py index 593fae1c4..6fd725a37 100644 --- a/python/ray/rllib/agents/ppo/ppo.py +++ b/python/ray/rllib/agents/ppo/ppo.py @@ -27,6 +27,8 @@ DEFAULT_CONFIG = with_common_config({ "train_batch_size": 4000, # Total SGD batch size across all devices for SGD "sgd_minibatch_size": 128, + # Whether to shuffle sequences in the batch when training (recommended) + "shuffle_sequences": True, # Number of SGD iterations in each outer loop "num_sgd_iter": 30, # Stepsize of SGD @@ -79,7 +81,8 @@ def choose_policy_optimizer(workers, config): num_envs_per_worker=config["num_envs_per_worker"], train_batch_size=config["train_batch_size"], standardize_fields=["advantages"], - straggler_mitigation=config["straggler_mitigation"]) + straggler_mitigation=config["straggler_mitigation"], + shuffle_sequences=config["shuffle_sequences"]) def update_kl(trainer, fetches): diff --git a/python/ray/rllib/models/lstm.py b/python/ray/rllib/models/lstm.py index 62b854a86..4d7e1f525 100644 --- a/python/ray/rllib/models/lstm.py +++ b/python/ray/rllib/models/lstm.py @@ -128,6 +128,7 @@ def chop_into_sequences(episode_ids, state_columns, max_seq_len, dynamic_max=True, + shuffle=False, _extra_padding=0): """Truncate and pad experiences into fixed-length sequences. @@ -143,6 +144,7 @@ def chop_into_sequences(episode_ids, dynamic_max (bool): Whether to dynamically shrink the max seq len. For example, if max len is 20 and the actual max seq len in the data is 7, it will be shrunk to 7. + shuffle (bool): Whether to shuffle the sequence outputs. _extra_padding (int): Add extra padding to the end of sequences. Returns: @@ -186,6 +188,7 @@ def chop_into_sequences(episode_ids, if seq_len: seq_lens.append(seq_len) assert sum(seq_lens) == len(unique_ids) + seq_lens = np.array(seq_lens) # Dynamically shrink max len as needed to optimize memory usage if dynamic_max: @@ -215,4 +218,17 @@ def chop_into_sequences(episode_ids, i += l initial_states.append(np.array(s_init)) - return feature_sequences, initial_states, np.array(seq_lens) + if shuffle: + permutation = np.random.permutation(len(seq_lens)) + for i, f in enumerate(feature_sequences): + orig_shape = f.shape + f = np.reshape(f, (len(seq_lens), -1) + f.shape[2:]) + f = f[permutation] + f = np.reshape(f, orig_shape) + feature_sequences[i] = f + for i, s in enumerate(initial_states): + s = s[permutation] + initial_states[i] = s + seq_lens = seq_lens[permutation] + + return feature_sequences, initial_states, seq_lens diff --git a/python/ray/rllib/optimizers/multi_gpu_optimizer.py b/python/ray/rllib/optimizers/multi_gpu_optimizer.py index 65d7842d8..11fd80e01 100644 --- a/python/ray/rllib/optimizers/multi_gpu_optimizer.py +++ b/python/ray/rllib/optimizers/multi_gpu_optimizer.py @@ -50,7 +50,8 @@ class LocalMultiGPUOptimizer(PolicyOptimizer): train_batch_size=1024, num_gpus=0, standardize_fields=[], - straggler_mitigation=False): + straggler_mitigation=False, + shuffle_sequences=True): PolicyOptimizer.__init__(self, workers) self.batch_size = sgd_batch_size @@ -59,6 +60,7 @@ class LocalMultiGPUOptimizer(PolicyOptimizer): self.sample_batch_size = sample_batch_size self.train_batch_size = train_batch_size self.straggler_mitigation = straggler_mitigation + self.shuffle_sequences = shuffle_sequences if not num_gpus: self.devices = ["/cpu:0"] else: @@ -157,10 +159,6 @@ class LocalMultiGPUOptimizer(PolicyOptimizer): standardized = (value - value.mean()) / max(1e-4, value.std()) batch[field] = standardized - # Important: don't shuffle RNN sequence elements - if not policy._state_inputs: - batch.shuffle() - num_loaded_tuples = {} with self.load_timer: for policy_id, batch in samples.policy_batches.items(): @@ -168,7 +166,8 @@ class LocalMultiGPUOptimizer(PolicyOptimizer): continue policy = self.policies[policy_id] - tuples = policy._get_loss_inputs_dict(batch) + tuples = policy._get_loss_inputs_dict( + batch, shuffle=self.shuffle_sequences) data_keys = [ph for _, ph in policy._loss_inputs] if policy._state_inputs: state_keys = policy._state_inputs + [policy._seq_lens] diff --git a/python/ray/rllib/policy/tf_policy.py b/python/ray/rllib/policy/tf_policy.py index 0642283e8..7348c7b04 100644 --- a/python/ray/rllib/policy/tf_policy.py +++ b/python/ray/rllib/policy/tf_policy.py @@ -438,7 +438,8 @@ class TFPolicy(Policy): def _build_compute_gradients(self, builder, postprocessed_batch): builder.add_feed_dict(self.extra_compute_grad_feed_dict()) builder.add_feed_dict({self._is_training: True}) - builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch)) + builder.add_feed_dict( + self._get_loss_inputs_dict(postprocessed_batch, shuffle=False)) fetches = builder.add_fetches( [self._grads, self._get_grad_and_stats_fetches()]) return fetches[0], fetches[1] @@ -455,7 +456,8 @@ class TFPolicy(Policy): def _build_learn_on_batch(self, builder, postprocessed_batch): builder.add_feed_dict(self.extra_compute_grad_feed_dict()) - builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch)) + builder.add_feed_dict( + self._get_loss_inputs_dict(postprocessed_batch, shuffle=False)) builder.add_feed_dict({self._is_training: True}) fetches = builder.add_fetches([ self._apply_op, @@ -473,7 +475,19 @@ class TFPolicy(Policy): **fetches[LEARNER_STATS_KEY]) return fetches - def _get_loss_inputs_dict(self, batch): + def _get_loss_inputs_dict(self, batch, shuffle): + """Return a feed dict from a batch. + + Arguments: + batch (SampleBatch): batch of data to derive inputs from + shuffle (bool): whether to shuffle batch sequences. Shuffle may + be done in-place. This only makes sense if you're further + applying minibatch SGD after getting the outputs. + + Returns: + feed dict of data + """ + feed_dict = {} if self._batch_divisibility_req > 1: meets_divisibility_reqs = ( @@ -485,6 +499,8 @@ class TFPolicy(Policy): # Simple case: not RNN nor do we need to pad if not self._state_inputs and meets_divisibility_reqs: + if shuffle: + batch.shuffle() for k, ph in self._loss_inputs: feed_dict[ph] = batch[k] return feed_dict @@ -507,7 +523,8 @@ class TFPolicy(Policy): batch[SampleBatch.AGENT_INDEX], [batch[k] for k in feature_keys], [batch[k] for k in state_keys], max_seq_len, - dynamic_max=dynamic_max) + dynamic_max=dynamic_max, + shuffle=shuffle) for k, v in zip(feature_keys, feature_sequences): feed_dict[self._loss_input_dict[k]] = v for k, v in zip(state_keys, initial_states): diff --git a/python/ray/rllib/tests/test_lstm.py b/python/ray/rllib/tests/test_lstm.py index fb8e6a20b..7cca28d2b 100644 --- a/python/ray/rllib/tests/test_lstm.py +++ b/python/ray/rllib/tests/test_lstm.py @@ -229,6 +229,7 @@ class RNNSequencing(unittest.TestCase): ppo = PPOTrainer( env="counter", config={ + "shuffle_sequences": False, # for deterministic testing "num_workers": 0, "sample_batch_size": 20, "train_batch_size": 20,