[rllib] Use nested scope in custom loss example

This commit is contained in:
Eric Liang
2019-03-04 18:29:22 -08:00
committed by GitHub
parent df9beb7123
commit 30bf8e46c7
3 changed files with 14 additions and 15 deletions
+8 -9
View File
@@ -31,7 +31,7 @@ parser.add_argument(
type=str,
default=os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"../test/data/cartpole_small"))
"../tests/data/cartpole_small"))
class CustomLossModel(Model):
@@ -39,8 +39,9 @@ class CustomLossModel(Model):
def _build_layers_v2(self, input_dict, num_outputs, options):
self.obs_in = input_dict["obs"]
self.fcnet = FullyConnectedNetwork(input_dict, self.obs_space,
num_outputs, options)
with tf.variable_scope("shared", reuse=tf.AUTO_REUSE):
self.fcnet = FullyConnectedNetwork(input_dict, self.obs_space,
num_outputs, options)
return self.fcnet.outputs, self.fcnet.last_layer
def custom_loss(self, policy_loss, loss_inputs):
@@ -49,12 +50,10 @@ class CustomLossModel(Model):
input_ops = reader.tf_input_ops()
# define a secondary loss by building a graph copy with weight sharing
with tf.variable_scope(
self.scope, reuse=tf.AUTO_REUSE, auxiliary_name_scope=False):
logits, _ = self._build_layers_v2({
"obs": restore_original_dimensions(input_ops["obs"],
self.obs_space)
}, self.num_outputs, self.options)
logits, _ = self._build_layers_v2({
"obs": restore_original_dimensions(input_ops["obs"],
self.obs_space)
}, self.num_outputs, self.options)
# You can also add self-supervised losses easily by referencing tensors
# created during _build_layers_v2(). For example, an autoencoder-style
+3 -6
View File
@@ -45,12 +45,9 @@ class InputReader(object):
... def custom_loss(self, policy_loss, loss_inputs):
... reader = JsonReader(...)
... input_ops = reader.tf_input_ops()
... with tf.variable_scope(
... self.scope, reuse=tf.AUTO_REUSE,
... auxiliary_name_scope=False):
... logits, _ = self._build_layers_v2(
... {"obs": input_ops["obs"]},
... self.num_outputs, self.options)
... logits, _ = self._build_layers_v2(
... {"obs": input_ops["obs"]},
... self.num_outputs, self.options)
... il_loss = imitation_loss(logits, input_ops["action"])
... return policy_loss + il_loss