diff --git a/python/ray/rllib/evaluation/sampler.py b/python/ray/rllib/evaluation/sampler.py index 6ae66e6da..823652f4d 100644 --- a/python/ray/rllib/evaluation/sampler.py +++ b/python/ray/rllib/evaluation/sampler.py @@ -4,6 +4,7 @@ from __future__ import print_function from collections import defaultdict, namedtuple import numpy as np +import random import six.moves.queue as queue import threading @@ -258,6 +259,7 @@ def _env_runner(async_vector_env, agent_id, policy_id, t=episode.length - 1, + eps_id=episode.episode_id, obs=last_observation, actions=episode.last_action_for(agent_id), rewards=rewards[env_id][agent_id], @@ -367,6 +369,7 @@ class _MultiAgentEpisode(object): self.batch_builder = batch_builder_factory() self.total_reward = 0.0 self.length = 0 + self.episode_id = random.randrange(2e9) self.agent_rewards = defaultdict(float) self._policies = policies self._policy_mapping_fn = policy_mapping_fn diff --git a/python/ray/rllib/evaluation/tf_policy_graph.py b/python/ray/rllib/evaluation/tf_policy_graph.py index d7225d7a4..ce9e0803b 100644 --- a/python/ray/rllib/evaluation/tf_policy_graph.py +++ b/python/ray/rllib/evaluation/tf_policy_graph.py @@ -140,7 +140,7 @@ class TFPolicyGraph(PolicyGraph): "state_in_{}".format(i) for i in range(len(self._state_inputs)) ] feature_sequences, initial_states, seq_lens = chop_into_sequences( - batch["t"], [batch[k] for k in feature_keys], + batch["eps_id"], [batch[k] for k in feature_keys], [batch[k] for k in state_keys], self._max_seq_len) for k, v in zip(feature_keys, feature_sequences): feed_dict[self._loss_input_dict[k]] = v diff --git a/python/ray/rllib/models/lstm.py b/python/ray/rllib/models/lstm.py index 1365b5a69..b18fbb087 100644 --- a/python/ray/rllib/models/lstm.py +++ b/python/ray/rllib/models/lstm.py @@ -49,14 +49,12 @@ def add_time_dimension(padded_inputs, seq_lens): return tf.reshape(padded_inputs, new_shape) -def chop_into_sequences(time_column, feature_columns, state_columns, +def chop_into_sequences(episode_ids, feature_columns, state_columns, max_seq_len): """Truncate and pad experiences into fixed-length sequences. Arguments: - time_column (list): Timesteps per feature / state. This contains - sequences of monotonically increasing step values, e.g., - [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2]. + episode_ids (list): List of episode ids for each step. feature_columns (list): List of arrays containing features. state_columns (list): List of arrays containing LSTM state values. max_seq_len (int): Max length of sequences before truncation. @@ -70,7 +68,7 @@ def chop_into_sequences(time_column, feature_columns, state_columns, Examples: >>> f_pad, s_init, seq_lens = chop_into_sequences( - time_column=[0, 1, 0, 1, 2, 3], + episode_id=[1, 1, 5, 5, 5, 5], feature_columns=[[4, 4, 8, 8, 8, 8], [1, 1, 0, 1, 1, 0]], state_columns=[[4, 5, 4, 5, 5, 5]], @@ -84,18 +82,19 @@ def chop_into_sequences(time_column, feature_columns, state_columns, [2, 3, 1] """ - prev_t = -1 + prev_id = None seq_lens = [] seq_len = 0 - for t in time_column: - if t <= prev_t or seq_len >= max_seq_len: + for eps_id in episode_ids: + if (prev_id is not None and eps_id != prev_id) or \ + seq_len >= max_seq_len: seq_lens.append(seq_len) seq_len = 0 seq_len += 1 - prev_t = t + prev_id = eps_id if seq_len: seq_lens.append(seq_len) - assert sum(seq_lens) == len(time_column) + assert sum(seq_lens) == len(episode_ids) # Dynamically shrink max len as needed to optimize memory usage max_seq_len = max(seq_lens) @@ -111,7 +110,7 @@ def chop_into_sequences(time_column, feature_columns, state_columns, f_pad[seq_base + seq_offset] = f[i] i += 1 seq_base += max_seq_len - assert i == len(time_column), f + assert i == len(episode_ids), f feature_sequences.append(f_pad) initial_states = [] diff --git a/python/ray/rllib/test/test_lstm.py b/python/ray/rllib/test/test_lstm.py index 0fd6dffc3..2abfb7680 100644 --- a/python/ray/rllib/test/test_lstm.py +++ b/python/ray/rllib/test/test_lstm.py @@ -9,11 +9,11 @@ from ray.rllib.models.lstm import chop_into_sequences class LSTMUtilsTest(unittest.TestCase): def testBasic(self): - t = [1, 2, 3, 1, 2, 3, 4, 5] + eps_ids = [1, 1, 1, 5, 5, 5, 5, 5] f = [[101, 102, 103, 201, 202, 203, 204, 205], [[101], [102], [103], [201], [202], [203], [204], [205]]] s = [[209, 208, 207, 109, 108, 107, 106, 105]] - f_pad, s_init, seq_lens = chop_into_sequences(t, f, s, 4) + f_pad, s_init, seq_lens = chop_into_sequences(eps_ids, f, s, 4) self.assertEqual([f.tolist() for f in f_pad], [ [101, 102, 103, 0, 201, 202, 203, 204, 205, 0, 0, 0], [[101], [102], [103], [0], [201], [202], [203], [204], [205], [0], @@ -23,10 +23,10 @@ class LSTMUtilsTest(unittest.TestCase): self.assertEqual(seq_lens.tolist(), [3, 4, 1]) def testDynamicMaxLen(self): - t = [1, 1, 2] + eps_ids = [5, 2, 2] f = [[1, 1, 1]] s = [[1, 1, 1]] - f_pad, s_init, seq_lens = chop_into_sequences(t, f, s, 4) + f_pad, s_init, seq_lens = chop_into_sequences(eps_ids, f, s, 4) self.assertEqual([f.tolist() for f in f_pad], [[1, 0, 1, 1]]) self.assertEqual([s.tolist() for s in s_init], [[1, 1]]) self.assertEqual(seq_lens.tolist(), [1, 2])