From f9b8e77e3b701f0b2de96062f8f6f9087bf3aa6c Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sun, 7 Apr 2019 12:11:30 -0700 Subject: [PATCH] [rllib] Don't merge unrolls from same episode when calculating seq lens (#4557) --- .../rllib/agents/qmix/qmix_policy_graph.py | 1 + python/ray/rllib/evaluation/sample_batch.py | 5 +++ .../rllib/evaluation/sample_batch_builder.py | 4 ++ .../ray/rllib/evaluation/tf_policy_graph.py | 1 + python/ray/rllib/models/lstm.py | 11 ++++- python/ray/rllib/optimizers/multi_gpu_impl.py | 42 ++++++++++++++----- python/ray/rllib/tests/test_lstm.py | 29 ++++++++++--- .../ray/rllib/tests/test_policy_evaluator.py | 13 +++++- 8 files changed, 87 insertions(+), 19 deletions(-) diff --git a/python/ray/rllib/agents/qmix/qmix_policy_graph.py b/python/ray/rllib/agents/qmix/qmix_policy_graph.py index 102dcce5d..1758b0adb 100644 --- a/python/ray/rllib/agents/qmix/qmix_policy_graph.py +++ b/python/ray/rllib/agents/qmix/qmix_policy_graph.py @@ -245,6 +245,7 @@ class QMixPolicyGraph(PolicyGraph): [rew, action_mask, act, dones, obs], initial_states, seq_lens = \ chop_into_sequences( samples[SampleBatch.EPS_ID], + samples[SampleBatch.UNROLL_ID], samples[SampleBatch.AGENT_INDEX], [ group_rewards, action_mask, samples[SampleBatch.ACTIONS], samples[SampleBatch.DONES], obs_batch diff --git a/python/ray/rllib/evaluation/sample_batch.py b/python/ray/rllib/evaluation/sample_batch.py index bbc345e0d..c80f22bdb 100644 --- a/python/ray/rllib/evaluation/sample_batch.py +++ b/python/ray/rllib/evaluation/sample_batch.py @@ -106,6 +106,11 @@ class SampleBatch(object): # Uniquely identifies an episode EPS_ID = "eps_id" + # Uniquely identifies a sample batch. This is important to distinguish RNN + # sequences from the same episode when multiple sample batches are + # concatenated (fusing sequences across batches can be unsafe). + UNROLL_ID = "unroll_id" + # Uniquely identifies an agent within an episode AGENT_INDEX = "agent_index" diff --git a/python/ray/rllib/evaluation/sample_batch_builder.py b/python/ray/rllib/evaluation/sample_batch_builder.py index 973388075..c6d69d7d9 100644 --- a/python/ray/rllib/evaluation/sample_batch_builder.py +++ b/python/ray/rllib/evaluation/sample_batch_builder.py @@ -32,6 +32,7 @@ class SampleBatchBuilder(object): def __init__(self): self.buffers = collections.defaultdict(list) self.count = 0 + self.unroll_id = 0 # disambiguates unrolls within a single episode @PublicAPI def add_values(self, **values): @@ -56,8 +57,11 @@ class SampleBatchBuilder(object): batch = SampleBatch( {k: to_float_array(v) for k, v in self.buffers.items()}) + batch.data[SampleBatch.UNROLL_ID] = np.repeat(self.unroll_id, + batch.count) self.buffers.clear() self.count = 0 + self.unroll_id += 1 return batch diff --git a/python/ray/rllib/evaluation/tf_policy_graph.py b/python/ray/rllib/evaluation/tf_policy_graph.py index ed39ec212..561aef332 100644 --- a/python/ray/rllib/evaluation/tf_policy_graph.py +++ b/python/ray/rllib/evaluation/tf_policy_graph.py @@ -464,6 +464,7 @@ class TFPolicyGraph(PolicyGraph): ] feature_sequences, initial_states, seq_lens = chop_into_sequences( batch[SampleBatch.EPS_ID], + batch[SampleBatch.UNROLL_ID], batch[SampleBatch.AGENT_INDEX], [batch[k] for k in feature_keys], [batch[k] for k in state_keys], max_seq_len, diff --git a/python/ray/rllib/models/lstm.py b/python/ray/rllib/models/lstm.py index b4da4f1fb..b3fb557b8 100644 --- a/python/ray/rllib/models/lstm.py +++ b/python/ray/rllib/models/lstm.py @@ -121,6 +121,7 @@ def add_time_dimension(padded_inputs, seq_lens): @DeveloperAPI def chop_into_sequences(episode_ids, + unroll_ids, agent_indices, feature_columns, state_columns, @@ -131,6 +132,8 @@ def chop_into_sequences(episode_ids, Arguments: episode_ids (list): List of episode ids for each step. + unroll_ids (list): List of identifiers for the sample batch. This is + used to make sure sequences are cut between sample batches. agent_indices (list): List of agent ids for each step. Note that this has to be combined with episode_ids for uniqueness. feature_columns (list): List of arrays containing features. @@ -150,7 +153,9 @@ def chop_into_sequences(episode_ids, Examples: >>> f_pad, s_init, seq_lens = chop_into_sequences( - episode_id=[1, 1, 5, 5, 5, 5], + episode_ids=[1, 1, 5, 5, 5, 5], + unroll_ids=[4, 4, 4, 4, 4, 4], + agent_indices=[0, 0, 0, 0, 0, 0], feature_columns=[[4, 4, 8, 8, 8, 8], [1, 1, 0, 1, 1, 0]], state_columns=[[4, 5, 4, 5, 5, 5]], @@ -167,7 +172,9 @@ def chop_into_sequences(episode_ids, prev_id = None seq_lens = [] seq_len = 0 - unique_ids = np.add(episode_ids, agent_indices) + unique_ids = np.add( + np.add(episode_ids, agent_indices), + np.array(unroll_ids) << 32) for uid in unique_ids: if (prev_id is not None and uid != prev_id) or \ seq_len >= max_seq_len: diff --git a/python/ray/rllib/optimizers/multi_gpu_impl.py b/python/ray/rllib/optimizers/multi_gpu_impl.py index 5a418f79d..9241e869c 100644 --- a/python/ray/rllib/optimizers/multi_gpu_impl.py +++ b/python/ray/rllib/optimizers/multi_gpu_impl.py @@ -164,35 +164,55 @@ class LocalSyncParallelOptimizer(object): smallest_array = inputs[0] self._loaded_max_seq_len = 1 - seq_batch_size = (self.max_per_device_batch_size // - self._loaded_max_seq_len * len(self.devices)) - if len(smallest_array) < seq_batch_size: + sequences_per_minibatch = ( + self.max_per_device_batch_size // self._loaded_max_seq_len * len( + self.devices)) + if sequences_per_minibatch < 1: + logger.warn( + ("Target minibatch size is {}, however the rollout sequence " + "length is {}, hence the minibatch size will be raised to " + "{}.").format(self.max_per_device_batch_size, + self._loaded_max_seq_len, + self._loaded_max_seq_len * len(self.devices))) + sequences_per_minibatch = 1 + + if len(smallest_array) < sequences_per_minibatch: # Dynamically shrink the batch size if insufficient data - seq_batch_size = make_divisible_by( + sequences_per_minibatch = make_divisible_by( len(smallest_array), len(self.devices)) - if seq_batch_size < len(self.devices): + + if log_once("data_slicing"): + logger.info( + ("Divided {} rollout sequences, each of length {}, among " + "{} devices.").format( + len(smallest_array), self._loaded_max_seq_len, + len(self.devices))) + + if sequences_per_minibatch < len(self.devices): raise ValueError( "Must load at least 1 tuple sequence per device. Try " "increasing `sgd_minibatch_size` or reducing `max_seq_len` " "to ensure that at least one sequence fits per device.") - self._loaded_per_device_batch_size = ( - seq_batch_size // len(self.devices) * self._loaded_max_seq_len) + self._loaded_per_device_batch_size = (sequences_per_minibatch // len( + self.devices) * self._loaded_max_seq_len) if len(state_inputs) > 0: - # First truncate the RNN state arrays to the seq_batch_size + # First truncate the RNN state arrays to the sequences_per_minib. state_inputs = [ - make_divisible_by(arr, seq_batch_size) for arr in state_inputs + make_divisible_by(arr, sequences_per_minibatch) + for arr in state_inputs ] # Then truncate the data inputs to match inputs = [arr[:len(state_inputs[0]) * seq_len] for arr in inputs] assert len(state_inputs[0]) * seq_len == len(inputs[0]), \ - (len(state_inputs[0]), seq_batch_size, seq_len, len(inputs[0])) + (len(state_inputs[0]), sequences_per_minibatch, seq_len, + len(inputs[0])) for ph, arr in zip(self.loss_inputs, inputs + state_inputs): feed_dict[ph] = arr truncated_len = len(inputs[0]) else: for ph, arr in zip(self.loss_inputs, inputs + state_inputs): - truncated_arr = make_divisible_by(arr, seq_batch_size) + truncated_arr = make_divisible_by(arr, sequences_per_minibatch) feed_dict[ph] = truncated_arr truncated_len = len(truncated_arr) diff --git a/python/ray/rllib/tests/test_lstm.py b/python/ray/rllib/tests/test_lstm.py index 56e2dfd5c..385f2d7bc 100644 --- a/python/ray/rllib/tests/test_lstm.py +++ b/python/ray/rllib/tests/test_lstm.py @@ -25,8 +25,9 @@ class LSTMUtilsTest(unittest.TestCase): f = [[101, 102, 103, 201, 202, 203, 204, 205], [[101], [102], [103], [201], [202], [203], [204], [205]]] s = [[209, 208, 207, 109, 108, 107, 106, 105]] - f_pad, s_init, seq_lens = chop_into_sequences(eps_ids, agent_ids, f, s, - 4) + f_pad, s_init, seq_lens = chop_into_sequences(eps_ids, + np.ones_like(eps_ids), + agent_ids, f, s, 4) self.assertEqual([f.tolist() for f in f_pad], [ [101, 102, 103, 0, 201, 202, 203, 204, 205, 0, 0, 0], [[101], [102], [103], [0], [201], [202], [203], [204], [205], [0], @@ -35,6 +36,17 @@ class LSTMUtilsTest(unittest.TestCase): self.assertEqual([s.tolist() for s in s_init], [[209, 109, 105]]) self.assertEqual(seq_lens.tolist(), [3, 4, 1]) + def testBatchId(self): + eps_ids = [1, 1, 1, 5, 5, 5, 5, 5] + batch_ids = [1, 1, 2, 2, 3, 3, 4, 4] + agent_ids = [1, 1, 1, 1, 1, 1, 1, 1] + f = [[101, 102, 103, 201, 202, 203, 204, 205], + [[101], [102], [103], [201], [202], [203], [204], [205]]] + s = [[209, 208, 207, 109, 108, 107, 106, 105]] + _, _, seq_lens = chop_into_sequences(eps_ids, batch_ids, agent_ids, f, + s, 4) + self.assertEqual(seq_lens.tolist(), [2, 1, 1, 2, 2]) + def testMultiAgent(self): eps_ids = [1, 1, 1, 5, 5, 5, 5, 5] agent_ids = [1, 1, 2, 1, 1, 2, 2, 3] @@ -42,7 +54,13 @@ class LSTMUtilsTest(unittest.TestCase): [[101], [102], [103], [201], [202], [203], [204], [205]]] s = [[209, 208, 207, 109, 108, 107, 106, 105]] f_pad, s_init, seq_lens = chop_into_sequences( - eps_ids, agent_ids, f, s, 4, dynamic_max=False) + eps_ids, + np.ones_like(eps_ids), + agent_ids, + f, + s, + 4, + dynamic_max=False) self.assertEqual(seq_lens.tolist(), [2, 1, 2, 2, 1]) self.assertEqual(len(f_pad[0]), 20) self.assertEqual(len(s_init[0]), 5) @@ -52,8 +70,9 @@ class LSTMUtilsTest(unittest.TestCase): agent_ids = [2, 2, 2] f = [[1, 1, 1]] s = [[1, 1, 1]] - f_pad, s_init, seq_lens = chop_into_sequences(eps_ids, agent_ids, f, s, - 4) + f_pad, s_init, seq_lens = chop_into_sequences(eps_ids, + np.ones_like(eps_ids), + agent_ids, f, s, 4) self.assertEqual([f.tolist() for f in f_pad], [[1, 0, 1, 1]]) self.assertEqual([s.tolist() for s in s_init], [[1, 1]]) self.assertEqual(seq_lens.tolist(), [1, 2]) diff --git a/python/ray/rllib/tests/test_policy_evaluator.py b/python/ray/rllib/tests/test_policy_evaluator.py index 32160b0e5..56cbbca6d 100644 --- a/python/ray/rllib/tests/test_policy_evaluator.py +++ b/python/ray/rllib/tests/test_policy_evaluator.py @@ -16,7 +16,7 @@ from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator from ray.rllib.evaluation.metrics import collect_metrics from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.evaluation.postprocessing import compute_advantages -from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID +from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID, SampleBatch from ray.rllib.env.vector_env import VectorEnv from ray.tune.registry import register_env @@ -155,6 +155,17 @@ class TestPolicyEvaluator(unittest.TestCase): to_prev(batch["actions"])) self.assertGreater(batch["advantages"][0], 1) + def testBatchIds(self): + ev = PolicyEvaluator( + env_creator=lambda _: gym.make("CartPole-v0"), + policy_graph=MockPolicyGraph) + batch1 = ev.sample() + batch2 = ev.sample() + self.assertEqual(len(set(batch1["unroll_id"])), 1) + self.assertEqual(len(set(batch2["unroll_id"])), 1) + self.assertEqual( + len(set(SampleBatch.concat(batch1, batch2)["unroll_id"])), 2) + # 11/23/18: Samples per second 8501.125113727468 def testBaselinePerformance(self): ev = PolicyEvaluator(