diff --git a/examples/resnet/resnet_main.py b/examples/resnet/resnet_main.py index 51bc4f761..a8f6d711e 100644 --- a/examples/resnet/resnet_main.py +++ b/examples/resnet/resnet_main.py @@ -97,6 +97,7 @@ class ResNetTrainActor(object): self.model = resnet_model.ResNet(hps, images, labels, "train") self.model.build_graph() config = tf.ConfigProto(allow_soft_placement=True) + config.gpu_options.allow_growth = True sess = tf.Session(config=config) self.model.variables.set_session(sess) self.coord = tf.train.Coordinator() @@ -144,6 +145,7 @@ class ResNetTestActor(object): self.model = resnet_model.ResNet(hps, images, labels, "eval") self.model.build_graph() config = tf.ConfigProto(allow_soft_placement=True) + config.gpu_options.allow_growth = True sess = tf.Session(config=config) self.model.variables.set_session(sess) self.coord = tf.train.Coordinator()