mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 16:46:43 +08:00
[rllib] more user-friendly Optimizer signature + compute_apply (#2335)
* Move signature of optimizers * fix * expose compute_apply for policy_graphs * dictionaries and such * test for multiagent
This commit is contained in:
@@ -94,8 +94,8 @@ class A3CAgent(Agent):
|
||||
self.env_creator, policy_cls, self.config["num_workers"],
|
||||
{"num_gpus": 1 if self.config["use_gpu_for_workers"] else 0})
|
||||
self.optimizer = AsyncGradientsOptimizer(
|
||||
self.config["optimizer"], self.local_evaluator,
|
||||
self.remote_evaluators)
|
||||
self.local_evaluator, self.remote_evaluators,
|
||||
self.config["optimizer"])
|
||||
|
||||
def _train(self):
|
||||
self.optimizer.step()
|
||||
|
||||
@@ -72,8 +72,8 @@ class BCAgent(Agent):
|
||||
remote_cls.remote(self.env_creator, self.config, self.logdir)
|
||||
for _ in range(self.config["num_workers"])]
|
||||
self.optimizer = AsyncGradientsOptimizer(
|
||||
self.config["optimizer"], self.local_evaluator,
|
||||
self.remote_evaluators)
|
||||
self.local_evaluator, self.remote_evaluators,
|
||||
self.config["optimizer"])
|
||||
|
||||
def _train(self):
|
||||
self.optimizer.step()
|
||||
|
||||
@@ -136,8 +136,8 @@ class DQNAgent(Agent):
|
||||
{"num_cpus": self.config["num_cpus_per_worker"],
|
||||
"num_gpus": self.config["num_gpus_per_worker"]})
|
||||
self.optimizer = getattr(optimizers, self.config["optimizer_class"])(
|
||||
self.config["optimizer"], self.local_evaluator,
|
||||
self.remote_evaluators)
|
||||
self.local_evaluator, self.remote_evaluators,
|
||||
self.config["optimizer"])
|
||||
|
||||
self.last_target_update_ts = 0
|
||||
self.num_target_updates = 0
|
||||
|
||||
@@ -45,8 +45,8 @@ class PGAgent(Agent):
|
||||
self.remote_evaluators = self.make_remote_evaluators(
|
||||
self.env_creator, PGPolicyGraph, self.config["num_workers"], {})
|
||||
self.optimizer = SyncSamplesOptimizer(
|
||||
self.config["optimizer"], self.local_evaluator,
|
||||
self.remote_evaluators)
|
||||
self.local_evaluator, self.remote_evaluators,
|
||||
self.config["optimizer"])
|
||||
|
||||
def _train(self):
|
||||
self.optimizer.step()
|
||||
|
||||
@@ -74,11 +74,11 @@ class PPOAgent(Agent):
|
||||
{"num_cpus": self.config["num_cpus_per_worker"],
|
||||
"num_gpus": self.config["num_gpus_per_worker"]})
|
||||
self.optimizer = LocalMultiGPUOptimizer(
|
||||
self.local_evaluator, self.remote_evaluators,
|
||||
{"sgd_batch_size": self.config["sgd_batchsize"],
|
||||
"sgd_stepsize": self.config["sgd_stepsize"],
|
||||
"num_sgd_iter": self.config["num_sgd_iter"],
|
||||
"timesteps_per_batch": self.config["timesteps_per_batch"]},
|
||||
self.local_evaluator, self.remote_evaluators)
|
||||
"timesteps_per_batch": self.config["timesteps_per_batch"]})
|
||||
|
||||
def _train(self):
|
||||
def postprocess_samples(batch):
|
||||
|
||||
@@ -106,6 +106,22 @@ class PolicyGraph(object):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def compute_apply(self, samples):
|
||||
"""Fused compute gradients and apply gradients call.
|
||||
|
||||
Returns:
|
||||
grad_info: dictionary of extra metadata from compute_gradients().
|
||||
apply_info: dictionary of extra metadata from apply_gradients().
|
||||
|
||||
Examples:
|
||||
>>> batch = ev.sample()
|
||||
>>> ev.compute_apply(samples)
|
||||
"""
|
||||
|
||||
grads, grad_info = self.compute_gradients(samples)
|
||||
apply_info = self.apply_gradients(grads)
|
||||
return grad_info, apply_info
|
||||
|
||||
def get_weights(self):
|
||||
"""Returns model weights.
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ class PolicyOptimizer(object):
|
||||
evaluators created by this optimizer.
|
||||
"""
|
||||
|
||||
def __init__(self, config, local_evaluator, remote_evaluators):
|
||||
def __init__(self, local_evaluator, remote_evaluators=None, config=None):
|
||||
"""Create an optimizer instance.
|
||||
|
||||
Args:
|
||||
@@ -41,10 +41,10 @@ class PolicyOptimizer(object):
|
||||
evaluators instances. If empty, the optimizer should fall back
|
||||
to using only the local evaluator.
|
||||
"""
|
||||
self.config = config
|
||||
self.local_evaluator = local_evaluator
|
||||
self.remote_evaluators = remote_evaluators
|
||||
self._init(**config)
|
||||
self.remote_evaluators = remote_evaluators or []
|
||||
self.config = config or {}
|
||||
self._init(**self.config)
|
||||
|
||||
# Counters that should be updated by sub-classes
|
||||
self.num_steps_trained = 0
|
||||
|
||||
@@ -40,8 +40,7 @@ class SyncSamplesOptimizer(PolicyOptimizer):
|
||||
samples = self.local_evaluator.sample()
|
||||
|
||||
with self.grad_timer:
|
||||
grad, _ = self.local_evaluator.compute_gradients(samples)
|
||||
self.local_evaluator.apply_gradients(grad)
|
||||
self.local_evaluator.compute_apply(samples)
|
||||
self.grad_timer.push_units_processed(samples.count)
|
||||
|
||||
self.num_steps_sampled += samples.count
|
||||
|
||||
@@ -296,7 +296,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
batch_steps=50)]
|
||||
else:
|
||||
remote_evs = []
|
||||
optimizer = optimizer_cls({}, ev, remote_evs)
|
||||
optimizer = optimizer_cls(ev, remote_evs, {})
|
||||
for i in range(200):
|
||||
ev.foreach_policy(
|
||||
lambda p, _: p.set_epsilon(max(0.02, 1 - i * .02))
|
||||
@@ -338,7 +338,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
policy_graph=policies,
|
||||
policy_mapping_fn=lambda agent_id: random.choice(policy_ids),
|
||||
batch_steps=100)
|
||||
optimizer = SyncSamplesOptimizer({}, ev, [])
|
||||
optimizer = SyncSamplesOptimizer(ev, [], {})
|
||||
for i in range(100):
|
||||
optimizer.step()
|
||||
result = collect_metrics(ev)
|
||||
|
||||
@@ -21,9 +21,8 @@ class AsyncOptimizerTest(unittest.TestCase):
|
||||
local = _MockEvaluator()
|
||||
remotes = ray.remote(_MockEvaluator)
|
||||
remote_evaluators = [remotes.remote() for i in range(5)]
|
||||
test_optimizer = AsyncGradientsOptimizer({
|
||||
"grads_per_step": 10
|
||||
}, local, remote_evaluators)
|
||||
test_optimizer = AsyncGradientsOptimizer(
|
||||
local, remote_evaluators, {"grads_per_step": 10})
|
||||
test_optimizer.step()
|
||||
self.assertTrue(all(local.get_weights() == 0))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user