[rllib] Extra Changes for Usability (#2363)

This commit is contained in:
Richard Liaw
2018-07-24 20:51:22 -07:00
committed by GitHub
parent 05490b8cb9
commit 7edc677304
5 changed files with 178 additions and 12 deletions
+6 -4
View File
@@ -3,15 +3,17 @@ from ray.rllib.evaluation.interface import EvaluatorInterface
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.evaluation.torch_policy_graph import TorchPolicyGraph
from ray.rllib.evaluation.sample_batch import SampleBatch, MultiAgentBatch, \
SampleBatchBuilder, MultiAgentSampleBatchBuilder
from ray.rllib.evaluation.sample_batch import (SampleBatch, MultiAgentBatch,
SampleBatchBuilder,
MultiAgentSampleBatchBuilder)
from ray.rllib.evaluation.sampler import SyncSampler, AsyncSampler
from ray.rllib.evaluation.postprocessing import compute_advantages
from ray.rllib.evaluation.postprocessing import (compute_advantages,
compute_targets)
from ray.rllib.evaluation.metrics import collect_metrics
__all__ = [
"EvaluatorInterface", "PolicyEvaluator", "PolicyGraph", "TFPolicyGraph",
"TorchPolicyGraph", "SampleBatch", "MultiAgentBatch", "SampleBatchBuilder",
"MultiAgentSampleBatchBuilder", "SyncSampler", "AsyncSampler",
"compute_advantages", "collect_metrics"
"compute_advantages", "compute_targets", "collect_metrics"
]
@@ -0,0 +1,65 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from ray.rllib.evaluation.policy_graph import PolicyGraph
def _sample(probs):
return [np.random.choice(len(pr), p=pr) for pr in probs]
class KerasPolicyGraph(PolicyGraph):
"""Initialize the Keras Policy Graph.
This is a Policy Graph used for models with actor and critics.
Note: This class is built for specific usage of Actor-Critic models,
and is less general compared to TFPolicyGraph and TorchPolicyGraphs.
Args:
observation_space (gym.Space): Observation space of the policy.
action_space (gym.Space): Action space of the policy.
config (dict): Policy-specific configuration data.
actor (Model): A model that holds the policy.
critic (Model): A model that holds the value function.
"""
def __init__(self,
observation_space,
action_space,
config,
actor=None,
critic=None):
PolicyGraph.__init__(self, observation_space, action_space, config)
self.actor = actor
self.critic = critic
self.models = [self.actor, self.critic]
def compute_actions(self, obs, *args, **kwargs):
state = np.array(obs)
policy = self.actor.predict(state)
value = self.critic.predict(state)
return _sample(policy), [], {"vf_preds": value.flatten()}
def compute_apply(self, batch, *args):
self.actor.fit(
batch["obs"],
batch["adv_targets"],
epochs=1,
verbose=0,
steps_per_epoch=20)
self.critic.fit(
batch["obs"],
batch["value_targets"],
epochs=1,
verbose=0,
steps_per_epoch=20)
return {}, {}
def get_weights(self):
return [model.get_weights() for model in self.models]
def set_weights(self, weights):
return [model.set_weights(w) for model, w in zip(self.models, weights)]
+26 -3
View File
@@ -11,13 +11,13 @@ def discount(x, gamma):
return scipy.signal.lfilter([1], [1, -gamma], x[::-1], axis=0)[::-1]
def compute_advantages(rollout, last_r, gamma, lambda_=1.0, use_gae=True):
def compute_advantages(rollout, last_r, gamma=0.9, lambda_=1.0, use_gae=True):
"""Given a rollout, compute its value targets and the advantage.
Args:
rollout (PartialRollout): Partial Rollout Object
rollout (SampleBatch): SampleBatch of a single trajectory
last_r (float): Value estimation for last observation
gamma (float): Parameter for GAE
gamma (float): Discount factor.
lambda_ (float): Parameter for GAE
use_gae (bool): Using Generalized Advantage Estamation
@@ -52,3 +52,26 @@ def compute_advantages(rollout, last_r, gamma, lambda_=1.0, use_gae=True):
assert all(val.shape[0] == trajsize for val in traj.values()), \
"Rollout stacked incorrectly!"
return SampleBatch(traj)
def compute_targets(rollout, action_space, last_r=0.0, gamma=0.9, lambda_=1.0):
"""Given a rollout, compute targets.
Used for categorical crossentropy loss on the policy. Also assumes
there is a value function. Uses GAE to calculate advantages.
Args:
rollout (SampleBatch): SampleBatch of a single trajectory
action_space (gym.Space): Dimensions of the advantage targets.
last_r (float): Value estimation for last observation
gamma (float): Discount factor.
lambda_ (float): Parameter for GAE
"""
rollout = compute_advantages(rollout, last_r, gamma=gamma, lambda_=lambda_)
rollout["adv_targets"] = np.zeros((rollout.count, action_space.n))
rollout["adv_targets"][np.arange(rollout.count), rollout["actions"]] = \
rollout["advantages"]
rollout["value_targets"] = rollout["rewards"].copy()
rollout["value_targets"][:-1] += gamma * rollout["vf_preds"][1:]
return rollout
@@ -3,6 +3,7 @@ from __future__ import division
from __future__ import print_function
import ray
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.evaluation.sample_batch import MultiAgentBatch
@@ -77,6 +78,17 @@ class PolicyOptimizer(object):
"num_steps_sampled": self.num_steps_sampled,
}
def collect_metrics(self):
"""Returns evaluator and optimizer stats.
Returns:
res (TrainingResult): TrainingResult from evaluator metrics with
`info` replaced with stats from self.
"""
res = collect_metrics(self.local_evaluator, self.remote_evaluators)
res = res._replace(info=self.stats())
return res
def save(self):
"""Returns a serializable object representing the optimizer state."""
@@ -109,11 +121,64 @@ class PolicyOptimizer(object):
])
return local_result + remote_results
def collect_metrics(self):
res = collect_metrics(self.local_evaluator, self.remote_evaluators)
return res._replace(info=self.stats())
def _check_not_multiagent(self, sample_batch):
if isinstance(sample_batch, MultiAgentBatch):
raise NotImplementedError(
"This optimizer does not support multi-agent yet.")
@classmethod
def make(cls,
env_creator,
policy_graph,
optimizer_batch_size=None,
num_workers=0,
num_envs_per_worker=None,
optimizer_config=None,
remote_num_cpus=None,
remote_num_gpus=None,
**eval_kwargs):
"""Creates an Optimizer with local and remote evaluators.
Args:
env_creator(func): Function that returns a gym.Env given an
EnvContext wrapped configuration.
policy_graph (class|dict): Either a class implementing
PolicyGraph, or a dictionary of policy id strings to
(PolicyGraph, obs_space, action_space, config) tuples.
See PolicyEvaluator documentation.
optimizer_batch_size (int): Batch size summed across all workers.
Will override worker `batch_steps`.
num_workers (int): Number of remote evaluators
num_envs_per_worker (int): (Optional) Sets the number
environments per evaluator for vectorization.
If set, overrides `num_envs` in kwargs
for PolicyEvaluator.__init__.
optimizer_config (dict): Config passed to the optimizer.
remote_num_cpus (int): CPU specification for remote evaluator.
remote_num_gpus (int): GPU specification for remote evaluator.
**eval_kwargs: PolicyEvaluator Class non-positional args.
Returns:
(Optimizer) Instance of `cls` with evaluators configured
accordingly.
"""
optimizer_config = optimizer_config or {}
if num_envs_per_worker:
assert num_envs_per_worker > 0, "Improper num_envs_per_worker!"
eval_kwargs["num_envs"] = int(num_envs_per_worker)
if optimizer_batch_size:
assert optimizer_batch_size > 0
if num_workers > 1:
eval_kwargs["batch_steps"] = \
optimizer_batch_size // num_workers
else:
eval_kwargs["batch_steps"] = optimizer_batch_size
evaluator = PolicyEvaluator(env_creator, policy_graph, **eval_kwargs)
remote_cls = PolicyEvaluator.as_remote(remote_num_cpus,
remote_num_gpus)
remote_evaluators = [
remote_cls.remote(env_creator, policy_graph, **eval_kwargs)
for i in range(num_workers)
]
return cls(evaluator, remote_evaluators, optimizer_config)
+12 -1
View File
@@ -76,6 +76,8 @@ class UnifiedLogger(Logger):
self._log_syncer.sync_now(force=True)
def flush(self):
for logger in self._loggers:
logger.flush()
self._log_syncer.sync_now(force=True)
self._log_syncer.wait()
@@ -109,7 +111,7 @@ def to_tf_values(result, path):
values = []
for attr, value in result.items():
if value is not None:
if type(value) in [int, float]:
if type(value) in [int, float, np.float32, np.float64, np.int32]:
values.append(
tf.Summary.Value(
tag="/".join(path + [attr]), simple_value=value))
@@ -131,6 +133,15 @@ class _TFLogger(Logger):
values = to_tf_values(tmp, ["ray", "tune"])
train_stats = tf.Summary(value=values)
self._file_writer.add_summary(train_stats, result.timesteps_total)
timesteps_value = to_tf_values({
"timesteps_total": result.timesteps_total
}, ["ray", "tune"])
timesteps_stats = tf.Summary(value=timesteps_value)
self._file_writer.add_summary(timesteps_stats,
result.training_iteration)
def flush(self):
self._file_writer.flush()
def close(self):
self._file_writer.close()