[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)

## What do these changes do?

**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).

```
# CartPole-v0 on single core with 64x64 MLP:

# vector_width=1:
Actions per second 2720.1284458322966

# vector_width=8:
Actions per second 13773.035334888269

# vector_width=64:
Actions per second 37903.20472563333
```

**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.

**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).

Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
 ```
        gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
        rllib.ServingEnv => rllib.AsyncVectorEnv
```
This commit is contained in:
Eric Liang
2018-06-18 11:55:32 -07:00
committed by GitHub
parent 8560993b46
commit 7dee2c6735
28 changed files with 1218 additions and 342 deletions
+5 -1
View File
@@ -9,6 +9,9 @@ from ray.tune.registry import register_trainable
from ray.rllib.utils.policy_graph import PolicyGraph
from ray.rllib.utils.tf_policy_graph import TFPolicyGraph
from ray.rllib.utils.common_policy_evaluator import CommonPolicyEvaluator
from ray.rllib.utils.async_vector_env import AsyncVectorEnv
from ray.rllib.utils.vector_env import VectorEnv
from ray.rllib.utils.serving_env import ServingEnv
from ray.rllib.optimizers.sample_batch import SampleBatch
@@ -23,5 +26,6 @@ def _register_all():
_register_all()
__all__ = [
"PolicyGraph", "TFPolicyGraph", "CommonPolicyEvaluator", "SampleBatch"
"PolicyGraph", "TFPolicyGraph", "CommonPolicyEvaluator", "SampleBatch",
"AsyncVectorEnv", "VectorEnv", "ServingEnv",
]
+6 -2
View File
@@ -17,6 +17,8 @@ from ray.tune.trial import Resources
DEFAULT_CONFIG = {
# Number of workers (excluding master)
"num_workers": 2,
# Number of environments to evaluate vectorwise per worker.
"num_envs": 1,
# Size of rollout batch
"batch_size": 10,
# Use LSTM model - only applicable for image states
@@ -101,7 +103,8 @@ class A3CAgent(Agent):
batch_mode="truncate_episodes",
tf_session_creator=session_creator,
registry=self.registry, env_config=self.config["env_config"],
model_config=self.config["model"], policy_config=self.config)
model_config=self.config["model"], policy_config=self.config,
num_envs=self.config["num_envs"])
self.remote_evaluators = [
remote_cls.remote(
self.env_creator, self.policy_cls,
@@ -109,7 +112,8 @@ class A3CAgent(Agent):
batch_mode="truncate_episodes", sample_async=True,
tf_session_creator=session_creator,
registry=self.registry, env_config=self.config["env_config"],
model_config=self.config["model"], policy_config=self.config)
model_config=self.config["model"], policy_config=self.config,
num_envs=self.config["num_envs"])
for i in range(self.config["num_workers"])]
self.optimizer = AsyncOptimizer(
+4 -4
View File
@@ -2,6 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from threading import Lock
import torch
@@ -33,13 +34,12 @@ class SharedTorchPolicy(PolicyGraph):
self.optimizer = torch.optim.Adam(
self._model.parameters(), lr=self.config["lr"])
def compute_single_action(self, obs, state, is_training=False):
def compute_actions(self, obs, state, is_training=False):
assert not state, "RNN not supported"
with self.lock:
ob = torch.from_numpy(obs).float().unsqueeze(0)
ob = torch.from_numpy(np.array(obs)).float()
logits, values = self._model(ob)
samples = F.softmax(logits, dim=1).multinomial(1).squeeze()
values = values.squeeze()
samples = F.softmax(logits, dim=1).multinomial(1).squeeze(0)
return var_to_np(samples), [], {"vf_preds": var_to_np(values)}
def compute_gradients(self, samples):
@@ -1,85 +0,0 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
import torch.nn.functional as F
from ray.rllib.a3c.torchpolicy import TorchPolicy
from ray.rllib.models.pytorch.misc import var_to_np, convert_batch
from ray.rllib.models.catalog import ModelCatalog
class SharedTorchPolicy(TorchPolicy):
"""Assumes nonrecurrent."""
other_output = ["vf_preds"]
is_recurrent = False
def __init__(self, registry, ob_space, ac_space, config, **kwargs):
super(SharedTorchPolicy, self).__init__(registry, ob_space, ac_space,
config, **kwargs)
def _setup_graph(self, ob_space, ac_space):
_, self.logit_dim = ModelCatalog.get_action_dist(ac_space)
self._model = ModelCatalog.get_torch_model(
self.registry, ob_space, self.logit_dim, self.config["model"])
self.optimizer = torch.optim.Adam(
self._model.parameters(), lr=self.config["lr"])
def compute(self, ob, *args):
"""Should take in a SINGLE ob"""
with self.lock:
ob = torch.from_numpy(ob).float().unsqueeze(0)
logits, values = self._model(ob)
# TODO(alok): Support non-categorical distributions. Multinomial
# is only for categorical.
sampled_actions = F.softmax(logits, dim=1).multinomial(1).squeeze()
values = values.squeeze()
return var_to_np(sampled_actions), {"vf_preds": var_to_np(values)}
def compute_logits(self, ob, *args):
with self.lock:
ob = torch.from_numpy(ob).float().unsqueeze(0)
res = self._model.hidden_layers(ob)
return var_to_np(self._model.logits(res))
def value(self, ob, *args):
with self.lock:
ob = torch.from_numpy(ob).float().unsqueeze(0)
res = self._model.hidden_layers(ob)
res = self._model.value_branch(res)
res = res.squeeze()
return var_to_np(res)
def _evaluate(self, obs, actions):
"""Passes in multiple obs."""
logits, values = self._model(obs)
log_probs = F.log_softmax(logits, dim=1)
probs = F.softmax(logits, dim=1)
action_log_probs = log_probs.gather(1, actions.view(-1, 1))
# TODO(alok): set distribution based on action space and use its
# `.entropy()` method to calculate automatically
entropy = -(log_probs * probs).sum(-1).sum()
return values, action_log_probs, entropy
def _backward(self, batch):
"""Loss is encoded in here. Defining a new loss function
would start by rewriting this function"""
states, actions, advs, rs, _ = convert_batch(batch)
values, action_log_probs, entropy = self._evaluate(states, actions)
pi_err = -advs.dot(action_log_probs.reshape(-1))
value_err = F.mse_loss(values.reshape(-1), rs)
self.optimizer.zero_grad()
overall_err = sum([
pi_err,
self.config["vf_loss_coeff"] * value_err,
self.config["entropy_coeff"] * entropy,
])
overall_err.backward()
torch.nn.utils.clip_grad_norm_(self._model.parameters(),
self.config["grad_clip"])
-82
View File
@@ -1,82 +0,0 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
from ray.rllib.a3c.policy import Policy
from threading import Lock
class TorchPolicy(Policy):
"""The policy base class for Torch.
The model is a separate object than the policy. This could be changed
in the future."""
def __init__(self,
registry,
ob_space,
action_space,
config,
name="local",
summarize=True):
self.registry = registry
self.local_steps = 0
self.config = config
self.summarize = summarize
self._setup_graph(ob_space, action_space)
torch.set_num_threads(2)
self.lock = Lock()
def apply_gradients(self, grads):
self.optimizer.zero_grad()
for g, p in zip(grads, self._model.parameters()):
p.grad = torch.from_numpy(g)
self.optimizer.step()
def get_weights(self):
# !! This only returns references to the data.
return self._model.state_dict()
def set_weights(self, weights):
with self.lock:
self._model.load_state_dict(weights)
def compute_gradients(self, samples):
"""_backward generates the gradient in each model parameter.
This is taken out.
Args:
samples: SampleBatch of data needed for gradient calculation.
Return:
gradients (list of np arrays): List of gradients
info (dict): Extra information (user-defined)"""
with self.lock:
self._backward(samples)
# Note that return values are just references;
# calling zero_grad will modify the values
return [p.grad.data.numpy() for p in self._model.parameters()], {}
def model_update(self, batch):
"""Implements compute + apply
TODO(rliaw): Pytorch has nice caching property that doesn't require
full batch to be passed in. Can exploit that later"""
with self.lock:
self._backward(batch)
self.optimizer.step()
def _setup_graph(ob_space, action_space):
raise NotImplementedError
def _backward(self, batch):
"""Implements the loss function and calculates the gradient.
Pytorch automatically generates a backward trace for each tensor.
Assumption right now is that variables are moved, so the backward
trace is lost.
This function regenerates the backward trace and
caluclates the gradient."""
raise NotImplementedError
+2
View File
@@ -95,6 +95,8 @@ DEFAULT_CONFIG = {
# to increase if your environment is particularly slow to sample, or if
# you"re using the Async or Ape-X optimizers.
"num_workers": 0,
# Number of environments to evaluate vectorwise per worker.
"num_envs": 1,
# Whether to allocate GPUs for workers (if > 0).
"num_gpus_per_worker": 0,
# Whether to allocate CPUs for workers (if > 0).
+8 -4
View File
@@ -89,6 +89,8 @@ DEFAULT_CONFIG = {
# to increase if your environment is particularly slow to sample, or if
# you"re using the Async or Ape-X optimizers.
"num_workers": 0,
# Number of environments to evaluate vectorwise per worker.
"num_envs": 1,
# Whether to allocate GPUs for workers (if > 0).
"num_gpus_per_worker": 0,
# Whether to allocate CPUs for workers (if > 0).
@@ -125,10 +127,11 @@ class DQNAgent(Agent):
self.local_evaluator = CommonPolicyEvaluator(
self.env_creator, self._policy_graph,
batch_steps=adjusted_batch_size,
batch_mode="pack_episodes", preprocessor_pref="deepmind",
batch_mode="truncate_episodes", preprocessor_pref="deepmind",
compress_observations=True,
registry=self.registry, env_config=self.config["env_config"],
model_config=self.config["model"], policy_config=self.config)
model_config=self.config["model"], policy_config=self.config,
num_envs=self.config["num_envs"])
remote_cls = CommonPolicyEvaluator.as_remote(
num_cpus=self.config["num_cpus_per_worker"],
num_gpus=self.config["num_gpus_per_worker"])
@@ -136,10 +139,11 @@ class DQNAgent(Agent):
remote_cls.remote(
self.env_creator, self._policy_graph,
batch_steps=adjusted_batch_size,
batch_mode="pack_episodes", preprocessor_pref="deepmind",
batch_mode="truncate_episodes", preprocessor_pref="deepmind",
compress_observations=True,
registry=self.registry, env_config=self.config["env_config"],
model_config=self.config["model"], policy_config=self.config)
model_config=self.config["model"], policy_config=self.config,
num_envs=self.config["num_envs"])
for _ in range(self.config["num_workers"])]
self.exploration0 = self._make_exploration_schedule(0)
+1 -3
View File
@@ -223,11 +223,9 @@ def _postprocess_dqn(policy_graph, sample_batch):
"obs": obs, "actions": actions, "rewards": rewards,
"new_obs": new_obs, "dones": dones,
"weights": np.ones_like(rewards)})
assert batch.count == policy_graph.config["sample_batch_size"], \
(batch.count, policy_graph.config["sample_batch_size"])
# Prioritize on the worker side
if policy_graph.config["worker_side_prioritization"]:
if batch.count > 0 and policy_graph.config["worker_side_prioritization"]:
td_errors = policy_graph.compute_td_error(
batch["obs"], batch["actions"], batch["rewards"],
batch["new_obs"], batch["dones"], batch["weights"])
+1 -1
View File
@@ -63,7 +63,7 @@ class Categorical(ActionDistribution):
reduction_indices=[1])
def sample(self):
return tf.multinomial(self.inputs, 1)[0]
return tf.squeeze(tf.multinomial(self.inputs, 1), axis=1)
class DiagGaussian(ActionDistribution):
-6
View File
@@ -125,22 +125,16 @@ def get_preprocessor(space):
legacy_patch_shapes(space)
obs_shape = space.shape
print("Observation shape is {}".format(obs_shape))
if isinstance(space, gym.spaces.Discrete):
print("Using one-hot preprocessor for discrete envs.")
preprocessor = OneHotPreprocessor
elif obs_shape == ATARI_OBS_SHAPE:
print("Assuming Atari pixel env, using AtariPixelPreprocessor.")
preprocessor = AtariPixelPreprocessor
elif obs_shape == ATARI_RAM_OBS_SHAPE:
print("Assuming Atari ram env, using AtariRamPreprocessor.")
preprocessor = AtariRamPreprocessor
elif isinstance(space, gym.spaces.Tuple):
print("Using a TupleFlatteningPreprocessor")
preprocessor = TupleFlatteningPreprocessor
else:
print("Not using any observation preprocessor.")
preprocessor = NoPreprocessor
return preprocessor
+1 -1
View File
@@ -56,5 +56,5 @@ class FullyConnectedNetwork(Model):
value: value function for each state"""
res = self.hidden_layers(obs)
logits = self.logits(res)
value = self.value_branch(res).reshape(-1)
value = self.value_branch(res).squeeze(1)
return logits, value
+1 -1
View File
@@ -65,5 +65,5 @@ class VisionNetwork(Model):
value (PyTorch): value function for each state"""
res = self.hidden_layers(obs)
logits = self.logits(res)
value = self.value_branch(res)
value = self.value_branch(res).squeeze(1)
return logits, value
@@ -105,6 +105,11 @@ class PolicyEvaluator(object):
return os.uname()[1]
def apply(self, func, *args):
"""Apply the given function to this evaluator instance."""
return func(self, *args)
class TFMultiGPUSupport(PolicyEvaluator):
"""The multi-GPU TF optimizer requires additional TF-specific support.
@@ -110,3 +110,23 @@ class PolicyOptimizer(object):
self.num_steps_trained = data[0]
self.num_steps_sampled = data[1]
def foreach_evaluator(self, func):
"""Apply the given function to each evaluator instance."""
local_result = [func(self.local_evaluator)]
remote_results = ray.get(
[ev.apply.remote(func) for ev in self.remote_evaluators])
return local_result + remote_results
def foreach_evaluator_with_index(self, func):
"""Apply the given function to each evaluator instance.
The index will be passed as the second arg to the given function.
"""
local_result = [func(self.local_evaluator, 0)]
remote_results = ray.get(
[ev.apply.remote(func, i + 1)
for i, ev in enumerate(self.remote_evaluators)])
return local_result + remote_results
+34 -3
View File
@@ -7,17 +7,47 @@ import numpy as np
class SampleBatchBuilder(object):
"""Util to build a SampleBatch incrementally."""
"""Util to build a SampleBatch incrementally.
For efficiency, SampleBatches hold values in column form (as arrays).
However, it is useful to add data one row (dict) at a time.
"""
def __init__(self):
self.postprocessed = []
self.buffers = collections.defaultdict(list)
self.count = 0
def add_values(self, **values):
"""Add the given dictionary (row) of values to this batch."""
for k, v in values.items():
self.buffers[k].append(v)
self.count += 1
def build(self):
return SampleBatch({k: np.array(v) for k, v in self.buffers.items()})
def postprocess_batch_so_far(self, postprocessor):
"""Apply the given postprocessor to any unprocessed rows."""
batch = postprocessor(self._build_buffers())
self.postprocessed.append(batch)
def build_and_reset(self, postprocessor):
"""Returns a sample batch including all previously added values.
Any unprocessed rows will be first postprocessed with the given
postprocessor. The internal state of this builder will be reset.
"""
self.postprocess_batch_so_far(postprocessor)
batch = SampleBatch.concat_samples(self.postprocessed)
self.postprocessed = []
self.count = 0
return batch
def _build_buffers(self):
batch = SampleBatch({k: np.array(v) for k, v in self.buffers.items()})
self.buffers.clear()
return batch
class SampleBatch(object):
@@ -41,6 +71,7 @@ class SampleBatch(object):
@staticmethod
def concat_samples(samples):
out = {}
samples = [s for s in samples if s.count > 0]
for k in samples[0].keys():
out[k] = np.concatenate([s[k] for s in samples])
return SampleBatch(out)
+4 -1
View File
@@ -12,7 +12,9 @@ from ray.tune.trial import Resources
DEFAULT_CONFIG = {
# Number of workers (excluding master)
"num_workers": 4,
"num_workers": 0,
# Number of environments to evaluate vectorwise per worker.
"num_envs": 1,
# Size of rollout batch
"batch_size": 512,
# Discount factor of MDP
@@ -57,6 +59,7 @@ class PGAgent(Agent):
"model_config": self.config["model"],
"env_config": self.config["env_config"],
"policy_config": self.config,
"num_envs": self.config["num_envs"],
},
num_workers=self.config["num_workers"],
optimizer_config=self.config["optimizer"])
+6 -3
View File
@@ -82,11 +82,14 @@ class ProximalPolicyGraph(object):
self.policy_results = [
self.sampler, self.curr_logits, tf.constant("NA")]
def compute_single_action(self, observation, features, is_training=False):
def compute_actions(self, observations, features, is_training=False):
action, logprobs, vf = self.sess.run(
self.policy_results,
feed_dict={self.observations: [observation]})
return action[0], [], {"vf_preds": vf[0], "logprobs": logprobs[0]}
feed_dict={self.observations: observations})
return action, [], {"vf_preds": vf, "logprobs": logprobs}
def postprocess_trajectory(self, batch):
return batch
def get_initial_state(self):
return []
+3 -3
View File
@@ -69,7 +69,7 @@ DEFAULT_CONFIG = {
# number of steps is obtained
"min_steps_per_task": 200,
# Number of actors used to collect the rollouts
"num_workers": 5,
"num_workers": 2,
# Whether to allocate GPUs for workers (if > 0).
"num_gpus_per_worker": 0,
# Whether to allocate CPUs for workers (if > 0).
@@ -299,5 +299,5 @@ class PPOAgent(Agent):
def compute_action(self, observation):
observation = self.local_evaluator.obs_filter(
observation, update=False)
return self.local_evaluator.common_policy.compute_single_action(
observation, [], False)[0]
return self.local_evaluator.common_policy.compute_actions(
[observation], [], False)[0][0]
@@ -7,9 +7,13 @@ import time
import unittest
import ray
from ray.rllib.utils.common_policy_evaluator import CommonPolicyEvaluator
from ray.rllib.pg import PGAgent
from ray.rllib.utils.common_policy_evaluator import CommonPolicyEvaluator, \
collect_metrics
from ray.rllib.utils.policy_graph import PolicyGraph
from ray.rllib.utils.process_rollout import compute_advantages
from ray.rllib.utils.vector_env import VectorEnv
from ray.tune.registry import register_env
class MockPolicyGraph(PolicyGraph):
@@ -20,6 +24,55 @@ class MockPolicyGraph(PolicyGraph):
return compute_advantages(batch, 100.0, 0.9, use_gae=False)
class BadPolicyGraph(PolicyGraph):
def compute_actions(self, obs_batch, state_batches, is_training=False):
raise Exception("intentional error")
def postprocess_trajectory(self, batch):
return compute_advantages(batch, 100.0, 0.9, use_gae=False)
class MockEnv(gym.Env):
def __init__(self, episode_length):
self.episode_length = episode_length
self.i = 0
self.observation_space = gym.spaces.Discrete(1)
self.action_space = gym.spaces.Discrete(2)
def reset(self):
self.i = 0
return self.i
def step(self, action):
self.i += 1
return 0, 1, self.i >= self.episode_length, {}
class MockVectorEnv(VectorEnv):
def __init__(self, episode_length, num_envs):
self.envs = [
MockEnv(episode_length) for _ in range(num_envs)]
self.observation_space = gym.spaces.Discrete(1)
self.action_space = gym.spaces.Discrete(2)
self.num_envs = num_envs
def vector_reset(self):
return [e.reset() for e in self.envs]
def reset_at(self, index):
return self.envs[index].reset()
def vector_step(self, actions):
obs_batch, rew_batch, done_batch, info_batch = [], [], [], []
for i in range(len(self.envs)):
obs, rew, done, info = self.envs[i].step(actions[i])
obs_batch.append(obs)
rew_batch.append(rew)
done_batch.append(done)
info_batch.append(info)
return obs_batch, rew_batch, done_batch, info_batch
class TestCommonPolicyEvaluator(unittest.TestCase):
def testBasic(self):
ev = CommonPolicyEvaluator(
@@ -30,43 +83,134 @@ class TestCommonPolicyEvaluator(unittest.TestCase):
self.assertIn(key, batch)
self.assertGreater(batch["advantages"][0], 1)
def testPackEpisodes(self):
for batch_size in [1, 10, 100, 1000]:
ev = CommonPolicyEvaluator(
env_creator=lambda _: gym.make("CartPole-v0"),
policy_graph=MockPolicyGraph,
batch_steps=batch_size,
batch_mode="pack_episodes")
def testQueryEvaluators(self):
register_env("test", lambda _: gym.make("CartPole-v0"))
pg = PGAgent(env="test", config={"num_workers": 2, "batch_size": 5})
results = pg.optimizer.foreach_evaluator(lambda ev: ev.batch_steps)
results2 = pg.optimizer.foreach_evaluator_with_index(
lambda ev, i: (i, ev.batch_steps))
self.assertEqual(results, [5, 5, 5])
self.assertEqual(results2, [(0, 5), (1, 5), (2, 5)])
def testMetrics(self):
ev = CommonPolicyEvaluator(
env_creator=lambda _: MockEnv(episode_length=10),
policy_graph=MockPolicyGraph, batch_mode="complete_episodes")
remote_ev = CommonPolicyEvaluator.as_remote().remote(
env_creator=lambda _: MockEnv(episode_length=10),
policy_graph=MockPolicyGraph, batch_mode="complete_episodes")
ev.sample()
ray.get(remote_ev.sample.remote())
result = collect_metrics(ev, [remote_ev])
self.assertEqual(result.episodes_total, 20)
self.assertEqual(result.episode_reward_mean, 10)
def testAsync(self):
ev = CommonPolicyEvaluator(
env_creator=lambda _: gym.make("CartPole-v0"),
sample_async=True,
policy_graph=MockPolicyGraph)
batch = ev.sample()
for key in ["obs", "actions", "rewards", "dones", "advantages"]:
self.assertIn(key, batch)
self.assertGreater(batch["advantages"][0], 1)
def testAutoConcat(self):
ev = CommonPolicyEvaluator(
env_creator=lambda _: MockEnv(episode_length=40),
policy_graph=MockPolicyGraph,
sample_async=True,
batch_steps=10,
batch_mode="truncate_episodes",
observation_filter="ConcurrentMeanStdFilter")
time.sleep(2)
batch = ev.sample()
self.assertEqual(batch.count, 40) # auto-concat up to 5 episodes
def testAutoVectorization(self):
ev = CommonPolicyEvaluator(
env_creator=lambda _: MockEnv(episode_length=20),
policy_graph=MockPolicyGraph,
batch_mode="truncate_episodes",
batch_steps=16, num_envs=8)
for _ in range(8):
batch = ev.sample()
self.assertEqual(batch.count, batch_size)
self.assertEqual(batch.count, 16)
result = collect_metrics(ev, [])
self.assertEqual(result.episodes_total, 0)
for _ in range(8):
batch = ev.sample()
self.assertEqual(batch.count, 16)
result = collect_metrics(ev, [])
self.assertEqual(result.episodes_total, 8)
def testBatchDivisibilityCheck(self):
self.assertRaises(
ValueError,
lambda: CommonPolicyEvaluator(
env_creator=lambda _: MockEnv(episode_length=8),
policy_graph=MockPolicyGraph,
batch_mode="truncate_episodes",
batch_steps=15, num_envs=4))
def testBatchesSmallerWhenVectorized(self):
ev = CommonPolicyEvaluator(
env_creator=lambda _: MockEnv(episode_length=8),
policy_graph=MockPolicyGraph,
batch_mode="truncate_episodes",
batch_steps=16, num_envs=4)
batch = ev.sample()
self.assertEqual(batch.count, 16)
result = collect_metrics(ev, [])
self.assertEqual(result.episodes_total, 0)
batch = ev.sample()
result = collect_metrics(ev, [])
self.assertEqual(result.episodes_total, 4)
def testVectorEnvSupport(self):
ev = CommonPolicyEvaluator(
env_creator=lambda _: MockVectorEnv(
episode_length=20, num_envs=8),
policy_graph=MockPolicyGraph,
batch_mode="truncate_episodes",
batch_steps=10)
for _ in range(8):
batch = ev.sample()
self.assertEqual(batch.count, 10)
result = collect_metrics(ev, [])
self.assertEqual(result.episodes_total, 0)
for _ in range(8):
batch = ev.sample()
self.assertEqual(batch.count, 10)
result = collect_metrics(ev, [])
self.assertEqual(result.episodes_total, 8)
def testTruncateEpisodes(self):
ev = CommonPolicyEvaluator(
env_creator=lambda _: gym.make("CartPole-v0"),
env_creator=lambda _: MockEnv(10),
policy_graph=MockPolicyGraph,
batch_steps=2,
batch_steps=15,
batch_mode="truncate_episodes")
batch = ev.sample()
self.assertEqual(batch.count, 2)
ev = CommonPolicyEvaluator(
env_creator=lambda _: gym.make("CartPole-v0"),
policy_graph=MockPolicyGraph,
batch_steps=1000,
batch_mode="truncate_episodes")
self.assertLess(batch.count, 200)
self.assertEqual(batch.count, 15)
def testCompleteEpisodes(self):
ev = CommonPolicyEvaluator(
env_creator=lambda _: gym.make("CartPole-v0"),
env_creator=lambda _: MockEnv(10),
policy_graph=MockPolicyGraph,
batch_steps=2,
batch_steps=5,
batch_mode="complete_episodes")
batch = ev.sample()
self.assertGreater(batch.count, 2)
self.assertTrue(batch["dones"][-1])
self.assertEqual(batch.count, 10)
def testCompleteEpisodesPacking(self):
ev = CommonPolicyEvaluator(
env_creator=lambda _: MockEnv(10),
policy_graph=MockPolicyGraph,
batch_steps=15,
batch_mode="complete_episodes")
batch = ev.sample()
self.assertGreater(batch.count, 2)
self.assertTrue(batch["dones"][-1])
self.assertEqual(batch.count, 20)
def testFilterSync(self):
ev = CommonPolicyEvaluator(
@@ -129,5 +273,5 @@ class TestCommonPolicyEvaluator(unittest.TestCase):
if __name__ == '__main__':
ray.init()
ray.init(num_cpus=5)
unittest.main(verbosity=2)
+192
View File
@@ -0,0 +1,192 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gym
import numpy as np
import random
import unittest
import uuid
import ray
from ray.rllib.dqn import DQNAgent
from ray.rllib.pg import PGAgent
from ray.rllib.utils.common_policy_evaluator import CommonPolicyEvaluator
from ray.rllib.utils.serving_env import ServingEnv
from ray.rllib.test.test_common_policy_evaluator import BadPolicyGraph, \
MockPolicyGraph, MockEnv
from ray.tune.registry import register_env
class SimpleServing(ServingEnv):
def __init__(self, env):
ServingEnv.__init__(self, env.action_space, env.observation_space)
self.env = env
def run(self):
self.start_episode()
obs = self.env.reset()
while True:
action = self.get_action(obs)
obs, reward, done, info = self.env.step(action)
self.log_returns(reward, info=info)
if done:
self.end_episode(obs)
obs = self.env.reset()
self.start_episode()
class PartOffPolicyServing(ServingEnv):
def __init__(self, env, off_pol_frac):
ServingEnv.__init__(self, env.action_space, env.observation_space)
self.env = env
self.off_pol_frac = off_pol_frac
def run(self):
self.start_episode()
obs = self.env.reset()
while True:
if random.random() < self.off_pol_frac:
action = self.env.action_space.sample()
self.log_action(obs, action)
else:
action = self.get_action(obs)
obs, reward, done, info = self.env.step(action)
self.log_returns(reward, info=info)
if done:
self.end_episode(obs)
obs = self.env.reset()
self.start_episode()
class SimpleOffPolicyServing(ServingEnv):
def __init__(self, env):
ServingEnv.__init__(self, env.action_space, env.observation_space)
self.env = env
def run(self):
self.start_episode()
obs = self.env.reset()
while True:
# Take random actions
action = self.env.action_space.sample()
self.log_action(obs, action)
obs, reward, done, info = self.env.step(action)
self.log_returns(reward, info=info)
if done:
self.end_episode(obs)
obs = self.env.reset()
self.start_episode()
class MultiServing(ServingEnv):
def __init__(self, env_creator):
self.env_creator = env_creator
self.env = env_creator()
ServingEnv.__init__(
self, self.env.action_space, self.env.observation_space)
def run(self):
envs = [self.env_creator() for _ in range(5)]
cur_obs = {}
eids = {}
while True:
active = np.random.choice(range(5), 2, replace=False)
for i in active:
if i not in cur_obs:
eids[i] = uuid.uuid4().hex
self.start_episode(episode_id=eids[i])
cur_obs[i] = envs[i].reset()
actions = [
self.get_action(
cur_obs[i], episode_id=eids[i]) for i in active]
for i, action in zip(active, actions):
obs, reward, done, _ = envs[i].step(action)
cur_obs[i] = obs
self.log_returns(reward, episode_id=eids[i])
if done:
self.end_episode(obs, episode_id=eids[i])
del cur_obs[i]
class TestServingEnv(unittest.TestCase):
def testServingEnvCompleteEpisodes(self):
ev = CommonPolicyEvaluator(
env_creator=lambda _: SimpleServing(MockEnv(25)),
policy_graph=MockPolicyGraph,
batch_steps=40,
batch_mode="complete_episodes")
for _ in range(3):
batch = ev.sample()
self.assertEqual(batch.count, 50)
def testServingEnvTruncateEpisodes(self):
ev = CommonPolicyEvaluator(
env_creator=lambda _: SimpleServing(MockEnv(25)),
policy_graph=MockPolicyGraph,
batch_steps=40,
batch_mode="truncate_episodes")
for _ in range(3):
batch = ev.sample()
self.assertEqual(batch.count, 40)
def testServingEnvOffPolicy(self):
ev = CommonPolicyEvaluator(
env_creator=lambda _: SimpleOffPolicyServing(MockEnv(25)),
policy_graph=MockPolicyGraph,
batch_steps=40,
batch_mode="complete_episodes")
for _ in range(3):
batch = ev.sample()
self.assertEqual(batch.count, 50)
def testServingEnvBadActions(self):
ev = CommonPolicyEvaluator(
env_creator=lambda _: SimpleServing(MockEnv(25)),
policy_graph=BadPolicyGraph,
sample_async=True,
batch_steps=40,
batch_mode="truncate_episodes")
self.assertRaises(Exception, lambda: ev.sample())
def testTrainCartpoleOffPolicy(self):
register_env(
"test3", lambda _: PartOffPolicyServing(
gym.make("CartPole-v0"), off_pol_frac=0.2))
dqn = DQNAgent(env="test3", config={"exploration_fraction": 0.001})
for i in range(100):
result = dqn.train()
print("Iteration {}, reward {}, timesteps {}".format(
i, result.episode_reward_mean, result.timesteps_total))
if result.episode_reward_mean >= 100:
return
raise Exception("failed to improve reward")
def testTrainCartpole(self):
register_env(
"test", lambda _: SimpleServing(gym.make("CartPole-v0")))
pg = PGAgent(env="test", config={"num_workers": 0})
for i in range(100):
result = pg.train()
print("Iteration {}, reward {}, timesteps {}".format(
i, result.episode_reward_mean, result.timesteps_total))
if result.episode_reward_mean >= 100:
return
raise Exception("failed to improve reward")
def testTrainCartpoleMulti(self):
register_env(
"test2", lambda _: MultiServing(lambda: gym.make("CartPole-v0")))
pg = PGAgent(env="test2", config={"num_workers": 0})
for i in range(100):
result = pg.train()
print("Iteration {}, reward {}, timesteps {}".format(
i, result.episode_reward_mean, result.timesteps_total))
if result.episode_reward_mean >= 100:
return
raise Exception("failed to improve reward")
if __name__ == '__main__':
ray.init()
unittest.main(verbosity=2)
+115
View File
@@ -0,0 +1,115 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
class AsyncVectorEnv(object):
"""The lowest-level env interface used by RLlib for sampling.
AsyncVectorEnv models multiple agents executing asynchronously. A call to
poll() returns observations from ready agents, and actions for those agents
can be sent back via send_actions().
All other env types can be adapted to AsyncVectorEnv. RLlib handles these
conversions internally in CommonPolicyEvaluator, for example:
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
Examples:
>>> env = MyAsyncVectorEnv()
>>> obs, rewards, dones, infos, off_policy_actions = env.poll()
>>> print(obs)
{"car_0": [2.4, 1.6], "car_1": [3.4, -3.2]}
>>> env.send_actions({"car_0": 0, "car_1": 1})
>>> obs, rewards, dones, infos, off_policy_actions = env.poll()
>>> print(obs)
{"car_0": [4.1, 1.7], "car_1": [3.2, -4.2]}
"""
def poll(self):
"""Returns observations from ready agents.
The returns are dicts mapping from agent episode ids to values. The
number of agents can vary over time.
Returns:
obs (dict): New observations for each ready episode.
rewards (dict): Reward values for each ready episode. If the
episode is just started, the value will be None.
dones (dict): Done values for each ready episode. If True, the
episode is terminated.
infos (dict): Info values for each ready episode.
off_policy_actions (dict): Agents may take off-policy actions. When
that happens, there will be an entry in this dict that contains
the taken action.
"""
raise NotImplementedError
def send_actions(self, action_dict):
"""Called to send actions back to running agents in this env.
Arguments:
action_dict (dict): Actions for each agent to take.
"""
raise NotImplementedError
def try_reset(self, agent_id):
"""Attempt to reset the agent with the given id.
If the environment does not support synchronous reset, None can be
returned here.
Returns:
obs (obj|None): Resetted observation or None if not supported.
"""
return None
def get_unwrapped(self):
"""Return a reference to some underlying gym env, if any.
Returns:
env (gym.Env|None): Underlying gym env or None.
"""
return None
class _VectorEnvToAsync(AsyncVectorEnv):
"""Wraps VectorEnv to implement AsyncVectorEnv.
We assume the caller will always send the full vector of actions in each
call to send_actions(), and that they call reset_at() on all completed
environments before calling send_actions().
"""
def __init__(self, vector_env):
self.vector_env = vector_env
self.num_envs = vector_env.num_envs
self.new_obs = self.vector_env.vector_reset()
self.cur_rewards = [None for _ in range(self.num_envs)]
self.cur_dones = [False for _ in range(self.num_envs)]
self.cur_infos = [None for _ in range(self.num_envs)]
def poll(self):
new_obs = dict(enumerate(self.new_obs))
rewards = dict(enumerate(self.cur_rewards))
dones = dict(enumerate(self.cur_dones))
infos = dict(enumerate(self.cur_infos))
self.new_obs = []
self.cur_rewards = []
self.cur_dones = []
self.cur_infos = []
return new_obs, rewards, dones, infos, {}
def send_actions(self, action_dict):
action_vector = [None] * self.num_envs
for i in range(self.num_envs):
action_vector[i] = action_dict[i]
self.new_obs, self.cur_rewards, self.cur_dones, self.cur_infos = \
self.vector_env.vector_step(action_vector)
def try_reset(self, agent_id):
return self.vector_env.reset_at(agent_id)
def get_unwrapped(self):
return self.vector_env.get_unwrapped()
+4
View File
@@ -6,6 +6,10 @@ import cv2
cv2.ocl.setUseOpenCL(False)
def is_atari(env):
return hasattr(env, "unwrapped") and hasattr(env.unwrapped, "ale")
class NoopResetEnv(gym.Wrapper):
def __init__(self, env, noop_max=30, random_starts=False):
"""Sample initial states by taking random number of no-ops on reset.
@@ -8,12 +8,16 @@ import tensorflow as tf
import ray
from ray.rllib.models import ModelCatalog
from ray.rllib.optimizers import SampleBatch
from ray.rllib.optimizers.policy_evaluator import PolicyEvaluator
from ray.rllib.utils.atari_wrappers import wrap_deepmind
from ray.rllib.utils.async_vector_env import AsyncVectorEnv, _VectorEnvToAsync
from ray.rllib.utils.atari_wrappers import wrap_deepmind, is_atari
from ray.rllib.utils.compression import pack
from ray.rllib.utils.filter import get_filter
from ray.rllib.utils.sampler import AsyncSampler, SyncSampler
from ray.rllib.utils.serving_env import ServingEnv, _ServingEnvToAsync
from ray.rllib.utils.tf_policy_graph import TFPolicyGraph
from ray.rllib.utils.vector_env import VectorEnv
from ray.tune.registry import get_registry
from ray.tune.result import TrainingResult
@@ -53,10 +57,8 @@ def collect_metrics(local_evaluator, remote_evaluators):
class CommonPolicyEvaluator(PolicyEvaluator):
"""Policy evaluator implementation that operates on a rllib.PolicyGraph.
TODO: vector env
TODO: multi-agent
TODO: consumer buffering for multi-agent
TODO: complete episode batch mode
TODO: multi-gpu
Examples:
# Create a policy evaluator and using it to collect experiences.
@@ -89,9 +91,11 @@ class CommonPolicyEvaluator(PolicyEvaluator):
tf_session_creator=None,
batch_steps=100,
batch_mode="truncate_episodes",
episode_horizon=None,
preprocessor_pref="rllib",
sample_async=False,
compress_observations=False,
num_envs=1,
observation_filter="NoFilter",
registry=None,
env_config=None,
@@ -108,13 +112,20 @@ class CommonPolicyEvaluator(PolicyEvaluator):
This is optional and only useful with TFPolicyGraph.
batch_steps (int): The target number of env transitions to include
in each sample batch returned from this evaluator.
batch_mode (str): One of the following choices:
complete_episodes: each batch will be at least batch_steps
in size, and will include one or more complete episodes.
truncate_episodes: each batch will be around batch_steps
in size, and include transitions from one episode only.
pack_episodes: each batch will be exactly batch_steps in
size, and may include transitions from multiple episodes.
batch_mode (str): One of the following batch modes:
"truncate_episodes": Each call to sample() will return a batch
of exactly `batch_steps` in size. Episodes may be truncated
in order to meet this size requirement. When
`num_envs > 1`, episodes will be truncated to sequences of
`batch_size / num_envs` in length.
"complete_episodes": Each call to sample() will return a batch
of at least `batch_steps in size. Episodes will not be
truncated, but multiple episodes may be packed within one
batch to meet the batch size. Note that when
`num_envs > 1`, episode steps will be buffered until the
episode completes, and hence batches may contain
significant amounts of off-policy data.
episode_horizon (int): Whether to stop episodes at this horizon.
preprocessor_pref (str): Whether to prefer RLlib preprocessors
("rllib") or deepmind ("deepmind") when applicable.
sample_async (bool): Whether to compute samples asynchronously in
@@ -122,6 +133,9 @@ class CommonPolicyEvaluator(PolicyEvaluator):
to be slightly off-policy.
compress_observations (bool): If true, compress the observations
returned.
num_envs (int): If more than one, will create multiple envs
and vectorize the computation of actions. This has no effect if
if the env already implements VectorEnv.
observation_filter (str): Name of observation filter to use.
registry (tune.Registry): User-registered objects. Pass in the
value from tune.registry.get_registry() if you're having
@@ -135,9 +149,6 @@ class CommonPolicyEvaluator(PolicyEvaluator):
env_config = env_config or {}
policy_config = policy_config or {}
model_config = model_config or {}
assert batch_mode in [
"complete_episodes", "truncate_episodes", "pack_episodes"]
self.env_creator = env_creator
self.policy_graph = policy_graph
self.batch_steps = batch_steps
@@ -145,15 +156,25 @@ class CommonPolicyEvaluator(PolicyEvaluator):
self.compress_observations = compress_observations
self.env = env_creator(env_config)
is_atari = hasattr(self.env.unwrapped, "ale")
if is_atari and "custom_preprocessor" not in model_config and \
if isinstance(self.env, VectorEnv) or \
isinstance(self.env, ServingEnv) or \
isinstance(self.env, AsyncVectorEnv):
def wrap(env):
return env # we can't auto-wrap these env types
elif is_atari(self.env) and \
"custom_preprocessor" not in model_config and \
preprocessor_pref == "deepmind":
self.env = wrap_deepmind(self.env, dim=model_config.get("dim", 80))
def wrap(env):
return wrap_deepmind(env, dim=model_config.get("dim", 80))
else:
self.env = ModelCatalog.get_preprocessor_as_wrapper(
registry, self.env, model_config)
def wrap(env):
return ModelCatalog.get_preprocessor_as_wrapper(
registry, env, model_config)
self.env = wrap(self.env)
def make_env():
return wrap(env_creator(env_config))
self.vectorized = hasattr(self.env, "vector_reset")
self.policy_map = {}
if issubclass(policy_graph, TFPolicyGraph):
@@ -179,24 +200,41 @@ class CommonPolicyEvaluator(PolicyEvaluator):
observation_filter, self.env.observation_space.shape)
self.filters = {"obs_filter": self.obs_filter}
if self.vectorized:
raise NotImplementedError("Vector envs not yet supported")
else:
if batch_mode not in [
"pack_episodes", "truncate_episodes", "complete_episodes"]:
raise NotImplementedError("Batch mode not yet supported")
pack = batch_mode == "pack_episodes"
if batch_mode == "complete_episodes":
batch_steps = 999999
if sample_async:
self.sampler = AsyncSampler(
self.env, self.policy_map["default"], self.obs_filter,
batch_steps, pack=pack)
self.sampler.start()
# Always use vector env for consistency even if num_envs = 1
if not isinstance(self.env, AsyncVectorEnv):
if isinstance(self.env, ServingEnv):
self.vector_env = _ServingEnvToAsync(self.env)
else:
self.sampler = SyncSampler(
self.env, self.policy_map["default"], self.obs_filter,
batch_steps, pack=pack)
if not isinstance(self.env, VectorEnv):
self.env = VectorEnv.wrap(
make_env, [self.env], num_envs=num_envs)
self.vector_env = _VectorEnvToAsync(self.env)
else:
self.vector_env = self.env
if self.batch_mode == "truncate_episodes":
if batch_steps % num_envs != 0:
raise ValueError(
"In 'truncate_episodes' batch mode, `batch_steps` must be "
"evenly divisible by `num_envs`. Got {} and {}.".format(
batch_steps, num_envs))
batch_steps = batch_steps // num_envs
pack_episodes = True
elif self.batch_mode == "complete_episodes":
batch_steps = float("inf") # never cut episodes
pack_episodes = False # sampler will return 1 episode per poll
else:
raise ValueError(
"Unsupported batch mode: {}".format(self.batch_mode))
if sample_async:
self.sampler = AsyncSampler(
self.vector_env, self.policy_map["default"], self.obs_filter,
batch_steps, horizon=episode_horizon, pack=pack_episodes)
self.sampler.start()
else:
self.sampler = SyncSampler(
self.vector_env, self.policy_map["default"], self.obs_filter,
batch_steps, horizon=episode_horizon, pack=pack_episodes)
def sample(self):
"""Evaluate the current policies and return a batch of experiences.
@@ -205,8 +243,13 @@ class CommonPolicyEvaluator(PolicyEvaluator):
SampleBatch from evaluating the current policies.
"""
batch = self.policy_map["default"].postprocess_trajectory(
self.sampler.get_data())
batches = [self.sampler.get_data()]
steps_so_far = batches[0].count
while steps_so_far < self.batch_steps:
batch = self.sampler.get_data()
steps_so_far += batch.count
batches.append(batch)
batch = SampleBatch.concat_samples(batches)
if self.compress_observations:
batch["obs"] = [pack(o) for o in batch["obs"]]
@@ -214,11 +257,6 @@ class CommonPolicyEvaluator(PolicyEvaluator):
return batch
def apply(self, func):
"""Apply the given function to this evaluator instance."""
return func(self)
def for_policy(self, func):
"""Apply the given function to this evaluator's default policy."""
+2 -1
View File
@@ -69,7 +69,8 @@ class PolicyGraph(object):
"""Implements algorithm-specific trajectory postprocessing.
Arguments:
sample_batch (SampleBatch): batch of experiences for the policy
sample_batch (SampleBatch): batch of experiences for the policy,
which will contain at most one episode trajectory.
other_agent_batches (dict): In a multi-agent env, this contains the
experience batches seen by other agents.
+160 -72
View File
@@ -2,12 +2,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import namedtuple
from collections import defaultdict, namedtuple
import numpy as np
import six.moves.queue as queue
import threading
from ray.rllib.optimizers.sample_batch import SampleBatchBuilder
from ray.rllib.utils.vector_env import VectorEnv
from ray.rllib.utils.async_vector_env import AsyncVectorEnv, _VectorEnvToAsync
CompletedRollout = namedtuple("CompletedRollout",
@@ -22,17 +24,20 @@ class SyncSampler(object):
This class provides data on invocation, rather than on a separate
thread."""
_async = False
def __init__(
self, env, policy, obs_filter, num_local_steps, horizon=None,
pack=False):
self, env, policy, obs_filter, num_local_steps,
horizon=None, pack=False):
if not isinstance(env, AsyncVectorEnv):
if not isinstance(env, VectorEnv):
env = VectorEnv.wrap(make_env=None, existing_envs=[env])
env = _VectorEnvToAsync(env)
self.async_vector_env = env
self.num_local_steps = num_local_steps
self.horizon = horizon
self.env = env
self.policy = policy
self._obs_filter = obs_filter
self.rollout_provider = _env_runner(self.env, self.policy,
self.rollout_provider = _env_runner(self.async_vector_env, self.policy,
self.num_local_steps, self.horizon,
self._obs_filter, pack)
self.metrics_queue = queue.Queue()
@@ -60,28 +65,29 @@ class AsyncSampler(threading.Thread):
Note that batch_size is only a unit of measure here. Batches can
accumulate and the gradient can be calculated on up to 5 batches."""
_async = True
def __init__(
self, env, policy, obs_filter, num_local_steps, horizon=None,
pack=False):
self, env, policy, obs_filter, num_local_steps,
horizon=None, pack=False):
assert getattr(
obs_filter, "is_concurrent",
False), ("Observation Filter must support concurrent updates.")
if not isinstance(env, AsyncVectorEnv):
if not isinstance(env, VectorEnv):
env = VectorEnv.wrap(make_env=None, existing_envs=[env])
env = _VectorEnvToAsync(env)
self.async_vector_env = env
threading.Thread.__init__(self)
self.queue = queue.Queue(5)
self.metrics_queue = queue.Queue()
self.num_local_steps = num_local_steps
self.horizon = horizon
self.env = env
self.policy = policy
self._obs_filter = obs_filter
self.started = False
self.daemon = True
self.pack = pack
def run(self):
self.started = True
try:
self._run()
except BaseException as e:
@@ -89,7 +95,7 @@ class AsyncSampler(threading.Thread):
raise e
def _run(self):
rollout_provider = _env_runner(self.env, self.policy,
rollout_provider = _env_runner(self.async_vector_env, self.policy,
self.num_local_steps, self.horizon,
self._obs_filter, self.pack)
while True:
@@ -103,15 +109,17 @@ class AsyncSampler(threading.Thread):
self.queue.put(item, timeout=600.0)
def get_data(self):
"""Gets currently accumulated data.
Returns:
rollout (SampleBatch): trajectory data (unprocessed)
"""
assert self.started, "Sampler never started running!"
rollout = self.queue.get(timeout=600.0)
# Propagate errors
if isinstance(rollout, BaseException):
raise rollout
# We can't auto-concat rollouts in vector mode
if self.async_vector_env.num_envs > 1:
return rollout
# Auto-concat rollouts; TODO(ekl) is this important for A3C perf?
while not rollout["dones"][-1]:
try:
part = self.queue.get_nowait()
@@ -132,7 +140,8 @@ class AsyncSampler(threading.Thread):
return completed
def _env_runner(env, policy, num_local_steps, horizon, obs_filter, pack):
def _env_runner(
async_vector_env, policy, num_local_steps, horizon, obs_filter, pack):
"""This implements the logic of the thread runner.
It continually runs the policy, and as long as the rollout exceeds a
@@ -141,9 +150,9 @@ def _env_runner(env, policy, num_local_steps, horizon, obs_filter, pack):
`num_local_steps` is reached.
Args:
env: Environment generated by env_creator
async_vector_env: env implementing AsyncVectorEnv.
policy: Policy used to interact with environment. Also sets fields
to be included in `SampleBatch`
to be included in `SampleBatch`.
num_local_steps: Number of steps before `SampleBatch` is yielded. Set
to infinity to yield complete episodes.
horizon: Horizon of the episode.
@@ -155,67 +164,146 @@ def _env_runner(env, policy, num_local_steps, horizon, obs_filter, pack):
rollout (SampleBatch): Object containing state, action, reward,
terminal condition, and other fields as dictated by `policy`.
"""
last_observation = obs_filter(env.reset())
try:
horizon = horizon if horizon else env.spec.max_episode_steps
if not horizon:
horizon = async_vector_env.get_unwrapped().spec.max_episode_steps
except Exception:
print("Warning, no horizon specified, assuming infinite")
if not horizon:
horizon = 999999
last_features = policy.get_initial_state()
features = last_features
length = 0
rewards = 0
rollout_number = 0
horizon = float("inf")
# Pool of batch builders, which can be shared across episodes to pack
# trajectory data.
batch_builder_pool = []
def get_batch_builder():
if batch_builder_pool:
return batch_builder_pool.pop()
else:
return SampleBatchBuilder()
episodes = defaultdict(
lambda: _Episode(policy.get_initial_state(), get_batch_builder))
while True:
batch_builder = SampleBatchBuilder()
# Get observations from ready envs
unfiltered_obs, rewards, dones, _, off_policy_actions = \
async_vector_env.poll()
ready_eids = []
ready_obs = []
ready_rnn_states = []
for _ in range(num_local_steps):
# Assume batch size one for now
action, features, pi_info = policy.compute_single_action(
last_observation, last_features, is_training=True)
for i, state_value in enumerate(last_features):
pi_info["state_in_{}".format(i)] = state_value
for i, state_value in enumerate(features):
pi_info["state_out_{}".format(i)] = state_value
observation, reward, terminal, info = env.step(action)
observation = obs_filter(observation)
# Process and record the new observations
for eid, raw_obs in unfiltered_obs.items():
episode = episodes[eid]
filtered_obs = obs_filter(raw_obs)
ready_eids.append(eid)
ready_obs.append(filtered_obs)
ready_rnn_states.append(episode.rnn_state)
length += 1
rewards += reward
if length >= horizon:
terminal = True
if episode.last_observation is None:
episode.last_observation = filtered_obs
continue # This is the initial observation after a reset
# Concatenate multiagent actions
if isinstance(action, list):
action = np.concatenate(action, axis=0).flatten()
episode.length += 1
episode.total_reward += rewards[eid]
# Collect the experience.
batch_builder.add_values(
obs=last_observation,
actions=action,
rewards=reward,
dones=terminal,
new_obs=observation,
**pi_info)
# Handle episode terminations
if dones[eid] or episode.length >= horizon:
done = True
yield CompletedRollout(episode.length, episode.total_reward)
else:
done = False
last_observation = observation
last_features = features
episode.batch_builder.add_values(
obs=episode.last_observation,
actions=episode.last_action_flat(),
rewards=rewards[eid],
dones=done,
new_obs=filtered_obs,
**episode.last_pi_info)
if terminal:
yield CompletedRollout(length, rewards)
# Cut the batch if we're not packing multiple episodes into one,
# or if we've exceeded the requested batch size.
if (done and not pack) or \
episode.batch_builder.count >= num_local_steps:
yield episode.batch_builder.build_and_reset(
policy.postprocess_trajectory)
elif done:
# Make sure postprocessor never goes across episode boundaries
episode.batch_builder.postprocess_batch_so_far(
policy.postprocess_trajectory)
if (length >= horizon or
not env.metadata.get("semantics.autoreset")):
last_observation = obs_filter(env.reset())
last_features = policy.get_initial_state()
rollout_number += 1
length = 0
rewards = 0
if not pack:
break
if done:
# Handle episode termination
batch_builder_pool.append(episode.batch_builder)
del episodes[eid]
resetted_obs = async_vector_env.try_reset(eid)
if resetted_obs is None:
# Reset not supported, drop this env from the ready list
assert horizon == float("inf"), \
"Setting episode horizon requires reset() support."
ready_eids.pop()
ready_obs.pop()
ready_rnn_states.pop()
else:
# Reset successful, put in the new obs as ready
episode = episodes[eid]
episode.last_observation = obs_filter(resetted_obs)
ready_obs[-1] = episode.last_observation
ready_rnn_states[-1] = episode.rnn_state
else:
episode.last_observation = filtered_obs
# Once we have enough experience, yield it, and have the ThreadRunner
# place it on a queue.
yield batch_builder.build()
if not ready_eids:
continue # No actions to take
# Compute action for ready envs
ready_rnn_state_cols = _to_column_format(ready_rnn_states)
actions, new_rnn_state_cols, pi_info_cols = policy.compute_actions(
ready_obs, ready_rnn_state_cols, is_training=True)
# Add RNN state info
for f_i, column in enumerate(ready_rnn_state_cols):
pi_info_cols["state_in_{}".format(f_i)] = column
for f_i, column in enumerate(new_rnn_state_cols):
pi_info_cols["state_out_{}".format(f_i)] = column
# Return computed actions to ready envs. We also send to envs that have
# taken off-policy actions; those envs are free to ignore the action.
async_vector_env.send_actions(dict(zip(ready_eids, actions)))
# Store the computed action info
for i, eid in enumerate(ready_eids):
episode = episodes[eid]
if eid in off_policy_actions:
episode.last_action = off_policy_actions[eid]
else:
episode.last_action = actions[i]
episode.rnn_state = [column[i] for column in new_rnn_state_cols]
episode.last_pi_info = {
k: column[i] for k, column in pi_info_cols.items()}
def _to_column_format(rnn_state_rows):
num_cols = len(rnn_state_rows[0])
return [
[row[i] for row in rnn_state_rows] for i in range(num_cols)]
class _Episode(object):
def __init__(self, init_rnn_state, batch_builder_factory):
self.rnn_state = init_rnn_state
self.batch_builder = batch_builder_factory()
self.last_action = None
self.last_observation = None
self.last_pi_info = None
self.total_reward = 0.0
self.length = 0
def last_action_flat(self):
# Concatenate multiagent actions
if isinstance(self.last_action, list):
return np.concatenate(self.last_action, axis=0).flatten()
return self.last_action
+266
View File
@@ -0,0 +1,266 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import queue
import threading
from ray.rllib.utils.async_vector_env import AsyncVectorEnv
class ServingEnv(threading.Thread):
"""Environment that provides policy serving.
Unlike simulator envs, control is inverted. The environment queries the
policy to obtain actions and logs observations and rewards for training.
This is in contrast to gym.Env, where the algorithm drives the simulation
through env.step() calls.
You can use ServingEnv as the backend for policy serving (by serving HTTP
requests in the run loop), for ingesting offline logs data (by reading
offline transitions in the run loop), or other custom use cases not easily
expressed through gym.Env.
ServingEnv supports both on-policy serving (through self.get_action()), and
off-policy serving (through self.log_action()).
This env is thread-safe, but individual episodes must be executed serially.
TODO: Provide a HTTP server/client example based on ServingEnv.
Examples:
>>> register_env("my_env", lambda config: YourServingEnv(config))
>>> agent = DQNAgent(env="my_env")
>>> while True:
print(agent.train())
"""
def __init__(self, action_space, observation_space, max_concurrent=100):
"""Initialize a serving env.
Arguments:
action_space (gym.Space): Action space of the env.
observation_space (gym.Space): Observation space of the env.
max_concurrent (int): Max number of active episodes to allow at
once. Exceeding this limit raises an error.
"""
threading.Thread.__init__(self)
self.daemon = True
self.action_space = action_space
self.observation_space = observation_space
self._episodes = {}
self._finished = set()
self._num_episodes = 0
self._cur_default_episode_id = None
self._results_avail_condition = threading.Condition()
self._max_concurrent_episodes = max_concurrent
def run(self):
"""Override this to implement the run loop.
Your loop should continuously:
1. Call self.start_episode()
2. Call self.get_action() or self.log_action()
3. Call self.log_returns()
4. Call self.end_episode()
5. Wait if nothing to do.
Multiple episodes may be started at the same time.
"""
raise NotImplementedError
def start_episode(self, episode_id=None):
"""Record the start of an episode.
Arguments:
episode_id (str): Unique string id for the episode or None for
it to be auto-assigned. Auto-assignment only works if there
is at most one active episode at a time.
"""
if episode_id is None:
if self._cur_default_episode_id:
raise ValueError(
"An existing episode is still active. You must pass "
"`episode_id` if there are going to be multiple active "
"episodes at once.")
episode_id = "default_{}".format(self._num_episodes)
self._cur_default_episode_id = episode_id
self._num_episodes += 1
if episode_id in self._finished:
raise ValueError(
"Episode {} has already completed.".format(episode_id))
if episode_id in self._episodes:
raise ValueError(
"Episode {} is already started".format(episode_id))
self._episodes[episode_id] = _Episode(
episode_id, self._results_avail_condition)
def get_action(self, observation, episode_id=None):
"""Record an observation and get the on-policy action.
Arguments:
observation (obj): Current environment observation.
episode_id (str): Episode id passed to start_episode() or None.
Returns:
action (obj): Action from the env action space.
"""
episode = self._get(episode_id)
return episode.wait_for_action(observation)
def log_action(self, observation, action, episode_id=None):
"""Record an observation and (off-policy) action taken.
Arguments:
observation (obj): Current environment observation.
action (obj): Action for the observation.
episode_id (str): Episode id passed to start_episode() or None.
"""
episode = self._get(episode_id)
episode.log_action(observation, action)
def log_returns(self, reward, info=None, episode_id=None):
"""Record returns from the environment.
The reward will be attributed to the previous action taken by the
episode. Rewards accumulate until the next action. If no reward is
logged before the next action, a reward of 0.0 is assumed.
Arguments:
episode_id (str): Episode id passed to start_episode() or None.
reward (float): Reward from the environment.
"""
episode = self._get(episode_id)
episode.cur_reward += reward
if info:
episode.cur_info = info
def end_episode(self, observation, episode_id=None):
"""Record the end of an episode.
Arguments:
episode_id (str): Episode id passed by start_episode() or None.
observation (obj): Current environment observation.
"""
episode = self._get(episode_id)
self._finished.add(episode.episode_id)
self._cur_default_episode_id = None
episode.done(observation)
def _get(self, episode_id=None):
"""Get a started episode or raise an error."""
if episode_id is None:
episode_id = self._cur_default_episode_id
if episode_id in self._finished:
raise ValueError(
"Episode {} has already completed.".format(episode_id))
if episode_id not in self._episodes:
raise ValueError("Episode {} not found.".format(episode_id))
return self._episodes[episode_id]
class _ServingEnvToAsync(AsyncVectorEnv):
"""Internal adapter of ServingEnv to AsyncVectorEnv."""
def __init__(self, serving_env):
self.serving_env = serving_env
serving_env.start()
def poll(self):
with self.serving_env._results_avail_condition:
results = self._poll()
while len(results[0]) == 0:
self.serving_env._results_avail_condition.wait()
results = self._poll()
if not self.serving_env.isAlive():
raise Exception("Serving thread has stopped.")
limit = self.serving_env._max_concurrent_episodes
assert len(results[0]) < limit, \
("Too many concurrent episodes, were some leaked? This ServingEnv "
"was created with max_concurrent={}".format(limit))
return results
def _poll(self):
all_obs, all_rewards, all_dones, all_infos = {}, {}, {}, {}
off_policy_actions = {}
for eid, episode in self.serving_env._episodes.copy().items():
data = episode.get_data()
if episode.cur_done:
del self.serving_env._episodes[eid]
if data:
all_obs[eid] = data["obs"]
all_rewards[eid] = data["reward"]
all_dones[eid] = data["done"]
all_infos[eid] = data["info"]
if "off_policy_action" in data:
off_policy_actions[eid] = data["off_policy_action"]
return all_obs, all_rewards, all_dones, all_infos, off_policy_actions
def send_actions(self, action_dict):
for eid, action in action_dict.items():
self.serving_env._episodes[eid].action_queue.put(action)
class _Episode(object):
"""Tracked state for each active episode."""
def __init__(self, episode_id, results_avail_condition):
self.episode_id = episode_id
self.results_avail_condition = results_avail_condition
self.data_queue = queue.Queue()
self.action_queue = queue.Queue()
self.new_observation = None
self.new_action = None
self.cur_reward = 0.0
self.cur_done = False
self.cur_info = {}
def get_data(self):
if self.data_queue.empty():
return None
return self.data_queue.get_nowait()
def log_action(self, observation, action):
self.new_observation = observation
self.new_action = action
self._send()
self.action_queue.get(True, timeout=60.0)
def wait_for_action(self, observation):
self.new_observation = observation
self._send()
return self.action_queue.get(True, timeout=60.0)
def done(self, observation):
self.new_observation = observation
self.cur_done = True
self._send()
def _send(self):
item = {
"obs": self.new_observation,
"reward": self.cur_reward,
"done": self.cur_done,
"info": self.cur_info,
}
if self.new_action is not None:
item["off_policy_action"] = self.new_action
self.new_observation = None
self.new_action = None
self.cur_reward = 0.0
with self.results_avail_condition:
self.data_queue.put_nowait(item)
self.results_avail_condition.notify()
+117
View File
@@ -0,0 +1,117 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import queue
import threading
class VectorEnv(object):
"""An environment that supports batch evaluation.
Subclasses must define the following attributes:
Attributes:
action_space (gym.Space): Action space of individual envs.
observation_space (gym.Space): Observation space of individual envs.
num_envs (int): Number of envs to batch over.
"""
@staticmethod
def wrap(make_env=None, existing_envs=None, num_envs=1):
return _VectorizedGymEnv(make_env, existing_envs or [], num_envs)
def vector_reset(self):
raise NotImplementedError
def reset_at(self, index):
raise NotImplementedError
def vector_step(self, actions):
raise NotImplementedError
def get_unwrapped(self):
raise NotImplementedError
class _VectorizedGymEnv(VectorEnv):
"""Internal wrapper for gym envs to implement VectorEnv.
Arguents:
make_env (func|None): Factory that produces a new gym env. Must be
defined if the number of existing envs is less than num_envs.
existing_envs (list): List of existing gym envs.
num_envs (int): Desired num gym envs to keep total.
"""
def __init__(self, make_env, existing_envs, num_envs):
self.make_env = make_env
self.envs = existing_envs
self.num_envs = num_envs
if make_env and num_envs > 1:
self.resetter = _AsyncResetter(
make_env, int(self.num_envs ** 0.5))
else:
self.resetter = _SimpleResetter(make_env)
while len(self.envs) < self.num_envs:
self.envs.append(self.make_env())
def vector_reset(self):
return [e.reset() for e in self.envs]
def reset_at(self, index):
new_obs, new_env = self.resetter.trade_for_resetted(self.envs[index])
self.envs[index] = new_env
return new_obs
def vector_step(self, actions):
obs_batch, rew_batch, done_batch, info_batch = [], [], [], []
for i in range(self.num_envs):
obs, rew, done, info = self.envs[i].step(actions[i])
obs_batch.append(obs)
rew_batch.append(rew)
done_batch.append(done)
info_batch.append(info)
return obs_batch, rew_batch, done_batch, info_batch
def get_unwrapped(self):
return self.envs[0]
class _AsyncResetter(threading.Thread):
"""Does env reset asynchronously in the background.
This is useful since resetting an env can be 100x slower than stepping."""
def __init__(self, make_env, pool_size):
threading.Thread.__init__(self)
self.make_env = make_env
self.pool_size = 0
self.to_reset = queue.Queue()
self.resetted = queue.Queue()
self.daemon = True
self.pool_size = pool_size
while self.resetted.qsize() < self.pool_size:
env = self.make_env()
obs = env.reset()
self.resetted.put((obs, env))
self.start()
def run(self):
while True:
env = self.to_reset.get()
obs = env.reset()
self.resetted.put((obs, env))
def trade_for_resetted(self, env):
self.to_reset.put(env)
new_obs, new_env = self.resetted.get(timeout=30)
return new_obs, new_env
class _SimpleResetter(object):
def __init__(self, make_env):
pass
def trade_for_resetted(self, env):
return env.reset(), env