mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 12:38:39 +08:00
[rllib] Use nested scope in custom loss example
This commit is contained in:
@@ -31,7 +31,7 @@ parser.add_argument(
|
||||
type=str,
|
||||
default=os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"../test/data/cartpole_small"))
|
||||
"../tests/data/cartpole_small"))
|
||||
|
||||
|
||||
class CustomLossModel(Model):
|
||||
@@ -39,8 +39,9 @@ class CustomLossModel(Model):
|
||||
|
||||
def _build_layers_v2(self, input_dict, num_outputs, options):
|
||||
self.obs_in = input_dict["obs"]
|
||||
self.fcnet = FullyConnectedNetwork(input_dict, self.obs_space,
|
||||
num_outputs, options)
|
||||
with tf.variable_scope("shared", reuse=tf.AUTO_REUSE):
|
||||
self.fcnet = FullyConnectedNetwork(input_dict, self.obs_space,
|
||||
num_outputs, options)
|
||||
return self.fcnet.outputs, self.fcnet.last_layer
|
||||
|
||||
def custom_loss(self, policy_loss, loss_inputs):
|
||||
@@ -49,12 +50,10 @@ class CustomLossModel(Model):
|
||||
input_ops = reader.tf_input_ops()
|
||||
|
||||
# define a secondary loss by building a graph copy with weight sharing
|
||||
with tf.variable_scope(
|
||||
self.scope, reuse=tf.AUTO_REUSE, auxiliary_name_scope=False):
|
||||
logits, _ = self._build_layers_v2({
|
||||
"obs": restore_original_dimensions(input_ops["obs"],
|
||||
self.obs_space)
|
||||
}, self.num_outputs, self.options)
|
||||
logits, _ = self._build_layers_v2({
|
||||
"obs": restore_original_dimensions(input_ops["obs"],
|
||||
self.obs_space)
|
||||
}, self.num_outputs, self.options)
|
||||
|
||||
# You can also add self-supervised losses easily by referencing tensors
|
||||
# created during _build_layers_v2(). For example, an autoencoder-style
|
||||
|
||||
@@ -45,12 +45,9 @@ class InputReader(object):
|
||||
... def custom_loss(self, policy_loss, loss_inputs):
|
||||
... reader = JsonReader(...)
|
||||
... input_ops = reader.tf_input_ops()
|
||||
... with tf.variable_scope(
|
||||
... self.scope, reuse=tf.AUTO_REUSE,
|
||||
... auxiliary_name_scope=False):
|
||||
... logits, _ = self._build_layers_v2(
|
||||
... {"obs": input_ops["obs"]},
|
||||
... self.num_outputs, self.options)
|
||||
... logits, _ = self._build_layers_v2(
|
||||
... {"obs": input_ops["obs"]},
|
||||
... self.num_outputs, self.options)
|
||||
... il_loss = imitation_loss(logits, input_ops["action"])
|
||||
... return policy_loss + il_loss
|
||||
|
||||
|
||||
Reference in New Issue
Block a user