mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 10:01:11 +08:00
[rllib] Add magic methods for rollouts (#2024)
This commit is contained in:
@@ -36,8 +36,8 @@ class SampleBatch(object):
|
||||
@staticmethod
|
||||
def concat_samples(samples):
|
||||
out = {}
|
||||
for k in samples[0].data.keys():
|
||||
out[k] = np.concatenate([s.data[k] for s in samples])
|
||||
for k in samples[0].keys():
|
||||
out[k] = np.concatenate([s[k] for s in samples])
|
||||
return SampleBatch(out)
|
||||
|
||||
def concat(self, other):
|
||||
@@ -50,10 +50,10 @@ class SampleBatch(object):
|
||||
{"a": [1, 2, 3, 4, 5]}
|
||||
"""
|
||||
|
||||
assert self.data.keys() == other.data.keys(), "must have same columns"
|
||||
assert self.keys() == other.keys(), "must have same columns"
|
||||
out = {}
|
||||
for k in self.data.keys():
|
||||
out[k] = np.concatenate([self.data[k], other.data[k]])
|
||||
for k in self.keys():
|
||||
out[k] = np.concatenate([self[k], other[k]])
|
||||
return SampleBatch(out)
|
||||
|
||||
def rows(self):
|
||||
@@ -70,7 +70,7 @@ class SampleBatch(object):
|
||||
|
||||
for i in range(self.count):
|
||||
row = {}
|
||||
for k in self.data.keys():
|
||||
for k in self.keys():
|
||||
row[k] = self[k][i]
|
||||
yield row
|
||||
|
||||
@@ -85,19 +85,34 @@ class SampleBatch(object):
|
||||
|
||||
out = []
|
||||
for k in keys:
|
||||
out.append(self.data[k])
|
||||
out.append(self[k])
|
||||
return out
|
||||
|
||||
def shuffle(self):
|
||||
permutation = np.random.permutation(self.count)
|
||||
for key, val in self.data.items():
|
||||
self.data[key] = val[permutation]
|
||||
for key, val in self.items():
|
||||
self[key] = val[permutation]
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.data[key]
|
||||
|
||||
def __setitem__(self, key, item):
|
||||
self.data[key] = item
|
||||
|
||||
def __str__(self):
|
||||
return "SampleBatch({})".format(str(self.data))
|
||||
|
||||
def __repr__(self):
|
||||
return "SampleBatch({})".format(str(self.data))
|
||||
|
||||
def keys(self):
|
||||
return self.data.keys()
|
||||
|
||||
def items(self):
|
||||
return self.data.items()
|
||||
|
||||
def __iter__(self):
|
||||
return self.data.__iter__()
|
||||
|
||||
def __contains__(self, x):
|
||||
return x in self.data
|
||||
|
||||
@@ -12,7 +12,6 @@ from ray.rllib.optimizers import AsyncOptimizer, SampleBatch
|
||||
|
||||
|
||||
class AsyncOptimizerTest(unittest.TestCase):
|
||||
|
||||
def tearDown(self):
|
||||
ray.worker.cleanup()
|
||||
|
||||
@@ -21,8 +20,9 @@ class AsyncOptimizerTest(unittest.TestCase):
|
||||
local = _MockEvaluator()
|
||||
remotes = ray.remote(_MockEvaluator)
|
||||
remote_evaluators = [remotes.remote() for i in range(5)]
|
||||
test_optimizer = AsyncOptimizer(
|
||||
{"grads_per_step": 10}, local, remote_evaluators)
|
||||
test_optimizer = AsyncOptimizer({
|
||||
"grads_per_step": 10
|
||||
}, local, remote_evaluators)
|
||||
test_optimizer.step()
|
||||
self.assertTrue(all(local.get_weights() == 0))
|
||||
|
||||
@@ -33,11 +33,11 @@ class SampleBatchTest(unittest.TestCase):
|
||||
b2 = SampleBatch({"a": np.array([1]), "b": np.array([4])})
|
||||
b3 = SampleBatch({"a": np.array([1]), "b": np.array([5])})
|
||||
b12 = b1.concat(b2)
|
||||
self.assertEqual(b12.data["a"].tolist(), [1, 2, 3, 1])
|
||||
self.assertEqual(b12.data["b"].tolist(), [4, 5, 6, 4])
|
||||
self.assertEqual(b12["a"].tolist(), [1, 2, 3, 1])
|
||||
self.assertEqual(b12["b"].tolist(), [4, 5, 6, 4])
|
||||
b = SampleBatch.concat_samples([b1, b2, b3])
|
||||
self.assertEqual(b.data["a"].tolist(), [1, 2, 3, 1, 1])
|
||||
self.assertEqual(b.data["b"].tolist(), [4, 5, 6, 4, 5])
|
||||
self.assertEqual(b["a"].tolist(), [1, 2, 3, 1, 1])
|
||||
self.assertEqual(b["b"].tolist(), [4, 5, 6, 4, 5])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -26,22 +26,22 @@ def process_rollout(rollout, reward_filter, gamma, lambda_=1.0, use_gae=True):
|
||||
processed rewards."""
|
||||
|
||||
traj = {}
|
||||
trajsize = len(rollout.data["actions"])
|
||||
for key in rollout.data:
|
||||
traj[key] = np.stack(rollout.data[key])
|
||||
trajsize = len(rollout["actions"])
|
||||
for key in rollout:
|
||||
traj[key] = np.stack(rollout[key])
|
||||
|
||||
if use_gae:
|
||||
assert "vf_preds" in rollout.data, "Values not found!"
|
||||
vpred_t = np.stack(
|
||||
rollout.data["vf_preds"] + [np.array(rollout.last_r)]).squeeze()
|
||||
assert "vf_preds" in rollout, "Values not found!"
|
||||
vpred_t = np.stack(rollout["vf_preds"] +
|
||||
[np.array(rollout.last_r)]).squeeze()
|
||||
delta_t = traj["rewards"] + gamma * vpred_t[1:] - vpred_t[:-1]
|
||||
# This formula for the advantage comes
|
||||
# "Generalized Advantage Estimation": https://arxiv.org/abs/1506.02438
|
||||
traj["advantages"] = discount(delta_t, gamma * lambda_)
|
||||
traj["value_targets"] = traj["advantages"] + traj["vf_preds"]
|
||||
else:
|
||||
rewards_plus_v = np.stack(
|
||||
rollout.data["rewards"] + [np.array(rollout.last_r)]).squeeze()
|
||||
rewards_plus_v = np.stack(rollout["rewards"] +
|
||||
[np.array(rollout.last_r)]).squeeze()
|
||||
traj["advantages"] = discount(rewards_plus_v, gamma)[:-1]
|
||||
|
||||
for i in range(traj["advantages"].shape[0]):
|
||||
|
||||
@@ -56,9 +56,30 @@ class PartialRollout(object):
|
||||
terminal (bool): if rollout has terminated."""
|
||||
return self.data["dones"][-1]
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.data[key]
|
||||
|
||||
CompletedRollout = namedtuple(
|
||||
"CompletedRollout", ["episode_length", "episode_reward"])
|
||||
def __setitem__(self, key, item):
|
||||
self.data[key] = item
|
||||
|
||||
def keys(self):
|
||||
return self.data.keys()
|
||||
|
||||
def items(self):
|
||||
return self.data.items()
|
||||
|
||||
def __iter__(self):
|
||||
return self.data.__iter__()
|
||||
|
||||
def __next__(self):
|
||||
return self.data.__next__()
|
||||
|
||||
def __contains__(self, x):
|
||||
return x in self.data
|
||||
|
||||
|
||||
CompletedRollout = namedtuple("CompletedRollout",
|
||||
["episode_length", "episode_reward"])
|
||||
|
||||
|
||||
class SyncSampler(object):
|
||||
@@ -71,16 +92,15 @@ class SyncSampler(object):
|
||||
thread."""
|
||||
async = False
|
||||
|
||||
def __init__(self, env, policy, obs_filter,
|
||||
num_local_steps, horizon=None):
|
||||
def __init__(self, env, policy, obs_filter, num_local_steps, horizon=None):
|
||||
self.num_local_steps = num_local_steps
|
||||
self.horizon = horizon
|
||||
self.env = env
|
||||
self.policy = policy
|
||||
self._obs_filter = obs_filter
|
||||
self.rollout_provider = _env_runner(
|
||||
self.env, self.policy, self.num_local_steps, self.horizon,
|
||||
self._obs_filter)
|
||||
self.rollout_provider = _env_runner(self.env, self.policy,
|
||||
self.num_local_steps, self.horizon,
|
||||
self._obs_filter)
|
||||
self.metrics_queue = queue.Queue()
|
||||
|
||||
def get_data(self):
|
||||
@@ -108,10 +128,10 @@ class AsyncSampler(threading.Thread):
|
||||
accumulate and the gradient can be calculated on up to 5 batches."""
|
||||
async = True
|
||||
|
||||
def __init__(self, env, policy, obs_filter,
|
||||
num_local_steps, horizon=None):
|
||||
assert getattr(obs_filter, "is_concurrent", False), (
|
||||
"Observation Filter must support concurrent updates.")
|
||||
def __init__(self, env, policy, obs_filter, num_local_steps, horizon=None):
|
||||
assert getattr(
|
||||
obs_filter, "is_concurrent",
|
||||
False), ("Observation Filter must support concurrent updates.")
|
||||
threading.Thread.__init__(self)
|
||||
self.queue = queue.Queue(5)
|
||||
self.metrics_queue = queue.Queue()
|
||||
@@ -132,9 +152,9 @@ class AsyncSampler(threading.Thread):
|
||||
raise e
|
||||
|
||||
def _run(self):
|
||||
rollout_provider = _env_runner(
|
||||
self.env, self.policy, self.num_local_steps,
|
||||
self.horizon, self._obs_filter)
|
||||
rollout_provider = _env_runner(self.env, self.policy,
|
||||
self.num_local_steps, self.horizon,
|
||||
self._obs_filter)
|
||||
while True:
|
||||
# The timeout variable exists because apparently, if one worker
|
||||
# dies, the other workers won't die with it, unless the timeout is
|
||||
@@ -232,13 +252,14 @@ def _env_runner(env, policy, num_local_steps, horizon, obs_filter):
|
||||
action = np.concatenate(action, axis=0).flatten()
|
||||
|
||||
# Collect the experience.
|
||||
rollout.add(obs=last_observation,
|
||||
actions=action,
|
||||
rewards=reward,
|
||||
dones=terminal,
|
||||
features=last_features,
|
||||
new_obs=observation,
|
||||
**pi_info)
|
||||
rollout.add(
|
||||
obs=last_observation,
|
||||
actions=action,
|
||||
rewards=reward,
|
||||
dones=terminal,
|
||||
features=last_features,
|
||||
new_obs=observation,
|
||||
**pi_info)
|
||||
|
||||
last_observation = observation
|
||||
last_features = features
|
||||
@@ -247,8 +268,8 @@ def _env_runner(env, policy, num_local_steps, horizon, obs_filter):
|
||||
terminal_end = True
|
||||
yield CompletedRollout(length, rewards)
|
||||
|
||||
if (length >= horizon or
|
||||
not env.metadata.get("semantics.autoreset")):
|
||||
if (length >= horizon
|
||||
or not env.metadata.get("semantics.autoreset")):
|
||||
last_observation = obs_filter(env.reset())
|
||||
if hasattr(policy, "get_initial_features"):
|
||||
last_features = policy.get_initial_features()
|
||||
|
||||
Reference in New Issue
Block a user