[rllib] Fix corner case in rnn episode handling

We should use episode ids instead of the timestep to determine when sequences should be cut, since when batches are concatenated, increasing t does not guarantee we are part of the same episode.
This commit is contained in:
Eric Liang
2018-07-30 13:24:43 -07:00
committed by GitHub
parent 696a229ece
commit 62a52ee989
4 changed files with 18 additions and 16 deletions
+3
View File
@@ -4,6 +4,7 @@ from __future__ import print_function
from collections import defaultdict, namedtuple
import numpy as np
import random
import six.moves.queue as queue
import threading
@@ -258,6 +259,7 @@ def _env_runner(async_vector_env,
agent_id,
policy_id,
t=episode.length - 1,
eps_id=episode.episode_id,
obs=last_observation,
actions=episode.last_action_for(agent_id),
rewards=rewards[env_id][agent_id],
@@ -367,6 +369,7 @@ class _MultiAgentEpisode(object):
self.batch_builder = batch_builder_factory()
self.total_reward = 0.0
self.length = 0
self.episode_id = random.randrange(2e9)
self.agent_rewards = defaultdict(float)
self._policies = policies
self._policy_mapping_fn = policy_mapping_fn
@@ -140,7 +140,7 @@ class TFPolicyGraph(PolicyGraph):
"state_in_{}".format(i) for i in range(len(self._state_inputs))
]
feature_sequences, initial_states, seq_lens = chop_into_sequences(
batch["t"], [batch[k] for k in feature_keys],
batch["eps_id"], [batch[k] for k in feature_keys],
[batch[k] for k in state_keys], self._max_seq_len)
for k, v in zip(feature_keys, feature_sequences):
feed_dict[self._loss_input_dict[k]] = v
+10 -11
View File
@@ -49,14 +49,12 @@ def add_time_dimension(padded_inputs, seq_lens):
return tf.reshape(padded_inputs, new_shape)
def chop_into_sequences(time_column, feature_columns, state_columns,
def chop_into_sequences(episode_ids, feature_columns, state_columns,
max_seq_len):
"""Truncate and pad experiences into fixed-length sequences.
Arguments:
time_column (list): Timesteps per feature / state. This contains
sequences of monotonically increasing step values, e.g.,
[0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2].
episode_ids (list): List of episode ids for each step.
feature_columns (list): List of arrays containing features.
state_columns (list): List of arrays containing LSTM state values.
max_seq_len (int): Max length of sequences before truncation.
@@ -70,7 +68,7 @@ def chop_into_sequences(time_column, feature_columns, state_columns,
Examples:
>>> f_pad, s_init, seq_lens = chop_into_sequences(
time_column=[0, 1, 0, 1, 2, 3],
episode_id=[1, 1, 5, 5, 5, 5],
feature_columns=[[4, 4, 8, 8, 8, 8],
[1, 1, 0, 1, 1, 0]],
state_columns=[[4, 5, 4, 5, 5, 5]],
@@ -84,18 +82,19 @@ def chop_into_sequences(time_column, feature_columns, state_columns,
[2, 3, 1]
"""
prev_t = -1
prev_id = None
seq_lens = []
seq_len = 0
for t in time_column:
if t <= prev_t or seq_len >= max_seq_len:
for eps_id in episode_ids:
if (prev_id is not None and eps_id != prev_id) or \
seq_len >= max_seq_len:
seq_lens.append(seq_len)
seq_len = 0
seq_len += 1
prev_t = t
prev_id = eps_id
if seq_len:
seq_lens.append(seq_len)
assert sum(seq_lens) == len(time_column)
assert sum(seq_lens) == len(episode_ids)
# Dynamically shrink max len as needed to optimize memory usage
max_seq_len = max(seq_lens)
@@ -111,7 +110,7 @@ def chop_into_sequences(time_column, feature_columns, state_columns,
f_pad[seq_base + seq_offset] = f[i]
i += 1
seq_base += max_seq_len
assert i == len(time_column), f
assert i == len(episode_ids), f
feature_sequences.append(f_pad)
initial_states = []
+4 -4
View File
@@ -9,11 +9,11 @@ from ray.rllib.models.lstm import chop_into_sequences
class LSTMUtilsTest(unittest.TestCase):
def testBasic(self):
t = [1, 2, 3, 1, 2, 3, 4, 5]
eps_ids = [1, 1, 1, 5, 5, 5, 5, 5]
f = [[101, 102, 103, 201, 202, 203, 204, 205],
[[101], [102], [103], [201], [202], [203], [204], [205]]]
s = [[209, 208, 207, 109, 108, 107, 106, 105]]
f_pad, s_init, seq_lens = chop_into_sequences(t, f, s, 4)
f_pad, s_init, seq_lens = chop_into_sequences(eps_ids, f, s, 4)
self.assertEqual([f.tolist() for f in f_pad], [
[101, 102, 103, 0, 201, 202, 203, 204, 205, 0, 0, 0],
[[101], [102], [103], [0], [201], [202], [203], [204], [205], [0],
@@ -23,10 +23,10 @@ class LSTMUtilsTest(unittest.TestCase):
self.assertEqual(seq_lens.tolist(), [3, 4, 1])
def testDynamicMaxLen(self):
t = [1, 1, 2]
eps_ids = [5, 2, 2]
f = [[1, 1, 1]]
s = [[1, 1, 1]]
f_pad, s_init, seq_lens = chop_into_sequences(t, f, s, 4)
f_pad, s_init, seq_lens = chop_into_sequences(eps_ids, f, s, 4)
self.assertEqual([f.tolist() for f in f_pad], [[1, 0, 1, 1]])
self.assertEqual([s.tolist() for s in s_init], [[1, 1]])
self.assertEqual(seq_lens.tolist(), [1, 2])