diff --git a/rllib/agents/a3c/a3c_tf_policy.py b/rllib/agents/a3c/a3c_tf_policy.py index 696876ae1..c0d4158c0 100644 --- a/rllib/agents/a3c/a3c_tf_policy.py +++ b/rllib/agents/a3c/a3c_tf_policy.py @@ -2,13 +2,12 @@ import ray from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.utils.explained_variance import explained_variance from ray.rllib.evaluation.postprocessing import compute_advantages, \ Postprocessing from ray.rllib.policy.tf_policy_template import build_tf_policy from ray.rllib.policy.tf_policy import LearningRateSchedule -from ray.rllib.utils.tf_ops import make_tf_callable -from ray.rllib.utils import try_import_tf +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.tf_ops import explained_variance, make_tf_callable tf = try_import_tf() diff --git a/rllib/agents/ddpg/ddpg_tf_model.py b/rllib/agents/ddpg/ddpg_tf_model.py index 21b505332..dcaa17aab 100644 --- a/rllib/agents/ddpg/ddpg_tf_model.py +++ b/rllib/agents/ddpg/ddpg_tf_model.py @@ -1,7 +1,7 @@ import numpy as np from ray.rllib.models.tf.tf_modelv2 import TFModelV2 -from ray.rllib.utils import try_import_tf +from ray.rllib.utils.framework import try_import_tf tf = try_import_tf() diff --git a/rllib/agents/ddpg/ddpg_tf_policy.py b/rllib/agents/ddpg/ddpg_tf_policy.py index b57b53de1..a8862beff 100644 --- a/rllib/agents/ddpg/ddpg_tf_policy.py +++ b/rllib/agents/ddpg/ddpg_tf_policy.py @@ -18,7 +18,7 @@ from ray.rllib.utils.annotations import override from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.policy.tf_policy import TFPolicy from ray.rllib.policy.tf_policy_template import build_tf_policy -from ray.rllib.utils import try_import_tf +from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.tf_ops import huber_loss, minimize_and_clip, \ make_tf_callable diff --git a/rllib/agents/ddpg/noop_model.py b/rllib/agents/ddpg/noop_model.py index 6c9145e1b..8da8af4ed 100644 --- a/rllib/agents/ddpg/noop_model.py +++ b/rllib/agents/ddpg/noop_model.py @@ -2,7 +2,7 @@ from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.utils.annotations import override -from ray.rllib.utils import try_import_tf +from ray.rllib.utils.framework import try_import_tf tf = try_import_tf() diff --git a/rllib/agents/dqn/distributional_q_tf_model.py b/rllib/agents/dqn/distributional_q_tf_model.py index c2220e0da..a3cf1de8b 100644 --- a/rllib/agents/dqn/distributional_q_tf_model.py +++ b/rllib/agents/dqn/distributional_q_tf_model.py @@ -1,7 +1,7 @@ import numpy as np from ray.rllib.models.tf.tf_modelv2 import TFModelV2 -from ray.rllib.utils import try_import_tf +from ray.rllib.utils.framework import try_import_tf tf = try_import_tf() diff --git a/rllib/agents/dqn/dqn_torch_model.py b/rllib/agents/dqn/dqn_torch_model.py index 2b7e9ca9e..25f0fb056 100644 --- a/rllib/agents/dqn/dqn_torch_model.py +++ b/rllib/agents/dqn/dqn_torch_model.py @@ -1,7 +1,7 @@ import numpy as np from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 -from ray.rllib.utils import try_import_torch +from ray.rllib.utils.framework import try_import_torch torch, nn = try_import_torch() diff --git a/rllib/agents/dqn/dqn_torch_policy.py b/rllib/agents/dqn/dqn_torch_policy.py index 22bf7905f..f9d00f82d 100644 --- a/rllib/agents/dqn/dqn_torch_policy.py +++ b/rllib/agents/dqn/dqn_torch_policy.py @@ -13,8 +13,8 @@ from ray.rllib.policy.torch_policy import LearningRateSchedule from ray.rllib.policy.torch_policy_template import build_torch_policy from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.exploration.parameter_noise import ParameterNoise +from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.torch_ops import huber_loss, reduce_mean_ignore_inf -from ray.rllib.utils import try_import_torch torch, nn = try_import_torch() F = None diff --git a/rllib/agents/dqn/simple_q_model.py b/rllib/agents/dqn/simple_q_model.py index 38febe052..432071775 100644 --- a/rllib/agents/dqn/simple_q_model.py +++ b/rllib/agents/dqn/simple_q_model.py @@ -1,5 +1,5 @@ from ray.rllib.models.tf.tf_modelv2 import TFModelV2 -from ray.rllib.utils import try_import_tf +from ray.rllib.utils.framework import try_import_tf tf = try_import_tf() diff --git a/rllib/agents/dqn/simple_q_tf_policy.py b/rllib/agents/dqn/simple_q_tf_policy.py index bc76e4954..a9879c434 100644 --- a/rllib/agents/dqn/simple_q_tf_policy.py +++ b/rllib/agents/dqn/simple_q_tf_policy.py @@ -12,7 +12,7 @@ from ray.rllib.utils.annotations import override from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.policy.tf_policy import TFPolicy from ray.rllib.policy.tf_policy_template import build_tf_policy -from ray.rllib.utils import try_import_tf +from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.tf_ops import huber_loss, make_tf_callable tf = try_import_tf() diff --git a/rllib/agents/dqn/simple_q_torch_policy.py b/rllib/agents/dqn/simple_q_torch_policy.py index a240f721d..da584a465 100644 --- a/rllib/agents/dqn/simple_q_torch_policy.py +++ b/rllib/agents/dqn/simple_q_torch_policy.py @@ -8,7 +8,7 @@ from ray.rllib.agents.dqn.simple_q_tf_policy import build_q_models, \ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.models.torch.torch_action_dist import TorchCategorical from ray.rllib.policy.torch_policy_template import build_torch_policy -from ray.rllib.utils import try_import_torch +from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.torch_ops import huber_loss torch, nn = try_import_torch() diff --git a/rllib/agents/impala/vtrace_tf.py b/rllib/agents/impala/vtrace_tf.py index 9d3e113f0..aa6ab5c7a 100644 --- a/rllib/agents/impala/vtrace_tf.py +++ b/rllib/agents/impala/vtrace_tf.py @@ -31,7 +31,7 @@ tensors. import collections from ray.rllib.models.tf.tf_action_dist import Categorical -from ray.rllib.utils import try_import_tf +from ray.rllib.utils.framework import try_import_tf tf = try_import_tf() diff --git a/rllib/agents/impala/vtrace_tf_policy.py b/rllib/agents/impala/vtrace_tf_policy.py index 05d173738..61816ff39 100644 --- a/rllib/agents/impala/vtrace_tf_policy.py +++ b/rllib/agents/impala/vtrace_tf_policy.py @@ -13,8 +13,8 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.tf_policy_template import build_tf_policy from ray.rllib.policy.tf_policy import LearningRateSchedule, \ EntropyCoeffSchedule -from ray.rllib.utils.explained_variance import explained_variance -from ray.rllib.utils import try_import_tf +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.tf_ops import explained_variance tf = try_import_tf() diff --git a/rllib/agents/impala/vtrace_torch_policy.py b/rllib/agents/impala/vtrace_torch_policy.py index 7a0a3cec1..f1ee1edb3 100644 --- a/rllib/agents/impala/vtrace_torch_policy.py +++ b/rllib/agents/impala/vtrace_torch_policy.py @@ -11,9 +11,9 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy import LearningRateSchedule, \ EntropyCoeffSchedule from ray.rllib.policy.torch_policy_template import build_torch_policy -from ray.rllib.utils.explained_variance import explained_variance from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.torch_ops import global_norm, sequence_mask +from ray.rllib.utils.torch_ops import explained_variance, global_norm, \ + sequence_mask torch, nn = try_import_torch() @@ -239,8 +239,7 @@ def stats(policy, train_batch): "vf_loss": policy.loss.vf_loss, "vf_explained_var": explained_variance( torch.reshape(policy.loss.value_targets, [-1]), - torch.reshape(values_batched, [-1]), - framework="torch"), + torch.reshape(values_batched, [-1])), } diff --git a/rllib/agents/marwil/marwil_tf_policy.py b/rllib/agents/marwil/marwil_tf_policy.py index eab6f4ee9..06d5f6848 100644 --- a/rllib/agents/marwil/marwil_tf_policy.py +++ b/rllib/agents/marwil/marwil_tf_policy.py @@ -1,11 +1,10 @@ import ray from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.utils.explained_variance import explained_variance from ray.rllib.evaluation.postprocessing import compute_advantages, \ Postprocessing from ray.rllib.policy.tf_policy_template import build_tf_policy -from ray.rllib.utils.tf_ops import make_tf_callable -from ray.rllib.utils import try_import_tf +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.tf_ops import explained_variance, make_tf_callable tf = try_import_tf() diff --git a/rllib/agents/marwil/marwil_torch_policy.py b/rllib/agents/marwil/marwil_torch_policy.py index ea06b1b24..fa2452d92 100644 --- a/rllib/agents/marwil/marwil_torch_policy.py +++ b/rllib/agents/marwil/marwil_torch_policy.py @@ -3,8 +3,8 @@ from ray.rllib.agents.marwil.marwil_tf_policy import postprocess_advantages from ray.rllib.evaluation.postprocessing import Postprocessing from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy_template import build_torch_policy -from ray.rllib.utils.explained_variance import explained_variance from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.torch_ops import explained_variance torch, _ = try_import_torch() @@ -52,8 +52,7 @@ def marwil_loss(policy, model, dist_class, train_batch): # Combine both losses. policy.total_loss = policy.p_loss + policy.config["vf_coeff"] * \ policy.v_loss - explained_var = explained_variance( - advantages, state_values, framework="torch") + explained_var = explained_variance(advantages, state_values) policy.explained_variance = torch.mean(explained_var) return policy.total_loss diff --git a/rllib/agents/pg/pg_tf_policy.py b/rllib/agents/pg/pg_tf_policy.py index 79b84673a..8f937a8de 100644 --- a/rllib/agents/pg/pg_tf_policy.py +++ b/rllib/agents/pg/pg_tf_policy.py @@ -3,7 +3,7 @@ from ray.rllib.evaluation.postprocessing import Postprocessing, \ compute_advantages from ray.rllib.policy.tf_policy_template import build_tf_policy from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.utils import try_import_tf +from ray.rllib.utils.framework import try_import_tf tf = try_import_tf() diff --git a/rllib/agents/ppo/appo_tf_policy.py b/rllib/agents/ppo/appo_tf_policy.py index d04762c4b..af833a8c9 100644 --- a/rllib/agents/ppo/appo_tf_policy.py +++ b/rllib/agents/ppo/appo_tf_policy.py @@ -13,14 +13,13 @@ from ray.rllib.evaluation.postprocessing import Postprocessing from ray.rllib.models.tf.tf_action_dist import Categorical from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.evaluation.postprocessing import compute_advantages -from ray.rllib.utils import try_import_tf from ray.rllib.policy.tf_policy_template import build_tf_policy from ray.rllib.policy.tf_policy import LearningRateSchedule, TFPolicy from ray.rllib.agents.ppo.ppo_tf_policy import KLCoeffMixin, ValueNetworkMixin from ray.rllib.models import ModelCatalog from ray.rllib.utils.annotations import override -from ray.rllib.utils.explained_variance import explained_variance -from ray.rllib.utils.tf_ops import make_tf_callable +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.tf_ops import explained_variance, make_tf_callable tf = try_import_tf() diff --git a/rllib/agents/ppo/appo_torch_policy.py b/rllib/agents/ppo/appo_torch_policy.py index b5dfbc961..a1c998e48 100644 --- a/rllib/agents/ppo/appo_torch_policy.py +++ b/rllib/agents/ppo/appo_torch_policy.py @@ -19,9 +19,9 @@ from ray.rllib.models.torch.torch_action_dist import TorchCategorical from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy import LearningRateSchedule from ray.rllib.policy.torch_policy_template import build_torch_policy -from ray.rllib.utils.explained_variance import explained_variance from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.torch_ops import global_norm, sequence_mask +from ray.rllib.utils.torch_ops import explained_variance, global_norm, \ + sequence_mask torch, nn = try_import_torch() @@ -353,8 +353,7 @@ def stats(policy, train_batch): "vf_loss": policy.loss.vf_loss, "vf_explained_var": explained_variance( torch.reshape(policy.loss.value_targets, [-1]), - torch.reshape(values_batched, [-1]), - framework="torch"), + torch.reshape(values_batched, [-1])), } if policy.config["vtrace"]: diff --git a/rllib/agents/ppo/ppo.py b/rllib/agents/ppo/ppo.py index d199f483d..082a69696 100644 --- a/rllib/agents/ppo/ppo.py +++ b/rllib/agents/ppo/ppo.py @@ -7,7 +7,7 @@ from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches, \ StandardizeFields, SelectExperiences from ray.rllib.execution.train_ops import TrainOneStep, TrainTFMultiGPU from ray.rllib.execution.metric_ops import StandardMetricsReporting -from ray.rllib.utils import try_import_tf +from ray.rllib.utils.framework import try_import_tf tf = try_import_tf() diff --git a/rllib/agents/ppo/ppo_tf_policy.py b/rllib/agents/ppo/ppo_tf_policy.py index 409bb0291..54f1386ea 100644 --- a/rllib/agents/ppo/ppo_tf_policy.py +++ b/rllib/agents/ppo/ppo_tf_policy.py @@ -7,9 +7,8 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.tf_policy import LearningRateSchedule, \ EntropyCoeffSchedule from ray.rllib.policy.tf_policy_template import build_tf_policy -from ray.rllib.utils.explained_variance import explained_variance -from ray.rllib.utils.tf_ops import make_tf_callable -from ray.rllib.utils import try_import_tf +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.tf_ops import explained_variance, make_tf_callable tf = try_import_tf() diff --git a/rllib/agents/ppo/ppo_torch_policy.py b/rllib/agents/ppo/ppo_torch_policy.py index a878b70d8..0c4c3cd7a 100644 --- a/rllib/agents/ppo/ppo_torch_policy.py +++ b/rllib/agents/ppo/ppo_torch_policy.py @@ -9,9 +9,8 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy import EntropyCoeffSchedule, \ LearningRateSchedule from ray.rllib.policy.torch_policy_template import build_torch_policy -from ray.rllib.utils.explained_variance import explained_variance -from ray.rllib.utils.torch_ops import sequence_mask -from ray.rllib.utils import try_import_torch +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.torch_ops import explained_variance, sequence_mask torch, nn = try_import_torch() @@ -152,8 +151,7 @@ def kl_and_loss_stats(policy, train_batch): "vf_loss": policy.loss_obj.mean_vf_loss, "vf_explained_var": explained_variance( train_batch[Postprocessing.VALUE_TARGETS], - policy.model.value_function(), - framework="torch"), + policy.model.value_function()), "kl": policy.loss_obj.mean_kl, "entropy": policy.loss_obj.mean_entropy, "entropy_coeff": policy.entropy_coeff, diff --git a/rllib/agents/ppo/tests/test.py b/rllib/agents/ppo/tests/test.py index 9deb84507..956338d75 100644 --- a/rllib/agents/ppo/tests/test.py +++ b/rllib/agents/ppo/tests/test.py @@ -3,7 +3,7 @@ import numpy as np from numpy.testing import assert_allclose from ray.rllib.agents.ppo.utils import flatten, concatenate -from ray.rllib.utils import try_import_tf +from ray.rllib.utils.framework import try_import_tf tf = try_import_tf() diff --git a/rllib/agents/qmix/model.py b/rllib/agents/qmix/model.py index 654f62648..0c7a6d117 100644 --- a/rllib/agents/qmix/model.py +++ b/rllib/agents/qmix/model.py @@ -2,7 +2,7 @@ from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.preprocessors import get_preprocessor from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.utils.annotations import override -from ray.rllib.utils import try_import_torch +from ray.rllib.utils.framework import try_import_torch torch, nn = try_import_torch() diff --git a/rllib/agents/sac/sac_torch_policy.py b/rllib/agents/sac/sac_torch_policy.py index 503bc4979..a3444ce0a 100644 --- a/rllib/agents/sac/sac_torch_policy.py +++ b/rllib/agents/sac/sac_torch_policy.py @@ -11,7 +11,7 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy_template import build_torch_policy from ray.rllib.models.torch.torch_action_dist import ( TorchCategorical, TorchSquashedGaussian, TorchDiagGaussian, TorchBeta) -from ray.rllib.utils import try_import_torch +from ray.rllib.utils.framework import try_import_torch torch, nn = try_import_torch() F = nn.functional diff --git a/rllib/contrib/alpha_zero/core/alpha_zero_policy.py b/rllib/contrib/alpha_zero/core/alpha_zero_policy.py index 3fab613af..2ae81299f 100644 --- a/rllib/contrib/alpha_zero/core/alpha_zero_policy.py +++ b/rllib/contrib/alpha_zero/core/alpha_zero_policy.py @@ -2,9 +2,9 @@ import numpy as np from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY from ray.rllib.policy.torch_policy import TorchPolicy -from ray.rllib.utils.annotations import override from ray.rllib.contrib.alpha_zero.core.mcts import Node, RootParentNode -from ray.rllib.utils import try_import_torch +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_torch torch, _ = try_import_torch() diff --git a/rllib/contrib/alpha_zero/core/alpha_zero_trainer.py b/rllib/contrib/alpha_zero/core/alpha_zero_trainer.py index 61c98d802..fa0345455 100644 --- a/rllib/contrib/alpha_zero/core/alpha_zero_trainer.py +++ b/rllib/contrib/alpha_zero/core/alpha_zero_trainer.py @@ -12,7 +12,7 @@ from ray.rllib.execution.metric_ops import StandardMetricsReporting from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.model import restore_original_dimensions from ray.rllib.models.torch.torch_action_dist import TorchCategorical -from ray.rllib.utils import try_import_tf, try_import_torch +from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.tune.registry import ENV_CREATOR, _global_registry from ray.rllib.contrib.alpha_zero.core.alpha_zero_policy import AlphaZeroPolicy diff --git a/rllib/contrib/alpha_zero/models/custom_torch_models.py b/rllib/contrib/alpha_zero/models/custom_torch_models.py index 260c55837..1b4e4da78 100644 --- a/rllib/contrib/alpha_zero/models/custom_torch_models.py +++ b/rllib/contrib/alpha_zero/models/custom_torch_models.py @@ -4,7 +4,7 @@ import numpy as np from ray.rllib.models.model import restore_original_dimensions from ray.rllib.models.preprocessors import get_preprocessor from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 -from ray.rllib.utils import try_import_torch +from ray.rllib.utils.framework import try_import_torch torch, nn = try_import_torch() diff --git a/rllib/contrib/bandits/models/linear_regression.py b/rllib/contrib/bandits/models/linear_regression.py index 76ff2203a..8dbf76980 100644 --- a/rllib/contrib/bandits/models/linear_regression.py +++ b/rllib/contrib/bandits/models/linear_regression.py @@ -2,8 +2,8 @@ import gym from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 -from ray.rllib.utils import try_import_torch from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_torch torch, nn = try_import_torch() diff --git a/rllib/contrib/maddpg/maddpg_policy.py b/rllib/contrib/maddpg/maddpg_policy.py index 67215a45f..963de76fd 100644 --- a/rllib/contrib/maddpg/maddpg_policy.py +++ b/rllib/contrib/maddpg/maddpg_policy.py @@ -7,7 +7,7 @@ from ray.rllib.utils.annotations import override from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.policy.policy import Policy from ray.rllib.policy.tf_policy import TFPolicy -from ray.rllib.utils import try_import_tf, try_import_tfp +from ray.rllib.utils.framework import try_import_tf, try_import_tfp import logging from gym.spaces import Box, Discrete diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index 69acba5d2..75521daa0 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -31,9 +31,9 @@ from ray.rllib.utils import merge_dicts from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.debug import summarize from ray.rllib.utils.filter import get_filter +from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.sgd import do_minibatch_sgd from ray.rllib.utils.tf_run_builder import TFRunBuilder -from ray.rllib.utils import try_import_tf, try_import_torch tf = try_import_tf() torch, _ = try_import_torch() diff --git a/rllib/examples/attention_net.py b/rllib/examples/attention_net.py index c5a8899c9..02c8d96b8 100644 --- a/rllib/examples/attention_net.py +++ b/rllib/examples/attention_net.py @@ -2,12 +2,12 @@ import argparse import ray from ray import tune -from ray.rllib.utils import try_import_tf from ray.rllib.models.tf.attention_net import GTrXLNet from ray.rllib.examples.env.look_and_push import LookAndPush, OneHot from ray.rllib.examples.env.repeat_after_me_env import RepeatAfterMeEnv from ray.rllib.examples.env.repeat_initial_obs_env import RepeatInitialObsEnv from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole +from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.test_utils import check_learning_achieved from ray.tune import registry diff --git a/rllib/examples/batch_norm_model.py b/rllib/examples/batch_norm_model.py index e3ae8633d..fa41a0add 100644 --- a/rllib/examples/batch_norm_model.py +++ b/rllib/examples/batch_norm_model.py @@ -7,7 +7,7 @@ from ray import tune from ray.rllib.examples.models.batch_norm_model import BatchNormModel, \ TorchBatchNormModel from ray.rllib.models import ModelCatalog -from ray.rllib.utils import try_import_tf +from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.test_utils import check_learning_achieved tf = try_import_tf() diff --git a/rllib/examples/centralized_critic.py b/rllib/examples/centralized_critic.py index 32c63ed3a..260d8494e 100644 --- a/rllib/examples/centralized_critic.py +++ b/rllib/examples/centralized_critic.py @@ -34,10 +34,9 @@ from ray.rllib.policy.tf_policy import LearningRateSchedule, \ EntropyCoeffSchedule from ray.rllib.policy.torch_policy import LearningRateSchedule as TorchLR, \ EntropyCoeffSchedule as TorchEntropyCoeffSchedule -from ray.rllib.utils.explained_variance import explained_variance from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.test_utils import check_learning_achieved -from ray.rllib.utils.tf_ops import make_tf_callable +from ray.rllib.utils.tf_ops import explained_variance, make_tf_callable from ray.rllib.utils.torch_ops import convert_to_torch_tensor tf = try_import_tf() diff --git a/rllib/examples/custom_keras_model.py b/rllib/examples/custom_keras_model.py index 1596c53fc..aac7e41d4 100644 --- a/rllib/examples/custom_keras_model.py +++ b/rllib/examples/custom_keras_model.py @@ -4,13 +4,13 @@ import argparse import ray from ray import tune +from ray.rllib.agents.dqn.distributional_q_tf_model import \ + DistributionalQTFModel from ray.rllib.models import ModelCatalog from ray.rllib.models.tf.misc import normc_initializer from ray.rllib.models.tf.tf_modelv2 import TFModelV2 -from ray.rllib.agents.dqn.distributional_q_tf_model import \ - DistributionalQTFModel -from ray.rllib.utils import try_import_tf from ray.rllib.models.tf.visionnet import VisionNetwork as MyVisionNetwork +from ray.rllib.utils.framework import try_import_tf tf = try_import_tf() diff --git a/rllib/examples/custom_loss.py b/rllib/examples/custom_loss.py index 315e6799b..9d3d90348 100644 --- a/rllib/examples/custom_loss.py +++ b/rllib/examples/custom_loss.py @@ -19,7 +19,7 @@ from ray import tune from ray.rllib.examples.models.custom_loss_model import CustomLossModel, \ TorchCustomLossModel from ray.rllib.models import ModelCatalog -from ray.rllib.utils import try_import_tf +from ray.rllib.utils.framework import try_import_tf tf = try_import_tf() diff --git a/rllib/examples/custom_tf_policy.py b/rllib/examples/custom_tf_policy.py index ccb9b85d2..e2a919273 100644 --- a/rllib/examples/custom_tf_policy.py +++ b/rllib/examples/custom_tf_policy.py @@ -5,7 +5,7 @@ from ray import tune from ray.rllib.agents.trainer_template import build_trainer from ray.rllib.evaluation.postprocessing import discount from ray.rllib.policy.tf_policy_template import build_tf_policy -from ray.rllib.utils import try_import_tf +from ray.rllib.utils.framework import try_import_tf tf = try_import_tf() diff --git a/rllib/examples/export/cartpole_dqn_export.py b/rllib/examples/export/cartpole_dqn_export.py index d500644e0..46ab741a9 100644 --- a/rllib/examples/export/cartpole_dqn_export.py +++ b/rllib/examples/export/cartpole_dqn_export.py @@ -4,7 +4,7 @@ import os import ray from ray.rllib.agents.registry import get_agent_class -from ray.rllib.utils import try_import_tf +from ray.rllib.utils.framework import try_import_tf tf = try_import_tf() diff --git a/rllib/examples/mobilenet_v2_with_lstm.py b/rllib/examples/mobilenet_v2_with_lstm.py index ff578feff..e0f066a13 100644 --- a/rllib/examples/mobilenet_v2_with_lstm.py +++ b/rllib/examples/mobilenet_v2_with_lstm.py @@ -11,7 +11,7 @@ from ray.rllib.examples.env.random_env import RandomEnv from ray.rllib.examples.models.mobilenet_v2_with_lstm_models import \ MobileV2PlusRNNModel, TorchMobileV2PlusRNNModel from ray.rllib.models import ModelCatalog -from ray.rllib.utils import try_import_tf +from ray.rllib.utils.framework import try_import_tf tf = try_import_tf() diff --git a/rllib/examples/models/batch_norm_model.py b/rllib/examples/models/batch_norm_model.py index 3330abd90..762793de2 100644 --- a/rllib/examples/models/batch_norm_model.py +++ b/rllib/examples/models/batch_norm_model.py @@ -6,8 +6,8 @@ from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.models.torch.misc import SlimFC, normc_initializer as \ torch_normc_initializer from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 -from ray.rllib.utils import try_import_tf, try_import_torch from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_tf, try_import_torch tf = try_import_tf() torch, nn = try_import_torch() diff --git a/rllib/examples/models/shared_weights_model.py b/rllib/examples/models/shared_weights_model.py index 6d2d96253..137396a2f 100644 --- a/rllib/examples/models/shared_weights_model.py +++ b/rllib/examples/models/shared_weights_model.py @@ -5,7 +5,7 @@ from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.models.torch.misc import SlimFC from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.utils.annotations import override -from ray.rllib.utils import try_import_tf, try_import_torch +from ray.rllib.utils.framework import try_import_tf, try_import_torch tf = try_import_tf() torch, nn = try_import_torch() diff --git a/rllib/examples/models/simple_rpg_model.py b/rllib/examples/models/simple_rpg_model.py index a1b54962f..b77428745 100644 --- a/rllib/examples/models/simple_rpg_model.py +++ b/rllib/examples/models/simple_rpg_model.py @@ -1,8 +1,8 @@ -from ray.rllib.utils import try_import_tf, try_import_torch from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork as TFFCNet from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFCNet +from ray.rllib.utils.framework import try_import_tf, try_import_torch tf = try_import_tf() torch, nn = try_import_torch() diff --git a/rllib/execution/multi_gpu_impl.py b/rllib/execution/multi_gpu_impl.py index f307af7f8..0771bb18b 100644 --- a/rllib/execution/multi_gpu_impl.py +++ b/rllib/execution/multi_gpu_impl.py @@ -3,7 +3,7 @@ import logging from ray.util.debug import log_once from ray.rllib.utils.debug import summarize -from ray.rllib.utils import try_import_tf +from ray.rllib.utils.framework import try_import_tf tf = try_import_tf() diff --git a/rllib/execution/multi_gpu_learner.py b/rllib/execution/multi_gpu_learner.py index d4e8042ef..5d2d2c220 100644 --- a/rllib/execution/multi_gpu_learner.py +++ b/rllib/execution/multi_gpu_learner.py @@ -10,8 +10,8 @@ from ray.rllib.execution.learner_thread import LearnerThread from ray.rllib.execution.minibatch_buffer import MinibatchBuffer from ray.rllib.execution.multi_gpu_impl import LocalSyncParallelOptimizer from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.timer import TimerStat -from ray.rllib.utils import try_import_tf tf = try_import_tf() diff --git a/rllib/execution/train_ops.py b/rllib/execution/train_ops.py index ef4d465bb..860337cb0 100644 --- a/rllib/execution/train_ops.py +++ b/rllib/execution/train_ops.py @@ -17,7 +17,7 @@ from ray.rllib.execution.multi_gpu_impl import LocalSyncParallelOptimizer from ray.rllib.policy.policy import PolicyID from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ MultiAgentBatch -from ray.rllib.utils import try_import_tf +from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.sgd import do_minibatch_sgd, averaged tf = try_import_tf() diff --git a/rllib/models/tf/misc.py b/rllib/models/tf/misc.py index bddbdeba8..64034407a 100644 --- a/rllib/models/tf/misc.py +++ b/rllib/models/tf/misc.py @@ -1,5 +1,5 @@ import numpy as np -from ray.rllib.utils import try_import_tf +from ray.rllib.utils.framework import try_import_tf tf = try_import_tf() diff --git a/rllib/models/tf/modelv1_compat.py b/rllib/models/tf/modelv1_compat.py index d5ff99150..fb90c2bbf 100644 --- a/rllib/models/tf/modelv1_compat.py +++ b/rllib/models/tf/modelv1_compat.py @@ -6,7 +6,7 @@ from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.models.tf.misc import linear, normc_initializer from ray.rllib.utils.annotations import override -from ray.rllib.utils import try_import_tf +from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.tf_ops import scope_vars tf = try_import_tf() diff --git a/rllib/models/tf/recurrent_net.py b/rllib/models/tf/recurrent_net.py index 843efb430..47157465a 100644 --- a/rllib/models/tf/recurrent_net.py +++ b/rllib/models/tf/recurrent_net.py @@ -4,7 +4,7 @@ from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.policy.rnn_sequencing import add_time_dimension from ray.rllib.utils.annotations import override, DeveloperAPI -from ray.rllib.utils import try_import_tf +from ray.rllib.utils.framework import try_import_tf tf = try_import_tf() diff --git a/rllib/models/tf/tf_modelv2.py b/rllib/models/tf/tf_modelv2.py index 9bd54755c..f8b5859ee 100644 --- a/rllib/models/tf/tf_modelv2.py +++ b/rllib/models/tf/tf_modelv2.py @@ -1,7 +1,6 @@ from ray.rllib.models.modelv2 import ModelV2 -from ray.rllib.utils.annotations import PublicAPI -from ray.rllib.utils import try_import_tf -from ray.rllib.utils.annotations import override +from ray.rllib.utils.annotations import override, PublicAPI +from ray.rllib.utils.framework import try_import_tf tf = try_import_tf() diff --git a/rllib/models/torch/fcnet.py b/rllib/models/torch/fcnet.py index 12c4eebb3..da100f711 100644 --- a/rllib/models/torch/fcnet.py +++ b/rllib/models/torch/fcnet.py @@ -5,8 +5,7 @@ from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.models.torch.misc import SlimFC, AppendBiasLayer, \ normc_initializer from ray.rllib.utils.annotations import override -from ray.rllib.utils.framework import get_activation_fn -from ray.rllib.utils import try_import_torch +from ray.rllib.utils.framework import get_activation_fn, try_import_torch torch, nn = try_import_torch() diff --git a/rllib/models/torch/misc.py b/rllib/models/torch/misc.py index 71b223187..1679e4715 100644 --- a/rllib/models/torch/misc.py +++ b/rllib/models/torch/misc.py @@ -1,7 +1,7 @@ """ Code adapted from https://github.com/ikostrikov/pytorch-a3c""" import numpy as np -from ray.rllib.utils import try_import_torch +from ray.rllib.utils.framework import try_import_torch torch, nn = try_import_torch() diff --git a/rllib/models/torch/torch_modelv2.py b/rllib/models/torch/torch_modelv2.py index bbb4338b8..cfbe48ad3 100644 --- a/rllib/models/torch/torch_modelv2.py +++ b/rllib/models/torch/torch_modelv2.py @@ -1,6 +1,6 @@ from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.utils.annotations import override, PublicAPI -from ray.rllib.utils import try_import_torch +from ray.rllib.utils.framework import try_import_torch _, nn = try_import_torch() diff --git a/rllib/models/torch/visionnet.py b/rllib/models/torch/visionnet.py index 66bb04f78..c3aa3d51c 100644 --- a/rllib/models/torch/visionnet.py +++ b/rllib/models/torch/visionnet.py @@ -3,8 +3,7 @@ from ray.rllib.models.torch.misc import normc_initializer, valid_padding, \ SlimConv2d, SlimFC from ray.rllib.models.tf.visionnet_v1 import _get_filter_config from ray.rllib.utils.annotations import override -from ray.rllib.utils.framework import get_activation_fn -from ray.rllib.utils import try_import_torch +from ray.rllib.utils.framework import get_activation_fn, try_import_torch _, nn = try_import_torch() diff --git a/rllib/offline/input_reader.py b/rllib/offline/input_reader.py index ab22825b5..9fe5f4309 100644 --- a/rllib/offline/input_reader.py +++ b/rllib/offline/input_reader.py @@ -4,7 +4,7 @@ import threading from ray.rllib.policy.sample_batch import MultiAgentBatch from ray.rllib.utils.annotations import PublicAPI -from ray.rllib.utils import try_import_tf +from ray.rllib.utils.framework import try_import_tf tf = try_import_tf() diff --git a/rllib/optimizers/aso_multi_gpu_learner.py b/rllib/optimizers/aso_multi_gpu_learner.py index 75fc39ba6..1935e78c5 100644 --- a/rllib/optimizers/aso_multi_gpu_learner.py +++ b/rllib/optimizers/aso_multi_gpu_learner.py @@ -12,8 +12,8 @@ from ray.rllib.optimizers.aso_learner import LearnerThread from ray.rllib.optimizers.aso_minibatch_buffer import MinibatchBuffer from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.timer import TimerStat -from ray.rllib.utils import try_import_tf tf = try_import_tf() diff --git a/rllib/optimizers/multi_gpu_impl.py b/rllib/optimizers/multi_gpu_impl.py index f307af7f8..0771bb18b 100644 --- a/rllib/optimizers/multi_gpu_impl.py +++ b/rllib/optimizers/multi_gpu_impl.py @@ -3,7 +3,7 @@ import logging from ray.util.debug import log_once from ray.rllib.utils.debug import summarize -from ray.rllib.utils import try_import_tf +from ray.rllib.utils.framework import try_import_tf tf = try_import_tf() diff --git a/rllib/optimizers/multi_gpu_optimizer.py b/rllib/optimizers/multi_gpu_optimizer.py index cf33f44b0..20883ff83 100644 --- a/rllib/optimizers/multi_gpu_optimizer.py +++ b/rllib/optimizers/multi_gpu_optimizer.py @@ -9,12 +9,12 @@ from ray.rllib.policy.tf_policy import TFPolicy from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer from ray.rllib.optimizers.rollout import collect_samples -from ray.rllib.utils.annotations import override -from ray.rllib.utils.sgd import averaged -from ray.rllib.utils.timer import TimerStat from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ MultiAgentBatch -from ray.rllib.utils import try_import_tf +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.sgd import averaged +from ray.rllib.utils.timer import TimerStat tf = try_import_tf() diff --git a/rllib/optimizers/tests/test_optimizers.py b/rllib/optimizers/tests/test_optimizers.py index ebf431e45..35ff838de 100644 --- a/rllib/optimizers/tests/test_optimizers.py +++ b/rllib/optimizers/tests/test_optimizers.py @@ -12,7 +12,7 @@ from ray.rllib.optimizers import AsyncGradientsOptimizer, AsyncSamplesOptimizer from ray.rllib.optimizers.aso_tree_aggregator import TreeAggregator from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.tests.mock_worker import _MockWorker -from ray.rllib.utils import try_import_tf +from ray.rllib.utils.framework import try_import_tf tf = try_import_tf() diff --git a/rllib/policy/dynamic_tf_policy.py b/rllib/policy/dynamic_tf_policy.py index 64c7f9cee..d8502e91b 100644 --- a/rllib/policy/dynamic_tf_policy.py +++ b/rllib/policy/dynamic_tf_policy.py @@ -9,8 +9,9 @@ from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.tf_policy import TFPolicy from ray.rllib.models.catalog import ModelCatalog -from ray.rllib.utils import try_import_tf, override +from ray.rllib.utils.annotations import override from ray.rllib.utils.debug import summarize +from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.tracking_dict import UsageTrackingDict tf = try_import_tf() diff --git a/rllib/policy/tf_policy_template.py b/rllib/policy/tf_policy_template.py index c932ecde3..26bfd51e2 100644 --- a/rllib/policy/tf_policy_template.py +++ b/rllib/policy/tf_policy_template.py @@ -4,7 +4,7 @@ from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY from ray.rllib.policy.tf_policy import TFPolicy from ray.rllib.utils import add_mixins from ray.rllib.utils.annotations import override, DeveloperAPI -from ray.rllib.utils import try_import_tf +from ray.rllib.utils.framework import try_import_tf tf = try_import_tf() diff --git a/rllib/tests/test_nested_observation_spaces.py b/rllib/tests/test_nested_observation_spaces.py index 89f40d039..dabc85be1 100644 --- a/rllib/tests/test_nested_observation_spaces.py +++ b/rllib/tests/test_nested_observation_spaces.py @@ -19,7 +19,7 @@ from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.rollout import rollout from ray.rllib.tests.test_external_env import SimpleServing from ray.tune.registry import register_env -from ray.rllib.utils import try_import_tf, try_import_torch +from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.spaces.repeated import Repeated tf = try_import_tf() diff --git a/rllib/tuned_examples/debug_learning_failure_git_bisect.py b/rllib/tuned_examples/debug_learning_failure_git_bisect.py index 309ca0de3..84c0418a6 100644 --- a/rllib/tuned_examples/debug_learning_failure_git_bisect.py +++ b/rllib/tuned_examples/debug_learning_failure_git_bisect.py @@ -25,7 +25,7 @@ import yaml import ray from ray import tune -from ray.rllib.utils import try_import_tf +from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.test_utils import check_learning_achieved tf = try_import_tf() diff --git a/rllib/utils/explained_variance.py b/rllib/utils/explained_variance.py deleted file mode 100644 index 555396a78..000000000 --- a/rllib/utils/explained_variance.py +++ /dev/null @@ -1,21 +0,0 @@ -from ray.rllib.utils import try_import_tf, try_import_torch - -tf = try_import_tf() -torch, nn = try_import_torch() - - -def explained_variance(y, pred, framework="tf"): - if framework == "tf": - _, y_var = tf.nn.moments(y, axes=[0]) - _, diff_var = tf.nn.moments(y - pred, axes=[0]) - return tf.maximum(-1.0, 1 - (diff_var / y_var)) - else: - y_var = torch.var(y, dim=[0]) - diff_var = torch.var(y - pred, dim=[0]) - min_ = torch.Tensor([-1.0]) - return torch.max( - min_.to( - device=torch.device("cuda") - ) if torch.cuda.is_available() else min_, - 1 - (diff_var / y_var) - ) diff --git a/rllib/utils/seed.py b/rllib/utils/seed.py deleted file mode 100644 index 513c34fc7..000000000 --- a/rllib/utils/seed.py +++ /dev/null @@ -1,11 +0,0 @@ -import numpy as np -import random -from ray.rllib.utils import try_import_tf - -tf = try_import_tf() - - -def seed(np_seed=0, random_seed=0, tf_seed=0): - np.random.seed(np_seed) - random.seed(random_seed) - tf.set_random_seed(tf_seed) diff --git a/rllib/utils/tf_ops.py b/rllib/utils/tf_ops.py index ea636f744..b415b8689 100644 --- a/rllib/utils/tf_ops.py +++ b/rllib/utils/tf_ops.py @@ -1,8 +1,14 @@ -from ray.rllib.utils import try_import_tf +from ray.rllib.utils.framework import try_import_tf tf = try_import_tf() +def explained_variance(y, pred): + _, y_var = tf.nn.moments(y, axes=[0]) + _, diff_var = tf.nn.moments(y - pred, axes=[0]) + return tf.maximum(-1.0, 1 - (diff_var / y_var)) + + def huber_loss(x, delta=1.0): """Reference: https://en.wikipedia.org/wiki/Huber_loss""" return tf.where( diff --git a/rllib/utils/tf_run_builder.py b/rllib/utils/tf_run_builder.py index 5a90bcd7d..4d891fbfa 100644 --- a/rllib/utils/tf_run_builder.py +++ b/rllib/utils/tf_run_builder.py @@ -3,7 +3,7 @@ import os import time from ray.util.debug import log_once -from ray.rllib.utils import try_import_tf +from ray.rllib.utils.framework import try_import_tf tf = try_import_tf() logger = logging.getLogger(__name__) diff --git a/rllib/utils/torch_ops.py b/rllib/utils/torch_ops.py index ef02f8898..f72ff070a 100644 --- a/rllib/utils/torch_ops.py +++ b/rllib/utils/torch_ops.py @@ -1,11 +1,22 @@ import numpy as np -from ray.rllib.utils import try_import_torch, try_import_tree +from ray.rllib.utils import try_import_tree +from ray.rllib.utils.framework import try_import_torch torch, _ = try_import_torch() tree = try_import_tree() +def explained_variance(y, pred): + y_var = torch.var(y, dim=[0]) + diff_var = torch.var(y - pred, dim=[0]) + min_ = torch.Tensor([-1.0]) + return torch.max( + min_.to(device=torch.device("cuda")) + if torch.cuda.is_available() else min_, + 1 - (diff_var / y_var)) + + def global_norm(tensors): """Returns the global L2 norm over a list of tensors.