mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 02:46:49 +08:00
[rllib] Flip sign of A2C, IMPALA entropy coefficient; raise DeprecationWarning if negative (#4374)
This commit is contained in:
@@ -27,7 +27,7 @@ DEFAULT_CONFIG = with_common_config({
|
||||
# Value Function Loss coefficient
|
||||
"vf_loss_coeff": 0.5,
|
||||
# Entropy coefficient
|
||||
"entropy_coeff": -0.01,
|
||||
"entropy_coeff": 0.01,
|
||||
# Min time per iteration
|
||||
"min_iter_time_s": 5,
|
||||
# Workers sample async. Note that this increases the effective
|
||||
@@ -54,6 +54,9 @@ class A3CAgent(Agent):
|
||||
else:
|
||||
policy_cls = self._policy_graph
|
||||
|
||||
if self.config["entropy_coeff"] < 0:
|
||||
raise DeprecationWarning("entropy_coeff must be >= 0")
|
||||
|
||||
self.local_evaluator = self.make_local_evaluator(
|
||||
self.env_creator, policy_cls)
|
||||
self.remote_evaluators = self.make_remote_evaluators(
|
||||
|
||||
@@ -26,7 +26,7 @@ class A3CLoss(object):
|
||||
v_target,
|
||||
vf,
|
||||
vf_loss_coeff=0.5,
|
||||
entropy_coeff=-0.01):
|
||||
entropy_coeff=0.01):
|
||||
log_prob = action_dist.logp(actions)
|
||||
|
||||
# The "policy gradients" loss
|
||||
@@ -35,7 +35,7 @@ class A3CLoss(object):
|
||||
delta = vf - v_target
|
||||
self.vf_loss = 0.5 * tf.reduce_sum(tf.square(delta))
|
||||
self.entropy = tf.reduce_sum(action_dist.entropy())
|
||||
self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff +
|
||||
self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff -
|
||||
self.entropy * entropy_coeff)
|
||||
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ from ray.rllib.utils.annotations import override
|
||||
|
||||
|
||||
class A3CLoss(nn.Module):
|
||||
def __init__(self, policy_model, vf_loss_coeff=0.5, entropy_coeff=-0.01):
|
||||
def __init__(self, policy_model, vf_loss_coeff=0.5, entropy_coeff=0.01):
|
||||
nn.Module.__init__(self)
|
||||
self.policy_model = policy_model
|
||||
self.vf_loss_coeff = vf_loss_coeff
|
||||
@@ -32,7 +32,7 @@ class A3CLoss(nn.Module):
|
||||
overall_err = sum([
|
||||
pi_err,
|
||||
self.vf_loss_coeff * value_err,
|
||||
self.entropy_coeff * entropy,
|
||||
-self.entropy_coeff * entropy,
|
||||
])
|
||||
return overall_err
|
||||
|
||||
|
||||
@@ -84,7 +84,7 @@ DEFAULT_CONFIG = with_common_config({
|
||||
"epsilon": 0.1,
|
||||
# balancing the three losses
|
||||
"vf_loss_coeff": 0.5,
|
||||
"entropy_coeff": -0.01,
|
||||
"entropy_coeff": 0.01,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
@@ -110,6 +110,8 @@ class ImpalaAgent(Agent):
|
||||
self.optimizer = AsyncSamplesOptimizer(self.local_evaluator,
|
||||
self.remote_evaluators,
|
||||
self.config["optimizer"])
|
||||
if self.config["entropy_coeff"] < 0:
|
||||
raise DeprecationWarning("entropy_coeff must be >= 0")
|
||||
|
||||
@override(Agent)
|
||||
def _train(self):
|
||||
|
||||
@@ -35,7 +35,7 @@ class VTraceLoss(object):
|
||||
bootstrap_value,
|
||||
valid_mask,
|
||||
vf_loss_coeff=0.5,
|
||||
entropy_coeff=-0.01,
|
||||
entropy_coeff=0.01,
|
||||
clip_rho_threshold=1.0,
|
||||
clip_pg_rho_threshold=1.0):
|
||||
"""Policy gradient loss with vtrace importance weighting.
|
||||
@@ -94,7 +94,7 @@ class VTraceLoss(object):
|
||||
tf.boolean_mask(actions_entropy, valid_mask))
|
||||
|
||||
# The summed weighted loss
|
||||
self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff +
|
||||
self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff -
|
||||
self.entropy * entropy_coeff)
|
||||
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ DEFAULT_CONFIG = with_base_config(impala.DEFAULT_CONFIG, {
|
||||
"momentum": 0.0,
|
||||
"epsilon": 0.1,
|
||||
"vf_loss_coeff": 0.5,
|
||||
"entropy_coeff": -0.01,
|
||||
"entropy_coeff": 0.01,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
||||
@@ -48,7 +48,7 @@ class PPOSurrogateLoss(object):
|
||||
advantages,
|
||||
value_targets,
|
||||
vf_loss_coeff=0.5,
|
||||
entropy_coeff=-0.01,
|
||||
entropy_coeff=0.01,
|
||||
clip_param=0.3):
|
||||
|
||||
logp_ratio = tf.exp(actions_logp - prev_actions_logp)
|
||||
@@ -71,7 +71,7 @@ class PPOSurrogateLoss(object):
|
||||
tf.boolean_mask(actions_entropy, valid_mask))
|
||||
|
||||
# The summed weighted loss
|
||||
self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff +
|
||||
self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff -
|
||||
self.entropy * entropy_coeff)
|
||||
|
||||
|
||||
@@ -91,7 +91,7 @@ class VTraceSurrogateLoss(object):
|
||||
bootstrap_value,
|
||||
valid_mask,
|
||||
vf_loss_coeff=0.5,
|
||||
entropy_coeff=-0.01,
|
||||
entropy_coeff=0.01,
|
||||
clip_rho_threshold=1.0,
|
||||
clip_pg_rho_threshold=1.0,
|
||||
clip_param=0.3):
|
||||
@@ -152,7 +152,7 @@ class VTraceSurrogateLoss(object):
|
||||
tf.boolean_mask(actions_entropy, valid_mask))
|
||||
|
||||
# The summed weighted loss
|
||||
self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff +
|
||||
self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff -
|
||||
self.entropy * entropy_coeff)
|
||||
|
||||
|
||||
|
||||
@@ -150,6 +150,8 @@ class PPOAgent(Agent):
|
||||
return res
|
||||
|
||||
def _validate_config(self):
|
||||
if self.config["entropy_coeff"] < 0:
|
||||
raise DeprecationWarning("entropy_coeff must be >= 0")
|
||||
if self.config["sgd_minibatch_size"] > self.config["train_batch_size"]:
|
||||
raise ValueError(
|
||||
"Minibatch size {} must be <= train batch size {}.".format(
|
||||
|
||||
@@ -6,7 +6,7 @@ pong-a3c-pytorch-cnn:
|
||||
sample_batch_size: 20
|
||||
use_pytorch: true
|
||||
vf_loss_coeff: 0.5
|
||||
entropy_coeff: -0.01
|
||||
entropy_coeff: 0.01
|
||||
gamma: 0.99
|
||||
grad_clip: 40.0
|
||||
lambda: 1.0
|
||||
|
||||
@@ -8,7 +8,7 @@ pong-a3c:
|
||||
sample_batch_size: 20
|
||||
use_pytorch: false
|
||||
vf_loss_coeff: 0.5
|
||||
entropy_coeff: -0.01
|
||||
entropy_coeff: 0.01
|
||||
gamma: 0.99
|
||||
grad_clip: 40.0
|
||||
lambda: 1.0
|
||||
|
||||
Reference in New Issue
Block a user