[rllib] Add magic methods for rollouts (#2024)

This commit is contained in:
Alok Singh
2018-05-16 22:59:46 -07:00
committed by Richard Liaw
parent 7549209aea
commit c0e4c9d3d1
4 changed files with 83 additions and 47 deletions
+24 -9
View File
@@ -36,8 +36,8 @@ class SampleBatch(object):
@staticmethod
def concat_samples(samples):
out = {}
for k in samples[0].data.keys():
out[k] = np.concatenate([s.data[k] for s in samples])
for k in samples[0].keys():
out[k] = np.concatenate([s[k] for s in samples])
return SampleBatch(out)
def concat(self, other):
@@ -50,10 +50,10 @@ class SampleBatch(object):
{"a": [1, 2, 3, 4, 5]}
"""
assert self.data.keys() == other.data.keys(), "must have same columns"
assert self.keys() == other.keys(), "must have same columns"
out = {}
for k in self.data.keys():
out[k] = np.concatenate([self.data[k], other.data[k]])
for k in self.keys():
out[k] = np.concatenate([self[k], other[k]])
return SampleBatch(out)
def rows(self):
@@ -70,7 +70,7 @@ class SampleBatch(object):
for i in range(self.count):
row = {}
for k in self.data.keys():
for k in self.keys():
row[k] = self[k][i]
yield row
@@ -85,19 +85,34 @@ class SampleBatch(object):
out = []
for k in keys:
out.append(self.data[k])
out.append(self[k])
return out
def shuffle(self):
permutation = np.random.permutation(self.count)
for key, val in self.data.items():
self.data[key] = val[permutation]
for key, val in self.items():
self[key] = val[permutation]
def __getitem__(self, key):
return self.data[key]
def __setitem__(self, key, item):
self.data[key] = item
def __str__(self):
return "SampleBatch({})".format(str(self.data))
def __repr__(self):
return "SampleBatch({})".format(str(self.data))
def keys(self):
return self.data.keys()
def items(self):
return self.data.items()
def __iter__(self):
return self.data.__iter__()
def __contains__(self, x):
return x in self.data
+7 -7
View File
@@ -12,7 +12,6 @@ from ray.rllib.optimizers import AsyncOptimizer, SampleBatch
class AsyncOptimizerTest(unittest.TestCase):
def tearDown(self):
ray.worker.cleanup()
@@ -21,8 +20,9 @@ class AsyncOptimizerTest(unittest.TestCase):
local = _MockEvaluator()
remotes = ray.remote(_MockEvaluator)
remote_evaluators = [remotes.remote() for i in range(5)]
test_optimizer = AsyncOptimizer(
{"grads_per_step": 10}, local, remote_evaluators)
test_optimizer = AsyncOptimizer({
"grads_per_step": 10
}, local, remote_evaluators)
test_optimizer.step()
self.assertTrue(all(local.get_weights() == 0))
@@ -33,11 +33,11 @@ class SampleBatchTest(unittest.TestCase):
b2 = SampleBatch({"a": np.array([1]), "b": np.array([4])})
b3 = SampleBatch({"a": np.array([1]), "b": np.array([5])})
b12 = b1.concat(b2)
self.assertEqual(b12.data["a"].tolist(), [1, 2, 3, 1])
self.assertEqual(b12.data["b"].tolist(), [4, 5, 6, 4])
self.assertEqual(b12["a"].tolist(), [1, 2, 3, 1])
self.assertEqual(b12["b"].tolist(), [4, 5, 6, 4])
b = SampleBatch.concat_samples([b1, b2, b3])
self.assertEqual(b.data["a"].tolist(), [1, 2, 3, 1, 1])
self.assertEqual(b.data["b"].tolist(), [4, 5, 6, 4, 5])
self.assertEqual(b["a"].tolist(), [1, 2, 3, 1, 1])
self.assertEqual(b["b"].tolist(), [4, 5, 6, 4, 5])
if __name__ == '__main__':
+8 -8
View File
@@ -26,22 +26,22 @@ def process_rollout(rollout, reward_filter, gamma, lambda_=1.0, use_gae=True):
processed rewards."""
traj = {}
trajsize = len(rollout.data["actions"])
for key in rollout.data:
traj[key] = np.stack(rollout.data[key])
trajsize = len(rollout["actions"])
for key in rollout:
traj[key] = np.stack(rollout[key])
if use_gae:
assert "vf_preds" in rollout.data, "Values not found!"
vpred_t = np.stack(
rollout.data["vf_preds"] + [np.array(rollout.last_r)]).squeeze()
assert "vf_preds" in rollout, "Values not found!"
vpred_t = np.stack(rollout["vf_preds"] +
[np.array(rollout.last_r)]).squeeze()
delta_t = traj["rewards"] + gamma * vpred_t[1:] - vpred_t[:-1]
# This formula for the advantage comes
# "Generalized Advantage Estimation": https://arxiv.org/abs/1506.02438
traj["advantages"] = discount(delta_t, gamma * lambda_)
traj["value_targets"] = traj["advantages"] + traj["vf_preds"]
else:
rewards_plus_v = np.stack(
rollout.data["rewards"] + [np.array(rollout.last_r)]).squeeze()
rewards_plus_v = np.stack(rollout["rewards"] +
[np.array(rollout.last_r)]).squeeze()
traj["advantages"] = discount(rewards_plus_v, gamma)[:-1]
for i in range(traj["advantages"].shape[0]):
+44 -23
View File
@@ -56,9 +56,30 @@ class PartialRollout(object):
terminal (bool): if rollout has terminated."""
return self.data["dones"][-1]
def __getitem__(self, key):
return self.data[key]
CompletedRollout = namedtuple(
"CompletedRollout", ["episode_length", "episode_reward"])
def __setitem__(self, key, item):
self.data[key] = item
def keys(self):
return self.data.keys()
def items(self):
return self.data.items()
def __iter__(self):
return self.data.__iter__()
def __next__(self):
return self.data.__next__()
def __contains__(self, x):
return x in self.data
CompletedRollout = namedtuple("CompletedRollout",
["episode_length", "episode_reward"])
class SyncSampler(object):
@@ -71,16 +92,15 @@ class SyncSampler(object):
thread."""
async = False
def __init__(self, env, policy, obs_filter,
num_local_steps, horizon=None):
def __init__(self, env, policy, obs_filter, num_local_steps, horizon=None):
self.num_local_steps = num_local_steps
self.horizon = horizon
self.env = env
self.policy = policy
self._obs_filter = obs_filter
self.rollout_provider = _env_runner(
self.env, self.policy, self.num_local_steps, self.horizon,
self._obs_filter)
self.rollout_provider = _env_runner(self.env, self.policy,
self.num_local_steps, self.horizon,
self._obs_filter)
self.metrics_queue = queue.Queue()
def get_data(self):
@@ -108,10 +128,10 @@ class AsyncSampler(threading.Thread):
accumulate and the gradient can be calculated on up to 5 batches."""
async = True
def __init__(self, env, policy, obs_filter,
num_local_steps, horizon=None):
assert getattr(obs_filter, "is_concurrent", False), (
"Observation Filter must support concurrent updates.")
def __init__(self, env, policy, obs_filter, num_local_steps, horizon=None):
assert getattr(
obs_filter, "is_concurrent",
False), ("Observation Filter must support concurrent updates.")
threading.Thread.__init__(self)
self.queue = queue.Queue(5)
self.metrics_queue = queue.Queue()
@@ -132,9 +152,9 @@ class AsyncSampler(threading.Thread):
raise e
def _run(self):
rollout_provider = _env_runner(
self.env, self.policy, self.num_local_steps,
self.horizon, self._obs_filter)
rollout_provider = _env_runner(self.env, self.policy,
self.num_local_steps, self.horizon,
self._obs_filter)
while True:
# The timeout variable exists because apparently, if one worker
# dies, the other workers won't die with it, unless the timeout is
@@ -232,13 +252,14 @@ def _env_runner(env, policy, num_local_steps, horizon, obs_filter):
action = np.concatenate(action, axis=0).flatten()
# Collect the experience.
rollout.add(obs=last_observation,
actions=action,
rewards=reward,
dones=terminal,
features=last_features,
new_obs=observation,
**pi_info)
rollout.add(
obs=last_observation,
actions=action,
rewards=reward,
dones=terminal,
features=last_features,
new_obs=observation,
**pi_info)
last_observation = observation
last_features = features
@@ -247,8 +268,8 @@ def _env_runner(env, policy, num_local_steps, horizon, obs_filter):
terminal_end = True
yield CompletedRollout(length, rewards)
if (length >= horizon or
not env.metadata.get("semantics.autoreset")):
if (length >= horizon
or not env.metadata.get("semantics.autoreset")):
last_observation = obs_filter(env.reset())
if hasattr(policy, "get_initial_features"):
last_features = policy.get_initial_features()