[rllib] Basic infrastructure for off-policy estimation (IS, WIS) (#3941)

This commit is contained in:
Eric Liang
2019-02-13 16:25:05 -08:00
committed by GitHub
parent 729d0b2825
commit 2dccf383dd
34 changed files with 549 additions and 131 deletions
@@ -97,6 +97,7 @@ class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph):
self.sess,
obs_input=self.observations,
action_sampler=action_dist.sample(),
action_prob=action_dist.sampled_action_prob(),
loss=self.model.loss() + self.loss.total_loss,
loss_inputs=loss_in,
state_inputs=self.model.state_in,
@@ -153,7 +154,9 @@ class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph):
@override(TFPolicyGraph)
def extra_compute_action_fetches(self):
return {"vf_preds": self.vf}
return dict(
TFPolicyGraph.extra_compute_action_fetches(self),
**{"vf_preds": self.vf})
def _value(self, ob, *args):
feed_dict = {self.observations: [ob], self.model.seq_lens: [1]}
+26 -12
View File
@@ -13,7 +13,8 @@ import tensorflow as tf
from types import FunctionType
import ray
from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter
from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \
ShuffledInput
from ray.rllib.models import MODEL_DEFAULTS
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
@@ -145,18 +146,22 @@ COMMON_CONFIG = {
# {"sampler": 0.4, "/tmp/*.json": 0.4, "s3://bucket/expert.json": 0.2}).
# - a function that returns a rllib.offline.InputReader
"input": "sampler",
# Specify how to evaluate the current policy. This only makes sense to set
# when the input is not already generating simulation data:
# - None: don't evaluate the policy. The episode reward and other
# metrics will be NaN if using offline data.
# Specify how to evaluate the current policy. This only has an effect when
# reading offline experiences. Available options:
# - "wis": the weighted step-wise importance sampling estimator.
# - "is": the step-wise importance sampling estimator.
# - "simulation": run the environment in the background, but use
# this data for evaluation only and not for learning.
"input_evaluation": None,
"input_evaluation": ["is", "wis"],
# Whether to run postprocess_trajectory() on the trajectory fragments from
# offline inputs. Note that postprocessing will be done using the *current*
# policy, not the *behaviour* policy, which is typically undesirable for
# on-policy algorithms.
"postprocess_inputs": False,
# If positive, input batches will be shuffled via a sliding window buffer
# of this number of batches. Use this if the input data is not in random
# enough order. Input is delayed until the shuffle buffer is filled.
"shuffle_buffer_size": 0,
# __sphinx_doc_input_end__
# __sphinx_doc_output_begin__
# Specify where experiences should be saved:
@@ -552,10 +557,10 @@ class Agent(Trainable):
raise ValueError(
"The `use_gpu_for_workers` config is deprecated, please use "
"`num_gpus_per_worker=1` instead.")
if (config["input"] == "sampler"
and config["input_evaluation"] is not None):
if type(config["input_evaluation"]) != list:
raise ValueError(
"`input_evaluation` should not be set when input=sampler")
"`input_evaluation` must be a list of strings, got {}".format(
config["input_evaluation"]))
def _make_evaluator(self,
cls,
@@ -575,9 +580,13 @@ class Agent(Trainable):
elif config["input"] == "sampler":
input_creator = (lambda ioctx: ioctx.default_sampler_input())
elif isinstance(config["input"], dict):
input_creator = (lambda ioctx: MixedInput(config["input"], ioctx))
input_creator = (lambda ioctx: ShuffledInput(
MixedInput(config["input"], ioctx),
config["shuffle_buffer_size"]))
else:
input_creator = (lambda ioctx: JsonReader(config["input"], ioctx))
input_creator = (lambda ioctx: ShuffledInput(
JsonReader(config["input"], ioctx),
config["shuffle_buffer_size"]))
if isinstance(config["output"], FunctionType):
output_creator = config["output"]
@@ -596,6 +605,11 @@ class Agent(Trainable):
max_file_size=config["output_max_file_size"],
compress_columns=config["output_compress_columns"]))
if config["input"] == "sampler":
input_evaluation = []
else:
input_evaluation = config["input_evaluation"]
return cls(
env_creator,
self.config["multiagent"]["policy_graphs"] or policy_graph,
@@ -622,7 +636,7 @@ class Agent(Trainable):
log_level=config["log_level"],
callbacks=config["callbacks"],
input_creator=input_creator,
input_evaluation_method=config["input_evaluation"],
input_evaluation=input_evaluation,
output_creator=output_creator,
remote_worker_envs=remote_worker_envs)
@@ -269,6 +269,11 @@ class DDPGPolicyGraph(TFPolicyGraph):
q_t, self.q_model = self._build_q_network(
self.obs_t, observation_space, self.act_t)
self.q_func_vars = _scope_vars(scope.name)
self.stats = {
"mean_q": tf.reduce_mean(q_t),
"max_q": tf.reduce_max(q_t),
"min_q": tf.reduce_min(q_t),
}
with tf.variable_scope(Q_SCOPE, reuse=True):
q_tp0, _ = self._build_q_network(self.obs_t, observation_space,
output_actions)
@@ -416,6 +421,7 @@ class DDPGPolicyGraph(TFPolicyGraph):
def extra_compute_grad_fetches(self):
return {
"td_error": self.loss.td_error,
"stats": self.stats,
}
@override(PolicyGraph)
+5
View File
@@ -68,6 +68,11 @@ DEFAULT_CONFIG = with_common_config({
"exploration_final_eps": 0.02,
# Update the target network every `target_network_update_freq` steps.
"target_network_update_freq": 500,
# Use softmax for sampling actions.
"soft_q": False,
# Softmax temperature. Q values are divided by this value prior to softmax.
# Softmax approaches argmax as the temperature drops to zero.
"softmax_temp": 1.0,
# === Replay buffer ===
# Size of the replay buffer. Note that if async_updates is set, then
+31 -22
View File
@@ -8,8 +8,7 @@ import tensorflow as tf
import tensorflow.contrib.layers as layers
import ray
from ray.rllib.models import ModelCatalog
from ray.rllib.evaluation.sample_batch import SampleBatch
from ray.rllib.models import ModelCatalog, Categorical
from ray.rllib.utils.annotations import override
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.evaluation.policy_graph import PolicyGraph
@@ -182,7 +181,14 @@ class QNetwork(object):
class QValuePolicy(object):
def __init__(self, q_values, observations, num_actions, stochastic, eps):
def __init__(self, q_values, observations, num_actions, stochastic, eps,
softmax, softmax_temp):
if softmax:
action_dist = Categorical(q_values / softmax_temp)
self.action = action_dist.sample()
self.action_prob = action_dist.sampled_action_prob()
return
deterministic_actions = tf.argmax(q_values, axis=1)
batch_size = tf.shape(observations)[0]
@@ -200,6 +206,7 @@ class QValuePolicy(object):
deterministic_actions)
self.action = tf.cond(stochastic, lambda: stochastic_actions,
lambda: deterministic_actions)
self.action_prob = None
class QLoss(object):
@@ -300,10 +307,12 @@ class DQNPolicyGraph(TFPolicyGraph):
with tf.variable_scope(Q_SCOPE) as scope:
q_values, q_logits, q_dist, _ = self._build_q_network(
self.cur_observations, observation_space)
self.q_values = q_values
self.q_func_vars = _scope_vars(scope.name)
# Action outputs
self.output_actions = self._build_q_value_policy(q_values)
self.output_actions, self.action_prob = self._build_q_value_policy(
q_values)
# Replay inputs
self.obs_t = tf.placeholder(
@@ -387,6 +396,7 @@ class DQNPolicyGraph(TFPolicyGraph):
self.sess,
obs_input=self.cur_observations,
action_sampler=self.output_actions,
action_prob=self.action_prob,
loss=model.loss() + self.loss.loss,
loss_inputs=self.loss_inputs,
update_ops=q_batchnorm_update_ops)
@@ -412,6 +422,13 @@ class DQNPolicyGraph(TFPolicyGraph):
grads_and_vars = [(g, v) for (g, v) in grads_and_vars if g is not None]
return grads_and_vars
@override(TFPolicyGraph)
def extra_compute_action_fetches(self):
return dict(
TFPolicyGraph.extra_compute_action_fetches(self), **{
"q_values": self.q_values,
})
@override(TFPolicyGraph)
def extra_compute_action_feed_dict(self):
return {
@@ -474,8 +491,10 @@ class DQNPolicyGraph(TFPolicyGraph):
return qnet.value, qnet.logits, qnet.dist, qnet.model
def _build_q_value_policy(self, q_values):
return QValuePolicy(q_values, self.cur_observations, self.num_actions,
self.stochastic, self.eps).action
policy = QValuePolicy(
q_values, self.cur_observations, self.num_actions, self.stochastic,
self.eps, self.config["soft_q"], self.config["softmax_temp"])
return policy.action, policy.action_prob
def _build_q_loss(self, q_t_selected, q_logits_t_selected, q_tp1_best,
q_dist_tp1_best):
@@ -511,26 +530,16 @@ def _adjust_nstep(n_step, gamma, obs, actions, rewards, new_obs, dones):
rewards[i] += gamma**j * rewards[i + j]
def _postprocess_dqn(policy_graph, sample_batch):
obs, actions, rewards, new_obs, dones = [
list(x) for x in sample_batch.columns(
["obs", "actions", "rewards", "new_obs", "dones"])
]
def _postprocess_dqn(policy_graph, batch):
# N-step Q adjustments
if policy_graph.config["n_step"] > 1:
_adjust_nstep(policy_graph.config["n_step"],
policy_graph.config["gamma"], obs, actions, rewards,
new_obs, dones)
policy_graph.config["gamma"], batch["obs"],
batch["actions"], batch["rewards"], batch["new_obs"],
batch["dones"])
batch = SampleBatch({
"obs": obs,
"actions": actions,
"rewards": rewards,
"new_obs": new_obs,
"dones": dones,
"weights": np.ones_like(rewards)
})
if "weights" not in batch:
batch["weights"] = np.ones_like(batch["rewards"])
# Prioritize on the worker side
if batch.count > 0 and policy_graph.config["worker_side_prioritization"]:
@@ -215,6 +215,7 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
self.sess,
obs_input=observations,
action_sampler=action_dist.sample(),
action_prob=action_dist.sampled_action_prob(),
loss=self.model.loss() + self.loss.total_loss,
loss_inputs=loss_in,
state_inputs=self.model.state_in,
@@ -270,7 +271,9 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
@override(TFPolicyGraph)
def extra_compute_action_fetches(self):
return {"behaviour_logits": self.model.outputs}
return dict(
TFPolicyGraph.extra_compute_action_fetches(self),
**{"behaviour_logits": self.model.outputs})
@override(TFPolicyGraph)
def extra_compute_grad_fetches(self):
+2 -2
View File
@@ -19,8 +19,8 @@ DEFAULT_CONFIG = with_common_config({
"postprocess_inputs": True,
# Whether to rollout "complete_episodes" or "truncate_episodes"
"batch_mode": "complete_episodes",
# Read data from historic data and evaluate by a sampler
"input_evaluation": "simulation",
# Use importance sampling estimators for reward
"input_evaluation": ["is", "wis"],
# Learning rate for adam optimizer
"lr": 1e-4,
# Number of timesteps collected for each SGD round
@@ -107,6 +107,7 @@ class MARWILPolicyGraph(TFPolicyGraph):
self.sess,
obs_input=self.obs_t,
action_sampler=self.output_actions,
action_prob=action_dist.sampled_action_prob(),
loss=self.model.loss() + objective,
loss_inputs=self.loss_inputs,
state_inputs=self.model.state_in,
@@ -67,6 +67,7 @@ class PGPolicyGraph(TFPolicyGraph):
sess,
obs_input=obs,
action_sampler=action_dist.sample(),
action_prob=action_dist.sampled_action_prob(),
loss=self.model.loss() + loss,
loss_inputs=loss_in,
state_inputs=self.model.state_in,
@@ -320,6 +320,7 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
self.sess,
obs_input=observations,
action_sampler=action_dist.sample(),
action_prob=action_dist.sampled_action_prob(),
loss=self.model.loss() + self.loss.total_loss,
loss_inputs=loss_in,
state_inputs=self.model.state_in,
@@ -373,7 +374,7 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
out = {"behaviour_logits": self.model.outputs}
if not self.config["vtrace"]:
out["vf_preds"] = self.value_function
return out
return dict(TFPolicyGraph.extra_compute_action_fetches(self), **out)
def extra_compute_grad_fetches(self):
return self.stats_fetches
@@ -234,6 +234,7 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
self.sess,
obs_input=obs_ph,
action_sampler=self.sampler,
action_prob=curr_action_dist.sampled_action_prob(),
loss=self.model.loss() + self.loss_obj.loss,
loss_inputs=self.loss_in,
state_inputs=self.model.state_in,
@@ -307,7 +308,11 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
@override(TFPolicyGraph)
def extra_compute_action_fetches(self):
return {"vf_preds": self.value_function, "logits": self.logits}
return dict(
TFPolicyGraph.extra_compute_action_fetches(self), **{
"vf_preds": self.value_function,
"logits": self.logits
})
@override(TFPolicyGraph)
def extra_compute_grad_fetches(self):