mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 01:43:50 +08:00
[RLlib] Rename rllib.utils.types into typing to match built-in python module's name. (#10114)
This commit is contained in:
@@ -6,7 +6,7 @@ from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.evaluation import MultiAgentEpisode, RolloutWorker
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
from ray.rllib.utils.types import AgentID, PolicyID
|
||||
from ray.rllib.utils.typing import AgentID, PolicyID
|
||||
|
||||
|
||||
@PublicAPI
|
||||
|
||||
@@ -16,7 +16,7 @@ from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.execution.replay_buffer import ReplayActor
|
||||
from ray.rllib.utils import merge_dicts
|
||||
from ray.rllib.utils.actors import create_colocated
|
||||
from ray.rllib.utils.types import SampleBatchType
|
||||
from ray.rllib.utils.typing import SampleBatchType
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
|
||||
@@ -6,7 +6,7 @@ from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.evaluation.rollout_worker import get_global_worker
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER
|
||||
from ray.rllib.utils.types import SampleBatchType
|
||||
from ray.rllib.utils.typing import SampleBatchType
|
||||
from ray.rllib.utils.torch_ops import convert_to_torch_tensor
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
@@ -25,7 +25,7 @@ from ray.rllib.utils.framework import try_import_tf, TensorStructType
|
||||
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
|
||||
from ray.rllib.utils.from_config import from_config
|
||||
from ray.rllib.utils.types import TrainerConfigDict, \
|
||||
from ray.rllib.utils.typing import TrainerConfigDict, \
|
||||
PartialTrainerConfigDict, EnvInfoDict, ResultDict, EnvType, PolicyID
|
||||
from ray.tune.registry import ENV_CREATOR, register_env, _global_registry
|
||||
from ray.tune.trainable import Trainable
|
||||
|
||||
@@ -9,7 +9,7 @@ from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.utils import add_mixins
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.types import TrainerConfigDict, ResultDict
|
||||
from ray.rllib.utils.typing import TrainerConfigDict, ResultDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
Vendored
+1
-1
@@ -5,7 +5,7 @@ from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
from ray.rllib.env.vector_env import VectorEnv
|
||||
from ray.rllib.utils.annotations import override, PublicAPI
|
||||
from ray.rllib.utils.types import EnvType, MultiEnvDict, EnvID, \
|
||||
from ray.rllib.utils.typing import EnvType, MultiEnvDict, EnvID, \
|
||||
AgentID, MultiAgentDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
Vendored
+1
-1
@@ -1,5 +1,5 @@
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
from ray.rllib.utils.types import EnvConfigDict
|
||||
from ray.rllib.utils.typing import EnvConfigDict
|
||||
|
||||
|
||||
@PublicAPI
|
||||
|
||||
Vendored
+1
-1
@@ -5,7 +5,7 @@ import uuid
|
||||
from typing import Optional
|
||||
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
from ray.rllib.utils.types import EnvActionType, EnvObsType, EnvInfoDict
|
||||
from ray.rllib.utils.typing import EnvActionType, EnvObsType, EnvInfoDict
|
||||
|
||||
|
||||
@PublicAPI
|
||||
|
||||
+1
-1
@@ -4,7 +4,7 @@ from typing import Optional
|
||||
|
||||
from ray.rllib.utils.annotations import override, PublicAPI
|
||||
from ray.rllib.env.external_env import ExternalEnv, _ExternalEnvEpisode
|
||||
from ray.rllib.utils.types import MultiAgentDict
|
||||
from ray.rllib.utils.typing import MultiAgentDict
|
||||
|
||||
|
||||
@PublicAPI
|
||||
|
||||
Vendored
+1
-1
@@ -2,7 +2,7 @@ from typing import Tuple, Dict, List
|
||||
import gym
|
||||
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
from ray.rllib.utils.types import MultiAgentDict, AgentID
|
||||
from ray.rllib.utils.typing import MultiAgentDict, AgentID
|
||||
|
||||
# If the obs space is Dict type, look for the global state under this key.
|
||||
ENV_STATE = "state"
|
||||
|
||||
Vendored
+1
-1
@@ -14,7 +14,7 @@ from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.env import ExternalEnv, MultiAgentEnv, ExternalMultiAgentEnv
|
||||
from ray.rllib.policy.sample_batch import MultiAgentBatch
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
from ray.rllib.utils.types import MultiAgentDict, EnvInfoDict, EnvObsType, \
|
||||
from ray.rllib.utils.typing import MultiAgentDict, EnvInfoDict, EnvObsType, \
|
||||
EnvActionType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
Vendored
+1
-1
@@ -4,7 +4,7 @@ from typing import Tuple, Callable, Optional
|
||||
import ray
|
||||
from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID, ASYNC_RESET_RETURN
|
||||
from ray.rllib.utils.annotations import override, PublicAPI
|
||||
from ray.rllib.utils.types import MultiEnvDict, EnvType, EnvID, MultiAgentDict
|
||||
from ray.rllib.utils.typing import MultiEnvDict, EnvType, EnvID, MultiAgentDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
Vendored
+1
-1
@@ -5,7 +5,7 @@ from typing import Callable, Tuple
|
||||
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.types import MultiAgentDict, PolicyID, AgentID
|
||||
from ray.rllib.utils.typing import MultiAgentDict, PolicyID, AgentID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
Vendored
+1
-1
@@ -4,7 +4,7 @@ import numpy as np
|
||||
from typing import Callable, List, Tuple
|
||||
|
||||
from ray.rllib.utils.annotations import override, PublicAPI
|
||||
from ray.rllib.utils.types import EnvType, EnvConfigDict, EnvObsType, \
|
||||
from ray.rllib.utils.typing import EnvType, EnvConfigDict, EnvObsType, \
|
||||
EnvInfoDict, EnvActionType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -7,7 +7,7 @@ from ray.rllib.env.base_env import _DUMMY_AGENT_ID
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray
|
||||
from ray.rllib.utils.types import SampleBatchType, AgentID, PolicyID, \
|
||||
from ray.rllib.utils.typing import SampleBatchType, AgentID, PolicyID, \
|
||||
EnvObsType, EnvInfoDict, EnvActionType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -9,7 +9,7 @@ from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimate
|
||||
from ray.rllib.policy.policy import LEARNER_STATS_KEY
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.types import GradInfoDict, LearnerStatsDict, ResultDict
|
||||
from ray.rllib.utils.typing import GradInfoDict, LearnerStatsDict, ResultDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from ray.rllib.env import BaseEnv
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.evaluation import MultiAgentEpisode, RolloutWorker
|
||||
from ray.rllib.utils.framework import TensorType
|
||||
from ray.rllib.utils.types import AgentID, PolicyID
|
||||
from ray.rllib.utils.typing import AgentID, PolicyID
|
||||
|
||||
|
||||
class ObservationFunction:
|
||||
|
||||
@@ -36,7 +36,7 @@ from ray.rllib.utils.filter import get_filter, Filter
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.sgd import do_minibatch_sgd
|
||||
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
||||
from ray.rllib.utils.types import EnvType, AgentID, PolicyID, EnvConfigDict, \
|
||||
from ray.rllib.utils.typing import EnvType, AgentID, PolicyID, EnvConfigDict, \
|
||||
ModelConfigDict, TrainerConfigDict, SampleBatchType, ModelWeights, \
|
||||
ModelGradients, MultiAgentPolicyConfigDict
|
||||
from ray.util.debug import log_once, disable_log_once_globally, \
|
||||
|
||||
@@ -8,7 +8,7 @@ from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
|
||||
from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI
|
||||
from ray.rllib.utils.debug import summarize
|
||||
from ray.rllib.utils.types import PolicyID, AgentID
|
||||
from ray.rllib.utils.typing import PolicyID, AgentID
|
||||
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
|
||||
from ray.util.debug import log_once
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Dict, Optional
|
||||
|
||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.utils.types import AgentID, EpisodeID, PolicyID, \
|
||||
from ray.rllib.utils.typing import AgentID, EpisodeID, PolicyID, \
|
||||
TensorType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -25,7 +25,7 @@ from ray.rllib.utils.debug import summarize
|
||||
from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray, \
|
||||
unbatch
|
||||
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
||||
from ray.rllib.utils.types import SampleBatchType, AgentID, PolicyID, \
|
||||
from ray.rllib.utils.typing import SampleBatchType, AgentID, PolicyID, \
|
||||
EnvObsType, EnvInfoDict, EnvID, MultiEnvDict, EnvActionType, \
|
||||
TensorStructType
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from ray.rllib.env.env_context import EnvContext
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.utils import merge_dicts
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.types import PolicyID, TrainerConfigDict, EnvType
|
||||
from ray.rllib.utils.typing import PolicyID, TrainerConfigDict, EnvType
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.util.iter import ParallelIteratorWorker
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.utils.window_stat import WindowStat
|
||||
from ray.rllib.utils.types import SampleBatchType
|
||||
from ray.rllib.utils.typing import SampleBatchType
|
||||
|
||||
# Constant that represents all policies in lockstep replay mode.
|
||||
_ALL_POLICIES = "__all__"
|
||||
|
||||
@@ -6,7 +6,7 @@ from ray.util.iter_metrics import SharedMetrics
|
||||
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
|
||||
from ray.rllib.execution.common import \
|
||||
STEPS_SAMPLED_COUNTER, _get_shared_metrics
|
||||
from ray.rllib.utils.types import SampleBatchType
|
||||
from ray.rllib.utils.typing import SampleBatchType
|
||||
|
||||
|
||||
class StoreToReplayBuffer:
|
||||
|
||||
@@ -13,7 +13,7 @@ from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, LEARNER_INFO, \
|
||||
from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
|
||||
MultiAgentBatch
|
||||
from ray.rllib.utils.sgd import standardized
|
||||
from ray.rllib.utils.types import PolicyID, SampleBatchType, ModelGradients
|
||||
from ray.rllib.utils.typing import PolicyID, SampleBatchType, ModelGradients
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
|
||||
MultiAgentBatch
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.sgd import do_minibatch_sgd, averaged
|
||||
from ray.rllib.utils.types import PolicyID, SampleBatchType
|
||||
from ray.rllib.utils.typing import PolicyID, SampleBatchType
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
|
||||
from ray.rllib.utils.actors import create_colocated
|
||||
from ray.util.iter import ParallelIterator, ParallelIteratorWorker, \
|
||||
from_actors
|
||||
from ray.rllib.utils.types import SampleBatchType
|
||||
from ray.rllib.utils.typing import SampleBatchType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import gym
|
||||
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.types import TensorType, List, Union, ModelConfigDict
|
||||
from ray.rllib.utils.typing import TensorType, List, Union, ModelConfigDict
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
|
||||
@@ -27,7 +27,7 @@ from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.spaces.simplex import Simplex
|
||||
from ray.rllib.utils.spaces.space_utils import flatten_space
|
||||
from ray.rllib.utils.types import ModelConfigDict, TensorType
|
||||
from ray.rllib.utils.typing import ModelConfigDict, TensorType
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch, \
|
||||
TensorType
|
||||
from ray.rllib.utils.spaces.repeated import Repeated
|
||||
from ray.rllib.utils.types import ModelConfigDict, TensorStructType
|
||||
from ray.rllib.utils.typing import ModelConfigDict, TensorStructType
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
torch, _ = try_import_torch()
|
||||
|
||||
@@ -10,7 +10,7 @@ from ray.rllib.utils import MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT, \
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_tfp
|
||||
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
|
||||
from ray.rllib.utils.types import TensorType, List
|
||||
from ray.rllib.utils.typing import TensorType, List
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
tfp = try_import_tfp()
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import List
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.utils.annotations import override, PublicAPI
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.types import ModelConfigDict, TensorType
|
||||
from ray.rllib.utils.typing import ModelConfigDict, TensorType
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from ray.rllib.utils.numpy import SMALL_NUMBER, MIN_LOG_NN_OUTPUT, \
|
||||
MAX_LOG_NN_OUTPUT
|
||||
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
|
||||
from ray.rllib.utils.torch_ops import atanh
|
||||
from ray.rllib.utils.types import TensorType, List
|
||||
from ray.rllib.utils.typing import TensorType, List
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import List
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.utils.annotations import override, PublicAPI
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.types import ModelConfigDict, TensorType
|
||||
from ray.rllib.utils.typing import ModelConfigDict, TensorType
|
||||
|
||||
_, nn = try_import_torch()
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from ray.rllib.policy.sample_batch import MultiAgentBatch
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from typing import Dict, List
|
||||
from ray.rllib.utils.types import TensorType, SampleBatchType
|
||||
from ray.rllib.utils.typing import TensorType, SampleBatchType
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, \
|
||||
OffPolicyEstimate
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.types import SampleBatchType
|
||||
from ray.rllib.utils.typing import SampleBatchType
|
||||
|
||||
|
||||
class ImportanceSamplingEstimator(OffPolicyEstimator):
|
||||
|
||||
@@ -16,7 +16,7 @@ from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch, \
|
||||
DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.annotations import override, PublicAPI
|
||||
from ray.rllib.utils.compression import unpack_if_needed
|
||||
from ray.rllib.utils.types import FileType, SampleBatchType
|
||||
from ray.rllib.utils.typing import FileType, SampleBatchType
|
||||
from typing import List
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -16,7 +16,7 @@ from ray.rllib.offline.io_context import IOContext
|
||||
from ray.rllib.offline.output_writer import OutputWriter
|
||||
from ray.rllib.utils.annotations import override, PublicAPI
|
||||
from ray.rllib.utils.compression import pack, compression_supported
|
||||
from ray.rllib.utils.types import FileType, SampleBatchType
|
||||
from ray.rllib.utils.typing import FileType, SampleBatchType
|
||||
from typing import Any, List
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -4,7 +4,7 @@ from ray.rllib.offline.input_reader import InputReader
|
||||
from ray.rllib.offline.json_reader import JsonReader
|
||||
from ray.rllib.offline.io_context import IOContext
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.types import SampleBatchType
|
||||
from ray.rllib.utils.typing import SampleBatchType
|
||||
from typing import Dict
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.offline.io_context import IOContext
|
||||
from ray.rllib.utils.types import TensorType, SampleBatchType
|
||||
from ray.rllib.utils.typing import TensorType, SampleBatchType
|
||||
from typing import List
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
from ray.rllib.utils.types import SampleBatchType
|
||||
from ray.rllib.utils.typing import SampleBatchType
|
||||
|
||||
|
||||
@PublicAPI
|
||||
|
||||
@@ -3,7 +3,7 @@ import random
|
||||
|
||||
from ray.rllib.offline.input_reader import InputReader
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.types import SampleBatchType
|
||||
from ray.rllib.utils.typing import SampleBatchType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, \
|
||||
OffPolicyEstimate
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.types import SampleBatchType
|
||||
from ray.rllib.utils.typing import SampleBatchType
|
||||
|
||||
|
||||
class WeightedImportanceSamplingEstimator(OffPolicyEstimator):
|
||||
|
||||
@@ -14,7 +14,8 @@ from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.debug import summarize
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.tracking_dict import UsageTrackingDict
|
||||
from ray.rllib.utils.types import ModelGradients, TensorType, TrainerConfigDict
|
||||
from ray.rllib.utils.typing import ModelGradients, TensorType, \
|
||||
TrainerConfigDict
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.from_config import from_config
|
||||
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space, \
|
||||
unbatch
|
||||
from ray.rllib.utils.types import AgentID, ModelGradients, ModelWeights, \
|
||||
from ray.rllib.utils.typing import AgentID, ModelGradients, ModelWeights, \
|
||||
TensorType, TrainerConfigDict, Tuple, Union
|
||||
|
||||
torch, _ = try_import_torch()
|
||||
|
||||
@@ -8,7 +8,7 @@ from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI
|
||||
from ray.rllib.utils.compression import pack, unpack, is_compressed
|
||||
from ray.rllib.utils.memory import concat_aligned
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
from ray.rllib.utils.types import TensorType
|
||||
from ray.rllib.utils.typing import TensorType
|
||||
|
||||
# Default policy id for single agent environments
|
||||
DEFAULT_POLICY_ID = "default_policy"
|
||||
|
||||
@@ -17,7 +17,8 @@ from ray.rllib.utils.debug import summarize
|
||||
from ray.rllib.utils.framework import try_import_tf, get_variable
|
||||
from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule
|
||||
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
||||
from ray.rllib.utils.types import ModelGradients, TensorType, TrainerConfigDict
|
||||
from ray.rllib.utils.typing import ModelGradients, TensorType, \
|
||||
TrainerConfigDict
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -9,7 +9,8 @@ from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
from ray.rllib.utils import add_mixins
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.types import ModelGradients, TensorType, TrainerConfigDict
|
||||
from ray.rllib.utils.typing import ModelGradients, TensorType, \
|
||||
TrainerConfigDict
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
|
||||
@@ -19,7 +19,7 @@ from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule
|
||||
from ray.rllib.utils.torch_ops import convert_to_non_torch_type, \
|
||||
convert_to_torch_tensor
|
||||
from ray.rllib.utils.tracking_dict import UsageTrackingDict
|
||||
from ray.rllib.utils.types import ModelGradients, ModelWeights, \
|
||||
from ray.rllib.utils.typing import ModelGradients, ModelWeights, \
|
||||
TensorType, TrainerConfigDict
|
||||
|
||||
torch, _ = try_import_torch()
|
||||
|
||||
@@ -13,7 +13,7 @@ from ray.rllib.utils import add_mixins, force_list
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.torch_ops import convert_to_non_torch_type
|
||||
from ray.rllib.utils.types import TensorType, TrainerConfigDict
|
||||
from ray.rllib.utils.typing import TensorType, TrainerConfigDict
|
||||
|
||||
torch, _ = try_import_torch()
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ from ray.rllib.models.torch.misc import SlimFC
|
||||
from ray.rllib.utils.exploration.exploration import Exploration
|
||||
from ray.rllib.utils.framework import try_import_torch, TensorType
|
||||
from ray.rllib.utils.from_config import from_config
|
||||
from ray.rllib.utils.types import SampleBatchType, TrainerConfigDict
|
||||
from ray.rllib.utils.typing import SampleBatchType, TrainerConfigDict
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import sys
|
||||
from typing import Any, Optional
|
||||
|
||||
from ray.rllib.utils.types import TensorStructType, TensorShape, TensorType
|
||||
from ray.rllib.utils.typing import TensorStructType, TensorShape, TensorType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Union
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.schedules.schedule import Schedule
|
||||
from ray.rllib.utils.types import TensorType
|
||||
from ray.rllib.utils.typing import TensorType
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
torch, _ = try_import_torch()
|
||||
|
||||
Reference in New Issue
Block a user