From f43d934817ae08119b38b59fb5dbcb68d2c8120d Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Sun, 5 Jul 2020 13:09:51 +0200 Subject: [PATCH] [RLlib] Type annotations for policy. (#9248) --- rllib/BUILD | 4 +- rllib/examples/models/custom_loss_model.py | 2 +- rllib/models/tests/test_attention_nets.py | 2 +- rllib/policy/dynamic_tf_policy.py | 148 ++++--- rllib/policy/policy.py | 313 +++++++++------ rllib/policy/rnn_sequencing.py | 8 +- rllib/policy/sample_batch.py | 178 +++++++-- rllib/policy/tf_policy.py | 368 +++++++++++------- rllib/policy/tf_policy_template.py | 8 +- rllib/policy/torch_policy.py | 275 ++++++++----- ...st_dependency.py => test_dependency_tf.py} | 0 rllib/tests/test_dependency_torch.py | 2 + rllib/utils/schedules/piecewise_schedule.py | 4 +- rllib/utils/types.py | 15 +- 14 files changed, 862 insertions(+), 465 deletions(-) rename rllib/tests/{test_dependency.py => test_dependency_tf.py} (100%) diff --git a/rllib/BUILD b/rllib/BUILD index 6e7c76538..243175a2b 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1134,10 +1134,10 @@ py_test( ) py_test( - name = "tests/test_dependency", + name = "tests/test_dependency_tf", tags = ["tests_dir", "tests_dir_D"], size = "small", - srcs = ["tests/test_dependency.py"] + srcs = ["tests/test_dependency_tf.py"] ) py_test( diff --git a/rllib/examples/models/custom_loss_model.py b/rllib/examples/models/custom_loss_model.py index a0fa41c2b..b9a8c31c2 100644 --- a/rllib/examples/models/custom_loss_model.py +++ b/rllib/examples/models/custom_loss_model.py @@ -179,7 +179,7 @@ class TorchCustomLossModel(TorchModelV2, nn.Module): # Add the imitation loss to each already calculated policy loss term. # Alternatively (if custom loss has its own optimizer): # return policy_loss + [10 * self.imitation_loss] - return [l + 10 * self.imitation_loss for l in policy_loss] + return [loss_ + 10 * self.imitation_loss for loss_ in policy_loss] def custom_stats(self): return { diff --git a/rllib/models/tests/test_attention_nets.py b/rllib/models/tests/test_attention_nets.py index e579c584f..3ca04a032 100644 --- a/rllib/models/tests/test_attention_nets.py +++ b/rllib/models/tests/test_attention_nets.py @@ -251,7 +251,7 @@ class TestModules(unittest.TestCase): self.train_tf_model( model, [x] + init_state, [y, value_labels, memory_labels, mlp_labels], - num_epochs=50, + num_epochs=200, minibatch_size=B) diff --git a/rllib/policy/dynamic_tf_policy.py b/rllib/policy/dynamic_tf_policy.py index 68fde7339..6ccc5326e 100644 --- a/rllib/policy/dynamic_tf_policy.py +++ b/rllib/policy/dynamic_tf_policy.py @@ -1,27 +1,34 @@ -"""Graph mode TF policy built using build_tf_policy().""" - from collections import OrderedDict +import gym import logging import numpy as np +from typing import Callable, Dict, List, Optional, Tuple from ray.util.debug import log_once +from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.tf_policy import TFPolicy from ray.rllib.models.catalog import ModelCatalog -from ray.rllib.utils.annotations import override +from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.debug import summarize from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.tracking_dict import UsageTrackingDict +from ray.rllib.utils.types import ModelGradients, TensorType, TrainerConfigDict tf1, tf, tfv = try_import_tf() logger = logging.getLogger(__name__) +@DeveloperAPI class DynamicTFPolicy(TFPolicy): """A TFPolicy that auto-defines placeholders dynamically at runtime. + Do not sub-class this class directly (neither should you sub-class + TFPolicy), but rather use rllib.policy.tf_policy_template.build_tf_policy + to generate your custom tf (graph-mode or eager) Policy classes. + Initialization of this class occurs in two phases. * Phase 1: the model is created and model variables are initialized. * Phase 2: a fake batch of data is created, sent to the trajectory @@ -39,61 +46,91 @@ class DynamicTFPolicy(TFPolicy): dist_class (type): TF action distribution class """ + @DeveloperAPI def __init__(self, - obs_space, - action_space, - config, - loss_fn, - stats_fn=None, - grad_stats_fn=None, - before_loss_init=None, - make_model=None, - action_sampler_fn=None, - action_distribution_fn=None, - existing_inputs=None, - existing_model=None, - get_batch_divisibility_req=None, - obs_include_prev_action_reward=True): + obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, + config: TrainerConfigDict, + loss_fn: Callable[ + [Policy, ModelV2, type, SampleBatch], TensorType], + *, + stats_fn: Optional[Callable[[Policy, SampleBatch], + Dict[str, TensorType]]] = None, + grad_stats_fn: Optional[Callable[ + [Policy, SampleBatch, ModelGradients], + Dict[str, TensorType]]] = None, + before_loss_init: Optional[Callable[ + [Policy, gym.spaces.Space, gym.spaces.Space, + TrainerConfigDict], None]] = None, + make_model: Optional[Callable[ + [Policy, gym.spaces.Space, gym.spaces.Space, + TrainerConfigDict], ModelV2]] = None, + action_sampler_fn: Optional[Callable[ + [TensorType, List[TensorType]], Tuple[ + TensorType, TensorType]]] = None, + action_distribution_fn: Optional[Callable[ + [Policy, ModelV2, TensorType, TensorType, TensorType], + Tuple[TensorType, type, List[TensorType]]]] = None, + existing_inputs: Optional[Dict[ + str, "tf1.placeholder"]] = None, + existing_model: Optional[ModelV2] = None, + get_batch_divisibility_req: Optional[int] = None, + obs_include_prev_action_reward: bool = True): """Initialize a dynamic TF policy. Arguments: - observation_space (gym.Space): Observation space of the policy. - action_space (gym.Space): Action space of the policy. - config (dict): Policy-specific configuration data. - loss_fn (func): function that returns a loss tensor the policy - graph, and dict of experience tensor placeholders - stats_fn (func): optional function that returns a dict of - TF fetches given the policy and batch input tensors - grad_stats_fn (func): optional function that returns a dict of - TF fetches given the policy and loss gradient tensors - before_loss_init (Optional[callable]): Optional function to run - prior to loss init that takes the same arguments as __init__. - make_model (func): optional function that returns a ModelV2 object - given (policy, obs_space, action_space, config). + observation_space (gym.spaces.Space): Observation space of the + policy. + action_space (gym.spaces.Space): Action space of the policy. + config (TrainerConfigDict): Policy-specific configuration data. + loss_fn (Callable[[Policy, ModelV2, type, SampleBatch], + TensorType]): Function that returns a loss tensor for the + policy graph. + stats_fn (Optional[Callable[[Policy, SampleBatch], + Dict[str, TensorType]]]): Optional function that returns a dict + of TF fetches given the policy and batch input tensors. + grad_stats_fn (Optional[Callable[[Policy, SampleBatch, + ModelGradients], Dict[str, TensorType]]]): + Optional function that returns a dict of TF fetches given the + policy, sample batch, and loss gradient tensors. + before_loss_init (Optional[Callable[ + [Policy, gym.spaces.Space, gym.spaces.Space, + TrainerConfigDict], None]]): Optional function to run prior to + loss init that takes the same arguments as __init__. + make_model (Optional[Callable[[Policy, gym.spaces.Space, + gym.spaces.Space, TrainerConfigDict], ModelV2]]): Optional + function that returns a ModelV2 object given + policy, obs_space, action_space, and policy config. All policy variables should be created in this function. If not specified, a default model will be created. - action_sampler_fn (Optional[callable]): An optional callable - returning a tuple of action and action prob tensors given - (policy, model, input_dict, obs_space, action_space, config). - If None, a default action distribution will be used. - action_distribution_fn (Optional[callable]): A callable returning - distribution inputs (parameters), a dist-class to generate an - action distribution object from, and internal-state outputs - (or an empty list if not applicable). + action_sampler_fn (Optional[Callable[[Policy, ModelV2, Dict[ + str, TensorType], TensorType, TensorType], Tuple[TensorType, + TensorType]]]): A callable returning a sampled action and its + log-likelihood given Policy, ModelV2, input_dict, explore, + timestep, and is_training. + action_distribution_fn (Optional[Callable[[Policy, ModelV2, + Dict[str, TensorType], TensorType, TensorType], + Tuple[TensorType, type, List[TensorType]]]]): A callable + returning distribution inputs (parameters), a dist-class to + generate an action distribution object from, and + internal-state outputs (or an empty list if not applicable). Note: No Exploration hooks have to be called from within `action_distribution_fn`. It's should only perform a simple forward pass through some model. - If None, pass inputs through `self.model()` to get the - distribution inputs. - existing_inputs (OrderedDict): When copying a policy, this - specifies an existing dict of placeholders to use instead of - defining new ones - existing_model (ModelV2): when copying a policy, this specifies - an existing model to clone and share weights with - get_batch_divisibility_req (func): optional function that returns - the divisibility requirement for sample batches - obs_include_prev_action_reward (bool): whether to include the - previous action and reward in the model input + If None, pass inputs through `self.model()` to get distribution + inputs. + The callable takes as inputs: Policy, ModelV2, input_dict, + explore, timestep, is_training. + existing_inputs (Optional[Dict[str, tf1.placeholder]]): When + copying a policy, this specifies an existing dict of + placeholders to use instead of defining new ones. + existing_model (Optional[ModelV2]): When copying a policy, this + specifies an existing model to clone and share weights with. + get_batch_divisibility_req (Optional[Callable[[Policy], int]]]): + Optional callable that returns the divisibility requirement + for sample batches given the Policy. + obs_include_prev_action_reward (bool): Whether to include the + previous action and reward in the model input (default: True). """ self.observation_space = obs_space self.action_space = action_space @@ -258,10 +295,12 @@ class DynamicTFPolicy(TFPolicy): before_loss_init(self, obs_space, action_space, config) if not existing_inputs: - self._initialize_loss() + self._initialize_loss_dynamically() @override(TFPolicy) - def copy(self, existing_inputs): + @DeveloperAPI + def copy(self, + existing_inputs: List[Tuple[str, "tf1.placeholder"]]) -> TFPolicy: """Creates a copy of self using existing input placeholders.""" # Note that there might be RNN state inputs at the end of the list @@ -306,13 +345,14 @@ class DynamicTFPolicy(TFPolicy): return instance @override(Policy) - def get_initial_state(self): + @DeveloperAPI + def get_initial_state(self) -> List[TensorType]: if self.model: return self.model.get_initial_state() else: return [] - def _initialize_loss(self): + def _initialize_loss_dynamically(self): def fake_array(tensor): shape = tensor.shape.as_list() shape = [s if s is not None else 1 for s in shape] @@ -404,7 +444,7 @@ class DynamicTFPolicy(TFPolicy): self._grad_stats_fn(self, train_batch, self._grads)) self._sess.run(tf1.global_variables_initializer()) - def _do_loss_init(self, train_batch): + def _do_loss_init(self, train_batch: SampleBatch): loss = self._loss_fn(self, self.model, self.dist_class, train_batch) if self._stats_fn: self._stats_fetches.update(self._stats_fn(self, train_batch)) diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index dcedda7d5..58b46bf55 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -4,13 +4,15 @@ import numpy as np import tree from typing import Dict, List, Optional +from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.exploration.exploration import Exploration from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.from_config import from_config from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space, \ unbatch -from ray.rllib.utils.types import AgentID +from ray.rllib.utils.types import AgentID, ModelGradients, ModelWeights, \ + TensorType, TrainerConfigDict, Tuple, Union torch, _ = try_import_torch() @@ -41,7 +43,11 @@ class Policy(metaclass=ABCMeta): """ @DeveloperAPI - def __init__(self, observation_space, action_space, config): + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + config: TrainerConfigDict): """Initialize the graph. This is the standard constructor for policies. The policy @@ -49,9 +55,10 @@ class Policy(metaclass=ABCMeta): these arguments. Args: - observation_space (gym.Space): Observation space of the policy. - action_space (gym.Space): Action space of the policy. - config (dict): Policy-specific configuration data. + observation_space (gym.spaces.Space): Observation space of the + policy. + action_space (gym.spaces.Space): Action space of the policy. + config (TrainerConfigDict): Policy-specific configuration data. """ self.observation_space = observation_space self.action_space = action_space @@ -66,78 +73,95 @@ class Policy(metaclass=ABCMeta): @abstractmethod @DeveloperAPI - def compute_actions(self, - obs_batch, - state_batches=None, - prev_action_batch=None, - prev_reward_batch=None, - info_batch=None, - episodes=None, - explore=None, - timestep=None, - **kwargs): + def compute_actions( + self, + obs_batch: Union[List[TensorType], TensorType], + state_batches: Optional[List[TensorType]] = None, + prev_action_batch: Union[List[TensorType], TensorType] = None, + prev_reward_batch: Union[List[TensorType], TensorType] = None, + info_batch: Optional[Dict[str, list]] = None, + episodes: Optional[List["MultiAgentEpisode"]] = None, + explore: Optional[bool] = None, + timestep: Optional[int] = None, + **kwargs) -> \ + Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: """Computes actions for the current policy. Args: - obs_batch (Union[List, np.ndarray]): Batch of observations. - state_batches (Optional[list]): List of RNN state input batches, - if any. - prev_action_batch (Optional[List, np.ndarray]): Batch of previous - action values. - prev_reward_batch (Optional[List, np.ndarray]): Batch of previous - rewards. - info_batch (info): Batch of info objects. - episodes (list): MultiAgentEpisode for each obs in obs_batch. - This provides access to all of the internal episode state, - which may be useful for model-based or multiagent algorithms. - explore (bool): Whether to pick an exploitation or exploration - action (default: None -> use self.config["explore"]). - timestep (int): The current (sampling) time step. + obs_batch (Union[List[TensorType], TensorType]): Batch of + observations. + state_batches (Optional[List[TensorType]]): List of RNN state input + batches, if any. + prev_action_batch (Union[List[TensorType], TensorType]): Batch of + previous action values. + prev_reward_batch (Union[List[TensorType], TensorType]): Batch of + previous rewards. + info_batch (Optional[Dict[str, list]]): Batch of info objects. + episodes (Optional[List[MultiAgentEpisode]] ): List of + MultiAgentEpisode, one for each obs in obs_batch. This provides + access to all of the internal episode state, which may be + useful for model-based or multiagent algorithms. + explore (Optional[bool]): Whether to pick an exploitation or + exploration action. Set to None (default) for using the + value of `self.config["explore"]`. + timestep (Optional[int]): The current (sampling) time step. + + Keyword Args: kwargs: forward compatibility placeholder Returns: - actions (np.ndarray): batch of output actions, with shape like - [BATCH_SIZE, ACTION_SHAPE]. - state_outs (list): list of RNN state output batches, if any, with - shape like [STATE_SIZE, BATCH_SIZE]. - info (dict): dictionary of extra feature batches, if any, with - shape like {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}. + Tuple: + actions (TensorType): Batch of output actions, with shape like + [BATCH_SIZE, ACTION_SHAPE]. + state_outs (List[TensorType]): List of RNN state output + batches, if any, with shape like [STATE_SIZE, BATCH_SIZE]. + info (List[dict]): Dictionary of extra feature batches, if any, + with shape like + {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}. """ raise NotImplementedError @DeveloperAPI - def compute_single_action(self, - obs, - state=None, - prev_action=None, - prev_reward=None, - info=None, - episode=None, - clip_actions=False, - explore=None, - timestep=None, - **kwargs): + def compute_single_action( + self, + obs: TensorType, + state: Optional[List[TensorType]] = None, + prev_action: Optional[TensorType] = None, + prev_reward: Optional[TensorType] = None, + info: dict = None, + episode: Optional["MultiAgentEpisode"] = None, + clip_actions: bool = False, + explore: Optional[bool] = None, + timestep: Optional[int] = None, + **kwargs) -> \ + Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: """Unbatched version of compute_actions. - Arguments: - obs (obj): Single observation. - state (list): List of RNN state inputs, if any. - prev_action (obj): Previous action value, if any. - prev_reward (float): Previous reward, if any. - info (dict): info object, if any - episode (MultiAgentEpisode): this provides access to all of the - internal episode state, which may be useful for model-based or - multi-agent algorithms. + Args: + obs (TensorType): Single observation. + state (Optional[List[TensorType]]): List of RNN state inputs, if + any. + prev_action (Optional[TensorType]): Previous action value, if any. + prev_reward (Optional[TensorType]): Previous reward, if any. + info (dict): Info object, if any. + episode (Optional[MultiAgentEpisode]): this provides access to all + of the internal episode state, which may be useful for + model-based or multi-agent algorithms. clip_actions (bool): Should actions be clipped? - explore (bool): Whether to pick an exploitation or exploration - action (default: None -> use self.config["explore"]). - timestep (int): The current (sampling) time step. + explore (Optional[bool]): Whether to pick an exploitation or + exploration action + (default: None -> use self.config["explore"]). + timestep (Optional[int]): The current (sampling) time step. + + Keyword Args: kwargs: forward compatibility placeholder Returns: - actions (obj): single action - state_outs (list): list of RNN state outputs, if any - info (dict): dictionary of extra features, if any + Tuple: + actions (TensorType): Single action. + state_outs (List[TensorType]): List of RNN state outputs, + if any. + info (dict): Dictionary of extra features, if any. """ prev_action_batch = None prev_reward_batch = None @@ -196,7 +220,8 @@ class Policy(metaclass=ABCMeta): other_trajectories: Dict[AgentID, "Trajectory"], explore: bool = None, timestep: Optional[int] = None, - **kwargs): + **kwargs) -> \ + Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: """Computes actions for the current policy based on . Note: This is an experimental API method. @@ -215,59 +240,68 @@ class Policy(metaclass=ABCMeta): kwargs: forward compatibility placeholder Returns: - actions (np.ndarray): batch of output actions, with shape like - [BATCH_SIZE, ACTION_SHAPE]. - state_outs (list): list of RNN state output batches, if any, with - shape like [STATE_SIZE, BATCH_SIZE]. - info (dict): dictionary of extra feature batches, if any, with - shape like {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}. + Tuple: + actions (TensorType): Batch of output actions, with shape + like [BATCH_SIZE, ACTION_SHAPE]. + state_outs (List[TensorType]): List of RNN state output + batches, if any, with shape like [STATE_SIZE, BATCH_SIZE]. + info (dict): Dictionary of extra feature batches, if any, with + shape like + {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}. """ raise NotImplementedError @DeveloperAPI - def compute_log_likelihoods(self, - actions, - obs_batch, - state_batches=None, - prev_action_batch=None, - prev_reward_batch=None): + def compute_log_likelihoods( + self, + actions: Union[List[TensorType], TensorType], + obs_batch: Union[List[TensorType], TensorType], + state_batches: Optional[List[TensorType]] = None, + prev_action_batch: Optional[ + Union[List[TensorType], TensorType]] = None, + prev_reward_batch: Optional[ + Union[List[TensorType], TensorType]] = None) -> TensorType: """Computes the log-prob/likelihood for a given action and observation. Args: - actions (Union[List,np.ndarray]): Batch of actions, for which to - retrieve the log-probs/likelihoods (given all other inputs: - obs, states, ..). - obs_batch (Union[List,np.ndarray]): Batch of observations. - state_batches (Optional[list]): List of RNN state input batches, - if any. - prev_action_batch (Optional[List,np.ndarray]): Batch of previous - action values. - prev_reward_batch (Optional[List,np.ndarray]): Batch of previous - rewards. + actions (Union[List[TensorType], TensorType]): Batch of actions, + for which to retrieve the log-probs/likelihoods (given all + other inputs: obs, states, ..). + obs_batch (Union[List[TensorType], TensorType]): Batch of + observations. + state_batches (Optional[List[TensorType]]): List of RNN state input + batches, if any. + prev_action_batch (Optional[Union[List[TensorType], TensorType]]): + Batch of previous action values. + prev_reward_batch (Optional[Union[List[TensorType], TensorType]]): + Batch of previous rewards. Returns: - log-likelihoods (np.ndarray): Batch of log probs/likelihoods, with - shape: [BATCH_SIZE]. + TensorType: Batch of log probs/likelihoods, with shape: + [BATCH_SIZE]. """ raise NotImplementedError @DeveloperAPI - def postprocess_trajectory(self, - sample_batch, - other_agent_batches=None, - episode=None): + def postprocess_trajectory( + self, + sample_batch: SampleBatch, + other_agent_batches: Optional[ + Dict[AgentID, Tuple["Policy", SampleBatch]]] = None, + episode: Optional["MultiAgentEpisode"] = None) -> SampleBatch: """Implements algorithm-specific trajectory postprocessing. This will be called on each trajectory fragment computed during policy evaluation. Each fragment is guaranteed to be only from one episode. - Arguments: + Args: sample_batch (SampleBatch): batch of experiences for the policy, which will contain at most one episode trajectory. other_agent_batches (dict): In a multi-agent env, this contains a mapping of agent ids to (policy, agent_batch) tuples containing the policy and experiences of the other agents. - episode (MultiAgentEpisode): this provides access to all of the + episode (Optional[MultiAgentEpisode]): An optional multi-agent + episode object to provide access to all of the internal episode state, which may be useful for model-based or multi-agent algorithms. @@ -277,18 +311,22 @@ class Policy(metaclass=ABCMeta): return sample_batch @DeveloperAPI - def learn_on_batch(self, samples): + def learn_on_batch(self, samples: SampleBatch) -> Dict[str, TensorType]: """Fused compute gradients and apply gradients call. Either this or the combination of compute/apply grads must be implemented by subclasses. + Args: + samples (SampleBatch): The SampleBatch object to learn from. + Returns: - grad_info: dictionary of extra metadata from compute_gradients(). + Dict[str, TensorType]: Dictionary of extra metadata from + compute_gradients(). Examples: - >>> batch = ev.sample() - >>> ev.learn_on_batch(samples) + >>> sample_batch = ev.sample() + >>> ev.learn_on_batch(sample_batch) """ grads, grad_info = self.compute_gradients(samples) @@ -296,65 +334,76 @@ class Policy(metaclass=ABCMeta): return grad_info @DeveloperAPI - def compute_gradients(self, postprocessed_batch): + def compute_gradients(self, postprocessed_batch: SampleBatch) -> \ + Tuple[ModelGradients, Dict[str, TensorType]]: """Computes gradients against a batch of experiences. Either this or learn_on_batch() must be implemented by subclasses. + Args: + postprocessed_batch (SampleBatch): The SampleBatch object to use + for calculating gradients. + Returns: - grads (list): List of gradient output values - info (dict): Extra policy-specific values + Tuple[ModelGradients, Dict[str, TensorType]]: + - List of gradient output values. + - Extra policy-specific info values. """ raise NotImplementedError @DeveloperAPI - def apply_gradients(self, gradients): + def apply_gradients(self, gradients: ModelGradients) -> None: """Applies previously computed gradients. Either this or learn_on_batch() must be implemented by subclasses. + + Args: + gradients (ModelGradients): The already calculated gradients to + apply to this Policy. """ raise NotImplementedError @DeveloperAPI - def get_weights(self): + def get_weights(self) -> ModelWeights: """Returns model weights. Returns: - weights (obj): Serializable copy or view of model weights + ModelWeights: Serializable copy or view of model weights. """ raise NotImplementedError @DeveloperAPI - def set_weights(self, weights): + def set_weights(self, weights: ModelWeights) -> None: """Sets model weights. - Arguments: - weights (obj): Serializable copy or view of model weights + Args: + weights (ModelWeights): Serializable copy or view of model weights. """ raise NotImplementedError @DeveloperAPI - def get_exploration_info(self): + def get_exploration_info(self) -> Dict[str, TensorType]: """Returns the current exploration information of this policy. This information depends on the policy's Exploration object. Returns: - any: Serializable information on the `self.exploration` object. + Dict[str, TensorType]: Serializable information on the + `self.exploration` object. """ return self.exploration.get_info() @DeveloperAPI - def is_recurrent(self): + def is_recurrent(self) -> bool: """Whether this Policy holds a recurrent Model. Returns: bool: True if this Policy has-a RNN-based Model. """ - return 0 + return False @DeveloperAPI - def num_state_tensors(self): + def num_state_tensors(self) -> int: """The number of internal states needed by the RNN-Model of the Policy. Returns: @@ -363,73 +412,83 @@ class Policy(metaclass=ABCMeta): return 0 @DeveloperAPI - def get_initial_state(self): - """Returns initial RNN state for the current policy.""" + def get_initial_state(self) -> List[TensorType]: + """Returns initial RNN state for the current policy. + + Returns: + List[TensorType]: Initial RNN state for the current policy. + """ return [] @DeveloperAPI - def get_state(self): + def get_state(self) -> Union[Dict[str, TensorType], List[TensorType]]: """Saves all local state. Returns: - state (obj): Serialized local state. + Union[Dict[str, TensorType], List[TensorType]]: Serialized local + state. """ return self.get_weights() @DeveloperAPI - def set_state(self, state): + def set_state(self, state: object) -> None: """Restores all local state. - Arguments: + Args: state (obj): Serialized local state. """ self.set_weights(state) @DeveloperAPI - def on_global_var_update(self, global_vars): + def on_global_var_update(self, global_vars: Dict[str, TensorType]) -> None: """Called on an update to global vars. - Arguments: - global_vars (dict): Global variables broadcast from the driver. + Args: + global_vars (Dict[str, TensorType]): Global variables by str key, + broadcast from the driver. """ # Store the current global time step (sum over all policies' sample # steps). self.global_timestep = global_vars["timestep"] @DeveloperAPI - def export_model(self, export_dir): + def export_model(self, export_dir: str) -> None: """Export Policy to local directory for serving. - Arguments: + Args: export_dir (str): Local writable directory. """ raise NotImplementedError @DeveloperAPI - def export_checkpoint(self, export_dir): + def export_checkpoint(self, export_dir: str) -> None: """Export Policy checkpoint to local directory. - Argument: + Args: export_dir (str): Local writable directory. """ raise NotImplementedError @DeveloperAPI - def import_model_from_h5(self, import_file): + def import_model_from_h5(self, import_file: str) -> None: """Imports Policy from local file. - Arguments: + Args: import_file (str): Local readable file. """ raise NotImplementedError - def _create_exploration(self): + def _create_exploration(self) -> Exploration: """Creates the Policy's Exploration object. This method only exists b/c some Trainers do not use TfPolicy nor TorchPolicy, but inherit directly from Policy. Others inherit from TfPolicy w/o using DynamicTfPolicy. - TODO(sven): unify these cases.""" + TODO(sven): unify these cases. + + Returns: + Exploration: The Exploration object to be used by this Policy. + """ if getattr(self, "exploration", None) is not None: return self.exploration diff --git a/rllib/policy/rnn_sequencing.py b/rllib/policy/rnn_sequencing.py index 5946938dc..910ad5e5c 100644 --- a/rllib/policy/rnn_sequencing.py +++ b/rllib/policy/rnn_sequencing.py @@ -230,8 +230,8 @@ def chop_into_sequences(episode_ids, f_pad = np.zeros((length, ) + np.shape(f)[1:]) seq_base = 0 i = 0 - for l in seq_lens: - for seq_offset in range(l): + for len_ in seq_lens: + for seq_offset in range(len_): f_pad[seq_base + seq_offset] = f[i] i += 1 seq_base += max_seq_len @@ -243,9 +243,9 @@ def chop_into_sequences(episode_ids, s = np.array(s) s_init = [] i = 0 - for l in seq_lens: + for len_ in seq_lens: s_init.append(s[i]) - i += l + i += len_ initial_states.append(np.array(s_init)) if shuffle: diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index 2036a4503..8db340c9b 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -2,12 +2,13 @@ import collections import numpy as np import sys import itertools -from typing import Dict, List, Any +from typing import Any, Dict, Iterable, List, Optional, Set, Union from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI from ray.rllib.utils.compression import pack, unpack, is_compressed from ray.rllib.utils.memory import concat_aligned from ray.rllib.utils.deprecation import deprecation_warning +from ray.rllib.utils.types import TensorType # Default policy id for single agent environments DEFAULT_POLICY_ID = "default_policy" @@ -71,15 +72,16 @@ class SampleBatch: @staticmethod @PublicAPI - def concat_samples(samples): + def concat_samples(samples: List[Dict[str, TensorType]]) -> \ + Union["SampleBatch", "MultiAgentBatch"]: """Concatenates n data dicts or MultiAgentBatches. Args: - samples (List[Dict[np.ndarray]]]): List of dicts of data (numpy). + samples (List[Dict[TensorType]]]): List of dicts of data (numpy). Returns: - Union[SampleBatch,MultiAgentBatch]: A new (compressed) SampleBatch/ - MultiAgentBatch. + Union[SampleBatch, MultiAgentBatch]: A new (compressed) + SampleBatch or MultiAgentBatch. """ if isinstance(samples[0], MultiAgentBatch): return MultiAgentBatch.concat_samples(samples) @@ -90,9 +92,17 @@ class SampleBatch: return SampleBatch(out) @PublicAPI - def concat(self, other): + def concat(self, other: "SampleBatch") -> "SampleBatch": """Returns a new SampleBatch with each data column concatenated. + Args: + other (SampleBatch): The other SampleBatch object to concat to this + one. + + Returns: + SampleBatch: The new SampleBatch, resulting from concating `other` + to `self`. + Examples: >>> b1 = SampleBatch({"a": [1, 2]}) >>> b2 = SampleBatch({"a": [3, 4, 5]}) @@ -110,15 +120,24 @@ class SampleBatch: return SampleBatch(out) @PublicAPI - def copy(self): + def copy(self) -> "SampleBatch": + """Creates a (deep) copy of this SampleBatch and returns it. + + Returns: + SampleBatch: A (deep) copy of this SampleBatch object. + """ return SampleBatch( {k: np.array(v, copy=True) for (k, v) in self.data.items()}) @PublicAPI - def rows(self): + def rows(self) -> Dict[str, TensorType]: """Returns an iterator over data rows, i.e. dicts with column values. + Yields: + Dict[str, TensorType]: The column values of the row in this + iteration. + Examples: >>> batch = SampleBatch({"a": [1, 2, 3], "b": [4, 5, 6]}) >>> for row in batch.rows(): @@ -135,7 +154,7 @@ class SampleBatch: yield row @PublicAPI - def columns(self, keys): + def columns(self, keys: List[str]) -> List[any]: """Returns a list of the batch-data in the specified columns. Args: @@ -157,7 +176,7 @@ class SampleBatch: return out @PublicAPI - def shuffle(self): + def shuffle(self) -> None: """Shuffles the rows of this batch in-place.""" permutation = np.random.permutation(self.count) @@ -165,7 +184,7 @@ class SampleBatch: self[key] = val[permutation] @PublicAPI - def split_by_episode(self): + def split_by_episode(self) -> List["SampleBatch"]: """Splits this batch's data by `eps_id`. Returns: @@ -189,7 +208,7 @@ class SampleBatch: return slices @PublicAPI - def slice(self, start, end): + def slice(self, start: int, end: int) -> "SampleBatch": """Returns a slice of the row data of this batch. Args: @@ -197,13 +216,25 @@ class SampleBatch: end (int): Ending index. Returns: - SampleBatch which has a slice of this batch's data. + SampleBatch: A new SampleBatch, which has a slice of this batch's + data. """ return SampleBatch({k: v[start:end] for k, v in self.data.items()}) @PublicAPI def timeslices(self, k: int) -> List["SampleBatch"]: + """Returns SampleBatches, each one representing a k-slice of this one. + + Will start from timestep 0 and produce slices of size=k. + + Args: + k (int): The size (in timesteps) of each returned SampleBatch. + + Returns: + List[SampleBatch]: The list of (new) SampleBatches (each one of + size k). + """ out = [] i = 0 while i < self.count: @@ -212,31 +243,78 @@ class SampleBatch: return out @PublicAPI - def keys(self): + def keys(self) -> Iterable[str]: + """ + Returns: + Iterable[str]: The keys() iterable over `self.data`. + """ return self.data.keys() @PublicAPI - def items(self): + def items(self) -> Iterable[TensorType]: + """ + Returns: + Iterable[TensorType]: The values() iterable over `self.data`. + """ return self.data.items() @PublicAPI - def get(self, key): + def get(self, key: str) -> Optional[TensorType]: + """Returns one column (by key) from the data or None if key not found. + + Args: + key (str): The key (column name) to return. + + Returns: + Optional[TensorType]: The data under the given key. None if key + not found in data. + """ return self.data.get(key) @PublicAPI def size_bytes(self) -> int: + """ + Returns: + int: The overall size in bytes of the data buffer (all columns). + """ return sum(sys.getsizeof(d) for d in self.data) @PublicAPI - def __getitem__(self, key): + def __getitem__(self, key: str) -> TensorType: + """Returns one column (by key) from the data. + + Args: + key (str): The key (column name) to return. + + Returns: + TensorType]: The data under the given key. + """ return self.data[key] @PublicAPI - def __setitem__(self, key, item): + def __setitem__(self, key, item) -> None: + """Inserts (overrides) an entire column (by key) in the data buffer. + + Args: + key (str): The column name to set a value for. + item (TensorType): The data to insert. + """ self.data[key] = item @DeveloperAPI - def compress(self, bulk=False, columns=frozenset(["obs", "new_obs"])): + def compress( + self, + bulk: bool = False, + columns: Set[str] = frozenset(["obs", "new_obs"])) -> None: + """Compresses the data buffers (by column) in place. + + Args: + bulk (bool): Whether to compress across the batch dimension (0) + as well. If False will compress n separate list items, where n + is the batch size. + columns (Set[str]): The columns to compress. Default: Only + compress the obs and new_obs columns. + """ for key in columns: if key in self.data: if bulk: @@ -246,7 +324,19 @@ class SampleBatch: [pack(o) for o in self.data[key]]) @DeveloperAPI - def decompress_if_needed(self, columns=frozenset(["obs", "new_obs"])): + def decompress_if_needed( + self, + columns: Set[str] = frozenset( + ["obs", "new_obs"])) -> "SampleBatch": + """Decompresses data buffers (per column if not compressed) in place. + + Args: + columns (Set[str]): The columns to decompress. Default: Only + decompress the obs and new_obs columns. + + Returns: + SampleBatch: This very SampleBatch. + """ for key in columns: if key in self.data: arr = self.data[key] @@ -272,10 +362,17 @@ class SampleBatch: @PublicAPI class MultiAgentBatch: - """A batch of experiences from multiple agents in the environment.""" + """A batch of experiences from multiple agents in the environment. + + Attributes: + policy_batches (Dict[PolicyID, SampleBatch]): Mapping from policy + ids to SampleBatches of experiences. + count (int): The number of env steps in this batch. + """ @PublicAPI - def __init__(self, policy_batches: Dict[PolicyID, SampleBatch], + def __init__(self, + policy_batches: Dict[PolicyID, SampleBatch], env_steps: int): """Initialize a MultiAgentBatch object. @@ -285,12 +382,8 @@ class MultiAgentBatch: env_steps (int): The number of timesteps in the environment this batch contains. This will be less than the number of transitions this batch contains across all policies in total. - - Attributes: - policy_batches (Dict[PolicyID, SampleBatch]): Mapping from policy - ids to SampleBatches of experiences. - count (int): the number of env steps in this batch. """ + for v in policy_batches.values(): assert isinstance(v, SampleBatch) self.policy_batches = policy_batches @@ -303,7 +396,7 @@ class MultiAgentBatch: """The number of env steps (there are >= 1 agent steps per env step). Returns: - int: the number of environment steps contained in this batch. + int: The number of environment steps contained in this batch. """ return self.count @@ -312,7 +405,7 @@ class MultiAgentBatch: """The number of agent steps (there are >= 1 agent steps per env step). Returns: - int: the number of agent steps total in this batch. + int: The number of agent steps total in this batch. """ ct = 0 for batch in self.policy_batches.values(): @@ -379,8 +472,9 @@ class MultiAgentBatch: @staticmethod @PublicAPI - def wrap_as_needed(policy_batches: Dict[PolicyID, SampleBatch], - env_steps: int) -> Any: + def wrap_as_needed( + policy_batches: Dict[PolicyID, SampleBatch], + env_steps: int) -> Union[SampleBatch, "MultiAgentBatch"]: """Returns SampleBatch or MultiAgentBatch, depending on given policies. Args: @@ -437,11 +531,19 @@ class MultiAgentBatch: @PublicAPI def size_bytes(self) -> int: + """ + Returns: + int: The overall size in bytes of all policy batches (all columns). + """ return sum(b.size_bytes() for b in self.policy_batches.values()) @DeveloperAPI - def compress(self, bulk=False, columns=frozenset(["obs", "new_obs"])): - """Compresses each policy batch. + def compress( + self, + bulk: bool = False, + columns: Set[str] = frozenset( + ["obs", "new_obs"])) -> None: + """Compresses each policy batch (per column) in place. Args: bulk (bool): Whether to compress across the batch dimension (0) @@ -453,11 +555,17 @@ class MultiAgentBatch: batch.compress(bulk=bulk, columns=columns) @DeveloperAPI - def decompress_if_needed(self, columns=frozenset(["obs", "new_obs"])): - """Decompresses each policy batch, if already compressed. + def decompress_if_needed( + self, + columns: Set[str] = frozenset( + ["obs", "new_obs"])) -> "MultiAgentBatch": + """Decompresses each policy batch (per column), if already compressed. Args: columns (Set[str]): Set of column names to decompress. + + Returns: + MultiAgentBatch: This very MultiAgentBatch. """ for batch in self.policy_batches.values(): batch.decompress_if_needed(columns) diff --git a/rllib/policy/tf_policy.py b/rllib/policy/tf_policy.py index 6b4243190..c813a65f5 100644 --- a/rllib/policy/tf_policy.py +++ b/rllib/policy/tf_policy.py @@ -1,7 +1,9 @@ import errno +import gym import logging import numpy as np import os +from typing import Dict, List, Optional, Tuple, Union import ray import ray.experimental.tf_utils @@ -15,6 +17,7 @@ from ray.rllib.utils.debug import summarize from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule from ray.rllib.utils.tf_run_builder import TFRunBuilder +from ray.rllib.utils.types import ModelGradients, TensorType, TrainerConfigDict tf1, tf, tfv = try_import_tf() logger = logging.getLogger(__name__) @@ -24,6 +27,11 @@ logger = logging.getLogger(__name__) class TFPolicy(Policy): """An agent policy and loss implemented in TensorFlow. + Do not sub-class this class directly (neither should you sub-class + DynamicTFPolicy), but rather use + rllib.policy.tf_policy_template.build_tf_policy + to generate your custom tf (graph-mode or eager) Policy classes. + Extending this class enables RLlib to perform TensorFlow specific optimizations on the policy, e.g., parallelization across gpus or fusing multiple graphs together in the multi-agent setting. @@ -48,77 +56,83 @@ class TFPolicy(Policy): @DeveloperAPI def __init__(self, - observation_space, - action_space, - config, - sess, - obs_input, - sampled_action, - loss, - loss_inputs, - model=None, - sampled_action_logp=None, - action_input=None, - log_likelihood=None, - dist_inputs=None, - dist_class=None, - state_inputs=None, - state_outputs=None, - prev_action_input=None, - prev_reward_input=None, - seq_lens=None, - max_seq_len=20, - batch_divisibility_req=1, - update_ops=None, - explore=None, - timestep=None): - """Initialize the policy. + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + config: TrainerConfigDict, + sess: "tf1.Session", + obs_input: TensorType, + sampled_action: TensorType, + loss: TensorType, + loss_inputs: List[Tuple[str, TensorType]], + model: ModelV2 = None, + sampled_action_logp: Optional[TensorType] = None, + action_input: Optional[TensorType] = None, + log_likelihood: Optional[TensorType] = None, + dist_inputs: Optional[TensorType] = None, + dist_class: Optional[type] = None, + state_inputs: Optional[List[TensorType]] = None, + state_outputs: Optional[List[TensorType]] = None, + prev_action_input: Optional[TensorType] = None, + prev_reward_input: Optional[TensorType] = None, + seq_lens: Optional[TensorType] = None, + max_seq_len: int = 20, + batch_divisibility_req: int = 1, + update_ops: List[TensorType] = None, + explore: Optional[TensorType] = None, + timestep: Optional[TensorType] = None): + """Initializes a Policy object. - Arguments: - observation_space (gym.Space): Observation space of the env. - action_space (gym.Space): Action space of the env. - config (dict): The Policy config dict. - sess (Session): The TensorFlow session to use. - obs_input (Tensor): Input placeholder for observations, of shape - [BATCH_SIZE, obs...]. - sampled_action (Tensor): Tensor for sampling an action, of shape - [BATCH_SIZE, action...] - loss (Tensor): Scalar policy loss output tensor. - loss_inputs (list): A (name, placeholder) tuple for each loss - input argument. Each placeholder name must correspond to a - SampleBatch column key returned by postprocess_trajectory(), - and has shape [BATCH_SIZE, data...]. These keys will be read - from postprocessed sample batches and fed into the specified - placeholders during loss computation. - model (rllib.models.Model): used to integrate custom losses and + Args: + observation_space (gym.spaces.Space): Observation space of the env. + action_space (gym.spaces.Space): Action space of the env. + config (TrainerConfigDict): The Policy config dict. + sess (tf1.Session): The TensorFlow session to use. + obs_input (TensorType): Input placeholder for observations, of + shape [BATCH_SIZE, obs...]. + sampled_action (TensorType): Tensor for sampling an action, of + shape [BATCH_SIZE, action...] + loss (TensorType): Scalar policy loss output tensor. + loss_inputs (List[Tuple[str, TensorType]]): A (name, placeholder) + tuple for each loss input argument. Each placeholder name must + correspond to a SampleBatch column key returned by + postprocess_trajectory(), and has shape [BATCH_SIZE, data...]. + These keys will be read from postprocessed sample batches and + fed into the specified placeholders during loss computation. + model (ModelV2): used to integrate custom losses and stats from user-defined RLlib models. - sampled_action_logp (Tensor): log probability of the sampled - action. - action_input (Optional[Tensor]): Input placeholder for actions for - logp/log-likelihood calculations. - log_likelihood (Optional[Tensor]): Tensor to calculate the + sampled_action_logp (Optional[TensorType]): log probability of the + sampled action. + action_input (Optional[TensorType]): Input placeholder for actions + for logp/log-likelihood calculations. + log_likelihood (Optional[TensorType]): Tensor to calculate the log_likelihood (given action_input and obs_input). - dist_class (Optional[type): An optional ActionDistribution class + dist_class (Optional[type]): An optional ActionDistribution class to use for generating a dist object from distribution inputs. - dist_inputs (Optional[Tensor]): Tensor to calculate the + dist_inputs (Optional[TensorType]): Tensor to calculate the distribution inputs/parameters. - state_inputs (list): list of RNN state input Tensors. - state_outputs (list): list of RNN state output Tensors. - prev_action_input (Tensor): placeholder for previous actions - prev_reward_input (Tensor): placeholder for previous rewards - seq_lens (Tensor): Placeholder for RNN sequence lengths, of shape - [NUM_SEQUENCES]. Note that NUM_SEQUENCES << BATCH_SIZE. See + state_inputs (Optional[List[TensorType]]): List of RNN state input + Tensors. + state_outputs (Optional[List[TensorType]]): List of RNN state + output Tensors. + prev_action_input (Optional[TensorType]): placeholder for previous + actions. + prev_reward_input (Optional[TensorType]): placeholder for previous + rewards. + seq_lens (Optional[TensorType]): Placeholder for RNN sequence + lengths, of shape [NUM_SEQUENCES]. + Note that NUM_SEQUENCES << BATCH_SIZE. See policy/rnn_sequencing.py for more information. max_seq_len (int): Max sequence length for LSTM training. batch_divisibility_req (int): pad all agent experiences batches to multiples of this value. This only has an effect if not using a LSTM model. - update_ops (list): override the batchnorm update ops to run when - applying gradients. Otherwise we run all update ops found in - the current variable scope. - explore (Tensor): Placeholder for `explore` parameter into - call to Exploration.get_exploration_action. - timestep (Tensor): Placeholder for the global sampling timestep. + update_ops (List[TensorType]): override the batchnorm update ops to + run when applying gradients. Otherwise we run all update ops + found in the current variable scope. + explore (Optional[TensorType]): Placeholder for `explore` parameter + into call to Exploration.get_exploration_action. + timestep (Optional[TensorType]): Placeholder for the global + sampling timestep. """ self.framework = "tf" super().__init__(observation_space, action_space, config) @@ -192,33 +206,49 @@ class TFPolicy(Policy): """Return the list of all savable variables for this policy.""" return self.model.variables() - def get_placeholder(self, name): + def get_placeholder(self, name) -> "tf1.placeholder": """Returns the given action or loss input placeholder by name. If the loss has not been initialized and a loss input placeholder is requested, an error is raised. + + Args: + name (str): The name of the placeholder to return. One of + SampleBatch.CUR_OBS|PREV_ACTION/REWARD or a valid key from + `self._loss_input_dict`. + + Returns: + tf1.placeholder: The placeholder under the given str key. """ - obs_inputs = { - SampleBatch.CUR_OBS: self._obs_input, - SampleBatch.PREV_ACTIONS: self._prev_action_input, - SampleBatch.PREV_REWARDS: self._prev_reward_input, - } - if name in obs_inputs: - return obs_inputs[name] + if name == SampleBatch.CUR_OBS: + return self._obs_input + elif name == SampleBatch.PREV_ACTIONS: + return self._prev_action_input + elif name == SampleBatch.PREV_REWARDS: + return self._prev_reward_input assert self._loss_input_dict is not None, \ "Should have set this before get_placeholder can be called" return self._loss_input_dict[name] - def get_session(self): + def get_session(self) -> "tf1.Session": """Returns a reference to the TF session for this policy.""" return self._sess - def loss_initialized(self): + def loss_initialized(self) -> bool: """Returns whether the loss function has been initialized.""" return self._loss is not None - def _initialize_loss(self, loss, loss_inputs): + def _initialize_loss(self, + loss: TensorType, + loss_inputs: List[Tuple[str, TensorType]]) -> None: + """Initializes the loss op from given loss tensor and placeholders. + + Args: + loss (TensorType): The loss op generated by some loss function. + loss_inputs (List[Tuple[str, TensorType]]): The list of Tuples: + (name, tf1.placeholders) needed for calculating the loss. + """ self._loss_inputs = loss_inputs self._loss_input_dict = dict(self._loss_inputs) for i, ph in enumerate(self._state_inputs): @@ -270,16 +300,18 @@ class TFPolicy(Policy): self._optimizer.variables(), self._sess) @override(Policy) - def compute_actions(self, - obs_batch, - state_batches=None, - prev_action_batch=None, - prev_reward_batch=None, - info_batch=None, - episodes=None, - explore=None, - timestep=None, - **kwargs): + def compute_actions( + self, + obs_batch: Union[List[TensorType], TensorType], + state_batches: Optional[List[TensorType]] = None, + prev_action_batch: Union[List[TensorType], TensorType] = None, + prev_reward_batch: Union[List[TensorType], TensorType] = None, + info_batch: Optional[Dict[str, list]] = None, + episodes: Optional[List["MultiAgentEpisode"]] = None, + explore: Optional[bool] = None, + timestep: Optional[int] = None, + **kwargs): + explore = explore if explore is not None else self.config["explore"] timestep = timestep if timestep is not None else self.global_timestep @@ -299,12 +331,16 @@ class TFPolicy(Policy): return fetched @override(Policy) - def compute_log_likelihoods(self, - actions, - obs_batch, - state_batches=None, - prev_action_batch=None, - prev_reward_batch=None): + def compute_log_likelihoods( + self, + actions: Union[List[TensorType], TensorType], + obs_batch: Union[List[TensorType], TensorType], + state_batches: Optional[List[TensorType]] = None, + prev_action_batch: Optional[ + Union[List[TensorType], TensorType]] = None, + prev_reward_batch: Optional[ + Union[List[TensorType], TensorType]] = None) -> TensorType: + if self._log_likelihood is None: raise ValueError("Cannot compute log-prob/likelihood w/o a " "self._log_likelihood op!") @@ -341,40 +377,51 @@ class TFPolicy(Policy): return builder.get(fetches)[0] @override(Policy) - def compute_gradients(self, postprocessed_batch): - assert self.loss_initialized() - builder = TFRunBuilder(self._sess, "compute_gradients") - fetches = self._build_compute_gradients(builder, postprocessed_batch) - return builder.get(fetches) - - @override(Policy) - def apply_gradients(self, gradients): - assert self.loss_initialized() - builder = TFRunBuilder(self._sess, "apply_gradients") - fetches = self._build_apply_gradients(builder, gradients) - builder.get(fetches) - - @override(Policy) - def learn_on_batch(self, postprocessed_batch): + @DeveloperAPI + def learn_on_batch(self, postprocessed_batch: SampleBatch) -> Dict[ + str, TensorType]: assert self.loss_initialized() builder = TFRunBuilder(self._sess, "learn_on_batch") fetches = self._build_learn_on_batch(builder, postprocessed_batch) return builder.get(fetches) @override(Policy) - def get_exploration_info(self): + @DeveloperAPI + def compute_gradients( + self, + postprocessed_batch: SampleBatch) -> \ + Tuple[ModelGradients, Dict[str, TensorType]]: + assert self.loss_initialized() + builder = TFRunBuilder(self._sess, "compute_gradients") + fetches = self._build_compute_gradients(builder, postprocessed_batch) + return builder.get(fetches) + + @override(Policy) + @DeveloperAPI + def apply_gradients(self, gradients: ModelGradients) -> None: + assert self.loss_initialized() + builder = TFRunBuilder(self._sess, "apply_gradients") + fetches = self._build_apply_gradients(builder, gradients) + builder.get(fetches) + + @override(Policy) + @DeveloperAPI + def get_exploration_info(self) -> Dict[str, TensorType]: return self.exploration.get_info(sess=self.get_session()) @override(Policy) - def get_weights(self): + @DeveloperAPI + def get_weights(self) -> Union[Dict[str, TensorType], List[TensorType]]: return self._variables.get_weights() @override(Policy) - def set_weights(self, weights): + @DeveloperAPI + def set_weights(self, weights) -> None: return self._variables.set_weights(weights) @override(Policy) - def get_state(self): + @DeveloperAPI + def get_state(self) -> Union[Dict[str, TensorType], List[TensorType]]: # For tf Policies, return Policy weights and optimizer var values. state = super().get_state() if self._optimizer_variables and \ @@ -384,7 +431,8 @@ class TFPolicy(Policy): return state @override(Policy) - def set_state(self, state): + @DeveloperAPI + def set_state(self, state) -> None: state = state.copy() # shallow copy # Set optimizer vars first. optimizer_vars = state.pop("_optimizer_variables", None) @@ -394,7 +442,8 @@ class TFPolicy(Policy): super().set_state(state) @override(Policy) - def export_model(self, export_dir): + @DeveloperAPI + def export_model(self, export_dir: str) -> None: """Export tensorflow graph to export_dir for serving.""" with self._sess.graph.as_default(): builder = tf1.saved_model.builder.SavedModelBuilder(export_dir) @@ -407,7 +456,9 @@ class TFPolicy(Policy): builder.save() @override(Policy) - def export_checkpoint(self, export_dir, filename_prefix="model"): + @DeveloperAPI + def export_checkpoint(self, export_dir: str, + filename_prefix: str = "model") -> None: """Export tensorflow checkpoint to export_dir.""" try: os.makedirs(export_dir) @@ -421,7 +472,8 @@ class TFPolicy(Policy): saver.save(self._sess, save_path) @override(Policy) - def import_model_from_h5(self, import_file): + @DeveloperAPI + def import_model_from_h5(self, import_file: str) -> None: """Imports weights into tf model.""" # Make sure the session is the right one (see issue #7046). with self._sess.graph.as_default(): @@ -429,31 +481,53 @@ class TFPolicy(Policy): return self.model.import_from_h5(import_file) @DeveloperAPI - def copy(self, existing_inputs): + def copy(self, + existing_inputs: List[Tuple[str, "tf1.placeholder"]]) -> \ + "TFPolicy": """Creates a copy of self using existing input placeholders. - Optional, only required to work with the multi-GPU optimizer.""" + Optional: Only required to work with the multi-GPU optimizer. + + Args: + existing_inputs (List[Tuple[str, tf1.placeholder]]): Dict mapping + names (str) to tf1.placeholders to re-use (share) with the + returned copy of self. + + Returns: + TFPolicy: A copy of self. + """ raise NotImplementedError @override(Policy) - def is_recurrent(self): + @DeveloperAPI + def is_recurrent(self) -> bool: return len(self._state_inputs) > 0 @override(Policy) - def num_state_tensors(self): + @DeveloperAPI + def num_state_tensors(self) -> int: return len(self._state_inputs) @DeveloperAPI - def extra_compute_action_feed_dict(self): - """Extra dict to pass to the compute actions session run.""" + def extra_compute_action_feed_dict(self) -> Dict[TensorType, TensorType]: + """Extra dict to pass to the compute actions session run. + + Returns: + Dict[TensorType, TensorType]: A feed dict to be added to the + feed_dict passed to the compute_actions session.run() call. + """ return {} @DeveloperAPI - def extra_compute_action_fetches(self): + def extra_compute_action_fetches(self) -> Dict[str, TensorType]: """Extra values to fetch and return from compute_actions(). By default we return action probability/log-likelihood info and action distribution inputs (if present). + + Returns: + Dict[str, TensorType]: An extra fetch-dict to be passed to and + returned from the compute_actions() call. """ extra_fetches = {} # Action-logp and action-prob. @@ -466,38 +540,70 @@ class TFPolicy(Policy): return extra_fetches @DeveloperAPI - def extra_compute_grad_feed_dict(self): - """Extra dict to pass to the compute gradients session run.""" + def extra_compute_grad_feed_dict(self) -> Dict[TensorType, TensorType]: + """Extra dict to pass to the compute gradients session run. + + Returns: + Dict[TensorType, TensorType]: Extra feed_dict to be passed to the + compute_gradients Session.run() call. + """ return {} # e.g, kl_coeff @DeveloperAPI - def extra_compute_grad_fetches(self): - """Extra values to fetch and return from compute_gradients().""" + def extra_compute_grad_fetches(self) -> Dict[str, any]: + """Extra values to fetch and return from compute_gradients(). + + Returns: + Dict[str, any]: Extra fetch dict to be added to the fetch dict + of the compute_gradients Session.run() call. + """ return {LEARNER_STATS_KEY: {}} # e.g, stats, td error, etc. @DeveloperAPI - def optimizer(self): - """TF optimizer to use for policy optimization.""" + def optimizer(self) -> "tf.keras.optimizers.Optimizer": + """TF optimizer to use for policy optimization. + + Returns: + tf.keras.optimizers.Optimizer: The local optimizer to use for this + Policy's Model. + """ if hasattr(self, "config"): return tf1.train.AdamOptimizer(learning_rate=self.config["lr"]) else: return tf1.train.AdamOptimizer() @DeveloperAPI - def gradients(self, optimizer, loss): - """Override for custom gradient computation.""" + def gradients(self, + optimizer: "tf.keras.optimizers.Optimizer", + loss: TensorType) -> List[Tuple[TensorType, TensorType]]: + """Override this for a custom gradient computation behavior. + + Returns: + List[Tuple[TensorType, TensorType]]: List of tuples with grad + values and the grad-value's corresponding tf.variable in it. + """ return optimizer.compute_gradients(loss) @DeveloperAPI - def build_apply_op(self, optimizer, grads_and_vars): - """Override for custom gradient apply computation.""" + def build_apply_op( + self, + optimizer: "tf.keras.optimizers.Optimizer", + grads_and_vars: List[Tuple[TensorType, TensorType]]) -> \ + "tf.Operation": + """Override this for a custom gradient apply computation behavior. - # specify global_step for TD3 which needs to count the num updates + Args: + optimizer (tf.keras.optimizers.Optimizer): The local tf optimizer + to use for applying the grads and vars. + grads_and_vars (List[Tuple[TensorType, TensorType]]): List of + tuples with grad values and the grad-value's corresponding + tf.variable in it. + """ + # Specify global_step for TD3 which needs to count the num updates. return optimizer.apply_gradients( self._grads_and_vars, global_step=tf1.train.get_or_create_global_step()) - @DeveloperAPI def _get_is_training_placeholder(self): """Get the placeholder for _is_training, i.e., for batch norm layers. @@ -670,7 +776,7 @@ class TFPolicy(Policy): def _get_loss_inputs_dict(self, batch, shuffle): """Return a feed dict from a batch. - Arguments: + Args: batch (SampleBatch): batch of data to derive inputs from shuffle (bool): whether to shuffle batch sequences. Shuffle may be done in-place. This only makes sense if you're further diff --git a/rllib/policy/tf_policy_template.py b/rllib/policy/tf_policy_template.py index c355e6f4d..81b3ff0b2 100644 --- a/rllib/policy/tf_policy_template.py +++ b/rllib/policy/tf_policy_template.py @@ -131,10 +131,10 @@ def build_tf_policy(name, DynamicTFPolicy.__init__( self, - obs_space, - action_space, - config, - loss_fn, + obs_space=obs_space, + action_space=action_space, + config=config, + loss_fn=loss_fn, stats_fn=stats_fn, grad_stats_fn=grad_stats_fn, before_loss_init=before_loss_init_wrapper, diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index a94f8d6d9..bdada9de1 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -1,8 +1,11 @@ import functools +import gym import numpy as np import time -from typing import Dict, List, Optional +from typing import Callable, Dict, List, Optional, Tuple, Union +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY from ray.rllib.policy.sample_batch import SampleBatch @@ -15,11 +18,13 @@ from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule from ray.rllib.utils.torch_ops import convert_to_non_torch_type, \ convert_to_torch_tensor from ray.rllib.utils.tracking_dict import UsageTrackingDict -from ray.rllib.utils.types import AgentID +from ray.rllib.utils.types import AgentID, ModelGradients, ModelWeights, \ + TensorType, TrainerConfigDict torch, _ = try_import_torch() +@DeveloperAPI class TorchPolicy(Policy): """Template for a PyTorch policy and loss to use with RLlib. @@ -33,49 +38,63 @@ class TorchPolicy(Policy): dist_class (type): Torch action distribution class. """ + @DeveloperAPI def __init__(self, - observation_space, - action_space, - config, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + config: TrainerConfigDict, *, - model, - loss, - action_distribution_class, - action_sampler_fn=None, - action_distribution_fn=None, - max_seq_len=20, - get_batch_divisibility_req=None): + model: ModelV2, + loss: Callable[ + [Policy, ModelV2, type, SampleBatch], TensorType], + action_distribution_class: TorchDistributionWrapper, + action_sampler_fn: Callable[ + [TensorType, List[TensorType]], Tuple[ + TensorType, TensorType]] = None, + action_distribution_fn: Optional[Callable[ + [Policy, ModelV2, TensorType, TensorType, TensorType], + Tuple[TensorType, type, List[TensorType]]]] = None, + max_seq_len: int = 20, + get_batch_divisibility_req: Optional[int] = None): """Build a policy from policy and loss torch modules. Note that model will be placed on GPU device if CUDA_VISIBLE_DEVICES is set. Only single GPU is supported for now. - Arguments: - observation_space (gym.Space): observation space of the policy. - action_space (gym.Space): action space of the policy. - config (dict): The Policy config dict. - model (nn.Module): PyTorch policy module. Given observations as + Args: + observation_space (gym.spaces.Space): observation space of the + policy. + action_space (gym.spaces.Space): action space of the policy. + config (TrainerConfigDict): The Policy config dict. + model (ModelV2): PyTorch policy module. Given observations as input, this module must return a list of outputs where the first item is action logits, and the rest can be any value. - loss (func): Function that takes (policy, model, dist_class, - train_batch) and returns a single scalar loss. - action_distribution_class (ActionDistribution): Class for action - distribution. - action_sampler_fn (Optional[callable]): A callable returning a - sampled action and its log-likelihood given some (obs and - state) inputs. - action_distribution_fn (Optional[callable]): A callable returning - distribution inputs (parameters), a dist-class to generate an - action distribution object from, and internal-state outputs - (or an empty list if not applicable). + loss (Callable[[Policy, ModelV2, type, SampleBatch], TensorType]): + Function that takes (policy, model, dist_class, train_batch) + and returns a single scalar loss. + action_distribution_class (TorchDistributionWrapper): Class for + a torch action distribution. + action_sampler_fn (Callable[[TensorType, List[TensorType]], + Tuple[TensorType, TensorType]]): A callable returning a + sampled action and its log-likelihood given Policy, ModelV2, + input_dict, explore, timestep, and is_training. + action_distribution_fn (Optional[Callable[[Policy, ModelV2, + Dict[str, TensorType], TensorType, TensorType], + Tuple[TensorType, type, List[TensorType]]]]): A callable + returning distribution inputs (parameters), a dist-class to + generate an action distribution object from, and + internal-state outputs (or an empty list if not applicable). Note: No Exploration hooks have to be called from within `action_distribution_fn`. It's should only perform a simple forward pass through some model. - If None, pass inputs through `self.model()` to get the - distribution inputs. + If None, pass inputs through `self.model()` to get distribution + inputs. + The callable takes as inputs: Policy, ModelV2, input_dict, + explore, timestep, is_training. max_seq_len (int): Max sequence length for LSTM training. - get_batch_divisibility_req (Optional[callable]): Optional callable - that returns the divisibility requirement for sample batches. + get_batch_divisibility_req (Optional[Callable[[Policy], int]]]): + Optional callable that returns the divisibility requirement + for sample batches given the Policy. """ self.framework = "torch" super().__init__(observation_space, action_space, config) @@ -100,16 +119,19 @@ class TorchPolicy(Policy): else 1 @override(Policy) - def compute_actions(self, - obs_batch, - state_batches=None, - prev_action_batch=None, - prev_reward_batch=None, - info_batch=None, - episodes=None, - explore=None, - timestep=None, - **kwargs): + @DeveloperAPI + def compute_actions( + self, + obs_batch: Union[List[TensorType], TensorType], + state_batches: Optional[List[TensorType]] = None, + prev_action_batch: Union[List[TensorType], TensorType] = None, + prev_reward_batch: Union[List[TensorType], TensorType] = None, + info_batch: Optional[Dict[str, list]] = None, + episodes: Optional[List["MultiAgentEpisode"]] = None, + explore: Optional[bool] = None, + timestep: Optional[int] = None, + **kwargs) -> \ + Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: explore = explore if explore is not None else self.config["explore"] timestep = timestep if timestep is not None else self.global_timestep @@ -149,7 +171,8 @@ class TorchPolicy(Policy): other_trajectories: Dict[AgentID, "Trajectory"], explore: bool = None, timestep: Optional[int] = None, - **kwargs): + **kwargs) -> \ + Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: explore = explore if explore is not None else self.config["explore"] timestep = timestep if timestep is not None else self.global_timestep @@ -237,12 +260,16 @@ class TorchPolicy(Policy): return actions, state_out, extra_fetches, logp @override(Policy) - def compute_log_likelihoods(self, - actions, - obs_batch, - state_batches=None, - prev_action_batch=None, - prev_reward_batch=None): + @DeveloperAPI + def compute_log_likelihoods( + self, + actions: Union[List[TensorType], TensorType], + obs_batch: Union[List[TensorType], TensorType], + state_batches: Optional[List[TensorType]] = None, + prev_action_batch: Optional[ + Union[List[TensorType], TensorType]] = None, + prev_reward_batch: Optional[ + Union[List[TensorType], TensorType]] = None) -> TensorType: if self.action_sampler_fn and self.action_distribution_fn is None: raise ValueError("Cannot compute log-prob/likelihood w/o an " @@ -282,7 +309,9 @@ class TorchPolicy(Policy): return log_likelihoods @override(Policy) - def learn_on_batch(self, postprocessed_batch): + @DeveloperAPI + def learn_on_batch(self, postprocessed_batch: SampleBatch) -> Dict[ + str, TensorType]: # Get batch ready for RNNs, if applicable. pad_batch_to_sequences_of_same_size( postprocessed_batch, @@ -340,7 +369,9 @@ class TorchPolicy(Policy): return {LEARNER_STATS_KEY: grad_info} @override(Policy) - def compute_gradients(self, postprocessed_batch): + @DeveloperAPI + def compute_gradients(self, + postprocessed_batch: SampleBatch) -> ModelGradients: train_batch = self._lazy_tensor_dict(postprocessed_batch) loss_out = force_list( self._loss(self, self.model, self.dist_class, train_batch)) @@ -367,7 +398,8 @@ class TorchPolicy(Policy): return grads, {LEARNER_STATS_KEY: grad_info} @override(Policy) - def apply_gradients(self, gradients): + @DeveloperAPI + def apply_gradients(self, gradients: ModelGradients) -> None: # TODO(sven): Not supported for multiple optimizers yet. assert len(self._optimizers) == 1 for g, p in zip(gradients, self.model.parameters()): @@ -377,19 +409,39 @@ class TorchPolicy(Policy): self._optimizers[0].step() @override(Policy) - def get_weights(self): + @DeveloperAPI + def get_weights(self) -> ModelWeights: return { k: v.cpu().detach().numpy() for k, v in self.model.state_dict().items() } @override(Policy) - def set_weights(self, weights): + @DeveloperAPI + def set_weights(self, weights: ModelWeights) -> None: weights = convert_to_torch_tensor(weights, device=self.device) self.model.load_state_dict(weights) @override(Policy) - def get_state(self): + @DeveloperAPI + def is_recurrent(self) -> bool: + return len(self.model.get_initial_state()) > 0 + + @override(Policy) + @DeveloperAPI + def num_state_tensors(self) -> int: + return len(self.model.get_initial_state()) + + @override(Policy) + @DeveloperAPI + def get_initial_state(self) -> List[TensorType]: + return [ + s.cpu().detach().numpy() for s in self.model.get_initial_state() + ] + + @override(Policy) + @DeveloperAPI + def get_state(self) -> Union[Dict[str, TensorType], List[TensorType]]: state = super().get_state() state["_optimizer_variables"] = [] for i, o in enumerate(self._optimizers): @@ -397,7 +449,8 @@ class TorchPolicy(Policy): return state @override(Policy) - def set_state(self, state): + @DeveloperAPI + def set_state(self, state: object) -> None: state = state.copy() # shallow copy # Set optimizer vars first. optimizer_vars = state.pop("_optimizer_variables", None) @@ -408,21 +461,11 @@ class TorchPolicy(Policy): # Then the Policy's (NN) weights. super().set_state(state) - @override(Policy) - def is_recurrent(self): - return len(self.model.get_initial_state()) > 0 - - @override(Policy) - def num_state_tensors(self): - return len(self.model.get_initial_state()) - - @override(Policy) - def get_initial_state(self): - return [ - s.cpu().detach().numpy() for s in self.model.get_initial_state() - ] - - def extra_grad_process(self, optimizer, loss): + @DeveloperAPI + def extra_grad_process( + self, + optimizer: "torch.optim.Optimizer", + loss: TensorType): """Called after each optimizer.zero_grad() + loss.backward() call. Called for each self._optimizers/loss-value pair. @@ -431,60 +474,94 @@ class TorchPolicy(Policy): Args: optimizer (torch.optim.Optimizer): A torch optimizer object. - loss (torch.Tensor): The loss tensor associated with the optimizer. + loss (TensorType): The loss tensor associated with the optimizer. Returns: - dict: An info dict. + Dict[str, TensorType]: An dict with information on the gradient + processing step. """ return {} - def extra_action_out(self, input_dict, state_batches, model, action_dist): + @DeveloperAPI + def extra_action_out( + self, + input_dict: Dict[str, TensorType], + state_batches: List[TensorType], + model: TorchModelV2, + action_dist: TorchDistributionWrapper) -> Dict[str, TensorType]: """Returns dict of extra info to include in experience batch. Args: - input_dict (dict): Dict of model input tensors. - state_batches (list): List of state tensors. - model (TorchModelV2): Reference to the model. - action_dist (TorchActionDistribution): Torch action dist object + input_dict (Dict[str, TensorType]): Dict of model input tensors. + state_batches (List[TensorType]): List of state tensors. + model (TorchModelV2): Reference to the model object. + action_dist (TorchDistributionWrapper): Torch action dist object to get log-probs (e.g. for already sampled actions). + + Returns: + Dict[str, TensorType]: Extra outputs to return in a + compute_actions() call (3rd return value). """ return {} - def extra_grad_info(self, train_batch): - """Return dict of extra grad info.""" + @DeveloperAPI + def extra_grad_info(self, train_batch: SampleBatch) -> Dict[ + str, TensorType]: + """Return dict of extra grad info. + + Args: + train_batch (SampleBatch): The training batch for which to produce + extra grad info for. + + Returns: + Dict[str, TensorType]: The info dict carrying grad info per str + key. + """ return {} - def optimizer(self): - """Custom PyTorch optimizer to use.""" + @DeveloperAPI + def optimizer(self) -> Union[ + List["torch.optim.Optimizer"], "torch.optim.Optimizer"]: + """Custom the local PyTorch optimizer(s) to use. + + Returns: + Union[List[torch.optim.Optimizer], torch.optim.Optimizer]: + The local PyTorch optimizer(s) to use for this Policy. + """ if hasattr(self, "config"): return torch.optim.Adam( self.model.parameters(), lr=self.config["lr"]) else: return torch.optim.Adam(self.model.parameters()) + @override(Policy) + @DeveloperAPI + def export_model(self, export_dir: str) -> None: + """TODO(sven): implement for torch. + """ + raise NotImplementedError + + @override(Policy) + @DeveloperAPI + def export_checkpoint(self, export_dir: str) -> None: + """TODO(sven): implement for torch. + """ + raise NotImplementedError + + @override(Policy) + @DeveloperAPI + def import_model_from_h5(self, import_file: str) -> None: + """Imports weights into torch model.""" + return self.model.import_from_h5(import_file) + def _lazy_tensor_dict(self, postprocessed_batch): train_batch = UsageTrackingDict(postprocessed_batch) train_batch.set_get_interceptor(convert_to_torch_tensor) return train_batch - @override(Policy) - def export_model(self, export_dir): - """TODO(sven): implement for torch. - """ - raise NotImplementedError - - @override(Policy) - def export_checkpoint(self, export_dir): - """TODO(sven): implement for torch. - """ - raise NotImplementedError - - @override(Policy) - def import_model_from_h5(self, import_file): - """Imports weights into torch model.""" - return self.model.import_from_h5(import_file) - +# TODO: (sven) Unify hyperparam annealing procedures across RLlib (tf/torch) +# and for all possible hyperparams, not just lr. @DeveloperAPI class LearningRateSchedule: """Mixin for TFPolicy that adds a learning rate schedule.""" diff --git a/rllib/tests/test_dependency.py b/rllib/tests/test_dependency_tf.py similarity index 100% rename from rllib/tests/test_dependency.py rename to rllib/tests/test_dependency_tf.py diff --git a/rllib/tests/test_dependency_torch.py b/rllib/tests/test_dependency_torch.py index 206302b49..12da36e22 100755 --- a/rllib/tests/test_dependency_torch.py +++ b/rllib/tests/test_dependency_torch.py @@ -23,3 +23,5 @@ if __name__ == "__main__": # Clean up. del os.environ["RLLIB_TEST_NO_TORCH_IMPORT"] + + print("ok") diff --git a/rllib/utils/schedules/piecewise_schedule.py b/rllib/utils/schedules/piecewise_schedule.py index b37fb1839..8d3f22c6a 100644 --- a/rllib/utils/schedules/piecewise_schedule.py +++ b/rllib/utils/schedules/piecewise_schedule.py @@ -5,8 +5,8 @@ from ray.rllib.utils.schedules.schedule import Schedule tf1, tf, tfv = try_import_tf() -def _linear_interpolation(l, r, alpha): - return l + alpha * (r - l) +def _linear_interpolation(left, right, alpha): + return left + alpha * (right - left) class PiecewiseSchedule(Schedule): diff --git a/rllib/utils/types.py b/rllib/utils/types.py index a12e78243..3f2e02f89 100644 --- a/rllib/utils/types.py +++ b/rllib/utils/types.py @@ -2,6 +2,9 @@ from typing import Any, Dict, List, Tuple, Union import gym # Represents a fully filled out config of a Trainer class. +# Note: Policy config dicts are usually the same as TrainerConfigDict, but +# parts of it may sometimes be altered in e.g. a multi-agent setup, +# where we have >1 Policies in the same Trainer. TrainerConfigDict = dict # A trainer config dict that only has overrides. It needs to be combined with @@ -63,8 +66,13 @@ GradInfoDict = dict # policy id. LearnerStatsDict = dict -# Type of dict returned by compute_gradients() representing model gradients. -ModelGradients = dict +# Represents a generic tensor type. +# This could be an np.ndarray, tf.Tensor, or a torch.Tensor. +TensorType = Any + +# List of grads+var tuples (tf) or list of gradient tensors (torch) +# representing model gradients and returned by compute_gradients(). +ModelGradients = Union[List[Tuple[TensorType, TensorType]], List[TensorType]] # Type of dict returned by get_weights() representing model weights. ModelWeights = dict @@ -72,9 +80,6 @@ ModelWeights = dict # Some kind of sample batch. SampleBatchType = Union["SampleBatch", "MultiAgentBatch"] -# Represents a generic tensor type. -TensorType = Any - # Either a plain tensor, or a dict or tuple of tensors (or StructTensors). TensorStructType = Union[TensorType, dict, tuple]