mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:08:13 +08:00
[rllib] Basic infrastructure for off-policy estimation (IS, WIS) (#3941)
This commit is contained in:
+10
-25
@@ -19,36 +19,21 @@ import shlex
|
||||
# These lines added to enable Sphinx to work without installing Ray.
|
||||
import mock
|
||||
MOCK_MODULES = [
|
||||
"gym",
|
||||
"gym.spaces",
|
||||
"scipy",
|
||||
"scipy.signal",
|
||||
"tensorflow",
|
||||
"tensorflow.contrib",
|
||||
"tensorflow.contrib.all_reduce",
|
||||
"tensorflow.contrib.all_reduce.python",
|
||||
"tensorflow.contrib.layers",
|
||||
"tensorflow.contrib.slim",
|
||||
"tensorflow.contrib.rnn",
|
||||
"tensorflow.core",
|
||||
"tensorflow.core.util",
|
||||
"tensorflow.python",
|
||||
"tensorflow.python.client",
|
||||
"tensorflow.python.util",
|
||||
"ray.core.generated",
|
||||
"gym", "gym.spaces", "scipy", "scipy.signal", "tensorflow",
|
||||
"tensorflow.contrib", "tensorflow.contrib.all_reduce",
|
||||
"tensorflow.contrib.all_reduce.python", "tensorflow.contrib.layers",
|
||||
"tensorflow.contrib.slim", "tensorflow.contrib.rnn", "tensorflow.core",
|
||||
"tensorflow.core.util", "tensorflow.python", "tensorflow.python.client",
|
||||
"tensorflow.python.util", "ray.core.generated",
|
||||
"ray.core.generated.ActorCheckpointIdData",
|
||||
"ray.core.generated.ClientTableData",
|
||||
"ray.core.generated.GcsTableEntry",
|
||||
"ray.core.generated.ClientTableData", "ray.core.generated.GcsTableEntry",
|
||||
"ray.core.generated.HeartbeatTableData",
|
||||
"ray.core.generated.HeartbeatBatchTableData",
|
||||
"ray.core.generated.DriverTableData",
|
||||
"ray.core.generated.ErrorTableData",
|
||||
"ray.core.generated.DriverTableData", "ray.core.generated.ErrorTableData",
|
||||
"ray.core.generated.ProfileTableData",
|
||||
"ray.core.generated.ObjectTableData",
|
||||
"ray.core.generated.ray.protocol.Task",
|
||||
"ray.core.generated.TablePrefix",
|
||||
"ray.core.generated.TablePubsub",
|
||||
"ray.core.generated.Language",
|
||||
"ray.core.generated.ray.protocol.Task", "ray.core.generated.TablePrefix",
|
||||
"ray.core.generated.TablePubsub", "ray.core.generated.Language",
|
||||
"ray._raylet"
|
||||
]
|
||||
for mod_name in MOCK_MODULES:
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 34 KiB After Width: | Height: | Size: 65 KiB |
@@ -44,14 +44,46 @@ Then, we can tell DQN to train using these previously generated experiences with
|
||||
--env=CartPole-v0 \
|
||||
--config='{
|
||||
"input": "/tmp/cartpole-out",
|
||||
"input_evaluation": [],
|
||||
"exploration_final_eps": 0,
|
||||
"exploration_fraction": 0}'
|
||||
|
||||
Since the input experiences are not from running simulations, RLlib cannot report the true policy performance during training. However, you can use ``tensorboard --logdir=~/ray_results`` to monitor training progress via other metrics such as estimated Q-value:
|
||||
**Off-policy estimation:** Since the input experiences are not from running simulations, RLlib cannot report the true policy performance during training. However, you can use ``tensorboard --logdir=~/ray_results`` to monitor training progress via other metrics such as estimated Q-value. Alternatively, `off-policy estimation <https://arxiv.org/pdf/1511.03722.pdf>`__ can be used, which requires both the source and target action probabilities to be available (i.e., the ``action_prob`` batch key). For DQN, this means enabling soft Q learning so that actions are sampled from a probability distribution:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ rllib train \
|
||||
--run=DQN \
|
||||
--env=CartPole-v0 \
|
||||
--config='{
|
||||
"input": "/tmp/cartpole-out",
|
||||
"input_evaluation": ["is", "wis"],
|
||||
"soft_q": true,
|
||||
"softmax_temp": 1.0}'
|
||||
|
||||
This example plot shows the Q-value metric in addition to importance sampling (IS) and weighted importance sampling (WIS) gain estimates (>1.0 means there is an estimated improvement over the original policy):
|
||||
|
||||
.. image:: offline-q.png
|
||||
|
||||
In offline input mode, no simulations are run, though you still need to specify the environment in order to define the action and observation spaces. If true simulation is also possible (i.e., your env supports ``step()``), you can also set ``"input_evaluation": "simulation"`` to tell RLlib to run background simulations to estimate current policy performance. The output of these simulations will not be used for learning.
|
||||
**Estimator Python API:** For greater control over the evaluation process, you can create off-policy estimators in your Python code and call ``estimator.estimate(episode_batch)`` to perform counterfactual estimation as needed. The estimators take in a policy graph object and gamma value for the environment:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
agent = DQNAgent(...)
|
||||
... # train agent offline
|
||||
|
||||
from ray.rllib.offline.json_reader import JsonReader
|
||||
from ray.rllib.offline.wis_estimator import WeightedImportanceSamplingEstimator
|
||||
|
||||
estimator = WeightedImportanceSamplingEstimator(agent.get_policy(), gamma=0.99)
|
||||
reader = JsonReader("/path/to/data")
|
||||
for _ in range(1000):
|
||||
batch = reader.next()
|
||||
for episode in batch.split_by_episode():
|
||||
print(estimator.estimate(episode))
|
||||
|
||||
|
||||
**Simulation-based estimation:** If true simulation is also possible (i.e., your env supports ``step()``), you can also set ``"input_evaluation": ["simulation"]`` to tell RLlib to run background simulations to estimate current policy performance. The output of these simulations will not be used for learning. Note that in all cases you still need to specify an environment object to define the action and observation spaces. However, you don't need to implement functions like reset() and step().
|
||||
|
||||
Example: Converting external experiences to batch format
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@@ -97,6 +97,7 @@ class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
self.sess,
|
||||
obs_input=self.observations,
|
||||
action_sampler=action_dist.sample(),
|
||||
action_prob=action_dist.sampled_action_prob(),
|
||||
loss=self.model.loss() + self.loss.total_loss,
|
||||
loss_inputs=loss_in,
|
||||
state_inputs=self.model.state_in,
|
||||
@@ -153,7 +154,9 @@ class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def extra_compute_action_fetches(self):
|
||||
return {"vf_preds": self.vf}
|
||||
return dict(
|
||||
TFPolicyGraph.extra_compute_action_fetches(self),
|
||||
**{"vf_preds": self.vf})
|
||||
|
||||
def _value(self, ob, *args):
|
||||
feed_dict = {self.observations: [ob], self.model.seq_lens: [1]}
|
||||
|
||||
@@ -13,7 +13,8 @@ import tensorflow as tf
|
||||
from types import FunctionType
|
||||
|
||||
import ray
|
||||
from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter
|
||||
from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \
|
||||
ShuffledInput
|
||||
from ray.rllib.models import MODEL_DEFAULTS
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
@@ -145,18 +146,22 @@ COMMON_CONFIG = {
|
||||
# {"sampler": 0.4, "/tmp/*.json": 0.4, "s3://bucket/expert.json": 0.2}).
|
||||
# - a function that returns a rllib.offline.InputReader
|
||||
"input": "sampler",
|
||||
# Specify how to evaluate the current policy. This only makes sense to set
|
||||
# when the input is not already generating simulation data:
|
||||
# - None: don't evaluate the policy. The episode reward and other
|
||||
# metrics will be NaN if using offline data.
|
||||
# Specify how to evaluate the current policy. This only has an effect when
|
||||
# reading offline experiences. Available options:
|
||||
# - "wis": the weighted step-wise importance sampling estimator.
|
||||
# - "is": the step-wise importance sampling estimator.
|
||||
# - "simulation": run the environment in the background, but use
|
||||
# this data for evaluation only and not for learning.
|
||||
"input_evaluation": None,
|
||||
"input_evaluation": ["is", "wis"],
|
||||
# Whether to run postprocess_trajectory() on the trajectory fragments from
|
||||
# offline inputs. Note that postprocessing will be done using the *current*
|
||||
# policy, not the *behaviour* policy, which is typically undesirable for
|
||||
# on-policy algorithms.
|
||||
"postprocess_inputs": False,
|
||||
# If positive, input batches will be shuffled via a sliding window buffer
|
||||
# of this number of batches. Use this if the input data is not in random
|
||||
# enough order. Input is delayed until the shuffle buffer is filled.
|
||||
"shuffle_buffer_size": 0,
|
||||
# __sphinx_doc_input_end__
|
||||
# __sphinx_doc_output_begin__
|
||||
# Specify where experiences should be saved:
|
||||
@@ -552,10 +557,10 @@ class Agent(Trainable):
|
||||
raise ValueError(
|
||||
"The `use_gpu_for_workers` config is deprecated, please use "
|
||||
"`num_gpus_per_worker=1` instead.")
|
||||
if (config["input"] == "sampler"
|
||||
and config["input_evaluation"] is not None):
|
||||
if type(config["input_evaluation"]) != list:
|
||||
raise ValueError(
|
||||
"`input_evaluation` should not be set when input=sampler")
|
||||
"`input_evaluation` must be a list of strings, got {}".format(
|
||||
config["input_evaluation"]))
|
||||
|
||||
def _make_evaluator(self,
|
||||
cls,
|
||||
@@ -575,9 +580,13 @@ class Agent(Trainable):
|
||||
elif config["input"] == "sampler":
|
||||
input_creator = (lambda ioctx: ioctx.default_sampler_input())
|
||||
elif isinstance(config["input"], dict):
|
||||
input_creator = (lambda ioctx: MixedInput(config["input"], ioctx))
|
||||
input_creator = (lambda ioctx: ShuffledInput(
|
||||
MixedInput(config["input"], ioctx),
|
||||
config["shuffle_buffer_size"]))
|
||||
else:
|
||||
input_creator = (lambda ioctx: JsonReader(config["input"], ioctx))
|
||||
input_creator = (lambda ioctx: ShuffledInput(
|
||||
JsonReader(config["input"], ioctx),
|
||||
config["shuffle_buffer_size"]))
|
||||
|
||||
if isinstance(config["output"], FunctionType):
|
||||
output_creator = config["output"]
|
||||
@@ -596,6 +605,11 @@ class Agent(Trainable):
|
||||
max_file_size=config["output_max_file_size"],
|
||||
compress_columns=config["output_compress_columns"]))
|
||||
|
||||
if config["input"] == "sampler":
|
||||
input_evaluation = []
|
||||
else:
|
||||
input_evaluation = config["input_evaluation"]
|
||||
|
||||
return cls(
|
||||
env_creator,
|
||||
self.config["multiagent"]["policy_graphs"] or policy_graph,
|
||||
@@ -622,7 +636,7 @@ class Agent(Trainable):
|
||||
log_level=config["log_level"],
|
||||
callbacks=config["callbacks"],
|
||||
input_creator=input_creator,
|
||||
input_evaluation_method=config["input_evaluation"],
|
||||
input_evaluation=input_evaluation,
|
||||
output_creator=output_creator,
|
||||
remote_worker_envs=remote_worker_envs)
|
||||
|
||||
|
||||
@@ -269,6 +269,11 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
q_t, self.q_model = self._build_q_network(
|
||||
self.obs_t, observation_space, self.act_t)
|
||||
self.q_func_vars = _scope_vars(scope.name)
|
||||
self.stats = {
|
||||
"mean_q": tf.reduce_mean(q_t),
|
||||
"max_q": tf.reduce_max(q_t),
|
||||
"min_q": tf.reduce_min(q_t),
|
||||
}
|
||||
with tf.variable_scope(Q_SCOPE, reuse=True):
|
||||
q_tp0, _ = self._build_q_network(self.obs_t, observation_space,
|
||||
output_actions)
|
||||
@@ -416,6 +421,7 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
def extra_compute_grad_fetches(self):
|
||||
return {
|
||||
"td_error": self.loss.td_error,
|
||||
"stats": self.stats,
|
||||
}
|
||||
|
||||
@override(PolicyGraph)
|
||||
|
||||
@@ -68,6 +68,11 @@ DEFAULT_CONFIG = with_common_config({
|
||||
"exploration_final_eps": 0.02,
|
||||
# Update the target network every `target_network_update_freq` steps.
|
||||
"target_network_update_freq": 500,
|
||||
# Use softmax for sampling actions.
|
||||
"soft_q": False,
|
||||
# Softmax temperature. Q values are divided by this value prior to softmax.
|
||||
# Softmax approaches argmax as the temperature drops to zero.
|
||||
"softmax_temp": 1.0,
|
||||
|
||||
# === Replay buffer ===
|
||||
# Size of the replay buffer. Note that if async_updates is set, then
|
||||
|
||||
@@ -8,8 +8,7 @@ import tensorflow as tf
|
||||
import tensorflow.contrib.layers as layers
|
||||
|
||||
import ray
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.models import ModelCatalog, Categorical
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
@@ -182,7 +181,14 @@ class QNetwork(object):
|
||||
|
||||
|
||||
class QValuePolicy(object):
|
||||
def __init__(self, q_values, observations, num_actions, stochastic, eps):
|
||||
def __init__(self, q_values, observations, num_actions, stochastic, eps,
|
||||
softmax, softmax_temp):
|
||||
if softmax:
|
||||
action_dist = Categorical(q_values / softmax_temp)
|
||||
self.action = action_dist.sample()
|
||||
self.action_prob = action_dist.sampled_action_prob()
|
||||
return
|
||||
|
||||
deterministic_actions = tf.argmax(q_values, axis=1)
|
||||
batch_size = tf.shape(observations)[0]
|
||||
|
||||
@@ -200,6 +206,7 @@ class QValuePolicy(object):
|
||||
deterministic_actions)
|
||||
self.action = tf.cond(stochastic, lambda: stochastic_actions,
|
||||
lambda: deterministic_actions)
|
||||
self.action_prob = None
|
||||
|
||||
|
||||
class QLoss(object):
|
||||
@@ -300,10 +307,12 @@ class DQNPolicyGraph(TFPolicyGraph):
|
||||
with tf.variable_scope(Q_SCOPE) as scope:
|
||||
q_values, q_logits, q_dist, _ = self._build_q_network(
|
||||
self.cur_observations, observation_space)
|
||||
self.q_values = q_values
|
||||
self.q_func_vars = _scope_vars(scope.name)
|
||||
|
||||
# Action outputs
|
||||
self.output_actions = self._build_q_value_policy(q_values)
|
||||
self.output_actions, self.action_prob = self._build_q_value_policy(
|
||||
q_values)
|
||||
|
||||
# Replay inputs
|
||||
self.obs_t = tf.placeholder(
|
||||
@@ -387,6 +396,7 @@ class DQNPolicyGraph(TFPolicyGraph):
|
||||
self.sess,
|
||||
obs_input=self.cur_observations,
|
||||
action_sampler=self.output_actions,
|
||||
action_prob=self.action_prob,
|
||||
loss=model.loss() + self.loss.loss,
|
||||
loss_inputs=self.loss_inputs,
|
||||
update_ops=q_batchnorm_update_ops)
|
||||
@@ -412,6 +422,13 @@ class DQNPolicyGraph(TFPolicyGraph):
|
||||
grads_and_vars = [(g, v) for (g, v) in grads_and_vars if g is not None]
|
||||
return grads_and_vars
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def extra_compute_action_fetches(self):
|
||||
return dict(
|
||||
TFPolicyGraph.extra_compute_action_fetches(self), **{
|
||||
"q_values": self.q_values,
|
||||
})
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def extra_compute_action_feed_dict(self):
|
||||
return {
|
||||
@@ -474,8 +491,10 @@ class DQNPolicyGraph(TFPolicyGraph):
|
||||
return qnet.value, qnet.logits, qnet.dist, qnet.model
|
||||
|
||||
def _build_q_value_policy(self, q_values):
|
||||
return QValuePolicy(q_values, self.cur_observations, self.num_actions,
|
||||
self.stochastic, self.eps).action
|
||||
policy = QValuePolicy(
|
||||
q_values, self.cur_observations, self.num_actions, self.stochastic,
|
||||
self.eps, self.config["soft_q"], self.config["softmax_temp"])
|
||||
return policy.action, policy.action_prob
|
||||
|
||||
def _build_q_loss(self, q_t_selected, q_logits_t_selected, q_tp1_best,
|
||||
q_dist_tp1_best):
|
||||
@@ -511,26 +530,16 @@ def _adjust_nstep(n_step, gamma, obs, actions, rewards, new_obs, dones):
|
||||
rewards[i] += gamma**j * rewards[i + j]
|
||||
|
||||
|
||||
def _postprocess_dqn(policy_graph, sample_batch):
|
||||
obs, actions, rewards, new_obs, dones = [
|
||||
list(x) for x in sample_batch.columns(
|
||||
["obs", "actions", "rewards", "new_obs", "dones"])
|
||||
]
|
||||
|
||||
def _postprocess_dqn(policy_graph, batch):
|
||||
# N-step Q adjustments
|
||||
if policy_graph.config["n_step"] > 1:
|
||||
_adjust_nstep(policy_graph.config["n_step"],
|
||||
policy_graph.config["gamma"], obs, actions, rewards,
|
||||
new_obs, dones)
|
||||
policy_graph.config["gamma"], batch["obs"],
|
||||
batch["actions"], batch["rewards"], batch["new_obs"],
|
||||
batch["dones"])
|
||||
|
||||
batch = SampleBatch({
|
||||
"obs": obs,
|
||||
"actions": actions,
|
||||
"rewards": rewards,
|
||||
"new_obs": new_obs,
|
||||
"dones": dones,
|
||||
"weights": np.ones_like(rewards)
|
||||
})
|
||||
if "weights" not in batch:
|
||||
batch["weights"] = np.ones_like(batch["rewards"])
|
||||
|
||||
# Prioritize on the worker side
|
||||
if batch.count > 0 and policy_graph.config["worker_side_prioritization"]:
|
||||
|
||||
@@ -215,6 +215,7 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
self.sess,
|
||||
obs_input=observations,
|
||||
action_sampler=action_dist.sample(),
|
||||
action_prob=action_dist.sampled_action_prob(),
|
||||
loss=self.model.loss() + self.loss.total_loss,
|
||||
loss_inputs=loss_in,
|
||||
state_inputs=self.model.state_in,
|
||||
@@ -270,7 +271,9 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def extra_compute_action_fetches(self):
|
||||
return {"behaviour_logits": self.model.outputs}
|
||||
return dict(
|
||||
TFPolicyGraph.extra_compute_action_fetches(self),
|
||||
**{"behaviour_logits": self.model.outputs})
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def extra_compute_grad_fetches(self):
|
||||
|
||||
@@ -19,8 +19,8 @@ DEFAULT_CONFIG = with_common_config({
|
||||
"postprocess_inputs": True,
|
||||
# Whether to rollout "complete_episodes" or "truncate_episodes"
|
||||
"batch_mode": "complete_episodes",
|
||||
# Read data from historic data and evaluate by a sampler
|
||||
"input_evaluation": "simulation",
|
||||
# Use importance sampling estimators for reward
|
||||
"input_evaluation": ["is", "wis"],
|
||||
# Learning rate for adam optimizer
|
||||
"lr": 1e-4,
|
||||
# Number of timesteps collected for each SGD round
|
||||
|
||||
@@ -107,6 +107,7 @@ class MARWILPolicyGraph(TFPolicyGraph):
|
||||
self.sess,
|
||||
obs_input=self.obs_t,
|
||||
action_sampler=self.output_actions,
|
||||
action_prob=action_dist.sampled_action_prob(),
|
||||
loss=self.model.loss() + objective,
|
||||
loss_inputs=self.loss_inputs,
|
||||
state_inputs=self.model.state_in,
|
||||
|
||||
@@ -67,6 +67,7 @@ class PGPolicyGraph(TFPolicyGraph):
|
||||
sess,
|
||||
obs_input=obs,
|
||||
action_sampler=action_dist.sample(),
|
||||
action_prob=action_dist.sampled_action_prob(),
|
||||
loss=self.model.loss() + loss,
|
||||
loss_inputs=loss_in,
|
||||
state_inputs=self.model.state_in,
|
||||
|
||||
@@ -320,6 +320,7 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
self.sess,
|
||||
obs_input=observations,
|
||||
action_sampler=action_dist.sample(),
|
||||
action_prob=action_dist.sampled_action_prob(),
|
||||
loss=self.model.loss() + self.loss.total_loss,
|
||||
loss_inputs=loss_in,
|
||||
state_inputs=self.model.state_in,
|
||||
@@ -373,7 +374,7 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
out = {"behaviour_logits": self.model.outputs}
|
||||
if not self.config["vtrace"]:
|
||||
out["vf_preds"] = self.value_function
|
||||
return out
|
||||
return dict(TFPolicyGraph.extra_compute_action_fetches(self), **out)
|
||||
|
||||
def extra_compute_grad_fetches(self):
|
||||
return self.stats_fetches
|
||||
|
||||
@@ -234,6 +234,7 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
self.sess,
|
||||
obs_input=obs_ph,
|
||||
action_sampler=self.sampler,
|
||||
action_prob=curr_action_dist.sampled_action_prob(),
|
||||
loss=self.model.loss() + self.loss_obj.loss,
|
||||
loss_inputs=self.loss_in,
|
||||
state_inputs=self.model.state_in,
|
||||
@@ -307,7 +308,11 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def extra_compute_action_fetches(self):
|
||||
return {"vf_preds": self.value_function, "logits": self.logits}
|
||||
return dict(
|
||||
TFPolicyGraph.extra_compute_action_fetches(self), **{
|
||||
"vf_preds": self.value_function,
|
||||
"logits": self.logits
|
||||
})
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def extra_compute_grad_fetches(self):
|
||||
|
||||
@@ -8,6 +8,8 @@ import collections
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.evaluation.sampler import RolloutMetrics
|
||||
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimate
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -31,15 +33,14 @@ def collect_episodes(local_evaluator,
|
||||
"""Gathers new episodes metrics tuples from the given evaluators."""
|
||||
|
||||
pending = [
|
||||
a.apply.remote(lambda ev: ev.sampler.get_metrics())
|
||||
for a in remote_evaluators
|
||||
a.apply.remote(lambda ev: ev.get_metrics()) for a in remote_evaluators
|
||||
]
|
||||
collected, _ = ray.wait(
|
||||
pending, num_returns=len(pending), timeout=timeout_seconds * 1.0)
|
||||
num_metric_batches_dropped = len(pending) - len(collected)
|
||||
|
||||
metric_lists = ray.get(collected)
|
||||
metric_lists.append(local_evaluator.sampler.get_metrics())
|
||||
metric_lists.append(local_evaluator.get_metrics())
|
||||
episodes = []
|
||||
for metrics in metric_lists:
|
||||
episodes.extend(metrics)
|
||||
@@ -60,6 +61,9 @@ def summarize_episodes(episodes, new_episodes, num_dropped):
|
||||
logger.warning("WARNING: {} workers have NOT returned metrics".format(
|
||||
num_dropped))
|
||||
|
||||
episodes, estimates = _partition(episodes)
|
||||
new_episodes, _ = _partition(new_episodes)
|
||||
|
||||
episode_rewards = []
|
||||
episode_lengths = []
|
||||
policy_rewards = collections.defaultdict(list)
|
||||
@@ -95,6 +99,16 @@ def summarize_episodes(episodes, new_episodes, num_dropped):
|
||||
custom_metrics[k + "_max"] = float("nan")
|
||||
del custom_metrics[k]
|
||||
|
||||
estimators = collections.defaultdict(lambda: collections.defaultdict(list))
|
||||
for e in estimates:
|
||||
acc = estimators[e.estimator_name]
|
||||
for k, v in e.metrics.items():
|
||||
acc[k].append(v)
|
||||
for name, metrics in estimators.items():
|
||||
for k, v_list in metrics.items():
|
||||
metrics[k] = np.mean(v_list)
|
||||
estimators[name] = dict(metrics)
|
||||
|
||||
return dict(
|
||||
episode_reward_max=max_reward,
|
||||
episode_reward_min=min_reward,
|
||||
@@ -103,4 +117,19 @@ def summarize_episodes(episodes, new_episodes, num_dropped):
|
||||
episodes_this_iter=len(new_episodes),
|
||||
policy_reward_mean=dict(policy_rewards),
|
||||
custom_metrics=dict(custom_metrics),
|
||||
off_policy_estimator=dict(estimators),
|
||||
num_metric_batches_dropped=num_dropped)
|
||||
|
||||
|
||||
def _partition(episodes):
|
||||
"""Divides metrics data into true rollouts vs off-policy estimates."""
|
||||
|
||||
rollouts, estimates = [], []
|
||||
for e in episodes:
|
||||
if isinstance(e, RolloutMetrics):
|
||||
rollouts.append(e)
|
||||
elif isinstance(e, OffPolicyEstimate):
|
||||
estimates.append(e)
|
||||
else:
|
||||
raise ValueError("Unknown metric type: {}".format(e))
|
||||
return rollouts, estimates
|
||||
|
||||
@@ -19,6 +19,8 @@ from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.offline import NoopOutput, IOContext, OutputWriter, InputReader
|
||||
from ray.rllib.offline.is_estimator import ImportanceSamplingEstimator
|
||||
from ray.rllib.offline.wis_estimator import WeightedImportanceSamplingEstimator
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.models.preprocessors import NoPreprocessor
|
||||
from ray.rllib.utils import merge_dicts
|
||||
@@ -116,7 +118,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
log_level=None,
|
||||
callbacks=None,
|
||||
input_creator=lambda ioctx: ioctx.default_sampler_input(),
|
||||
input_evaluation_method=None,
|
||||
input_evaluation=frozenset([]),
|
||||
output_creator=lambda ioctx: NoopOutput(),
|
||||
remote_worker_envs=False):
|
||||
"""Initialize a policy evaluator.
|
||||
@@ -184,11 +186,11 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
callbacks (dict): Dict of custom debug callbacks.
|
||||
input_creator (func): Function that returns an InputReader object
|
||||
for loading previous generated experiences.
|
||||
input_evaluation_method (str): How to evaluate the current policy.
|
||||
This only applies when the input is reading offline data.
|
||||
Options are:
|
||||
- None: don't evaluate the policy. The episode reward and
|
||||
other metrics will be NaN.
|
||||
input_evaluation (list): How to evaluate the policy performance.
|
||||
This only makes sense to set when the input is reading offline
|
||||
data. The possible values include:
|
||||
- "is": the step-wise importance sampling estimator.
|
||||
- "wis": the weighted step-wise is estimator.
|
||||
- "simulation": run the environment in the background, but
|
||||
use this data for evaluation only and never for learning.
|
||||
output_creator (func): Function that returns an OutputWriter object
|
||||
@@ -316,16 +318,24 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
raise ValueError("Unsupported batch mode: {}".format(
|
||||
self.batch_mode))
|
||||
|
||||
if input_evaluation_method == "simulation":
|
||||
logger.warning(
|
||||
"Requested 'simulation' input evaluation method: "
|
||||
"will discard all sampler outputs and keep only metrics.")
|
||||
sample_async = True
|
||||
elif input_evaluation_method is None:
|
||||
pass
|
||||
else:
|
||||
raise ValueError("Unknown evaluation method: {}".format(
|
||||
input_evaluation_method))
|
||||
self.io_context = IOContext(log_dir, policy_config, worker_index, self)
|
||||
self.reward_estimators = []
|
||||
for method in input_evaluation:
|
||||
if method == "simulation":
|
||||
logger.warning(
|
||||
"Requested 'simulation' input evaluation method: "
|
||||
"will discard all sampler outputs and keep only metrics.")
|
||||
sample_async = True
|
||||
elif method == "is":
|
||||
ise = ImportanceSamplingEstimator.create(self.io_context)
|
||||
self.reward_estimators.append(ise)
|
||||
elif method == "wis":
|
||||
wise = WeightedImportanceSamplingEstimator.create(
|
||||
self.io_context)
|
||||
self.reward_estimators.append(wise)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown evaluation method: {}".format(method))
|
||||
|
||||
if sample_async:
|
||||
self.sampler = AsyncSampler(
|
||||
@@ -341,7 +351,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
pack=pack_episodes,
|
||||
tf_sess=self.tf_sess,
|
||||
clip_actions=clip_actions,
|
||||
blackhole_outputs=input_evaluation_method == "simulation")
|
||||
blackhole_outputs="simulation" in input_evaluation)
|
||||
self.sampler.start()
|
||||
else:
|
||||
self.sampler = SyncSampler(
|
||||
@@ -358,7 +368,6 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
tf_sess=self.tf_sess,
|
||||
clip_actions=clip_actions)
|
||||
|
||||
self.io_context = IOContext(log_dir, policy_config, worker_index, self)
|
||||
self.input_reader = input_creator(self.io_context)
|
||||
assert isinstance(self.input_reader, InputReader), self.input_reader
|
||||
self.output_writer = output_creator(self.io_context)
|
||||
@@ -402,6 +411,12 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
# for better compression inside the writer.
|
||||
self.output_writer.write(batch)
|
||||
|
||||
# Do off-policy estimation if needed
|
||||
if self.reward_estimators:
|
||||
for sub_batch in batch.split_by_episode():
|
||||
for estimator in self.reward_estimators:
|
||||
estimator.process(sub_batch)
|
||||
|
||||
if self.compress_observations:
|
||||
if isinstance(batch, MultiAgentBatch):
|
||||
for data in batch.policy_batches.values():
|
||||
@@ -504,6 +519,15 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
self.policy_map[DEFAULT_POLICY_ID].learn_on_batch(samples))
|
||||
return grad_fetch
|
||||
|
||||
@DeveloperAPI
|
||||
def get_metrics(self):
|
||||
"""Returns a list of new RolloutMetric objects from evaluation."""
|
||||
|
||||
out = self.sampler.get_metrics()
|
||||
for m in self.reward_estimators:
|
||||
out.extend(m.get_metrics())
|
||||
return out
|
||||
|
||||
@DeveloperAPI
|
||||
def foreach_env(self, func):
|
||||
"""Apply the given function to each underlying env instance."""
|
||||
|
||||
@@ -16,6 +16,8 @@ from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.env.base_env import BaseEnv
|
||||
from ray.rllib.env.atari_wrappers import get_wrapper_by_cls, MonitorEnv
|
||||
from ray.rllib.models.action_dist import TupleActions
|
||||
from ray.rllib.offline import InputReader
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -31,7 +33,20 @@ PolicyEvalData = namedtuple("PolicyEvalData", [
|
||||
])
|
||||
|
||||
|
||||
class SyncSampler(object):
|
||||
class SamplerInput(InputReader):
|
||||
"""Reads input experiences from an existing sampler."""
|
||||
|
||||
@override(InputReader)
|
||||
def next(self):
|
||||
batches = [self.get_data()]
|
||||
batches.extend(self.get_extra_batches())
|
||||
if len(batches) > 1:
|
||||
return batches[0].concat_samples(batches)
|
||||
else:
|
||||
return batches[0]
|
||||
|
||||
|
||||
class SyncSampler(SamplerInput):
|
||||
def __init__(self,
|
||||
env,
|
||||
policies,
|
||||
@@ -87,7 +102,7 @@ class SyncSampler(object):
|
||||
return extra
|
||||
|
||||
|
||||
class AsyncSampler(threading.Thread):
|
||||
class AsyncSampler(threading.Thread, SamplerInput):
|
||||
def __init__(self,
|
||||
env,
|
||||
policies,
|
||||
|
||||
@@ -52,6 +52,7 @@ class TFPolicyGraph(PolicyGraph):
|
||||
action_sampler,
|
||||
loss,
|
||||
loss_inputs,
|
||||
action_prob=None,
|
||||
state_inputs=None,
|
||||
state_outputs=None,
|
||||
prev_action_input=None,
|
||||
@@ -77,6 +78,7 @@ class TFPolicyGraph(PolicyGraph):
|
||||
and has shape [BATCH_SIZE, data...]. These keys will be read
|
||||
from postprocessed sample batches and fed into the specified
|
||||
placeholders during loss computation.
|
||||
action_prob (Tensor): probability of the sampled action.
|
||||
state_inputs (list): list of RNN state input Tensors.
|
||||
state_outputs (list): list of RNN state output Tensors.
|
||||
prev_action_input (Tensor): placeholder for previous actions
|
||||
@@ -104,6 +106,7 @@ class TFPolicyGraph(PolicyGraph):
|
||||
self._loss_inputs = loss_inputs
|
||||
self._loss_input_dict = dict(self._loss_inputs)
|
||||
self._is_training = self._get_is_training_placeholder()
|
||||
self._action_prob = action_prob
|
||||
self._state_inputs = state_inputs or []
|
||||
self._state_outputs = state_outputs or []
|
||||
for i, ph in enumerate(self._state_inputs):
|
||||
@@ -231,8 +234,14 @@ class TFPolicyGraph(PolicyGraph):
|
||||
|
||||
@DeveloperAPI
|
||||
def extra_compute_action_fetches(self):
|
||||
"""Extra values to fetch and return from compute_actions()."""
|
||||
return {} # e.g, value function
|
||||
"""Extra values to fetch and return from compute_actions().
|
||||
|
||||
By default we only return action probability info (if present).
|
||||
"""
|
||||
if self._action_prob is not None:
|
||||
return {"action_prob": self._action_prob}
|
||||
else:
|
||||
return {}
|
||||
|
||||
@DeveloperAPI
|
||||
def extra_compute_grad_feed_dict(self):
|
||||
|
||||
@@ -33,6 +33,7 @@ if __name__ == "__main__":
|
||||
agent_index=0,
|
||||
obs=obs,
|
||||
actions=action,
|
||||
action_prob=1.0, # put the true action probability here
|
||||
rewards=rew,
|
||||
prev_actions=prev_action,
|
||||
prev_rewards=prev_reward,
|
||||
|
||||
@@ -24,6 +24,7 @@ class ActionDistribution(object):
|
||||
@DeveloperAPI
|
||||
def __init__(self, inputs):
|
||||
self.inputs = inputs
|
||||
self.sample_op = self._build_sample_op()
|
||||
|
||||
@DeveloperAPI
|
||||
def logp(self, x):
|
||||
@@ -37,13 +38,27 @@ class ActionDistribution(object):
|
||||
|
||||
@DeveloperAPI
|
||||
def entropy(self):
|
||||
"""The entroy of the action distribution."""
|
||||
"""The entropy of the action distribution."""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def _build_sample_op(self):
|
||||
"""Implement this instead of sample(), to enable op reuse.
|
||||
|
||||
This is needed since the sample op is non-deterministic and is shared
|
||||
between sample() and sampled_action_prob().
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def sample(self):
|
||||
"""Draw a sample from the action distribution."""
|
||||
raise NotImplementedError
|
||||
return self.sample_op
|
||||
|
||||
@DeveloperAPI
|
||||
def sampled_action_prob(self):
|
||||
"""Returns the log probability of the sampled action."""
|
||||
return tf.exp(self.logp(self.sample_op))
|
||||
|
||||
|
||||
class Categorical(ActionDistribution):
|
||||
@@ -95,7 +110,7 @@ class Categorical(ActionDistribution):
|
||||
p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), reduction_indices=[1])
|
||||
|
||||
@override(ActionDistribution)
|
||||
def sample(self):
|
||||
def _build_sample_op(self):
|
||||
return tf.squeeze(tf.multinomial(self.inputs, 1), axis=1)
|
||||
|
||||
|
||||
@@ -107,11 +122,11 @@ class DiagGaussian(ActionDistribution):
|
||||
"""
|
||||
|
||||
def __init__(self, inputs):
|
||||
ActionDistribution.__init__(self, inputs)
|
||||
mean, log_std = tf.split(inputs, 2, axis=1)
|
||||
self.mean = mean
|
||||
self.log_std = log_std
|
||||
self.std = tf.exp(log_std)
|
||||
ActionDistribution.__init__(self, inputs)
|
||||
|
||||
@override(ActionDistribution)
|
||||
def logp(self, x):
|
||||
@@ -136,7 +151,7 @@ class DiagGaussian(ActionDistribution):
|
||||
reduction_indices=[1])
|
||||
|
||||
@override(ActionDistribution)
|
||||
def sample(self):
|
||||
def _build_sample_op(self):
|
||||
return self.mean + self.std * tf.random_normal(tf.shape(self.mean))
|
||||
|
||||
|
||||
@@ -147,7 +162,11 @@ class Deterministic(ActionDistribution):
|
||||
"""
|
||||
|
||||
@override(ActionDistribution)
|
||||
def sample(self):
|
||||
def sampled_action_prob(self):
|
||||
return 1.0
|
||||
|
||||
@override(ActionDistribution)
|
||||
def _build_sample_op(self):
|
||||
return self.inputs
|
||||
|
||||
|
||||
@@ -205,5 +224,12 @@ class MultiActionDistribution(ActionDistribution):
|
||||
def sample(self):
|
||||
return TupleActions([s.sample() for s in self.child_distributions])
|
||||
|
||||
@override(ActionDistribution)
|
||||
def sampled_action_prob(self):
|
||||
p = self.child_distributions[0].sampled_action_prob()
|
||||
for c in self.child_distributions[1:]:
|
||||
p *= c.sampled_action_prob()
|
||||
return p
|
||||
|
||||
|
||||
TupleActions = namedtuple("TupleActions", ["batches"])
|
||||
|
||||
@@ -8,6 +8,7 @@ from ray.rllib.offline.json_writer import JsonWriter
|
||||
from ray.rllib.offline.output_writer import OutputWriter, NoopOutput
|
||||
from ray.rllib.offline.input_reader import InputReader
|
||||
from ray.rllib.offline.mixed_input import MixedInput
|
||||
from ray.rllib.offline.shuffled_input import ShuffledInput
|
||||
|
||||
__all__ = [
|
||||
"IOContext",
|
||||
@@ -17,4 +18,5 @@ __all__ = [
|
||||
"OutputWriter",
|
||||
"InputReader",
|
||||
"MixedInput",
|
||||
"ShuffledInput",
|
||||
]
|
||||
|
||||
@@ -2,7 +2,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
|
||||
|
||||
@@ -18,19 +17,3 @@ class InputReader(object):
|
||||
SampleBatch or MultiAgentBatch read.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SamplerInput(InputReader):
|
||||
"""Reads input experiences from an existing sampler."""
|
||||
|
||||
def __init__(self, sampler):
|
||||
self.sampler = sampler
|
||||
|
||||
@override(InputReader)
|
||||
def next(self):
|
||||
batches = [self.sampler.get_data()]
|
||||
batches.extend(self.sampler.get_extra_batches())
|
||||
if len(batches) > 1:
|
||||
return batches[0].concat_samples(batches)
|
||||
else:
|
||||
return batches[0]
|
||||
|
||||
@@ -4,7 +4,6 @@ from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from ray.rllib.offline.input_reader import SamplerInput
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
|
||||
|
||||
@@ -35,4 +34,4 @@ class IOContext(object):
|
||||
|
||||
@PublicAPI
|
||||
def default_sampler_input(self):
|
||||
return SamplerInput(self.evaluator.sampler)
|
||||
return self.evaluator.sampler
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, \
|
||||
OffPolicyEstimate
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
|
||||
class ImportanceSamplingEstimator(OffPolicyEstimator):
|
||||
"""The step-wise IS estimator.
|
||||
|
||||
Step-wise IS estimator described in https://arxiv.org/pdf/1511.03722.pdf"""
|
||||
|
||||
def __init__(self, policy, gamma):
|
||||
OffPolicyEstimator.__init__(self, policy, gamma)
|
||||
|
||||
@override(OffPolicyEstimator)
|
||||
def estimate(self, batch):
|
||||
self.check_can_estimate_for(batch)
|
||||
|
||||
rewards, old_prob = batch["rewards"], batch["action_prob"]
|
||||
new_prob = self.action_prob(batch)
|
||||
|
||||
# calculate importance ratios
|
||||
p = []
|
||||
for t in range(batch.count - 1):
|
||||
if t == 0:
|
||||
pt_prev = 1.0
|
||||
else:
|
||||
pt_prev = p[t - 1]
|
||||
p.append(pt_prev * new_prob[t] / old_prob[t])
|
||||
|
||||
# calculate stepwise IS estimate
|
||||
V_prev, V_step_IS = 0.0, 0.0
|
||||
for t in range(batch.count - 1):
|
||||
V_prev += rewards[t] * self.gamma**t
|
||||
V_step_IS += p[t] * rewards[t] * self.gamma**t
|
||||
|
||||
estimation = OffPolicyEstimate(
|
||||
"is", {
|
||||
"V_prev": V_prev,
|
||||
"V_step_IS": V_step_IS,
|
||||
"V_gain_est": V_step_IS / max(1e-8, V_prev),
|
||||
})
|
||||
return estimation
|
||||
@@ -44,6 +44,7 @@ class JsonReader(InputReader):
|
||||
|
||||
self.ioctx = ioctx or IOContext()
|
||||
if isinstance(inputs, six.string_types):
|
||||
inputs = os.path.abspath(os.path.expanduser(inputs))
|
||||
if os.path.isdir(inputs):
|
||||
inputs = os.path.join(inputs, "*.json")
|
||||
logger.warning(
|
||||
|
||||
@@ -43,13 +43,13 @@ class JsonWriter(OutputWriter):
|
||||
compress_columns (list): list of sample batch columns to compress.
|
||||
"""
|
||||
|
||||
self.path = path
|
||||
self.ioctx = ioctx or IOContext()
|
||||
self.max_file_size = max_file_size
|
||||
self.compress_columns = compress_columns
|
||||
if urlparse(path).scheme:
|
||||
self.path_is_uri = True
|
||||
else:
|
||||
path = os.path.abspath(os.path.expanduser(path))
|
||||
# Try to create local dirs if they don't exist
|
||||
try:
|
||||
os.makedirs(path)
|
||||
@@ -57,6 +57,7 @@ class JsonWriter(OutputWriter):
|
||||
pass # already exists
|
||||
assert os.path.exists(path), "Failed to create {}".format(path)
|
||||
self.path_is_uri = False
|
||||
self.path = path
|
||||
self.file_index = 0
|
||||
self.bytes_written = 0
|
||||
self.cur_file = None
|
||||
|
||||
@@ -6,10 +6,10 @@ import numpy as np
|
||||
|
||||
from ray.rllib.offline.input_reader import InputReader
|
||||
from ray.rllib.offline.json_reader import JsonReader
|
||||
from ray.rllib.utils.annotations import override, PublicAPI
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
|
||||
|
||||
@PublicAPI
|
||||
@DeveloperAPI
|
||||
class MixedInput(InputReader):
|
||||
"""Mixes input from a number of other input sources.
|
||||
|
||||
@@ -21,7 +21,7 @@ class MixedInput(InputReader):
|
||||
}, ioctx)
|
||||
"""
|
||||
|
||||
@PublicAPI
|
||||
@DeveloperAPI
|
||||
def __init__(self, dist, ioctx):
|
||||
"""Initialize a MixedInput.
|
||||
|
||||
|
||||
@@ -0,0 +1,107 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import namedtuple
|
||||
import logging
|
||||
|
||||
from ray.rllib.evaluation.sample_batch import MultiAgentBatch
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OffPolicyEstimate = namedtuple("OffPolicyEstimate",
|
||||
["estimator_name", "metrics"])
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class OffPolicyEstimator(object):
|
||||
"""Interface for an off policy reward estimator."""
|
||||
|
||||
@DeveloperAPI
|
||||
def __init__(self, policy, gamma):
|
||||
"""Creates an off-policy estimator.
|
||||
|
||||
Arguments:
|
||||
policy (PolicyGraph): Policy graph to evaluate.
|
||||
gamma (float): Discount of the MDP.
|
||||
"""
|
||||
self.policy = policy
|
||||
self.gamma = gamma
|
||||
self.new_estimates = []
|
||||
|
||||
@classmethod
|
||||
def create(cls, ioctx):
|
||||
"""Create an off-policy estimator from a IOContext."""
|
||||
gamma = ioctx.evaluator.policy_config["gamma"]
|
||||
# Grab a reference to the current model
|
||||
keys = list(ioctx.evaluator.policy_map.keys())
|
||||
if len(keys) > 1:
|
||||
raise NotImplementedError(
|
||||
"Off-policy estimation is not implemented for multi-agent. "
|
||||
"You can set `input_evaluation: []` to resolve this.")
|
||||
policy = ioctx.evaluator.get_policy(keys[0])
|
||||
return cls(policy, gamma)
|
||||
|
||||
@DeveloperAPI
|
||||
def estimate(self, batch):
|
||||
"""Returns an estimate for the given batch of experiences.
|
||||
|
||||
The batch will only contain data from one episode, but it may only be
|
||||
a fragment of an episode.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def action_prob(self, batch):
|
||||
"""Returns the probs for the batch actions for the current policy."""
|
||||
|
||||
num_state_inputs = 0
|
||||
for k in batch.keys():
|
||||
if k.startswith("state_in_"):
|
||||
num_state_inputs += 1
|
||||
state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)]
|
||||
_, _, info = self.policy.compute_actions(
|
||||
obs_batch=batch["obs"],
|
||||
state_batches=[batch[k] for k in state_keys],
|
||||
prev_action_batch=batch.data.get("prev_action"),
|
||||
prev_reward_batch=batch.data.get("prev_reward"),
|
||||
info_batch=batch.data.get("info"))
|
||||
if "action_prob" not in info:
|
||||
raise ValueError(
|
||||
"Off-policy estimation is not possible unless the policy "
|
||||
"returns action probabilities when computing actions (i.e., "
|
||||
"the 'action_prob' key is output by the policy graph). You "
|
||||
"can set `input_evaluation: []` to resolve this.")
|
||||
return info["action_prob"]
|
||||
|
||||
@DeveloperAPI
|
||||
def process(self, batch):
|
||||
self.new_estimates.append(self.estimate(batch))
|
||||
|
||||
@DeveloperAPI
|
||||
def check_can_estimate_for(self, batch):
|
||||
"""Returns whether we can support OPE for this batch."""
|
||||
|
||||
if isinstance(batch, MultiAgentBatch):
|
||||
raise ValueError(
|
||||
"IS-estimation is not implemented for multi-agent batches. "
|
||||
"You can set `input_evaluation: []` to resolve this.")
|
||||
|
||||
if "action_prob" not in batch:
|
||||
raise ValueError(
|
||||
"Off-policy estimation is not possible unless the inputs "
|
||||
"include action probabilities (i.e., the policy is stochastic "
|
||||
"and emits the 'action_prob' key). You can set "
|
||||
"`input_evaluation: []` to resolve this.")
|
||||
|
||||
@DeveloperAPI
|
||||
def get_metrics(self):
|
||||
"""Return a list of new episode metric estimates since the last call.
|
||||
|
||||
Returns:
|
||||
list of OffPolicyEstimate objects.
|
||||
"""
|
||||
out = self.new_estimates
|
||||
self.new_estimates = []
|
||||
return out
|
||||
@@ -0,0 +1,45 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
import random
|
||||
|
||||
from ray.rllib.offline.input_reader import InputReader
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class ShuffledInput(InputReader):
|
||||
"""Randomizes data over a sliding window buffer of N batches.
|
||||
|
||||
This increases the randomization of the data, which is useful if the
|
||||
batches were not in random order to start with.
|
||||
"""
|
||||
|
||||
@DeveloperAPI
|
||||
def __init__(self, child, n=0):
|
||||
"""Initialize a MixedInput.
|
||||
|
||||
Arguments:
|
||||
child (InputReader): child input reader to shuffle.
|
||||
n (int): if positive, shuffle input over this many batches.
|
||||
"""
|
||||
self.n = n
|
||||
self.child = child
|
||||
self.buffer = []
|
||||
|
||||
@override(InputReader)
|
||||
def next(self):
|
||||
if self.n <= 1:
|
||||
return self.child.next()
|
||||
if len(self.buffer) < self.n:
|
||||
logger.info("Filling shuffle buffer to {} batches".format(self.n))
|
||||
while len(self.buffer) < self.n:
|
||||
self.buffer.append(self.child.next())
|
||||
logger.info("Shuffle buffer filled")
|
||||
i = random.randint(0, len(self.buffer) - 1)
|
||||
self.buffer[i] = self.child.next()
|
||||
return random.choice(self.buffer)
|
||||
@@ -0,0 +1,56 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, \
|
||||
OffPolicyEstimate
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
|
||||
class WeightedImportanceSamplingEstimator(OffPolicyEstimator):
|
||||
"""The weighted step-wise IS estimator.
|
||||
|
||||
Step-wise WIS estimator in https://arxiv.org/pdf/1511.03722.pdf"""
|
||||
|
||||
def __init__(self, policy, gamma):
|
||||
OffPolicyEstimator.__init__(self, policy, gamma)
|
||||
self.filter_values = []
|
||||
self.filter_counts = []
|
||||
|
||||
@override(OffPolicyEstimator)
|
||||
def estimate(self, batch):
|
||||
self.check_can_estimate_for(batch)
|
||||
|
||||
rewards, old_prob = batch["rewards"], batch["action_prob"]
|
||||
new_prob = self.action_prob(batch)
|
||||
|
||||
# calculate importance ratios
|
||||
p = []
|
||||
for t in range(batch.count - 1):
|
||||
if t == 0:
|
||||
pt_prev = 1.0
|
||||
else:
|
||||
pt_prev = p[t - 1]
|
||||
p.append(pt_prev * new_prob[t] / old_prob[t])
|
||||
for t, v in enumerate(p):
|
||||
if t >= len(self.filter_values):
|
||||
self.filter_values.append(v)
|
||||
self.filter_counts.append(1.0)
|
||||
else:
|
||||
self.filter_values[t] += v
|
||||
self.filter_counts[t] += 1.0
|
||||
|
||||
# calculate stepwise weighted IS estimate
|
||||
V_prev, V_step_WIS = 0.0, 0.0
|
||||
for t in range(batch.count - 1):
|
||||
V_prev += rewards[t] * self.gamma**t
|
||||
w_t = self.filter_values[t] / self.filter_counts[t]
|
||||
V_step_WIS += p[t] / w_t * rewards[t] * self.gamma**t
|
||||
|
||||
estimation = OffPolicyEstimate(
|
||||
"wis", {
|
||||
"V_prev": V_prev,
|
||||
"V_step_WIS": V_step_WIS,
|
||||
"V_gain_est": V_step_WIS / max(1e-8, V_prev),
|
||||
})
|
||||
return estimation
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -69,7 +69,7 @@ class AgentIOTest(unittest.TestCase):
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"input": self.test_dir,
|
||||
"input_evaluation": None,
|
||||
"input_evaluation": [],
|
||||
})
|
||||
result = agent.train()
|
||||
self.assertEqual(result["timesteps_total"], 250) # read from input
|
||||
@@ -101,7 +101,7 @@ class AgentIOTest(unittest.TestCase):
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"input": self.test_dir,
|
||||
"input_evaluation": None,
|
||||
"input_evaluation": [],
|
||||
"postprocess_inputs": True, # adds back 'advantages'
|
||||
})
|
||||
|
||||
@@ -115,7 +115,7 @@ class AgentIOTest(unittest.TestCase):
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"input": self.test_dir,
|
||||
"input_evaluation": "simulation",
|
||||
"input_evaluation": ["simulation"],
|
||||
})
|
||||
for _ in range(50):
|
||||
result = agent.train()
|
||||
@@ -130,7 +130,7 @@ class AgentIOTest(unittest.TestCase):
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"input": glob.glob(self.test_dir + "/*.json"),
|
||||
"input_evaluation": None,
|
||||
"input_evaluation": [],
|
||||
"sample_batch_size": 99,
|
||||
})
|
||||
result = agent.train()
|
||||
@@ -147,7 +147,7 @@ class AgentIOTest(unittest.TestCase):
|
||||
"sampler": 0.9,
|
||||
},
|
||||
"train_batch_size": 2000,
|
||||
"input_evaluation": None,
|
||||
"input_evaluation": [],
|
||||
})
|
||||
result = agent.train()
|
||||
self.assertTrue(not np.isnan(result["episode_reward_mean"]))
|
||||
@@ -185,7 +185,7 @@ class AgentIOTest(unittest.TestCase):
|
||||
config={
|
||||
"num_workers": 0,
|
||||
"input": self.test_dir,
|
||||
"input_evaluation": "simulation",
|
||||
"input_evaluation": ["simulation"],
|
||||
"train_batch_size": 2000,
|
||||
"multiagent": {
|
||||
"policy_graphs": {
|
||||
|
||||
@@ -251,12 +251,20 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||
--ray-num-cpus 8 \
|
||||
--stop '{"training_iteration": 2}' \
|
||||
--config '{"num_workers": 2, "optimizer": {"num_replay_buffer_shards": 1}, "learning_starts": 100, "min_iter_time_s": 1}'
|
||||
|
||||
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/train.py \
|
||||
--env CartPole-v0 \
|
||||
--run MARWIL \
|
||||
--stop '{"training_iteration": 2}' \
|
||||
--config '{"input": "/ray/python/ray/rllib/test/data/cartpole_small", "learning_starts": 0}'
|
||||
--config '{"input": "/ray/python/ray/rllib/test/data/cartpole_small", "learning_starts": 0, "input_evaluation": ["wis", "is"], "shuffle_buffer_size": 10}'
|
||||
|
||||
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/train.py \
|
||||
--env CartPole-v0 \
|
||||
--run DQN \
|
||||
--stop '{"training_iteration": 2}' \
|
||||
--config '{"input": "/ray/python/ray/rllib/test/data/cartpole_small", "learning_starts": 0, "input_evaluation": ["wis", "is"], "soft_q": true}'
|
||||
|
||||
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/test/test_local.py
|
||||
|
||||
Reference in New Issue
Block a user