mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 08:53:44 +08:00
[rllib] Don't merge unrolls from same episode when calculating seq lens (#4557)
This commit is contained in:
@@ -245,6 +245,7 @@ class QMixPolicyGraph(PolicyGraph):
|
||||
[rew, action_mask, act, dones, obs], initial_states, seq_lens = \
|
||||
chop_into_sequences(
|
||||
samples[SampleBatch.EPS_ID],
|
||||
samples[SampleBatch.UNROLL_ID],
|
||||
samples[SampleBatch.AGENT_INDEX], [
|
||||
group_rewards, action_mask, samples[SampleBatch.ACTIONS],
|
||||
samples[SampleBatch.DONES], obs_batch
|
||||
|
||||
@@ -106,6 +106,11 @@ class SampleBatch(object):
|
||||
# Uniquely identifies an episode
|
||||
EPS_ID = "eps_id"
|
||||
|
||||
# Uniquely identifies a sample batch. This is important to distinguish RNN
|
||||
# sequences from the same episode when multiple sample batches are
|
||||
# concatenated (fusing sequences across batches can be unsafe).
|
||||
UNROLL_ID = "unroll_id"
|
||||
|
||||
# Uniquely identifies an agent within an episode
|
||||
AGENT_INDEX = "agent_index"
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ class SampleBatchBuilder(object):
|
||||
def __init__(self):
|
||||
self.buffers = collections.defaultdict(list)
|
||||
self.count = 0
|
||||
self.unroll_id = 0 # disambiguates unrolls within a single episode
|
||||
|
||||
@PublicAPI
|
||||
def add_values(self, **values):
|
||||
@@ -56,8 +57,11 @@ class SampleBatchBuilder(object):
|
||||
batch = SampleBatch(
|
||||
{k: to_float_array(v)
|
||||
for k, v in self.buffers.items()})
|
||||
batch.data[SampleBatch.UNROLL_ID] = np.repeat(self.unroll_id,
|
||||
batch.count)
|
||||
self.buffers.clear()
|
||||
self.count = 0
|
||||
self.unroll_id += 1
|
||||
return batch
|
||||
|
||||
|
||||
|
||||
@@ -464,6 +464,7 @@ class TFPolicyGraph(PolicyGraph):
|
||||
]
|
||||
feature_sequences, initial_states, seq_lens = chop_into_sequences(
|
||||
batch[SampleBatch.EPS_ID],
|
||||
batch[SampleBatch.UNROLL_ID],
|
||||
batch[SampleBatch.AGENT_INDEX], [batch[k] for k in feature_keys],
|
||||
[batch[k] for k in state_keys],
|
||||
max_seq_len,
|
||||
|
||||
@@ -121,6 +121,7 @@ def add_time_dimension(padded_inputs, seq_lens):
|
||||
|
||||
@DeveloperAPI
|
||||
def chop_into_sequences(episode_ids,
|
||||
unroll_ids,
|
||||
agent_indices,
|
||||
feature_columns,
|
||||
state_columns,
|
||||
@@ -131,6 +132,8 @@ def chop_into_sequences(episode_ids,
|
||||
|
||||
Arguments:
|
||||
episode_ids (list): List of episode ids for each step.
|
||||
unroll_ids (list): List of identifiers for the sample batch. This is
|
||||
used to make sure sequences are cut between sample batches.
|
||||
agent_indices (list): List of agent ids for each step. Note that this
|
||||
has to be combined with episode_ids for uniqueness.
|
||||
feature_columns (list): List of arrays containing features.
|
||||
@@ -150,7 +153,9 @@ def chop_into_sequences(episode_ids,
|
||||
|
||||
Examples:
|
||||
>>> f_pad, s_init, seq_lens = chop_into_sequences(
|
||||
episode_id=[1, 1, 5, 5, 5, 5],
|
||||
episode_ids=[1, 1, 5, 5, 5, 5],
|
||||
unroll_ids=[4, 4, 4, 4, 4, 4],
|
||||
agent_indices=[0, 0, 0, 0, 0, 0],
|
||||
feature_columns=[[4, 4, 8, 8, 8, 8],
|
||||
[1, 1, 0, 1, 1, 0]],
|
||||
state_columns=[[4, 5, 4, 5, 5, 5]],
|
||||
@@ -167,7 +172,9 @@ def chop_into_sequences(episode_ids,
|
||||
prev_id = None
|
||||
seq_lens = []
|
||||
seq_len = 0
|
||||
unique_ids = np.add(episode_ids, agent_indices)
|
||||
unique_ids = np.add(
|
||||
np.add(episode_ids, agent_indices),
|
||||
np.array(unroll_ids) << 32)
|
||||
for uid in unique_ids:
|
||||
if (prev_id is not None and uid != prev_id) or \
|
||||
seq_len >= max_seq_len:
|
||||
|
||||
@@ -164,35 +164,55 @@ class LocalSyncParallelOptimizer(object):
|
||||
smallest_array = inputs[0]
|
||||
self._loaded_max_seq_len = 1
|
||||
|
||||
seq_batch_size = (self.max_per_device_batch_size //
|
||||
self._loaded_max_seq_len * len(self.devices))
|
||||
if len(smallest_array) < seq_batch_size:
|
||||
sequences_per_minibatch = (
|
||||
self.max_per_device_batch_size // self._loaded_max_seq_len * len(
|
||||
self.devices))
|
||||
if sequences_per_minibatch < 1:
|
||||
logger.warn(
|
||||
("Target minibatch size is {}, however the rollout sequence "
|
||||
"length is {}, hence the minibatch size will be raised to "
|
||||
"{}.").format(self.max_per_device_batch_size,
|
||||
self._loaded_max_seq_len,
|
||||
self._loaded_max_seq_len * len(self.devices)))
|
||||
sequences_per_minibatch = 1
|
||||
|
||||
if len(smallest_array) < sequences_per_minibatch:
|
||||
# Dynamically shrink the batch size if insufficient data
|
||||
seq_batch_size = make_divisible_by(
|
||||
sequences_per_minibatch = make_divisible_by(
|
||||
len(smallest_array), len(self.devices))
|
||||
if seq_batch_size < len(self.devices):
|
||||
|
||||
if log_once("data_slicing"):
|
||||
logger.info(
|
||||
("Divided {} rollout sequences, each of length {}, among "
|
||||
"{} devices.").format(
|
||||
len(smallest_array), self._loaded_max_seq_len,
|
||||
len(self.devices)))
|
||||
|
||||
if sequences_per_minibatch < len(self.devices):
|
||||
raise ValueError(
|
||||
"Must load at least 1 tuple sequence per device. Try "
|
||||
"increasing `sgd_minibatch_size` or reducing `max_seq_len` "
|
||||
"to ensure that at least one sequence fits per device.")
|
||||
self._loaded_per_device_batch_size = (
|
||||
seq_batch_size // len(self.devices) * self._loaded_max_seq_len)
|
||||
self._loaded_per_device_batch_size = (sequences_per_minibatch // len(
|
||||
self.devices) * self._loaded_max_seq_len)
|
||||
|
||||
if len(state_inputs) > 0:
|
||||
# First truncate the RNN state arrays to the seq_batch_size
|
||||
# First truncate the RNN state arrays to the sequences_per_minib.
|
||||
state_inputs = [
|
||||
make_divisible_by(arr, seq_batch_size) for arr in state_inputs
|
||||
make_divisible_by(arr, sequences_per_minibatch)
|
||||
for arr in state_inputs
|
||||
]
|
||||
# Then truncate the data inputs to match
|
||||
inputs = [arr[:len(state_inputs[0]) * seq_len] for arr in inputs]
|
||||
assert len(state_inputs[0]) * seq_len == len(inputs[0]), \
|
||||
(len(state_inputs[0]), seq_batch_size, seq_len, len(inputs[0]))
|
||||
(len(state_inputs[0]), sequences_per_minibatch, seq_len,
|
||||
len(inputs[0]))
|
||||
for ph, arr in zip(self.loss_inputs, inputs + state_inputs):
|
||||
feed_dict[ph] = arr
|
||||
truncated_len = len(inputs[0])
|
||||
else:
|
||||
for ph, arr in zip(self.loss_inputs, inputs + state_inputs):
|
||||
truncated_arr = make_divisible_by(arr, seq_batch_size)
|
||||
truncated_arr = make_divisible_by(arr, sequences_per_minibatch)
|
||||
feed_dict[ph] = truncated_arr
|
||||
truncated_len = len(truncated_arr)
|
||||
|
||||
|
||||
@@ -25,8 +25,9 @@ class LSTMUtilsTest(unittest.TestCase):
|
||||
f = [[101, 102, 103, 201, 202, 203, 204, 205],
|
||||
[[101], [102], [103], [201], [202], [203], [204], [205]]]
|
||||
s = [[209, 208, 207, 109, 108, 107, 106, 105]]
|
||||
f_pad, s_init, seq_lens = chop_into_sequences(eps_ids, agent_ids, f, s,
|
||||
4)
|
||||
f_pad, s_init, seq_lens = chop_into_sequences(eps_ids,
|
||||
np.ones_like(eps_ids),
|
||||
agent_ids, f, s, 4)
|
||||
self.assertEqual([f.tolist() for f in f_pad], [
|
||||
[101, 102, 103, 0, 201, 202, 203, 204, 205, 0, 0, 0],
|
||||
[[101], [102], [103], [0], [201], [202], [203], [204], [205], [0],
|
||||
@@ -35,6 +36,17 @@ class LSTMUtilsTest(unittest.TestCase):
|
||||
self.assertEqual([s.tolist() for s in s_init], [[209, 109, 105]])
|
||||
self.assertEqual(seq_lens.tolist(), [3, 4, 1])
|
||||
|
||||
def testBatchId(self):
|
||||
eps_ids = [1, 1, 1, 5, 5, 5, 5, 5]
|
||||
batch_ids = [1, 1, 2, 2, 3, 3, 4, 4]
|
||||
agent_ids = [1, 1, 1, 1, 1, 1, 1, 1]
|
||||
f = [[101, 102, 103, 201, 202, 203, 204, 205],
|
||||
[[101], [102], [103], [201], [202], [203], [204], [205]]]
|
||||
s = [[209, 208, 207, 109, 108, 107, 106, 105]]
|
||||
_, _, seq_lens = chop_into_sequences(eps_ids, batch_ids, agent_ids, f,
|
||||
s, 4)
|
||||
self.assertEqual(seq_lens.tolist(), [2, 1, 1, 2, 2])
|
||||
|
||||
def testMultiAgent(self):
|
||||
eps_ids = [1, 1, 1, 5, 5, 5, 5, 5]
|
||||
agent_ids = [1, 1, 2, 1, 1, 2, 2, 3]
|
||||
@@ -42,7 +54,13 @@ class LSTMUtilsTest(unittest.TestCase):
|
||||
[[101], [102], [103], [201], [202], [203], [204], [205]]]
|
||||
s = [[209, 208, 207, 109, 108, 107, 106, 105]]
|
||||
f_pad, s_init, seq_lens = chop_into_sequences(
|
||||
eps_ids, agent_ids, f, s, 4, dynamic_max=False)
|
||||
eps_ids,
|
||||
np.ones_like(eps_ids),
|
||||
agent_ids,
|
||||
f,
|
||||
s,
|
||||
4,
|
||||
dynamic_max=False)
|
||||
self.assertEqual(seq_lens.tolist(), [2, 1, 2, 2, 1])
|
||||
self.assertEqual(len(f_pad[0]), 20)
|
||||
self.assertEqual(len(s_init[0]), 5)
|
||||
@@ -52,8 +70,9 @@ class LSTMUtilsTest(unittest.TestCase):
|
||||
agent_ids = [2, 2, 2]
|
||||
f = [[1, 1, 1]]
|
||||
s = [[1, 1, 1]]
|
||||
f_pad, s_init, seq_lens = chop_into_sequences(eps_ids, agent_ids, f, s,
|
||||
4)
|
||||
f_pad, s_init, seq_lens = chop_into_sequences(eps_ids,
|
||||
np.ones_like(eps_ids),
|
||||
agent_ids, f, s, 4)
|
||||
self.assertEqual([f.tolist() for f in f_pad], [[1, 0, 1, 1]])
|
||||
self.assertEqual([s.tolist() for s in s_init], [[1, 1]])
|
||||
self.assertEqual(seq_lens.tolist(), [1, 2])
|
||||
|
||||
@@ -16,7 +16,7 @@ from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID, SampleBatch
|
||||
from ray.rllib.env.vector_env import VectorEnv
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
@@ -155,6 +155,17 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
to_prev(batch["actions"]))
|
||||
self.assertGreater(batch["advantages"][0], 1)
|
||||
|
||||
def testBatchIds(self):
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy_graph=MockPolicyGraph)
|
||||
batch1 = ev.sample()
|
||||
batch2 = ev.sample()
|
||||
self.assertEqual(len(set(batch1["unroll_id"])), 1)
|
||||
self.assertEqual(len(set(batch2["unroll_id"])), 1)
|
||||
self.assertEqual(
|
||||
len(set(SampleBatch.concat(batch1, batch2)["unroll_id"])), 2)
|
||||
|
||||
# 11/23/18: Samples per second 8501.125113727468
|
||||
def testBaselinePerformance(self):
|
||||
ev = PolicyEvaluator(
|
||||
|
||||
Reference in New Issue
Block a user