mirror of
https://github.com/wassname/ray.git
synced 2026-07-04 15:23:24 +08:00
[rllib] Basic infrastructure for off-policy estimation (IS, WIS) (#3941)
This commit is contained in:
@@ -97,6 +97,7 @@ class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
self.sess,
|
||||
obs_input=self.observations,
|
||||
action_sampler=action_dist.sample(),
|
||||
action_prob=action_dist.sampled_action_prob(),
|
||||
loss=self.model.loss() + self.loss.total_loss,
|
||||
loss_inputs=loss_in,
|
||||
state_inputs=self.model.state_in,
|
||||
@@ -153,7 +154,9 @@ class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def extra_compute_action_fetches(self):
|
||||
return {"vf_preds": self.vf}
|
||||
return dict(
|
||||
TFPolicyGraph.extra_compute_action_fetches(self),
|
||||
**{"vf_preds": self.vf})
|
||||
|
||||
def _value(self, ob, *args):
|
||||
feed_dict = {self.observations: [ob], self.model.seq_lens: [1]}
|
||||
|
||||
@@ -13,7 +13,8 @@ import tensorflow as tf
|
||||
from types import FunctionType
|
||||
|
||||
import ray
|
||||
from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter
|
||||
from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \
|
||||
ShuffledInput
|
||||
from ray.rllib.models import MODEL_DEFAULTS
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
@@ -145,18 +146,22 @@ COMMON_CONFIG = {
|
||||
# {"sampler": 0.4, "/tmp/*.json": 0.4, "s3://bucket/expert.json": 0.2}).
|
||||
# - a function that returns a rllib.offline.InputReader
|
||||
"input": "sampler",
|
||||
# Specify how to evaluate the current policy. This only makes sense to set
|
||||
# when the input is not already generating simulation data:
|
||||
# - None: don't evaluate the policy. The episode reward and other
|
||||
# metrics will be NaN if using offline data.
|
||||
# Specify how to evaluate the current policy. This only has an effect when
|
||||
# reading offline experiences. Available options:
|
||||
# - "wis": the weighted step-wise importance sampling estimator.
|
||||
# - "is": the step-wise importance sampling estimator.
|
||||
# - "simulation": run the environment in the background, but use
|
||||
# this data for evaluation only and not for learning.
|
||||
"input_evaluation": None,
|
||||
"input_evaluation": ["is", "wis"],
|
||||
# Whether to run postprocess_trajectory() on the trajectory fragments from
|
||||
# offline inputs. Note that postprocessing will be done using the *current*
|
||||
# policy, not the *behaviour* policy, which is typically undesirable for
|
||||
# on-policy algorithms.
|
||||
"postprocess_inputs": False,
|
||||
# If positive, input batches will be shuffled via a sliding window buffer
|
||||
# of this number of batches. Use this if the input data is not in random
|
||||
# enough order. Input is delayed until the shuffle buffer is filled.
|
||||
"shuffle_buffer_size": 0,
|
||||
# __sphinx_doc_input_end__
|
||||
# __sphinx_doc_output_begin__
|
||||
# Specify where experiences should be saved:
|
||||
@@ -552,10 +557,10 @@ class Agent(Trainable):
|
||||
raise ValueError(
|
||||
"The `use_gpu_for_workers` config is deprecated, please use "
|
||||
"`num_gpus_per_worker=1` instead.")
|
||||
if (config["input"] == "sampler"
|
||||
and config["input_evaluation"] is not None):
|
||||
if type(config["input_evaluation"]) != list:
|
||||
raise ValueError(
|
||||
"`input_evaluation` should not be set when input=sampler")
|
||||
"`input_evaluation` must be a list of strings, got {}".format(
|
||||
config["input_evaluation"]))
|
||||
|
||||
def _make_evaluator(self,
|
||||
cls,
|
||||
@@ -575,9 +580,13 @@ class Agent(Trainable):
|
||||
elif config["input"] == "sampler":
|
||||
input_creator = (lambda ioctx: ioctx.default_sampler_input())
|
||||
elif isinstance(config["input"], dict):
|
||||
input_creator = (lambda ioctx: MixedInput(config["input"], ioctx))
|
||||
input_creator = (lambda ioctx: ShuffledInput(
|
||||
MixedInput(config["input"], ioctx),
|
||||
config["shuffle_buffer_size"]))
|
||||
else:
|
||||
input_creator = (lambda ioctx: JsonReader(config["input"], ioctx))
|
||||
input_creator = (lambda ioctx: ShuffledInput(
|
||||
JsonReader(config["input"], ioctx),
|
||||
config["shuffle_buffer_size"]))
|
||||
|
||||
if isinstance(config["output"], FunctionType):
|
||||
output_creator = config["output"]
|
||||
@@ -596,6 +605,11 @@ class Agent(Trainable):
|
||||
max_file_size=config["output_max_file_size"],
|
||||
compress_columns=config["output_compress_columns"]))
|
||||
|
||||
if config["input"] == "sampler":
|
||||
input_evaluation = []
|
||||
else:
|
||||
input_evaluation = config["input_evaluation"]
|
||||
|
||||
return cls(
|
||||
env_creator,
|
||||
self.config["multiagent"]["policy_graphs"] or policy_graph,
|
||||
@@ -622,7 +636,7 @@ class Agent(Trainable):
|
||||
log_level=config["log_level"],
|
||||
callbacks=config["callbacks"],
|
||||
input_creator=input_creator,
|
||||
input_evaluation_method=config["input_evaluation"],
|
||||
input_evaluation=input_evaluation,
|
||||
output_creator=output_creator,
|
||||
remote_worker_envs=remote_worker_envs)
|
||||
|
||||
|
||||
@@ -269,6 +269,11 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
q_t, self.q_model = self._build_q_network(
|
||||
self.obs_t, observation_space, self.act_t)
|
||||
self.q_func_vars = _scope_vars(scope.name)
|
||||
self.stats = {
|
||||
"mean_q": tf.reduce_mean(q_t),
|
||||
"max_q": tf.reduce_max(q_t),
|
||||
"min_q": tf.reduce_min(q_t),
|
||||
}
|
||||
with tf.variable_scope(Q_SCOPE, reuse=True):
|
||||
q_tp0, _ = self._build_q_network(self.obs_t, observation_space,
|
||||
output_actions)
|
||||
@@ -416,6 +421,7 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
def extra_compute_grad_fetches(self):
|
||||
return {
|
||||
"td_error": self.loss.td_error,
|
||||
"stats": self.stats,
|
||||
}
|
||||
|
||||
@override(PolicyGraph)
|
||||
|
||||
@@ -68,6 +68,11 @@ DEFAULT_CONFIG = with_common_config({
|
||||
"exploration_final_eps": 0.02,
|
||||
# Update the target network every `target_network_update_freq` steps.
|
||||
"target_network_update_freq": 500,
|
||||
# Use softmax for sampling actions.
|
||||
"soft_q": False,
|
||||
# Softmax temperature. Q values are divided by this value prior to softmax.
|
||||
# Softmax approaches argmax as the temperature drops to zero.
|
||||
"softmax_temp": 1.0,
|
||||
|
||||
# === Replay buffer ===
|
||||
# Size of the replay buffer. Note that if async_updates is set, then
|
||||
|
||||
@@ -8,8 +8,7 @@ import tensorflow as tf
|
||||
import tensorflow.contrib.layers as layers
|
||||
|
||||
import ray
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.models import ModelCatalog, Categorical
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
@@ -182,7 +181,14 @@ class QNetwork(object):
|
||||
|
||||
|
||||
class QValuePolicy(object):
|
||||
def __init__(self, q_values, observations, num_actions, stochastic, eps):
|
||||
def __init__(self, q_values, observations, num_actions, stochastic, eps,
|
||||
softmax, softmax_temp):
|
||||
if softmax:
|
||||
action_dist = Categorical(q_values / softmax_temp)
|
||||
self.action = action_dist.sample()
|
||||
self.action_prob = action_dist.sampled_action_prob()
|
||||
return
|
||||
|
||||
deterministic_actions = tf.argmax(q_values, axis=1)
|
||||
batch_size = tf.shape(observations)[0]
|
||||
|
||||
@@ -200,6 +206,7 @@ class QValuePolicy(object):
|
||||
deterministic_actions)
|
||||
self.action = tf.cond(stochastic, lambda: stochastic_actions,
|
||||
lambda: deterministic_actions)
|
||||
self.action_prob = None
|
||||
|
||||
|
||||
class QLoss(object):
|
||||
@@ -300,10 +307,12 @@ class DQNPolicyGraph(TFPolicyGraph):
|
||||
with tf.variable_scope(Q_SCOPE) as scope:
|
||||
q_values, q_logits, q_dist, _ = self._build_q_network(
|
||||
self.cur_observations, observation_space)
|
||||
self.q_values = q_values
|
||||
self.q_func_vars = _scope_vars(scope.name)
|
||||
|
||||
# Action outputs
|
||||
self.output_actions = self._build_q_value_policy(q_values)
|
||||
self.output_actions, self.action_prob = self._build_q_value_policy(
|
||||
q_values)
|
||||
|
||||
# Replay inputs
|
||||
self.obs_t = tf.placeholder(
|
||||
@@ -387,6 +396,7 @@ class DQNPolicyGraph(TFPolicyGraph):
|
||||
self.sess,
|
||||
obs_input=self.cur_observations,
|
||||
action_sampler=self.output_actions,
|
||||
action_prob=self.action_prob,
|
||||
loss=model.loss() + self.loss.loss,
|
||||
loss_inputs=self.loss_inputs,
|
||||
update_ops=q_batchnorm_update_ops)
|
||||
@@ -412,6 +422,13 @@ class DQNPolicyGraph(TFPolicyGraph):
|
||||
grads_and_vars = [(g, v) for (g, v) in grads_and_vars if g is not None]
|
||||
return grads_and_vars
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def extra_compute_action_fetches(self):
|
||||
return dict(
|
||||
TFPolicyGraph.extra_compute_action_fetches(self), **{
|
||||
"q_values": self.q_values,
|
||||
})
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def extra_compute_action_feed_dict(self):
|
||||
return {
|
||||
@@ -474,8 +491,10 @@ class DQNPolicyGraph(TFPolicyGraph):
|
||||
return qnet.value, qnet.logits, qnet.dist, qnet.model
|
||||
|
||||
def _build_q_value_policy(self, q_values):
|
||||
return QValuePolicy(q_values, self.cur_observations, self.num_actions,
|
||||
self.stochastic, self.eps).action
|
||||
policy = QValuePolicy(
|
||||
q_values, self.cur_observations, self.num_actions, self.stochastic,
|
||||
self.eps, self.config["soft_q"], self.config["softmax_temp"])
|
||||
return policy.action, policy.action_prob
|
||||
|
||||
def _build_q_loss(self, q_t_selected, q_logits_t_selected, q_tp1_best,
|
||||
q_dist_tp1_best):
|
||||
@@ -511,26 +530,16 @@ def _adjust_nstep(n_step, gamma, obs, actions, rewards, new_obs, dones):
|
||||
rewards[i] += gamma**j * rewards[i + j]
|
||||
|
||||
|
||||
def _postprocess_dqn(policy_graph, sample_batch):
|
||||
obs, actions, rewards, new_obs, dones = [
|
||||
list(x) for x in sample_batch.columns(
|
||||
["obs", "actions", "rewards", "new_obs", "dones"])
|
||||
]
|
||||
|
||||
def _postprocess_dqn(policy_graph, batch):
|
||||
# N-step Q adjustments
|
||||
if policy_graph.config["n_step"] > 1:
|
||||
_adjust_nstep(policy_graph.config["n_step"],
|
||||
policy_graph.config["gamma"], obs, actions, rewards,
|
||||
new_obs, dones)
|
||||
policy_graph.config["gamma"], batch["obs"],
|
||||
batch["actions"], batch["rewards"], batch["new_obs"],
|
||||
batch["dones"])
|
||||
|
||||
batch = SampleBatch({
|
||||
"obs": obs,
|
||||
"actions": actions,
|
||||
"rewards": rewards,
|
||||
"new_obs": new_obs,
|
||||
"dones": dones,
|
||||
"weights": np.ones_like(rewards)
|
||||
})
|
||||
if "weights" not in batch:
|
||||
batch["weights"] = np.ones_like(batch["rewards"])
|
||||
|
||||
# Prioritize on the worker side
|
||||
if batch.count > 0 and policy_graph.config["worker_side_prioritization"]:
|
||||
|
||||
@@ -215,6 +215,7 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
self.sess,
|
||||
obs_input=observations,
|
||||
action_sampler=action_dist.sample(),
|
||||
action_prob=action_dist.sampled_action_prob(),
|
||||
loss=self.model.loss() + self.loss.total_loss,
|
||||
loss_inputs=loss_in,
|
||||
state_inputs=self.model.state_in,
|
||||
@@ -270,7 +271,9 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def extra_compute_action_fetches(self):
|
||||
return {"behaviour_logits": self.model.outputs}
|
||||
return dict(
|
||||
TFPolicyGraph.extra_compute_action_fetches(self),
|
||||
**{"behaviour_logits": self.model.outputs})
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def extra_compute_grad_fetches(self):
|
||||
|
||||
@@ -19,8 +19,8 @@ DEFAULT_CONFIG = with_common_config({
|
||||
"postprocess_inputs": True,
|
||||
# Whether to rollout "complete_episodes" or "truncate_episodes"
|
||||
"batch_mode": "complete_episodes",
|
||||
# Read data from historic data and evaluate by a sampler
|
||||
"input_evaluation": "simulation",
|
||||
# Use importance sampling estimators for reward
|
||||
"input_evaluation": ["is", "wis"],
|
||||
# Learning rate for adam optimizer
|
||||
"lr": 1e-4,
|
||||
# Number of timesteps collected for each SGD round
|
||||
|
||||
@@ -107,6 +107,7 @@ class MARWILPolicyGraph(TFPolicyGraph):
|
||||
self.sess,
|
||||
obs_input=self.obs_t,
|
||||
action_sampler=self.output_actions,
|
||||
action_prob=action_dist.sampled_action_prob(),
|
||||
loss=self.model.loss() + objective,
|
||||
loss_inputs=self.loss_inputs,
|
||||
state_inputs=self.model.state_in,
|
||||
|
||||
@@ -67,6 +67,7 @@ class PGPolicyGraph(TFPolicyGraph):
|
||||
sess,
|
||||
obs_input=obs,
|
||||
action_sampler=action_dist.sample(),
|
||||
action_prob=action_dist.sampled_action_prob(),
|
||||
loss=self.model.loss() + loss,
|
||||
loss_inputs=loss_in,
|
||||
state_inputs=self.model.state_in,
|
||||
|
||||
@@ -320,6 +320,7 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
self.sess,
|
||||
obs_input=observations,
|
||||
action_sampler=action_dist.sample(),
|
||||
action_prob=action_dist.sampled_action_prob(),
|
||||
loss=self.model.loss() + self.loss.total_loss,
|
||||
loss_inputs=loss_in,
|
||||
state_inputs=self.model.state_in,
|
||||
@@ -373,7 +374,7 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
out = {"behaviour_logits": self.model.outputs}
|
||||
if not self.config["vtrace"]:
|
||||
out["vf_preds"] = self.value_function
|
||||
return out
|
||||
return dict(TFPolicyGraph.extra_compute_action_fetches(self), **out)
|
||||
|
||||
def extra_compute_grad_fetches(self):
|
||||
return self.stats_fetches
|
||||
|
||||
@@ -234,6 +234,7 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
self.sess,
|
||||
obs_input=obs_ph,
|
||||
action_sampler=self.sampler,
|
||||
action_prob=curr_action_dist.sampled_action_prob(),
|
||||
loss=self.model.loss() + self.loss_obj.loss,
|
||||
loss_inputs=self.loss_in,
|
||||
state_inputs=self.model.state_in,
|
||||
@@ -307,7 +308,11 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def extra_compute_action_fetches(self):
|
||||
return {"vf_preds": self.value_function, "logits": self.logits}
|
||||
return dict(
|
||||
TFPolicyGraph.extra_compute_action_fetches(self), **{
|
||||
"vf_preds": self.value_function,
|
||||
"logits": self.logits
|
||||
})
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def extra_compute_grad_fetches(self):
|
||||
|
||||
Reference in New Issue
Block a user