mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 16:32:21 +08:00
[rllib] Clean up concepts documentation and policy optimizer creation (#4592)
This commit is contained in:
@@ -26,5 +26,6 @@ class A2CTrainer(A3CTrainer):
|
||||
@override(A3CTrainer)
|
||||
def _make_optimizer(self):
|
||||
return SyncSamplesOptimizer(
|
||||
self.local_evaluator, self.remote_evaluators,
|
||||
{"train_batch_size": self.config["train_batch_size"]})
|
||||
self.local_evaluator,
|
||||
self.remote_evaluators,
|
||||
train_batch_size=self.config["train_batch_size"])
|
||||
|
||||
@@ -77,4 +77,4 @@ class A3CTrainer(Trainer):
|
||||
def _make_optimizer(self):
|
||||
return AsyncGradientsOptimizer(self.local_evaluator,
|
||||
self.remote_evaluators,
|
||||
self.config["optimizer"])
|
||||
**self.config["optimizer"])
|
||||
|
||||
@@ -229,7 +229,8 @@ class DQNTrainer(Trainer):
|
||||
self.remote_evaluators = None
|
||||
|
||||
self.optimizer = getattr(optimizers, config["optimizer_class"])(
|
||||
self.local_evaluator, self.remote_evaluators, config["optimizer"])
|
||||
self.local_evaluator, self.remote_evaluators,
|
||||
**config["optimizer"])
|
||||
# Create the remote evaluators *after* the replay actors
|
||||
if self.remote_evaluators is None:
|
||||
self.remote_evaluators = create_remote_evaluators()
|
||||
|
||||
@@ -123,8 +123,9 @@ class ImpalaTrainer(Trainer):
|
||||
|
||||
self.remote_evaluators = self.make_remote_evaluators(
|
||||
env_creator, policy_cls, config["num_workers"])
|
||||
self.optimizer = AsyncSamplesOptimizer(
|
||||
self.local_evaluator, self.remote_evaluators, config["optimizer"])
|
||||
self.optimizer = AsyncSamplesOptimizer(self.local_evaluator,
|
||||
self.remote_evaluators,
|
||||
**config["optimizer"])
|
||||
if config["entropy_coeff"] < 0:
|
||||
raise DeprecationWarning("entropy_coeff must be >= 0")
|
||||
|
||||
|
||||
@@ -53,11 +53,12 @@ class MARWILTrainer(Trainer):
|
||||
self.remote_evaluators = self.make_remote_evaluators(
|
||||
env_creator, self._policy_graph, config["num_workers"])
|
||||
self.optimizer = SyncBatchReplayOptimizer(
|
||||
self.local_evaluator, self.remote_evaluators, {
|
||||
"learning_starts": config["learning_starts"],
|
||||
"buffer_size": config["replay_buffer_size"],
|
||||
"train_batch_size": config["train_batch_size"],
|
||||
})
|
||||
self.local_evaluator,
|
||||
self.remote_evaluators,
|
||||
learning_starts=config["learning_starts"],
|
||||
buffer_size=config["replay_buffer_size"],
|
||||
train_batch_size=config["train_batch_size"],
|
||||
)
|
||||
|
||||
@override(Trainer)
|
||||
def _train(self):
|
||||
|
||||
@@ -49,7 +49,7 @@ class PGTrainer(Trainer):
|
||||
config["optimizer"],
|
||||
**{"train_batch_size": config["train_batch_size"]})
|
||||
self.optimizer = SyncSamplesOptimizer(
|
||||
self.local_evaluator, self.remote_evaluators, optimizer_config)
|
||||
self.local_evaluator, self.remote_evaluators, **optimizer_config)
|
||||
|
||||
@override(Trainer)
|
||||
def _train(self):
|
||||
|
||||
@@ -79,22 +79,22 @@ class PPOTrainer(Trainer):
|
||||
env_creator, self._policy_graph, config["num_workers"])
|
||||
if config["simple_optimizer"]:
|
||||
self.optimizer = SyncSamplesOptimizer(
|
||||
self.local_evaluator, self.remote_evaluators, {
|
||||
"num_sgd_iter": config["num_sgd_iter"],
|
||||
"train_batch_size": config["train_batch_size"],
|
||||
})
|
||||
self.local_evaluator,
|
||||
self.remote_evaluators,
|
||||
num_sgd_iter=config["num_sgd_iter"],
|
||||
train_batch_size=config["train_batch_size"])
|
||||
else:
|
||||
self.optimizer = LocalMultiGPUOptimizer(
|
||||
self.local_evaluator, self.remote_evaluators, {
|
||||
"sgd_batch_size": config["sgd_minibatch_size"],
|
||||
"num_sgd_iter": config["num_sgd_iter"],
|
||||
"num_gpus": config["num_gpus"],
|
||||
"sample_batch_size": config["sample_batch_size"],
|
||||
"num_envs_per_worker": config["num_envs_per_worker"],
|
||||
"train_batch_size": config["train_batch_size"],
|
||||
"standardize_fields": ["advantages"],
|
||||
"straggler_mitigation": config["straggler_mitigation"],
|
||||
})
|
||||
self.local_evaluator,
|
||||
self.remote_evaluators,
|
||||
sgd_batch_size=config["sgd_minibatch_size"],
|
||||
num_sgd_iter=config["num_sgd_iter"],
|
||||
num_gpus=config["num_gpus"],
|
||||
sample_batch_size=config["sample_batch_size"],
|
||||
num_envs_per_worker=config["num_envs_per_worker"],
|
||||
train_batch_size=config["train_batch_size"],
|
||||
standardize_fields=["advantages"],
|
||||
straggler_mitigation=config["straggler_mitigation"])
|
||||
|
||||
@override(Trainer)
|
||||
def _train(self):
|
||||
|
||||
@@ -299,7 +299,7 @@ class QMixPolicyGraph(PolicyGraph):
|
||||
mask_elems,
|
||||
"target_mean": (targets * mask).sum().item() / mask_elems,
|
||||
}
|
||||
return {LEARNER_STATS_KEY: stats}, {}
|
||||
return {LEARNER_STATS_KEY: stats}
|
||||
|
||||
@override(PolicyGraph)
|
||||
def get_initial_state(self):
|
||||
|
||||
@@ -56,7 +56,7 @@ class KerasPolicyGraph(PolicyGraph):
|
||||
epochs=1,
|
||||
verbose=0,
|
||||
steps_per_epoch=20)
|
||||
return {}, {}
|
||||
return {}
|
||||
|
||||
def get_weights(self):
|
||||
return [model.get_weights() for model in self.models]
|
||||
|
||||
@@ -574,14 +574,14 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
continue
|
||||
policy = self.policy_map[pid]
|
||||
if builder and hasattr(policy, "_build_learn_on_batch"):
|
||||
to_fetch[pid], _ = policy._build_learn_on_batch(
|
||||
to_fetch[pid] = policy._build_learn_on_batch(
|
||||
builder, batch)
|
||||
else:
|
||||
info_out[pid], _ = policy.learn_on_batch(batch)
|
||||
info_out[pid] = policy.learn_on_batch(batch)
|
||||
info_out.update({k: builder.get(v) for k, v in to_fetch.items()})
|
||||
else:
|
||||
info_out, _ = (
|
||||
self.policy_map[DEFAULT_POLICY_ID].learn_on_batch(samples))
|
||||
info_out = self.policy_map[DEFAULT_POLICY_ID].learn_on_batch(
|
||||
samples)
|
||||
if log_once("learn_out"):
|
||||
logger.info("Training output:\n\n{}\n".format(summarize(info_out)))
|
||||
return info_out
|
||||
|
||||
@@ -163,7 +163,6 @@ class PolicyGraph(object):
|
||||
|
||||
Returns:
|
||||
grad_info: dictionary of extra metadata from compute_gradients().
|
||||
apply_info: dictionary of extra metadata from apply_gradients().
|
||||
|
||||
Examples:
|
||||
>>> batch = ev.sample()
|
||||
@@ -171,8 +170,8 @@ class PolicyGraph(object):
|
||||
"""
|
||||
|
||||
grads, grad_info = self.compute_gradients(samples)
|
||||
apply_info = self.apply_gradients(grads)
|
||||
return grad_info, apply_info
|
||||
self.apply_gradients(grads)
|
||||
return grad_info
|
||||
|
||||
@DeveloperAPI
|
||||
def compute_gradients(self, postprocessed_batch):
|
||||
@@ -191,9 +190,6 @@ class PolicyGraph(object):
|
||||
"""Applies previously computed gradients.
|
||||
|
||||
Either this or learn_on_batch() must be implemented by subclasses.
|
||||
|
||||
Returns:
|
||||
info (dict): Extra policy-specific values
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -193,7 +193,7 @@ class TFPolicyGraph(PolicyGraph):
|
||||
def apply_gradients(self, gradients):
|
||||
builder = TFRunBuilder(self._sess, "apply_gradients")
|
||||
fetches = self._build_apply_gradients(builder, gradients)
|
||||
return builder.get(fetches)
|
||||
builder.get(fetches)
|
||||
|
||||
@override(PolicyGraph)
|
||||
def learn_on_batch(self, postprocessed_batch):
|
||||
@@ -267,16 +267,6 @@ class TFPolicyGraph(PolicyGraph):
|
||||
"""Extra values to fetch and return from compute_gradients()."""
|
||||
return {LEARNER_STATS_KEY: {}} # e.g, stats, td error, etc.
|
||||
|
||||
@DeveloperAPI
|
||||
def extra_apply_grad_feed_dict(self):
|
||||
"""Extra dict to pass to the apply gradients session run."""
|
||||
return {}
|
||||
|
||||
@DeveloperAPI
|
||||
def extra_apply_grad_fetches(self):
|
||||
"""Extra values to fetch and return from apply_gradients()."""
|
||||
return {} # e.g., batch norm updates
|
||||
|
||||
@DeveloperAPI
|
||||
def optimizer(self):
|
||||
"""TF optimizer to use for policy optimization."""
|
||||
@@ -405,24 +395,20 @@ class TFPolicyGraph(PolicyGraph):
|
||||
raise ValueError(
|
||||
"Unexpected number of gradients to apply, got {} for {}".
|
||||
format(gradients, self._grads))
|
||||
builder.add_feed_dict(self.extra_apply_grad_feed_dict())
|
||||
builder.add_feed_dict({self._is_training: True})
|
||||
builder.add_feed_dict(dict(zip(self._grads, gradients)))
|
||||
fetches = builder.add_fetches(
|
||||
[self._apply_op, self.extra_apply_grad_fetches()])
|
||||
return fetches[1]
|
||||
fetches = builder.add_fetches([self._apply_op])
|
||||
return fetches[0]
|
||||
|
||||
def _build_learn_on_batch(self, builder, postprocessed_batch):
|
||||
builder.add_feed_dict(self.extra_compute_grad_feed_dict())
|
||||
builder.add_feed_dict(self.extra_apply_grad_feed_dict())
|
||||
builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch))
|
||||
builder.add_feed_dict({self._is_training: True})
|
||||
fetches = builder.add_fetches([
|
||||
self._apply_op,
|
||||
self._get_grad_and_stats_fetches(),
|
||||
self.extra_apply_grad_fetches()
|
||||
])
|
||||
return fetches[1], fetches[2]
|
||||
return fetches[1]
|
||||
|
||||
def _get_grad_and_stats_fetches(self):
|
||||
fetches = self.extra_compute_grad_fetches()
|
||||
|
||||
@@ -118,7 +118,6 @@ class TorchPolicyGraph(PolicyGraph):
|
||||
if g is not None:
|
||||
p.grad = torch.from_numpy(g).to(self.device)
|
||||
self._optimizer.step()
|
||||
return {}
|
||||
|
||||
@override(PolicyGraph)
|
||||
def get_weights(self):
|
||||
|
||||
@@ -46,7 +46,7 @@ class RandomPolicy(PolicyGraph):
|
||||
|
||||
def learn_on_batch(self, samples):
|
||||
"""No learning."""
|
||||
return {}, {}
|
||||
return {}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -48,7 +48,7 @@ class CustomPolicy(PolicyGraph):
|
||||
|
||||
def learn_on_batch(self, samples):
|
||||
# implement your learning code here
|
||||
return {}, {}
|
||||
return {}
|
||||
|
||||
def update_some_value(self, w):
|
||||
# can also call other methods on policies
|
||||
|
||||
@@ -17,8 +17,9 @@ class AsyncGradientsOptimizer(PolicyOptimizer):
|
||||
gradient computations on the remote workers.
|
||||
"""
|
||||
|
||||
@override(PolicyOptimizer)
|
||||
def _init(self, grads_per_step=100):
|
||||
def __init__(self, local_evaluator, remote_evaluators, grads_per_step=100):
|
||||
PolicyOptimizer.__init__(self, local_evaluator, remote_evaluators)
|
||||
|
||||
self.apply_timer = TimerStat()
|
||||
self.wait_timer = TimerStat()
|
||||
self.dispatch_timer = TimerStat()
|
||||
|
||||
@@ -46,20 +46,22 @@ class AsyncReplayOptimizer(PolicyOptimizer):
|
||||
"td_error" array in the info return of compute_gradients(). This error
|
||||
term will be used for sample prioritization."""
|
||||
|
||||
@override(PolicyOptimizer)
|
||||
def _init(self,
|
||||
learning_starts=1000,
|
||||
buffer_size=10000,
|
||||
prioritized_replay=True,
|
||||
prioritized_replay_alpha=0.6,
|
||||
prioritized_replay_beta=0.4,
|
||||
prioritized_replay_eps=1e-6,
|
||||
train_batch_size=512,
|
||||
sample_batch_size=50,
|
||||
num_replay_buffer_shards=1,
|
||||
max_weight_sync_delay=400,
|
||||
debug=False,
|
||||
batch_replay=False):
|
||||
def __init__(self,
|
||||
local_evaluator,
|
||||
remote_evaluators,
|
||||
learning_starts=1000,
|
||||
buffer_size=10000,
|
||||
prioritized_replay=True,
|
||||
prioritized_replay_alpha=0.6,
|
||||
prioritized_replay_beta=0.4,
|
||||
prioritized_replay_eps=1e-6,
|
||||
train_batch_size=512,
|
||||
sample_batch_size=50,
|
||||
num_replay_buffer_shards=1,
|
||||
max_weight_sync_delay=400,
|
||||
debug=False,
|
||||
batch_replay=False):
|
||||
PolicyOptimizer.__init__(self, local_evaluator, remote_evaluators)
|
||||
|
||||
self.debug = debug
|
||||
self.batch_replay = batch_replay
|
||||
|
||||
@@ -27,23 +27,26 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
|
||||
and remote evaluators (IMPALA actors).
|
||||
"""
|
||||
|
||||
@override(PolicyOptimizer)
|
||||
def _init(self,
|
||||
train_batch_size=500,
|
||||
sample_batch_size=50,
|
||||
num_envs_per_worker=1,
|
||||
num_gpus=0,
|
||||
lr=0.0005,
|
||||
replay_buffer_num_slots=0,
|
||||
replay_proportion=0.0,
|
||||
num_data_loader_buffers=1,
|
||||
max_sample_requests_in_flight_per_worker=2,
|
||||
broadcast_interval=1,
|
||||
num_sgd_iter=1,
|
||||
minibatch_buffer_size=1,
|
||||
learner_queue_size=16,
|
||||
num_aggregation_workers=0,
|
||||
_fake_gpus=False):
|
||||
def __init__(self,
|
||||
local_evaluator,
|
||||
remote_evaluators,
|
||||
train_batch_size=500,
|
||||
sample_batch_size=50,
|
||||
num_envs_per_worker=1,
|
||||
num_gpus=0,
|
||||
lr=0.0005,
|
||||
replay_buffer_num_slots=0,
|
||||
replay_proportion=0.0,
|
||||
num_data_loader_buffers=1,
|
||||
max_sample_requests_in_flight_per_worker=2,
|
||||
broadcast_interval=1,
|
||||
num_sgd_iter=1,
|
||||
minibatch_buffer_size=1,
|
||||
learner_queue_size=16,
|
||||
num_aggregation_workers=0,
|
||||
_fake_gpus=False):
|
||||
PolicyOptimizer.__init__(self, local_evaluator, remote_evaluators)
|
||||
|
||||
self._stats_start_time = time.time()
|
||||
self._last_stats_time = {}
|
||||
self._last_stats_sum = {}
|
||||
|
||||
@@ -250,12 +250,10 @@ class LocalSyncParallelOptimizer(object):
|
||||
}
|
||||
for tower in self._towers:
|
||||
feed_dict.update(tower.loss_graph.extra_compute_grad_feed_dict())
|
||||
feed_dict.update(tower.loss_graph.extra_apply_grad_feed_dict())
|
||||
|
||||
fetches = {"train": self._train_op}
|
||||
for tower in self._towers:
|
||||
fetches.update(tower.loss_graph.extra_compute_grad_fetches())
|
||||
fetches.update(tower.loss_graph.extra_apply_grad_fetches())
|
||||
|
||||
return sess.run(fetches, feed_dict=feed_dict)
|
||||
|
||||
|
||||
@@ -39,16 +39,19 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
|
||||
may result in unexpected behavior.
|
||||
"""
|
||||
|
||||
@override(PolicyOptimizer)
|
||||
def _init(self,
|
||||
sgd_batch_size=128,
|
||||
num_sgd_iter=10,
|
||||
sample_batch_size=200,
|
||||
num_envs_per_worker=1,
|
||||
train_batch_size=1024,
|
||||
num_gpus=0,
|
||||
standardize_fields=[],
|
||||
straggler_mitigation=False):
|
||||
def __init__(self,
|
||||
local_evaluator,
|
||||
remote_evaluators,
|
||||
sgd_batch_size=128,
|
||||
num_sgd_iter=10,
|
||||
sample_batch_size=200,
|
||||
num_envs_per_worker=1,
|
||||
train_batch_size=1024,
|
||||
num_gpus=0,
|
||||
standardize_fields=[],
|
||||
straggler_mitigation=False):
|
||||
PolicyOptimizer.__init__(self, local_evaluator, remote_evaluators)
|
||||
|
||||
self.batch_size = sgd_batch_size
|
||||
self.num_sgd_iter = num_sgd_iter
|
||||
self.num_envs_per_worker = num_envs_per_worker
|
||||
|
||||
@@ -6,7 +6,6 @@ import logging
|
||||
|
||||
import ray
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.evaluation.metrics import collect_episodes, summarize_episodes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -39,11 +38,10 @@ class PolicyOptimizer(object):
|
||||
"""
|
||||
|
||||
@DeveloperAPI
|
||||
def __init__(self, local_evaluator, remote_evaluators=None, config=None):
|
||||
def __init__(self, local_evaluator, remote_evaluators=None):
|
||||
"""Create an optimizer instance.
|
||||
|
||||
Args:
|
||||
config (dict): Optimizer-specific arguments.
|
||||
local_evaluator (Evaluator): Local evaluator instance, required.
|
||||
remote_evaluators (list): A list of Ray actor handles to remote
|
||||
evaluators instances. If empty, the optimizer should fall back
|
||||
@@ -52,22 +50,11 @@ class PolicyOptimizer(object):
|
||||
self.local_evaluator = local_evaluator
|
||||
self.remote_evaluators = remote_evaluators or []
|
||||
self.episode_history = []
|
||||
self.config = config or {}
|
||||
self._init(**self.config)
|
||||
|
||||
# Counters that should be updated by sub-classes
|
||||
self.num_steps_trained = 0
|
||||
self.num_steps_sampled = 0
|
||||
|
||||
logger.debug("Created policy optimizer with {}: {}".format(
|
||||
config, self))
|
||||
|
||||
@DeveloperAPI
|
||||
def _init(self, **config):
|
||||
"""Subclasses should prefer overriding this instead of __init__."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def step(self):
|
||||
"""Takes a logical optimization step.
|
||||
@@ -170,60 +157,3 @@ class PolicyOptimizer(object):
|
||||
for i, ev in enumerate(self.remote_evaluators)
|
||||
])
|
||||
return local_result + remote_results
|
||||
|
||||
@classmethod
|
||||
def make(cls,
|
||||
env_creator,
|
||||
policy_graph,
|
||||
optimizer_batch_size=None,
|
||||
num_workers=0,
|
||||
num_envs_per_worker=None,
|
||||
optimizer_config=None,
|
||||
remote_num_cpus=None,
|
||||
remote_num_gpus=None,
|
||||
**eval_kwargs):
|
||||
"""Creates an Optimizer with local and remote evaluators.
|
||||
|
||||
Args:
|
||||
env_creator(func): Function that returns a gym.Env given an
|
||||
EnvContext wrapped configuration.
|
||||
policy_graph (class|dict): Either a class implementing
|
||||
PolicyGraph, or a dictionary of policy id strings to
|
||||
(PolicyGraph, obs_space, action_space, config) tuples.
|
||||
See PolicyEvaluator documentation.
|
||||
optimizer_batch_size (int): Batch size summed across all workers.
|
||||
Will override worker `batch_steps`.
|
||||
num_workers (int): Number of remote evaluators
|
||||
num_envs_per_worker (int): (Optional) Sets the number
|
||||
environments per evaluator for vectorization.
|
||||
If set, overrides `num_envs` in kwargs
|
||||
for PolicyEvaluator.__init__.
|
||||
optimizer_config (dict): Config passed to the optimizer.
|
||||
remote_num_cpus (int): CPU specification for remote evaluator.
|
||||
remote_num_gpus (int): GPU specification for remote evaluator.
|
||||
**eval_kwargs: PolicyEvaluator Class non-positional args.
|
||||
|
||||
Returns:
|
||||
(Optimizer) Instance of `cls` with evaluators configured
|
||||
accordingly.
|
||||
"""
|
||||
optimizer_config = optimizer_config or {}
|
||||
if num_envs_per_worker:
|
||||
assert num_envs_per_worker > 0, "Improper num_envs_per_worker!"
|
||||
eval_kwargs["num_envs"] = int(num_envs_per_worker)
|
||||
if optimizer_batch_size:
|
||||
assert optimizer_batch_size > 0
|
||||
if num_workers > 1:
|
||||
eval_kwargs["batch_steps"] = \
|
||||
optimizer_batch_size // num_workers
|
||||
else:
|
||||
eval_kwargs["batch_steps"] = optimizer_batch_size
|
||||
evaluator = PolicyEvaluator(env_creator, policy_graph, **eval_kwargs)
|
||||
remote_cls = PolicyEvaluator.as_remote(remote_num_cpus,
|
||||
remote_num_gpus)
|
||||
remote_evaluators = [
|
||||
remote_cls.remote(env_creator, policy_graph, **eval_kwargs)
|
||||
for i in range(num_workers)
|
||||
]
|
||||
|
||||
return cls(evaluator, remote_evaluators, optimizer_config)
|
||||
|
||||
@@ -18,11 +18,14 @@ class SyncBatchReplayOptimizer(PolicyOptimizer):
|
||||
|
||||
This enables RNN support. Does not currently support prioritization."""
|
||||
|
||||
@override(PolicyOptimizer)
|
||||
def _init(self,
|
||||
learning_starts=1000,
|
||||
buffer_size=10000,
|
||||
train_batch_size=32):
|
||||
def __init__(self,
|
||||
local_evaluator,
|
||||
remote_evaluators,
|
||||
learning_starts=1000,
|
||||
buffer_size=10000,
|
||||
train_batch_size=32):
|
||||
PolicyOptimizer.__init__(self, local_evaluator, remote_evaluators)
|
||||
|
||||
self.replay_starts = learning_starts
|
||||
self.max_buffer_size = buffer_size
|
||||
self.train_batch_size = train_batch_size
|
||||
|
||||
@@ -28,19 +28,21 @@ class SyncReplayOptimizer(PolicyOptimizer):
|
||||
"td_error" array in the info return of compute_gradients(). This error
|
||||
term will be used for sample prioritization."""
|
||||
|
||||
@override(PolicyOptimizer)
|
||||
def _init(self,
|
||||
learning_starts=1000,
|
||||
buffer_size=10000,
|
||||
prioritized_replay=True,
|
||||
prioritized_replay_alpha=0.6,
|
||||
prioritized_replay_beta=0.4,
|
||||
schedule_max_timesteps=100000,
|
||||
beta_annealing_fraction=0.2,
|
||||
final_prioritized_replay_beta=0.4,
|
||||
prioritized_replay_eps=1e-6,
|
||||
train_batch_size=32,
|
||||
sample_batch_size=4):
|
||||
def __init__(self,
|
||||
local_evaluator,
|
||||
remote_evaluators,
|
||||
learning_starts=1000,
|
||||
buffer_size=10000,
|
||||
prioritized_replay=True,
|
||||
prioritized_replay_alpha=0.6,
|
||||
prioritized_replay_beta=0.4,
|
||||
schedule_max_timesteps=100000,
|
||||
beta_annealing_fraction=0.2,
|
||||
final_prioritized_replay_beta=0.4,
|
||||
prioritized_replay_eps=1e-6,
|
||||
train_batch_size=32,
|
||||
sample_batch_size=4):
|
||||
PolicyOptimizer.__init__(self, local_evaluator, remote_evaluators)
|
||||
|
||||
self.replay_starts = learning_starts
|
||||
# linearly annealing beta used in Rainbow paper
|
||||
|
||||
@@ -22,8 +22,13 @@ class SyncSamplesOptimizer(PolicyOptimizer):
|
||||
model weights are then broadcast to all remote evaluators.
|
||||
"""
|
||||
|
||||
@override(PolicyOptimizer)
|
||||
def _init(self, num_sgd_iter=1, train_batch_size=1):
|
||||
def __init__(self,
|
||||
local_evaluator,
|
||||
remote_evaluators,
|
||||
num_sgd_iter=1,
|
||||
train_batch_size=1):
|
||||
PolicyOptimizer.__init__(self, local_evaluator, remote_evaluators)
|
||||
|
||||
self.update_weights_timer = TimerStat()
|
||||
self.sample_timer = TimerStat()
|
||||
self.grad_timer = TimerStat()
|
||||
|
||||
@@ -75,7 +75,7 @@ class TestExternalMultiAgentEnv(unittest.TestCase):
|
||||
policy_graph=policies,
|
||||
policy_mapping_fn=lambda agent_id: random.choice(policy_ids),
|
||||
batch_steps=100)
|
||||
optimizer = SyncSamplesOptimizer(ev, [], {})
|
||||
optimizer = SyncSamplesOptimizer(ev, [])
|
||||
for i in range(100):
|
||||
optimizer.step()
|
||||
result = collect_metrics(ev)
|
||||
|
||||
@@ -606,7 +606,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
]
|
||||
else:
|
||||
remote_evs = []
|
||||
optimizer = optimizer_cls(ev, remote_evs, {})
|
||||
optimizer = optimizer_cls(ev, remote_evs)
|
||||
for i in range(200):
|
||||
ev.foreach_policy(lambda p, _: p.set_epsilon(
|
||||
max(0.02, 1 - i * .02))
|
||||
@@ -648,7 +648,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
policy_graph=policies,
|
||||
policy_mapping_fn=lambda agent_id: random.choice(policy_ids),
|
||||
batch_steps=100)
|
||||
optimizer = SyncSamplesOptimizer(ev, [], {})
|
||||
optimizer = SyncSamplesOptimizer(ev, [])
|
||||
for i in range(100):
|
||||
optimizer.step()
|
||||
result = collect_metrics(ev)
|
||||
|
||||
@@ -27,8 +27,8 @@ class AsyncOptimizerTest(unittest.TestCase):
|
||||
local = _MockEvaluator()
|
||||
remotes = ray.remote(_MockEvaluator)
|
||||
remote_evaluators = [remotes.remote() for i in range(5)]
|
||||
test_optimizer = AsyncGradientsOptimizer(local, remote_evaluators,
|
||||
{"grads_per_step": 10})
|
||||
test_optimizer = AsyncGradientsOptimizer(
|
||||
local, remote_evaluators, grads_per_step=10)
|
||||
test_optimizer.step()
|
||||
self.assertTrue(all(local.get_weights() == 0))
|
||||
|
||||
@@ -115,35 +115,34 @@ class AsyncSamplesOptimizerTest(unittest.TestCase):
|
||||
|
||||
def testSimple(self):
|
||||
local, remotes = self._make_evs()
|
||||
optimizer = AsyncSamplesOptimizer(local, remotes, {})
|
||||
optimizer = AsyncSamplesOptimizer(local, remotes)
|
||||
self._wait_for(optimizer, 1000, 1000)
|
||||
|
||||
def testMultiGPU(self):
|
||||
local, remotes = self._make_evs()
|
||||
optimizer = AsyncSamplesOptimizer(local, remotes, {
|
||||
"num_gpus": 2,
|
||||
"_fake_gpus": True
|
||||
})
|
||||
optimizer = AsyncSamplesOptimizer(
|
||||
local, remotes, num_gpus=2, _fake_gpus=True)
|
||||
self._wait_for(optimizer, 1000, 1000)
|
||||
|
||||
def testMultiGPUParallelLoad(self):
|
||||
local, remotes = self._make_evs()
|
||||
optimizer = AsyncSamplesOptimizer(local, remotes, {
|
||||
"num_gpus": 2,
|
||||
"num_data_loader_buffers": 2,
|
||||
"_fake_gpus": True
|
||||
})
|
||||
optimizer = AsyncSamplesOptimizer(
|
||||
local,
|
||||
remotes,
|
||||
num_gpus=2,
|
||||
num_data_loader_buffers=2,
|
||||
_fake_gpus=True)
|
||||
self._wait_for(optimizer, 1000, 1000)
|
||||
|
||||
def testMultiplePasses(self):
|
||||
local, remotes = self._make_evs()
|
||||
optimizer = AsyncSamplesOptimizer(
|
||||
local, remotes, {
|
||||
"minibatch_buffer_size": 10,
|
||||
"num_sgd_iter": 10,
|
||||
"sample_batch_size": 10,
|
||||
"train_batch_size": 50,
|
||||
})
|
||||
local,
|
||||
remotes,
|
||||
minibatch_buffer_size=10,
|
||||
num_sgd_iter=10,
|
||||
sample_batch_size=10,
|
||||
train_batch_size=50)
|
||||
self._wait_for(optimizer, 1000, 10000)
|
||||
self.assertLess(optimizer.stats()["num_steps_sampled"], 5000)
|
||||
self.assertGreater(optimizer.stats()["num_steps_trained"], 8000)
|
||||
@@ -151,12 +150,13 @@ class AsyncSamplesOptimizerTest(unittest.TestCase):
|
||||
def testReplay(self):
|
||||
local, remotes = self._make_evs()
|
||||
optimizer = AsyncSamplesOptimizer(
|
||||
local, remotes, {
|
||||
"replay_buffer_num_slots": 100,
|
||||
"replay_proportion": 10,
|
||||
"sample_batch_size": 10,
|
||||
"train_batch_size": 10,
|
||||
})
|
||||
local,
|
||||
remotes,
|
||||
replay_buffer_num_slots=100,
|
||||
replay_proportion=10,
|
||||
sample_batch_size=10,
|
||||
train_batch_size=10,
|
||||
)
|
||||
self._wait_for(optimizer, 1000, 1000)
|
||||
stats = optimizer.stats()
|
||||
self.assertLess(stats["num_steps_sampled"], 5000)
|
||||
@@ -167,14 +167,14 @@ class AsyncSamplesOptimizerTest(unittest.TestCase):
|
||||
def testReplayAndMultiplePasses(self):
|
||||
local, remotes = self._make_evs()
|
||||
optimizer = AsyncSamplesOptimizer(
|
||||
local, remotes, {
|
||||
"minibatch_buffer_size": 10,
|
||||
"num_sgd_iter": 10,
|
||||
"replay_buffer_num_slots": 100,
|
||||
"replay_proportion": 10,
|
||||
"sample_batch_size": 10,
|
||||
"train_batch_size": 10,
|
||||
})
|
||||
local,
|
||||
remotes,
|
||||
minibatch_buffer_size=10,
|
||||
num_sgd_iter=10,
|
||||
replay_buffer_num_slots=100,
|
||||
replay_proportion=10,
|
||||
sample_batch_size=10,
|
||||
train_batch_size=10)
|
||||
self._wait_for(optimizer, 1000, 1000)
|
||||
|
||||
stats = optimizer.stats()
|
||||
@@ -188,17 +188,16 @@ class AsyncSamplesOptimizerTest(unittest.TestCase):
|
||||
def testMultiTierAggregationBadConf(self):
|
||||
local, remotes = self._make_evs()
|
||||
aggregators = TreeAggregator.precreate_aggregators(4)
|
||||
optimizer = AsyncSamplesOptimizer(local, remotes,
|
||||
{"num_aggregation_workers": 4})
|
||||
optimizer = AsyncSamplesOptimizer(
|
||||
local, remotes, num_aggregation_workers=4)
|
||||
self.assertRaises(ValueError,
|
||||
lambda: optimizer.aggregator.init(aggregators))
|
||||
|
||||
def testMultiTierAggregation(self):
|
||||
local, remotes = self._make_evs()
|
||||
aggregators = TreeAggregator.precreate_aggregators(1)
|
||||
optimizer = AsyncSamplesOptimizer(local, remotes, {
|
||||
"num_aggregation_workers": 1,
|
||||
})
|
||||
optimizer = AsyncSamplesOptimizer(
|
||||
local, remotes, num_aggregation_workers=1)
|
||||
optimizer.aggregator.init(aggregators)
|
||||
self._wait_for(optimizer, 1000, 1000)
|
||||
|
||||
@@ -207,30 +206,30 @@ class AsyncSamplesOptimizerTest(unittest.TestCase):
|
||||
self.assertRaises(
|
||||
ValueError, lambda: AsyncSamplesOptimizer(
|
||||
local, remotes,
|
||||
{"num_data_loader_buffers": 2, "minibatch_buffer_size": 4}))
|
||||
num_data_loader_buffers=2, minibatch_buffer_size=4))
|
||||
optimizer = AsyncSamplesOptimizer(
|
||||
local, remotes, {
|
||||
"num_gpus": 2,
|
||||
"train_batch_size": 100,
|
||||
"sample_batch_size": 50,
|
||||
"_fake_gpus": True
|
||||
})
|
||||
local,
|
||||
remotes,
|
||||
num_gpus=2,
|
||||
train_batch_size=100,
|
||||
sample_batch_size=50,
|
||||
_fake_gpus=True)
|
||||
self._wait_for(optimizer, 1000, 1000)
|
||||
optimizer = AsyncSamplesOptimizer(
|
||||
local, remotes, {
|
||||
"num_gpus": 2,
|
||||
"train_batch_size": 100,
|
||||
"sample_batch_size": 25,
|
||||
"_fake_gpus": True
|
||||
})
|
||||
local,
|
||||
remotes,
|
||||
num_gpus=2,
|
||||
train_batch_size=100,
|
||||
sample_batch_size=25,
|
||||
_fake_gpus=True)
|
||||
self._wait_for(optimizer, 1000, 1000)
|
||||
optimizer = AsyncSamplesOptimizer(
|
||||
local, remotes, {
|
||||
"num_gpus": 2,
|
||||
"train_batch_size": 100,
|
||||
"sample_batch_size": 74,
|
||||
"_fake_gpus": True
|
||||
})
|
||||
local,
|
||||
remotes,
|
||||
num_gpus=2,
|
||||
train_batch_size=100,
|
||||
sample_batch_size=74,
|
||||
_fake_gpus=True)
|
||||
self._wait_for(optimizer, 1000, 1000)
|
||||
|
||||
def _make_evs(self):
|
||||
|
||||
Reference in New Issue
Block a user