[rllib] Basic infrastructure for off-policy estimation (IS, WIS) (#3941)

This commit is contained in:
Eric Liang
2019-02-13 16:25:05 -08:00
committed by GitHub
parent 729d0b2825
commit 2dccf383dd
34 changed files with 549 additions and 131 deletions
@@ -97,6 +97,7 @@ class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph):
self.sess,
obs_input=self.observations,
action_sampler=action_dist.sample(),
action_prob=action_dist.sampled_action_prob(),
loss=self.model.loss() + self.loss.total_loss,
loss_inputs=loss_in,
state_inputs=self.model.state_in,
@@ -153,7 +154,9 @@ class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph):
@override(TFPolicyGraph)
def extra_compute_action_fetches(self):
return {"vf_preds": self.vf}
return dict(
TFPolicyGraph.extra_compute_action_fetches(self),
**{"vf_preds": self.vf})
def _value(self, ob, *args):
feed_dict = {self.observations: [ob], self.model.seq_lens: [1]}
+26 -12
View File
@@ -13,7 +13,8 @@ import tensorflow as tf
from types import FunctionType
import ray
from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter
from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \
ShuffledInput
from ray.rllib.models import MODEL_DEFAULTS
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
@@ -145,18 +146,22 @@ COMMON_CONFIG = {
# {"sampler": 0.4, "/tmp/*.json": 0.4, "s3://bucket/expert.json": 0.2}).
# - a function that returns a rllib.offline.InputReader
"input": "sampler",
# Specify how to evaluate the current policy. This only makes sense to set
# when the input is not already generating simulation data:
# - None: don't evaluate the policy. The episode reward and other
# metrics will be NaN if using offline data.
# Specify how to evaluate the current policy. This only has an effect when
# reading offline experiences. Available options:
# - "wis": the weighted step-wise importance sampling estimator.
# - "is": the step-wise importance sampling estimator.
# - "simulation": run the environment in the background, but use
# this data for evaluation only and not for learning.
"input_evaluation": None,
"input_evaluation": ["is", "wis"],
# Whether to run postprocess_trajectory() on the trajectory fragments from
# offline inputs. Note that postprocessing will be done using the *current*
# policy, not the *behaviour* policy, which is typically undesirable for
# on-policy algorithms.
"postprocess_inputs": False,
# If positive, input batches will be shuffled via a sliding window buffer
# of this number of batches. Use this if the input data is not in random
# enough order. Input is delayed until the shuffle buffer is filled.
"shuffle_buffer_size": 0,
# __sphinx_doc_input_end__
# __sphinx_doc_output_begin__
# Specify where experiences should be saved:
@@ -552,10 +557,10 @@ class Agent(Trainable):
raise ValueError(
"The `use_gpu_for_workers` config is deprecated, please use "
"`num_gpus_per_worker=1` instead.")
if (config["input"] == "sampler"
and config["input_evaluation"] is not None):
if type(config["input_evaluation"]) != list:
raise ValueError(
"`input_evaluation` should not be set when input=sampler")
"`input_evaluation` must be a list of strings, got {}".format(
config["input_evaluation"]))
def _make_evaluator(self,
cls,
@@ -575,9 +580,13 @@ class Agent(Trainable):
elif config["input"] == "sampler":
input_creator = (lambda ioctx: ioctx.default_sampler_input())
elif isinstance(config["input"], dict):
input_creator = (lambda ioctx: MixedInput(config["input"], ioctx))
input_creator = (lambda ioctx: ShuffledInput(
MixedInput(config["input"], ioctx),
config["shuffle_buffer_size"]))
else:
input_creator = (lambda ioctx: JsonReader(config["input"], ioctx))
input_creator = (lambda ioctx: ShuffledInput(
JsonReader(config["input"], ioctx),
config["shuffle_buffer_size"]))
if isinstance(config["output"], FunctionType):
output_creator = config["output"]
@@ -596,6 +605,11 @@ class Agent(Trainable):
max_file_size=config["output_max_file_size"],
compress_columns=config["output_compress_columns"]))
if config["input"] == "sampler":
input_evaluation = []
else:
input_evaluation = config["input_evaluation"]
return cls(
env_creator,
self.config["multiagent"]["policy_graphs"] or policy_graph,
@@ -622,7 +636,7 @@ class Agent(Trainable):
log_level=config["log_level"],
callbacks=config["callbacks"],
input_creator=input_creator,
input_evaluation_method=config["input_evaluation"],
input_evaluation=input_evaluation,
output_creator=output_creator,
remote_worker_envs=remote_worker_envs)
@@ -269,6 +269,11 @@ class DDPGPolicyGraph(TFPolicyGraph):
q_t, self.q_model = self._build_q_network(
self.obs_t, observation_space, self.act_t)
self.q_func_vars = _scope_vars(scope.name)
self.stats = {
"mean_q": tf.reduce_mean(q_t),
"max_q": tf.reduce_max(q_t),
"min_q": tf.reduce_min(q_t),
}
with tf.variable_scope(Q_SCOPE, reuse=True):
q_tp0, _ = self._build_q_network(self.obs_t, observation_space,
output_actions)
@@ -416,6 +421,7 @@ class DDPGPolicyGraph(TFPolicyGraph):
def extra_compute_grad_fetches(self):
return {
"td_error": self.loss.td_error,
"stats": self.stats,
}
@override(PolicyGraph)
+5
View File
@@ -68,6 +68,11 @@ DEFAULT_CONFIG = with_common_config({
"exploration_final_eps": 0.02,
# Update the target network every `target_network_update_freq` steps.
"target_network_update_freq": 500,
# Use softmax for sampling actions.
"soft_q": False,
# Softmax temperature. Q values are divided by this value prior to softmax.
# Softmax approaches argmax as the temperature drops to zero.
"softmax_temp": 1.0,
# === Replay buffer ===
# Size of the replay buffer. Note that if async_updates is set, then
+31 -22
View File
@@ -8,8 +8,7 @@ import tensorflow as tf
import tensorflow.contrib.layers as layers
import ray
from ray.rllib.models import ModelCatalog
from ray.rllib.evaluation.sample_batch import SampleBatch
from ray.rllib.models import ModelCatalog, Categorical
from ray.rllib.utils.annotations import override
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.evaluation.policy_graph import PolicyGraph
@@ -182,7 +181,14 @@ class QNetwork(object):
class QValuePolicy(object):
def __init__(self, q_values, observations, num_actions, stochastic, eps):
def __init__(self, q_values, observations, num_actions, stochastic, eps,
softmax, softmax_temp):
if softmax:
action_dist = Categorical(q_values / softmax_temp)
self.action = action_dist.sample()
self.action_prob = action_dist.sampled_action_prob()
return
deterministic_actions = tf.argmax(q_values, axis=1)
batch_size = tf.shape(observations)[0]
@@ -200,6 +206,7 @@ class QValuePolicy(object):
deterministic_actions)
self.action = tf.cond(stochastic, lambda: stochastic_actions,
lambda: deterministic_actions)
self.action_prob = None
class QLoss(object):
@@ -300,10 +307,12 @@ class DQNPolicyGraph(TFPolicyGraph):
with tf.variable_scope(Q_SCOPE) as scope:
q_values, q_logits, q_dist, _ = self._build_q_network(
self.cur_observations, observation_space)
self.q_values = q_values
self.q_func_vars = _scope_vars(scope.name)
# Action outputs
self.output_actions = self._build_q_value_policy(q_values)
self.output_actions, self.action_prob = self._build_q_value_policy(
q_values)
# Replay inputs
self.obs_t = tf.placeholder(
@@ -387,6 +396,7 @@ class DQNPolicyGraph(TFPolicyGraph):
self.sess,
obs_input=self.cur_observations,
action_sampler=self.output_actions,
action_prob=self.action_prob,
loss=model.loss() + self.loss.loss,
loss_inputs=self.loss_inputs,
update_ops=q_batchnorm_update_ops)
@@ -412,6 +422,13 @@ class DQNPolicyGraph(TFPolicyGraph):
grads_and_vars = [(g, v) for (g, v) in grads_and_vars if g is not None]
return grads_and_vars
@override(TFPolicyGraph)
def extra_compute_action_fetches(self):
return dict(
TFPolicyGraph.extra_compute_action_fetches(self), **{
"q_values": self.q_values,
})
@override(TFPolicyGraph)
def extra_compute_action_feed_dict(self):
return {
@@ -474,8 +491,10 @@ class DQNPolicyGraph(TFPolicyGraph):
return qnet.value, qnet.logits, qnet.dist, qnet.model
def _build_q_value_policy(self, q_values):
return QValuePolicy(q_values, self.cur_observations, self.num_actions,
self.stochastic, self.eps).action
policy = QValuePolicy(
q_values, self.cur_observations, self.num_actions, self.stochastic,
self.eps, self.config["soft_q"], self.config["softmax_temp"])
return policy.action, policy.action_prob
def _build_q_loss(self, q_t_selected, q_logits_t_selected, q_tp1_best,
q_dist_tp1_best):
@@ -511,26 +530,16 @@ def _adjust_nstep(n_step, gamma, obs, actions, rewards, new_obs, dones):
rewards[i] += gamma**j * rewards[i + j]
def _postprocess_dqn(policy_graph, sample_batch):
obs, actions, rewards, new_obs, dones = [
list(x) for x in sample_batch.columns(
["obs", "actions", "rewards", "new_obs", "dones"])
]
def _postprocess_dqn(policy_graph, batch):
# N-step Q adjustments
if policy_graph.config["n_step"] > 1:
_adjust_nstep(policy_graph.config["n_step"],
policy_graph.config["gamma"], obs, actions, rewards,
new_obs, dones)
policy_graph.config["gamma"], batch["obs"],
batch["actions"], batch["rewards"], batch["new_obs"],
batch["dones"])
batch = SampleBatch({
"obs": obs,
"actions": actions,
"rewards": rewards,
"new_obs": new_obs,
"dones": dones,
"weights": np.ones_like(rewards)
})
if "weights" not in batch:
batch["weights"] = np.ones_like(batch["rewards"])
# Prioritize on the worker side
if batch.count > 0 and policy_graph.config["worker_side_prioritization"]:
@@ -215,6 +215,7 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
self.sess,
obs_input=observations,
action_sampler=action_dist.sample(),
action_prob=action_dist.sampled_action_prob(),
loss=self.model.loss() + self.loss.total_loss,
loss_inputs=loss_in,
state_inputs=self.model.state_in,
@@ -270,7 +271,9 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
@override(TFPolicyGraph)
def extra_compute_action_fetches(self):
return {"behaviour_logits": self.model.outputs}
return dict(
TFPolicyGraph.extra_compute_action_fetches(self),
**{"behaviour_logits": self.model.outputs})
@override(TFPolicyGraph)
def extra_compute_grad_fetches(self):
+2 -2
View File
@@ -19,8 +19,8 @@ DEFAULT_CONFIG = with_common_config({
"postprocess_inputs": True,
# Whether to rollout "complete_episodes" or "truncate_episodes"
"batch_mode": "complete_episodes",
# Read data from historic data and evaluate by a sampler
"input_evaluation": "simulation",
# Use importance sampling estimators for reward
"input_evaluation": ["is", "wis"],
# Learning rate for adam optimizer
"lr": 1e-4,
# Number of timesteps collected for each SGD round
@@ -107,6 +107,7 @@ class MARWILPolicyGraph(TFPolicyGraph):
self.sess,
obs_input=self.obs_t,
action_sampler=self.output_actions,
action_prob=action_dist.sampled_action_prob(),
loss=self.model.loss() + objective,
loss_inputs=self.loss_inputs,
state_inputs=self.model.state_in,
@@ -67,6 +67,7 @@ class PGPolicyGraph(TFPolicyGraph):
sess,
obs_input=obs,
action_sampler=action_dist.sample(),
action_prob=action_dist.sampled_action_prob(),
loss=self.model.loss() + loss,
loss_inputs=loss_in,
state_inputs=self.model.state_in,
@@ -320,6 +320,7 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
self.sess,
obs_input=observations,
action_sampler=action_dist.sample(),
action_prob=action_dist.sampled_action_prob(),
loss=self.model.loss() + self.loss.total_loss,
loss_inputs=loss_in,
state_inputs=self.model.state_in,
@@ -373,7 +374,7 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
out = {"behaviour_logits": self.model.outputs}
if not self.config["vtrace"]:
out["vf_preds"] = self.value_function
return out
return dict(TFPolicyGraph.extra_compute_action_fetches(self), **out)
def extra_compute_grad_fetches(self):
return self.stats_fetches
@@ -234,6 +234,7 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
self.sess,
obs_input=obs_ph,
action_sampler=self.sampler,
action_prob=curr_action_dist.sampled_action_prob(),
loss=self.model.loss() + self.loss_obj.loss,
loss_inputs=self.loss_in,
state_inputs=self.model.state_in,
@@ -307,7 +308,11 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
@override(TFPolicyGraph)
def extra_compute_action_fetches(self):
return {"vf_preds": self.value_function, "logits": self.logits}
return dict(
TFPolicyGraph.extra_compute_action_fetches(self), **{
"vf_preds": self.value_function,
"logits": self.logits
})
@override(TFPolicyGraph)
def extra_compute_grad_fetches(self):
+32 -3
View File
@@ -8,6 +8,8 @@ import collections
import ray
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.evaluation.sampler import RolloutMetrics
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimate
from ray.rllib.utils.annotations import DeveloperAPI
logger = logging.getLogger(__name__)
@@ -31,15 +33,14 @@ def collect_episodes(local_evaluator,
"""Gathers new episodes metrics tuples from the given evaluators."""
pending = [
a.apply.remote(lambda ev: ev.sampler.get_metrics())
for a in remote_evaluators
a.apply.remote(lambda ev: ev.get_metrics()) for a in remote_evaluators
]
collected, _ = ray.wait(
pending, num_returns=len(pending), timeout=timeout_seconds * 1.0)
num_metric_batches_dropped = len(pending) - len(collected)
metric_lists = ray.get(collected)
metric_lists.append(local_evaluator.sampler.get_metrics())
metric_lists.append(local_evaluator.get_metrics())
episodes = []
for metrics in metric_lists:
episodes.extend(metrics)
@@ -60,6 +61,9 @@ def summarize_episodes(episodes, new_episodes, num_dropped):
logger.warning("WARNING: {} workers have NOT returned metrics".format(
num_dropped))
episodes, estimates = _partition(episodes)
new_episodes, _ = _partition(new_episodes)
episode_rewards = []
episode_lengths = []
policy_rewards = collections.defaultdict(list)
@@ -95,6 +99,16 @@ def summarize_episodes(episodes, new_episodes, num_dropped):
custom_metrics[k + "_max"] = float("nan")
del custom_metrics[k]
estimators = collections.defaultdict(lambda: collections.defaultdict(list))
for e in estimates:
acc = estimators[e.estimator_name]
for k, v in e.metrics.items():
acc[k].append(v)
for name, metrics in estimators.items():
for k, v_list in metrics.items():
metrics[k] = np.mean(v_list)
estimators[name] = dict(metrics)
return dict(
episode_reward_max=max_reward,
episode_reward_min=min_reward,
@@ -103,4 +117,19 @@ def summarize_episodes(episodes, new_episodes, num_dropped):
episodes_this_iter=len(new_episodes),
policy_reward_mean=dict(policy_rewards),
custom_metrics=dict(custom_metrics),
off_policy_estimator=dict(estimators),
num_metric_batches_dropped=num_dropped)
def _partition(episodes):
"""Divides metrics data into true rollouts vs off-policy estimates."""
rollouts, estimates = [], []
for e in episodes:
if isinstance(e, RolloutMetrics):
rollouts.append(e)
elif isinstance(e, OffPolicyEstimate):
estimates.append(e)
else:
raise ValueError("Unknown metric type: {}".format(e))
return rollouts, estimates
+42 -18
View File
@@ -19,6 +19,8 @@ from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.offline import NoopOutput, IOContext, OutputWriter, InputReader
from ray.rllib.offline.is_estimator import ImportanceSamplingEstimator
from ray.rllib.offline.wis_estimator import WeightedImportanceSamplingEstimator
from ray.rllib.models import ModelCatalog
from ray.rllib.models.preprocessors import NoPreprocessor
from ray.rllib.utils import merge_dicts
@@ -116,7 +118,7 @@ class PolicyEvaluator(EvaluatorInterface):
log_level=None,
callbacks=None,
input_creator=lambda ioctx: ioctx.default_sampler_input(),
input_evaluation_method=None,
input_evaluation=frozenset([]),
output_creator=lambda ioctx: NoopOutput(),
remote_worker_envs=False):
"""Initialize a policy evaluator.
@@ -184,11 +186,11 @@ class PolicyEvaluator(EvaluatorInterface):
callbacks (dict): Dict of custom debug callbacks.
input_creator (func): Function that returns an InputReader object
for loading previous generated experiences.
input_evaluation_method (str): How to evaluate the current policy.
This only applies when the input is reading offline data.
Options are:
- None: don't evaluate the policy. The episode reward and
other metrics will be NaN.
input_evaluation (list): How to evaluate the policy performance.
This only makes sense to set when the input is reading offline
data. The possible values include:
- "is": the step-wise importance sampling estimator.
- "wis": the weighted step-wise is estimator.
- "simulation": run the environment in the background, but
use this data for evaluation only and never for learning.
output_creator (func): Function that returns an OutputWriter object
@@ -316,16 +318,24 @@ class PolicyEvaluator(EvaluatorInterface):
raise ValueError("Unsupported batch mode: {}".format(
self.batch_mode))
if input_evaluation_method == "simulation":
logger.warning(
"Requested 'simulation' input evaluation method: "
"will discard all sampler outputs and keep only metrics.")
sample_async = True
elif input_evaluation_method is None:
pass
else:
raise ValueError("Unknown evaluation method: {}".format(
input_evaluation_method))
self.io_context = IOContext(log_dir, policy_config, worker_index, self)
self.reward_estimators = []
for method in input_evaluation:
if method == "simulation":
logger.warning(
"Requested 'simulation' input evaluation method: "
"will discard all sampler outputs and keep only metrics.")
sample_async = True
elif method == "is":
ise = ImportanceSamplingEstimator.create(self.io_context)
self.reward_estimators.append(ise)
elif method == "wis":
wise = WeightedImportanceSamplingEstimator.create(
self.io_context)
self.reward_estimators.append(wise)
else:
raise ValueError(
"Unknown evaluation method: {}".format(method))
if sample_async:
self.sampler = AsyncSampler(
@@ -341,7 +351,7 @@ class PolicyEvaluator(EvaluatorInterface):
pack=pack_episodes,
tf_sess=self.tf_sess,
clip_actions=clip_actions,
blackhole_outputs=input_evaluation_method == "simulation")
blackhole_outputs="simulation" in input_evaluation)
self.sampler.start()
else:
self.sampler = SyncSampler(
@@ -358,7 +368,6 @@ class PolicyEvaluator(EvaluatorInterface):
tf_sess=self.tf_sess,
clip_actions=clip_actions)
self.io_context = IOContext(log_dir, policy_config, worker_index, self)
self.input_reader = input_creator(self.io_context)
assert isinstance(self.input_reader, InputReader), self.input_reader
self.output_writer = output_creator(self.io_context)
@@ -402,6 +411,12 @@ class PolicyEvaluator(EvaluatorInterface):
# for better compression inside the writer.
self.output_writer.write(batch)
# Do off-policy estimation if needed
if self.reward_estimators:
for sub_batch in batch.split_by_episode():
for estimator in self.reward_estimators:
estimator.process(sub_batch)
if self.compress_observations:
if isinstance(batch, MultiAgentBatch):
for data in batch.policy_batches.values():
@@ -504,6 +519,15 @@ class PolicyEvaluator(EvaluatorInterface):
self.policy_map[DEFAULT_POLICY_ID].learn_on_batch(samples))
return grad_fetch
@DeveloperAPI
def get_metrics(self):
"""Returns a list of new RolloutMetric objects from evaluation."""
out = self.sampler.get_metrics()
for m in self.reward_estimators:
out.extend(m.get_metrics())
return out
@DeveloperAPI
def foreach_env(self, func):
"""Apply the given function to each underlying env instance."""
+17 -2
View File
@@ -16,6 +16,8 @@ from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.atari_wrappers import get_wrapper_by_cls, MonitorEnv
from ray.rllib.models.action_dist import TupleActions
from ray.rllib.offline import InputReader
from ray.rllib.utils.annotations import override
from ray.rllib.utils.tf_run_builder import TFRunBuilder
logger = logging.getLogger(__name__)
@@ -31,7 +33,20 @@ PolicyEvalData = namedtuple("PolicyEvalData", [
])
class SyncSampler(object):
class SamplerInput(InputReader):
"""Reads input experiences from an existing sampler."""
@override(InputReader)
def next(self):
batches = [self.get_data()]
batches.extend(self.get_extra_batches())
if len(batches) > 1:
return batches[0].concat_samples(batches)
else:
return batches[0]
class SyncSampler(SamplerInput):
def __init__(self,
env,
policies,
@@ -87,7 +102,7 @@ class SyncSampler(object):
return extra
class AsyncSampler(threading.Thread):
class AsyncSampler(threading.Thread, SamplerInput):
def __init__(self,
env,
policies,
+11 -2
View File
@@ -52,6 +52,7 @@ class TFPolicyGraph(PolicyGraph):
action_sampler,
loss,
loss_inputs,
action_prob=None,
state_inputs=None,
state_outputs=None,
prev_action_input=None,
@@ -77,6 +78,7 @@ class TFPolicyGraph(PolicyGraph):
and has shape [BATCH_SIZE, data...]. These keys will be read
from postprocessed sample batches and fed into the specified
placeholders during loss computation.
action_prob (Tensor): probability of the sampled action.
state_inputs (list): list of RNN state input Tensors.
state_outputs (list): list of RNN state output Tensors.
prev_action_input (Tensor): placeholder for previous actions
@@ -104,6 +106,7 @@ class TFPolicyGraph(PolicyGraph):
self._loss_inputs = loss_inputs
self._loss_input_dict = dict(self._loss_inputs)
self._is_training = self._get_is_training_placeholder()
self._action_prob = action_prob
self._state_inputs = state_inputs or []
self._state_outputs = state_outputs or []
for i, ph in enumerate(self._state_inputs):
@@ -231,8 +234,14 @@ class TFPolicyGraph(PolicyGraph):
@DeveloperAPI
def extra_compute_action_fetches(self):
"""Extra values to fetch and return from compute_actions()."""
return {} # e.g, value function
"""Extra values to fetch and return from compute_actions().
By default we only return action probability info (if present).
"""
if self._action_prob is not None:
return {"action_prob": self._action_prob}
else:
return {}
@DeveloperAPI
def extra_compute_grad_feed_dict(self):
@@ -33,6 +33,7 @@ if __name__ == "__main__":
agent_index=0,
obs=obs,
actions=action,
action_prob=1.0, # put the true action probability here
rewards=rew,
prev_actions=prev_action,
prev_rewards=prev_reward,
+32 -6
View File
@@ -24,6 +24,7 @@ class ActionDistribution(object):
@DeveloperAPI
def __init__(self, inputs):
self.inputs = inputs
self.sample_op = self._build_sample_op()
@DeveloperAPI
def logp(self, x):
@@ -37,13 +38,27 @@ class ActionDistribution(object):
@DeveloperAPI
def entropy(self):
"""The entroy of the action distribution."""
"""The entropy of the action distribution."""
raise NotImplementedError
@DeveloperAPI
def _build_sample_op(self):
"""Implement this instead of sample(), to enable op reuse.
This is needed since the sample op is non-deterministic and is shared
between sample() and sampled_action_prob().
"""
raise NotImplementedError
@DeveloperAPI
def sample(self):
"""Draw a sample from the action distribution."""
raise NotImplementedError
return self.sample_op
@DeveloperAPI
def sampled_action_prob(self):
"""Returns the log probability of the sampled action."""
return tf.exp(self.logp(self.sample_op))
class Categorical(ActionDistribution):
@@ -95,7 +110,7 @@ class Categorical(ActionDistribution):
p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), reduction_indices=[1])
@override(ActionDistribution)
def sample(self):
def _build_sample_op(self):
return tf.squeeze(tf.multinomial(self.inputs, 1), axis=1)
@@ -107,11 +122,11 @@ class DiagGaussian(ActionDistribution):
"""
def __init__(self, inputs):
ActionDistribution.__init__(self, inputs)
mean, log_std = tf.split(inputs, 2, axis=1)
self.mean = mean
self.log_std = log_std
self.std = tf.exp(log_std)
ActionDistribution.__init__(self, inputs)
@override(ActionDistribution)
def logp(self, x):
@@ -136,7 +151,7 @@ class DiagGaussian(ActionDistribution):
reduction_indices=[1])
@override(ActionDistribution)
def sample(self):
def _build_sample_op(self):
return self.mean + self.std * tf.random_normal(tf.shape(self.mean))
@@ -147,7 +162,11 @@ class Deterministic(ActionDistribution):
"""
@override(ActionDistribution)
def sample(self):
def sampled_action_prob(self):
return 1.0
@override(ActionDistribution)
def _build_sample_op(self):
return self.inputs
@@ -205,5 +224,12 @@ class MultiActionDistribution(ActionDistribution):
def sample(self):
return TupleActions([s.sample() for s in self.child_distributions])
@override(ActionDistribution)
def sampled_action_prob(self):
p = self.child_distributions[0].sampled_action_prob()
for c in self.child_distributions[1:]:
p *= c.sampled_action_prob()
return p
TupleActions = namedtuple("TupleActions", ["batches"])
+2
View File
@@ -8,6 +8,7 @@ from ray.rllib.offline.json_writer import JsonWriter
from ray.rllib.offline.output_writer import OutputWriter, NoopOutput
from ray.rllib.offline.input_reader import InputReader
from ray.rllib.offline.mixed_input import MixedInput
from ray.rllib.offline.shuffled_input import ShuffledInput
__all__ = [
"IOContext",
@@ -17,4 +18,5 @@ __all__ = [
"OutputWriter",
"InputReader",
"MixedInput",
"ShuffledInput",
]
-17
View File
@@ -2,7 +2,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.utils.annotations import override
from ray.rllib.utils.annotations import PublicAPI
@@ -18,19 +17,3 @@ class InputReader(object):
SampleBatch or MultiAgentBatch read.
"""
raise NotImplementedError
class SamplerInput(InputReader):
"""Reads input experiences from an existing sampler."""
def __init__(self, sampler):
self.sampler = sampler
@override(InputReader)
def next(self):
batches = [self.sampler.get_data()]
batches.extend(self.sampler.get_extra_batches())
if len(batches) > 1:
return batches[0].concat_samples(batches)
else:
return batches[0]
+1 -2
View File
@@ -4,7 +4,6 @@ from __future__ import print_function
import os
from ray.rllib.offline.input_reader import SamplerInput
from ray.rllib.utils.annotations import PublicAPI
@@ -35,4 +34,4 @@ class IOContext(object):
@PublicAPI
def default_sampler_input(self):
return SamplerInput(self.evaluator.sampler)
return self.evaluator.sampler
+46
View File
@@ -0,0 +1,46 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, \
OffPolicyEstimate
from ray.rllib.utils.annotations import override
class ImportanceSamplingEstimator(OffPolicyEstimator):
"""The step-wise IS estimator.
Step-wise IS estimator described in https://arxiv.org/pdf/1511.03722.pdf"""
def __init__(self, policy, gamma):
OffPolicyEstimator.__init__(self, policy, gamma)
@override(OffPolicyEstimator)
def estimate(self, batch):
self.check_can_estimate_for(batch)
rewards, old_prob = batch["rewards"], batch["action_prob"]
new_prob = self.action_prob(batch)
# calculate importance ratios
p = []
for t in range(batch.count - 1):
if t == 0:
pt_prev = 1.0
else:
pt_prev = p[t - 1]
p.append(pt_prev * new_prob[t] / old_prob[t])
# calculate stepwise IS estimate
V_prev, V_step_IS = 0.0, 0.0
for t in range(batch.count - 1):
V_prev += rewards[t] * self.gamma**t
V_step_IS += p[t] * rewards[t] * self.gamma**t
estimation = OffPolicyEstimate(
"is", {
"V_prev": V_prev,
"V_step_IS": V_step_IS,
"V_gain_est": V_step_IS / max(1e-8, V_prev),
})
return estimation
+1
View File
@@ -44,6 +44,7 @@ class JsonReader(InputReader):
self.ioctx = ioctx or IOContext()
if isinstance(inputs, six.string_types):
inputs = os.path.abspath(os.path.expanduser(inputs))
if os.path.isdir(inputs):
inputs = os.path.join(inputs, "*.json")
logger.warning(
+2 -1
View File
@@ -43,13 +43,13 @@ class JsonWriter(OutputWriter):
compress_columns (list): list of sample batch columns to compress.
"""
self.path = path
self.ioctx = ioctx or IOContext()
self.max_file_size = max_file_size
self.compress_columns = compress_columns
if urlparse(path).scheme:
self.path_is_uri = True
else:
path = os.path.abspath(os.path.expanduser(path))
# Try to create local dirs if they don't exist
try:
os.makedirs(path)
@@ -57,6 +57,7 @@ class JsonWriter(OutputWriter):
pass # already exists
assert os.path.exists(path), "Failed to create {}".format(path)
self.path_is_uri = False
self.path = path
self.file_index = 0
self.bytes_written = 0
self.cur_file = None
+3 -3
View File
@@ -6,10 +6,10 @@ import numpy as np
from ray.rllib.offline.input_reader import InputReader
from ray.rllib.offline.json_reader import JsonReader
from ray.rllib.utils.annotations import override, PublicAPI
from ray.rllib.utils.annotations import override, DeveloperAPI
@PublicAPI
@DeveloperAPI
class MixedInput(InputReader):
"""Mixes input from a number of other input sources.
@@ -21,7 +21,7 @@ class MixedInput(InputReader):
}, ioctx)
"""
@PublicAPI
@DeveloperAPI
def __init__(self, dist, ioctx):
"""Initialize a MixedInput.
@@ -0,0 +1,107 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import namedtuple
import logging
from ray.rllib.evaluation.sample_batch import MultiAgentBatch
from ray.rllib.utils.annotations import DeveloperAPI
logger = logging.getLogger(__name__)
OffPolicyEstimate = namedtuple("OffPolicyEstimate",
["estimator_name", "metrics"])
@DeveloperAPI
class OffPolicyEstimator(object):
"""Interface for an off policy reward estimator."""
@DeveloperAPI
def __init__(self, policy, gamma):
"""Creates an off-policy estimator.
Arguments:
policy (PolicyGraph): Policy graph to evaluate.
gamma (float): Discount of the MDP.
"""
self.policy = policy
self.gamma = gamma
self.new_estimates = []
@classmethod
def create(cls, ioctx):
"""Create an off-policy estimator from a IOContext."""
gamma = ioctx.evaluator.policy_config["gamma"]
# Grab a reference to the current model
keys = list(ioctx.evaluator.policy_map.keys())
if len(keys) > 1:
raise NotImplementedError(
"Off-policy estimation is not implemented for multi-agent. "
"You can set `input_evaluation: []` to resolve this.")
policy = ioctx.evaluator.get_policy(keys[0])
return cls(policy, gamma)
@DeveloperAPI
def estimate(self, batch):
"""Returns an estimate for the given batch of experiences.
The batch will only contain data from one episode, but it may only be
a fragment of an episode.
"""
raise NotImplementedError
@DeveloperAPI
def action_prob(self, batch):
"""Returns the probs for the batch actions for the current policy."""
num_state_inputs = 0
for k in batch.keys():
if k.startswith("state_in_"):
num_state_inputs += 1
state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)]
_, _, info = self.policy.compute_actions(
obs_batch=batch["obs"],
state_batches=[batch[k] for k in state_keys],
prev_action_batch=batch.data.get("prev_action"),
prev_reward_batch=batch.data.get("prev_reward"),
info_batch=batch.data.get("info"))
if "action_prob" not in info:
raise ValueError(
"Off-policy estimation is not possible unless the policy "
"returns action probabilities when computing actions (i.e., "
"the 'action_prob' key is output by the policy graph). You "
"can set `input_evaluation: []` to resolve this.")
return info["action_prob"]
@DeveloperAPI
def process(self, batch):
self.new_estimates.append(self.estimate(batch))
@DeveloperAPI
def check_can_estimate_for(self, batch):
"""Returns whether we can support OPE for this batch."""
if isinstance(batch, MultiAgentBatch):
raise ValueError(
"IS-estimation is not implemented for multi-agent batches. "
"You can set `input_evaluation: []` to resolve this.")
if "action_prob" not in batch:
raise ValueError(
"Off-policy estimation is not possible unless the inputs "
"include action probabilities (i.e., the policy is stochastic "
"and emits the 'action_prob' key). You can set "
"`input_evaluation: []` to resolve this.")
@DeveloperAPI
def get_metrics(self):
"""Return a list of new episode metric estimates since the last call.
Returns:
list of OffPolicyEstimate objects.
"""
out = self.new_estimates
self.new_estimates = []
return out
@@ -0,0 +1,45 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import random
from ray.rllib.offline.input_reader import InputReader
from ray.rllib.utils.annotations import override, DeveloperAPI
logger = logging.getLogger(__name__)
@DeveloperAPI
class ShuffledInput(InputReader):
"""Randomizes data over a sliding window buffer of N batches.
This increases the randomization of the data, which is useful if the
batches were not in random order to start with.
"""
@DeveloperAPI
def __init__(self, child, n=0):
"""Initialize a MixedInput.
Arguments:
child (InputReader): child input reader to shuffle.
n (int): if positive, shuffle input over this many batches.
"""
self.n = n
self.child = child
self.buffer = []
@override(InputReader)
def next(self):
if self.n <= 1:
return self.child.next()
if len(self.buffer) < self.n:
logger.info("Filling shuffle buffer to {} batches".format(self.n))
while len(self.buffer) < self.n:
self.buffer.append(self.child.next())
logger.info("Shuffle buffer filled")
i = random.randint(0, len(self.buffer) - 1)
self.buffer[i] = self.child.next()
return random.choice(self.buffer)
+56
View File
@@ -0,0 +1,56 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, \
OffPolicyEstimate
from ray.rllib.utils.annotations import override
class WeightedImportanceSamplingEstimator(OffPolicyEstimator):
"""The weighted step-wise IS estimator.
Step-wise WIS estimator in https://arxiv.org/pdf/1511.03722.pdf"""
def __init__(self, policy, gamma):
OffPolicyEstimator.__init__(self, policy, gamma)
self.filter_values = []
self.filter_counts = []
@override(OffPolicyEstimator)
def estimate(self, batch):
self.check_can_estimate_for(batch)
rewards, old_prob = batch["rewards"], batch["action_prob"]
new_prob = self.action_prob(batch)
# calculate importance ratios
p = []
for t in range(batch.count - 1):
if t == 0:
pt_prev = 1.0
else:
pt_prev = p[t - 1]
p.append(pt_prev * new_prob[t] / old_prob[t])
for t, v in enumerate(p):
if t >= len(self.filter_values):
self.filter_values.append(v)
self.filter_counts.append(1.0)
else:
self.filter_values[t] += v
self.filter_counts[t] += 1.0
# calculate stepwise weighted IS estimate
V_prev, V_step_WIS = 0.0, 0.0
for t in range(batch.count - 1):
V_prev += rewards[t] * self.gamma**t
w_t = self.filter_values[t] / self.filter_counts[t]
V_step_WIS += p[t] / w_t * rewards[t] * self.gamma**t
estimation = OffPolicyEstimate(
"wis", {
"V_prev": V_prev,
"V_step_WIS": V_step_WIS,
"V_gain_est": V_step_WIS / max(1e-8, V_prev),
})
return estimation
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
+6 -6
View File
@@ -69,7 +69,7 @@ class AgentIOTest(unittest.TestCase):
env="CartPole-v0",
config={
"input": self.test_dir,
"input_evaluation": None,
"input_evaluation": [],
})
result = agent.train()
self.assertEqual(result["timesteps_total"], 250) # read from input
@@ -101,7 +101,7 @@ class AgentIOTest(unittest.TestCase):
env="CartPole-v0",
config={
"input": self.test_dir,
"input_evaluation": None,
"input_evaluation": [],
"postprocess_inputs": True, # adds back 'advantages'
})
@@ -115,7 +115,7 @@ class AgentIOTest(unittest.TestCase):
env="CartPole-v0",
config={
"input": self.test_dir,
"input_evaluation": "simulation",
"input_evaluation": ["simulation"],
})
for _ in range(50):
result = agent.train()
@@ -130,7 +130,7 @@ class AgentIOTest(unittest.TestCase):
env="CartPole-v0",
config={
"input": glob.glob(self.test_dir + "/*.json"),
"input_evaluation": None,
"input_evaluation": [],
"sample_batch_size": 99,
})
result = agent.train()
@@ -147,7 +147,7 @@ class AgentIOTest(unittest.TestCase):
"sampler": 0.9,
},
"train_batch_size": 2000,
"input_evaluation": None,
"input_evaluation": [],
})
result = agent.train()
self.assertTrue(not np.isnan(result["episode_reward_mean"]))
@@ -185,7 +185,7 @@ class AgentIOTest(unittest.TestCase):
config={
"num_workers": 0,
"input": self.test_dir,
"input_evaluation": "simulation",
"input_evaluation": ["simulation"],
"train_batch_size": 2000,
"multiagent": {
"policy_graphs": {