mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 05:41:19 +08:00
[rllib] Fix corner case in rnn episode handling
We should use episode ids instead of the timestep to determine when sequences should be cut, since when batches are concatenated, increasing t does not guarantee we are part of the same episode.
This commit is contained in:
@@ -4,6 +4,7 @@ from __future__ import print_function
|
||||
|
||||
from collections import defaultdict, namedtuple
|
||||
import numpy as np
|
||||
import random
|
||||
import six.moves.queue as queue
|
||||
import threading
|
||||
|
||||
@@ -258,6 +259,7 @@ def _env_runner(async_vector_env,
|
||||
agent_id,
|
||||
policy_id,
|
||||
t=episode.length - 1,
|
||||
eps_id=episode.episode_id,
|
||||
obs=last_observation,
|
||||
actions=episode.last_action_for(agent_id),
|
||||
rewards=rewards[env_id][agent_id],
|
||||
@@ -367,6 +369,7 @@ class _MultiAgentEpisode(object):
|
||||
self.batch_builder = batch_builder_factory()
|
||||
self.total_reward = 0.0
|
||||
self.length = 0
|
||||
self.episode_id = random.randrange(2e9)
|
||||
self.agent_rewards = defaultdict(float)
|
||||
self._policies = policies
|
||||
self._policy_mapping_fn = policy_mapping_fn
|
||||
|
||||
@@ -140,7 +140,7 @@ class TFPolicyGraph(PolicyGraph):
|
||||
"state_in_{}".format(i) for i in range(len(self._state_inputs))
|
||||
]
|
||||
feature_sequences, initial_states, seq_lens = chop_into_sequences(
|
||||
batch["t"], [batch[k] for k in feature_keys],
|
||||
batch["eps_id"], [batch[k] for k in feature_keys],
|
||||
[batch[k] for k in state_keys], self._max_seq_len)
|
||||
for k, v in zip(feature_keys, feature_sequences):
|
||||
feed_dict[self._loss_input_dict[k]] = v
|
||||
|
||||
@@ -49,14 +49,12 @@ def add_time_dimension(padded_inputs, seq_lens):
|
||||
return tf.reshape(padded_inputs, new_shape)
|
||||
|
||||
|
||||
def chop_into_sequences(time_column, feature_columns, state_columns,
|
||||
def chop_into_sequences(episode_ids, feature_columns, state_columns,
|
||||
max_seq_len):
|
||||
"""Truncate and pad experiences into fixed-length sequences.
|
||||
|
||||
Arguments:
|
||||
time_column (list): Timesteps per feature / state. This contains
|
||||
sequences of monotonically increasing step values, e.g.,
|
||||
[0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2].
|
||||
episode_ids (list): List of episode ids for each step.
|
||||
feature_columns (list): List of arrays containing features.
|
||||
state_columns (list): List of arrays containing LSTM state values.
|
||||
max_seq_len (int): Max length of sequences before truncation.
|
||||
@@ -70,7 +68,7 @@ def chop_into_sequences(time_column, feature_columns, state_columns,
|
||||
|
||||
Examples:
|
||||
>>> f_pad, s_init, seq_lens = chop_into_sequences(
|
||||
time_column=[0, 1, 0, 1, 2, 3],
|
||||
episode_id=[1, 1, 5, 5, 5, 5],
|
||||
feature_columns=[[4, 4, 8, 8, 8, 8],
|
||||
[1, 1, 0, 1, 1, 0]],
|
||||
state_columns=[[4, 5, 4, 5, 5, 5]],
|
||||
@@ -84,18 +82,19 @@ def chop_into_sequences(time_column, feature_columns, state_columns,
|
||||
[2, 3, 1]
|
||||
"""
|
||||
|
||||
prev_t = -1
|
||||
prev_id = None
|
||||
seq_lens = []
|
||||
seq_len = 0
|
||||
for t in time_column:
|
||||
if t <= prev_t or seq_len >= max_seq_len:
|
||||
for eps_id in episode_ids:
|
||||
if (prev_id is not None and eps_id != prev_id) or \
|
||||
seq_len >= max_seq_len:
|
||||
seq_lens.append(seq_len)
|
||||
seq_len = 0
|
||||
seq_len += 1
|
||||
prev_t = t
|
||||
prev_id = eps_id
|
||||
if seq_len:
|
||||
seq_lens.append(seq_len)
|
||||
assert sum(seq_lens) == len(time_column)
|
||||
assert sum(seq_lens) == len(episode_ids)
|
||||
|
||||
# Dynamically shrink max len as needed to optimize memory usage
|
||||
max_seq_len = max(seq_lens)
|
||||
@@ -111,7 +110,7 @@ def chop_into_sequences(time_column, feature_columns, state_columns,
|
||||
f_pad[seq_base + seq_offset] = f[i]
|
||||
i += 1
|
||||
seq_base += max_seq_len
|
||||
assert i == len(time_column), f
|
||||
assert i == len(episode_ids), f
|
||||
feature_sequences.append(f_pad)
|
||||
|
||||
initial_states = []
|
||||
|
||||
@@ -9,11 +9,11 @@ from ray.rllib.models.lstm import chop_into_sequences
|
||||
|
||||
class LSTMUtilsTest(unittest.TestCase):
|
||||
def testBasic(self):
|
||||
t = [1, 2, 3, 1, 2, 3, 4, 5]
|
||||
eps_ids = [1, 1, 1, 5, 5, 5, 5, 5]
|
||||
f = [[101, 102, 103, 201, 202, 203, 204, 205],
|
||||
[[101], [102], [103], [201], [202], [203], [204], [205]]]
|
||||
s = [[209, 208, 207, 109, 108, 107, 106, 105]]
|
||||
f_pad, s_init, seq_lens = chop_into_sequences(t, f, s, 4)
|
||||
f_pad, s_init, seq_lens = chop_into_sequences(eps_ids, f, s, 4)
|
||||
self.assertEqual([f.tolist() for f in f_pad], [
|
||||
[101, 102, 103, 0, 201, 202, 203, 204, 205, 0, 0, 0],
|
||||
[[101], [102], [103], [0], [201], [202], [203], [204], [205], [0],
|
||||
@@ -23,10 +23,10 @@ class LSTMUtilsTest(unittest.TestCase):
|
||||
self.assertEqual(seq_lens.tolist(), [3, 4, 1])
|
||||
|
||||
def testDynamicMaxLen(self):
|
||||
t = [1, 1, 2]
|
||||
eps_ids = [5, 2, 2]
|
||||
f = [[1, 1, 1]]
|
||||
s = [[1, 1, 1]]
|
||||
f_pad, s_init, seq_lens = chop_into_sequences(t, f, s, 4)
|
||||
f_pad, s_init, seq_lens = chop_into_sequences(eps_ids, f, s, 4)
|
||||
self.assertEqual([f.tolist() for f in f_pad], [[1, 0, 1, 1]])
|
||||
self.assertEqual([s.tolist() for s in s_init], [[1, 1]])
|
||||
self.assertEqual(seq_lens.tolist(), [1, 2])
|
||||
|
||||
Reference in New Issue
Block a user