mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 03:54:16 +08:00
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
This commit is contained in:
@@ -9,6 +9,9 @@ from ray.tune.registry import register_trainable
|
||||
from ray.rllib.utils.policy_graph import PolicyGraph
|
||||
from ray.rllib.utils.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.utils.common_policy_evaluator import CommonPolicyEvaluator
|
||||
from ray.rllib.utils.async_vector_env import AsyncVectorEnv
|
||||
from ray.rllib.utils.vector_env import VectorEnv
|
||||
from ray.rllib.utils.serving_env import ServingEnv
|
||||
from ray.rllib.optimizers.sample_batch import SampleBatch
|
||||
|
||||
|
||||
@@ -23,5 +26,6 @@ def _register_all():
|
||||
_register_all()
|
||||
|
||||
__all__ = [
|
||||
"PolicyGraph", "TFPolicyGraph", "CommonPolicyEvaluator", "SampleBatch"
|
||||
"PolicyGraph", "TFPolicyGraph", "CommonPolicyEvaluator", "SampleBatch",
|
||||
"AsyncVectorEnv", "VectorEnv", "ServingEnv",
|
||||
]
|
||||
|
||||
@@ -17,6 +17,8 @@ from ray.tune.trial import Resources
|
||||
DEFAULT_CONFIG = {
|
||||
# Number of workers (excluding master)
|
||||
"num_workers": 2,
|
||||
# Number of environments to evaluate vectorwise per worker.
|
||||
"num_envs": 1,
|
||||
# Size of rollout batch
|
||||
"batch_size": 10,
|
||||
# Use LSTM model - only applicable for image states
|
||||
@@ -101,7 +103,8 @@ class A3CAgent(Agent):
|
||||
batch_mode="truncate_episodes",
|
||||
tf_session_creator=session_creator,
|
||||
registry=self.registry, env_config=self.config["env_config"],
|
||||
model_config=self.config["model"], policy_config=self.config)
|
||||
model_config=self.config["model"], policy_config=self.config,
|
||||
num_envs=self.config["num_envs"])
|
||||
self.remote_evaluators = [
|
||||
remote_cls.remote(
|
||||
self.env_creator, self.policy_cls,
|
||||
@@ -109,7 +112,8 @@ class A3CAgent(Agent):
|
||||
batch_mode="truncate_episodes", sample_async=True,
|
||||
tf_session_creator=session_creator,
|
||||
registry=self.registry, env_config=self.config["env_config"],
|
||||
model_config=self.config["model"], policy_config=self.config)
|
||||
model_config=self.config["model"], policy_config=self.config,
|
||||
num_envs=self.config["num_envs"])
|
||||
for i in range(self.config["num_workers"])]
|
||||
|
||||
self.optimizer = AsyncOptimizer(
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
from threading import Lock
|
||||
|
||||
import torch
|
||||
@@ -33,13 +34,12 @@ class SharedTorchPolicy(PolicyGraph):
|
||||
self.optimizer = torch.optim.Adam(
|
||||
self._model.parameters(), lr=self.config["lr"])
|
||||
|
||||
def compute_single_action(self, obs, state, is_training=False):
|
||||
def compute_actions(self, obs, state, is_training=False):
|
||||
assert not state, "RNN not supported"
|
||||
with self.lock:
|
||||
ob = torch.from_numpy(obs).float().unsqueeze(0)
|
||||
ob = torch.from_numpy(np.array(obs)).float()
|
||||
logits, values = self._model(ob)
|
||||
samples = F.softmax(logits, dim=1).multinomial(1).squeeze()
|
||||
values = values.squeeze()
|
||||
samples = F.softmax(logits, dim=1).multinomial(1).squeeze(0)
|
||||
return var_to_np(samples), [], {"vf_preds": var_to_np(values)}
|
||||
|
||||
def compute_gradients(self, samples):
|
||||
|
||||
@@ -1,85 +0,0 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ray.rllib.a3c.torchpolicy import TorchPolicy
|
||||
from ray.rllib.models.pytorch.misc import var_to_np, convert_batch
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
|
||||
|
||||
class SharedTorchPolicy(TorchPolicy):
|
||||
"""Assumes nonrecurrent."""
|
||||
|
||||
other_output = ["vf_preds"]
|
||||
is_recurrent = False
|
||||
|
||||
def __init__(self, registry, ob_space, ac_space, config, **kwargs):
|
||||
super(SharedTorchPolicy, self).__init__(registry, ob_space, ac_space,
|
||||
config, **kwargs)
|
||||
|
||||
def _setup_graph(self, ob_space, ac_space):
|
||||
_, self.logit_dim = ModelCatalog.get_action_dist(ac_space)
|
||||
self._model = ModelCatalog.get_torch_model(
|
||||
self.registry, ob_space, self.logit_dim, self.config["model"])
|
||||
self.optimizer = torch.optim.Adam(
|
||||
self._model.parameters(), lr=self.config["lr"])
|
||||
|
||||
def compute(self, ob, *args):
|
||||
"""Should take in a SINGLE ob"""
|
||||
with self.lock:
|
||||
ob = torch.from_numpy(ob).float().unsqueeze(0)
|
||||
logits, values = self._model(ob)
|
||||
# TODO(alok): Support non-categorical distributions. Multinomial
|
||||
# is only for categorical.
|
||||
sampled_actions = F.softmax(logits, dim=1).multinomial(1).squeeze()
|
||||
values = values.squeeze()
|
||||
return var_to_np(sampled_actions), {"vf_preds": var_to_np(values)}
|
||||
|
||||
def compute_logits(self, ob, *args):
|
||||
with self.lock:
|
||||
ob = torch.from_numpy(ob).float().unsqueeze(0)
|
||||
res = self._model.hidden_layers(ob)
|
||||
return var_to_np(self._model.logits(res))
|
||||
|
||||
def value(self, ob, *args):
|
||||
with self.lock:
|
||||
ob = torch.from_numpy(ob).float().unsqueeze(0)
|
||||
res = self._model.hidden_layers(ob)
|
||||
res = self._model.value_branch(res)
|
||||
res = res.squeeze()
|
||||
return var_to_np(res)
|
||||
|
||||
def _evaluate(self, obs, actions):
|
||||
"""Passes in multiple obs."""
|
||||
logits, values = self._model(obs)
|
||||
log_probs = F.log_softmax(logits, dim=1)
|
||||
probs = F.softmax(logits, dim=1)
|
||||
action_log_probs = log_probs.gather(1, actions.view(-1, 1))
|
||||
# TODO(alok): set distribution based on action space and use its
|
||||
# `.entropy()` method to calculate automatically
|
||||
entropy = -(log_probs * probs).sum(-1).sum()
|
||||
return values, action_log_probs, entropy
|
||||
|
||||
def _backward(self, batch):
|
||||
"""Loss is encoded in here. Defining a new loss function
|
||||
would start by rewriting this function"""
|
||||
|
||||
states, actions, advs, rs, _ = convert_batch(batch)
|
||||
values, action_log_probs, entropy = self._evaluate(states, actions)
|
||||
pi_err = -advs.dot(action_log_probs.reshape(-1))
|
||||
value_err = F.mse_loss(values.reshape(-1), rs)
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
overall_err = sum([
|
||||
pi_err,
|
||||
self.config["vf_loss_coeff"] * value_err,
|
||||
self.config["entropy_coeff"] * entropy,
|
||||
])
|
||||
|
||||
overall_err.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self._model.parameters(),
|
||||
self.config["grad_clip"])
|
||||
@@ -1,82 +0,0 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import torch
|
||||
|
||||
from ray.rllib.a3c.policy import Policy
|
||||
from threading import Lock
|
||||
|
||||
|
||||
class TorchPolicy(Policy):
|
||||
"""The policy base class for Torch.
|
||||
|
||||
The model is a separate object than the policy. This could be changed
|
||||
in the future."""
|
||||
|
||||
def __init__(self,
|
||||
registry,
|
||||
ob_space,
|
||||
action_space,
|
||||
config,
|
||||
name="local",
|
||||
summarize=True):
|
||||
self.registry = registry
|
||||
self.local_steps = 0
|
||||
self.config = config
|
||||
self.summarize = summarize
|
||||
self._setup_graph(ob_space, action_space)
|
||||
torch.set_num_threads(2)
|
||||
self.lock = Lock()
|
||||
|
||||
def apply_gradients(self, grads):
|
||||
self.optimizer.zero_grad()
|
||||
for g, p in zip(grads, self._model.parameters()):
|
||||
p.grad = torch.from_numpy(g)
|
||||
self.optimizer.step()
|
||||
|
||||
def get_weights(self):
|
||||
# !! This only returns references to the data.
|
||||
return self._model.state_dict()
|
||||
|
||||
def set_weights(self, weights):
|
||||
with self.lock:
|
||||
self._model.load_state_dict(weights)
|
||||
|
||||
def compute_gradients(self, samples):
|
||||
"""_backward generates the gradient in each model parameter.
|
||||
This is taken out.
|
||||
|
||||
Args:
|
||||
samples: SampleBatch of data needed for gradient calculation.
|
||||
|
||||
Return:
|
||||
gradients (list of np arrays): List of gradients
|
||||
info (dict): Extra information (user-defined)"""
|
||||
with self.lock:
|
||||
self._backward(samples)
|
||||
# Note that return values are just references;
|
||||
# calling zero_grad will modify the values
|
||||
return [p.grad.data.numpy() for p in self._model.parameters()], {}
|
||||
|
||||
def model_update(self, batch):
|
||||
"""Implements compute + apply
|
||||
|
||||
TODO(rliaw): Pytorch has nice caching property that doesn't require
|
||||
full batch to be passed in. Can exploit that later"""
|
||||
with self.lock:
|
||||
self._backward(batch)
|
||||
self.optimizer.step()
|
||||
|
||||
def _setup_graph(ob_space, action_space):
|
||||
raise NotImplementedError
|
||||
|
||||
def _backward(self, batch):
|
||||
"""Implements the loss function and calculates the gradient.
|
||||
Pytorch automatically generates a backward trace for each tensor.
|
||||
Assumption right now is that variables are moved, so the backward
|
||||
trace is lost.
|
||||
|
||||
This function regenerates the backward trace and
|
||||
caluclates the gradient."""
|
||||
raise NotImplementedError
|
||||
@@ -95,6 +95,8 @@ DEFAULT_CONFIG = {
|
||||
# to increase if your environment is particularly slow to sample, or if
|
||||
# you"re using the Async or Ape-X optimizers.
|
||||
"num_workers": 0,
|
||||
# Number of environments to evaluate vectorwise per worker.
|
||||
"num_envs": 1,
|
||||
# Whether to allocate GPUs for workers (if > 0).
|
||||
"num_gpus_per_worker": 0,
|
||||
# Whether to allocate CPUs for workers (if > 0).
|
||||
|
||||
@@ -89,6 +89,8 @@ DEFAULT_CONFIG = {
|
||||
# to increase if your environment is particularly slow to sample, or if
|
||||
# you"re using the Async or Ape-X optimizers.
|
||||
"num_workers": 0,
|
||||
# Number of environments to evaluate vectorwise per worker.
|
||||
"num_envs": 1,
|
||||
# Whether to allocate GPUs for workers (if > 0).
|
||||
"num_gpus_per_worker": 0,
|
||||
# Whether to allocate CPUs for workers (if > 0).
|
||||
@@ -125,10 +127,11 @@ class DQNAgent(Agent):
|
||||
self.local_evaluator = CommonPolicyEvaluator(
|
||||
self.env_creator, self._policy_graph,
|
||||
batch_steps=adjusted_batch_size,
|
||||
batch_mode="pack_episodes", preprocessor_pref="deepmind",
|
||||
batch_mode="truncate_episodes", preprocessor_pref="deepmind",
|
||||
compress_observations=True,
|
||||
registry=self.registry, env_config=self.config["env_config"],
|
||||
model_config=self.config["model"], policy_config=self.config)
|
||||
model_config=self.config["model"], policy_config=self.config,
|
||||
num_envs=self.config["num_envs"])
|
||||
remote_cls = CommonPolicyEvaluator.as_remote(
|
||||
num_cpus=self.config["num_cpus_per_worker"],
|
||||
num_gpus=self.config["num_gpus_per_worker"])
|
||||
@@ -136,10 +139,11 @@ class DQNAgent(Agent):
|
||||
remote_cls.remote(
|
||||
self.env_creator, self._policy_graph,
|
||||
batch_steps=adjusted_batch_size,
|
||||
batch_mode="pack_episodes", preprocessor_pref="deepmind",
|
||||
batch_mode="truncate_episodes", preprocessor_pref="deepmind",
|
||||
compress_observations=True,
|
||||
registry=self.registry, env_config=self.config["env_config"],
|
||||
model_config=self.config["model"], policy_config=self.config)
|
||||
model_config=self.config["model"], policy_config=self.config,
|
||||
num_envs=self.config["num_envs"])
|
||||
for _ in range(self.config["num_workers"])]
|
||||
|
||||
self.exploration0 = self._make_exploration_schedule(0)
|
||||
|
||||
@@ -223,11 +223,9 @@ def _postprocess_dqn(policy_graph, sample_batch):
|
||||
"obs": obs, "actions": actions, "rewards": rewards,
|
||||
"new_obs": new_obs, "dones": dones,
|
||||
"weights": np.ones_like(rewards)})
|
||||
assert batch.count == policy_graph.config["sample_batch_size"], \
|
||||
(batch.count, policy_graph.config["sample_batch_size"])
|
||||
|
||||
# Prioritize on the worker side
|
||||
if policy_graph.config["worker_side_prioritization"]:
|
||||
if batch.count > 0 and policy_graph.config["worker_side_prioritization"]:
|
||||
td_errors = policy_graph.compute_td_error(
|
||||
batch["obs"], batch["actions"], batch["rewards"],
|
||||
batch["new_obs"], batch["dones"], batch["weights"])
|
||||
|
||||
@@ -63,7 +63,7 @@ class Categorical(ActionDistribution):
|
||||
reduction_indices=[1])
|
||||
|
||||
def sample(self):
|
||||
return tf.multinomial(self.inputs, 1)[0]
|
||||
return tf.squeeze(tf.multinomial(self.inputs, 1), axis=1)
|
||||
|
||||
|
||||
class DiagGaussian(ActionDistribution):
|
||||
|
||||
@@ -125,22 +125,16 @@ def get_preprocessor(space):
|
||||
|
||||
legacy_patch_shapes(space)
|
||||
obs_shape = space.shape
|
||||
print("Observation shape is {}".format(obs_shape))
|
||||
|
||||
if isinstance(space, gym.spaces.Discrete):
|
||||
print("Using one-hot preprocessor for discrete envs.")
|
||||
preprocessor = OneHotPreprocessor
|
||||
elif obs_shape == ATARI_OBS_SHAPE:
|
||||
print("Assuming Atari pixel env, using AtariPixelPreprocessor.")
|
||||
preprocessor = AtariPixelPreprocessor
|
||||
elif obs_shape == ATARI_RAM_OBS_SHAPE:
|
||||
print("Assuming Atari ram env, using AtariRamPreprocessor.")
|
||||
preprocessor = AtariRamPreprocessor
|
||||
elif isinstance(space, gym.spaces.Tuple):
|
||||
print("Using a TupleFlatteningPreprocessor")
|
||||
preprocessor = TupleFlatteningPreprocessor
|
||||
else:
|
||||
print("Not using any observation preprocessor.")
|
||||
preprocessor = NoPreprocessor
|
||||
|
||||
return preprocessor
|
||||
|
||||
@@ -56,5 +56,5 @@ class FullyConnectedNetwork(Model):
|
||||
value: value function for each state"""
|
||||
res = self.hidden_layers(obs)
|
||||
logits = self.logits(res)
|
||||
value = self.value_branch(res).reshape(-1)
|
||||
value = self.value_branch(res).squeeze(1)
|
||||
return logits, value
|
||||
|
||||
@@ -65,5 +65,5 @@ class VisionNetwork(Model):
|
||||
value (PyTorch): value function for each state"""
|
||||
res = self.hidden_layers(obs)
|
||||
logits = self.logits(res)
|
||||
value = self.value_branch(res)
|
||||
value = self.value_branch(res).squeeze(1)
|
||||
return logits, value
|
||||
|
||||
@@ -105,6 +105,11 @@ class PolicyEvaluator(object):
|
||||
|
||||
return os.uname()[1]
|
||||
|
||||
def apply(self, func, *args):
|
||||
"""Apply the given function to this evaluator instance."""
|
||||
|
||||
return func(self, *args)
|
||||
|
||||
|
||||
class TFMultiGPUSupport(PolicyEvaluator):
|
||||
"""The multi-GPU TF optimizer requires additional TF-specific support.
|
||||
|
||||
@@ -110,3 +110,23 @@ class PolicyOptimizer(object):
|
||||
|
||||
self.num_steps_trained = data[0]
|
||||
self.num_steps_sampled = data[1]
|
||||
|
||||
def foreach_evaluator(self, func):
|
||||
"""Apply the given function to each evaluator instance."""
|
||||
|
||||
local_result = [func(self.local_evaluator)]
|
||||
remote_results = ray.get(
|
||||
[ev.apply.remote(func) for ev in self.remote_evaluators])
|
||||
return local_result + remote_results
|
||||
|
||||
def foreach_evaluator_with_index(self, func):
|
||||
"""Apply the given function to each evaluator instance.
|
||||
|
||||
The index will be passed as the second arg to the given function.
|
||||
"""
|
||||
|
||||
local_result = [func(self.local_evaluator, 0)]
|
||||
remote_results = ray.get(
|
||||
[ev.apply.remote(func, i + 1)
|
||||
for i, ev in enumerate(self.remote_evaluators)])
|
||||
return local_result + remote_results
|
||||
|
||||
@@ -7,17 +7,47 @@ import numpy as np
|
||||
|
||||
|
||||
class SampleBatchBuilder(object):
|
||||
"""Util to build a SampleBatch incrementally."""
|
||||
"""Util to build a SampleBatch incrementally.
|
||||
|
||||
For efficiency, SampleBatches hold values in column form (as arrays).
|
||||
However, it is useful to add data one row (dict) at a time.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.postprocessed = []
|
||||
self.buffers = collections.defaultdict(list)
|
||||
self.count = 0
|
||||
|
||||
def add_values(self, **values):
|
||||
"""Add the given dictionary (row) of values to this batch."""
|
||||
|
||||
for k, v in values.items():
|
||||
self.buffers[k].append(v)
|
||||
self.count += 1
|
||||
|
||||
def build(self):
|
||||
return SampleBatch({k: np.array(v) for k, v in self.buffers.items()})
|
||||
def postprocess_batch_so_far(self, postprocessor):
|
||||
"""Apply the given postprocessor to any unprocessed rows."""
|
||||
|
||||
batch = postprocessor(self._build_buffers())
|
||||
self.postprocessed.append(batch)
|
||||
|
||||
def build_and_reset(self, postprocessor):
|
||||
"""Returns a sample batch including all previously added values.
|
||||
|
||||
Any unprocessed rows will be first postprocessed with the given
|
||||
postprocessor. The internal state of this builder will be reset.
|
||||
"""
|
||||
|
||||
self.postprocess_batch_so_far(postprocessor)
|
||||
batch = SampleBatch.concat_samples(self.postprocessed)
|
||||
self.postprocessed = []
|
||||
self.count = 0
|
||||
return batch
|
||||
|
||||
def _build_buffers(self):
|
||||
batch = SampleBatch({k: np.array(v) for k, v in self.buffers.items()})
|
||||
self.buffers.clear()
|
||||
return batch
|
||||
|
||||
|
||||
class SampleBatch(object):
|
||||
@@ -41,6 +71,7 @@ class SampleBatch(object):
|
||||
@staticmethod
|
||||
def concat_samples(samples):
|
||||
out = {}
|
||||
samples = [s for s in samples if s.count > 0]
|
||||
for k in samples[0].keys():
|
||||
out[k] = np.concatenate([s[k] for s in samples])
|
||||
return SampleBatch(out)
|
||||
|
||||
@@ -12,7 +12,9 @@ from ray.tune.trial import Resources
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
# Number of workers (excluding master)
|
||||
"num_workers": 4,
|
||||
"num_workers": 0,
|
||||
# Number of environments to evaluate vectorwise per worker.
|
||||
"num_envs": 1,
|
||||
# Size of rollout batch
|
||||
"batch_size": 512,
|
||||
# Discount factor of MDP
|
||||
@@ -57,6 +59,7 @@ class PGAgent(Agent):
|
||||
"model_config": self.config["model"],
|
||||
"env_config": self.config["env_config"],
|
||||
"policy_config": self.config,
|
||||
"num_envs": self.config["num_envs"],
|
||||
},
|
||||
num_workers=self.config["num_workers"],
|
||||
optimizer_config=self.config["optimizer"])
|
||||
|
||||
@@ -82,11 +82,14 @@ class ProximalPolicyGraph(object):
|
||||
self.policy_results = [
|
||||
self.sampler, self.curr_logits, tf.constant("NA")]
|
||||
|
||||
def compute_single_action(self, observation, features, is_training=False):
|
||||
def compute_actions(self, observations, features, is_training=False):
|
||||
action, logprobs, vf = self.sess.run(
|
||||
self.policy_results,
|
||||
feed_dict={self.observations: [observation]})
|
||||
return action[0], [], {"vf_preds": vf[0], "logprobs": logprobs[0]}
|
||||
feed_dict={self.observations: observations})
|
||||
return action, [], {"vf_preds": vf, "logprobs": logprobs}
|
||||
|
||||
def postprocess_trajectory(self, batch):
|
||||
return batch
|
||||
|
||||
def get_initial_state(self):
|
||||
return []
|
||||
|
||||
@@ -69,7 +69,7 @@ DEFAULT_CONFIG = {
|
||||
# number of steps is obtained
|
||||
"min_steps_per_task": 200,
|
||||
# Number of actors used to collect the rollouts
|
||||
"num_workers": 5,
|
||||
"num_workers": 2,
|
||||
# Whether to allocate GPUs for workers (if > 0).
|
||||
"num_gpus_per_worker": 0,
|
||||
# Whether to allocate CPUs for workers (if > 0).
|
||||
@@ -299,5 +299,5 @@ class PPOAgent(Agent):
|
||||
def compute_action(self, observation):
|
||||
observation = self.local_evaluator.obs_filter(
|
||||
observation, update=False)
|
||||
return self.local_evaluator.common_policy.compute_single_action(
|
||||
observation, [], False)[0]
|
||||
return self.local_evaluator.common_policy.compute_actions(
|
||||
[observation], [], False)[0][0]
|
||||
|
||||
@@ -7,9 +7,13 @@ import time
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.utils.common_policy_evaluator import CommonPolicyEvaluator
|
||||
from ray.rllib.pg import PGAgent
|
||||
from ray.rllib.utils.common_policy_evaluator import CommonPolicyEvaluator, \
|
||||
collect_metrics
|
||||
from ray.rllib.utils.policy_graph import PolicyGraph
|
||||
from ray.rllib.utils.process_rollout import compute_advantages
|
||||
from ray.rllib.utils.vector_env import VectorEnv
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
|
||||
class MockPolicyGraph(PolicyGraph):
|
||||
@@ -20,6 +24,55 @@ class MockPolicyGraph(PolicyGraph):
|
||||
return compute_advantages(batch, 100.0, 0.9, use_gae=False)
|
||||
|
||||
|
||||
class BadPolicyGraph(PolicyGraph):
|
||||
def compute_actions(self, obs_batch, state_batches, is_training=False):
|
||||
raise Exception("intentional error")
|
||||
|
||||
def postprocess_trajectory(self, batch):
|
||||
return compute_advantages(batch, 100.0, 0.9, use_gae=False)
|
||||
|
||||
|
||||
class MockEnv(gym.Env):
|
||||
def __init__(self, episode_length):
|
||||
self.episode_length = episode_length
|
||||
self.i = 0
|
||||
self.observation_space = gym.spaces.Discrete(1)
|
||||
self.action_space = gym.spaces.Discrete(2)
|
||||
|
||||
def reset(self):
|
||||
self.i = 0
|
||||
return self.i
|
||||
|
||||
def step(self, action):
|
||||
self.i += 1
|
||||
return 0, 1, self.i >= self.episode_length, {}
|
||||
|
||||
|
||||
class MockVectorEnv(VectorEnv):
|
||||
def __init__(self, episode_length, num_envs):
|
||||
self.envs = [
|
||||
MockEnv(episode_length) for _ in range(num_envs)]
|
||||
self.observation_space = gym.spaces.Discrete(1)
|
||||
self.action_space = gym.spaces.Discrete(2)
|
||||
self.num_envs = num_envs
|
||||
|
||||
def vector_reset(self):
|
||||
return [e.reset() for e in self.envs]
|
||||
|
||||
def reset_at(self, index):
|
||||
return self.envs[index].reset()
|
||||
|
||||
def vector_step(self, actions):
|
||||
obs_batch, rew_batch, done_batch, info_batch = [], [], [], []
|
||||
for i in range(len(self.envs)):
|
||||
obs, rew, done, info = self.envs[i].step(actions[i])
|
||||
obs_batch.append(obs)
|
||||
rew_batch.append(rew)
|
||||
done_batch.append(done)
|
||||
info_batch.append(info)
|
||||
return obs_batch, rew_batch, done_batch, info_batch
|
||||
|
||||
|
||||
class TestCommonPolicyEvaluator(unittest.TestCase):
|
||||
def testBasic(self):
|
||||
ev = CommonPolicyEvaluator(
|
||||
@@ -30,43 +83,134 @@ class TestCommonPolicyEvaluator(unittest.TestCase):
|
||||
self.assertIn(key, batch)
|
||||
self.assertGreater(batch["advantages"][0], 1)
|
||||
|
||||
def testPackEpisodes(self):
|
||||
for batch_size in [1, 10, 100, 1000]:
|
||||
ev = CommonPolicyEvaluator(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy_graph=MockPolicyGraph,
|
||||
batch_steps=batch_size,
|
||||
batch_mode="pack_episodes")
|
||||
def testQueryEvaluators(self):
|
||||
register_env("test", lambda _: gym.make("CartPole-v0"))
|
||||
pg = PGAgent(env="test", config={"num_workers": 2, "batch_size": 5})
|
||||
results = pg.optimizer.foreach_evaluator(lambda ev: ev.batch_steps)
|
||||
results2 = pg.optimizer.foreach_evaluator_with_index(
|
||||
lambda ev, i: (i, ev.batch_steps))
|
||||
self.assertEqual(results, [5, 5, 5])
|
||||
self.assertEqual(results2, [(0, 5), (1, 5), (2, 5)])
|
||||
|
||||
def testMetrics(self):
|
||||
ev = CommonPolicyEvaluator(
|
||||
env_creator=lambda _: MockEnv(episode_length=10),
|
||||
policy_graph=MockPolicyGraph, batch_mode="complete_episodes")
|
||||
remote_ev = CommonPolicyEvaluator.as_remote().remote(
|
||||
env_creator=lambda _: MockEnv(episode_length=10),
|
||||
policy_graph=MockPolicyGraph, batch_mode="complete_episodes")
|
||||
ev.sample()
|
||||
ray.get(remote_ev.sample.remote())
|
||||
result = collect_metrics(ev, [remote_ev])
|
||||
self.assertEqual(result.episodes_total, 20)
|
||||
self.assertEqual(result.episode_reward_mean, 10)
|
||||
|
||||
def testAsync(self):
|
||||
ev = CommonPolicyEvaluator(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
sample_async=True,
|
||||
policy_graph=MockPolicyGraph)
|
||||
batch = ev.sample()
|
||||
for key in ["obs", "actions", "rewards", "dones", "advantages"]:
|
||||
self.assertIn(key, batch)
|
||||
self.assertGreater(batch["advantages"][0], 1)
|
||||
|
||||
def testAutoConcat(self):
|
||||
ev = CommonPolicyEvaluator(
|
||||
env_creator=lambda _: MockEnv(episode_length=40),
|
||||
policy_graph=MockPolicyGraph,
|
||||
sample_async=True,
|
||||
batch_steps=10,
|
||||
batch_mode="truncate_episodes",
|
||||
observation_filter="ConcurrentMeanStdFilter")
|
||||
time.sleep(2)
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 40) # auto-concat up to 5 episodes
|
||||
|
||||
def testAutoVectorization(self):
|
||||
ev = CommonPolicyEvaluator(
|
||||
env_creator=lambda _: MockEnv(episode_length=20),
|
||||
policy_graph=MockPolicyGraph,
|
||||
batch_mode="truncate_episodes",
|
||||
batch_steps=16, num_envs=8)
|
||||
for _ in range(8):
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, batch_size)
|
||||
self.assertEqual(batch.count, 16)
|
||||
result = collect_metrics(ev, [])
|
||||
self.assertEqual(result.episodes_total, 0)
|
||||
for _ in range(8):
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 16)
|
||||
result = collect_metrics(ev, [])
|
||||
self.assertEqual(result.episodes_total, 8)
|
||||
|
||||
def testBatchDivisibilityCheck(self):
|
||||
self.assertRaises(
|
||||
ValueError,
|
||||
lambda: CommonPolicyEvaluator(
|
||||
env_creator=lambda _: MockEnv(episode_length=8),
|
||||
policy_graph=MockPolicyGraph,
|
||||
batch_mode="truncate_episodes",
|
||||
batch_steps=15, num_envs=4))
|
||||
|
||||
def testBatchesSmallerWhenVectorized(self):
|
||||
ev = CommonPolicyEvaluator(
|
||||
env_creator=lambda _: MockEnv(episode_length=8),
|
||||
policy_graph=MockPolicyGraph,
|
||||
batch_mode="truncate_episodes",
|
||||
batch_steps=16, num_envs=4)
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 16)
|
||||
result = collect_metrics(ev, [])
|
||||
self.assertEqual(result.episodes_total, 0)
|
||||
batch = ev.sample()
|
||||
result = collect_metrics(ev, [])
|
||||
self.assertEqual(result.episodes_total, 4)
|
||||
|
||||
def testVectorEnvSupport(self):
|
||||
ev = CommonPolicyEvaluator(
|
||||
env_creator=lambda _: MockVectorEnv(
|
||||
episode_length=20, num_envs=8),
|
||||
policy_graph=MockPolicyGraph,
|
||||
batch_mode="truncate_episodes",
|
||||
batch_steps=10)
|
||||
for _ in range(8):
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 10)
|
||||
result = collect_metrics(ev, [])
|
||||
self.assertEqual(result.episodes_total, 0)
|
||||
for _ in range(8):
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 10)
|
||||
result = collect_metrics(ev, [])
|
||||
self.assertEqual(result.episodes_total, 8)
|
||||
|
||||
def testTruncateEpisodes(self):
|
||||
ev = CommonPolicyEvaluator(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
env_creator=lambda _: MockEnv(10),
|
||||
policy_graph=MockPolicyGraph,
|
||||
batch_steps=2,
|
||||
batch_steps=15,
|
||||
batch_mode="truncate_episodes")
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 2)
|
||||
ev = CommonPolicyEvaluator(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy_graph=MockPolicyGraph,
|
||||
batch_steps=1000,
|
||||
batch_mode="truncate_episodes")
|
||||
self.assertLess(batch.count, 200)
|
||||
self.assertEqual(batch.count, 15)
|
||||
|
||||
def testCompleteEpisodes(self):
|
||||
ev = CommonPolicyEvaluator(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
env_creator=lambda _: MockEnv(10),
|
||||
policy_graph=MockPolicyGraph,
|
||||
batch_steps=2,
|
||||
batch_steps=5,
|
||||
batch_mode="complete_episodes")
|
||||
batch = ev.sample()
|
||||
self.assertGreater(batch.count, 2)
|
||||
self.assertTrue(batch["dones"][-1])
|
||||
self.assertEqual(batch.count, 10)
|
||||
|
||||
def testCompleteEpisodesPacking(self):
|
||||
ev = CommonPolicyEvaluator(
|
||||
env_creator=lambda _: MockEnv(10),
|
||||
policy_graph=MockPolicyGraph,
|
||||
batch_steps=15,
|
||||
batch_mode="complete_episodes")
|
||||
batch = ev.sample()
|
||||
self.assertGreater(batch.count, 2)
|
||||
self.assertTrue(batch["dones"][-1])
|
||||
self.assertEqual(batch.count, 20)
|
||||
|
||||
def testFilterSync(self):
|
||||
ev = CommonPolicyEvaluator(
|
||||
@@ -129,5 +273,5 @@ class TestCommonPolicyEvaluator(unittest.TestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ray.init()
|
||||
ray.init(num_cpus=5)
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -0,0 +1,192 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import random
|
||||
import unittest
|
||||
import uuid
|
||||
|
||||
import ray
|
||||
from ray.rllib.dqn import DQNAgent
|
||||
from ray.rllib.pg import PGAgent
|
||||
from ray.rllib.utils.common_policy_evaluator import CommonPolicyEvaluator
|
||||
from ray.rllib.utils.serving_env import ServingEnv
|
||||
from ray.rllib.test.test_common_policy_evaluator import BadPolicyGraph, \
|
||||
MockPolicyGraph, MockEnv
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
|
||||
class SimpleServing(ServingEnv):
|
||||
def __init__(self, env):
|
||||
ServingEnv.__init__(self, env.action_space, env.observation_space)
|
||||
self.env = env
|
||||
|
||||
def run(self):
|
||||
self.start_episode()
|
||||
obs = self.env.reset()
|
||||
while True:
|
||||
action = self.get_action(obs)
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
self.log_returns(reward, info=info)
|
||||
if done:
|
||||
self.end_episode(obs)
|
||||
obs = self.env.reset()
|
||||
self.start_episode()
|
||||
|
||||
|
||||
class PartOffPolicyServing(ServingEnv):
|
||||
def __init__(self, env, off_pol_frac):
|
||||
ServingEnv.__init__(self, env.action_space, env.observation_space)
|
||||
self.env = env
|
||||
self.off_pol_frac = off_pol_frac
|
||||
|
||||
def run(self):
|
||||
self.start_episode()
|
||||
obs = self.env.reset()
|
||||
while True:
|
||||
if random.random() < self.off_pol_frac:
|
||||
action = self.env.action_space.sample()
|
||||
self.log_action(obs, action)
|
||||
else:
|
||||
action = self.get_action(obs)
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
self.log_returns(reward, info=info)
|
||||
if done:
|
||||
self.end_episode(obs)
|
||||
obs = self.env.reset()
|
||||
self.start_episode()
|
||||
|
||||
|
||||
class SimpleOffPolicyServing(ServingEnv):
|
||||
def __init__(self, env):
|
||||
ServingEnv.__init__(self, env.action_space, env.observation_space)
|
||||
self.env = env
|
||||
|
||||
def run(self):
|
||||
self.start_episode()
|
||||
obs = self.env.reset()
|
||||
while True:
|
||||
# Take random actions
|
||||
action = self.env.action_space.sample()
|
||||
self.log_action(obs, action)
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
self.log_returns(reward, info=info)
|
||||
if done:
|
||||
self.end_episode(obs)
|
||||
obs = self.env.reset()
|
||||
self.start_episode()
|
||||
|
||||
|
||||
class MultiServing(ServingEnv):
|
||||
def __init__(self, env_creator):
|
||||
self.env_creator = env_creator
|
||||
self.env = env_creator()
|
||||
ServingEnv.__init__(
|
||||
self, self.env.action_space, self.env.observation_space)
|
||||
|
||||
def run(self):
|
||||
envs = [self.env_creator() for _ in range(5)]
|
||||
cur_obs = {}
|
||||
eids = {}
|
||||
while True:
|
||||
active = np.random.choice(range(5), 2, replace=False)
|
||||
for i in active:
|
||||
if i not in cur_obs:
|
||||
eids[i] = uuid.uuid4().hex
|
||||
self.start_episode(episode_id=eids[i])
|
||||
cur_obs[i] = envs[i].reset()
|
||||
actions = [
|
||||
self.get_action(
|
||||
cur_obs[i], episode_id=eids[i]) for i in active]
|
||||
for i, action in zip(active, actions):
|
||||
obs, reward, done, _ = envs[i].step(action)
|
||||
cur_obs[i] = obs
|
||||
self.log_returns(reward, episode_id=eids[i])
|
||||
if done:
|
||||
self.end_episode(obs, episode_id=eids[i])
|
||||
del cur_obs[i]
|
||||
|
||||
|
||||
class TestServingEnv(unittest.TestCase):
|
||||
def testServingEnvCompleteEpisodes(self):
|
||||
ev = CommonPolicyEvaluator(
|
||||
env_creator=lambda _: SimpleServing(MockEnv(25)),
|
||||
policy_graph=MockPolicyGraph,
|
||||
batch_steps=40,
|
||||
batch_mode="complete_episodes")
|
||||
for _ in range(3):
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 50)
|
||||
|
||||
def testServingEnvTruncateEpisodes(self):
|
||||
ev = CommonPolicyEvaluator(
|
||||
env_creator=lambda _: SimpleServing(MockEnv(25)),
|
||||
policy_graph=MockPolicyGraph,
|
||||
batch_steps=40,
|
||||
batch_mode="truncate_episodes")
|
||||
for _ in range(3):
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 40)
|
||||
|
||||
def testServingEnvOffPolicy(self):
|
||||
ev = CommonPolicyEvaluator(
|
||||
env_creator=lambda _: SimpleOffPolicyServing(MockEnv(25)),
|
||||
policy_graph=MockPolicyGraph,
|
||||
batch_steps=40,
|
||||
batch_mode="complete_episodes")
|
||||
for _ in range(3):
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 50)
|
||||
|
||||
def testServingEnvBadActions(self):
|
||||
ev = CommonPolicyEvaluator(
|
||||
env_creator=lambda _: SimpleServing(MockEnv(25)),
|
||||
policy_graph=BadPolicyGraph,
|
||||
sample_async=True,
|
||||
batch_steps=40,
|
||||
batch_mode="truncate_episodes")
|
||||
self.assertRaises(Exception, lambda: ev.sample())
|
||||
|
||||
def testTrainCartpoleOffPolicy(self):
|
||||
register_env(
|
||||
"test3", lambda _: PartOffPolicyServing(
|
||||
gym.make("CartPole-v0"), off_pol_frac=0.2))
|
||||
dqn = DQNAgent(env="test3", config={"exploration_fraction": 0.001})
|
||||
for i in range(100):
|
||||
result = dqn.train()
|
||||
print("Iteration {}, reward {}, timesteps {}".format(
|
||||
i, result.episode_reward_mean, result.timesteps_total))
|
||||
if result.episode_reward_mean >= 100:
|
||||
return
|
||||
raise Exception("failed to improve reward")
|
||||
|
||||
def testTrainCartpole(self):
|
||||
register_env(
|
||||
"test", lambda _: SimpleServing(gym.make("CartPole-v0")))
|
||||
pg = PGAgent(env="test", config={"num_workers": 0})
|
||||
for i in range(100):
|
||||
result = pg.train()
|
||||
print("Iteration {}, reward {}, timesteps {}".format(
|
||||
i, result.episode_reward_mean, result.timesteps_total))
|
||||
if result.episode_reward_mean >= 100:
|
||||
return
|
||||
raise Exception("failed to improve reward")
|
||||
|
||||
def testTrainCartpoleMulti(self):
|
||||
register_env(
|
||||
"test2", lambda _: MultiServing(lambda: gym.make("CartPole-v0")))
|
||||
pg = PGAgent(env="test2", config={"num_workers": 0})
|
||||
for i in range(100):
|
||||
result = pg.train()
|
||||
print("Iteration {}, reward {}, timesteps {}".format(
|
||||
i, result.episode_reward_mean, result.timesteps_total))
|
||||
if result.episode_reward_mean >= 100:
|
||||
return
|
||||
raise Exception("failed to improve reward")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ray.init()
|
||||
unittest.main(verbosity=2)
|
||||
@@ -0,0 +1,115 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
class AsyncVectorEnv(object):
|
||||
"""The lowest-level env interface used by RLlib for sampling.
|
||||
|
||||
AsyncVectorEnv models multiple agents executing asynchronously. A call to
|
||||
poll() returns observations from ready agents, and actions for those agents
|
||||
can be sent back via send_actions().
|
||||
|
||||
All other env types can be adapted to AsyncVectorEnv. RLlib handles these
|
||||
conversions internally in CommonPolicyEvaluator, for example:
|
||||
|
||||
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
|
||||
rllib.ServingEnv => rllib.AsyncVectorEnv
|
||||
|
||||
Examples:
|
||||
>>> env = MyAsyncVectorEnv()
|
||||
>>> obs, rewards, dones, infos, off_policy_actions = env.poll()
|
||||
>>> print(obs)
|
||||
{"car_0": [2.4, 1.6], "car_1": [3.4, -3.2]}
|
||||
>>> env.send_actions({"car_0": 0, "car_1": 1})
|
||||
>>> obs, rewards, dones, infos, off_policy_actions = env.poll()
|
||||
>>> print(obs)
|
||||
{"car_0": [4.1, 1.7], "car_1": [3.2, -4.2]}
|
||||
"""
|
||||
|
||||
def poll(self):
|
||||
"""Returns observations from ready agents.
|
||||
|
||||
The returns are dicts mapping from agent episode ids to values. The
|
||||
number of agents can vary over time.
|
||||
|
||||
Returns:
|
||||
obs (dict): New observations for each ready episode.
|
||||
rewards (dict): Reward values for each ready episode. If the
|
||||
episode is just started, the value will be None.
|
||||
dones (dict): Done values for each ready episode. If True, the
|
||||
episode is terminated.
|
||||
infos (dict): Info values for each ready episode.
|
||||
off_policy_actions (dict): Agents may take off-policy actions. When
|
||||
that happens, there will be an entry in this dict that contains
|
||||
the taken action.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def send_actions(self, action_dict):
|
||||
"""Called to send actions back to running agents in this env.
|
||||
|
||||
Arguments:
|
||||
action_dict (dict): Actions for each agent to take.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def try_reset(self, agent_id):
|
||||
"""Attempt to reset the agent with the given id.
|
||||
|
||||
If the environment does not support synchronous reset, None can be
|
||||
returned here.
|
||||
|
||||
Returns:
|
||||
obs (obj|None): Resetted observation or None if not supported.
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_unwrapped(self):
|
||||
"""Return a reference to some underlying gym env, if any.
|
||||
|
||||
Returns:
|
||||
env (gym.Env|None): Underlying gym env or None.
|
||||
"""
|
||||
return None
|
||||
|
||||
|
||||
class _VectorEnvToAsync(AsyncVectorEnv):
|
||||
"""Wraps VectorEnv to implement AsyncVectorEnv.
|
||||
|
||||
We assume the caller will always send the full vector of actions in each
|
||||
call to send_actions(), and that they call reset_at() on all completed
|
||||
environments before calling send_actions().
|
||||
"""
|
||||
|
||||
def __init__(self, vector_env):
|
||||
self.vector_env = vector_env
|
||||
self.num_envs = vector_env.num_envs
|
||||
self.new_obs = self.vector_env.vector_reset()
|
||||
self.cur_rewards = [None for _ in range(self.num_envs)]
|
||||
self.cur_dones = [False for _ in range(self.num_envs)]
|
||||
self.cur_infos = [None for _ in range(self.num_envs)]
|
||||
|
||||
def poll(self):
|
||||
new_obs = dict(enumerate(self.new_obs))
|
||||
rewards = dict(enumerate(self.cur_rewards))
|
||||
dones = dict(enumerate(self.cur_dones))
|
||||
infos = dict(enumerate(self.cur_infos))
|
||||
self.new_obs = []
|
||||
self.cur_rewards = []
|
||||
self.cur_dones = []
|
||||
self.cur_infos = []
|
||||
return new_obs, rewards, dones, infos, {}
|
||||
|
||||
def send_actions(self, action_dict):
|
||||
action_vector = [None] * self.num_envs
|
||||
for i in range(self.num_envs):
|
||||
action_vector[i] = action_dict[i]
|
||||
self.new_obs, self.cur_rewards, self.cur_dones, self.cur_infos = \
|
||||
self.vector_env.vector_step(action_vector)
|
||||
|
||||
def try_reset(self, agent_id):
|
||||
return self.vector_env.reset_at(agent_id)
|
||||
|
||||
def get_unwrapped(self):
|
||||
return self.vector_env.get_unwrapped()
|
||||
@@ -6,6 +6,10 @@ import cv2
|
||||
cv2.ocl.setUseOpenCL(False)
|
||||
|
||||
|
||||
def is_atari(env):
|
||||
return hasattr(env, "unwrapped") and hasattr(env.unwrapped, "ale")
|
||||
|
||||
|
||||
class NoopResetEnv(gym.Wrapper):
|
||||
def __init__(self, env, noop_max=30, random_starts=False):
|
||||
"""Sample initial states by taking random number of no-ops on reset.
|
||||
|
||||
@@ -8,12 +8,16 @@ import tensorflow as tf
|
||||
|
||||
import ray
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.optimizers import SampleBatch
|
||||
from ray.rllib.optimizers.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.utils.atari_wrappers import wrap_deepmind
|
||||
from ray.rllib.utils.async_vector_env import AsyncVectorEnv, _VectorEnvToAsync
|
||||
from ray.rllib.utils.atari_wrappers import wrap_deepmind, is_atari
|
||||
from ray.rllib.utils.compression import pack
|
||||
from ray.rllib.utils.filter import get_filter
|
||||
from ray.rllib.utils.sampler import AsyncSampler, SyncSampler
|
||||
from ray.rllib.utils.serving_env import ServingEnv, _ServingEnvToAsync
|
||||
from ray.rllib.utils.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.utils.vector_env import VectorEnv
|
||||
from ray.tune.registry import get_registry
|
||||
from ray.tune.result import TrainingResult
|
||||
|
||||
@@ -53,10 +57,8 @@ def collect_metrics(local_evaluator, remote_evaluators):
|
||||
class CommonPolicyEvaluator(PolicyEvaluator):
|
||||
"""Policy evaluator implementation that operates on a rllib.PolicyGraph.
|
||||
|
||||
TODO: vector env
|
||||
TODO: multi-agent
|
||||
TODO: consumer buffering for multi-agent
|
||||
TODO: complete episode batch mode
|
||||
TODO: multi-gpu
|
||||
|
||||
Examples:
|
||||
# Create a policy evaluator and using it to collect experiences.
|
||||
@@ -89,9 +91,11 @@ class CommonPolicyEvaluator(PolicyEvaluator):
|
||||
tf_session_creator=None,
|
||||
batch_steps=100,
|
||||
batch_mode="truncate_episodes",
|
||||
episode_horizon=None,
|
||||
preprocessor_pref="rllib",
|
||||
sample_async=False,
|
||||
compress_observations=False,
|
||||
num_envs=1,
|
||||
observation_filter="NoFilter",
|
||||
registry=None,
|
||||
env_config=None,
|
||||
@@ -108,13 +112,20 @@ class CommonPolicyEvaluator(PolicyEvaluator):
|
||||
This is optional and only useful with TFPolicyGraph.
|
||||
batch_steps (int): The target number of env transitions to include
|
||||
in each sample batch returned from this evaluator.
|
||||
batch_mode (str): One of the following choices:
|
||||
complete_episodes: each batch will be at least batch_steps
|
||||
in size, and will include one or more complete episodes.
|
||||
truncate_episodes: each batch will be around batch_steps
|
||||
in size, and include transitions from one episode only.
|
||||
pack_episodes: each batch will be exactly batch_steps in
|
||||
size, and may include transitions from multiple episodes.
|
||||
batch_mode (str): One of the following batch modes:
|
||||
"truncate_episodes": Each call to sample() will return a batch
|
||||
of exactly `batch_steps` in size. Episodes may be truncated
|
||||
in order to meet this size requirement. When
|
||||
`num_envs > 1`, episodes will be truncated to sequences of
|
||||
`batch_size / num_envs` in length.
|
||||
"complete_episodes": Each call to sample() will return a batch
|
||||
of at least `batch_steps in size. Episodes will not be
|
||||
truncated, but multiple episodes may be packed within one
|
||||
batch to meet the batch size. Note that when
|
||||
`num_envs > 1`, episode steps will be buffered until the
|
||||
episode completes, and hence batches may contain
|
||||
significant amounts of off-policy data.
|
||||
episode_horizon (int): Whether to stop episodes at this horizon.
|
||||
preprocessor_pref (str): Whether to prefer RLlib preprocessors
|
||||
("rllib") or deepmind ("deepmind") when applicable.
|
||||
sample_async (bool): Whether to compute samples asynchronously in
|
||||
@@ -122,6 +133,9 @@ class CommonPolicyEvaluator(PolicyEvaluator):
|
||||
to be slightly off-policy.
|
||||
compress_observations (bool): If true, compress the observations
|
||||
returned.
|
||||
num_envs (int): If more than one, will create multiple envs
|
||||
and vectorize the computation of actions. This has no effect if
|
||||
if the env already implements VectorEnv.
|
||||
observation_filter (str): Name of observation filter to use.
|
||||
registry (tune.Registry): User-registered objects. Pass in the
|
||||
value from tune.registry.get_registry() if you're having
|
||||
@@ -135,9 +149,6 @@ class CommonPolicyEvaluator(PolicyEvaluator):
|
||||
env_config = env_config or {}
|
||||
policy_config = policy_config or {}
|
||||
model_config = model_config or {}
|
||||
|
||||
assert batch_mode in [
|
||||
"complete_episodes", "truncate_episodes", "pack_episodes"]
|
||||
self.env_creator = env_creator
|
||||
self.policy_graph = policy_graph
|
||||
self.batch_steps = batch_steps
|
||||
@@ -145,15 +156,25 @@ class CommonPolicyEvaluator(PolicyEvaluator):
|
||||
self.compress_observations = compress_observations
|
||||
|
||||
self.env = env_creator(env_config)
|
||||
is_atari = hasattr(self.env.unwrapped, "ale")
|
||||
if is_atari and "custom_preprocessor" not in model_config and \
|
||||
if isinstance(self.env, VectorEnv) or \
|
||||
isinstance(self.env, ServingEnv) or \
|
||||
isinstance(self.env, AsyncVectorEnv):
|
||||
def wrap(env):
|
||||
return env # we can't auto-wrap these env types
|
||||
elif is_atari(self.env) and \
|
||||
"custom_preprocessor" not in model_config and \
|
||||
preprocessor_pref == "deepmind":
|
||||
self.env = wrap_deepmind(self.env, dim=model_config.get("dim", 80))
|
||||
def wrap(env):
|
||||
return wrap_deepmind(env, dim=model_config.get("dim", 80))
|
||||
else:
|
||||
self.env = ModelCatalog.get_preprocessor_as_wrapper(
|
||||
registry, self.env, model_config)
|
||||
def wrap(env):
|
||||
return ModelCatalog.get_preprocessor_as_wrapper(
|
||||
registry, env, model_config)
|
||||
self.env = wrap(self.env)
|
||||
|
||||
def make_env():
|
||||
return wrap(env_creator(env_config))
|
||||
|
||||
self.vectorized = hasattr(self.env, "vector_reset")
|
||||
self.policy_map = {}
|
||||
|
||||
if issubclass(policy_graph, TFPolicyGraph):
|
||||
@@ -179,24 +200,41 @@ class CommonPolicyEvaluator(PolicyEvaluator):
|
||||
observation_filter, self.env.observation_space.shape)
|
||||
self.filters = {"obs_filter": self.obs_filter}
|
||||
|
||||
if self.vectorized:
|
||||
raise NotImplementedError("Vector envs not yet supported")
|
||||
else:
|
||||
if batch_mode not in [
|
||||
"pack_episodes", "truncate_episodes", "complete_episodes"]:
|
||||
raise NotImplementedError("Batch mode not yet supported")
|
||||
pack = batch_mode == "pack_episodes"
|
||||
if batch_mode == "complete_episodes":
|
||||
batch_steps = 999999
|
||||
if sample_async:
|
||||
self.sampler = AsyncSampler(
|
||||
self.env, self.policy_map["default"], self.obs_filter,
|
||||
batch_steps, pack=pack)
|
||||
self.sampler.start()
|
||||
# Always use vector env for consistency even if num_envs = 1
|
||||
if not isinstance(self.env, AsyncVectorEnv):
|
||||
if isinstance(self.env, ServingEnv):
|
||||
self.vector_env = _ServingEnvToAsync(self.env)
|
||||
else:
|
||||
self.sampler = SyncSampler(
|
||||
self.env, self.policy_map["default"], self.obs_filter,
|
||||
batch_steps, pack=pack)
|
||||
if not isinstance(self.env, VectorEnv):
|
||||
self.env = VectorEnv.wrap(
|
||||
make_env, [self.env], num_envs=num_envs)
|
||||
self.vector_env = _VectorEnvToAsync(self.env)
|
||||
else:
|
||||
self.vector_env = self.env
|
||||
|
||||
if self.batch_mode == "truncate_episodes":
|
||||
if batch_steps % num_envs != 0:
|
||||
raise ValueError(
|
||||
"In 'truncate_episodes' batch mode, `batch_steps` must be "
|
||||
"evenly divisible by `num_envs`. Got {} and {}.".format(
|
||||
batch_steps, num_envs))
|
||||
batch_steps = batch_steps // num_envs
|
||||
pack_episodes = True
|
||||
elif self.batch_mode == "complete_episodes":
|
||||
batch_steps = float("inf") # never cut episodes
|
||||
pack_episodes = False # sampler will return 1 episode per poll
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported batch mode: {}".format(self.batch_mode))
|
||||
if sample_async:
|
||||
self.sampler = AsyncSampler(
|
||||
self.vector_env, self.policy_map["default"], self.obs_filter,
|
||||
batch_steps, horizon=episode_horizon, pack=pack_episodes)
|
||||
self.sampler.start()
|
||||
else:
|
||||
self.sampler = SyncSampler(
|
||||
self.vector_env, self.policy_map["default"], self.obs_filter,
|
||||
batch_steps, horizon=episode_horizon, pack=pack_episodes)
|
||||
|
||||
def sample(self):
|
||||
"""Evaluate the current policies and return a batch of experiences.
|
||||
@@ -205,8 +243,13 @@ class CommonPolicyEvaluator(PolicyEvaluator):
|
||||
SampleBatch from evaluating the current policies.
|
||||
"""
|
||||
|
||||
batch = self.policy_map["default"].postprocess_trajectory(
|
||||
self.sampler.get_data())
|
||||
batches = [self.sampler.get_data()]
|
||||
steps_so_far = batches[0].count
|
||||
while steps_so_far < self.batch_steps:
|
||||
batch = self.sampler.get_data()
|
||||
steps_so_far += batch.count
|
||||
batches.append(batch)
|
||||
batch = SampleBatch.concat_samples(batches)
|
||||
|
||||
if self.compress_observations:
|
||||
batch["obs"] = [pack(o) for o in batch["obs"]]
|
||||
@@ -214,11 +257,6 @@ class CommonPolicyEvaluator(PolicyEvaluator):
|
||||
|
||||
return batch
|
||||
|
||||
def apply(self, func):
|
||||
"""Apply the given function to this evaluator instance."""
|
||||
|
||||
return func(self)
|
||||
|
||||
def for_policy(self, func):
|
||||
"""Apply the given function to this evaluator's default policy."""
|
||||
|
||||
|
||||
@@ -69,7 +69,8 @@ class PolicyGraph(object):
|
||||
"""Implements algorithm-specific trajectory postprocessing.
|
||||
|
||||
Arguments:
|
||||
sample_batch (SampleBatch): batch of experiences for the policy
|
||||
sample_batch (SampleBatch): batch of experiences for the policy,
|
||||
which will contain at most one episode trajectory.
|
||||
other_agent_batches (dict): In a multi-agent env, this contains the
|
||||
experience batches seen by other agents.
|
||||
|
||||
|
||||
@@ -2,12 +2,14 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import namedtuple
|
||||
from collections import defaultdict, namedtuple
|
||||
import numpy as np
|
||||
import six.moves.queue as queue
|
||||
import threading
|
||||
|
||||
from ray.rllib.optimizers.sample_batch import SampleBatchBuilder
|
||||
from ray.rllib.utils.vector_env import VectorEnv
|
||||
from ray.rllib.utils.async_vector_env import AsyncVectorEnv, _VectorEnvToAsync
|
||||
|
||||
|
||||
CompletedRollout = namedtuple("CompletedRollout",
|
||||
@@ -22,17 +24,20 @@ class SyncSampler(object):
|
||||
|
||||
This class provides data on invocation, rather than on a separate
|
||||
thread."""
|
||||
_async = False
|
||||
|
||||
def __init__(
|
||||
self, env, policy, obs_filter, num_local_steps, horizon=None,
|
||||
pack=False):
|
||||
self, env, policy, obs_filter, num_local_steps,
|
||||
horizon=None, pack=False):
|
||||
if not isinstance(env, AsyncVectorEnv):
|
||||
if not isinstance(env, VectorEnv):
|
||||
env = VectorEnv.wrap(make_env=None, existing_envs=[env])
|
||||
env = _VectorEnvToAsync(env)
|
||||
self.async_vector_env = env
|
||||
self.num_local_steps = num_local_steps
|
||||
self.horizon = horizon
|
||||
self.env = env
|
||||
self.policy = policy
|
||||
self._obs_filter = obs_filter
|
||||
self.rollout_provider = _env_runner(self.env, self.policy,
|
||||
self.rollout_provider = _env_runner(self.async_vector_env, self.policy,
|
||||
self.num_local_steps, self.horizon,
|
||||
self._obs_filter, pack)
|
||||
self.metrics_queue = queue.Queue()
|
||||
@@ -60,28 +65,29 @@ class AsyncSampler(threading.Thread):
|
||||
|
||||
Note that batch_size is only a unit of measure here. Batches can
|
||||
accumulate and the gradient can be calculated on up to 5 batches."""
|
||||
_async = True
|
||||
|
||||
def __init__(
|
||||
self, env, policy, obs_filter, num_local_steps, horizon=None,
|
||||
pack=False):
|
||||
self, env, policy, obs_filter, num_local_steps,
|
||||
horizon=None, pack=False):
|
||||
assert getattr(
|
||||
obs_filter, "is_concurrent",
|
||||
False), ("Observation Filter must support concurrent updates.")
|
||||
if not isinstance(env, AsyncVectorEnv):
|
||||
if not isinstance(env, VectorEnv):
|
||||
env = VectorEnv.wrap(make_env=None, existing_envs=[env])
|
||||
env = _VectorEnvToAsync(env)
|
||||
self.async_vector_env = env
|
||||
threading.Thread.__init__(self)
|
||||
self.queue = queue.Queue(5)
|
||||
self.metrics_queue = queue.Queue()
|
||||
self.num_local_steps = num_local_steps
|
||||
self.horizon = horizon
|
||||
self.env = env
|
||||
self.policy = policy
|
||||
self._obs_filter = obs_filter
|
||||
self.started = False
|
||||
self.daemon = True
|
||||
self.pack = pack
|
||||
|
||||
def run(self):
|
||||
self.started = True
|
||||
try:
|
||||
self._run()
|
||||
except BaseException as e:
|
||||
@@ -89,7 +95,7 @@ class AsyncSampler(threading.Thread):
|
||||
raise e
|
||||
|
||||
def _run(self):
|
||||
rollout_provider = _env_runner(self.env, self.policy,
|
||||
rollout_provider = _env_runner(self.async_vector_env, self.policy,
|
||||
self.num_local_steps, self.horizon,
|
||||
self._obs_filter, self.pack)
|
||||
while True:
|
||||
@@ -103,15 +109,17 @@ class AsyncSampler(threading.Thread):
|
||||
self.queue.put(item, timeout=600.0)
|
||||
|
||||
def get_data(self):
|
||||
"""Gets currently accumulated data.
|
||||
|
||||
Returns:
|
||||
rollout (SampleBatch): trajectory data (unprocessed)
|
||||
"""
|
||||
assert self.started, "Sampler never started running!"
|
||||
rollout = self.queue.get(timeout=600.0)
|
||||
|
||||
# Propagate errors
|
||||
if isinstance(rollout, BaseException):
|
||||
raise rollout
|
||||
|
||||
# We can't auto-concat rollouts in vector mode
|
||||
if self.async_vector_env.num_envs > 1:
|
||||
return rollout
|
||||
|
||||
# Auto-concat rollouts; TODO(ekl) is this important for A3C perf?
|
||||
while not rollout["dones"][-1]:
|
||||
try:
|
||||
part = self.queue.get_nowait()
|
||||
@@ -132,7 +140,8 @@ class AsyncSampler(threading.Thread):
|
||||
return completed
|
||||
|
||||
|
||||
def _env_runner(env, policy, num_local_steps, horizon, obs_filter, pack):
|
||||
def _env_runner(
|
||||
async_vector_env, policy, num_local_steps, horizon, obs_filter, pack):
|
||||
"""This implements the logic of the thread runner.
|
||||
|
||||
It continually runs the policy, and as long as the rollout exceeds a
|
||||
@@ -141,9 +150,9 @@ def _env_runner(env, policy, num_local_steps, horizon, obs_filter, pack):
|
||||
`num_local_steps` is reached.
|
||||
|
||||
Args:
|
||||
env: Environment generated by env_creator
|
||||
async_vector_env: env implementing AsyncVectorEnv.
|
||||
policy: Policy used to interact with environment. Also sets fields
|
||||
to be included in `SampleBatch`
|
||||
to be included in `SampleBatch`.
|
||||
num_local_steps: Number of steps before `SampleBatch` is yielded. Set
|
||||
to infinity to yield complete episodes.
|
||||
horizon: Horizon of the episode.
|
||||
@@ -155,67 +164,146 @@ def _env_runner(env, policy, num_local_steps, horizon, obs_filter, pack):
|
||||
rollout (SampleBatch): Object containing state, action, reward,
|
||||
terminal condition, and other fields as dictated by `policy`.
|
||||
"""
|
||||
last_observation = obs_filter(env.reset())
|
||||
|
||||
try:
|
||||
horizon = horizon if horizon else env.spec.max_episode_steps
|
||||
if not horizon:
|
||||
horizon = async_vector_env.get_unwrapped().spec.max_episode_steps
|
||||
except Exception:
|
||||
print("Warning, no horizon specified, assuming infinite")
|
||||
if not horizon:
|
||||
horizon = 999999
|
||||
last_features = policy.get_initial_state()
|
||||
features = last_features
|
||||
length = 0
|
||||
rewards = 0
|
||||
rollout_number = 0
|
||||
horizon = float("inf")
|
||||
|
||||
# Pool of batch builders, which can be shared across episodes to pack
|
||||
# trajectory data.
|
||||
batch_builder_pool = []
|
||||
|
||||
def get_batch_builder():
|
||||
if batch_builder_pool:
|
||||
return batch_builder_pool.pop()
|
||||
else:
|
||||
return SampleBatchBuilder()
|
||||
|
||||
episodes = defaultdict(
|
||||
lambda: _Episode(policy.get_initial_state(), get_batch_builder))
|
||||
|
||||
while True:
|
||||
batch_builder = SampleBatchBuilder()
|
||||
# Get observations from ready envs
|
||||
unfiltered_obs, rewards, dones, _, off_policy_actions = \
|
||||
async_vector_env.poll()
|
||||
ready_eids = []
|
||||
ready_obs = []
|
||||
ready_rnn_states = []
|
||||
|
||||
for _ in range(num_local_steps):
|
||||
# Assume batch size one for now
|
||||
action, features, pi_info = policy.compute_single_action(
|
||||
last_observation, last_features, is_training=True)
|
||||
for i, state_value in enumerate(last_features):
|
||||
pi_info["state_in_{}".format(i)] = state_value
|
||||
for i, state_value in enumerate(features):
|
||||
pi_info["state_out_{}".format(i)] = state_value
|
||||
observation, reward, terminal, info = env.step(action)
|
||||
observation = obs_filter(observation)
|
||||
# Process and record the new observations
|
||||
for eid, raw_obs in unfiltered_obs.items():
|
||||
episode = episodes[eid]
|
||||
filtered_obs = obs_filter(raw_obs)
|
||||
ready_eids.append(eid)
|
||||
ready_obs.append(filtered_obs)
|
||||
ready_rnn_states.append(episode.rnn_state)
|
||||
|
||||
length += 1
|
||||
rewards += reward
|
||||
if length >= horizon:
|
||||
terminal = True
|
||||
if episode.last_observation is None:
|
||||
episode.last_observation = filtered_obs
|
||||
continue # This is the initial observation after a reset
|
||||
|
||||
# Concatenate multiagent actions
|
||||
if isinstance(action, list):
|
||||
action = np.concatenate(action, axis=0).flatten()
|
||||
episode.length += 1
|
||||
episode.total_reward += rewards[eid]
|
||||
|
||||
# Collect the experience.
|
||||
batch_builder.add_values(
|
||||
obs=last_observation,
|
||||
actions=action,
|
||||
rewards=reward,
|
||||
dones=terminal,
|
||||
new_obs=observation,
|
||||
**pi_info)
|
||||
# Handle episode terminations
|
||||
if dones[eid] or episode.length >= horizon:
|
||||
done = True
|
||||
yield CompletedRollout(episode.length, episode.total_reward)
|
||||
else:
|
||||
done = False
|
||||
|
||||
last_observation = observation
|
||||
last_features = features
|
||||
episode.batch_builder.add_values(
|
||||
obs=episode.last_observation,
|
||||
actions=episode.last_action_flat(),
|
||||
rewards=rewards[eid],
|
||||
dones=done,
|
||||
new_obs=filtered_obs,
|
||||
**episode.last_pi_info)
|
||||
|
||||
if terminal:
|
||||
yield CompletedRollout(length, rewards)
|
||||
# Cut the batch if we're not packing multiple episodes into one,
|
||||
# or if we've exceeded the requested batch size.
|
||||
if (done and not pack) or \
|
||||
episode.batch_builder.count >= num_local_steps:
|
||||
yield episode.batch_builder.build_and_reset(
|
||||
policy.postprocess_trajectory)
|
||||
elif done:
|
||||
# Make sure postprocessor never goes across episode boundaries
|
||||
episode.batch_builder.postprocess_batch_so_far(
|
||||
policy.postprocess_trajectory)
|
||||
|
||||
if (length >= horizon or
|
||||
not env.metadata.get("semantics.autoreset")):
|
||||
last_observation = obs_filter(env.reset())
|
||||
last_features = policy.get_initial_state()
|
||||
rollout_number += 1
|
||||
length = 0
|
||||
rewards = 0
|
||||
if not pack:
|
||||
break
|
||||
if done:
|
||||
# Handle episode termination
|
||||
batch_builder_pool.append(episode.batch_builder)
|
||||
del episodes[eid]
|
||||
resetted_obs = async_vector_env.try_reset(eid)
|
||||
if resetted_obs is None:
|
||||
# Reset not supported, drop this env from the ready list
|
||||
assert horizon == float("inf"), \
|
||||
"Setting episode horizon requires reset() support."
|
||||
ready_eids.pop()
|
||||
ready_obs.pop()
|
||||
ready_rnn_states.pop()
|
||||
else:
|
||||
# Reset successful, put in the new obs as ready
|
||||
episode = episodes[eid]
|
||||
episode.last_observation = obs_filter(resetted_obs)
|
||||
ready_obs[-1] = episode.last_observation
|
||||
ready_rnn_states[-1] = episode.rnn_state
|
||||
else:
|
||||
episode.last_observation = filtered_obs
|
||||
|
||||
# Once we have enough experience, yield it, and have the ThreadRunner
|
||||
# place it on a queue.
|
||||
yield batch_builder.build()
|
||||
if not ready_eids:
|
||||
continue # No actions to take
|
||||
|
||||
# Compute action for ready envs
|
||||
ready_rnn_state_cols = _to_column_format(ready_rnn_states)
|
||||
actions, new_rnn_state_cols, pi_info_cols = policy.compute_actions(
|
||||
ready_obs, ready_rnn_state_cols, is_training=True)
|
||||
|
||||
# Add RNN state info
|
||||
for f_i, column in enumerate(ready_rnn_state_cols):
|
||||
pi_info_cols["state_in_{}".format(f_i)] = column
|
||||
for f_i, column in enumerate(new_rnn_state_cols):
|
||||
pi_info_cols["state_out_{}".format(f_i)] = column
|
||||
|
||||
# Return computed actions to ready envs. We also send to envs that have
|
||||
# taken off-policy actions; those envs are free to ignore the action.
|
||||
async_vector_env.send_actions(dict(zip(ready_eids, actions)))
|
||||
|
||||
# Store the computed action info
|
||||
for i, eid in enumerate(ready_eids):
|
||||
episode = episodes[eid]
|
||||
if eid in off_policy_actions:
|
||||
episode.last_action = off_policy_actions[eid]
|
||||
else:
|
||||
episode.last_action = actions[i]
|
||||
episode.rnn_state = [column[i] for column in new_rnn_state_cols]
|
||||
episode.last_pi_info = {
|
||||
k: column[i] for k, column in pi_info_cols.items()}
|
||||
|
||||
|
||||
def _to_column_format(rnn_state_rows):
|
||||
num_cols = len(rnn_state_rows[0])
|
||||
return [
|
||||
[row[i] for row in rnn_state_rows] for i in range(num_cols)]
|
||||
|
||||
|
||||
class _Episode(object):
|
||||
def __init__(self, init_rnn_state, batch_builder_factory):
|
||||
self.rnn_state = init_rnn_state
|
||||
self.batch_builder = batch_builder_factory()
|
||||
self.last_action = None
|
||||
self.last_observation = None
|
||||
self.last_pi_info = None
|
||||
self.total_reward = 0.0
|
||||
self.length = 0
|
||||
|
||||
def last_action_flat(self):
|
||||
# Concatenate multiagent actions
|
||||
if isinstance(self.last_action, list):
|
||||
return np.concatenate(self.last_action, axis=0).flatten()
|
||||
return self.last_action
|
||||
|
||||
@@ -0,0 +1,266 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from six.moves import queue
|
||||
import threading
|
||||
|
||||
from ray.rllib.utils.async_vector_env import AsyncVectorEnv
|
||||
|
||||
|
||||
class ServingEnv(threading.Thread):
|
||||
"""Environment that provides policy serving.
|
||||
|
||||
Unlike simulator envs, control is inverted. The environment queries the
|
||||
policy to obtain actions and logs observations and rewards for training.
|
||||
This is in contrast to gym.Env, where the algorithm drives the simulation
|
||||
through env.step() calls.
|
||||
|
||||
You can use ServingEnv as the backend for policy serving (by serving HTTP
|
||||
requests in the run loop), for ingesting offline logs data (by reading
|
||||
offline transitions in the run loop), or other custom use cases not easily
|
||||
expressed through gym.Env.
|
||||
|
||||
ServingEnv supports both on-policy serving (through self.get_action()), and
|
||||
off-policy serving (through self.log_action()).
|
||||
|
||||
This env is thread-safe, but individual episodes must be executed serially.
|
||||
|
||||
TODO: Provide a HTTP server/client example based on ServingEnv.
|
||||
|
||||
Examples:
|
||||
>>> register_env("my_env", lambda config: YourServingEnv(config))
|
||||
>>> agent = DQNAgent(env="my_env")
|
||||
>>> while True:
|
||||
print(agent.train())
|
||||
"""
|
||||
|
||||
def __init__(self, action_space, observation_space, max_concurrent=100):
|
||||
"""Initialize a serving env.
|
||||
|
||||
Arguments:
|
||||
action_space (gym.Space): Action space of the env.
|
||||
observation_space (gym.Space): Observation space of the env.
|
||||
max_concurrent (int): Max number of active episodes to allow at
|
||||
once. Exceeding this limit raises an error.
|
||||
"""
|
||||
|
||||
threading.Thread.__init__(self)
|
||||
self.daemon = True
|
||||
self.action_space = action_space
|
||||
self.observation_space = observation_space
|
||||
self._episodes = {}
|
||||
self._finished = set()
|
||||
self._num_episodes = 0
|
||||
self._cur_default_episode_id = None
|
||||
self._results_avail_condition = threading.Condition()
|
||||
self._max_concurrent_episodes = max_concurrent
|
||||
|
||||
def run(self):
|
||||
"""Override this to implement the run loop.
|
||||
|
||||
Your loop should continuously:
|
||||
1. Call self.start_episode()
|
||||
2. Call self.get_action() or self.log_action()
|
||||
3. Call self.log_returns()
|
||||
4. Call self.end_episode()
|
||||
5. Wait if nothing to do.
|
||||
|
||||
Multiple episodes may be started at the same time.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def start_episode(self, episode_id=None):
|
||||
"""Record the start of an episode.
|
||||
|
||||
Arguments:
|
||||
episode_id (str): Unique string id for the episode or None for
|
||||
it to be auto-assigned. Auto-assignment only works if there
|
||||
is at most one active episode at a time.
|
||||
"""
|
||||
|
||||
if episode_id is None:
|
||||
if self._cur_default_episode_id:
|
||||
raise ValueError(
|
||||
"An existing episode is still active. You must pass "
|
||||
"`episode_id` if there are going to be multiple active "
|
||||
"episodes at once.")
|
||||
episode_id = "default_{}".format(self._num_episodes)
|
||||
self._cur_default_episode_id = episode_id
|
||||
self._num_episodes += 1
|
||||
|
||||
if episode_id in self._finished:
|
||||
raise ValueError(
|
||||
"Episode {} has already completed.".format(episode_id))
|
||||
|
||||
if episode_id in self._episodes:
|
||||
raise ValueError(
|
||||
"Episode {} is already started".format(episode_id))
|
||||
|
||||
self._episodes[episode_id] = _Episode(
|
||||
episode_id, self._results_avail_condition)
|
||||
|
||||
def get_action(self, observation, episode_id=None):
|
||||
"""Record an observation and get the on-policy action.
|
||||
|
||||
Arguments:
|
||||
observation (obj): Current environment observation.
|
||||
episode_id (str): Episode id passed to start_episode() or None.
|
||||
|
||||
Returns:
|
||||
action (obj): Action from the env action space.
|
||||
"""
|
||||
|
||||
episode = self._get(episode_id)
|
||||
return episode.wait_for_action(observation)
|
||||
|
||||
def log_action(self, observation, action, episode_id=None):
|
||||
"""Record an observation and (off-policy) action taken.
|
||||
|
||||
Arguments:
|
||||
observation (obj): Current environment observation.
|
||||
action (obj): Action for the observation.
|
||||
episode_id (str): Episode id passed to start_episode() or None.
|
||||
"""
|
||||
|
||||
episode = self._get(episode_id)
|
||||
episode.log_action(observation, action)
|
||||
|
||||
def log_returns(self, reward, info=None, episode_id=None):
|
||||
"""Record returns from the environment.
|
||||
|
||||
The reward will be attributed to the previous action taken by the
|
||||
episode. Rewards accumulate until the next action. If no reward is
|
||||
logged before the next action, a reward of 0.0 is assumed.
|
||||
|
||||
Arguments:
|
||||
episode_id (str): Episode id passed to start_episode() or None.
|
||||
reward (float): Reward from the environment.
|
||||
"""
|
||||
|
||||
episode = self._get(episode_id)
|
||||
episode.cur_reward += reward
|
||||
if info:
|
||||
episode.cur_info = info
|
||||
|
||||
def end_episode(self, observation, episode_id=None):
|
||||
"""Record the end of an episode.
|
||||
|
||||
Arguments:
|
||||
episode_id (str): Episode id passed by start_episode() or None.
|
||||
observation (obj): Current environment observation.
|
||||
"""
|
||||
|
||||
episode = self._get(episode_id)
|
||||
self._finished.add(episode.episode_id)
|
||||
self._cur_default_episode_id = None
|
||||
episode.done(observation)
|
||||
|
||||
def _get(self, episode_id=None):
|
||||
"""Get a started episode or raise an error."""
|
||||
|
||||
if episode_id is None:
|
||||
episode_id = self._cur_default_episode_id
|
||||
|
||||
if episode_id in self._finished:
|
||||
raise ValueError(
|
||||
"Episode {} has already completed.".format(episode_id))
|
||||
|
||||
if episode_id not in self._episodes:
|
||||
raise ValueError("Episode {} not found.".format(episode_id))
|
||||
|
||||
return self._episodes[episode_id]
|
||||
|
||||
|
||||
class _ServingEnvToAsync(AsyncVectorEnv):
|
||||
"""Internal adapter of ServingEnv to AsyncVectorEnv."""
|
||||
|
||||
def __init__(self, serving_env):
|
||||
self.serving_env = serving_env
|
||||
serving_env.start()
|
||||
|
||||
def poll(self):
|
||||
with self.serving_env._results_avail_condition:
|
||||
results = self._poll()
|
||||
while len(results[0]) == 0:
|
||||
self.serving_env._results_avail_condition.wait()
|
||||
results = self._poll()
|
||||
if not self.serving_env.isAlive():
|
||||
raise Exception("Serving thread has stopped.")
|
||||
limit = self.serving_env._max_concurrent_episodes
|
||||
assert len(results[0]) < limit, \
|
||||
("Too many concurrent episodes, were some leaked? This ServingEnv "
|
||||
"was created with max_concurrent={}".format(limit))
|
||||
return results
|
||||
|
||||
def _poll(self):
|
||||
all_obs, all_rewards, all_dones, all_infos = {}, {}, {}, {}
|
||||
off_policy_actions = {}
|
||||
for eid, episode in self.serving_env._episodes.copy().items():
|
||||
data = episode.get_data()
|
||||
if episode.cur_done:
|
||||
del self.serving_env._episodes[eid]
|
||||
if data:
|
||||
all_obs[eid] = data["obs"]
|
||||
all_rewards[eid] = data["reward"]
|
||||
all_dones[eid] = data["done"]
|
||||
all_infos[eid] = data["info"]
|
||||
if "off_policy_action" in data:
|
||||
off_policy_actions[eid] = data["off_policy_action"]
|
||||
return all_obs, all_rewards, all_dones, all_infos, off_policy_actions
|
||||
|
||||
def send_actions(self, action_dict):
|
||||
for eid, action in action_dict.items():
|
||||
self.serving_env._episodes[eid].action_queue.put(action)
|
||||
|
||||
|
||||
class _Episode(object):
|
||||
"""Tracked state for each active episode."""
|
||||
|
||||
def __init__(self, episode_id, results_avail_condition):
|
||||
self.episode_id = episode_id
|
||||
self.results_avail_condition = results_avail_condition
|
||||
self.data_queue = queue.Queue()
|
||||
self.action_queue = queue.Queue()
|
||||
self.new_observation = None
|
||||
self.new_action = None
|
||||
self.cur_reward = 0.0
|
||||
self.cur_done = False
|
||||
self.cur_info = {}
|
||||
|
||||
def get_data(self):
|
||||
if self.data_queue.empty():
|
||||
return None
|
||||
return self.data_queue.get_nowait()
|
||||
|
||||
def log_action(self, observation, action):
|
||||
self.new_observation = observation
|
||||
self.new_action = action
|
||||
self._send()
|
||||
self.action_queue.get(True, timeout=60.0)
|
||||
|
||||
def wait_for_action(self, observation):
|
||||
self.new_observation = observation
|
||||
self._send()
|
||||
return self.action_queue.get(True, timeout=60.0)
|
||||
|
||||
def done(self, observation):
|
||||
self.new_observation = observation
|
||||
self.cur_done = True
|
||||
self._send()
|
||||
|
||||
def _send(self):
|
||||
item = {
|
||||
"obs": self.new_observation,
|
||||
"reward": self.cur_reward,
|
||||
"done": self.cur_done,
|
||||
"info": self.cur_info,
|
||||
}
|
||||
if self.new_action is not None:
|
||||
item["off_policy_action"] = self.new_action
|
||||
self.new_observation = None
|
||||
self.new_action = None
|
||||
self.cur_reward = 0.0
|
||||
with self.results_avail_condition:
|
||||
self.data_queue.put_nowait(item)
|
||||
self.results_avail_condition.notify()
|
||||
@@ -0,0 +1,117 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import queue
|
||||
import threading
|
||||
|
||||
|
||||
class VectorEnv(object):
|
||||
"""An environment that supports batch evaluation.
|
||||
|
||||
Subclasses must define the following attributes:
|
||||
|
||||
Attributes:
|
||||
action_space (gym.Space): Action space of individual envs.
|
||||
observation_space (gym.Space): Observation space of individual envs.
|
||||
num_envs (int): Number of envs to batch over.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def wrap(make_env=None, existing_envs=None, num_envs=1):
|
||||
return _VectorizedGymEnv(make_env, existing_envs or [], num_envs)
|
||||
|
||||
def vector_reset(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def reset_at(self, index):
|
||||
raise NotImplementedError
|
||||
|
||||
def vector_step(self, actions):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_unwrapped(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class _VectorizedGymEnv(VectorEnv):
|
||||
"""Internal wrapper for gym envs to implement VectorEnv.
|
||||
|
||||
Arguents:
|
||||
make_env (func|None): Factory that produces a new gym env. Must be
|
||||
defined if the number of existing envs is less than num_envs.
|
||||
existing_envs (list): List of existing gym envs.
|
||||
num_envs (int): Desired num gym envs to keep total.
|
||||
"""
|
||||
|
||||
def __init__(self, make_env, existing_envs, num_envs):
|
||||
self.make_env = make_env
|
||||
self.envs = existing_envs
|
||||
self.num_envs = num_envs
|
||||
if make_env and num_envs > 1:
|
||||
self.resetter = _AsyncResetter(
|
||||
make_env, int(self.num_envs ** 0.5))
|
||||
else:
|
||||
self.resetter = _SimpleResetter(make_env)
|
||||
while len(self.envs) < self.num_envs:
|
||||
self.envs.append(self.make_env())
|
||||
|
||||
def vector_reset(self):
|
||||
return [e.reset() for e in self.envs]
|
||||
|
||||
def reset_at(self, index):
|
||||
new_obs, new_env = self.resetter.trade_for_resetted(self.envs[index])
|
||||
self.envs[index] = new_env
|
||||
return new_obs
|
||||
|
||||
def vector_step(self, actions):
|
||||
obs_batch, rew_batch, done_batch, info_batch = [], [], [], []
|
||||
for i in range(self.num_envs):
|
||||
obs, rew, done, info = self.envs[i].step(actions[i])
|
||||
obs_batch.append(obs)
|
||||
rew_batch.append(rew)
|
||||
done_batch.append(done)
|
||||
info_batch.append(info)
|
||||
return obs_batch, rew_batch, done_batch, info_batch
|
||||
|
||||
def get_unwrapped(self):
|
||||
return self.envs[0]
|
||||
|
||||
|
||||
class _AsyncResetter(threading.Thread):
|
||||
"""Does env reset asynchronously in the background.
|
||||
|
||||
This is useful since resetting an env can be 100x slower than stepping."""
|
||||
|
||||
def __init__(self, make_env, pool_size):
|
||||
threading.Thread.__init__(self)
|
||||
self.make_env = make_env
|
||||
self.pool_size = 0
|
||||
self.to_reset = queue.Queue()
|
||||
self.resetted = queue.Queue()
|
||||
self.daemon = True
|
||||
self.pool_size = pool_size
|
||||
while self.resetted.qsize() < self.pool_size:
|
||||
env = self.make_env()
|
||||
obs = env.reset()
|
||||
self.resetted.put((obs, env))
|
||||
self.start()
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
env = self.to_reset.get()
|
||||
obs = env.reset()
|
||||
self.resetted.put((obs, env))
|
||||
|
||||
def trade_for_resetted(self, env):
|
||||
self.to_reset.put(env)
|
||||
new_obs, new_env = self.resetted.get(timeout=30)
|
||||
return new_obs, new_env
|
||||
|
||||
|
||||
class _SimpleResetter(object):
|
||||
def __init__(self, make_env):
|
||||
pass
|
||||
|
||||
def trade_for_resetted(self, env):
|
||||
return env.reset(), env
|
||||
Reference in New Issue
Block a user