mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 02:30:45 +08:00
[rllib] Basic infrastructure for off-policy estimation (IS, WIS) (#3941)
This commit is contained in:
@@ -97,6 +97,7 @@ class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
self.sess,
|
||||
obs_input=self.observations,
|
||||
action_sampler=action_dist.sample(),
|
||||
action_prob=action_dist.sampled_action_prob(),
|
||||
loss=self.model.loss() + self.loss.total_loss,
|
||||
loss_inputs=loss_in,
|
||||
state_inputs=self.model.state_in,
|
||||
@@ -153,7 +154,9 @@ class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def extra_compute_action_fetches(self):
|
||||
return {"vf_preds": self.vf}
|
||||
return dict(
|
||||
TFPolicyGraph.extra_compute_action_fetches(self),
|
||||
**{"vf_preds": self.vf})
|
||||
|
||||
def _value(self, ob, *args):
|
||||
feed_dict = {self.observations: [ob], self.model.seq_lens: [1]}
|
||||
|
||||
@@ -13,7 +13,8 @@ import tensorflow as tf
|
||||
from types import FunctionType
|
||||
|
||||
import ray
|
||||
from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter
|
||||
from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \
|
||||
ShuffledInput
|
||||
from ray.rllib.models import MODEL_DEFAULTS
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
@@ -145,18 +146,22 @@ COMMON_CONFIG = {
|
||||
# {"sampler": 0.4, "/tmp/*.json": 0.4, "s3://bucket/expert.json": 0.2}).
|
||||
# - a function that returns a rllib.offline.InputReader
|
||||
"input": "sampler",
|
||||
# Specify how to evaluate the current policy. This only makes sense to set
|
||||
# when the input is not already generating simulation data:
|
||||
# - None: don't evaluate the policy. The episode reward and other
|
||||
# metrics will be NaN if using offline data.
|
||||
# Specify how to evaluate the current policy. This only has an effect when
|
||||
# reading offline experiences. Available options:
|
||||
# - "wis": the weighted step-wise importance sampling estimator.
|
||||
# - "is": the step-wise importance sampling estimator.
|
||||
# - "simulation": run the environment in the background, but use
|
||||
# this data for evaluation only and not for learning.
|
||||
"input_evaluation": None,
|
||||
"input_evaluation": ["is", "wis"],
|
||||
# Whether to run postprocess_trajectory() on the trajectory fragments from
|
||||
# offline inputs. Note that postprocessing will be done using the *current*
|
||||
# policy, not the *behaviour* policy, which is typically undesirable for
|
||||
# on-policy algorithms.
|
||||
"postprocess_inputs": False,
|
||||
# If positive, input batches will be shuffled via a sliding window buffer
|
||||
# of this number of batches. Use this if the input data is not in random
|
||||
# enough order. Input is delayed until the shuffle buffer is filled.
|
||||
"shuffle_buffer_size": 0,
|
||||
# __sphinx_doc_input_end__
|
||||
# __sphinx_doc_output_begin__
|
||||
# Specify where experiences should be saved:
|
||||
@@ -552,10 +557,10 @@ class Agent(Trainable):
|
||||
raise ValueError(
|
||||
"The `use_gpu_for_workers` config is deprecated, please use "
|
||||
"`num_gpus_per_worker=1` instead.")
|
||||
if (config["input"] == "sampler"
|
||||
and config["input_evaluation"] is not None):
|
||||
if type(config["input_evaluation"]) != list:
|
||||
raise ValueError(
|
||||
"`input_evaluation` should not be set when input=sampler")
|
||||
"`input_evaluation` must be a list of strings, got {}".format(
|
||||
config["input_evaluation"]))
|
||||
|
||||
def _make_evaluator(self,
|
||||
cls,
|
||||
@@ -575,9 +580,13 @@ class Agent(Trainable):
|
||||
elif config["input"] == "sampler":
|
||||
input_creator = (lambda ioctx: ioctx.default_sampler_input())
|
||||
elif isinstance(config["input"], dict):
|
||||
input_creator = (lambda ioctx: MixedInput(config["input"], ioctx))
|
||||
input_creator = (lambda ioctx: ShuffledInput(
|
||||
MixedInput(config["input"], ioctx),
|
||||
config["shuffle_buffer_size"]))
|
||||
else:
|
||||
input_creator = (lambda ioctx: JsonReader(config["input"], ioctx))
|
||||
input_creator = (lambda ioctx: ShuffledInput(
|
||||
JsonReader(config["input"], ioctx),
|
||||
config["shuffle_buffer_size"]))
|
||||
|
||||
if isinstance(config["output"], FunctionType):
|
||||
output_creator = config["output"]
|
||||
@@ -596,6 +605,11 @@ class Agent(Trainable):
|
||||
max_file_size=config["output_max_file_size"],
|
||||
compress_columns=config["output_compress_columns"]))
|
||||
|
||||
if config["input"] == "sampler":
|
||||
input_evaluation = []
|
||||
else:
|
||||
input_evaluation = config["input_evaluation"]
|
||||
|
||||
return cls(
|
||||
env_creator,
|
||||
self.config["multiagent"]["policy_graphs"] or policy_graph,
|
||||
@@ -622,7 +636,7 @@ class Agent(Trainable):
|
||||
log_level=config["log_level"],
|
||||
callbacks=config["callbacks"],
|
||||
input_creator=input_creator,
|
||||
input_evaluation_method=config["input_evaluation"],
|
||||
input_evaluation=input_evaluation,
|
||||
output_creator=output_creator,
|
||||
remote_worker_envs=remote_worker_envs)
|
||||
|
||||
|
||||
@@ -269,6 +269,11 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
q_t, self.q_model = self._build_q_network(
|
||||
self.obs_t, observation_space, self.act_t)
|
||||
self.q_func_vars = _scope_vars(scope.name)
|
||||
self.stats = {
|
||||
"mean_q": tf.reduce_mean(q_t),
|
||||
"max_q": tf.reduce_max(q_t),
|
||||
"min_q": tf.reduce_min(q_t),
|
||||
}
|
||||
with tf.variable_scope(Q_SCOPE, reuse=True):
|
||||
q_tp0, _ = self._build_q_network(self.obs_t, observation_space,
|
||||
output_actions)
|
||||
@@ -416,6 +421,7 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
def extra_compute_grad_fetches(self):
|
||||
return {
|
||||
"td_error": self.loss.td_error,
|
||||
"stats": self.stats,
|
||||
}
|
||||
|
||||
@override(PolicyGraph)
|
||||
|
||||
@@ -68,6 +68,11 @@ DEFAULT_CONFIG = with_common_config({
|
||||
"exploration_final_eps": 0.02,
|
||||
# Update the target network every `target_network_update_freq` steps.
|
||||
"target_network_update_freq": 500,
|
||||
# Use softmax for sampling actions.
|
||||
"soft_q": False,
|
||||
# Softmax temperature. Q values are divided by this value prior to softmax.
|
||||
# Softmax approaches argmax as the temperature drops to zero.
|
||||
"softmax_temp": 1.0,
|
||||
|
||||
# === Replay buffer ===
|
||||
# Size of the replay buffer. Note that if async_updates is set, then
|
||||
|
||||
@@ -8,8 +8,7 @@ import tensorflow as tf
|
||||
import tensorflow.contrib.layers as layers
|
||||
|
||||
import ray
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.models import ModelCatalog, Categorical
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
@@ -182,7 +181,14 @@ class QNetwork(object):
|
||||
|
||||
|
||||
class QValuePolicy(object):
|
||||
def __init__(self, q_values, observations, num_actions, stochastic, eps):
|
||||
def __init__(self, q_values, observations, num_actions, stochastic, eps,
|
||||
softmax, softmax_temp):
|
||||
if softmax:
|
||||
action_dist = Categorical(q_values / softmax_temp)
|
||||
self.action = action_dist.sample()
|
||||
self.action_prob = action_dist.sampled_action_prob()
|
||||
return
|
||||
|
||||
deterministic_actions = tf.argmax(q_values, axis=1)
|
||||
batch_size = tf.shape(observations)[0]
|
||||
|
||||
@@ -200,6 +206,7 @@ class QValuePolicy(object):
|
||||
deterministic_actions)
|
||||
self.action = tf.cond(stochastic, lambda: stochastic_actions,
|
||||
lambda: deterministic_actions)
|
||||
self.action_prob = None
|
||||
|
||||
|
||||
class QLoss(object):
|
||||
@@ -300,10 +307,12 @@ class DQNPolicyGraph(TFPolicyGraph):
|
||||
with tf.variable_scope(Q_SCOPE) as scope:
|
||||
q_values, q_logits, q_dist, _ = self._build_q_network(
|
||||
self.cur_observations, observation_space)
|
||||
self.q_values = q_values
|
||||
self.q_func_vars = _scope_vars(scope.name)
|
||||
|
||||
# Action outputs
|
||||
self.output_actions = self._build_q_value_policy(q_values)
|
||||
self.output_actions, self.action_prob = self._build_q_value_policy(
|
||||
q_values)
|
||||
|
||||
# Replay inputs
|
||||
self.obs_t = tf.placeholder(
|
||||
@@ -387,6 +396,7 @@ class DQNPolicyGraph(TFPolicyGraph):
|
||||
self.sess,
|
||||
obs_input=self.cur_observations,
|
||||
action_sampler=self.output_actions,
|
||||
action_prob=self.action_prob,
|
||||
loss=model.loss() + self.loss.loss,
|
||||
loss_inputs=self.loss_inputs,
|
||||
update_ops=q_batchnorm_update_ops)
|
||||
@@ -412,6 +422,13 @@ class DQNPolicyGraph(TFPolicyGraph):
|
||||
grads_and_vars = [(g, v) for (g, v) in grads_and_vars if g is not None]
|
||||
return grads_and_vars
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def extra_compute_action_fetches(self):
|
||||
return dict(
|
||||
TFPolicyGraph.extra_compute_action_fetches(self), **{
|
||||
"q_values": self.q_values,
|
||||
})
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def extra_compute_action_feed_dict(self):
|
||||
return {
|
||||
@@ -474,8 +491,10 @@ class DQNPolicyGraph(TFPolicyGraph):
|
||||
return qnet.value, qnet.logits, qnet.dist, qnet.model
|
||||
|
||||
def _build_q_value_policy(self, q_values):
|
||||
return QValuePolicy(q_values, self.cur_observations, self.num_actions,
|
||||
self.stochastic, self.eps).action
|
||||
policy = QValuePolicy(
|
||||
q_values, self.cur_observations, self.num_actions, self.stochastic,
|
||||
self.eps, self.config["soft_q"], self.config["softmax_temp"])
|
||||
return policy.action, policy.action_prob
|
||||
|
||||
def _build_q_loss(self, q_t_selected, q_logits_t_selected, q_tp1_best,
|
||||
q_dist_tp1_best):
|
||||
@@ -511,26 +530,16 @@ def _adjust_nstep(n_step, gamma, obs, actions, rewards, new_obs, dones):
|
||||
rewards[i] += gamma**j * rewards[i + j]
|
||||
|
||||
|
||||
def _postprocess_dqn(policy_graph, sample_batch):
|
||||
obs, actions, rewards, new_obs, dones = [
|
||||
list(x) for x in sample_batch.columns(
|
||||
["obs", "actions", "rewards", "new_obs", "dones"])
|
||||
]
|
||||
|
||||
def _postprocess_dqn(policy_graph, batch):
|
||||
# N-step Q adjustments
|
||||
if policy_graph.config["n_step"] > 1:
|
||||
_adjust_nstep(policy_graph.config["n_step"],
|
||||
policy_graph.config["gamma"], obs, actions, rewards,
|
||||
new_obs, dones)
|
||||
policy_graph.config["gamma"], batch["obs"],
|
||||
batch["actions"], batch["rewards"], batch["new_obs"],
|
||||
batch["dones"])
|
||||
|
||||
batch = SampleBatch({
|
||||
"obs": obs,
|
||||
"actions": actions,
|
||||
"rewards": rewards,
|
||||
"new_obs": new_obs,
|
||||
"dones": dones,
|
||||
"weights": np.ones_like(rewards)
|
||||
})
|
||||
if "weights" not in batch:
|
||||
batch["weights"] = np.ones_like(batch["rewards"])
|
||||
|
||||
# Prioritize on the worker side
|
||||
if batch.count > 0 and policy_graph.config["worker_side_prioritization"]:
|
||||
|
||||
@@ -215,6 +215,7 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
self.sess,
|
||||
obs_input=observations,
|
||||
action_sampler=action_dist.sample(),
|
||||
action_prob=action_dist.sampled_action_prob(),
|
||||
loss=self.model.loss() + self.loss.total_loss,
|
||||
loss_inputs=loss_in,
|
||||
state_inputs=self.model.state_in,
|
||||
@@ -270,7 +271,9 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def extra_compute_action_fetches(self):
|
||||
return {"behaviour_logits": self.model.outputs}
|
||||
return dict(
|
||||
TFPolicyGraph.extra_compute_action_fetches(self),
|
||||
**{"behaviour_logits": self.model.outputs})
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def extra_compute_grad_fetches(self):
|
||||
|
||||
@@ -19,8 +19,8 @@ DEFAULT_CONFIG = with_common_config({
|
||||
"postprocess_inputs": True,
|
||||
# Whether to rollout "complete_episodes" or "truncate_episodes"
|
||||
"batch_mode": "complete_episodes",
|
||||
# Read data from historic data and evaluate by a sampler
|
||||
"input_evaluation": "simulation",
|
||||
# Use importance sampling estimators for reward
|
||||
"input_evaluation": ["is", "wis"],
|
||||
# Learning rate for adam optimizer
|
||||
"lr": 1e-4,
|
||||
# Number of timesteps collected for each SGD round
|
||||
|
||||
@@ -107,6 +107,7 @@ class MARWILPolicyGraph(TFPolicyGraph):
|
||||
self.sess,
|
||||
obs_input=self.obs_t,
|
||||
action_sampler=self.output_actions,
|
||||
action_prob=action_dist.sampled_action_prob(),
|
||||
loss=self.model.loss() + objective,
|
||||
loss_inputs=self.loss_inputs,
|
||||
state_inputs=self.model.state_in,
|
||||
|
||||
@@ -67,6 +67,7 @@ class PGPolicyGraph(TFPolicyGraph):
|
||||
sess,
|
||||
obs_input=obs,
|
||||
action_sampler=action_dist.sample(),
|
||||
action_prob=action_dist.sampled_action_prob(),
|
||||
loss=self.model.loss() + loss,
|
||||
loss_inputs=loss_in,
|
||||
state_inputs=self.model.state_in,
|
||||
|
||||
@@ -320,6 +320,7 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
self.sess,
|
||||
obs_input=observations,
|
||||
action_sampler=action_dist.sample(),
|
||||
action_prob=action_dist.sampled_action_prob(),
|
||||
loss=self.model.loss() + self.loss.total_loss,
|
||||
loss_inputs=loss_in,
|
||||
state_inputs=self.model.state_in,
|
||||
@@ -373,7 +374,7 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
out = {"behaviour_logits": self.model.outputs}
|
||||
if not self.config["vtrace"]:
|
||||
out["vf_preds"] = self.value_function
|
||||
return out
|
||||
return dict(TFPolicyGraph.extra_compute_action_fetches(self), **out)
|
||||
|
||||
def extra_compute_grad_fetches(self):
|
||||
return self.stats_fetches
|
||||
|
||||
@@ -234,6 +234,7 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
self.sess,
|
||||
obs_input=obs_ph,
|
||||
action_sampler=self.sampler,
|
||||
action_prob=curr_action_dist.sampled_action_prob(),
|
||||
loss=self.model.loss() + self.loss_obj.loss,
|
||||
loss_inputs=self.loss_in,
|
||||
state_inputs=self.model.state_in,
|
||||
@@ -307,7 +308,11 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def extra_compute_action_fetches(self):
|
||||
return {"vf_preds": self.value_function, "logits": self.logits}
|
||||
return dict(
|
||||
TFPolicyGraph.extra_compute_action_fetches(self), **{
|
||||
"vf_preds": self.value_function,
|
||||
"logits": self.logits
|
||||
})
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def extra_compute_grad_fetches(self):
|
||||
|
||||
@@ -8,6 +8,8 @@ import collections
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.evaluation.sampler import RolloutMetrics
|
||||
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimate
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -31,15 +33,14 @@ def collect_episodes(local_evaluator,
|
||||
"""Gathers new episodes metrics tuples from the given evaluators."""
|
||||
|
||||
pending = [
|
||||
a.apply.remote(lambda ev: ev.sampler.get_metrics())
|
||||
for a in remote_evaluators
|
||||
a.apply.remote(lambda ev: ev.get_metrics()) for a in remote_evaluators
|
||||
]
|
||||
collected, _ = ray.wait(
|
||||
pending, num_returns=len(pending), timeout=timeout_seconds * 1.0)
|
||||
num_metric_batches_dropped = len(pending) - len(collected)
|
||||
|
||||
metric_lists = ray.get(collected)
|
||||
metric_lists.append(local_evaluator.sampler.get_metrics())
|
||||
metric_lists.append(local_evaluator.get_metrics())
|
||||
episodes = []
|
||||
for metrics in metric_lists:
|
||||
episodes.extend(metrics)
|
||||
@@ -60,6 +61,9 @@ def summarize_episodes(episodes, new_episodes, num_dropped):
|
||||
logger.warning("WARNING: {} workers have NOT returned metrics".format(
|
||||
num_dropped))
|
||||
|
||||
episodes, estimates = _partition(episodes)
|
||||
new_episodes, _ = _partition(new_episodes)
|
||||
|
||||
episode_rewards = []
|
||||
episode_lengths = []
|
||||
policy_rewards = collections.defaultdict(list)
|
||||
@@ -95,6 +99,16 @@ def summarize_episodes(episodes, new_episodes, num_dropped):
|
||||
custom_metrics[k + "_max"] = float("nan")
|
||||
del custom_metrics[k]
|
||||
|
||||
estimators = collections.defaultdict(lambda: collections.defaultdict(list))
|
||||
for e in estimates:
|
||||
acc = estimators[e.estimator_name]
|
||||
for k, v in e.metrics.items():
|
||||
acc[k].append(v)
|
||||
for name, metrics in estimators.items():
|
||||
for k, v_list in metrics.items():
|
||||
metrics[k] = np.mean(v_list)
|
||||
estimators[name] = dict(metrics)
|
||||
|
||||
return dict(
|
||||
episode_reward_max=max_reward,
|
||||
episode_reward_min=min_reward,
|
||||
@@ -103,4 +117,19 @@ def summarize_episodes(episodes, new_episodes, num_dropped):
|
||||
episodes_this_iter=len(new_episodes),
|
||||
policy_reward_mean=dict(policy_rewards),
|
||||
custom_metrics=dict(custom_metrics),
|
||||
off_policy_estimator=dict(estimators),
|
||||
num_metric_batches_dropped=num_dropped)
|
||||
|
||||
|
||||
def _partition(episodes):
|
||||
"""Divides metrics data into true rollouts vs off-policy estimates."""
|
||||
|
||||
rollouts, estimates = [], []
|
||||
for e in episodes:
|
||||
if isinstance(e, RolloutMetrics):
|
||||
rollouts.append(e)
|
||||
elif isinstance(e, OffPolicyEstimate):
|
||||
estimates.append(e)
|
||||
else:
|
||||
raise ValueError("Unknown metric type: {}".format(e))
|
||||
return rollouts, estimates
|
||||
|
||||
@@ -19,6 +19,8 @@ from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.offline import NoopOutput, IOContext, OutputWriter, InputReader
|
||||
from ray.rllib.offline.is_estimator import ImportanceSamplingEstimator
|
||||
from ray.rllib.offline.wis_estimator import WeightedImportanceSamplingEstimator
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.models.preprocessors import NoPreprocessor
|
||||
from ray.rllib.utils import merge_dicts
|
||||
@@ -116,7 +118,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
log_level=None,
|
||||
callbacks=None,
|
||||
input_creator=lambda ioctx: ioctx.default_sampler_input(),
|
||||
input_evaluation_method=None,
|
||||
input_evaluation=frozenset([]),
|
||||
output_creator=lambda ioctx: NoopOutput(),
|
||||
remote_worker_envs=False):
|
||||
"""Initialize a policy evaluator.
|
||||
@@ -184,11 +186,11 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
callbacks (dict): Dict of custom debug callbacks.
|
||||
input_creator (func): Function that returns an InputReader object
|
||||
for loading previous generated experiences.
|
||||
input_evaluation_method (str): How to evaluate the current policy.
|
||||
This only applies when the input is reading offline data.
|
||||
Options are:
|
||||
- None: don't evaluate the policy. The episode reward and
|
||||
other metrics will be NaN.
|
||||
input_evaluation (list): How to evaluate the policy performance.
|
||||
This only makes sense to set when the input is reading offline
|
||||
data. The possible values include:
|
||||
- "is": the step-wise importance sampling estimator.
|
||||
- "wis": the weighted step-wise is estimator.
|
||||
- "simulation": run the environment in the background, but
|
||||
use this data for evaluation only and never for learning.
|
||||
output_creator (func): Function that returns an OutputWriter object
|
||||
@@ -316,16 +318,24 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
raise ValueError("Unsupported batch mode: {}".format(
|
||||
self.batch_mode))
|
||||
|
||||
if input_evaluation_method == "simulation":
|
||||
logger.warning(
|
||||
"Requested 'simulation' input evaluation method: "
|
||||
"will discard all sampler outputs and keep only metrics.")
|
||||
sample_async = True
|
||||
elif input_evaluation_method is None:
|
||||
pass
|
||||
else:
|
||||
raise ValueError("Unknown evaluation method: {}".format(
|
||||
input_evaluation_method))
|
||||
self.io_context = IOContext(log_dir, policy_config, worker_index, self)
|
||||
self.reward_estimators = []
|
||||
for method in input_evaluation:
|
||||
if method == "simulation":
|
||||
logger.warning(
|
||||
"Requested 'simulation' input evaluation method: "
|
||||
"will discard all sampler outputs and keep only metrics.")
|
||||
sample_async = True
|
||||
elif method == "is":
|
||||
ise = ImportanceSamplingEstimator.create(self.io_context)
|
||||
self.reward_estimators.append(ise)
|
||||
elif method == "wis":
|
||||
wise = WeightedImportanceSamplingEstimator.create(
|
||||
self.io_context)
|
||||
self.reward_estimators.append(wise)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown evaluation method: {}".format(method))
|
||||
|
||||
if sample_async:
|
||||
self.sampler = AsyncSampler(
|
||||
@@ -341,7 +351,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
pack=pack_episodes,
|
||||
tf_sess=self.tf_sess,
|
||||
clip_actions=clip_actions,
|
||||
blackhole_outputs=input_evaluation_method == "simulation")
|
||||
blackhole_outputs="simulation" in input_evaluation)
|
||||
self.sampler.start()
|
||||
else:
|
||||
self.sampler = SyncSampler(
|
||||
@@ -358,7 +368,6 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
tf_sess=self.tf_sess,
|
||||
clip_actions=clip_actions)
|
||||
|
||||
self.io_context = IOContext(log_dir, policy_config, worker_index, self)
|
||||
self.input_reader = input_creator(self.io_context)
|
||||
assert isinstance(self.input_reader, InputReader), self.input_reader
|
||||
self.output_writer = output_creator(self.io_context)
|
||||
@@ -402,6 +411,12 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
# for better compression inside the writer.
|
||||
self.output_writer.write(batch)
|
||||
|
||||
# Do off-policy estimation if needed
|
||||
if self.reward_estimators:
|
||||
for sub_batch in batch.split_by_episode():
|
||||
for estimator in self.reward_estimators:
|
||||
estimator.process(sub_batch)
|
||||
|
||||
if self.compress_observations:
|
||||
if isinstance(batch, MultiAgentBatch):
|
||||
for data in batch.policy_batches.values():
|
||||
@@ -504,6 +519,15 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
self.policy_map[DEFAULT_POLICY_ID].learn_on_batch(samples))
|
||||
return grad_fetch
|
||||
|
||||
@DeveloperAPI
|
||||
def get_metrics(self):
|
||||
"""Returns a list of new RolloutMetric objects from evaluation."""
|
||||
|
||||
out = self.sampler.get_metrics()
|
||||
for m in self.reward_estimators:
|
||||
out.extend(m.get_metrics())
|
||||
return out
|
||||
|
||||
@DeveloperAPI
|
||||
def foreach_env(self, func):
|
||||
"""Apply the given function to each underlying env instance."""
|
||||
|
||||
@@ -16,6 +16,8 @@ from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.env.base_env import BaseEnv
|
||||
from ray.rllib.env.atari_wrappers import get_wrapper_by_cls, MonitorEnv
|
||||
from ray.rllib.models.action_dist import TupleActions
|
||||
from ray.rllib.offline import InputReader
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -31,7 +33,20 @@ PolicyEvalData = namedtuple("PolicyEvalData", [
|
||||
])
|
||||
|
||||
|
||||
class SyncSampler(object):
|
||||
class SamplerInput(InputReader):
|
||||
"""Reads input experiences from an existing sampler."""
|
||||
|
||||
@override(InputReader)
|
||||
def next(self):
|
||||
batches = [self.get_data()]
|
||||
batches.extend(self.get_extra_batches())
|
||||
if len(batches) > 1:
|
||||
return batches[0].concat_samples(batches)
|
||||
else:
|
||||
return batches[0]
|
||||
|
||||
|
||||
class SyncSampler(SamplerInput):
|
||||
def __init__(self,
|
||||
env,
|
||||
policies,
|
||||
@@ -87,7 +102,7 @@ class SyncSampler(object):
|
||||
return extra
|
||||
|
||||
|
||||
class AsyncSampler(threading.Thread):
|
||||
class AsyncSampler(threading.Thread, SamplerInput):
|
||||
def __init__(self,
|
||||
env,
|
||||
policies,
|
||||
|
||||
@@ -52,6 +52,7 @@ class TFPolicyGraph(PolicyGraph):
|
||||
action_sampler,
|
||||
loss,
|
||||
loss_inputs,
|
||||
action_prob=None,
|
||||
state_inputs=None,
|
||||
state_outputs=None,
|
||||
prev_action_input=None,
|
||||
@@ -77,6 +78,7 @@ class TFPolicyGraph(PolicyGraph):
|
||||
and has shape [BATCH_SIZE, data...]. These keys will be read
|
||||
from postprocessed sample batches and fed into the specified
|
||||
placeholders during loss computation.
|
||||
action_prob (Tensor): probability of the sampled action.
|
||||
state_inputs (list): list of RNN state input Tensors.
|
||||
state_outputs (list): list of RNN state output Tensors.
|
||||
prev_action_input (Tensor): placeholder for previous actions
|
||||
@@ -104,6 +106,7 @@ class TFPolicyGraph(PolicyGraph):
|
||||
self._loss_inputs = loss_inputs
|
||||
self._loss_input_dict = dict(self._loss_inputs)
|
||||
self._is_training = self._get_is_training_placeholder()
|
||||
self._action_prob = action_prob
|
||||
self._state_inputs = state_inputs or []
|
||||
self._state_outputs = state_outputs or []
|
||||
for i, ph in enumerate(self._state_inputs):
|
||||
@@ -231,8 +234,14 @@ class TFPolicyGraph(PolicyGraph):
|
||||
|
||||
@DeveloperAPI
|
||||
def extra_compute_action_fetches(self):
|
||||
"""Extra values to fetch and return from compute_actions()."""
|
||||
return {} # e.g, value function
|
||||
"""Extra values to fetch and return from compute_actions().
|
||||
|
||||
By default we only return action probability info (if present).
|
||||
"""
|
||||
if self._action_prob is not None:
|
||||
return {"action_prob": self._action_prob}
|
||||
else:
|
||||
return {}
|
||||
|
||||
@DeveloperAPI
|
||||
def extra_compute_grad_feed_dict(self):
|
||||
|
||||
@@ -33,6 +33,7 @@ if __name__ == "__main__":
|
||||
agent_index=0,
|
||||
obs=obs,
|
||||
actions=action,
|
||||
action_prob=1.0, # put the true action probability here
|
||||
rewards=rew,
|
||||
prev_actions=prev_action,
|
||||
prev_rewards=prev_reward,
|
||||
|
||||
@@ -24,6 +24,7 @@ class ActionDistribution(object):
|
||||
@DeveloperAPI
|
||||
def __init__(self, inputs):
|
||||
self.inputs = inputs
|
||||
self.sample_op = self._build_sample_op()
|
||||
|
||||
@DeveloperAPI
|
||||
def logp(self, x):
|
||||
@@ -37,13 +38,27 @@ class ActionDistribution(object):
|
||||
|
||||
@DeveloperAPI
|
||||
def entropy(self):
|
||||
"""The entroy of the action distribution."""
|
||||
"""The entropy of the action distribution."""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def _build_sample_op(self):
|
||||
"""Implement this instead of sample(), to enable op reuse.
|
||||
|
||||
This is needed since the sample op is non-deterministic and is shared
|
||||
between sample() and sampled_action_prob().
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def sample(self):
|
||||
"""Draw a sample from the action distribution."""
|
||||
raise NotImplementedError
|
||||
return self.sample_op
|
||||
|
||||
@DeveloperAPI
|
||||
def sampled_action_prob(self):
|
||||
"""Returns the log probability of the sampled action."""
|
||||
return tf.exp(self.logp(self.sample_op))
|
||||
|
||||
|
||||
class Categorical(ActionDistribution):
|
||||
@@ -95,7 +110,7 @@ class Categorical(ActionDistribution):
|
||||
p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), reduction_indices=[1])
|
||||
|
||||
@override(ActionDistribution)
|
||||
def sample(self):
|
||||
def _build_sample_op(self):
|
||||
return tf.squeeze(tf.multinomial(self.inputs, 1), axis=1)
|
||||
|
||||
|
||||
@@ -107,11 +122,11 @@ class DiagGaussian(ActionDistribution):
|
||||
"""
|
||||
|
||||
def __init__(self, inputs):
|
||||
ActionDistribution.__init__(self, inputs)
|
||||
mean, log_std = tf.split(inputs, 2, axis=1)
|
||||
self.mean = mean
|
||||
self.log_std = log_std
|
||||
self.std = tf.exp(log_std)
|
||||
ActionDistribution.__init__(self, inputs)
|
||||
|
||||
@override(ActionDistribution)
|
||||
def logp(self, x):
|
||||
@@ -136,7 +151,7 @@ class DiagGaussian(ActionDistribution):
|
||||
reduction_indices=[1])
|
||||
|
||||
@override(ActionDistribution)
|
||||
def sample(self):
|
||||
def _build_sample_op(self):
|
||||
return self.mean + self.std * tf.random_normal(tf.shape(self.mean))
|
||||
|
||||
|
||||
@@ -147,7 +162,11 @@ class Deterministic(ActionDistribution):
|
||||
"""
|
||||
|
||||
@override(ActionDistribution)
|
||||
def sample(self):
|
||||
def sampled_action_prob(self):
|
||||
return 1.0
|
||||
|
||||
@override(ActionDistribution)
|
||||
def _build_sample_op(self):
|
||||
return self.inputs
|
||||
|
||||
|
||||
@@ -205,5 +224,12 @@ class MultiActionDistribution(ActionDistribution):
|
||||
def sample(self):
|
||||
return TupleActions([s.sample() for s in self.child_distributions])
|
||||
|
||||
@override(ActionDistribution)
|
||||
def sampled_action_prob(self):
|
||||
p = self.child_distributions[0].sampled_action_prob()
|
||||
for c in self.child_distributions[1:]:
|
||||
p *= c.sampled_action_prob()
|
||||
return p
|
||||
|
||||
|
||||
TupleActions = namedtuple("TupleActions", ["batches"])
|
||||
|
||||
@@ -8,6 +8,7 @@ from ray.rllib.offline.json_writer import JsonWriter
|
||||
from ray.rllib.offline.output_writer import OutputWriter, NoopOutput
|
||||
from ray.rllib.offline.input_reader import InputReader
|
||||
from ray.rllib.offline.mixed_input import MixedInput
|
||||
from ray.rllib.offline.shuffled_input import ShuffledInput
|
||||
|
||||
__all__ = [
|
||||
"IOContext",
|
||||
@@ -17,4 +18,5 @@ __all__ = [
|
||||
"OutputWriter",
|
||||
"InputReader",
|
||||
"MixedInput",
|
||||
"ShuffledInput",
|
||||
]
|
||||
|
||||
@@ -2,7 +2,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
|
||||
|
||||
@@ -18,19 +17,3 @@ class InputReader(object):
|
||||
SampleBatch or MultiAgentBatch read.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SamplerInput(InputReader):
|
||||
"""Reads input experiences from an existing sampler."""
|
||||
|
||||
def __init__(self, sampler):
|
||||
self.sampler = sampler
|
||||
|
||||
@override(InputReader)
|
||||
def next(self):
|
||||
batches = [self.sampler.get_data()]
|
||||
batches.extend(self.sampler.get_extra_batches())
|
||||
if len(batches) > 1:
|
||||
return batches[0].concat_samples(batches)
|
||||
else:
|
||||
return batches[0]
|
||||
|
||||
@@ -4,7 +4,6 @@ from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from ray.rllib.offline.input_reader import SamplerInput
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
|
||||
|
||||
@@ -35,4 +34,4 @@ class IOContext(object):
|
||||
|
||||
@PublicAPI
|
||||
def default_sampler_input(self):
|
||||
return SamplerInput(self.evaluator.sampler)
|
||||
return self.evaluator.sampler
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, \
|
||||
OffPolicyEstimate
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
|
||||
class ImportanceSamplingEstimator(OffPolicyEstimator):
|
||||
"""The step-wise IS estimator.
|
||||
|
||||
Step-wise IS estimator described in https://arxiv.org/pdf/1511.03722.pdf"""
|
||||
|
||||
def __init__(self, policy, gamma):
|
||||
OffPolicyEstimator.__init__(self, policy, gamma)
|
||||
|
||||
@override(OffPolicyEstimator)
|
||||
def estimate(self, batch):
|
||||
self.check_can_estimate_for(batch)
|
||||
|
||||
rewards, old_prob = batch["rewards"], batch["action_prob"]
|
||||
new_prob = self.action_prob(batch)
|
||||
|
||||
# calculate importance ratios
|
||||
p = []
|
||||
for t in range(batch.count - 1):
|
||||
if t == 0:
|
||||
pt_prev = 1.0
|
||||
else:
|
||||
pt_prev = p[t - 1]
|
||||
p.append(pt_prev * new_prob[t] / old_prob[t])
|
||||
|
||||
# calculate stepwise IS estimate
|
||||
V_prev, V_step_IS = 0.0, 0.0
|
||||
for t in range(batch.count - 1):
|
||||
V_prev += rewards[t] * self.gamma**t
|
||||
V_step_IS += p[t] * rewards[t] * self.gamma**t
|
||||
|
||||
estimation = OffPolicyEstimate(
|
||||
"is", {
|
||||
"V_prev": V_prev,
|
||||
"V_step_IS": V_step_IS,
|
||||
"V_gain_est": V_step_IS / max(1e-8, V_prev),
|
||||
})
|
||||
return estimation
|
||||
@@ -44,6 +44,7 @@ class JsonReader(InputReader):
|
||||
|
||||
self.ioctx = ioctx or IOContext()
|
||||
if isinstance(inputs, six.string_types):
|
||||
inputs = os.path.abspath(os.path.expanduser(inputs))
|
||||
if os.path.isdir(inputs):
|
||||
inputs = os.path.join(inputs, "*.json")
|
||||
logger.warning(
|
||||
|
||||
@@ -43,13 +43,13 @@ class JsonWriter(OutputWriter):
|
||||
compress_columns (list): list of sample batch columns to compress.
|
||||
"""
|
||||
|
||||
self.path = path
|
||||
self.ioctx = ioctx or IOContext()
|
||||
self.max_file_size = max_file_size
|
||||
self.compress_columns = compress_columns
|
||||
if urlparse(path).scheme:
|
||||
self.path_is_uri = True
|
||||
else:
|
||||
path = os.path.abspath(os.path.expanduser(path))
|
||||
# Try to create local dirs if they don't exist
|
||||
try:
|
||||
os.makedirs(path)
|
||||
@@ -57,6 +57,7 @@ class JsonWriter(OutputWriter):
|
||||
pass # already exists
|
||||
assert os.path.exists(path), "Failed to create {}".format(path)
|
||||
self.path_is_uri = False
|
||||
self.path = path
|
||||
self.file_index = 0
|
||||
self.bytes_written = 0
|
||||
self.cur_file = None
|
||||
|
||||
@@ -6,10 +6,10 @@ import numpy as np
|
||||
|
||||
from ray.rllib.offline.input_reader import InputReader
|
||||
from ray.rllib.offline.json_reader import JsonReader
|
||||
from ray.rllib.utils.annotations import override, PublicAPI
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
|
||||
|
||||
@PublicAPI
|
||||
@DeveloperAPI
|
||||
class MixedInput(InputReader):
|
||||
"""Mixes input from a number of other input sources.
|
||||
|
||||
@@ -21,7 +21,7 @@ class MixedInput(InputReader):
|
||||
}, ioctx)
|
||||
"""
|
||||
|
||||
@PublicAPI
|
||||
@DeveloperAPI
|
||||
def __init__(self, dist, ioctx):
|
||||
"""Initialize a MixedInput.
|
||||
|
||||
|
||||
@@ -0,0 +1,107 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import namedtuple
|
||||
import logging
|
||||
|
||||
from ray.rllib.evaluation.sample_batch import MultiAgentBatch
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OffPolicyEstimate = namedtuple("OffPolicyEstimate",
|
||||
["estimator_name", "metrics"])
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class OffPolicyEstimator(object):
|
||||
"""Interface for an off policy reward estimator."""
|
||||
|
||||
@DeveloperAPI
|
||||
def __init__(self, policy, gamma):
|
||||
"""Creates an off-policy estimator.
|
||||
|
||||
Arguments:
|
||||
policy (PolicyGraph): Policy graph to evaluate.
|
||||
gamma (float): Discount of the MDP.
|
||||
"""
|
||||
self.policy = policy
|
||||
self.gamma = gamma
|
||||
self.new_estimates = []
|
||||
|
||||
@classmethod
|
||||
def create(cls, ioctx):
|
||||
"""Create an off-policy estimator from a IOContext."""
|
||||
gamma = ioctx.evaluator.policy_config["gamma"]
|
||||
# Grab a reference to the current model
|
||||
keys = list(ioctx.evaluator.policy_map.keys())
|
||||
if len(keys) > 1:
|
||||
raise NotImplementedError(
|
||||
"Off-policy estimation is not implemented for multi-agent. "
|
||||
"You can set `input_evaluation: []` to resolve this.")
|
||||
policy = ioctx.evaluator.get_policy(keys[0])
|
||||
return cls(policy, gamma)
|
||||
|
||||
@DeveloperAPI
|
||||
def estimate(self, batch):
|
||||
"""Returns an estimate for the given batch of experiences.
|
||||
|
||||
The batch will only contain data from one episode, but it may only be
|
||||
a fragment of an episode.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def action_prob(self, batch):
|
||||
"""Returns the probs for the batch actions for the current policy."""
|
||||
|
||||
num_state_inputs = 0
|
||||
for k in batch.keys():
|
||||
if k.startswith("state_in_"):
|
||||
num_state_inputs += 1
|
||||
state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)]
|
||||
_, _, info = self.policy.compute_actions(
|
||||
obs_batch=batch["obs"],
|
||||
state_batches=[batch[k] for k in state_keys],
|
||||
prev_action_batch=batch.data.get("prev_action"),
|
||||
prev_reward_batch=batch.data.get("prev_reward"),
|
||||
info_batch=batch.data.get("info"))
|
||||
if "action_prob" not in info:
|
||||
raise ValueError(
|
||||
"Off-policy estimation is not possible unless the policy "
|
||||
"returns action probabilities when computing actions (i.e., "
|
||||
"the 'action_prob' key is output by the policy graph). You "
|
||||
"can set `input_evaluation: []` to resolve this.")
|
||||
return info["action_prob"]
|
||||
|
||||
@DeveloperAPI
|
||||
def process(self, batch):
|
||||
self.new_estimates.append(self.estimate(batch))
|
||||
|
||||
@DeveloperAPI
|
||||
def check_can_estimate_for(self, batch):
|
||||
"""Returns whether we can support OPE for this batch."""
|
||||
|
||||
if isinstance(batch, MultiAgentBatch):
|
||||
raise ValueError(
|
||||
"IS-estimation is not implemented for multi-agent batches. "
|
||||
"You can set `input_evaluation: []` to resolve this.")
|
||||
|
||||
if "action_prob" not in batch:
|
||||
raise ValueError(
|
||||
"Off-policy estimation is not possible unless the inputs "
|
||||
"include action probabilities (i.e., the policy is stochastic "
|
||||
"and emits the 'action_prob' key). You can set "
|
||||
"`input_evaluation: []` to resolve this.")
|
||||
|
||||
@DeveloperAPI
|
||||
def get_metrics(self):
|
||||
"""Return a list of new episode metric estimates since the last call.
|
||||
|
||||
Returns:
|
||||
list of OffPolicyEstimate objects.
|
||||
"""
|
||||
out = self.new_estimates
|
||||
self.new_estimates = []
|
||||
return out
|
||||
@@ -0,0 +1,45 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
import random
|
||||
|
||||
from ray.rllib.offline.input_reader import InputReader
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class ShuffledInput(InputReader):
|
||||
"""Randomizes data over a sliding window buffer of N batches.
|
||||
|
||||
This increases the randomization of the data, which is useful if the
|
||||
batches were not in random order to start with.
|
||||
"""
|
||||
|
||||
@DeveloperAPI
|
||||
def __init__(self, child, n=0):
|
||||
"""Initialize a MixedInput.
|
||||
|
||||
Arguments:
|
||||
child (InputReader): child input reader to shuffle.
|
||||
n (int): if positive, shuffle input over this many batches.
|
||||
"""
|
||||
self.n = n
|
||||
self.child = child
|
||||
self.buffer = []
|
||||
|
||||
@override(InputReader)
|
||||
def next(self):
|
||||
if self.n <= 1:
|
||||
return self.child.next()
|
||||
if len(self.buffer) < self.n:
|
||||
logger.info("Filling shuffle buffer to {} batches".format(self.n))
|
||||
while len(self.buffer) < self.n:
|
||||
self.buffer.append(self.child.next())
|
||||
logger.info("Shuffle buffer filled")
|
||||
i = random.randint(0, len(self.buffer) - 1)
|
||||
self.buffer[i] = self.child.next()
|
||||
return random.choice(self.buffer)
|
||||
@@ -0,0 +1,56 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, \
|
||||
OffPolicyEstimate
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
|
||||
class WeightedImportanceSamplingEstimator(OffPolicyEstimator):
|
||||
"""The weighted step-wise IS estimator.
|
||||
|
||||
Step-wise WIS estimator in https://arxiv.org/pdf/1511.03722.pdf"""
|
||||
|
||||
def __init__(self, policy, gamma):
|
||||
OffPolicyEstimator.__init__(self, policy, gamma)
|
||||
self.filter_values = []
|
||||
self.filter_counts = []
|
||||
|
||||
@override(OffPolicyEstimator)
|
||||
def estimate(self, batch):
|
||||
self.check_can_estimate_for(batch)
|
||||
|
||||
rewards, old_prob = batch["rewards"], batch["action_prob"]
|
||||
new_prob = self.action_prob(batch)
|
||||
|
||||
# calculate importance ratios
|
||||
p = []
|
||||
for t in range(batch.count - 1):
|
||||
if t == 0:
|
||||
pt_prev = 1.0
|
||||
else:
|
||||
pt_prev = p[t - 1]
|
||||
p.append(pt_prev * new_prob[t] / old_prob[t])
|
||||
for t, v in enumerate(p):
|
||||
if t >= len(self.filter_values):
|
||||
self.filter_values.append(v)
|
||||
self.filter_counts.append(1.0)
|
||||
else:
|
||||
self.filter_values[t] += v
|
||||
self.filter_counts[t] += 1.0
|
||||
|
||||
# calculate stepwise weighted IS estimate
|
||||
V_prev, V_step_WIS = 0.0, 0.0
|
||||
for t in range(batch.count - 1):
|
||||
V_prev += rewards[t] * self.gamma**t
|
||||
w_t = self.filter_values[t] / self.filter_counts[t]
|
||||
V_step_WIS += p[t] / w_t * rewards[t] * self.gamma**t
|
||||
|
||||
estimation = OffPolicyEstimate(
|
||||
"wis", {
|
||||
"V_prev": V_prev,
|
||||
"V_step_WIS": V_step_WIS,
|
||||
"V_gain_est": V_step_WIS / max(1e-8, V_prev),
|
||||
})
|
||||
return estimation
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -69,7 +69,7 @@ class AgentIOTest(unittest.TestCase):
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"input": self.test_dir,
|
||||
"input_evaluation": None,
|
||||
"input_evaluation": [],
|
||||
})
|
||||
result = agent.train()
|
||||
self.assertEqual(result["timesteps_total"], 250) # read from input
|
||||
@@ -101,7 +101,7 @@ class AgentIOTest(unittest.TestCase):
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"input": self.test_dir,
|
||||
"input_evaluation": None,
|
||||
"input_evaluation": [],
|
||||
"postprocess_inputs": True, # adds back 'advantages'
|
||||
})
|
||||
|
||||
@@ -115,7 +115,7 @@ class AgentIOTest(unittest.TestCase):
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"input": self.test_dir,
|
||||
"input_evaluation": "simulation",
|
||||
"input_evaluation": ["simulation"],
|
||||
})
|
||||
for _ in range(50):
|
||||
result = agent.train()
|
||||
@@ -130,7 +130,7 @@ class AgentIOTest(unittest.TestCase):
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"input": glob.glob(self.test_dir + "/*.json"),
|
||||
"input_evaluation": None,
|
||||
"input_evaluation": [],
|
||||
"sample_batch_size": 99,
|
||||
})
|
||||
result = agent.train()
|
||||
@@ -147,7 +147,7 @@ class AgentIOTest(unittest.TestCase):
|
||||
"sampler": 0.9,
|
||||
},
|
||||
"train_batch_size": 2000,
|
||||
"input_evaluation": None,
|
||||
"input_evaluation": [],
|
||||
})
|
||||
result = agent.train()
|
||||
self.assertTrue(not np.isnan(result["episode_reward_mean"]))
|
||||
@@ -185,7 +185,7 @@ class AgentIOTest(unittest.TestCase):
|
||||
config={
|
||||
"num_workers": 0,
|
||||
"input": self.test_dir,
|
||||
"input_evaluation": "simulation",
|
||||
"input_evaluation": ["simulation"],
|
||||
"train_batch_size": 2000,
|
||||
"multiagent": {
|
||||
"policy_graphs": {
|
||||
|
||||
Reference in New Issue
Block a user