From 2256047876a79db7e2de2e9fb11b34b8b91d5d2a Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Sat, 15 Aug 2020 13:24:22 +0200 Subject: [PATCH] [RLlib] Rename rllib.utils.types into typing to match built-in python module's name. (#10114) --- rllib/agents/callbacks.py | 2 +- rllib/agents/dqn/apex.py | 2 +- rllib/agents/mbmpo/model_ensemble.py | 2 +- rllib/agents/trainer.py | 2 +- rllib/agents/trainer_template.py | 2 +- rllib/env/base_env.py | 2 +- rllib/env/env_context.py | 2 +- rllib/env/external_env.py | 2 +- rllib/env/external_multi_agent_env.py | 2 +- rllib/env/multi_agent_env.py | 2 +- rllib/env/policy_client.py | 2 +- rllib/env/remote_vector_env.py | 2 +- rllib/env/unity3d_env.py | 2 +- rllib/env/vector_env.py | 2 +- rllib/evaluation/episode.py | 2 +- rllib/evaluation/metrics.py | 2 +- rllib/evaluation/observation_function.py | 2 +- rllib/evaluation/rollout_worker.py | 2 +- rllib/evaluation/sample_batch_builder.py | 2 +- rllib/evaluation/sample_collector.py | 2 +- rllib/evaluation/sampler.py | 2 +- rllib/evaluation/worker_set.py | 2 +- rllib/execution/replay_buffer.py | 2 +- rllib/execution/replay_ops.py | 2 +- rllib/execution/rollout_ops.py | 2 +- rllib/execution/train_ops.py | 2 +- rllib/execution/tree_agg.py | 2 +- rllib/models/action_dist.py | 2 +- rllib/models/catalog.py | 2 +- rllib/models/modelv2.py | 2 +- rllib/models/tf/tf_action_dist.py | 2 +- rllib/models/tf/tf_modelv2.py | 2 +- rllib/models/torch/torch_action_dist.py | 2 +- rllib/models/torch/torch_modelv2.py | 2 +- rllib/offline/input_reader.py | 2 +- rllib/offline/is_estimator.py | 2 +- rllib/offline/json_reader.py | 2 +- rllib/offline/json_writer.py | 2 +- rllib/offline/mixed_input.py | 2 +- rllib/offline/off_policy_estimator.py | 2 +- rllib/offline/output_writer.py | 2 +- rllib/offline/shuffled_input.py | 2 +- rllib/offline/wis_estimator.py | 2 +- rllib/policy/dynamic_tf_policy.py | 3 ++- rllib/policy/policy.py | 2 +- rllib/policy/sample_batch.py | 2 +- rllib/policy/tf_policy.py | 3 ++- rllib/policy/tf_policy_template.py | 3 ++- rllib/policy/torch_policy.py | 2 +- rllib/policy/torch_policy_template.py | 2 +- rllib/utils/exploration/curiosity.py | 2 +- rllib/utils/framework.py | 2 +- rllib/utils/schedules/polynomial_schedule.py | 2 +- rllib/utils/{types.py => typing.py} | 0 54 files changed, 56 insertions(+), 53 deletions(-) rename rllib/utils/{types.py => typing.py} (100%) diff --git a/rllib/agents/callbacks.py b/rllib/agents/callbacks.py index 7ef06bed7..2111032cb 100644 --- a/rllib/agents/callbacks.py +++ b/rllib/agents/callbacks.py @@ -6,7 +6,7 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.evaluation import MultiAgentEpisode, RolloutWorker from ray.rllib.utils.annotations import PublicAPI from ray.rllib.utils.deprecation import deprecation_warning -from ray.rllib.utils.types import AgentID, PolicyID +from ray.rllib.utils.typing import AgentID, PolicyID @PublicAPI diff --git a/rllib/agents/dqn/apex.py b/rllib/agents/dqn/apex.py index eb1bb57c6..2ea8ebf91 100644 --- a/rllib/agents/dqn/apex.py +++ b/rllib/agents/dqn/apex.py @@ -16,7 +16,7 @@ from ray.rllib.execution.metric_ops import StandardMetricsReporting from ray.rllib.execution.replay_buffer import ReplayActor from ray.rllib.utils import merge_dicts from ray.rllib.utils.actors import create_colocated -from ray.rllib.utils.types import SampleBatchType +from ray.rllib.utils.typing import SampleBatchType # yapf: disable # __sphinx_doc_begin__ diff --git a/rllib/agents/mbmpo/model_ensemble.py b/rllib/agents/mbmpo/model_ensemble.py index bf37cff5a..c252e0464 100644 --- a/rllib/agents/mbmpo/model_ensemble.py +++ b/rllib/agents/mbmpo/model_ensemble.py @@ -6,7 +6,7 @@ from ray.rllib.utils.framework import try_import_torch from ray.rllib.evaluation.rollout_worker import get_global_worker from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER -from ray.rllib.utils.types import SampleBatchType +from ray.rllib.utils.typing import SampleBatchType from ray.rllib.utils.torch_ops import convert_to_torch_tensor torch, nn = try_import_torch() diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 400d8a287..b2a3c1972 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -25,7 +25,7 @@ from ray.rllib.utils.framework import try_import_tf, TensorStructType from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning from ray.rllib.utils.from_config import from_config -from ray.rllib.utils.types import TrainerConfigDict, \ +from ray.rllib.utils.typing import TrainerConfigDict, \ PartialTrainerConfigDict, EnvInfoDict, ResultDict, EnvType, PolicyID from ray.tune.registry import ENV_CREATOR, register_env, _global_registry from ray.tune.trainable import Trainable diff --git a/rllib/agents/trainer_template.py b/rllib/agents/trainer_template.py index f7daff848..c835b776f 100644 --- a/rllib/agents/trainer_template.py +++ b/rllib/agents/trainer_template.py @@ -9,7 +9,7 @@ from ray.rllib.execution.metric_ops import StandardMetricsReporting from ray.rllib.policy import Policy from ray.rllib.utils import add_mixins from ray.rllib.utils.annotations import override, DeveloperAPI -from ray.rllib.utils.types import TrainerConfigDict, ResultDict +from ray.rllib.utils.typing import TrainerConfigDict, ResultDict logger = logging.getLogger(__name__) diff --git a/rllib/env/base_env.py b/rllib/env/base_env.py index 43389bff4..20444ec49 100644 --- a/rllib/env/base_env.py +++ b/rllib/env/base_env.py @@ -5,7 +5,7 @@ from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.env.vector_env import VectorEnv from ray.rllib.utils.annotations import override, PublicAPI -from ray.rllib.utils.types import EnvType, MultiEnvDict, EnvID, \ +from ray.rllib.utils.typing import EnvType, MultiEnvDict, EnvID, \ AgentID, MultiAgentDict if TYPE_CHECKING: diff --git a/rllib/env/env_context.py b/rllib/env/env_context.py index e8ed98b93..c5f9dd62d 100644 --- a/rllib/env/env_context.py +++ b/rllib/env/env_context.py @@ -1,5 +1,5 @@ from ray.rllib.utils.annotations import PublicAPI -from ray.rllib.utils.types import EnvConfigDict +from ray.rllib.utils.typing import EnvConfigDict @PublicAPI diff --git a/rllib/env/external_env.py b/rllib/env/external_env.py index fdac97fb1..fdb382298 100644 --- a/rllib/env/external_env.py +++ b/rllib/env/external_env.py @@ -5,7 +5,7 @@ import uuid from typing import Optional from ray.rllib.utils.annotations import PublicAPI -from ray.rllib.utils.types import EnvActionType, EnvObsType, EnvInfoDict +from ray.rllib.utils.typing import EnvActionType, EnvObsType, EnvInfoDict @PublicAPI diff --git a/rllib/env/external_multi_agent_env.py b/rllib/env/external_multi_agent_env.py index 775888341..42cd11c46 100644 --- a/rllib/env/external_multi_agent_env.py +++ b/rllib/env/external_multi_agent_env.py @@ -4,7 +4,7 @@ from typing import Optional from ray.rllib.utils.annotations import override, PublicAPI from ray.rllib.env.external_env import ExternalEnv, _ExternalEnvEpisode -from ray.rllib.utils.types import MultiAgentDict +from ray.rllib.utils.typing import MultiAgentDict @PublicAPI diff --git a/rllib/env/multi_agent_env.py b/rllib/env/multi_agent_env.py index 4034fb81c..96db0637d 100644 --- a/rllib/env/multi_agent_env.py +++ b/rllib/env/multi_agent_env.py @@ -2,7 +2,7 @@ from typing import Tuple, Dict, List import gym from ray.rllib.utils.annotations import PublicAPI -from ray.rllib.utils.types import MultiAgentDict, AgentID +from ray.rllib.utils.typing import MultiAgentDict, AgentID # If the obs space is Dict type, look for the global state under this key. ENV_STATE = "state" diff --git a/rllib/env/policy_client.py b/rllib/env/policy_client.py index 67835efde..5aa17fae0 100644 --- a/rllib/env/policy_client.py +++ b/rllib/env/policy_client.py @@ -14,7 +14,7 @@ from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.env import ExternalEnv, MultiAgentEnv, ExternalMultiAgentEnv from ray.rllib.policy.sample_batch import MultiAgentBatch from ray.rllib.utils.annotations import PublicAPI -from ray.rllib.utils.types import MultiAgentDict, EnvInfoDict, EnvObsType, \ +from ray.rllib.utils.typing import MultiAgentDict, EnvInfoDict, EnvObsType, \ EnvActionType logger = logging.getLogger(__name__) diff --git a/rllib/env/remote_vector_env.py b/rllib/env/remote_vector_env.py index 16c5dde35..b1bbc3f83 100644 --- a/rllib/env/remote_vector_env.py +++ b/rllib/env/remote_vector_env.py @@ -4,7 +4,7 @@ from typing import Tuple, Callable, Optional import ray from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID, ASYNC_RESET_RETURN from ray.rllib.utils.annotations import override, PublicAPI -from ray.rllib.utils.types import MultiEnvDict, EnvType, EnvID, MultiAgentDict +from ray.rllib.utils.typing import MultiEnvDict, EnvType, EnvID, MultiAgentDict logger = logging.getLogger(__name__) diff --git a/rllib/env/unity3d_env.py b/rllib/env/unity3d_env.py index 45a374cde..61ad14348 100644 --- a/rllib/env/unity3d_env.py +++ b/rllib/env/unity3d_env.py @@ -5,7 +5,7 @@ from typing import Callable, Tuple from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.utils.annotations import override -from ray.rllib.utils.types import MultiAgentDict, PolicyID, AgentID +from ray.rllib.utils.typing import MultiAgentDict, PolicyID, AgentID logger = logging.getLogger(__name__) diff --git a/rllib/env/vector_env.py b/rllib/env/vector_env.py index 4267905db..b97050933 100644 --- a/rllib/env/vector_env.py +++ b/rllib/env/vector_env.py @@ -4,7 +4,7 @@ import numpy as np from typing import Callable, List, Tuple from ray.rllib.utils.annotations import override, PublicAPI -from ray.rllib.utils.types import EnvType, EnvConfigDict, EnvObsType, \ +from ray.rllib.utils.typing import EnvType, EnvConfigDict, EnvObsType, \ EnvInfoDict, EnvActionType logger = logging.getLogger(__name__) diff --git a/rllib/evaluation/episode.py b/rllib/evaluation/episode.py index 0f578770f..b09171c4e 100644 --- a/rllib/evaluation/episode.py +++ b/rllib/evaluation/episode.py @@ -7,7 +7,7 @@ from ray.rllib.env.base_env import _DUMMY_AGENT_ID from ray.rllib.policy.policy import Policy from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray -from ray.rllib.utils.types import SampleBatchType, AgentID, PolicyID, \ +from ray.rllib.utils.typing import SampleBatchType, AgentID, PolicyID, \ EnvObsType, EnvInfoDict, EnvActionType if TYPE_CHECKING: diff --git a/rllib/evaluation/metrics.py b/rllib/evaluation/metrics.py index 3f0ef4122..5ac1fa74e 100644 --- a/rllib/evaluation/metrics.py +++ b/rllib/evaluation/metrics.py @@ -9,7 +9,7 @@ from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.offline.off_policy_estimator import OffPolicyEstimate from ray.rllib.policy.policy import LEARNER_STATS_KEY from ray.rllib.utils.annotations import DeveloperAPI -from ray.rllib.utils.types import GradInfoDict, LearnerStatsDict, ResultDict +from ray.rllib.utils.typing import GradInfoDict, LearnerStatsDict, ResultDict logger = logging.getLogger(__name__) diff --git a/rllib/evaluation/observation_function.py b/rllib/evaluation/observation_function.py index 48661aa0b..91849a735 100644 --- a/rllib/evaluation/observation_function.py +++ b/rllib/evaluation/observation_function.py @@ -4,7 +4,7 @@ from ray.rllib.env import BaseEnv from ray.rllib.policy import Policy from ray.rllib.evaluation import MultiAgentEpisode, RolloutWorker from ray.rllib.utils.framework import TensorType -from ray.rllib.utils.types import AgentID, PolicyID +from ray.rllib.utils.typing import AgentID, PolicyID class ObservationFunction: diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index f198364d5..5c2c24bf6 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -36,7 +36,7 @@ from ray.rllib.utils.filter import get_filter, Filter from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.sgd import do_minibatch_sgd from ray.rllib.utils.tf_run_builder import TFRunBuilder -from ray.rllib.utils.types import EnvType, AgentID, PolicyID, EnvConfigDict, \ +from ray.rllib.utils.typing import EnvType, AgentID, PolicyID, EnvConfigDict, \ ModelConfigDict, TrainerConfigDict, SampleBatchType, ModelWeights, \ ModelGradients, MultiAgentPolicyConfigDict from ray.util.debug import log_once, disable_log_once_globally, \ diff --git a/rllib/evaluation/sample_batch_builder.py b/rllib/evaluation/sample_batch_builder.py index ae68d2f5a..15d471240 100644 --- a/rllib/evaluation/sample_batch_builder.py +++ b/rllib/evaluation/sample_batch_builder.py @@ -8,7 +8,7 @@ from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI from ray.rllib.utils.debug import summarize -from ray.rllib.utils.types import PolicyID, AgentID +from ray.rllib.utils.typing import PolicyID, AgentID from ray.rllib.env.base_env import _DUMMY_AGENT_ID from ray.util.debug import log_once diff --git a/rllib/evaluation/sample_collector.py b/rllib/evaluation/sample_collector.py index af17fae9d..f532ba25a 100644 --- a/rllib/evaluation/sample_collector.py +++ b/rllib/evaluation/sample_collector.py @@ -4,7 +4,7 @@ from typing import Dict, Optional from ray.rllib.evaluation.episode import MultiAgentEpisode from ray.rllib.models.modelv2 import ModelV2 -from ray.rllib.utils.types import AgentID, EpisodeID, PolicyID, \ +from ray.rllib.utils.typing import AgentID, EpisodeID, PolicyID, \ TensorType logger = logging.getLogger(__name__) diff --git a/rllib/evaluation/sampler.py b/rllib/evaluation/sampler.py index 0cb9f9b93..bb6071062 100644 --- a/rllib/evaluation/sampler.py +++ b/rllib/evaluation/sampler.py @@ -25,7 +25,7 @@ from ray.rllib.utils.debug import summarize from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray, \ unbatch from ray.rllib.utils.tf_run_builder import TFRunBuilder -from ray.rllib.utils.types import SampleBatchType, AgentID, PolicyID, \ +from ray.rllib.utils.typing import SampleBatchType, AgentID, PolicyID, \ EnvObsType, EnvInfoDict, EnvID, MultiEnvDict, EnvActionType, \ TensorStructType diff --git a/rllib/evaluation/worker_set.py b/rllib/evaluation/worker_set.py index 46f6a86ff..a1b269b4c 100644 --- a/rllib/evaluation/worker_set.py +++ b/rllib/evaluation/worker_set.py @@ -12,7 +12,7 @@ from ray.rllib.env.env_context import EnvContext from ray.rllib.policy import Policy from ray.rllib.utils import merge_dicts from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.types import PolicyID, TrainerConfigDict, EnvType +from ray.rllib.utils.typing import PolicyID, TrainerConfigDict, EnvType tf1, tf, tfv = try_import_tf() diff --git a/rllib/execution/replay_buffer.py b/rllib/execution/replay_buffer.py index dba78e790..15a0719cd 100644 --- a/rllib/execution/replay_buffer.py +++ b/rllib/execution/replay_buffer.py @@ -13,7 +13,7 @@ from ray.rllib.utils.annotations import DeveloperAPI from ray.util.iter import ParallelIteratorWorker from ray.rllib.utils.timer import TimerStat from ray.rllib.utils.window_stat import WindowStat -from ray.rllib.utils.types import SampleBatchType +from ray.rllib.utils.typing import SampleBatchType # Constant that represents all policies in lockstep replay mode. _ALL_POLICIES = "__all__" diff --git a/rllib/execution/replay_ops.py b/rllib/execution/replay_ops.py index c5245e78f..7c14a4ef4 100644 --- a/rllib/execution/replay_ops.py +++ b/rllib/execution/replay_ops.py @@ -6,7 +6,7 @@ from ray.util.iter_metrics import SharedMetrics from ray.rllib.execution.replay_buffer import LocalReplayBuffer from ray.rllib.execution.common import \ STEPS_SAMPLED_COUNTER, _get_shared_metrics -from ray.rllib.utils.types import SampleBatchType +from ray.rllib.utils.typing import SampleBatchType class StoreToReplayBuffer: diff --git a/rllib/execution/rollout_ops.py b/rllib/execution/rollout_ops.py index c199e15b9..7af144b81 100644 --- a/rllib/execution/rollout_ops.py +++ b/rllib/execution/rollout_ops.py @@ -13,7 +13,7 @@ from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, LEARNER_INFO, \ from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ MultiAgentBatch from ray.rllib.utils.sgd import standardized -from ray.rllib.utils.types import PolicyID, SampleBatchType, ModelGradients +from ray.rllib.utils.typing import PolicyID, SampleBatchType, ModelGradients logger = logging.getLogger(__name__) diff --git a/rllib/execution/train_ops.py b/rllib/execution/train_ops.py index edc8ea934..f5ac24e0e 100644 --- a/rllib/execution/train_ops.py +++ b/rllib/execution/train_ops.py @@ -18,7 +18,7 @@ from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ MultiAgentBatch from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.sgd import do_minibatch_sgd, averaged -from ray.rllib.utils.types import PolicyID, SampleBatchType +from ray.rllib.utils.typing import PolicyID, SampleBatchType tf1, tf, tfv = try_import_tf() diff --git a/rllib/execution/tree_agg.py b/rllib/execution/tree_agg.py index b7d07850b..344a22e20 100644 --- a/rllib/execution/tree_agg.py +++ b/rllib/execution/tree_agg.py @@ -10,7 +10,7 @@ from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches from ray.rllib.utils.actors import create_colocated from ray.util.iter import ParallelIterator, ParallelIteratorWorker, \ from_actors -from ray.rllib.utils.types import SampleBatchType +from ray.rllib.utils.typing import SampleBatchType logger = logging.getLogger(__name__) diff --git a/rllib/models/action_dist.py b/rllib/models/action_dist.py index bf349355a..34e3f63c6 100644 --- a/rllib/models/action_dist.py +++ b/rllib/models/action_dist.py @@ -3,7 +3,7 @@ import gym from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.utils.annotations import DeveloperAPI -from ray.rllib.utils.types import TensorType, List, Union, ModelConfigDict +from ray.rllib.utils.typing import TensorType, List, Union, ModelConfigDict @DeveloperAPI diff --git a/rllib/models/catalog.py b/rllib/models/catalog.py index afe39e9cf..f7c91801c 100644 --- a/rllib/models/catalog.py +++ b/rllib/models/catalog.py @@ -27,7 +27,7 @@ from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.spaces.simplex import Simplex from ray.rllib.utils.spaces.space_utils import flatten_space -from ray.rllib.utils.types import ModelConfigDict, TensorType +from ray.rllib.utils.typing import ModelConfigDict, TensorType tf1, tf, tfv = try_import_tf() diff --git a/rllib/models/modelv2.py b/rllib/models/modelv2.py index a8b14bf39..5ead72459 100644 --- a/rllib/models/modelv2.py +++ b/rllib/models/modelv2.py @@ -13,7 +13,7 @@ from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI from ray.rllib.utils.framework import try_import_tf, try_import_torch, \ TensorType from ray.rllib.utils.spaces.repeated import Repeated -from ray.rllib.utils.types import ModelConfigDict, TensorStructType +from ray.rllib.utils.typing import ModelConfigDict, TensorStructType tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() diff --git a/rllib/models/tf/tf_action_dist.py b/rllib/models/tf/tf_action_dist.py index 94235e2ec..de7fb6f29 100644 --- a/rllib/models/tf/tf_action_dist.py +++ b/rllib/models/tf/tf_action_dist.py @@ -10,7 +10,7 @@ from ray.rllib.utils import MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT, \ from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.framework import try_import_tf, try_import_tfp from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space -from ray.rllib.utils.types import TensorType, List +from ray.rllib.utils.typing import TensorType, List tf1, tf, tfv = try_import_tf() tfp = try_import_tfp() diff --git a/rllib/models/tf/tf_modelv2.py b/rllib/models/tf/tf_modelv2.py index 1d5408f13..09625781b 100644 --- a/rllib/models/tf/tf_modelv2.py +++ b/rllib/models/tf/tf_modelv2.py @@ -5,7 +5,7 @@ from typing import List from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.utils.annotations import override, PublicAPI from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.types import ModelConfigDict, TensorType +from ray.rllib.utils.typing import ModelConfigDict, TensorType tf1, tf, tfv = try_import_tf() diff --git a/rllib/models/torch/torch_action_dist.py b/rllib/models/torch/torch_action_dist.py index 9b7f70cfe..d0b28c77d 100644 --- a/rllib/models/torch/torch_action_dist.py +++ b/rllib/models/torch/torch_action_dist.py @@ -11,7 +11,7 @@ from ray.rllib.utils.numpy import SMALL_NUMBER, MIN_LOG_NN_OUTPUT, \ MAX_LOG_NN_OUTPUT from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space from ray.rllib.utils.torch_ops import atanh -from ray.rllib.utils.types import TensorType, List +from ray.rllib.utils.typing import TensorType, List torch, nn = try_import_torch() diff --git a/rllib/models/torch/torch_modelv2.py b/rllib/models/torch/torch_modelv2.py index 393d33ee4..091418120 100644 --- a/rllib/models/torch/torch_modelv2.py +++ b/rllib/models/torch/torch_modelv2.py @@ -4,7 +4,7 @@ from typing import List from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.utils.annotations import override, PublicAPI from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.types import ModelConfigDict, TensorType +from ray.rllib.utils.typing import ModelConfigDict, TensorType _, nn = try_import_torch() diff --git a/rllib/offline/input_reader.py b/rllib/offline/input_reader.py index e9ea53e4b..cf18288c4 100644 --- a/rllib/offline/input_reader.py +++ b/rllib/offline/input_reader.py @@ -6,7 +6,7 @@ from ray.rllib.policy.sample_batch import MultiAgentBatch from ray.rllib.utils.annotations import PublicAPI from ray.rllib.utils.framework import try_import_tf from typing import Dict, List -from ray.rllib.utils.types import TensorType, SampleBatchType +from ray.rllib.utils.typing import TensorType, SampleBatchType tf1, tf, tfv = try_import_tf() diff --git a/rllib/offline/is_estimator.py b/rllib/offline/is_estimator.py index 58b24c691..1591be84a 100644 --- a/rllib/offline/is_estimator.py +++ b/rllib/offline/is_estimator.py @@ -1,7 +1,7 @@ from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, \ OffPolicyEstimate from ray.rllib.utils.annotations import override -from ray.rllib.utils.types import SampleBatchType +from ray.rllib.utils.typing import SampleBatchType class ImportanceSamplingEstimator(OffPolicyEstimator): diff --git a/rllib/offline/json_reader.py b/rllib/offline/json_reader.py index 54ef3d504..1229bdd07 100644 --- a/rllib/offline/json_reader.py +++ b/rllib/offline/json_reader.py @@ -16,7 +16,7 @@ from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch, \ DEFAULT_POLICY_ID from ray.rllib.utils.annotations import override, PublicAPI from ray.rllib.utils.compression import unpack_if_needed -from ray.rllib.utils.types import FileType, SampleBatchType +from ray.rllib.utils.typing import FileType, SampleBatchType from typing import List logger = logging.getLogger(__name__) diff --git a/rllib/offline/json_writer.py b/rllib/offline/json_writer.py index 3b5f4e895..f34545ffa 100644 --- a/rllib/offline/json_writer.py +++ b/rllib/offline/json_writer.py @@ -16,7 +16,7 @@ from ray.rllib.offline.io_context import IOContext from ray.rllib.offline.output_writer import OutputWriter from ray.rllib.utils.annotations import override, PublicAPI from ray.rllib.utils.compression import pack, compression_supported -from ray.rllib.utils.types import FileType, SampleBatchType +from ray.rllib.utils.typing import FileType, SampleBatchType from typing import Any, List logger = logging.getLogger(__name__) diff --git a/rllib/offline/mixed_input.py b/rllib/offline/mixed_input.py index 45e4aa41a..039e9239c 100644 --- a/rllib/offline/mixed_input.py +++ b/rllib/offline/mixed_input.py @@ -4,7 +4,7 @@ from ray.rllib.offline.input_reader import InputReader from ray.rllib.offline.json_reader import JsonReader from ray.rllib.offline.io_context import IOContext from ray.rllib.utils.annotations import override, DeveloperAPI -from ray.rllib.utils.types import SampleBatchType +from ray.rllib.utils.typing import SampleBatchType from typing import Dict diff --git a/rllib/offline/off_policy_estimator.py b/rllib/offline/off_policy_estimator.py index c0c1fa849..fff235f82 100644 --- a/rllib/offline/off_policy_estimator.py +++ b/rllib/offline/off_policy_estimator.py @@ -5,7 +5,7 @@ from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch from ray.rllib.policy import Policy from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.offline.io_context import IOContext -from ray.rllib.utils.types import TensorType, SampleBatchType +from ray.rllib.utils.typing import TensorType, SampleBatchType from typing import List logger = logging.getLogger(__name__) diff --git a/rllib/offline/output_writer.py b/rllib/offline/output_writer.py index cf9d0cc80..3e528fb87 100644 --- a/rllib/offline/output_writer.py +++ b/rllib/offline/output_writer.py @@ -1,6 +1,6 @@ from ray.rllib.utils.annotations import override from ray.rllib.utils.annotations import PublicAPI -from ray.rllib.utils.types import SampleBatchType +from ray.rllib.utils.typing import SampleBatchType @PublicAPI diff --git a/rllib/offline/shuffled_input.py b/rllib/offline/shuffled_input.py index 10dfc8cb7..b829ee44a 100644 --- a/rllib/offline/shuffled_input.py +++ b/rllib/offline/shuffled_input.py @@ -3,7 +3,7 @@ import random from ray.rllib.offline.input_reader import InputReader from ray.rllib.utils.annotations import override, DeveloperAPI -from ray.rllib.utils.types import SampleBatchType +from ray.rllib.utils.typing import SampleBatchType logger = logging.getLogger(__name__) diff --git a/rllib/offline/wis_estimator.py b/rllib/offline/wis_estimator.py index e1bb156bb..a99d6643f 100644 --- a/rllib/offline/wis_estimator.py +++ b/rllib/offline/wis_estimator.py @@ -2,7 +2,7 @@ from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, \ OffPolicyEstimate from ray.rllib.policy import Policy from ray.rllib.utils.annotations import override -from ray.rllib.utils.types import SampleBatchType +from ray.rllib.utils.typing import SampleBatchType class WeightedImportanceSamplingEstimator(OffPolicyEstimator): diff --git a/rllib/policy/dynamic_tf_policy.py b/rllib/policy/dynamic_tf_policy.py index 722a53248..ec1872a92 100644 --- a/rllib/policy/dynamic_tf_policy.py +++ b/rllib/policy/dynamic_tf_policy.py @@ -14,7 +14,8 @@ from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.debug import summarize from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.tracking_dict import UsageTrackingDict -from ray.rllib.utils.types import ModelGradients, TensorType, TrainerConfigDict +from ray.rllib.utils.typing import ModelGradients, TensorType, \ + TrainerConfigDict tf1, tf, tfv = try_import_tf() diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index 920d40a1b..a5ffb3168 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -11,7 +11,7 @@ from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.from_config import from_config from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space, \ unbatch -from ray.rllib.utils.types import AgentID, ModelGradients, ModelWeights, \ +from ray.rllib.utils.typing import AgentID, ModelGradients, ModelWeights, \ TensorType, TrainerConfigDict, Tuple, Union torch, _ = try_import_torch() diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index 37fecb002..32e9c84a6 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -8,7 +8,7 @@ from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI from ray.rllib.utils.compression import pack, unpack, is_compressed from ray.rllib.utils.memory import concat_aligned from ray.rllib.utils.deprecation import deprecation_warning -from ray.rllib.utils.types import TensorType +from ray.rllib.utils.typing import TensorType # Default policy id for single agent environments DEFAULT_POLICY_ID = "default_policy" diff --git a/rllib/policy/tf_policy.py b/rllib/policy/tf_policy.py index ac9b768f8..e76a56eca 100644 --- a/rllib/policy/tf_policy.py +++ b/rllib/policy/tf_policy.py @@ -17,7 +17,8 @@ from ray.rllib.utils.debug import summarize from ray.rllib.utils.framework import try_import_tf, get_variable from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule from ray.rllib.utils.tf_run_builder import TFRunBuilder -from ray.rllib.utils.types import ModelGradients, TensorType, TrainerConfigDict +from ray.rllib.utils.typing import ModelGradients, TensorType, \ + TrainerConfigDict tf1, tf, tfv = try_import_tf() logger = logging.getLogger(__name__) diff --git a/rllib/policy/tf_policy_template.py b/rllib/policy/tf_policy_template.py index 806e4fb80..6242ed611 100644 --- a/rllib/policy/tf_policy_template.py +++ b/rllib/policy/tf_policy_template.py @@ -9,7 +9,8 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.tf_policy import TFPolicy from ray.rllib.utils import add_mixins from ray.rllib.utils.annotations import override, DeveloperAPI -from ray.rllib.utils.types import ModelGradients, TensorType, TrainerConfigDict +from ray.rllib.utils.typing import ModelGradients, TensorType, \ + TrainerConfigDict @DeveloperAPI diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index b824f4b6a..5b4c86e4d 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -19,7 +19,7 @@ from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule from ray.rllib.utils.torch_ops import convert_to_non_torch_type, \ convert_to_torch_tensor from ray.rllib.utils.tracking_dict import UsageTrackingDict -from ray.rllib.utils.types import ModelGradients, ModelWeights, \ +from ray.rllib.utils.typing import ModelGradients, ModelWeights, \ TensorType, TrainerConfigDict torch, _ = try_import_torch() diff --git a/rllib/policy/torch_policy_template.py b/rllib/policy/torch_policy_template.py index 270cd4c3f..f7ce8ed2f 100644 --- a/rllib/policy/torch_policy_template.py +++ b/rllib/policy/torch_policy_template.py @@ -13,7 +13,7 @@ from ray.rllib.utils import add_mixins, force_list from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.torch_ops import convert_to_non_torch_type -from ray.rllib.utils.types import TensorType, TrainerConfigDict +from ray.rllib.utils.typing import TensorType, TrainerConfigDict torch, _ = try_import_torch() diff --git a/rllib/utils/exploration/curiosity.py b/rllib/utils/exploration/curiosity.py index e1e4ac2b7..4191366d4 100644 --- a/rllib/utils/exploration/curiosity.py +++ b/rllib/utils/exploration/curiosity.py @@ -27,7 +27,7 @@ from ray.rllib.models.torch.misc import SlimFC from ray.rllib.utils.exploration.exploration import Exploration from ray.rllib.utils.framework import try_import_torch, TensorType from ray.rllib.utils.from_config import from_config -from ray.rllib.utils.types import SampleBatchType, TrainerConfigDict +from ray.rllib.utils.typing import SampleBatchType, TrainerConfigDict torch, nn = try_import_torch() diff --git a/rllib/utils/framework.py b/rllib/utils/framework.py index 9346b792b..d7a3b8db3 100644 --- a/rllib/utils/framework.py +++ b/rllib/utils/framework.py @@ -3,7 +3,7 @@ import os import sys from typing import Any, Optional -from ray.rllib.utils.types import TensorStructType, TensorShape, TensorType +from ray.rllib.utils.typing import TensorStructType, TensorShape, TensorType logger = logging.getLogger(__name__) diff --git a/rllib/utils/schedules/polynomial_schedule.py b/rllib/utils/schedules/polynomial_schedule.py index ba54ec542..270a99a2f 100644 --- a/rllib/utils/schedules/polynomial_schedule.py +++ b/rllib/utils/schedules/polynomial_schedule.py @@ -3,7 +3,7 @@ from typing import Union from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.schedules.schedule import Schedule -from ray.rllib.utils.types import TensorType +from ray.rllib.utils.typing import TensorType tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() diff --git a/rllib/utils/types.py b/rllib/utils/typing.py similarity index 100% rename from rllib/utils/types.py rename to rllib/utils/typing.py