mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 19:32:11 +08:00
[RLlib] Minor rllib.utils cleanup. (#8932)
This commit is contained in:
@@ -2,13 +2,12 @@
|
||||
|
||||
import ray
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.explained_variance import explained_variance
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages, \
|
||||
Postprocessing
|
||||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.policy.tf_policy import LearningRateSchedule
|
||||
from ray.rllib.utils.tf_ops import make_tf_callable
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.tf_ops import explained_variance, make_tf_callable
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.tf_ops import huber_loss, minimize_and_clip, \
|
||||
make_tf_callable
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.utils import try_import_torch
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
|
||||
@@ -13,8 +13,8 @@ from ray.rllib.policy.torch_policy import LearningRateSchedule
|
||||
from ray.rllib.policy.torch_policy_template import build_torch_policy
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.utils.exploration.parameter_noise import ParameterNoise
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.torch_ops import huber_loss, reduce_mean_ignore_inf
|
||||
from ray.rllib.utils import try_import_torch
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
F = None
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.tf_ops import huber_loss, make_tf_callable
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
@@ -8,7 +8,7 @@ from ray.rllib.agents.dqn.simple_q_tf_policy import build_q_models, \
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
|
||||
from ray.rllib.policy.torch_policy_template import build_torch_policy
|
||||
from ray.rllib.utils import try_import_torch
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.torch_ops import huber_loss
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
@@ -31,7 +31,7 @@ tensors.
|
||||
import collections
|
||||
|
||||
from ray.rllib.models.tf.tf_action_dist import Categorical
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
@@ -13,8 +13,8 @@ from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.policy.tf_policy import LearningRateSchedule, \
|
||||
EntropyCoeffSchedule
|
||||
from ray.rllib.utils.explained_variance import explained_variance
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.tf_ops import explained_variance
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
@@ -11,9 +11,9 @@ from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.torch_policy import LearningRateSchedule, \
|
||||
EntropyCoeffSchedule
|
||||
from ray.rllib.policy.torch_policy_template import build_torch_policy
|
||||
from ray.rllib.utils.explained_variance import explained_variance
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.torch_ops import global_norm, sequence_mask
|
||||
from ray.rllib.utils.torch_ops import explained_variance, global_norm, \
|
||||
sequence_mask
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
@@ -239,8 +239,7 @@ def stats(policy, train_batch):
|
||||
"vf_loss": policy.loss.vf_loss,
|
||||
"vf_explained_var": explained_variance(
|
||||
torch.reshape(policy.loss.value_targets, [-1]),
|
||||
torch.reshape(values_batched, [-1]),
|
||||
framework="torch"),
|
||||
torch.reshape(values_batched, [-1])),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import ray
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.explained_variance import explained_variance
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages, \
|
||||
Postprocessing
|
||||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.utils.tf_ops import make_tf_callable
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.tf_ops import explained_variance, make_tf_callable
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ from ray.rllib.agents.marwil.marwil_tf_policy import postprocess_advantages
|
||||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.torch_policy_template import build_torch_policy
|
||||
from ray.rllib.utils.explained_variance import explained_variance
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.torch_ops import explained_variance
|
||||
|
||||
torch, _ = try_import_torch()
|
||||
|
||||
@@ -52,8 +52,7 @@ def marwil_loss(policy, model, dist_class, train_batch):
|
||||
# Combine both losses.
|
||||
policy.total_loss = policy.p_loss + policy.config["vf_coeff"] * \
|
||||
policy.v_loss
|
||||
explained_var = explained_variance(
|
||||
advantages, state_values, framework="torch")
|
||||
explained_var = explained_variance(advantages, state_values)
|
||||
policy.explained_variance = torch.mean(explained_var)
|
||||
|
||||
return policy.total_loss
|
||||
|
||||
@@ -3,7 +3,7 @@ from ray.rllib.evaluation.postprocessing import Postprocessing, \
|
||||
compute_advantages
|
||||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
@@ -13,14 +13,13 @@ from ray.rllib.evaluation.postprocessing import Postprocessing
|
||||
from ray.rllib.models.tf.tf_action_dist import Categorical
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.policy.tf_policy import LearningRateSchedule, TFPolicy
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import KLCoeffMixin, ValueNetworkMixin
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.explained_variance import explained_variance
|
||||
from ray.rllib.utils.tf_ops import make_tf_callable
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.tf_ops import explained_variance, make_tf_callable
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
@@ -19,9 +19,9 @@ from ray.rllib.models.torch.torch_action_dist import TorchCategorical
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.torch_policy import LearningRateSchedule
|
||||
from ray.rllib.policy.torch_policy_template import build_torch_policy
|
||||
from ray.rllib.utils.explained_variance import explained_variance
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.torch_ops import global_norm, sequence_mask
|
||||
from ray.rllib.utils.torch_ops import explained_variance, global_norm, \
|
||||
sequence_mask
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
@@ -353,8 +353,7 @@ def stats(policy, train_batch):
|
||||
"vf_loss": policy.loss.vf_loss,
|
||||
"vf_explained_var": explained_variance(
|
||||
torch.reshape(policy.loss.value_targets, [-1]),
|
||||
torch.reshape(values_batched, [-1]),
|
||||
framework="torch"),
|
||||
torch.reshape(values_batched, [-1])),
|
||||
}
|
||||
|
||||
if policy.config["vtrace"]:
|
||||
|
||||
@@ -7,7 +7,7 @@ from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches, \
|
||||
StandardizeFields, SelectExperiences
|
||||
from ray.rllib.execution.train_ops import TrainOneStep, TrainTFMultiGPU
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
@@ -7,9 +7,8 @@ from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.tf_policy import LearningRateSchedule, \
|
||||
EntropyCoeffSchedule
|
||||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.utils.explained_variance import explained_variance
|
||||
from ray.rllib.utils.tf_ops import make_tf_callable
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.tf_ops import explained_variance, make_tf_callable
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
@@ -9,9 +9,8 @@ from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.torch_policy import EntropyCoeffSchedule, \
|
||||
LearningRateSchedule
|
||||
from ray.rllib.policy.torch_policy_template import build_torch_policy
|
||||
from ray.rllib.utils.explained_variance import explained_variance
|
||||
from ray.rllib.utils.torch_ops import sequence_mask
|
||||
from ray.rllib.utils import try_import_torch
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.torch_ops import explained_variance, sequence_mask
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
@@ -152,8 +151,7 @@ def kl_and_loss_stats(policy, train_batch):
|
||||
"vf_loss": policy.loss_obj.mean_vf_loss,
|
||||
"vf_explained_var": explained_variance(
|
||||
train_batch[Postprocessing.VALUE_TARGETS],
|
||||
policy.model.value_function(),
|
||||
framework="torch"),
|
||||
policy.model.value_function()),
|
||||
"kl": policy.loss_obj.mean_kl,
|
||||
"entropy": policy.loss_obj.mean_entropy,
|
||||
"entropy_coeff": policy.entropy_coeff,
|
||||
|
||||
@@ -3,7 +3,7 @@ import numpy as np
|
||||
from numpy.testing import assert_allclose
|
||||
|
||||
from ray.rllib.agents.ppo.utils import flatten, concatenate
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.preprocessors import get_preprocessor
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils import try_import_torch
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.torch_policy_template import build_torch_policy
|
||||
from ray.rllib.models.torch.torch_action_dist import (
|
||||
TorchCategorical, TorchSquashedGaussian, TorchDiagGaussian, TorchBeta)
|
||||
from ray.rllib.utils import try_import_torch
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
F = nn.functional
|
||||
|
||||
@@ -2,9 +2,9 @@ import numpy as np
|
||||
|
||||
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
|
||||
from ray.rllib.policy.torch_policy import TorchPolicy
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.contrib.alpha_zero.core.mcts import Node, RootParentNode
|
||||
from ray.rllib.utils import try_import_torch
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
|
||||
torch, _ = try_import_torch()
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.model import restore_original_dimensions
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
|
||||
from ray.rllib.utils import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.tune.registry import ENV_CREATOR, _global_registry
|
||||
|
||||
from ray.rllib.contrib.alpha_zero.core.alpha_zero_policy import AlphaZeroPolicy
|
||||
|
||||
@@ -4,7 +4,7 @@ import numpy as np
|
||||
from ray.rllib.models.model import restore_original_dimensions
|
||||
from ray.rllib.models.preprocessors import get_preprocessor
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.utils import try_import_torch
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
|
||||
@@ -2,8 +2,8 @@ import gym
|
||||
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.utils import try_import_torch
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
from ray.rllib.utils import try_import_tf, try_import_tfp
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_tfp
|
||||
|
||||
import logging
|
||||
from gym.spaces import Box, Discrete
|
||||
|
||||
@@ -31,9 +31,9 @@ from ray.rllib.utils import merge_dicts
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.debug import summarize
|
||||
from ray.rllib.utils.filter import get_filter
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.sgd import do_minibatch_sgd
|
||||
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
||||
from ray.rllib.utils import try_import_tf, try_import_torch
|
||||
|
||||
tf = try_import_tf()
|
||||
torch, _ = try_import_torch()
|
||||
|
||||
@@ -2,12 +2,12 @@ import argparse
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.models.tf.attention_net import GTrXLNet
|
||||
from ray.rllib.examples.env.look_and_push import LookAndPush, OneHot
|
||||
from ray.rllib.examples.env.repeat_after_me_env import RepeatAfterMeEnv
|
||||
from ray.rllib.examples.env.repeat_initial_obs_env import RepeatInitialObsEnv
|
||||
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.test_utils import check_learning_achieved
|
||||
from ray.tune import registry
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from ray import tune
|
||||
from ray.rllib.examples.models.batch_norm_model import BatchNormModel, \
|
||||
TorchBatchNormModel
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.test_utils import check_learning_achieved
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
@@ -34,10 +34,9 @@ from ray.rllib.policy.tf_policy import LearningRateSchedule, \
|
||||
EntropyCoeffSchedule
|
||||
from ray.rllib.policy.torch_policy import LearningRateSchedule as TorchLR, \
|
||||
EntropyCoeffSchedule as TorchEntropyCoeffSchedule
|
||||
from ray.rllib.utils.explained_variance import explained_variance
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.test_utils import check_learning_achieved
|
||||
from ray.rllib.utils.tf_ops import make_tf_callable
|
||||
from ray.rllib.utils.tf_ops import explained_variance, make_tf_callable
|
||||
from ray.rllib.utils.torch_ops import convert_to_torch_tensor
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
@@ -4,13 +4,13 @@ import argparse
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.agents.dqn.distributional_q_tf_model import \
|
||||
DistributionalQTFModel
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.models.tf.misc import normc_initializer
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.agents.dqn.distributional_q_tf_model import \
|
||||
DistributionalQTFModel
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.models.tf.visionnet import VisionNetwork as MyVisionNetwork
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ from ray import tune
|
||||
from ray.rllib.examples.models.custom_loss_model import CustomLossModel, \
|
||||
TorchCustomLossModel
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from ray import tune
|
||||
from ray.rllib.agents.trainer_template import build_trainer
|
||||
from ray.rllib.evaluation.postprocessing import discount
|
||||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import os
|
||||
import ray
|
||||
|
||||
from ray.rllib.agents.registry import get_agent_class
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from ray.rllib.examples.env.random_env import RandomEnv
|
||||
from ray.rllib.examples.models.mobilenet_v2_with_lstm_models import \
|
||||
MobileV2PlusRNNModel, TorchMobileV2PlusRNNModel
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
@@ -6,8 +6,8 @@ from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.models.torch.misc import SlimFC, normc_initializer as \
|
||||
torch_normc_initializer
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.utils import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
|
||||
tf = try_import_tf()
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
@@ -5,7 +5,7 @@ from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.models.torch.misc import SlimFC
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
|
||||
tf = try_import_tf()
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from ray.rllib.utils import try_import_tf, try_import_torch
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork as TFFCNet
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFCNet
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
|
||||
tf = try_import_tf()
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
|
||||
from ray.util.debug import log_once
|
||||
from ray.rllib.utils.debug import summarize
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
@@ -10,8 +10,8 @@ from ray.rllib.execution.learner_thread import LearnerThread
|
||||
from ray.rllib.execution.minibatch_buffer import MinibatchBuffer
|
||||
from ray.rllib.execution.multi_gpu_impl import LocalSyncParallelOptimizer
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from ray.rllib.execution.multi_gpu_impl import LocalSyncParallelOptimizer
|
||||
from ray.rllib.policy.policy import PolicyID
|
||||
from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
|
||||
MultiAgentBatch
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.sgd import do_minibatch_sgd, averaged
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import numpy as np
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.models.tf.misc import linear, normc_initializer
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.tf_ops import scope_vars
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
@@ -4,7 +4,7 @@ from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.policy.rnn_sequencing import add_time_dimension
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.annotations import override, PublicAPI
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
@@ -5,8 +5,7 @@ from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.models.torch.misc import SlimFC, AppendBiasLayer, \
|
||||
normc_initializer
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.framework import get_activation_fn
|
||||
from ray.rllib.utils import try_import_torch
|
||||
from ray.rllib.utils.framework import get_activation_fn, try_import_torch
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
""" Code adapted from https://github.com/ikostrikov/pytorch-a3c"""
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.utils import try_import_torch
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.utils.annotations import override, PublicAPI
|
||||
from ray.rllib.utils import try_import_torch
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
|
||||
_, nn = try_import_torch()
|
||||
|
||||
|
||||
@@ -3,8 +3,7 @@ from ray.rllib.models.torch.misc import normc_initializer, valid_padding, \
|
||||
SlimConv2d, SlimFC
|
||||
from ray.rllib.models.tf.visionnet_v1 import _get_filter_config
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.framework import get_activation_fn
|
||||
from ray.rllib.utils import try_import_torch
|
||||
from ray.rllib.utils.framework import get_activation_fn, try_import_torch
|
||||
|
||||
_, nn = try_import_torch()
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import threading
|
||||
|
||||
from ray.rllib.policy.sample_batch import MultiAgentBatch
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
@@ -12,8 +12,8 @@ from ray.rllib.optimizers.aso_learner import LearnerThread
|
||||
from ray.rllib.optimizers.aso_minibatch_buffer import MinibatchBuffer
|
||||
from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
|
||||
from ray.util.debug import log_once
|
||||
from ray.rllib.utils.debug import summarize
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
@@ -9,12 +9,12 @@ from ray.rllib.policy.tf_policy import TFPolicy
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer
|
||||
from ray.rllib.optimizers.rollout import collect_samples
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.sgd import averaged
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
|
||||
MultiAgentBatch
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.sgd import averaged
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from ray.rllib.optimizers import AsyncGradientsOptimizer, AsyncSamplesOptimizer
|
||||
from ray.rllib.optimizers.aso_tree_aggregator import TreeAggregator
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.tests.mock_worker import _MockWorker
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
@@ -9,8 +9,9 @@ from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.utils import try_import_tf, override
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.debug import summarize
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.tracking_dict import UsageTrackingDict
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
@@ -4,7 +4,7 @@ from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
from ray.rllib.utils import add_mixins
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.rollout import rollout
|
||||
from ray.rllib.tests.test_external_env import SimpleServing
|
||||
from ray.tune.registry import register_env
|
||||
from ray.rllib.utils import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.spaces.repeated import Repeated
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
@@ -25,7 +25,7 @@ import yaml
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.test_utils import check_learning_achieved
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
from ray.rllib.utils import try_import_tf, try_import_torch
|
||||
|
||||
tf = try_import_tf()
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
|
||||
def explained_variance(y, pred, framework="tf"):
|
||||
if framework == "tf":
|
||||
_, y_var = tf.nn.moments(y, axes=[0])
|
||||
_, diff_var = tf.nn.moments(y - pred, axes=[0])
|
||||
return tf.maximum(-1.0, 1 - (diff_var / y_var))
|
||||
else:
|
||||
y_var = torch.var(y, dim=[0])
|
||||
diff_var = torch.var(y - pred, dim=[0])
|
||||
min_ = torch.Tensor([-1.0])
|
||||
return torch.max(
|
||||
min_.to(
|
||||
device=torch.device("cuda")
|
||||
) if torch.cuda.is_available() else min_,
|
||||
1 - (diff_var / y_var)
|
||||
)
|
||||
@@ -1,11 +0,0 @@
|
||||
import numpy as np
|
||||
import random
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
def seed(np_seed=0, random_seed=0, tf_seed=0):
|
||||
np.random.seed(np_seed)
|
||||
random.seed(random_seed)
|
||||
tf.set_random_seed(tf_seed)
|
||||
@@ -1,8 +1,14 @@
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
def explained_variance(y, pred):
|
||||
_, y_var = tf.nn.moments(y, axes=[0])
|
||||
_, diff_var = tf.nn.moments(y - pred, axes=[0])
|
||||
return tf.maximum(-1.0, 1 - (diff_var / y_var))
|
||||
|
||||
|
||||
def huber_loss(x, delta=1.0):
|
||||
"""Reference: https://en.wikipedia.org/wiki/Huber_loss"""
|
||||
return tf.where(
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import time
|
||||
|
||||
from ray.util.debug import log_once
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1,11 +1,22 @@
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.utils import try_import_torch, try_import_tree
|
||||
from ray.rllib.utils import try_import_tree
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
|
||||
torch, _ = try_import_torch()
|
||||
tree = try_import_tree()
|
||||
|
||||
|
||||
def explained_variance(y, pred):
|
||||
y_var = torch.var(y, dim=[0])
|
||||
diff_var = torch.var(y - pred, dim=[0])
|
||||
min_ = torch.Tensor([-1.0])
|
||||
return torch.max(
|
||||
min_.to(device=torch.device("cuda"))
|
||||
if torch.cuda.is_available() else min_,
|
||||
1 - (diff_var / y_var))
|
||||
|
||||
|
||||
def global_norm(tensors):
|
||||
"""Returns the global L2 norm over a list of tensors.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user