diff --git a/python/ray/experimental/tf_utils.py b/python/ray/experimental/tf_utils.py index bb424134e..d2f1b2599 100644 --- a/python/ray/experimental/tf_utils.py +++ b/python/ray/experimental/tf_utils.py @@ -76,7 +76,8 @@ class TensorFlowVariables(object): if control not in explored_inputs: queue.append(control) explored_inputs.add(control) - if "Variable" in tf_obj.node_def.op: + if ("Variable" in tf_obj.node_def.op + or "VarHandle" in tf_obj.node_def.op): variable_names.append(tf_obj.node_def.name) self.variables = OrderedDict() variable_list = [