mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:22:39 +08:00
[RLlib] Type annotations for policy. (#9248)
This commit is contained in:
+2
-2
@@ -1134,10 +1134,10 @@ py_test(
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tests/test_dependency",
|
||||
name = "tests/test_dependency_tf",
|
||||
tags = ["tests_dir", "tests_dir_D"],
|
||||
size = "small",
|
||||
srcs = ["tests/test_dependency.py"]
|
||||
srcs = ["tests/test_dependency_tf.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
|
||||
@@ -179,7 +179,7 @@ class TorchCustomLossModel(TorchModelV2, nn.Module):
|
||||
# Add the imitation loss to each already calculated policy loss term.
|
||||
# Alternatively (if custom loss has its own optimizer):
|
||||
# return policy_loss + [10 * self.imitation_loss]
|
||||
return [l + 10 * self.imitation_loss for l in policy_loss]
|
||||
return [loss_ + 10 * self.imitation_loss for loss_ in policy_loss]
|
||||
|
||||
def custom_stats(self):
|
||||
return {
|
||||
|
||||
@@ -251,7 +251,7 @@ class TestModules(unittest.TestCase):
|
||||
self.train_tf_model(
|
||||
model, [x] + init_state,
|
||||
[y, value_labels, memory_labels, mlp_labels],
|
||||
num_epochs=50,
|
||||
num_epochs=200,
|
||||
minibatch_size=B)
|
||||
|
||||
|
||||
|
||||
@@ -1,27 +1,34 @@
|
||||
"""Graph mode TF policy built using build_tf_policy()."""
|
||||
|
||||
from collections import OrderedDict
|
||||
import gym
|
||||
import logging
|
||||
import numpy as np
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from ray.util.debug import log_once
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.debug import summarize
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.tracking_dict import UsageTrackingDict
|
||||
from ray.rllib.utils.types import ModelGradients, TensorType, TrainerConfigDict
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class DynamicTFPolicy(TFPolicy):
|
||||
"""A TFPolicy that auto-defines placeholders dynamically at runtime.
|
||||
|
||||
Do not sub-class this class directly (neither should you sub-class
|
||||
TFPolicy), but rather use rllib.policy.tf_policy_template.build_tf_policy
|
||||
to generate your custom tf (graph-mode or eager) Policy classes.
|
||||
|
||||
Initialization of this class occurs in two phases.
|
||||
* Phase 1: the model is created and model variables are initialized.
|
||||
* Phase 2: a fake batch of data is created, sent to the trajectory
|
||||
@@ -39,61 +46,91 @@ class DynamicTFPolicy(TFPolicy):
|
||||
dist_class (type): TF action distribution class
|
||||
"""
|
||||
|
||||
@DeveloperAPI
|
||||
def __init__(self,
|
||||
obs_space,
|
||||
action_space,
|
||||
config,
|
||||
loss_fn,
|
||||
stats_fn=None,
|
||||
grad_stats_fn=None,
|
||||
before_loss_init=None,
|
||||
make_model=None,
|
||||
action_sampler_fn=None,
|
||||
action_distribution_fn=None,
|
||||
existing_inputs=None,
|
||||
existing_model=None,
|
||||
get_batch_divisibility_req=None,
|
||||
obs_include_prev_action_reward=True):
|
||||
obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict,
|
||||
loss_fn: Callable[
|
||||
[Policy, ModelV2, type, SampleBatch], TensorType],
|
||||
*,
|
||||
stats_fn: Optional[Callable[[Policy, SampleBatch],
|
||||
Dict[str, TensorType]]] = None,
|
||||
grad_stats_fn: Optional[Callable[
|
||||
[Policy, SampleBatch, ModelGradients],
|
||||
Dict[str, TensorType]]] = None,
|
||||
before_loss_init: Optional[Callable[
|
||||
[Policy, gym.spaces.Space, gym.spaces.Space,
|
||||
TrainerConfigDict], None]] = None,
|
||||
make_model: Optional[Callable[
|
||||
[Policy, gym.spaces.Space, gym.spaces.Space,
|
||||
TrainerConfigDict], ModelV2]] = None,
|
||||
action_sampler_fn: Optional[Callable[
|
||||
[TensorType, List[TensorType]], Tuple[
|
||||
TensorType, TensorType]]] = None,
|
||||
action_distribution_fn: Optional[Callable[
|
||||
[Policy, ModelV2, TensorType, TensorType, TensorType],
|
||||
Tuple[TensorType, type, List[TensorType]]]] = None,
|
||||
existing_inputs: Optional[Dict[
|
||||
str, "tf1.placeholder"]] = None,
|
||||
existing_model: Optional[ModelV2] = None,
|
||||
get_batch_divisibility_req: Optional[int] = None,
|
||||
obs_include_prev_action_reward: bool = True):
|
||||
"""Initialize a dynamic TF policy.
|
||||
|
||||
Arguments:
|
||||
observation_space (gym.Space): Observation space of the policy.
|
||||
action_space (gym.Space): Action space of the policy.
|
||||
config (dict): Policy-specific configuration data.
|
||||
loss_fn (func): function that returns a loss tensor the policy
|
||||
graph, and dict of experience tensor placeholders
|
||||
stats_fn (func): optional function that returns a dict of
|
||||
TF fetches given the policy and batch input tensors
|
||||
grad_stats_fn (func): optional function that returns a dict of
|
||||
TF fetches given the policy and loss gradient tensors
|
||||
before_loss_init (Optional[callable]): Optional function to run
|
||||
prior to loss init that takes the same arguments as __init__.
|
||||
make_model (func): optional function that returns a ModelV2 object
|
||||
given (policy, obs_space, action_space, config).
|
||||
observation_space (gym.spaces.Space): Observation space of the
|
||||
policy.
|
||||
action_space (gym.spaces.Space): Action space of the policy.
|
||||
config (TrainerConfigDict): Policy-specific configuration data.
|
||||
loss_fn (Callable[[Policy, ModelV2, type, SampleBatch],
|
||||
TensorType]): Function that returns a loss tensor for the
|
||||
policy graph.
|
||||
stats_fn (Optional[Callable[[Policy, SampleBatch],
|
||||
Dict[str, TensorType]]]): Optional function that returns a dict
|
||||
of TF fetches given the policy and batch input tensors.
|
||||
grad_stats_fn (Optional[Callable[[Policy, SampleBatch,
|
||||
ModelGradients], Dict[str, TensorType]]]):
|
||||
Optional function that returns a dict of TF fetches given the
|
||||
policy, sample batch, and loss gradient tensors.
|
||||
before_loss_init (Optional[Callable[
|
||||
[Policy, gym.spaces.Space, gym.spaces.Space,
|
||||
TrainerConfigDict], None]]): Optional function to run prior to
|
||||
loss init that takes the same arguments as __init__.
|
||||
make_model (Optional[Callable[[Policy, gym.spaces.Space,
|
||||
gym.spaces.Space, TrainerConfigDict], ModelV2]]): Optional
|
||||
function that returns a ModelV2 object given
|
||||
policy, obs_space, action_space, and policy config.
|
||||
All policy variables should be created in this function. If not
|
||||
specified, a default model will be created.
|
||||
action_sampler_fn (Optional[callable]): An optional callable
|
||||
returning a tuple of action and action prob tensors given
|
||||
(policy, model, input_dict, obs_space, action_space, config).
|
||||
If None, a default action distribution will be used.
|
||||
action_distribution_fn (Optional[callable]): A callable returning
|
||||
distribution inputs (parameters), a dist-class to generate an
|
||||
action distribution object from, and internal-state outputs
|
||||
(or an empty list if not applicable).
|
||||
action_sampler_fn (Optional[Callable[[Policy, ModelV2, Dict[
|
||||
str, TensorType], TensorType, TensorType], Tuple[TensorType,
|
||||
TensorType]]]): A callable returning a sampled action and its
|
||||
log-likelihood given Policy, ModelV2, input_dict, explore,
|
||||
timestep, and is_training.
|
||||
action_distribution_fn (Optional[Callable[[Policy, ModelV2,
|
||||
Dict[str, TensorType], TensorType, TensorType],
|
||||
Tuple[TensorType, type, List[TensorType]]]]): A callable
|
||||
returning distribution inputs (parameters), a dist-class to
|
||||
generate an action distribution object from, and
|
||||
internal-state outputs (or an empty list if not applicable).
|
||||
Note: No Exploration hooks have to be called from within
|
||||
`action_distribution_fn`. It's should only perform a simple
|
||||
forward pass through some model.
|
||||
If None, pass inputs through `self.model()` to get the
|
||||
distribution inputs.
|
||||
existing_inputs (OrderedDict): When copying a policy, this
|
||||
specifies an existing dict of placeholders to use instead of
|
||||
defining new ones
|
||||
existing_model (ModelV2): when copying a policy, this specifies
|
||||
an existing model to clone and share weights with
|
||||
get_batch_divisibility_req (func): optional function that returns
|
||||
the divisibility requirement for sample batches
|
||||
obs_include_prev_action_reward (bool): whether to include the
|
||||
previous action and reward in the model input
|
||||
If None, pass inputs through `self.model()` to get distribution
|
||||
inputs.
|
||||
The callable takes as inputs: Policy, ModelV2, input_dict,
|
||||
explore, timestep, is_training.
|
||||
existing_inputs (Optional[Dict[str, tf1.placeholder]]): When
|
||||
copying a policy, this specifies an existing dict of
|
||||
placeholders to use instead of defining new ones.
|
||||
existing_model (Optional[ModelV2]): When copying a policy, this
|
||||
specifies an existing model to clone and share weights with.
|
||||
get_batch_divisibility_req (Optional[Callable[[Policy], int]]]):
|
||||
Optional callable that returns the divisibility requirement
|
||||
for sample batches given the Policy.
|
||||
obs_include_prev_action_reward (bool): Whether to include the
|
||||
previous action and reward in the model input (default: True).
|
||||
"""
|
||||
self.observation_space = obs_space
|
||||
self.action_space = action_space
|
||||
@@ -258,10 +295,12 @@ class DynamicTFPolicy(TFPolicy):
|
||||
before_loss_init(self, obs_space, action_space, config)
|
||||
|
||||
if not existing_inputs:
|
||||
self._initialize_loss()
|
||||
self._initialize_loss_dynamically()
|
||||
|
||||
@override(TFPolicy)
|
||||
def copy(self, existing_inputs):
|
||||
@DeveloperAPI
|
||||
def copy(self,
|
||||
existing_inputs: List[Tuple[str, "tf1.placeholder"]]) -> TFPolicy:
|
||||
"""Creates a copy of self using existing input placeholders."""
|
||||
|
||||
# Note that there might be RNN state inputs at the end of the list
|
||||
@@ -306,13 +345,14 @@ class DynamicTFPolicy(TFPolicy):
|
||||
return instance
|
||||
|
||||
@override(Policy)
|
||||
def get_initial_state(self):
|
||||
@DeveloperAPI
|
||||
def get_initial_state(self) -> List[TensorType]:
|
||||
if self.model:
|
||||
return self.model.get_initial_state()
|
||||
else:
|
||||
return []
|
||||
|
||||
def _initialize_loss(self):
|
||||
def _initialize_loss_dynamically(self):
|
||||
def fake_array(tensor):
|
||||
shape = tensor.shape.as_list()
|
||||
shape = [s if s is not None else 1 for s in shape]
|
||||
@@ -404,7 +444,7 @@ class DynamicTFPolicy(TFPolicy):
|
||||
self._grad_stats_fn(self, train_batch, self._grads))
|
||||
self._sess.run(tf1.global_variables_initializer())
|
||||
|
||||
def _do_loss_init(self, train_batch):
|
||||
def _do_loss_init(self, train_batch: SampleBatch):
|
||||
loss = self._loss_fn(self, self.model, self.dist_class, train_batch)
|
||||
if self._stats_fn:
|
||||
self._stats_fetches.update(self._stats_fn(self, train_batch))
|
||||
|
||||
+186
-127
@@ -4,13 +4,15 @@ import numpy as np
|
||||
import tree
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.exploration.exploration import Exploration
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.from_config import from_config
|
||||
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space, \
|
||||
unbatch
|
||||
from ray.rllib.utils.types import AgentID
|
||||
from ray.rllib.utils.types import AgentID, ModelGradients, ModelWeights, \
|
||||
TensorType, TrainerConfigDict, Tuple, Union
|
||||
|
||||
torch, _ = try_import_torch()
|
||||
|
||||
@@ -41,7 +43,11 @@ class Policy(metaclass=ABCMeta):
|
||||
"""
|
||||
|
||||
@DeveloperAPI
|
||||
def __init__(self, observation_space, action_space, config):
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict):
|
||||
"""Initialize the graph.
|
||||
|
||||
This is the standard constructor for policies. The policy
|
||||
@@ -49,9 +55,10 @@ class Policy(metaclass=ABCMeta):
|
||||
these arguments.
|
||||
|
||||
Args:
|
||||
observation_space (gym.Space): Observation space of the policy.
|
||||
action_space (gym.Space): Action space of the policy.
|
||||
config (dict): Policy-specific configuration data.
|
||||
observation_space (gym.spaces.Space): Observation space of the
|
||||
policy.
|
||||
action_space (gym.spaces.Space): Action space of the policy.
|
||||
config (TrainerConfigDict): Policy-specific configuration data.
|
||||
"""
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
@@ -66,78 +73,95 @@ class Policy(metaclass=ABCMeta):
|
||||
|
||||
@abstractmethod
|
||||
@DeveloperAPI
|
||||
def compute_actions(self,
|
||||
obs_batch,
|
||||
state_batches=None,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
info_batch=None,
|
||||
episodes=None,
|
||||
explore=None,
|
||||
timestep=None,
|
||||
**kwargs):
|
||||
def compute_actions(
|
||||
self,
|
||||
obs_batch: Union[List[TensorType], TensorType],
|
||||
state_batches: Optional[List[TensorType]] = None,
|
||||
prev_action_batch: Union[List[TensorType], TensorType] = None,
|
||||
prev_reward_batch: Union[List[TensorType], TensorType] = None,
|
||||
info_batch: Optional[Dict[str, list]] = None,
|
||||
episodes: Optional[List["MultiAgentEpisode"]] = None,
|
||||
explore: Optional[bool] = None,
|
||||
timestep: Optional[int] = None,
|
||||
**kwargs) -> \
|
||||
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
|
||||
"""Computes actions for the current policy.
|
||||
|
||||
Args:
|
||||
obs_batch (Union[List, np.ndarray]): Batch of observations.
|
||||
state_batches (Optional[list]): List of RNN state input batches,
|
||||
if any.
|
||||
prev_action_batch (Optional[List, np.ndarray]): Batch of previous
|
||||
action values.
|
||||
prev_reward_batch (Optional[List, np.ndarray]): Batch of previous
|
||||
rewards.
|
||||
info_batch (info): Batch of info objects.
|
||||
episodes (list): MultiAgentEpisode for each obs in obs_batch.
|
||||
This provides access to all of the internal episode state,
|
||||
which may be useful for model-based or multiagent algorithms.
|
||||
explore (bool): Whether to pick an exploitation or exploration
|
||||
action (default: None -> use self.config["explore"]).
|
||||
timestep (int): The current (sampling) time step.
|
||||
obs_batch (Union[List[TensorType], TensorType]): Batch of
|
||||
observations.
|
||||
state_batches (Optional[List[TensorType]]): List of RNN state input
|
||||
batches, if any.
|
||||
prev_action_batch (Union[List[TensorType], TensorType]): Batch of
|
||||
previous action values.
|
||||
prev_reward_batch (Union[List[TensorType], TensorType]): Batch of
|
||||
previous rewards.
|
||||
info_batch (Optional[Dict[str, list]]): Batch of info objects.
|
||||
episodes (Optional[List[MultiAgentEpisode]] ): List of
|
||||
MultiAgentEpisode, one for each obs in obs_batch. This provides
|
||||
access to all of the internal episode state, which may be
|
||||
useful for model-based or multiagent algorithms.
|
||||
explore (Optional[bool]): Whether to pick an exploitation or
|
||||
exploration action. Set to None (default) for using the
|
||||
value of `self.config["explore"]`.
|
||||
timestep (Optional[int]): The current (sampling) time step.
|
||||
|
||||
Keyword Args:
|
||||
kwargs: forward compatibility placeholder
|
||||
|
||||
Returns:
|
||||
actions (np.ndarray): batch of output actions, with shape like
|
||||
[BATCH_SIZE, ACTION_SHAPE].
|
||||
state_outs (list): list of RNN state output batches, if any, with
|
||||
shape like [STATE_SIZE, BATCH_SIZE].
|
||||
info (dict): dictionary of extra feature batches, if any, with
|
||||
shape like {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}.
|
||||
Tuple:
|
||||
actions (TensorType): Batch of output actions, with shape like
|
||||
[BATCH_SIZE, ACTION_SHAPE].
|
||||
state_outs (List[TensorType]): List of RNN state output
|
||||
batches, if any, with shape like [STATE_SIZE, BATCH_SIZE].
|
||||
info (List[dict]): Dictionary of extra feature batches, if any,
|
||||
with shape like
|
||||
{"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def compute_single_action(self,
|
||||
obs,
|
||||
state=None,
|
||||
prev_action=None,
|
||||
prev_reward=None,
|
||||
info=None,
|
||||
episode=None,
|
||||
clip_actions=False,
|
||||
explore=None,
|
||||
timestep=None,
|
||||
**kwargs):
|
||||
def compute_single_action(
|
||||
self,
|
||||
obs: TensorType,
|
||||
state: Optional[List[TensorType]] = None,
|
||||
prev_action: Optional[TensorType] = None,
|
||||
prev_reward: Optional[TensorType] = None,
|
||||
info: dict = None,
|
||||
episode: Optional["MultiAgentEpisode"] = None,
|
||||
clip_actions: bool = False,
|
||||
explore: Optional[bool] = None,
|
||||
timestep: Optional[int] = None,
|
||||
**kwargs) -> \
|
||||
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
|
||||
"""Unbatched version of compute_actions.
|
||||
|
||||
Arguments:
|
||||
obs (obj): Single observation.
|
||||
state (list): List of RNN state inputs, if any.
|
||||
prev_action (obj): Previous action value, if any.
|
||||
prev_reward (float): Previous reward, if any.
|
||||
info (dict): info object, if any
|
||||
episode (MultiAgentEpisode): this provides access to all of the
|
||||
internal episode state, which may be useful for model-based or
|
||||
multi-agent algorithms.
|
||||
Args:
|
||||
obs (TensorType): Single observation.
|
||||
state (Optional[List[TensorType]]): List of RNN state inputs, if
|
||||
any.
|
||||
prev_action (Optional[TensorType]): Previous action value, if any.
|
||||
prev_reward (Optional[TensorType]): Previous reward, if any.
|
||||
info (dict): Info object, if any.
|
||||
episode (Optional[MultiAgentEpisode]): this provides access to all
|
||||
of the internal episode state, which may be useful for
|
||||
model-based or multi-agent algorithms.
|
||||
clip_actions (bool): Should actions be clipped?
|
||||
explore (bool): Whether to pick an exploitation or exploration
|
||||
action (default: None -> use self.config["explore"]).
|
||||
timestep (int): The current (sampling) time step.
|
||||
explore (Optional[bool]): Whether to pick an exploitation or
|
||||
exploration action
|
||||
(default: None -> use self.config["explore"]).
|
||||
timestep (Optional[int]): The current (sampling) time step.
|
||||
|
||||
Keyword Args:
|
||||
kwargs: forward compatibility placeholder
|
||||
|
||||
Returns:
|
||||
actions (obj): single action
|
||||
state_outs (list): list of RNN state outputs, if any
|
||||
info (dict): dictionary of extra features, if any
|
||||
Tuple:
|
||||
actions (TensorType): Single action.
|
||||
state_outs (List[TensorType]): List of RNN state outputs,
|
||||
if any.
|
||||
info (dict): Dictionary of extra features, if any.
|
||||
"""
|
||||
prev_action_batch = None
|
||||
prev_reward_batch = None
|
||||
@@ -196,7 +220,8 @@ class Policy(metaclass=ABCMeta):
|
||||
other_trajectories: Dict[AgentID, "Trajectory"],
|
||||
explore: bool = None,
|
||||
timestep: Optional[int] = None,
|
||||
**kwargs):
|
||||
**kwargs) -> \
|
||||
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
|
||||
"""Computes actions for the current policy based on .
|
||||
|
||||
Note: This is an experimental API method.
|
||||
@@ -215,59 +240,68 @@ class Policy(metaclass=ABCMeta):
|
||||
kwargs: forward compatibility placeholder
|
||||
|
||||
Returns:
|
||||
actions (np.ndarray): batch of output actions, with shape like
|
||||
[BATCH_SIZE, ACTION_SHAPE].
|
||||
state_outs (list): list of RNN state output batches, if any, with
|
||||
shape like [STATE_SIZE, BATCH_SIZE].
|
||||
info (dict): dictionary of extra feature batches, if any, with
|
||||
shape like {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}.
|
||||
Tuple:
|
||||
actions (TensorType): Batch of output actions, with shape
|
||||
like [BATCH_SIZE, ACTION_SHAPE].
|
||||
state_outs (List[TensorType]): List of RNN state output
|
||||
batches, if any, with shape like [STATE_SIZE, BATCH_SIZE].
|
||||
info (dict): Dictionary of extra feature batches, if any, with
|
||||
shape like
|
||||
{"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def compute_log_likelihoods(self,
|
||||
actions,
|
||||
obs_batch,
|
||||
state_batches=None,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None):
|
||||
def compute_log_likelihoods(
|
||||
self,
|
||||
actions: Union[List[TensorType], TensorType],
|
||||
obs_batch: Union[List[TensorType], TensorType],
|
||||
state_batches: Optional[List[TensorType]] = None,
|
||||
prev_action_batch: Optional[
|
||||
Union[List[TensorType], TensorType]] = None,
|
||||
prev_reward_batch: Optional[
|
||||
Union[List[TensorType], TensorType]] = None) -> TensorType:
|
||||
"""Computes the log-prob/likelihood for a given action and observation.
|
||||
|
||||
Args:
|
||||
actions (Union[List,np.ndarray]): Batch of actions, for which to
|
||||
retrieve the log-probs/likelihoods (given all other inputs:
|
||||
obs, states, ..).
|
||||
obs_batch (Union[List,np.ndarray]): Batch of observations.
|
||||
state_batches (Optional[list]): List of RNN state input batches,
|
||||
if any.
|
||||
prev_action_batch (Optional[List,np.ndarray]): Batch of previous
|
||||
action values.
|
||||
prev_reward_batch (Optional[List,np.ndarray]): Batch of previous
|
||||
rewards.
|
||||
actions (Union[List[TensorType], TensorType]): Batch of actions,
|
||||
for which to retrieve the log-probs/likelihoods (given all
|
||||
other inputs: obs, states, ..).
|
||||
obs_batch (Union[List[TensorType], TensorType]): Batch of
|
||||
observations.
|
||||
state_batches (Optional[List[TensorType]]): List of RNN state input
|
||||
batches, if any.
|
||||
prev_action_batch (Optional[Union[List[TensorType], TensorType]]):
|
||||
Batch of previous action values.
|
||||
prev_reward_batch (Optional[Union[List[TensorType], TensorType]]):
|
||||
Batch of previous rewards.
|
||||
|
||||
Returns:
|
||||
log-likelihoods (np.ndarray): Batch of log probs/likelihoods, with
|
||||
shape: [BATCH_SIZE].
|
||||
TensorType: Batch of log probs/likelihoods, with shape:
|
||||
[BATCH_SIZE].
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
def postprocess_trajectory(
|
||||
self,
|
||||
sample_batch: SampleBatch,
|
||||
other_agent_batches: Optional[
|
||||
Dict[AgentID, Tuple["Policy", SampleBatch]]] = None,
|
||||
episode: Optional["MultiAgentEpisode"] = None) -> SampleBatch:
|
||||
"""Implements algorithm-specific trajectory postprocessing.
|
||||
|
||||
This will be called on each trajectory fragment computed during policy
|
||||
evaluation. Each fragment is guaranteed to be only from one episode.
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
sample_batch (SampleBatch): batch of experiences for the policy,
|
||||
which will contain at most one episode trajectory.
|
||||
other_agent_batches (dict): In a multi-agent env, this contains a
|
||||
mapping of agent ids to (policy, agent_batch) tuples
|
||||
containing the policy and experiences of the other agents.
|
||||
episode (MultiAgentEpisode): this provides access to all of the
|
||||
episode (Optional[MultiAgentEpisode]): An optional multi-agent
|
||||
episode object to provide access to all of the
|
||||
internal episode state, which may be useful for model-based or
|
||||
multi-agent algorithms.
|
||||
|
||||
@@ -277,18 +311,22 @@ class Policy(metaclass=ABCMeta):
|
||||
return sample_batch
|
||||
|
||||
@DeveloperAPI
|
||||
def learn_on_batch(self, samples):
|
||||
def learn_on_batch(self, samples: SampleBatch) -> Dict[str, TensorType]:
|
||||
"""Fused compute gradients and apply gradients call.
|
||||
|
||||
Either this or the combination of compute/apply grads must be
|
||||
implemented by subclasses.
|
||||
|
||||
Args:
|
||||
samples (SampleBatch): The SampleBatch object to learn from.
|
||||
|
||||
Returns:
|
||||
grad_info: dictionary of extra metadata from compute_gradients().
|
||||
Dict[str, TensorType]: Dictionary of extra metadata from
|
||||
compute_gradients().
|
||||
|
||||
Examples:
|
||||
>>> batch = ev.sample()
|
||||
>>> ev.learn_on_batch(samples)
|
||||
>>> sample_batch = ev.sample()
|
||||
>>> ev.learn_on_batch(sample_batch)
|
||||
"""
|
||||
|
||||
grads, grad_info = self.compute_gradients(samples)
|
||||
@@ -296,65 +334,76 @@ class Policy(metaclass=ABCMeta):
|
||||
return grad_info
|
||||
|
||||
@DeveloperAPI
|
||||
def compute_gradients(self, postprocessed_batch):
|
||||
def compute_gradients(self, postprocessed_batch: SampleBatch) -> \
|
||||
Tuple[ModelGradients, Dict[str, TensorType]]:
|
||||
"""Computes gradients against a batch of experiences.
|
||||
|
||||
Either this or learn_on_batch() must be implemented by subclasses.
|
||||
|
||||
Args:
|
||||
postprocessed_batch (SampleBatch): The SampleBatch object to use
|
||||
for calculating gradients.
|
||||
|
||||
Returns:
|
||||
grads (list): List of gradient output values
|
||||
info (dict): Extra policy-specific values
|
||||
Tuple[ModelGradients, Dict[str, TensorType]]:
|
||||
- List of gradient output values.
|
||||
- Extra policy-specific info values.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def apply_gradients(self, gradients):
|
||||
def apply_gradients(self, gradients: ModelGradients) -> None:
|
||||
"""Applies previously computed gradients.
|
||||
|
||||
Either this or learn_on_batch() must be implemented by subclasses.
|
||||
|
||||
Args:
|
||||
gradients (ModelGradients): The already calculated gradients to
|
||||
apply to this Policy.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def get_weights(self):
|
||||
def get_weights(self) -> ModelWeights:
|
||||
"""Returns model weights.
|
||||
|
||||
Returns:
|
||||
weights (obj): Serializable copy or view of model weights
|
||||
ModelWeights: Serializable copy or view of model weights.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def set_weights(self, weights):
|
||||
def set_weights(self, weights: ModelWeights) -> None:
|
||||
"""Sets model weights.
|
||||
|
||||
Arguments:
|
||||
weights (obj): Serializable copy or view of model weights
|
||||
Args:
|
||||
weights (ModelWeights): Serializable copy or view of model weights.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def get_exploration_info(self):
|
||||
def get_exploration_info(self) -> Dict[str, TensorType]:
|
||||
"""Returns the current exploration information of this policy.
|
||||
|
||||
This information depends on the policy's Exploration object.
|
||||
|
||||
Returns:
|
||||
any: Serializable information on the `self.exploration` object.
|
||||
Dict[str, TensorType]: Serializable information on the
|
||||
`self.exploration` object.
|
||||
"""
|
||||
return self.exploration.get_info()
|
||||
|
||||
@DeveloperAPI
|
||||
def is_recurrent(self):
|
||||
def is_recurrent(self) -> bool:
|
||||
"""Whether this Policy holds a recurrent Model.
|
||||
|
||||
Returns:
|
||||
bool: True if this Policy has-a RNN-based Model.
|
||||
"""
|
||||
return 0
|
||||
return False
|
||||
|
||||
@DeveloperAPI
|
||||
def num_state_tensors(self):
|
||||
def num_state_tensors(self) -> int:
|
||||
"""The number of internal states needed by the RNN-Model of the Policy.
|
||||
|
||||
Returns:
|
||||
@@ -363,73 +412,83 @@ class Policy(metaclass=ABCMeta):
|
||||
return 0
|
||||
|
||||
@DeveloperAPI
|
||||
def get_initial_state(self):
|
||||
"""Returns initial RNN state for the current policy."""
|
||||
def get_initial_state(self) -> List[TensorType]:
|
||||
"""Returns initial RNN state for the current policy.
|
||||
|
||||
Returns:
|
||||
List[TensorType]: Initial RNN state for the current policy.
|
||||
"""
|
||||
return []
|
||||
|
||||
@DeveloperAPI
|
||||
def get_state(self):
|
||||
def get_state(self) -> Union[Dict[str, TensorType], List[TensorType]]:
|
||||
"""Saves all local state.
|
||||
|
||||
Returns:
|
||||
state (obj): Serialized local state.
|
||||
Union[Dict[str, TensorType], List[TensorType]]: Serialized local
|
||||
state.
|
||||
"""
|
||||
return self.get_weights()
|
||||
|
||||
@DeveloperAPI
|
||||
def set_state(self, state):
|
||||
def set_state(self, state: object) -> None:
|
||||
"""Restores all local state.
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
state (obj): Serialized local state.
|
||||
"""
|
||||
self.set_weights(state)
|
||||
|
||||
@DeveloperAPI
|
||||
def on_global_var_update(self, global_vars):
|
||||
def on_global_var_update(self, global_vars: Dict[str, TensorType]) -> None:
|
||||
"""Called on an update to global vars.
|
||||
|
||||
Arguments:
|
||||
global_vars (dict): Global variables broadcast from the driver.
|
||||
Args:
|
||||
global_vars (Dict[str, TensorType]): Global variables by str key,
|
||||
broadcast from the driver.
|
||||
"""
|
||||
# Store the current global time step (sum over all policies' sample
|
||||
# steps).
|
||||
self.global_timestep = global_vars["timestep"]
|
||||
|
||||
@DeveloperAPI
|
||||
def export_model(self, export_dir):
|
||||
def export_model(self, export_dir: str) -> None:
|
||||
"""Export Policy to local directory for serving.
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
export_dir (str): Local writable directory.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def export_checkpoint(self, export_dir):
|
||||
def export_checkpoint(self, export_dir: str) -> None:
|
||||
"""Export Policy checkpoint to local directory.
|
||||
|
||||
Argument:
|
||||
Args:
|
||||
export_dir (str): Local writable directory.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def import_model_from_h5(self, import_file):
|
||||
def import_model_from_h5(self, import_file: str) -> None:
|
||||
"""Imports Policy from local file.
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
import_file (str): Local readable file.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _create_exploration(self):
|
||||
def _create_exploration(self) -> Exploration:
|
||||
"""Creates the Policy's Exploration object.
|
||||
|
||||
This method only exists b/c some Trainers do not use TfPolicy nor
|
||||
TorchPolicy, but inherit directly from Policy. Others inherit from
|
||||
TfPolicy w/o using DynamicTfPolicy.
|
||||
TODO(sven): unify these cases."""
|
||||
TODO(sven): unify these cases.
|
||||
|
||||
Returns:
|
||||
Exploration: The Exploration object to be used by this Policy.
|
||||
"""
|
||||
if getattr(self, "exploration", None) is not None:
|
||||
return self.exploration
|
||||
|
||||
|
||||
@@ -230,8 +230,8 @@ def chop_into_sequences(episode_ids,
|
||||
f_pad = np.zeros((length, ) + np.shape(f)[1:])
|
||||
seq_base = 0
|
||||
i = 0
|
||||
for l in seq_lens:
|
||||
for seq_offset in range(l):
|
||||
for len_ in seq_lens:
|
||||
for seq_offset in range(len_):
|
||||
f_pad[seq_base + seq_offset] = f[i]
|
||||
i += 1
|
||||
seq_base += max_seq_len
|
||||
@@ -243,9 +243,9 @@ def chop_into_sequences(episode_ids,
|
||||
s = np.array(s)
|
||||
s_init = []
|
||||
i = 0
|
||||
for l in seq_lens:
|
||||
for len_ in seq_lens:
|
||||
s_init.append(s[i])
|
||||
i += l
|
||||
i += len_
|
||||
initial_states.append(np.array(s_init))
|
||||
|
||||
if shuffle:
|
||||
|
||||
+143
-35
@@ -2,12 +2,13 @@ import collections
|
||||
import numpy as np
|
||||
import sys
|
||||
import itertools
|
||||
from typing import Dict, List, Any
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Union
|
||||
|
||||
from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI
|
||||
from ray.rllib.utils.compression import pack, unpack, is_compressed
|
||||
from ray.rllib.utils.memory import concat_aligned
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
from ray.rllib.utils.types import TensorType
|
||||
|
||||
# Default policy id for single agent environments
|
||||
DEFAULT_POLICY_ID = "default_policy"
|
||||
@@ -71,15 +72,16 @@ class SampleBatch:
|
||||
|
||||
@staticmethod
|
||||
@PublicAPI
|
||||
def concat_samples(samples):
|
||||
def concat_samples(samples: List[Dict[str, TensorType]]) -> \
|
||||
Union["SampleBatch", "MultiAgentBatch"]:
|
||||
"""Concatenates n data dicts or MultiAgentBatches.
|
||||
|
||||
Args:
|
||||
samples (List[Dict[np.ndarray]]]): List of dicts of data (numpy).
|
||||
samples (List[Dict[TensorType]]]): List of dicts of data (numpy).
|
||||
|
||||
Returns:
|
||||
Union[SampleBatch,MultiAgentBatch]: A new (compressed) SampleBatch/
|
||||
MultiAgentBatch.
|
||||
Union[SampleBatch, MultiAgentBatch]: A new (compressed)
|
||||
SampleBatch or MultiAgentBatch.
|
||||
"""
|
||||
if isinstance(samples[0], MultiAgentBatch):
|
||||
return MultiAgentBatch.concat_samples(samples)
|
||||
@@ -90,9 +92,17 @@ class SampleBatch:
|
||||
return SampleBatch(out)
|
||||
|
||||
@PublicAPI
|
||||
def concat(self, other):
|
||||
def concat(self, other: "SampleBatch") -> "SampleBatch":
|
||||
"""Returns a new SampleBatch with each data column concatenated.
|
||||
|
||||
Args:
|
||||
other (SampleBatch): The other SampleBatch object to concat to this
|
||||
one.
|
||||
|
||||
Returns:
|
||||
SampleBatch: The new SampleBatch, resulting from concating `other`
|
||||
to `self`.
|
||||
|
||||
Examples:
|
||||
>>> b1 = SampleBatch({"a": [1, 2]})
|
||||
>>> b2 = SampleBatch({"a": [3, 4, 5]})
|
||||
@@ -110,15 +120,24 @@ class SampleBatch:
|
||||
return SampleBatch(out)
|
||||
|
||||
@PublicAPI
|
||||
def copy(self):
|
||||
def copy(self) -> "SampleBatch":
|
||||
"""Creates a (deep) copy of this SampleBatch and returns it.
|
||||
|
||||
Returns:
|
||||
SampleBatch: A (deep) copy of this SampleBatch object.
|
||||
"""
|
||||
return SampleBatch(
|
||||
{k: np.array(v, copy=True)
|
||||
for (k, v) in self.data.items()})
|
||||
|
||||
@PublicAPI
|
||||
def rows(self):
|
||||
def rows(self) -> Dict[str, TensorType]:
|
||||
"""Returns an iterator over data rows, i.e. dicts with column values.
|
||||
|
||||
Yields:
|
||||
Dict[str, TensorType]: The column values of the row in this
|
||||
iteration.
|
||||
|
||||
Examples:
|
||||
>>> batch = SampleBatch({"a": [1, 2, 3], "b": [4, 5, 6]})
|
||||
>>> for row in batch.rows():
|
||||
@@ -135,7 +154,7 @@ class SampleBatch:
|
||||
yield row
|
||||
|
||||
@PublicAPI
|
||||
def columns(self, keys):
|
||||
def columns(self, keys: List[str]) -> List[any]:
|
||||
"""Returns a list of the batch-data in the specified columns.
|
||||
|
||||
Args:
|
||||
@@ -157,7 +176,7 @@ class SampleBatch:
|
||||
return out
|
||||
|
||||
@PublicAPI
|
||||
def shuffle(self):
|
||||
def shuffle(self) -> None:
|
||||
"""Shuffles the rows of this batch in-place."""
|
||||
|
||||
permutation = np.random.permutation(self.count)
|
||||
@@ -165,7 +184,7 @@ class SampleBatch:
|
||||
self[key] = val[permutation]
|
||||
|
||||
@PublicAPI
|
||||
def split_by_episode(self):
|
||||
def split_by_episode(self) -> List["SampleBatch"]:
|
||||
"""Splits this batch's data by `eps_id`.
|
||||
|
||||
Returns:
|
||||
@@ -189,7 +208,7 @@ class SampleBatch:
|
||||
return slices
|
||||
|
||||
@PublicAPI
|
||||
def slice(self, start, end):
|
||||
def slice(self, start: int, end: int) -> "SampleBatch":
|
||||
"""Returns a slice of the row data of this batch.
|
||||
|
||||
Args:
|
||||
@@ -197,13 +216,25 @@ class SampleBatch:
|
||||
end (int): Ending index.
|
||||
|
||||
Returns:
|
||||
SampleBatch which has a slice of this batch's data.
|
||||
SampleBatch: A new SampleBatch, which has a slice of this batch's
|
||||
data.
|
||||
"""
|
||||
|
||||
return SampleBatch({k: v[start:end] for k, v in self.data.items()})
|
||||
|
||||
@PublicAPI
|
||||
def timeslices(self, k: int) -> List["SampleBatch"]:
|
||||
"""Returns SampleBatches, each one representing a k-slice of this one.
|
||||
|
||||
Will start from timestep 0 and produce slices of size=k.
|
||||
|
||||
Args:
|
||||
k (int): The size (in timesteps) of each returned SampleBatch.
|
||||
|
||||
Returns:
|
||||
List[SampleBatch]: The list of (new) SampleBatches (each one of
|
||||
size k).
|
||||
"""
|
||||
out = []
|
||||
i = 0
|
||||
while i < self.count:
|
||||
@@ -212,31 +243,78 @@ class SampleBatch:
|
||||
return out
|
||||
|
||||
@PublicAPI
|
||||
def keys(self):
|
||||
def keys(self) -> Iterable[str]:
|
||||
"""
|
||||
Returns:
|
||||
Iterable[str]: The keys() iterable over `self.data`.
|
||||
"""
|
||||
return self.data.keys()
|
||||
|
||||
@PublicAPI
|
||||
def items(self):
|
||||
def items(self) -> Iterable[TensorType]:
|
||||
"""
|
||||
Returns:
|
||||
Iterable[TensorType]: The values() iterable over `self.data`.
|
||||
"""
|
||||
return self.data.items()
|
||||
|
||||
@PublicAPI
|
||||
def get(self, key):
|
||||
def get(self, key: str) -> Optional[TensorType]:
|
||||
"""Returns one column (by key) from the data or None if key not found.
|
||||
|
||||
Args:
|
||||
key (str): The key (column name) to return.
|
||||
|
||||
Returns:
|
||||
Optional[TensorType]: The data under the given key. None if key
|
||||
not found in data.
|
||||
"""
|
||||
return self.data.get(key)
|
||||
|
||||
@PublicAPI
|
||||
def size_bytes(self) -> int:
|
||||
"""
|
||||
Returns:
|
||||
int: The overall size in bytes of the data buffer (all columns).
|
||||
"""
|
||||
return sum(sys.getsizeof(d) for d in self.data)
|
||||
|
||||
@PublicAPI
|
||||
def __getitem__(self, key):
|
||||
def __getitem__(self, key: str) -> TensorType:
|
||||
"""Returns one column (by key) from the data.
|
||||
|
||||
Args:
|
||||
key (str): The key (column name) to return.
|
||||
|
||||
Returns:
|
||||
TensorType]: The data under the given key.
|
||||
"""
|
||||
return self.data[key]
|
||||
|
||||
@PublicAPI
|
||||
def __setitem__(self, key, item):
|
||||
def __setitem__(self, key, item) -> None:
|
||||
"""Inserts (overrides) an entire column (by key) in the data buffer.
|
||||
|
||||
Args:
|
||||
key (str): The column name to set a value for.
|
||||
item (TensorType): The data to insert.
|
||||
"""
|
||||
self.data[key] = item
|
||||
|
||||
@DeveloperAPI
|
||||
def compress(self, bulk=False, columns=frozenset(["obs", "new_obs"])):
|
||||
def compress(
|
||||
self,
|
||||
bulk: bool = False,
|
||||
columns: Set[str] = frozenset(["obs", "new_obs"])) -> None:
|
||||
"""Compresses the data buffers (by column) in place.
|
||||
|
||||
Args:
|
||||
bulk (bool): Whether to compress across the batch dimension (0)
|
||||
as well. If False will compress n separate list items, where n
|
||||
is the batch size.
|
||||
columns (Set[str]): The columns to compress. Default: Only
|
||||
compress the obs and new_obs columns.
|
||||
"""
|
||||
for key in columns:
|
||||
if key in self.data:
|
||||
if bulk:
|
||||
@@ -246,7 +324,19 @@ class SampleBatch:
|
||||
[pack(o) for o in self.data[key]])
|
||||
|
||||
@DeveloperAPI
|
||||
def decompress_if_needed(self, columns=frozenset(["obs", "new_obs"])):
|
||||
def decompress_if_needed(
|
||||
self,
|
||||
columns: Set[str] = frozenset(
|
||||
["obs", "new_obs"])) -> "SampleBatch":
|
||||
"""Decompresses data buffers (per column if not compressed) in place.
|
||||
|
||||
Args:
|
||||
columns (Set[str]): The columns to decompress. Default: Only
|
||||
decompress the obs and new_obs columns.
|
||||
|
||||
Returns:
|
||||
SampleBatch: This very SampleBatch.
|
||||
"""
|
||||
for key in columns:
|
||||
if key in self.data:
|
||||
arr = self.data[key]
|
||||
@@ -272,10 +362,17 @@ class SampleBatch:
|
||||
|
||||
@PublicAPI
|
||||
class MultiAgentBatch:
|
||||
"""A batch of experiences from multiple agents in the environment."""
|
||||
"""A batch of experiences from multiple agents in the environment.
|
||||
|
||||
Attributes:
|
||||
policy_batches (Dict[PolicyID, SampleBatch]): Mapping from policy
|
||||
ids to SampleBatches of experiences.
|
||||
count (int): The number of env steps in this batch.
|
||||
"""
|
||||
|
||||
@PublicAPI
|
||||
def __init__(self, policy_batches: Dict[PolicyID, SampleBatch],
|
||||
def __init__(self,
|
||||
policy_batches: Dict[PolicyID, SampleBatch],
|
||||
env_steps: int):
|
||||
"""Initialize a MultiAgentBatch object.
|
||||
|
||||
@@ -285,12 +382,8 @@ class MultiAgentBatch:
|
||||
env_steps (int): The number of timesteps in the environment this
|
||||
batch contains. This will be less than the number of
|
||||
transitions this batch contains across all policies in total.
|
||||
|
||||
Attributes:
|
||||
policy_batches (Dict[PolicyID, SampleBatch]): Mapping from policy
|
||||
ids to SampleBatches of experiences.
|
||||
count (int): the number of env steps in this batch.
|
||||
"""
|
||||
|
||||
for v in policy_batches.values():
|
||||
assert isinstance(v, SampleBatch)
|
||||
self.policy_batches = policy_batches
|
||||
@@ -303,7 +396,7 @@ class MultiAgentBatch:
|
||||
"""The number of env steps (there are >= 1 agent steps per env step).
|
||||
|
||||
Returns:
|
||||
int: the number of environment steps contained in this batch.
|
||||
int: The number of environment steps contained in this batch.
|
||||
"""
|
||||
return self.count
|
||||
|
||||
@@ -312,7 +405,7 @@ class MultiAgentBatch:
|
||||
"""The number of agent steps (there are >= 1 agent steps per env step).
|
||||
|
||||
Returns:
|
||||
int: the number of agent steps total in this batch.
|
||||
int: The number of agent steps total in this batch.
|
||||
"""
|
||||
ct = 0
|
||||
for batch in self.policy_batches.values():
|
||||
@@ -379,8 +472,9 @@ class MultiAgentBatch:
|
||||
|
||||
@staticmethod
|
||||
@PublicAPI
|
||||
def wrap_as_needed(policy_batches: Dict[PolicyID, SampleBatch],
|
||||
env_steps: int) -> Any:
|
||||
def wrap_as_needed(
|
||||
policy_batches: Dict[PolicyID, SampleBatch],
|
||||
env_steps: int) -> Union[SampleBatch, "MultiAgentBatch"]:
|
||||
"""Returns SampleBatch or MultiAgentBatch, depending on given policies.
|
||||
|
||||
Args:
|
||||
@@ -437,11 +531,19 @@ class MultiAgentBatch:
|
||||
|
||||
@PublicAPI
|
||||
def size_bytes(self) -> int:
|
||||
"""
|
||||
Returns:
|
||||
int: The overall size in bytes of all policy batches (all columns).
|
||||
"""
|
||||
return sum(b.size_bytes() for b in self.policy_batches.values())
|
||||
|
||||
@DeveloperAPI
|
||||
def compress(self, bulk=False, columns=frozenset(["obs", "new_obs"])):
|
||||
"""Compresses each policy batch.
|
||||
def compress(
|
||||
self,
|
||||
bulk: bool = False,
|
||||
columns: Set[str] = frozenset(
|
||||
["obs", "new_obs"])) -> None:
|
||||
"""Compresses each policy batch (per column) in place.
|
||||
|
||||
Args:
|
||||
bulk (bool): Whether to compress across the batch dimension (0)
|
||||
@@ -453,11 +555,17 @@ class MultiAgentBatch:
|
||||
batch.compress(bulk=bulk, columns=columns)
|
||||
|
||||
@DeveloperAPI
|
||||
def decompress_if_needed(self, columns=frozenset(["obs", "new_obs"])):
|
||||
"""Decompresses each policy batch, if already compressed.
|
||||
def decompress_if_needed(
|
||||
self,
|
||||
columns: Set[str] = frozenset(
|
||||
["obs", "new_obs"])) -> "MultiAgentBatch":
|
||||
"""Decompresses each policy batch (per column), if already compressed.
|
||||
|
||||
Args:
|
||||
columns (Set[str]): Set of column names to decompress.
|
||||
|
||||
Returns:
|
||||
MultiAgentBatch: This very MultiAgentBatch.
|
||||
"""
|
||||
for batch in self.policy_batches.values():
|
||||
batch.decompress_if_needed(columns)
|
||||
|
||||
+237
-131
@@ -1,7 +1,9 @@
|
||||
import errno
|
||||
import gym
|
||||
import logging
|
||||
import numpy as np
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import ray
|
||||
import ray.experimental.tf_utils
|
||||
@@ -15,6 +17,7 @@ from ray.rllib.utils.debug import summarize
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule
|
||||
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
||||
from ray.rllib.utils.types import ModelGradients, TensorType, TrainerConfigDict
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -24,6 +27,11 @@ logger = logging.getLogger(__name__)
|
||||
class TFPolicy(Policy):
|
||||
"""An agent policy and loss implemented in TensorFlow.
|
||||
|
||||
Do not sub-class this class directly (neither should you sub-class
|
||||
DynamicTFPolicy), but rather use
|
||||
rllib.policy.tf_policy_template.build_tf_policy
|
||||
to generate your custom tf (graph-mode or eager) Policy classes.
|
||||
|
||||
Extending this class enables RLlib to perform TensorFlow specific
|
||||
optimizations on the policy, e.g., parallelization across gpus or
|
||||
fusing multiple graphs together in the multi-agent setting.
|
||||
@@ -48,77 +56,83 @@ class TFPolicy(Policy):
|
||||
|
||||
@DeveloperAPI
|
||||
def __init__(self,
|
||||
observation_space,
|
||||
action_space,
|
||||
config,
|
||||
sess,
|
||||
obs_input,
|
||||
sampled_action,
|
||||
loss,
|
||||
loss_inputs,
|
||||
model=None,
|
||||
sampled_action_logp=None,
|
||||
action_input=None,
|
||||
log_likelihood=None,
|
||||
dist_inputs=None,
|
||||
dist_class=None,
|
||||
state_inputs=None,
|
||||
state_outputs=None,
|
||||
prev_action_input=None,
|
||||
prev_reward_input=None,
|
||||
seq_lens=None,
|
||||
max_seq_len=20,
|
||||
batch_divisibility_req=1,
|
||||
update_ops=None,
|
||||
explore=None,
|
||||
timestep=None):
|
||||
"""Initialize the policy.
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict,
|
||||
sess: "tf1.Session",
|
||||
obs_input: TensorType,
|
||||
sampled_action: TensorType,
|
||||
loss: TensorType,
|
||||
loss_inputs: List[Tuple[str, TensorType]],
|
||||
model: ModelV2 = None,
|
||||
sampled_action_logp: Optional[TensorType] = None,
|
||||
action_input: Optional[TensorType] = None,
|
||||
log_likelihood: Optional[TensorType] = None,
|
||||
dist_inputs: Optional[TensorType] = None,
|
||||
dist_class: Optional[type] = None,
|
||||
state_inputs: Optional[List[TensorType]] = None,
|
||||
state_outputs: Optional[List[TensorType]] = None,
|
||||
prev_action_input: Optional[TensorType] = None,
|
||||
prev_reward_input: Optional[TensorType] = None,
|
||||
seq_lens: Optional[TensorType] = None,
|
||||
max_seq_len: int = 20,
|
||||
batch_divisibility_req: int = 1,
|
||||
update_ops: List[TensorType] = None,
|
||||
explore: Optional[TensorType] = None,
|
||||
timestep: Optional[TensorType] = None):
|
||||
"""Initializes a Policy object.
|
||||
|
||||
Arguments:
|
||||
observation_space (gym.Space): Observation space of the env.
|
||||
action_space (gym.Space): Action space of the env.
|
||||
config (dict): The Policy config dict.
|
||||
sess (Session): The TensorFlow session to use.
|
||||
obs_input (Tensor): Input placeholder for observations, of shape
|
||||
[BATCH_SIZE, obs...].
|
||||
sampled_action (Tensor): Tensor for sampling an action, of shape
|
||||
[BATCH_SIZE, action...]
|
||||
loss (Tensor): Scalar policy loss output tensor.
|
||||
loss_inputs (list): A (name, placeholder) tuple for each loss
|
||||
input argument. Each placeholder name must correspond to a
|
||||
SampleBatch column key returned by postprocess_trajectory(),
|
||||
and has shape [BATCH_SIZE, data...]. These keys will be read
|
||||
from postprocessed sample batches and fed into the specified
|
||||
placeholders during loss computation.
|
||||
model (rllib.models.Model): used to integrate custom losses and
|
||||
Args:
|
||||
observation_space (gym.spaces.Space): Observation space of the env.
|
||||
action_space (gym.spaces.Space): Action space of the env.
|
||||
config (TrainerConfigDict): The Policy config dict.
|
||||
sess (tf1.Session): The TensorFlow session to use.
|
||||
obs_input (TensorType): Input placeholder for observations, of
|
||||
shape [BATCH_SIZE, obs...].
|
||||
sampled_action (TensorType): Tensor for sampling an action, of
|
||||
shape [BATCH_SIZE, action...]
|
||||
loss (TensorType): Scalar policy loss output tensor.
|
||||
loss_inputs (List[Tuple[str, TensorType]]): A (name, placeholder)
|
||||
tuple for each loss input argument. Each placeholder name must
|
||||
correspond to a SampleBatch column key returned by
|
||||
postprocess_trajectory(), and has shape [BATCH_SIZE, data...].
|
||||
These keys will be read from postprocessed sample batches and
|
||||
fed into the specified placeholders during loss computation.
|
||||
model (ModelV2): used to integrate custom losses and
|
||||
stats from user-defined RLlib models.
|
||||
sampled_action_logp (Tensor): log probability of the sampled
|
||||
action.
|
||||
action_input (Optional[Tensor]): Input placeholder for actions for
|
||||
logp/log-likelihood calculations.
|
||||
log_likelihood (Optional[Tensor]): Tensor to calculate the
|
||||
sampled_action_logp (Optional[TensorType]): log probability of the
|
||||
sampled action.
|
||||
action_input (Optional[TensorType]): Input placeholder for actions
|
||||
for logp/log-likelihood calculations.
|
||||
log_likelihood (Optional[TensorType]): Tensor to calculate the
|
||||
log_likelihood (given action_input and obs_input).
|
||||
dist_class (Optional[type): An optional ActionDistribution class
|
||||
dist_class (Optional[type]): An optional ActionDistribution class
|
||||
to use for generating a dist object from distribution inputs.
|
||||
dist_inputs (Optional[Tensor]): Tensor to calculate the
|
||||
dist_inputs (Optional[TensorType]): Tensor to calculate the
|
||||
distribution inputs/parameters.
|
||||
state_inputs (list): list of RNN state input Tensors.
|
||||
state_outputs (list): list of RNN state output Tensors.
|
||||
prev_action_input (Tensor): placeholder for previous actions
|
||||
prev_reward_input (Tensor): placeholder for previous rewards
|
||||
seq_lens (Tensor): Placeholder for RNN sequence lengths, of shape
|
||||
[NUM_SEQUENCES]. Note that NUM_SEQUENCES << BATCH_SIZE. See
|
||||
state_inputs (Optional[List[TensorType]]): List of RNN state input
|
||||
Tensors.
|
||||
state_outputs (Optional[List[TensorType]]): List of RNN state
|
||||
output Tensors.
|
||||
prev_action_input (Optional[TensorType]): placeholder for previous
|
||||
actions.
|
||||
prev_reward_input (Optional[TensorType]): placeholder for previous
|
||||
rewards.
|
||||
seq_lens (Optional[TensorType]): Placeholder for RNN sequence
|
||||
lengths, of shape [NUM_SEQUENCES].
|
||||
Note that NUM_SEQUENCES << BATCH_SIZE. See
|
||||
policy/rnn_sequencing.py for more information.
|
||||
max_seq_len (int): Max sequence length for LSTM training.
|
||||
batch_divisibility_req (int): pad all agent experiences batches to
|
||||
multiples of this value. This only has an effect if not using
|
||||
a LSTM model.
|
||||
update_ops (list): override the batchnorm update ops to run when
|
||||
applying gradients. Otherwise we run all update ops found in
|
||||
the current variable scope.
|
||||
explore (Tensor): Placeholder for `explore` parameter into
|
||||
call to Exploration.get_exploration_action.
|
||||
timestep (Tensor): Placeholder for the global sampling timestep.
|
||||
update_ops (List[TensorType]): override the batchnorm update ops to
|
||||
run when applying gradients. Otherwise we run all update ops
|
||||
found in the current variable scope.
|
||||
explore (Optional[TensorType]): Placeholder for `explore` parameter
|
||||
into call to Exploration.get_exploration_action.
|
||||
timestep (Optional[TensorType]): Placeholder for the global
|
||||
sampling timestep.
|
||||
"""
|
||||
self.framework = "tf"
|
||||
super().__init__(observation_space, action_space, config)
|
||||
@@ -192,33 +206,49 @@ class TFPolicy(Policy):
|
||||
"""Return the list of all savable variables for this policy."""
|
||||
return self.model.variables()
|
||||
|
||||
def get_placeholder(self, name):
|
||||
def get_placeholder(self, name) -> "tf1.placeholder":
|
||||
"""Returns the given action or loss input placeholder by name.
|
||||
|
||||
If the loss has not been initialized and a loss input placeholder is
|
||||
requested, an error is raised.
|
||||
|
||||
Args:
|
||||
name (str): The name of the placeholder to return. One of
|
||||
SampleBatch.CUR_OBS|PREV_ACTION/REWARD or a valid key from
|
||||
`self._loss_input_dict`.
|
||||
|
||||
Returns:
|
||||
tf1.placeholder: The placeholder under the given str key.
|
||||
"""
|
||||
obs_inputs = {
|
||||
SampleBatch.CUR_OBS: self._obs_input,
|
||||
SampleBatch.PREV_ACTIONS: self._prev_action_input,
|
||||
SampleBatch.PREV_REWARDS: self._prev_reward_input,
|
||||
}
|
||||
if name in obs_inputs:
|
||||
return obs_inputs[name]
|
||||
if name == SampleBatch.CUR_OBS:
|
||||
return self._obs_input
|
||||
elif name == SampleBatch.PREV_ACTIONS:
|
||||
return self._prev_action_input
|
||||
elif name == SampleBatch.PREV_REWARDS:
|
||||
return self._prev_reward_input
|
||||
|
||||
assert self._loss_input_dict is not None, \
|
||||
"Should have set this before get_placeholder can be called"
|
||||
return self._loss_input_dict[name]
|
||||
|
||||
def get_session(self):
|
||||
def get_session(self) -> "tf1.Session":
|
||||
"""Returns a reference to the TF session for this policy."""
|
||||
return self._sess
|
||||
|
||||
def loss_initialized(self):
|
||||
def loss_initialized(self) -> bool:
|
||||
"""Returns whether the loss function has been initialized."""
|
||||
return self._loss is not None
|
||||
|
||||
def _initialize_loss(self, loss, loss_inputs):
|
||||
def _initialize_loss(self,
|
||||
loss: TensorType,
|
||||
loss_inputs: List[Tuple[str, TensorType]]) -> None:
|
||||
"""Initializes the loss op from given loss tensor and placeholders.
|
||||
|
||||
Args:
|
||||
loss (TensorType): The loss op generated by some loss function.
|
||||
loss_inputs (List[Tuple[str, TensorType]]): The list of Tuples:
|
||||
(name, tf1.placeholders) needed for calculating the loss.
|
||||
"""
|
||||
self._loss_inputs = loss_inputs
|
||||
self._loss_input_dict = dict(self._loss_inputs)
|
||||
for i, ph in enumerate(self._state_inputs):
|
||||
@@ -270,16 +300,18 @@ class TFPolicy(Policy):
|
||||
self._optimizer.variables(), self._sess)
|
||||
|
||||
@override(Policy)
|
||||
def compute_actions(self,
|
||||
obs_batch,
|
||||
state_batches=None,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
info_batch=None,
|
||||
episodes=None,
|
||||
explore=None,
|
||||
timestep=None,
|
||||
**kwargs):
|
||||
def compute_actions(
|
||||
self,
|
||||
obs_batch: Union[List[TensorType], TensorType],
|
||||
state_batches: Optional[List[TensorType]] = None,
|
||||
prev_action_batch: Union[List[TensorType], TensorType] = None,
|
||||
prev_reward_batch: Union[List[TensorType], TensorType] = None,
|
||||
info_batch: Optional[Dict[str, list]] = None,
|
||||
episodes: Optional[List["MultiAgentEpisode"]] = None,
|
||||
explore: Optional[bool] = None,
|
||||
timestep: Optional[int] = None,
|
||||
**kwargs):
|
||||
|
||||
explore = explore if explore is not None else self.config["explore"]
|
||||
timestep = timestep if timestep is not None else self.global_timestep
|
||||
|
||||
@@ -299,12 +331,16 @@ class TFPolicy(Policy):
|
||||
return fetched
|
||||
|
||||
@override(Policy)
|
||||
def compute_log_likelihoods(self,
|
||||
actions,
|
||||
obs_batch,
|
||||
state_batches=None,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None):
|
||||
def compute_log_likelihoods(
|
||||
self,
|
||||
actions: Union[List[TensorType], TensorType],
|
||||
obs_batch: Union[List[TensorType], TensorType],
|
||||
state_batches: Optional[List[TensorType]] = None,
|
||||
prev_action_batch: Optional[
|
||||
Union[List[TensorType], TensorType]] = None,
|
||||
prev_reward_batch: Optional[
|
||||
Union[List[TensorType], TensorType]] = None) -> TensorType:
|
||||
|
||||
if self._log_likelihood is None:
|
||||
raise ValueError("Cannot compute log-prob/likelihood w/o a "
|
||||
"self._log_likelihood op!")
|
||||
@@ -341,40 +377,51 @@ class TFPolicy(Policy):
|
||||
return builder.get(fetches)[0]
|
||||
|
||||
@override(Policy)
|
||||
def compute_gradients(self, postprocessed_batch):
|
||||
assert self.loss_initialized()
|
||||
builder = TFRunBuilder(self._sess, "compute_gradients")
|
||||
fetches = self._build_compute_gradients(builder, postprocessed_batch)
|
||||
return builder.get(fetches)
|
||||
|
||||
@override(Policy)
|
||||
def apply_gradients(self, gradients):
|
||||
assert self.loss_initialized()
|
||||
builder = TFRunBuilder(self._sess, "apply_gradients")
|
||||
fetches = self._build_apply_gradients(builder, gradients)
|
||||
builder.get(fetches)
|
||||
|
||||
@override(Policy)
|
||||
def learn_on_batch(self, postprocessed_batch):
|
||||
@DeveloperAPI
|
||||
def learn_on_batch(self, postprocessed_batch: SampleBatch) -> Dict[
|
||||
str, TensorType]:
|
||||
assert self.loss_initialized()
|
||||
builder = TFRunBuilder(self._sess, "learn_on_batch")
|
||||
fetches = self._build_learn_on_batch(builder, postprocessed_batch)
|
||||
return builder.get(fetches)
|
||||
|
||||
@override(Policy)
|
||||
def get_exploration_info(self):
|
||||
@DeveloperAPI
|
||||
def compute_gradients(
|
||||
self,
|
||||
postprocessed_batch: SampleBatch) -> \
|
||||
Tuple[ModelGradients, Dict[str, TensorType]]:
|
||||
assert self.loss_initialized()
|
||||
builder = TFRunBuilder(self._sess, "compute_gradients")
|
||||
fetches = self._build_compute_gradients(builder, postprocessed_batch)
|
||||
return builder.get(fetches)
|
||||
|
||||
@override(Policy)
|
||||
@DeveloperAPI
|
||||
def apply_gradients(self, gradients: ModelGradients) -> None:
|
||||
assert self.loss_initialized()
|
||||
builder = TFRunBuilder(self._sess, "apply_gradients")
|
||||
fetches = self._build_apply_gradients(builder, gradients)
|
||||
builder.get(fetches)
|
||||
|
||||
@override(Policy)
|
||||
@DeveloperAPI
|
||||
def get_exploration_info(self) -> Dict[str, TensorType]:
|
||||
return self.exploration.get_info(sess=self.get_session())
|
||||
|
||||
@override(Policy)
|
||||
def get_weights(self):
|
||||
@DeveloperAPI
|
||||
def get_weights(self) -> Union[Dict[str, TensorType], List[TensorType]]:
|
||||
return self._variables.get_weights()
|
||||
|
||||
@override(Policy)
|
||||
def set_weights(self, weights):
|
||||
@DeveloperAPI
|
||||
def set_weights(self, weights) -> None:
|
||||
return self._variables.set_weights(weights)
|
||||
|
||||
@override(Policy)
|
||||
def get_state(self):
|
||||
@DeveloperAPI
|
||||
def get_state(self) -> Union[Dict[str, TensorType], List[TensorType]]:
|
||||
# For tf Policies, return Policy weights and optimizer var values.
|
||||
state = super().get_state()
|
||||
if self._optimizer_variables and \
|
||||
@@ -384,7 +431,8 @@ class TFPolicy(Policy):
|
||||
return state
|
||||
|
||||
@override(Policy)
|
||||
def set_state(self, state):
|
||||
@DeveloperAPI
|
||||
def set_state(self, state) -> None:
|
||||
state = state.copy() # shallow copy
|
||||
# Set optimizer vars first.
|
||||
optimizer_vars = state.pop("_optimizer_variables", None)
|
||||
@@ -394,7 +442,8 @@ class TFPolicy(Policy):
|
||||
super().set_state(state)
|
||||
|
||||
@override(Policy)
|
||||
def export_model(self, export_dir):
|
||||
@DeveloperAPI
|
||||
def export_model(self, export_dir: str) -> None:
|
||||
"""Export tensorflow graph to export_dir for serving."""
|
||||
with self._sess.graph.as_default():
|
||||
builder = tf1.saved_model.builder.SavedModelBuilder(export_dir)
|
||||
@@ -407,7 +456,9 @@ class TFPolicy(Policy):
|
||||
builder.save()
|
||||
|
||||
@override(Policy)
|
||||
def export_checkpoint(self, export_dir, filename_prefix="model"):
|
||||
@DeveloperAPI
|
||||
def export_checkpoint(self, export_dir: str,
|
||||
filename_prefix: str = "model") -> None:
|
||||
"""Export tensorflow checkpoint to export_dir."""
|
||||
try:
|
||||
os.makedirs(export_dir)
|
||||
@@ -421,7 +472,8 @@ class TFPolicy(Policy):
|
||||
saver.save(self._sess, save_path)
|
||||
|
||||
@override(Policy)
|
||||
def import_model_from_h5(self, import_file):
|
||||
@DeveloperAPI
|
||||
def import_model_from_h5(self, import_file: str) -> None:
|
||||
"""Imports weights into tf model."""
|
||||
# Make sure the session is the right one (see issue #7046).
|
||||
with self._sess.graph.as_default():
|
||||
@@ -429,31 +481,53 @@ class TFPolicy(Policy):
|
||||
return self.model.import_from_h5(import_file)
|
||||
|
||||
@DeveloperAPI
|
||||
def copy(self, existing_inputs):
|
||||
def copy(self,
|
||||
existing_inputs: List[Tuple[str, "tf1.placeholder"]]) -> \
|
||||
"TFPolicy":
|
||||
"""Creates a copy of self using existing input placeholders.
|
||||
|
||||
Optional, only required to work with the multi-GPU optimizer."""
|
||||
Optional: Only required to work with the multi-GPU optimizer.
|
||||
|
||||
Args:
|
||||
existing_inputs (List[Tuple[str, tf1.placeholder]]): Dict mapping
|
||||
names (str) to tf1.placeholders to re-use (share) with the
|
||||
returned copy of self.
|
||||
|
||||
Returns:
|
||||
TFPolicy: A copy of self.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@override(Policy)
|
||||
def is_recurrent(self):
|
||||
@DeveloperAPI
|
||||
def is_recurrent(self) -> bool:
|
||||
return len(self._state_inputs) > 0
|
||||
|
||||
@override(Policy)
|
||||
def num_state_tensors(self):
|
||||
@DeveloperAPI
|
||||
def num_state_tensors(self) -> int:
|
||||
return len(self._state_inputs)
|
||||
|
||||
@DeveloperAPI
|
||||
def extra_compute_action_feed_dict(self):
|
||||
"""Extra dict to pass to the compute actions session run."""
|
||||
def extra_compute_action_feed_dict(self) -> Dict[TensorType, TensorType]:
|
||||
"""Extra dict to pass to the compute actions session run.
|
||||
|
||||
Returns:
|
||||
Dict[TensorType, TensorType]: A feed dict to be added to the
|
||||
feed_dict passed to the compute_actions session.run() call.
|
||||
"""
|
||||
return {}
|
||||
|
||||
@DeveloperAPI
|
||||
def extra_compute_action_fetches(self):
|
||||
def extra_compute_action_fetches(self) -> Dict[str, TensorType]:
|
||||
"""Extra values to fetch and return from compute_actions().
|
||||
|
||||
By default we return action probability/log-likelihood info
|
||||
and action distribution inputs (if present).
|
||||
|
||||
Returns:
|
||||
Dict[str, TensorType]: An extra fetch-dict to be passed to and
|
||||
returned from the compute_actions() call.
|
||||
"""
|
||||
extra_fetches = {}
|
||||
# Action-logp and action-prob.
|
||||
@@ -466,38 +540,70 @@ class TFPolicy(Policy):
|
||||
return extra_fetches
|
||||
|
||||
@DeveloperAPI
|
||||
def extra_compute_grad_feed_dict(self):
|
||||
"""Extra dict to pass to the compute gradients session run."""
|
||||
def extra_compute_grad_feed_dict(self) -> Dict[TensorType, TensorType]:
|
||||
"""Extra dict to pass to the compute gradients session run.
|
||||
|
||||
Returns:
|
||||
Dict[TensorType, TensorType]: Extra feed_dict to be passed to the
|
||||
compute_gradients Session.run() call.
|
||||
"""
|
||||
return {} # e.g, kl_coeff
|
||||
|
||||
@DeveloperAPI
|
||||
def extra_compute_grad_fetches(self):
|
||||
"""Extra values to fetch and return from compute_gradients()."""
|
||||
def extra_compute_grad_fetches(self) -> Dict[str, any]:
|
||||
"""Extra values to fetch and return from compute_gradients().
|
||||
|
||||
Returns:
|
||||
Dict[str, any]: Extra fetch dict to be added to the fetch dict
|
||||
of the compute_gradients Session.run() call.
|
||||
"""
|
||||
return {LEARNER_STATS_KEY: {}} # e.g, stats, td error, etc.
|
||||
|
||||
@DeveloperAPI
|
||||
def optimizer(self):
|
||||
"""TF optimizer to use for policy optimization."""
|
||||
def optimizer(self) -> "tf.keras.optimizers.Optimizer":
|
||||
"""TF optimizer to use for policy optimization.
|
||||
|
||||
Returns:
|
||||
tf.keras.optimizers.Optimizer: The local optimizer to use for this
|
||||
Policy's Model.
|
||||
"""
|
||||
if hasattr(self, "config"):
|
||||
return tf1.train.AdamOptimizer(learning_rate=self.config["lr"])
|
||||
else:
|
||||
return tf1.train.AdamOptimizer()
|
||||
|
||||
@DeveloperAPI
|
||||
def gradients(self, optimizer, loss):
|
||||
"""Override for custom gradient computation."""
|
||||
def gradients(self,
|
||||
optimizer: "tf.keras.optimizers.Optimizer",
|
||||
loss: TensorType) -> List[Tuple[TensorType, TensorType]]:
|
||||
"""Override this for a custom gradient computation behavior.
|
||||
|
||||
Returns:
|
||||
List[Tuple[TensorType, TensorType]]: List of tuples with grad
|
||||
values and the grad-value's corresponding tf.variable in it.
|
||||
"""
|
||||
return optimizer.compute_gradients(loss)
|
||||
|
||||
@DeveloperAPI
|
||||
def build_apply_op(self, optimizer, grads_and_vars):
|
||||
"""Override for custom gradient apply computation."""
|
||||
def build_apply_op(
|
||||
self,
|
||||
optimizer: "tf.keras.optimizers.Optimizer",
|
||||
grads_and_vars: List[Tuple[TensorType, TensorType]]) -> \
|
||||
"tf.Operation":
|
||||
"""Override this for a custom gradient apply computation behavior.
|
||||
|
||||
# specify global_step for TD3 which needs to count the num updates
|
||||
Args:
|
||||
optimizer (tf.keras.optimizers.Optimizer): The local tf optimizer
|
||||
to use for applying the grads and vars.
|
||||
grads_and_vars (List[Tuple[TensorType, TensorType]]): List of
|
||||
tuples with grad values and the grad-value's corresponding
|
||||
tf.variable in it.
|
||||
"""
|
||||
# Specify global_step for TD3 which needs to count the num updates.
|
||||
return optimizer.apply_gradients(
|
||||
self._grads_and_vars,
|
||||
global_step=tf1.train.get_or_create_global_step())
|
||||
|
||||
@DeveloperAPI
|
||||
def _get_is_training_placeholder(self):
|
||||
"""Get the placeholder for _is_training, i.e., for batch norm layers.
|
||||
|
||||
@@ -670,7 +776,7 @@ class TFPolicy(Policy):
|
||||
def _get_loss_inputs_dict(self, batch, shuffle):
|
||||
"""Return a feed dict from a batch.
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
batch (SampleBatch): batch of data to derive inputs from
|
||||
shuffle (bool): whether to shuffle batch sequences. Shuffle may
|
||||
be done in-place. This only makes sense if you're further
|
||||
|
||||
@@ -131,10 +131,10 @@ def build_tf_policy(name,
|
||||
|
||||
DynamicTFPolicy.__init__(
|
||||
self,
|
||||
obs_space,
|
||||
action_space,
|
||||
config,
|
||||
loss_fn,
|
||||
obs_space=obs_space,
|
||||
action_space=action_space,
|
||||
config=config,
|
||||
loss_fn=loss_fn,
|
||||
stats_fn=stats_fn,
|
||||
grad_stats_fn=grad_stats_fn,
|
||||
before_loss_init=before_loss_init_wrapper,
|
||||
|
||||
+176
-99
@@ -1,8 +1,11 @@
|
||||
import functools
|
||||
import gym
|
||||
import numpy as np
|
||||
import time
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
|
||||
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
@@ -15,11 +18,13 @@ from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule
|
||||
from ray.rllib.utils.torch_ops import convert_to_non_torch_type, \
|
||||
convert_to_torch_tensor
|
||||
from ray.rllib.utils.tracking_dict import UsageTrackingDict
|
||||
from ray.rllib.utils.types import AgentID
|
||||
from ray.rllib.utils.types import AgentID, ModelGradients, ModelWeights, \
|
||||
TensorType, TrainerConfigDict
|
||||
|
||||
torch, _ = try_import_torch()
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class TorchPolicy(Policy):
|
||||
"""Template for a PyTorch policy and loss to use with RLlib.
|
||||
|
||||
@@ -33,49 +38,63 @@ class TorchPolicy(Policy):
|
||||
dist_class (type): Torch action distribution class.
|
||||
"""
|
||||
|
||||
@DeveloperAPI
|
||||
def __init__(self,
|
||||
observation_space,
|
||||
action_space,
|
||||
config,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict,
|
||||
*,
|
||||
model,
|
||||
loss,
|
||||
action_distribution_class,
|
||||
action_sampler_fn=None,
|
||||
action_distribution_fn=None,
|
||||
max_seq_len=20,
|
||||
get_batch_divisibility_req=None):
|
||||
model: ModelV2,
|
||||
loss: Callable[
|
||||
[Policy, ModelV2, type, SampleBatch], TensorType],
|
||||
action_distribution_class: TorchDistributionWrapper,
|
||||
action_sampler_fn: Callable[
|
||||
[TensorType, List[TensorType]], Tuple[
|
||||
TensorType, TensorType]] = None,
|
||||
action_distribution_fn: Optional[Callable[
|
||||
[Policy, ModelV2, TensorType, TensorType, TensorType],
|
||||
Tuple[TensorType, type, List[TensorType]]]] = None,
|
||||
max_seq_len: int = 20,
|
||||
get_batch_divisibility_req: Optional[int] = None):
|
||||
"""Build a policy from policy and loss torch modules.
|
||||
|
||||
Note that model will be placed on GPU device if CUDA_VISIBLE_DEVICES
|
||||
is set. Only single GPU is supported for now.
|
||||
|
||||
Arguments:
|
||||
observation_space (gym.Space): observation space of the policy.
|
||||
action_space (gym.Space): action space of the policy.
|
||||
config (dict): The Policy config dict.
|
||||
model (nn.Module): PyTorch policy module. Given observations as
|
||||
Args:
|
||||
observation_space (gym.spaces.Space): observation space of the
|
||||
policy.
|
||||
action_space (gym.spaces.Space): action space of the policy.
|
||||
config (TrainerConfigDict): The Policy config dict.
|
||||
model (ModelV2): PyTorch policy module. Given observations as
|
||||
input, this module must return a list of outputs where the
|
||||
first item is action logits, and the rest can be any value.
|
||||
loss (func): Function that takes (policy, model, dist_class,
|
||||
train_batch) and returns a single scalar loss.
|
||||
action_distribution_class (ActionDistribution): Class for action
|
||||
distribution.
|
||||
action_sampler_fn (Optional[callable]): A callable returning a
|
||||
sampled action and its log-likelihood given some (obs and
|
||||
state) inputs.
|
||||
action_distribution_fn (Optional[callable]): A callable returning
|
||||
distribution inputs (parameters), a dist-class to generate an
|
||||
action distribution object from, and internal-state outputs
|
||||
(or an empty list if not applicable).
|
||||
loss (Callable[[Policy, ModelV2, type, SampleBatch], TensorType]):
|
||||
Function that takes (policy, model, dist_class, train_batch)
|
||||
and returns a single scalar loss.
|
||||
action_distribution_class (TorchDistributionWrapper): Class for
|
||||
a torch action distribution.
|
||||
action_sampler_fn (Callable[[TensorType, List[TensorType]],
|
||||
Tuple[TensorType, TensorType]]): A callable returning a
|
||||
sampled action and its log-likelihood given Policy, ModelV2,
|
||||
input_dict, explore, timestep, and is_training.
|
||||
action_distribution_fn (Optional[Callable[[Policy, ModelV2,
|
||||
Dict[str, TensorType], TensorType, TensorType],
|
||||
Tuple[TensorType, type, List[TensorType]]]]): A callable
|
||||
returning distribution inputs (parameters), a dist-class to
|
||||
generate an action distribution object from, and
|
||||
internal-state outputs (or an empty list if not applicable).
|
||||
Note: No Exploration hooks have to be called from within
|
||||
`action_distribution_fn`. It's should only perform a simple
|
||||
forward pass through some model.
|
||||
If None, pass inputs through `self.model()` to get the
|
||||
distribution inputs.
|
||||
If None, pass inputs through `self.model()` to get distribution
|
||||
inputs.
|
||||
The callable takes as inputs: Policy, ModelV2, input_dict,
|
||||
explore, timestep, is_training.
|
||||
max_seq_len (int): Max sequence length for LSTM training.
|
||||
get_batch_divisibility_req (Optional[callable]): Optional callable
|
||||
that returns the divisibility requirement for sample batches.
|
||||
get_batch_divisibility_req (Optional[Callable[[Policy], int]]]):
|
||||
Optional callable that returns the divisibility requirement
|
||||
for sample batches given the Policy.
|
||||
"""
|
||||
self.framework = "torch"
|
||||
super().__init__(observation_space, action_space, config)
|
||||
@@ -100,16 +119,19 @@ class TorchPolicy(Policy):
|
||||
else 1
|
||||
|
||||
@override(Policy)
|
||||
def compute_actions(self,
|
||||
obs_batch,
|
||||
state_batches=None,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
info_batch=None,
|
||||
episodes=None,
|
||||
explore=None,
|
||||
timestep=None,
|
||||
**kwargs):
|
||||
@DeveloperAPI
|
||||
def compute_actions(
|
||||
self,
|
||||
obs_batch: Union[List[TensorType], TensorType],
|
||||
state_batches: Optional[List[TensorType]] = None,
|
||||
prev_action_batch: Union[List[TensorType], TensorType] = None,
|
||||
prev_reward_batch: Union[List[TensorType], TensorType] = None,
|
||||
info_batch: Optional[Dict[str, list]] = None,
|
||||
episodes: Optional[List["MultiAgentEpisode"]] = None,
|
||||
explore: Optional[bool] = None,
|
||||
timestep: Optional[int] = None,
|
||||
**kwargs) -> \
|
||||
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
|
||||
|
||||
explore = explore if explore is not None else self.config["explore"]
|
||||
timestep = timestep if timestep is not None else self.global_timestep
|
||||
@@ -149,7 +171,8 @@ class TorchPolicy(Policy):
|
||||
other_trajectories: Dict[AgentID, "Trajectory"],
|
||||
explore: bool = None,
|
||||
timestep: Optional[int] = None,
|
||||
**kwargs):
|
||||
**kwargs) -> \
|
||||
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
|
||||
|
||||
explore = explore if explore is not None else self.config["explore"]
|
||||
timestep = timestep if timestep is not None else self.global_timestep
|
||||
@@ -237,12 +260,16 @@ class TorchPolicy(Policy):
|
||||
return actions, state_out, extra_fetches, logp
|
||||
|
||||
@override(Policy)
|
||||
def compute_log_likelihoods(self,
|
||||
actions,
|
||||
obs_batch,
|
||||
state_batches=None,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None):
|
||||
@DeveloperAPI
|
||||
def compute_log_likelihoods(
|
||||
self,
|
||||
actions: Union[List[TensorType], TensorType],
|
||||
obs_batch: Union[List[TensorType], TensorType],
|
||||
state_batches: Optional[List[TensorType]] = None,
|
||||
prev_action_batch: Optional[
|
||||
Union[List[TensorType], TensorType]] = None,
|
||||
prev_reward_batch: Optional[
|
||||
Union[List[TensorType], TensorType]] = None) -> TensorType:
|
||||
|
||||
if self.action_sampler_fn and self.action_distribution_fn is None:
|
||||
raise ValueError("Cannot compute log-prob/likelihood w/o an "
|
||||
@@ -282,7 +309,9 @@ class TorchPolicy(Policy):
|
||||
return log_likelihoods
|
||||
|
||||
@override(Policy)
|
||||
def learn_on_batch(self, postprocessed_batch):
|
||||
@DeveloperAPI
|
||||
def learn_on_batch(self, postprocessed_batch: SampleBatch) -> Dict[
|
||||
str, TensorType]:
|
||||
# Get batch ready for RNNs, if applicable.
|
||||
pad_batch_to_sequences_of_same_size(
|
||||
postprocessed_batch,
|
||||
@@ -340,7 +369,9 @@ class TorchPolicy(Policy):
|
||||
return {LEARNER_STATS_KEY: grad_info}
|
||||
|
||||
@override(Policy)
|
||||
def compute_gradients(self, postprocessed_batch):
|
||||
@DeveloperAPI
|
||||
def compute_gradients(self,
|
||||
postprocessed_batch: SampleBatch) -> ModelGradients:
|
||||
train_batch = self._lazy_tensor_dict(postprocessed_batch)
|
||||
loss_out = force_list(
|
||||
self._loss(self, self.model, self.dist_class, train_batch))
|
||||
@@ -367,7 +398,8 @@ class TorchPolicy(Policy):
|
||||
return grads, {LEARNER_STATS_KEY: grad_info}
|
||||
|
||||
@override(Policy)
|
||||
def apply_gradients(self, gradients):
|
||||
@DeveloperAPI
|
||||
def apply_gradients(self, gradients: ModelGradients) -> None:
|
||||
# TODO(sven): Not supported for multiple optimizers yet.
|
||||
assert len(self._optimizers) == 1
|
||||
for g, p in zip(gradients, self.model.parameters()):
|
||||
@@ -377,19 +409,39 @@ class TorchPolicy(Policy):
|
||||
self._optimizers[0].step()
|
||||
|
||||
@override(Policy)
|
||||
def get_weights(self):
|
||||
@DeveloperAPI
|
||||
def get_weights(self) -> ModelWeights:
|
||||
return {
|
||||
k: v.cpu().detach().numpy()
|
||||
for k, v in self.model.state_dict().items()
|
||||
}
|
||||
|
||||
@override(Policy)
|
||||
def set_weights(self, weights):
|
||||
@DeveloperAPI
|
||||
def set_weights(self, weights: ModelWeights) -> None:
|
||||
weights = convert_to_torch_tensor(weights, device=self.device)
|
||||
self.model.load_state_dict(weights)
|
||||
|
||||
@override(Policy)
|
||||
def get_state(self):
|
||||
@DeveloperAPI
|
||||
def is_recurrent(self) -> bool:
|
||||
return len(self.model.get_initial_state()) > 0
|
||||
|
||||
@override(Policy)
|
||||
@DeveloperAPI
|
||||
def num_state_tensors(self) -> int:
|
||||
return len(self.model.get_initial_state())
|
||||
|
||||
@override(Policy)
|
||||
@DeveloperAPI
|
||||
def get_initial_state(self) -> List[TensorType]:
|
||||
return [
|
||||
s.cpu().detach().numpy() for s in self.model.get_initial_state()
|
||||
]
|
||||
|
||||
@override(Policy)
|
||||
@DeveloperAPI
|
||||
def get_state(self) -> Union[Dict[str, TensorType], List[TensorType]]:
|
||||
state = super().get_state()
|
||||
state["_optimizer_variables"] = []
|
||||
for i, o in enumerate(self._optimizers):
|
||||
@@ -397,7 +449,8 @@ class TorchPolicy(Policy):
|
||||
return state
|
||||
|
||||
@override(Policy)
|
||||
def set_state(self, state):
|
||||
@DeveloperAPI
|
||||
def set_state(self, state: object) -> None:
|
||||
state = state.copy() # shallow copy
|
||||
# Set optimizer vars first.
|
||||
optimizer_vars = state.pop("_optimizer_variables", None)
|
||||
@@ -408,21 +461,11 @@ class TorchPolicy(Policy):
|
||||
# Then the Policy's (NN) weights.
|
||||
super().set_state(state)
|
||||
|
||||
@override(Policy)
|
||||
def is_recurrent(self):
|
||||
return len(self.model.get_initial_state()) > 0
|
||||
|
||||
@override(Policy)
|
||||
def num_state_tensors(self):
|
||||
return len(self.model.get_initial_state())
|
||||
|
||||
@override(Policy)
|
||||
def get_initial_state(self):
|
||||
return [
|
||||
s.cpu().detach().numpy() for s in self.model.get_initial_state()
|
||||
]
|
||||
|
||||
def extra_grad_process(self, optimizer, loss):
|
||||
@DeveloperAPI
|
||||
def extra_grad_process(
|
||||
self,
|
||||
optimizer: "torch.optim.Optimizer",
|
||||
loss: TensorType):
|
||||
"""Called after each optimizer.zero_grad() + loss.backward() call.
|
||||
|
||||
Called for each self._optimizers/loss-value pair.
|
||||
@@ -431,60 +474,94 @@ class TorchPolicy(Policy):
|
||||
|
||||
Args:
|
||||
optimizer (torch.optim.Optimizer): A torch optimizer object.
|
||||
loss (torch.Tensor): The loss tensor associated with the optimizer.
|
||||
loss (TensorType): The loss tensor associated with the optimizer.
|
||||
|
||||
Returns:
|
||||
dict: An info dict.
|
||||
Dict[str, TensorType]: An dict with information on the gradient
|
||||
processing step.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def extra_action_out(self, input_dict, state_batches, model, action_dist):
|
||||
@DeveloperAPI
|
||||
def extra_action_out(
|
||||
self,
|
||||
input_dict: Dict[str, TensorType],
|
||||
state_batches: List[TensorType],
|
||||
model: TorchModelV2,
|
||||
action_dist: TorchDistributionWrapper) -> Dict[str, TensorType]:
|
||||
"""Returns dict of extra info to include in experience batch.
|
||||
|
||||
Args:
|
||||
input_dict (dict): Dict of model input tensors.
|
||||
state_batches (list): List of state tensors.
|
||||
model (TorchModelV2): Reference to the model.
|
||||
action_dist (TorchActionDistribution): Torch action dist object
|
||||
input_dict (Dict[str, TensorType]): Dict of model input tensors.
|
||||
state_batches (List[TensorType]): List of state tensors.
|
||||
model (TorchModelV2): Reference to the model object.
|
||||
action_dist (TorchDistributionWrapper): Torch action dist object
|
||||
to get log-probs (e.g. for already sampled actions).
|
||||
|
||||
Returns:
|
||||
Dict[str, TensorType]: Extra outputs to return in a
|
||||
compute_actions() call (3rd return value).
|
||||
"""
|
||||
return {}
|
||||
|
||||
def extra_grad_info(self, train_batch):
|
||||
"""Return dict of extra grad info."""
|
||||
@DeveloperAPI
|
||||
def extra_grad_info(self, train_batch: SampleBatch) -> Dict[
|
||||
str, TensorType]:
|
||||
"""Return dict of extra grad info.
|
||||
|
||||
Args:
|
||||
train_batch (SampleBatch): The training batch for which to produce
|
||||
extra grad info for.
|
||||
|
||||
Returns:
|
||||
Dict[str, TensorType]: The info dict carrying grad info per str
|
||||
key.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def optimizer(self):
|
||||
"""Custom PyTorch optimizer to use."""
|
||||
@DeveloperAPI
|
||||
def optimizer(self) -> Union[
|
||||
List["torch.optim.Optimizer"], "torch.optim.Optimizer"]:
|
||||
"""Custom the local PyTorch optimizer(s) to use.
|
||||
|
||||
Returns:
|
||||
Union[List[torch.optim.Optimizer], torch.optim.Optimizer]:
|
||||
The local PyTorch optimizer(s) to use for this Policy.
|
||||
"""
|
||||
if hasattr(self, "config"):
|
||||
return torch.optim.Adam(
|
||||
self.model.parameters(), lr=self.config["lr"])
|
||||
else:
|
||||
return torch.optim.Adam(self.model.parameters())
|
||||
|
||||
@override(Policy)
|
||||
@DeveloperAPI
|
||||
def export_model(self, export_dir: str) -> None:
|
||||
"""TODO(sven): implement for torch.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@override(Policy)
|
||||
@DeveloperAPI
|
||||
def export_checkpoint(self, export_dir: str) -> None:
|
||||
"""TODO(sven): implement for torch.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@override(Policy)
|
||||
@DeveloperAPI
|
||||
def import_model_from_h5(self, import_file: str) -> None:
|
||||
"""Imports weights into torch model."""
|
||||
return self.model.import_from_h5(import_file)
|
||||
|
||||
def _lazy_tensor_dict(self, postprocessed_batch):
|
||||
train_batch = UsageTrackingDict(postprocessed_batch)
|
||||
train_batch.set_get_interceptor(convert_to_torch_tensor)
|
||||
return train_batch
|
||||
|
||||
@override(Policy)
|
||||
def export_model(self, export_dir):
|
||||
"""TODO(sven): implement for torch.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@override(Policy)
|
||||
def export_checkpoint(self, export_dir):
|
||||
"""TODO(sven): implement for torch.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@override(Policy)
|
||||
def import_model_from_h5(self, import_file):
|
||||
"""Imports weights into torch model."""
|
||||
return self.model.import_from_h5(import_file)
|
||||
|
||||
|
||||
# TODO: (sven) Unify hyperparam annealing procedures across RLlib (tf/torch)
|
||||
# and for all possible hyperparams, not just lr.
|
||||
@DeveloperAPI
|
||||
class LearningRateSchedule:
|
||||
"""Mixin for TFPolicy that adds a learning rate schedule."""
|
||||
|
||||
@@ -23,3 +23,5 @@ if __name__ == "__main__":
|
||||
|
||||
# Clean up.
|
||||
del os.environ["RLLIB_TEST_NO_TORCH_IMPORT"]
|
||||
|
||||
print("ok")
|
||||
|
||||
@@ -5,8 +5,8 @@ from ray.rllib.utils.schedules.schedule import Schedule
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
|
||||
def _linear_interpolation(l, r, alpha):
|
||||
return l + alpha * (r - l)
|
||||
def _linear_interpolation(left, right, alpha):
|
||||
return left + alpha * (right - left)
|
||||
|
||||
|
||||
class PiecewiseSchedule(Schedule):
|
||||
|
||||
+10
-5
@@ -2,6 +2,9 @@ from typing import Any, Dict, List, Tuple, Union
|
||||
import gym
|
||||
|
||||
# Represents a fully filled out config of a Trainer class.
|
||||
# Note: Policy config dicts are usually the same as TrainerConfigDict, but
|
||||
# parts of it may sometimes be altered in e.g. a multi-agent setup,
|
||||
# where we have >1 Policies in the same Trainer.
|
||||
TrainerConfigDict = dict
|
||||
|
||||
# A trainer config dict that only has overrides. It needs to be combined with
|
||||
@@ -63,8 +66,13 @@ GradInfoDict = dict
|
||||
# policy id.
|
||||
LearnerStatsDict = dict
|
||||
|
||||
# Type of dict returned by compute_gradients() representing model gradients.
|
||||
ModelGradients = dict
|
||||
# Represents a generic tensor type.
|
||||
# This could be an np.ndarray, tf.Tensor, or a torch.Tensor.
|
||||
TensorType = Any
|
||||
|
||||
# List of grads+var tuples (tf) or list of gradient tensors (torch)
|
||||
# representing model gradients and returned by compute_gradients().
|
||||
ModelGradients = Union[List[Tuple[TensorType, TensorType]], List[TensorType]]
|
||||
|
||||
# Type of dict returned by get_weights() representing model weights.
|
||||
ModelWeights = dict
|
||||
@@ -72,9 +80,6 @@ ModelWeights = dict
|
||||
# Some kind of sample batch.
|
||||
SampleBatchType = Union["SampleBatch", "MultiAgentBatch"]
|
||||
|
||||
# Represents a generic tensor type.
|
||||
TensorType = Any
|
||||
|
||||
# Either a plain tensor, or a dict or tuple of tensors (or StructTensors).
|
||||
TensorStructType = Union[TensorType, dict, tuple]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user