[rllib] Clean up concepts documentation and policy optimizer creation (#4592)

This commit is contained in:
Eric Liang
2019-04-12 21:03:26 -07:00
committed by GitHub
parent 0f42f87ebc
commit 6e7680bf21
29 changed files with 303 additions and 270 deletions
+3 -2
View File
@@ -26,5 +26,6 @@ class A2CTrainer(A3CTrainer):
@override(A3CTrainer)
def _make_optimizer(self):
return SyncSamplesOptimizer(
self.local_evaluator, self.remote_evaluators,
{"train_batch_size": self.config["train_batch_size"]})
self.local_evaluator,
self.remote_evaluators,
train_batch_size=self.config["train_batch_size"])
+1 -1
View File
@@ -77,4 +77,4 @@ class A3CTrainer(Trainer):
def _make_optimizer(self):
return AsyncGradientsOptimizer(self.local_evaluator,
self.remote_evaluators,
self.config["optimizer"])
**self.config["optimizer"])
+2 -1
View File
@@ -229,7 +229,8 @@ class DQNTrainer(Trainer):
self.remote_evaluators = None
self.optimizer = getattr(optimizers, config["optimizer_class"])(
self.local_evaluator, self.remote_evaluators, config["optimizer"])
self.local_evaluator, self.remote_evaluators,
**config["optimizer"])
# Create the remote evaluators *after* the replay actors
if self.remote_evaluators is None:
self.remote_evaluators = create_remote_evaluators()
+3 -2
View File
@@ -123,8 +123,9 @@ class ImpalaTrainer(Trainer):
self.remote_evaluators = self.make_remote_evaluators(
env_creator, policy_cls, config["num_workers"])
self.optimizer = AsyncSamplesOptimizer(
self.local_evaluator, self.remote_evaluators, config["optimizer"])
self.optimizer = AsyncSamplesOptimizer(self.local_evaluator,
self.remote_evaluators,
**config["optimizer"])
if config["entropy_coeff"] < 0:
raise DeprecationWarning("entropy_coeff must be >= 0")
+6 -5
View File
@@ -53,11 +53,12 @@ class MARWILTrainer(Trainer):
self.remote_evaluators = self.make_remote_evaluators(
env_creator, self._policy_graph, config["num_workers"])
self.optimizer = SyncBatchReplayOptimizer(
self.local_evaluator, self.remote_evaluators, {
"learning_starts": config["learning_starts"],
"buffer_size": config["replay_buffer_size"],
"train_batch_size": config["train_batch_size"],
})
self.local_evaluator,
self.remote_evaluators,
learning_starts=config["learning_starts"],
buffer_size=config["replay_buffer_size"],
train_batch_size=config["train_batch_size"],
)
@override(Trainer)
def _train(self):
+1 -1
View File
@@ -49,7 +49,7 @@ class PGTrainer(Trainer):
config["optimizer"],
**{"train_batch_size": config["train_batch_size"]})
self.optimizer = SyncSamplesOptimizer(
self.local_evaluator, self.remote_evaluators, optimizer_config)
self.local_evaluator, self.remote_evaluators, **optimizer_config)
@override(Trainer)
def _train(self):
+14 -14
View File
@@ -79,22 +79,22 @@ class PPOTrainer(Trainer):
env_creator, self._policy_graph, config["num_workers"])
if config["simple_optimizer"]:
self.optimizer = SyncSamplesOptimizer(
self.local_evaluator, self.remote_evaluators, {
"num_sgd_iter": config["num_sgd_iter"],
"train_batch_size": config["train_batch_size"],
})
self.local_evaluator,
self.remote_evaluators,
num_sgd_iter=config["num_sgd_iter"],
train_batch_size=config["train_batch_size"])
else:
self.optimizer = LocalMultiGPUOptimizer(
self.local_evaluator, self.remote_evaluators, {
"sgd_batch_size": config["sgd_minibatch_size"],
"num_sgd_iter": config["num_sgd_iter"],
"num_gpus": config["num_gpus"],
"sample_batch_size": config["sample_batch_size"],
"num_envs_per_worker": config["num_envs_per_worker"],
"train_batch_size": config["train_batch_size"],
"standardize_fields": ["advantages"],
"straggler_mitigation": config["straggler_mitigation"],
})
self.local_evaluator,
self.remote_evaluators,
sgd_batch_size=config["sgd_minibatch_size"],
num_sgd_iter=config["num_sgd_iter"],
num_gpus=config["num_gpus"],
sample_batch_size=config["sample_batch_size"],
num_envs_per_worker=config["num_envs_per_worker"],
train_batch_size=config["train_batch_size"],
standardize_fields=["advantages"],
straggler_mitigation=config["straggler_mitigation"])
@override(Trainer)
def _train(self):
@@ -299,7 +299,7 @@ class QMixPolicyGraph(PolicyGraph):
mask_elems,
"target_mean": (targets * mask).sum().item() / mask_elems,
}
return {LEARNER_STATS_KEY: stats}, {}
return {LEARNER_STATS_KEY: stats}
@override(PolicyGraph)
def get_initial_state(self):
@@ -56,7 +56,7 @@ class KerasPolicyGraph(PolicyGraph):
epochs=1,
verbose=0,
steps_per_epoch=20)
return {}, {}
return {}
def get_weights(self):
return [model.get_weights() for model in self.models]
@@ -574,14 +574,14 @@ class PolicyEvaluator(EvaluatorInterface):
continue
policy = self.policy_map[pid]
if builder and hasattr(policy, "_build_learn_on_batch"):
to_fetch[pid], _ = policy._build_learn_on_batch(
to_fetch[pid] = policy._build_learn_on_batch(
builder, batch)
else:
info_out[pid], _ = policy.learn_on_batch(batch)
info_out[pid] = policy.learn_on_batch(batch)
info_out.update({k: builder.get(v) for k, v in to_fetch.items()})
else:
info_out, _ = (
self.policy_map[DEFAULT_POLICY_ID].learn_on_batch(samples))
info_out = self.policy_map[DEFAULT_POLICY_ID].learn_on_batch(
samples)
if log_once("learn_out"):
logger.info("Training output:\n\n{}\n".format(summarize(info_out)))
return info_out
+2 -6
View File
@@ -163,7 +163,6 @@ class PolicyGraph(object):
Returns:
grad_info: dictionary of extra metadata from compute_gradients().
apply_info: dictionary of extra metadata from apply_gradients().
Examples:
>>> batch = ev.sample()
@@ -171,8 +170,8 @@ class PolicyGraph(object):
"""
grads, grad_info = self.compute_gradients(samples)
apply_info = self.apply_gradients(grads)
return grad_info, apply_info
self.apply_gradients(grads)
return grad_info
@DeveloperAPI
def compute_gradients(self, postprocessed_batch):
@@ -191,9 +190,6 @@ class PolicyGraph(object):
"""Applies previously computed gradients.
Either this or learn_on_batch() must be implemented by subclasses.
Returns:
info (dict): Extra policy-specific values
"""
raise NotImplementedError
+4 -18
View File
@@ -193,7 +193,7 @@ class TFPolicyGraph(PolicyGraph):
def apply_gradients(self, gradients):
builder = TFRunBuilder(self._sess, "apply_gradients")
fetches = self._build_apply_gradients(builder, gradients)
return builder.get(fetches)
builder.get(fetches)
@override(PolicyGraph)
def learn_on_batch(self, postprocessed_batch):
@@ -267,16 +267,6 @@ class TFPolicyGraph(PolicyGraph):
"""Extra values to fetch and return from compute_gradients()."""
return {LEARNER_STATS_KEY: {}} # e.g, stats, td error, etc.
@DeveloperAPI
def extra_apply_grad_feed_dict(self):
"""Extra dict to pass to the apply gradients session run."""
return {}
@DeveloperAPI
def extra_apply_grad_fetches(self):
"""Extra values to fetch and return from apply_gradients()."""
return {} # e.g., batch norm updates
@DeveloperAPI
def optimizer(self):
"""TF optimizer to use for policy optimization."""
@@ -405,24 +395,20 @@ class TFPolicyGraph(PolicyGraph):
raise ValueError(
"Unexpected number of gradients to apply, got {} for {}".
format(gradients, self._grads))
builder.add_feed_dict(self.extra_apply_grad_feed_dict())
builder.add_feed_dict({self._is_training: True})
builder.add_feed_dict(dict(zip(self._grads, gradients)))
fetches = builder.add_fetches(
[self._apply_op, self.extra_apply_grad_fetches()])
return fetches[1]
fetches = builder.add_fetches([self._apply_op])
return fetches[0]
def _build_learn_on_batch(self, builder, postprocessed_batch):
builder.add_feed_dict(self.extra_compute_grad_feed_dict())
builder.add_feed_dict(self.extra_apply_grad_feed_dict())
builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch))
builder.add_feed_dict({self._is_training: True})
fetches = builder.add_fetches([
self._apply_op,
self._get_grad_and_stats_fetches(),
self.extra_apply_grad_fetches()
])
return fetches[1], fetches[2]
return fetches[1]
def _get_grad_and_stats_fetches(self):
fetches = self.extra_compute_grad_fetches()
@@ -118,7 +118,6 @@ class TorchPolicyGraph(PolicyGraph):
if g is not None:
p.grad = torch.from_numpy(g).to(self.device)
self._optimizer.step()
return {}
@override(PolicyGraph)
def get_weights(self):
@@ -46,7 +46,7 @@ class RandomPolicy(PolicyGraph):
def learn_on_batch(self, samples):
"""No learning."""
return {}, {}
return {}
if __name__ == "__main__":
@@ -48,7 +48,7 @@ class CustomPolicy(PolicyGraph):
def learn_on_batch(self, samples):
# implement your learning code here
return {}, {}
return {}
def update_some_value(self, w):
# can also call other methods on policies
@@ -17,8 +17,9 @@ class AsyncGradientsOptimizer(PolicyOptimizer):
gradient computations on the remote workers.
"""
@override(PolicyOptimizer)
def _init(self, grads_per_step=100):
def __init__(self, local_evaluator, remote_evaluators, grads_per_step=100):
PolicyOptimizer.__init__(self, local_evaluator, remote_evaluators)
self.apply_timer = TimerStat()
self.wait_timer = TimerStat()
self.dispatch_timer = TimerStat()
@@ -46,20 +46,22 @@ class AsyncReplayOptimizer(PolicyOptimizer):
"td_error" array in the info return of compute_gradients(). This error
term will be used for sample prioritization."""
@override(PolicyOptimizer)
def _init(self,
learning_starts=1000,
buffer_size=10000,
prioritized_replay=True,
prioritized_replay_alpha=0.6,
prioritized_replay_beta=0.4,
prioritized_replay_eps=1e-6,
train_batch_size=512,
sample_batch_size=50,
num_replay_buffer_shards=1,
max_weight_sync_delay=400,
debug=False,
batch_replay=False):
def __init__(self,
local_evaluator,
remote_evaluators,
learning_starts=1000,
buffer_size=10000,
prioritized_replay=True,
prioritized_replay_alpha=0.6,
prioritized_replay_beta=0.4,
prioritized_replay_eps=1e-6,
train_batch_size=512,
sample_batch_size=50,
num_replay_buffer_shards=1,
max_weight_sync_delay=400,
debug=False,
batch_replay=False):
PolicyOptimizer.__init__(self, local_evaluator, remote_evaluators)
self.debug = debug
self.batch_replay = batch_replay
@@ -27,23 +27,26 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
and remote evaluators (IMPALA actors).
"""
@override(PolicyOptimizer)
def _init(self,
train_batch_size=500,
sample_batch_size=50,
num_envs_per_worker=1,
num_gpus=0,
lr=0.0005,
replay_buffer_num_slots=0,
replay_proportion=0.0,
num_data_loader_buffers=1,
max_sample_requests_in_flight_per_worker=2,
broadcast_interval=1,
num_sgd_iter=1,
minibatch_buffer_size=1,
learner_queue_size=16,
num_aggregation_workers=0,
_fake_gpus=False):
def __init__(self,
local_evaluator,
remote_evaluators,
train_batch_size=500,
sample_batch_size=50,
num_envs_per_worker=1,
num_gpus=0,
lr=0.0005,
replay_buffer_num_slots=0,
replay_proportion=0.0,
num_data_loader_buffers=1,
max_sample_requests_in_flight_per_worker=2,
broadcast_interval=1,
num_sgd_iter=1,
minibatch_buffer_size=1,
learner_queue_size=16,
num_aggregation_workers=0,
_fake_gpus=False):
PolicyOptimizer.__init__(self, local_evaluator, remote_evaluators)
self._stats_start_time = time.time()
self._last_stats_time = {}
self._last_stats_sum = {}
@@ -250,12 +250,10 @@ class LocalSyncParallelOptimizer(object):
}
for tower in self._towers:
feed_dict.update(tower.loss_graph.extra_compute_grad_feed_dict())
feed_dict.update(tower.loss_graph.extra_apply_grad_feed_dict())
fetches = {"train": self._train_op}
for tower in self._towers:
fetches.update(tower.loss_graph.extra_compute_grad_fetches())
fetches.update(tower.loss_graph.extra_apply_grad_fetches())
return sess.run(fetches, feed_dict=feed_dict)
@@ -39,16 +39,19 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
may result in unexpected behavior.
"""
@override(PolicyOptimizer)
def _init(self,
sgd_batch_size=128,
num_sgd_iter=10,
sample_batch_size=200,
num_envs_per_worker=1,
train_batch_size=1024,
num_gpus=0,
standardize_fields=[],
straggler_mitigation=False):
def __init__(self,
local_evaluator,
remote_evaluators,
sgd_batch_size=128,
num_sgd_iter=10,
sample_batch_size=200,
num_envs_per_worker=1,
train_batch_size=1024,
num_gpus=0,
standardize_fields=[],
straggler_mitigation=False):
PolicyOptimizer.__init__(self, local_evaluator, remote_evaluators)
self.batch_size = sgd_batch_size
self.num_sgd_iter = num_sgd_iter
self.num_envs_per_worker = num_envs_per_worker
@@ -6,7 +6,6 @@ import logging
import ray
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.evaluation.metrics import collect_episodes, summarize_episodes
logger = logging.getLogger(__name__)
@@ -39,11 +38,10 @@ class PolicyOptimizer(object):
"""
@DeveloperAPI
def __init__(self, local_evaluator, remote_evaluators=None, config=None):
def __init__(self, local_evaluator, remote_evaluators=None):
"""Create an optimizer instance.
Args:
config (dict): Optimizer-specific arguments.
local_evaluator (Evaluator): Local evaluator instance, required.
remote_evaluators (list): A list of Ray actor handles to remote
evaluators instances. If empty, the optimizer should fall back
@@ -52,22 +50,11 @@ class PolicyOptimizer(object):
self.local_evaluator = local_evaluator
self.remote_evaluators = remote_evaluators or []
self.episode_history = []
self.config = config or {}
self._init(**self.config)
# Counters that should be updated by sub-classes
self.num_steps_trained = 0
self.num_steps_sampled = 0
logger.debug("Created policy optimizer with {}: {}".format(
config, self))
@DeveloperAPI
def _init(self, **config):
"""Subclasses should prefer overriding this instead of __init__."""
raise NotImplementedError
@DeveloperAPI
def step(self):
"""Takes a logical optimization step.
@@ -170,60 +157,3 @@ class PolicyOptimizer(object):
for i, ev in enumerate(self.remote_evaluators)
])
return local_result + remote_results
@classmethod
def make(cls,
env_creator,
policy_graph,
optimizer_batch_size=None,
num_workers=0,
num_envs_per_worker=None,
optimizer_config=None,
remote_num_cpus=None,
remote_num_gpus=None,
**eval_kwargs):
"""Creates an Optimizer with local and remote evaluators.
Args:
env_creator(func): Function that returns a gym.Env given an
EnvContext wrapped configuration.
policy_graph (class|dict): Either a class implementing
PolicyGraph, or a dictionary of policy id strings to
(PolicyGraph, obs_space, action_space, config) tuples.
See PolicyEvaluator documentation.
optimizer_batch_size (int): Batch size summed across all workers.
Will override worker `batch_steps`.
num_workers (int): Number of remote evaluators
num_envs_per_worker (int): (Optional) Sets the number
environments per evaluator for vectorization.
If set, overrides `num_envs` in kwargs
for PolicyEvaluator.__init__.
optimizer_config (dict): Config passed to the optimizer.
remote_num_cpus (int): CPU specification for remote evaluator.
remote_num_gpus (int): GPU specification for remote evaluator.
**eval_kwargs: PolicyEvaluator Class non-positional args.
Returns:
(Optimizer) Instance of `cls` with evaluators configured
accordingly.
"""
optimizer_config = optimizer_config or {}
if num_envs_per_worker:
assert num_envs_per_worker > 0, "Improper num_envs_per_worker!"
eval_kwargs["num_envs"] = int(num_envs_per_worker)
if optimizer_batch_size:
assert optimizer_batch_size > 0
if num_workers > 1:
eval_kwargs["batch_steps"] = \
optimizer_batch_size // num_workers
else:
eval_kwargs["batch_steps"] = optimizer_batch_size
evaluator = PolicyEvaluator(env_creator, policy_graph, **eval_kwargs)
remote_cls = PolicyEvaluator.as_remote(remote_num_cpus,
remote_num_gpus)
remote_evaluators = [
remote_cls.remote(env_creator, policy_graph, **eval_kwargs)
for i in range(num_workers)
]
return cls(evaluator, remote_evaluators, optimizer_config)
@@ -18,11 +18,14 @@ class SyncBatchReplayOptimizer(PolicyOptimizer):
This enables RNN support. Does not currently support prioritization."""
@override(PolicyOptimizer)
def _init(self,
learning_starts=1000,
buffer_size=10000,
train_batch_size=32):
def __init__(self,
local_evaluator,
remote_evaluators,
learning_starts=1000,
buffer_size=10000,
train_batch_size=32):
PolicyOptimizer.__init__(self, local_evaluator, remote_evaluators)
self.replay_starts = learning_starts
self.max_buffer_size = buffer_size
self.train_batch_size = train_batch_size
@@ -28,19 +28,21 @@ class SyncReplayOptimizer(PolicyOptimizer):
"td_error" array in the info return of compute_gradients(). This error
term will be used for sample prioritization."""
@override(PolicyOptimizer)
def _init(self,
learning_starts=1000,
buffer_size=10000,
prioritized_replay=True,
prioritized_replay_alpha=0.6,
prioritized_replay_beta=0.4,
schedule_max_timesteps=100000,
beta_annealing_fraction=0.2,
final_prioritized_replay_beta=0.4,
prioritized_replay_eps=1e-6,
train_batch_size=32,
sample_batch_size=4):
def __init__(self,
local_evaluator,
remote_evaluators,
learning_starts=1000,
buffer_size=10000,
prioritized_replay=True,
prioritized_replay_alpha=0.6,
prioritized_replay_beta=0.4,
schedule_max_timesteps=100000,
beta_annealing_fraction=0.2,
final_prioritized_replay_beta=0.4,
prioritized_replay_eps=1e-6,
train_batch_size=32,
sample_batch_size=4):
PolicyOptimizer.__init__(self, local_evaluator, remote_evaluators)
self.replay_starts = learning_starts
# linearly annealing beta used in Rainbow paper
@@ -22,8 +22,13 @@ class SyncSamplesOptimizer(PolicyOptimizer):
model weights are then broadcast to all remote evaluators.
"""
@override(PolicyOptimizer)
def _init(self, num_sgd_iter=1, train_batch_size=1):
def __init__(self,
local_evaluator,
remote_evaluators,
num_sgd_iter=1,
train_batch_size=1):
PolicyOptimizer.__init__(self, local_evaluator, remote_evaluators)
self.update_weights_timer = TimerStat()
self.sample_timer = TimerStat()
self.grad_timer = TimerStat()
@@ -75,7 +75,7 @@ class TestExternalMultiAgentEnv(unittest.TestCase):
policy_graph=policies,
policy_mapping_fn=lambda agent_id: random.choice(policy_ids),
batch_steps=100)
optimizer = SyncSamplesOptimizer(ev, [], {})
optimizer = SyncSamplesOptimizer(ev, [])
for i in range(100):
optimizer.step()
result = collect_metrics(ev)
@@ -606,7 +606,7 @@ class TestMultiAgentEnv(unittest.TestCase):
]
else:
remote_evs = []
optimizer = optimizer_cls(ev, remote_evs, {})
optimizer = optimizer_cls(ev, remote_evs)
for i in range(200):
ev.foreach_policy(lambda p, _: p.set_epsilon(
max(0.02, 1 - i * .02))
@@ -648,7 +648,7 @@ class TestMultiAgentEnv(unittest.TestCase):
policy_graph=policies,
policy_mapping_fn=lambda agent_id: random.choice(policy_ids),
batch_steps=100)
optimizer = SyncSamplesOptimizer(ev, [], {})
optimizer = SyncSamplesOptimizer(ev, [])
for i in range(100):
optimizer.step()
result = collect_metrics(ev)
+55 -56
View File
@@ -27,8 +27,8 @@ class AsyncOptimizerTest(unittest.TestCase):
local = _MockEvaluator()
remotes = ray.remote(_MockEvaluator)
remote_evaluators = [remotes.remote() for i in range(5)]
test_optimizer = AsyncGradientsOptimizer(local, remote_evaluators,
{"grads_per_step": 10})
test_optimizer = AsyncGradientsOptimizer(
local, remote_evaluators, grads_per_step=10)
test_optimizer.step()
self.assertTrue(all(local.get_weights() == 0))
@@ -115,35 +115,34 @@ class AsyncSamplesOptimizerTest(unittest.TestCase):
def testSimple(self):
local, remotes = self._make_evs()
optimizer = AsyncSamplesOptimizer(local, remotes, {})
optimizer = AsyncSamplesOptimizer(local, remotes)
self._wait_for(optimizer, 1000, 1000)
def testMultiGPU(self):
local, remotes = self._make_evs()
optimizer = AsyncSamplesOptimizer(local, remotes, {
"num_gpus": 2,
"_fake_gpus": True
})
optimizer = AsyncSamplesOptimizer(
local, remotes, num_gpus=2, _fake_gpus=True)
self._wait_for(optimizer, 1000, 1000)
def testMultiGPUParallelLoad(self):
local, remotes = self._make_evs()
optimizer = AsyncSamplesOptimizer(local, remotes, {
"num_gpus": 2,
"num_data_loader_buffers": 2,
"_fake_gpus": True
})
optimizer = AsyncSamplesOptimizer(
local,
remotes,
num_gpus=2,
num_data_loader_buffers=2,
_fake_gpus=True)
self._wait_for(optimizer, 1000, 1000)
def testMultiplePasses(self):
local, remotes = self._make_evs()
optimizer = AsyncSamplesOptimizer(
local, remotes, {
"minibatch_buffer_size": 10,
"num_sgd_iter": 10,
"sample_batch_size": 10,
"train_batch_size": 50,
})
local,
remotes,
minibatch_buffer_size=10,
num_sgd_iter=10,
sample_batch_size=10,
train_batch_size=50)
self._wait_for(optimizer, 1000, 10000)
self.assertLess(optimizer.stats()["num_steps_sampled"], 5000)
self.assertGreater(optimizer.stats()["num_steps_trained"], 8000)
@@ -151,12 +150,13 @@ class AsyncSamplesOptimizerTest(unittest.TestCase):
def testReplay(self):
local, remotes = self._make_evs()
optimizer = AsyncSamplesOptimizer(
local, remotes, {
"replay_buffer_num_slots": 100,
"replay_proportion": 10,
"sample_batch_size": 10,
"train_batch_size": 10,
})
local,
remotes,
replay_buffer_num_slots=100,
replay_proportion=10,
sample_batch_size=10,
train_batch_size=10,
)
self._wait_for(optimizer, 1000, 1000)
stats = optimizer.stats()
self.assertLess(stats["num_steps_sampled"], 5000)
@@ -167,14 +167,14 @@ class AsyncSamplesOptimizerTest(unittest.TestCase):
def testReplayAndMultiplePasses(self):
local, remotes = self._make_evs()
optimizer = AsyncSamplesOptimizer(
local, remotes, {
"minibatch_buffer_size": 10,
"num_sgd_iter": 10,
"replay_buffer_num_slots": 100,
"replay_proportion": 10,
"sample_batch_size": 10,
"train_batch_size": 10,
})
local,
remotes,
minibatch_buffer_size=10,
num_sgd_iter=10,
replay_buffer_num_slots=100,
replay_proportion=10,
sample_batch_size=10,
train_batch_size=10)
self._wait_for(optimizer, 1000, 1000)
stats = optimizer.stats()
@@ -188,17 +188,16 @@ class AsyncSamplesOptimizerTest(unittest.TestCase):
def testMultiTierAggregationBadConf(self):
local, remotes = self._make_evs()
aggregators = TreeAggregator.precreate_aggregators(4)
optimizer = AsyncSamplesOptimizer(local, remotes,
{"num_aggregation_workers": 4})
optimizer = AsyncSamplesOptimizer(
local, remotes, num_aggregation_workers=4)
self.assertRaises(ValueError,
lambda: optimizer.aggregator.init(aggregators))
def testMultiTierAggregation(self):
local, remotes = self._make_evs()
aggregators = TreeAggregator.precreate_aggregators(1)
optimizer = AsyncSamplesOptimizer(local, remotes, {
"num_aggregation_workers": 1,
})
optimizer = AsyncSamplesOptimizer(
local, remotes, num_aggregation_workers=1)
optimizer.aggregator.init(aggregators)
self._wait_for(optimizer, 1000, 1000)
@@ -207,30 +206,30 @@ class AsyncSamplesOptimizerTest(unittest.TestCase):
self.assertRaises(
ValueError, lambda: AsyncSamplesOptimizer(
local, remotes,
{"num_data_loader_buffers": 2, "minibatch_buffer_size": 4}))
num_data_loader_buffers=2, minibatch_buffer_size=4))
optimizer = AsyncSamplesOptimizer(
local, remotes, {
"num_gpus": 2,
"train_batch_size": 100,
"sample_batch_size": 50,
"_fake_gpus": True
})
local,
remotes,
num_gpus=2,
train_batch_size=100,
sample_batch_size=50,
_fake_gpus=True)
self._wait_for(optimizer, 1000, 1000)
optimizer = AsyncSamplesOptimizer(
local, remotes, {
"num_gpus": 2,
"train_batch_size": 100,
"sample_batch_size": 25,
"_fake_gpus": True
})
local,
remotes,
num_gpus=2,
train_batch_size=100,
sample_batch_size=25,
_fake_gpus=True)
self._wait_for(optimizer, 1000, 1000)
optimizer = AsyncSamplesOptimizer(
local, remotes, {
"num_gpus": 2,
"train_batch_size": 100,
"sample_batch_size": 74,
"_fake_gpus": True
})
local,
remotes,
num_gpus=2,
train_batch_size=100,
sample_batch_size=74,
_fake_gpus=True)
self._wait_for(optimizer, 1000, 1000)
def _make_evs(self):