mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 12:54:27 +08:00
[tune] Fix for keras threading (#5517)
This commit is contained in:
committed by
Philipp Moritz
parent
dbf7089c79
commit
d2a6f7958a
@@ -2,6 +2,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
import keras
|
||||
from keras.datasets import mnist
|
||||
from keras import backend as K
|
||||
@@ -52,8 +53,8 @@ def set_keras_threads(threads):
|
||||
# We set threads here to avoid contention, as Keras
|
||||
# is heavily parallelized across multiple cores.
|
||||
K.set_session(
|
||||
K.tf.Session(
|
||||
config=K.tf.ConfigProto(
|
||||
tf.Session(
|
||||
config=tf.ConfigProto(
|
||||
intra_op_parallelism_threads=threads,
|
||||
inter_op_parallelism_threads=threads)))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user