mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 14:02:36 +08:00
60d4d5e1aa
* Remove all __future__ imports from RLlib. * Remove (object) again from tf_run_builder.py::TFRunBuilder. * Fix 2xLINT warnings. * Fix broken appo_policy import (must be appo_tf_policy) * Remove future imports from all other ray files (not just RLlib). * Remove future imports from all other ray files (not just RLlib). * Remove future import blocks that contain `unicode_literals` as well. Revert appo_tf_policy.py to appo_policy.py (belongs to another PR). * Add two empty lines before Schedule class. * Put back __future__ imports into determine_tests_to_run.py. Fails otherwise on a py2/print related error.
63 lines
1.4 KiB
Python
63 lines
1.4 KiB
Python
import logging
|
|
import os
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def try_import_tf():
|
|
"""
|
|
Returns:
|
|
The tf module (either from tf2.0.compat.v1 OR as tf1.x.
|
|
"""
|
|
if "RLLIB_TEST_NO_TF_IMPORT" in os.environ:
|
|
logger.warning("Not importing TensorFlow for test purposes")
|
|
return None
|
|
|
|
try:
|
|
if "TF_CPP_MIN_LOG_LEVEL" not in os.environ:
|
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
import tensorflow.compat.v1 as tf
|
|
tf.logging.set_verbosity(tf.logging.ERROR)
|
|
tf.disable_v2_behavior()
|
|
return tf
|
|
except ImportError:
|
|
try:
|
|
import tensorflow as tf
|
|
return tf
|
|
except ImportError:
|
|
return None
|
|
|
|
|
|
def try_import_tfp():
|
|
"""
|
|
Returns:
|
|
The tfp module.
|
|
"""
|
|
if "RLLIB_TEST_NO_TF_IMPORT" in os.environ:
|
|
logger.warning("Not importing TensorFlow Probability for test "
|
|
"purposes.")
|
|
return None
|
|
|
|
try:
|
|
import tensorflow_probability as tfp
|
|
return tfp
|
|
except ImportError:
|
|
return None
|
|
|
|
|
|
def try_import_torch():
|
|
"""
|
|
Returns:
|
|
tuple: torch AND torch.nn modules.
|
|
"""
|
|
if "RLLIB_TEST_NO_TORCH_IMPORT" in os.environ:
|
|
logger.warning("Not importing Torch for test purposes.")
|
|
return None, None
|
|
|
|
try:
|
|
import torch
|
|
import torch.nn as nn
|
|
return torch, nn
|
|
except ImportError:
|
|
return None, None
|