[RLlib] Type annotations for policy. (#9248)

This commit is contained in:
Sven Mika
2020-07-05 13:09:51 +02:00
committed by GitHub
parent b71c912da7
commit f43d934817
14 changed files with 862 additions and 465 deletions
+2 -2
View File
@@ -1134,10 +1134,10 @@ py_test(
)
py_test(
name = "tests/test_dependency",
name = "tests/test_dependency_tf",
tags = ["tests_dir", "tests_dir_D"],
size = "small",
srcs = ["tests/test_dependency.py"]
srcs = ["tests/test_dependency_tf.py"]
)
py_test(
+1 -1
View File
@@ -179,7 +179,7 @@ class TorchCustomLossModel(TorchModelV2, nn.Module):
# Add the imitation loss to each already calculated policy loss term.
# Alternatively (if custom loss has its own optimizer):
# return policy_loss + [10 * self.imitation_loss]
return [l + 10 * self.imitation_loss for l in policy_loss]
return [loss_ + 10 * self.imitation_loss for loss_ in policy_loss]
def custom_stats(self):
return {
+1 -1
View File
@@ -251,7 +251,7 @@ class TestModules(unittest.TestCase):
self.train_tf_model(
model, [x] + init_state,
[y, value_labels, memory_labels, mlp_labels],
num_epochs=50,
num_epochs=200,
minibatch_size=B)
+94 -54
View File
@@ -1,27 +1,34 @@
"""Graph mode TF policy built using build_tf_policy()."""
from collections import OrderedDict
import gym
import logging
import numpy as np
from typing import Callable, Dict, List, Optional, Tuple
from ray.util.debug import log_once
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.utils.annotations import override
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.debug import summarize
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.tracking_dict import UsageTrackingDict
from ray.rllib.utils.types import ModelGradients, TensorType, TrainerConfigDict
tf1, tf, tfv = try_import_tf()
logger = logging.getLogger(__name__)
@DeveloperAPI
class DynamicTFPolicy(TFPolicy):
"""A TFPolicy that auto-defines placeholders dynamically at runtime.
Do not sub-class this class directly (neither should you sub-class
TFPolicy), but rather use rllib.policy.tf_policy_template.build_tf_policy
to generate your custom tf (graph-mode or eager) Policy classes.
Initialization of this class occurs in two phases.
* Phase 1: the model is created and model variables are initialized.
* Phase 2: a fake batch of data is created, sent to the trajectory
@@ -39,61 +46,91 @@ class DynamicTFPolicy(TFPolicy):
dist_class (type): TF action distribution class
"""
@DeveloperAPI
def __init__(self,
obs_space,
action_space,
config,
loss_fn,
stats_fn=None,
grad_stats_fn=None,
before_loss_init=None,
make_model=None,
action_sampler_fn=None,
action_distribution_fn=None,
existing_inputs=None,
existing_model=None,
get_batch_divisibility_req=None,
obs_include_prev_action_reward=True):
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict,
loss_fn: Callable[
[Policy, ModelV2, type, SampleBatch], TensorType],
*,
stats_fn: Optional[Callable[[Policy, SampleBatch],
Dict[str, TensorType]]] = None,
grad_stats_fn: Optional[Callable[
[Policy, SampleBatch, ModelGradients],
Dict[str, TensorType]]] = None,
before_loss_init: Optional[Callable[
[Policy, gym.spaces.Space, gym.spaces.Space,
TrainerConfigDict], None]] = None,
make_model: Optional[Callable[
[Policy, gym.spaces.Space, gym.spaces.Space,
TrainerConfigDict], ModelV2]] = None,
action_sampler_fn: Optional[Callable[
[TensorType, List[TensorType]], Tuple[
TensorType, TensorType]]] = None,
action_distribution_fn: Optional[Callable[
[Policy, ModelV2, TensorType, TensorType, TensorType],
Tuple[TensorType, type, List[TensorType]]]] = None,
existing_inputs: Optional[Dict[
str, "tf1.placeholder"]] = None,
existing_model: Optional[ModelV2] = None,
get_batch_divisibility_req: Optional[int] = None,
obs_include_prev_action_reward: bool = True):
"""Initialize a dynamic TF policy.
Arguments:
observation_space (gym.Space): Observation space of the policy.
action_space (gym.Space): Action space of the policy.
config (dict): Policy-specific configuration data.
loss_fn (func): function that returns a loss tensor the policy
graph, and dict of experience tensor placeholders
stats_fn (func): optional function that returns a dict of
TF fetches given the policy and batch input tensors
grad_stats_fn (func): optional function that returns a dict of
TF fetches given the policy and loss gradient tensors
before_loss_init (Optional[callable]): Optional function to run
prior to loss init that takes the same arguments as __init__.
make_model (func): optional function that returns a ModelV2 object
given (policy, obs_space, action_space, config).
observation_space (gym.spaces.Space): Observation space of the
policy.
action_space (gym.spaces.Space): Action space of the policy.
config (TrainerConfigDict): Policy-specific configuration data.
loss_fn (Callable[[Policy, ModelV2, type, SampleBatch],
TensorType]): Function that returns a loss tensor for the
policy graph.
stats_fn (Optional[Callable[[Policy, SampleBatch],
Dict[str, TensorType]]]): Optional function that returns a dict
of TF fetches given the policy and batch input tensors.
grad_stats_fn (Optional[Callable[[Policy, SampleBatch,
ModelGradients], Dict[str, TensorType]]]):
Optional function that returns a dict of TF fetches given the
policy, sample batch, and loss gradient tensors.
before_loss_init (Optional[Callable[
[Policy, gym.spaces.Space, gym.spaces.Space,
TrainerConfigDict], None]]): Optional function to run prior to
loss init that takes the same arguments as __init__.
make_model (Optional[Callable[[Policy, gym.spaces.Space,
gym.spaces.Space, TrainerConfigDict], ModelV2]]): Optional
function that returns a ModelV2 object given
policy, obs_space, action_space, and policy config.
All policy variables should be created in this function. If not
specified, a default model will be created.
action_sampler_fn (Optional[callable]): An optional callable
returning a tuple of action and action prob tensors given
(policy, model, input_dict, obs_space, action_space, config).
If None, a default action distribution will be used.
action_distribution_fn (Optional[callable]): A callable returning
distribution inputs (parameters), a dist-class to generate an
action distribution object from, and internal-state outputs
(or an empty list if not applicable).
action_sampler_fn (Optional[Callable[[Policy, ModelV2, Dict[
str, TensorType], TensorType, TensorType], Tuple[TensorType,
TensorType]]]): A callable returning a sampled action and its
log-likelihood given Policy, ModelV2, input_dict, explore,
timestep, and is_training.
action_distribution_fn (Optional[Callable[[Policy, ModelV2,
Dict[str, TensorType], TensorType, TensorType],
Tuple[TensorType, type, List[TensorType]]]]): A callable
returning distribution inputs (parameters), a dist-class to
generate an action distribution object from, and
internal-state outputs (or an empty list if not applicable).
Note: No Exploration hooks have to be called from within
`action_distribution_fn`. It's should only perform a simple
forward pass through some model.
If None, pass inputs through `self.model()` to get the
distribution inputs.
existing_inputs (OrderedDict): When copying a policy, this
specifies an existing dict of placeholders to use instead of
defining new ones
existing_model (ModelV2): when copying a policy, this specifies
an existing model to clone and share weights with
get_batch_divisibility_req (func): optional function that returns
the divisibility requirement for sample batches
obs_include_prev_action_reward (bool): whether to include the
previous action and reward in the model input
If None, pass inputs through `self.model()` to get distribution
inputs.
The callable takes as inputs: Policy, ModelV2, input_dict,
explore, timestep, is_training.
existing_inputs (Optional[Dict[str, tf1.placeholder]]): When
copying a policy, this specifies an existing dict of
placeholders to use instead of defining new ones.
existing_model (Optional[ModelV2]): When copying a policy, this
specifies an existing model to clone and share weights with.
get_batch_divisibility_req (Optional[Callable[[Policy], int]]]):
Optional callable that returns the divisibility requirement
for sample batches given the Policy.
obs_include_prev_action_reward (bool): Whether to include the
previous action and reward in the model input (default: True).
"""
self.observation_space = obs_space
self.action_space = action_space
@@ -258,10 +295,12 @@ class DynamicTFPolicy(TFPolicy):
before_loss_init(self, obs_space, action_space, config)
if not existing_inputs:
self._initialize_loss()
self._initialize_loss_dynamically()
@override(TFPolicy)
def copy(self, existing_inputs):
@DeveloperAPI
def copy(self,
existing_inputs: List[Tuple[str, "tf1.placeholder"]]) -> TFPolicy:
"""Creates a copy of self using existing input placeholders."""
# Note that there might be RNN state inputs at the end of the list
@@ -306,13 +345,14 @@ class DynamicTFPolicy(TFPolicy):
return instance
@override(Policy)
def get_initial_state(self):
@DeveloperAPI
def get_initial_state(self) -> List[TensorType]:
if self.model:
return self.model.get_initial_state()
else:
return []
def _initialize_loss(self):
def _initialize_loss_dynamically(self):
def fake_array(tensor):
shape = tensor.shape.as_list()
shape = [s if s is not None else 1 for s in shape]
@@ -404,7 +444,7 @@ class DynamicTFPolicy(TFPolicy):
self._grad_stats_fn(self, train_batch, self._grads))
self._sess.run(tf1.global_variables_initializer())
def _do_loss_init(self, train_batch):
def _do_loss_init(self, train_batch: SampleBatch):
loss = self._loss_fn(self, self.model, self.dist_class, train_batch)
if self._stats_fn:
self._stats_fetches.update(self._stats_fn(self, train_batch))
+186 -127
View File
@@ -4,13 +4,15 @@ import numpy as np
import tree
from typing import Dict, List, Optional
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.exploration.exploration import Exploration
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.from_config import from_config
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space, \
unbatch
from ray.rllib.utils.types import AgentID
from ray.rllib.utils.types import AgentID, ModelGradients, ModelWeights, \
TensorType, TrainerConfigDict, Tuple, Union
torch, _ = try_import_torch()
@@ -41,7 +43,11 @@ class Policy(metaclass=ABCMeta):
"""
@DeveloperAPI
def __init__(self, observation_space, action_space, config):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict):
"""Initialize the graph.
This is the standard constructor for policies. The policy
@@ -49,9 +55,10 @@ class Policy(metaclass=ABCMeta):
these arguments.
Args:
observation_space (gym.Space): Observation space of the policy.
action_space (gym.Space): Action space of the policy.
config (dict): Policy-specific configuration data.
observation_space (gym.spaces.Space): Observation space of the
policy.
action_space (gym.spaces.Space): Action space of the policy.
config (TrainerConfigDict): Policy-specific configuration data.
"""
self.observation_space = observation_space
self.action_space = action_space
@@ -66,78 +73,95 @@ class Policy(metaclass=ABCMeta):
@abstractmethod
@DeveloperAPI
def compute_actions(self,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
info_batch=None,
episodes=None,
explore=None,
timestep=None,
**kwargs):
def compute_actions(
self,
obs_batch: Union[List[TensorType], TensorType],
state_batches: Optional[List[TensorType]] = None,
prev_action_batch: Union[List[TensorType], TensorType] = None,
prev_reward_batch: Union[List[TensorType], TensorType] = None,
info_batch: Optional[Dict[str, list]] = None,
episodes: Optional[List["MultiAgentEpisode"]] = None,
explore: Optional[bool] = None,
timestep: Optional[int] = None,
**kwargs) -> \
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
"""Computes actions for the current policy.
Args:
obs_batch (Union[List, np.ndarray]): Batch of observations.
state_batches (Optional[list]): List of RNN state input batches,
if any.
prev_action_batch (Optional[List, np.ndarray]): Batch of previous
action values.
prev_reward_batch (Optional[List, np.ndarray]): Batch of previous
rewards.
info_batch (info): Batch of info objects.
episodes (list): MultiAgentEpisode for each obs in obs_batch.
This provides access to all of the internal episode state,
which may be useful for model-based or multiagent algorithms.
explore (bool): Whether to pick an exploitation or exploration
action (default: None -> use self.config["explore"]).
timestep (int): The current (sampling) time step.
obs_batch (Union[List[TensorType], TensorType]): Batch of
observations.
state_batches (Optional[List[TensorType]]): List of RNN state input
batches, if any.
prev_action_batch (Union[List[TensorType], TensorType]): Batch of
previous action values.
prev_reward_batch (Union[List[TensorType], TensorType]): Batch of
previous rewards.
info_batch (Optional[Dict[str, list]]): Batch of info objects.
episodes (Optional[List[MultiAgentEpisode]] ): List of
MultiAgentEpisode, one for each obs in obs_batch. This provides
access to all of the internal episode state, which may be
useful for model-based or multiagent algorithms.
explore (Optional[bool]): Whether to pick an exploitation or
exploration action. Set to None (default) for using the
value of `self.config["explore"]`.
timestep (Optional[int]): The current (sampling) time step.
Keyword Args:
kwargs: forward compatibility placeholder
Returns:
actions (np.ndarray): batch of output actions, with shape like
[BATCH_SIZE, ACTION_SHAPE].
state_outs (list): list of RNN state output batches, if any, with
shape like [STATE_SIZE, BATCH_SIZE].
info (dict): dictionary of extra feature batches, if any, with
shape like {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}.
Tuple:
actions (TensorType): Batch of output actions, with shape like
[BATCH_SIZE, ACTION_SHAPE].
state_outs (List[TensorType]): List of RNN state output
batches, if any, with shape like [STATE_SIZE, BATCH_SIZE].
info (List[dict]): Dictionary of extra feature batches, if any,
with shape like
{"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}.
"""
raise NotImplementedError
@DeveloperAPI
def compute_single_action(self,
obs,
state=None,
prev_action=None,
prev_reward=None,
info=None,
episode=None,
clip_actions=False,
explore=None,
timestep=None,
**kwargs):
def compute_single_action(
self,
obs: TensorType,
state: Optional[List[TensorType]] = None,
prev_action: Optional[TensorType] = None,
prev_reward: Optional[TensorType] = None,
info: dict = None,
episode: Optional["MultiAgentEpisode"] = None,
clip_actions: bool = False,
explore: Optional[bool] = None,
timestep: Optional[int] = None,
**kwargs) -> \
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
"""Unbatched version of compute_actions.
Arguments:
obs (obj): Single observation.
state (list): List of RNN state inputs, if any.
prev_action (obj): Previous action value, if any.
prev_reward (float): Previous reward, if any.
info (dict): info object, if any
episode (MultiAgentEpisode): this provides access to all of the
internal episode state, which may be useful for model-based or
multi-agent algorithms.
Args:
obs (TensorType): Single observation.
state (Optional[List[TensorType]]): List of RNN state inputs, if
any.
prev_action (Optional[TensorType]): Previous action value, if any.
prev_reward (Optional[TensorType]): Previous reward, if any.
info (dict): Info object, if any.
episode (Optional[MultiAgentEpisode]): this provides access to all
of the internal episode state, which may be useful for
model-based or multi-agent algorithms.
clip_actions (bool): Should actions be clipped?
explore (bool): Whether to pick an exploitation or exploration
action (default: None -> use self.config["explore"]).
timestep (int): The current (sampling) time step.
explore (Optional[bool]): Whether to pick an exploitation or
exploration action
(default: None -> use self.config["explore"]).
timestep (Optional[int]): The current (sampling) time step.
Keyword Args:
kwargs: forward compatibility placeholder
Returns:
actions (obj): single action
state_outs (list): list of RNN state outputs, if any
info (dict): dictionary of extra features, if any
Tuple:
actions (TensorType): Single action.
state_outs (List[TensorType]): List of RNN state outputs,
if any.
info (dict): Dictionary of extra features, if any.
"""
prev_action_batch = None
prev_reward_batch = None
@@ -196,7 +220,8 @@ class Policy(metaclass=ABCMeta):
other_trajectories: Dict[AgentID, "Trajectory"],
explore: bool = None,
timestep: Optional[int] = None,
**kwargs):
**kwargs) -> \
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
"""Computes actions for the current policy based on .
Note: This is an experimental API method.
@@ -215,59 +240,68 @@ class Policy(metaclass=ABCMeta):
kwargs: forward compatibility placeholder
Returns:
actions (np.ndarray): batch of output actions, with shape like
[BATCH_SIZE, ACTION_SHAPE].
state_outs (list): list of RNN state output batches, if any, with
shape like [STATE_SIZE, BATCH_SIZE].
info (dict): dictionary of extra feature batches, if any, with
shape like {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}.
Tuple:
actions (TensorType): Batch of output actions, with shape
like [BATCH_SIZE, ACTION_SHAPE].
state_outs (List[TensorType]): List of RNN state output
batches, if any, with shape like [STATE_SIZE, BATCH_SIZE].
info (dict): Dictionary of extra feature batches, if any, with
shape like
{"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}.
"""
raise NotImplementedError
@DeveloperAPI
def compute_log_likelihoods(self,
actions,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None):
def compute_log_likelihoods(
self,
actions: Union[List[TensorType], TensorType],
obs_batch: Union[List[TensorType], TensorType],
state_batches: Optional[List[TensorType]] = None,
prev_action_batch: Optional[
Union[List[TensorType], TensorType]] = None,
prev_reward_batch: Optional[
Union[List[TensorType], TensorType]] = None) -> TensorType:
"""Computes the log-prob/likelihood for a given action and observation.
Args:
actions (Union[List,np.ndarray]): Batch of actions, for which to
retrieve the log-probs/likelihoods (given all other inputs:
obs, states, ..).
obs_batch (Union[List,np.ndarray]): Batch of observations.
state_batches (Optional[list]): List of RNN state input batches,
if any.
prev_action_batch (Optional[List,np.ndarray]): Batch of previous
action values.
prev_reward_batch (Optional[List,np.ndarray]): Batch of previous
rewards.
actions (Union[List[TensorType], TensorType]): Batch of actions,
for which to retrieve the log-probs/likelihoods (given all
other inputs: obs, states, ..).
obs_batch (Union[List[TensorType], TensorType]): Batch of
observations.
state_batches (Optional[List[TensorType]]): List of RNN state input
batches, if any.
prev_action_batch (Optional[Union[List[TensorType], TensorType]]):
Batch of previous action values.
prev_reward_batch (Optional[Union[List[TensorType], TensorType]]):
Batch of previous rewards.
Returns:
log-likelihoods (np.ndarray): Batch of log probs/likelihoods, with
shape: [BATCH_SIZE].
TensorType: Batch of log probs/likelihoods, with shape:
[BATCH_SIZE].
"""
raise NotImplementedError
@DeveloperAPI
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
def postprocess_trajectory(
self,
sample_batch: SampleBatch,
other_agent_batches: Optional[
Dict[AgentID, Tuple["Policy", SampleBatch]]] = None,
episode: Optional["MultiAgentEpisode"] = None) -> SampleBatch:
"""Implements algorithm-specific trajectory postprocessing.
This will be called on each trajectory fragment computed during policy
evaluation. Each fragment is guaranteed to be only from one episode.
Arguments:
Args:
sample_batch (SampleBatch): batch of experiences for the policy,
which will contain at most one episode trajectory.
other_agent_batches (dict): In a multi-agent env, this contains a
mapping of agent ids to (policy, agent_batch) tuples
containing the policy and experiences of the other agents.
episode (MultiAgentEpisode): this provides access to all of the
episode (Optional[MultiAgentEpisode]): An optional multi-agent
episode object to provide access to all of the
internal episode state, which may be useful for model-based or
multi-agent algorithms.
@@ -277,18 +311,22 @@ class Policy(metaclass=ABCMeta):
return sample_batch
@DeveloperAPI
def learn_on_batch(self, samples):
def learn_on_batch(self, samples: SampleBatch) -> Dict[str, TensorType]:
"""Fused compute gradients and apply gradients call.
Either this or the combination of compute/apply grads must be
implemented by subclasses.
Args:
samples (SampleBatch): The SampleBatch object to learn from.
Returns:
grad_info: dictionary of extra metadata from compute_gradients().
Dict[str, TensorType]: Dictionary of extra metadata from
compute_gradients().
Examples:
>>> batch = ev.sample()
>>> ev.learn_on_batch(samples)
>>> sample_batch = ev.sample()
>>> ev.learn_on_batch(sample_batch)
"""
grads, grad_info = self.compute_gradients(samples)
@@ -296,65 +334,76 @@ class Policy(metaclass=ABCMeta):
return grad_info
@DeveloperAPI
def compute_gradients(self, postprocessed_batch):
def compute_gradients(self, postprocessed_batch: SampleBatch) -> \
Tuple[ModelGradients, Dict[str, TensorType]]:
"""Computes gradients against a batch of experiences.
Either this or learn_on_batch() must be implemented by subclasses.
Args:
postprocessed_batch (SampleBatch): The SampleBatch object to use
for calculating gradients.
Returns:
grads (list): List of gradient output values
info (dict): Extra policy-specific values
Tuple[ModelGradients, Dict[str, TensorType]]:
- List of gradient output values.
- Extra policy-specific info values.
"""
raise NotImplementedError
@DeveloperAPI
def apply_gradients(self, gradients):
def apply_gradients(self, gradients: ModelGradients) -> None:
"""Applies previously computed gradients.
Either this or learn_on_batch() must be implemented by subclasses.
Args:
gradients (ModelGradients): The already calculated gradients to
apply to this Policy.
"""
raise NotImplementedError
@DeveloperAPI
def get_weights(self):
def get_weights(self) -> ModelWeights:
"""Returns model weights.
Returns:
weights (obj): Serializable copy or view of model weights
ModelWeights: Serializable copy or view of model weights.
"""
raise NotImplementedError
@DeveloperAPI
def set_weights(self, weights):
def set_weights(self, weights: ModelWeights) -> None:
"""Sets model weights.
Arguments:
weights (obj): Serializable copy or view of model weights
Args:
weights (ModelWeights): Serializable copy or view of model weights.
"""
raise NotImplementedError
@DeveloperAPI
def get_exploration_info(self):
def get_exploration_info(self) -> Dict[str, TensorType]:
"""Returns the current exploration information of this policy.
This information depends on the policy's Exploration object.
Returns:
any: Serializable information on the `self.exploration` object.
Dict[str, TensorType]: Serializable information on the
`self.exploration` object.
"""
return self.exploration.get_info()
@DeveloperAPI
def is_recurrent(self):
def is_recurrent(self) -> bool:
"""Whether this Policy holds a recurrent Model.
Returns:
bool: True if this Policy has-a RNN-based Model.
"""
return 0
return False
@DeveloperAPI
def num_state_tensors(self):
def num_state_tensors(self) -> int:
"""The number of internal states needed by the RNN-Model of the Policy.
Returns:
@@ -363,73 +412,83 @@ class Policy(metaclass=ABCMeta):
return 0
@DeveloperAPI
def get_initial_state(self):
"""Returns initial RNN state for the current policy."""
def get_initial_state(self) -> List[TensorType]:
"""Returns initial RNN state for the current policy.
Returns:
List[TensorType]: Initial RNN state for the current policy.
"""
return []
@DeveloperAPI
def get_state(self):
def get_state(self) -> Union[Dict[str, TensorType], List[TensorType]]:
"""Saves all local state.
Returns:
state (obj): Serialized local state.
Union[Dict[str, TensorType], List[TensorType]]: Serialized local
state.
"""
return self.get_weights()
@DeveloperAPI
def set_state(self, state):
def set_state(self, state: object) -> None:
"""Restores all local state.
Arguments:
Args:
state (obj): Serialized local state.
"""
self.set_weights(state)
@DeveloperAPI
def on_global_var_update(self, global_vars):
def on_global_var_update(self, global_vars: Dict[str, TensorType]) -> None:
"""Called on an update to global vars.
Arguments:
global_vars (dict): Global variables broadcast from the driver.
Args:
global_vars (Dict[str, TensorType]): Global variables by str key,
broadcast from the driver.
"""
# Store the current global time step (sum over all policies' sample
# steps).
self.global_timestep = global_vars["timestep"]
@DeveloperAPI
def export_model(self, export_dir):
def export_model(self, export_dir: str) -> None:
"""Export Policy to local directory for serving.
Arguments:
Args:
export_dir (str): Local writable directory.
"""
raise NotImplementedError
@DeveloperAPI
def export_checkpoint(self, export_dir):
def export_checkpoint(self, export_dir: str) -> None:
"""Export Policy checkpoint to local directory.
Argument:
Args:
export_dir (str): Local writable directory.
"""
raise NotImplementedError
@DeveloperAPI
def import_model_from_h5(self, import_file):
def import_model_from_h5(self, import_file: str) -> None:
"""Imports Policy from local file.
Arguments:
Args:
import_file (str): Local readable file.
"""
raise NotImplementedError
def _create_exploration(self):
def _create_exploration(self) -> Exploration:
"""Creates the Policy's Exploration object.
This method only exists b/c some Trainers do not use TfPolicy nor
TorchPolicy, but inherit directly from Policy. Others inherit from
TfPolicy w/o using DynamicTfPolicy.
TODO(sven): unify these cases."""
TODO(sven): unify these cases.
Returns:
Exploration: The Exploration object to be used by this Policy.
"""
if getattr(self, "exploration", None) is not None:
return self.exploration
+4 -4
View File
@@ -230,8 +230,8 @@ def chop_into_sequences(episode_ids,
f_pad = np.zeros((length, ) + np.shape(f)[1:])
seq_base = 0
i = 0
for l in seq_lens:
for seq_offset in range(l):
for len_ in seq_lens:
for seq_offset in range(len_):
f_pad[seq_base + seq_offset] = f[i]
i += 1
seq_base += max_seq_len
@@ -243,9 +243,9 @@ def chop_into_sequences(episode_ids,
s = np.array(s)
s_init = []
i = 0
for l in seq_lens:
for len_ in seq_lens:
s_init.append(s[i])
i += l
i += len_
initial_states.append(np.array(s_init))
if shuffle:
+143 -35
View File
@@ -2,12 +2,13 @@ import collections
import numpy as np
import sys
import itertools
from typing import Dict, List, Any
from typing import Any, Dict, Iterable, List, Optional, Set, Union
from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI
from ray.rllib.utils.compression import pack, unpack, is_compressed
from ray.rllib.utils.memory import concat_aligned
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.types import TensorType
# Default policy id for single agent environments
DEFAULT_POLICY_ID = "default_policy"
@@ -71,15 +72,16 @@ class SampleBatch:
@staticmethod
@PublicAPI
def concat_samples(samples):
def concat_samples(samples: List[Dict[str, TensorType]]) -> \
Union["SampleBatch", "MultiAgentBatch"]:
"""Concatenates n data dicts or MultiAgentBatches.
Args:
samples (List[Dict[np.ndarray]]]): List of dicts of data (numpy).
samples (List[Dict[TensorType]]]): List of dicts of data (numpy).
Returns:
Union[SampleBatch,MultiAgentBatch]: A new (compressed) SampleBatch/
MultiAgentBatch.
Union[SampleBatch, MultiAgentBatch]: A new (compressed)
SampleBatch or MultiAgentBatch.
"""
if isinstance(samples[0], MultiAgentBatch):
return MultiAgentBatch.concat_samples(samples)
@@ -90,9 +92,17 @@ class SampleBatch:
return SampleBatch(out)
@PublicAPI
def concat(self, other):
def concat(self, other: "SampleBatch") -> "SampleBatch":
"""Returns a new SampleBatch with each data column concatenated.
Args:
other (SampleBatch): The other SampleBatch object to concat to this
one.
Returns:
SampleBatch: The new SampleBatch, resulting from concating `other`
to `self`.
Examples:
>>> b1 = SampleBatch({"a": [1, 2]})
>>> b2 = SampleBatch({"a": [3, 4, 5]})
@@ -110,15 +120,24 @@ class SampleBatch:
return SampleBatch(out)
@PublicAPI
def copy(self):
def copy(self) -> "SampleBatch":
"""Creates a (deep) copy of this SampleBatch and returns it.
Returns:
SampleBatch: A (deep) copy of this SampleBatch object.
"""
return SampleBatch(
{k: np.array(v, copy=True)
for (k, v) in self.data.items()})
@PublicAPI
def rows(self):
def rows(self) -> Dict[str, TensorType]:
"""Returns an iterator over data rows, i.e. dicts with column values.
Yields:
Dict[str, TensorType]: The column values of the row in this
iteration.
Examples:
>>> batch = SampleBatch({"a": [1, 2, 3], "b": [4, 5, 6]})
>>> for row in batch.rows():
@@ -135,7 +154,7 @@ class SampleBatch:
yield row
@PublicAPI
def columns(self, keys):
def columns(self, keys: List[str]) -> List[any]:
"""Returns a list of the batch-data in the specified columns.
Args:
@@ -157,7 +176,7 @@ class SampleBatch:
return out
@PublicAPI
def shuffle(self):
def shuffle(self) -> None:
"""Shuffles the rows of this batch in-place."""
permutation = np.random.permutation(self.count)
@@ -165,7 +184,7 @@ class SampleBatch:
self[key] = val[permutation]
@PublicAPI
def split_by_episode(self):
def split_by_episode(self) -> List["SampleBatch"]:
"""Splits this batch's data by `eps_id`.
Returns:
@@ -189,7 +208,7 @@ class SampleBatch:
return slices
@PublicAPI
def slice(self, start, end):
def slice(self, start: int, end: int) -> "SampleBatch":
"""Returns a slice of the row data of this batch.
Args:
@@ -197,13 +216,25 @@ class SampleBatch:
end (int): Ending index.
Returns:
SampleBatch which has a slice of this batch's data.
SampleBatch: A new SampleBatch, which has a slice of this batch's
data.
"""
return SampleBatch({k: v[start:end] for k, v in self.data.items()})
@PublicAPI
def timeslices(self, k: int) -> List["SampleBatch"]:
"""Returns SampleBatches, each one representing a k-slice of this one.
Will start from timestep 0 and produce slices of size=k.
Args:
k (int): The size (in timesteps) of each returned SampleBatch.
Returns:
List[SampleBatch]: The list of (new) SampleBatches (each one of
size k).
"""
out = []
i = 0
while i < self.count:
@@ -212,31 +243,78 @@ class SampleBatch:
return out
@PublicAPI
def keys(self):
def keys(self) -> Iterable[str]:
"""
Returns:
Iterable[str]: The keys() iterable over `self.data`.
"""
return self.data.keys()
@PublicAPI
def items(self):
def items(self) -> Iterable[TensorType]:
"""
Returns:
Iterable[TensorType]: The values() iterable over `self.data`.
"""
return self.data.items()
@PublicAPI
def get(self, key):
def get(self, key: str) -> Optional[TensorType]:
"""Returns one column (by key) from the data or None if key not found.
Args:
key (str): The key (column name) to return.
Returns:
Optional[TensorType]: The data under the given key. None if key
not found in data.
"""
return self.data.get(key)
@PublicAPI
def size_bytes(self) -> int:
"""
Returns:
int: The overall size in bytes of the data buffer (all columns).
"""
return sum(sys.getsizeof(d) for d in self.data)
@PublicAPI
def __getitem__(self, key):
def __getitem__(self, key: str) -> TensorType:
"""Returns one column (by key) from the data.
Args:
key (str): The key (column name) to return.
Returns:
TensorType]: The data under the given key.
"""
return self.data[key]
@PublicAPI
def __setitem__(self, key, item):
def __setitem__(self, key, item) -> None:
"""Inserts (overrides) an entire column (by key) in the data buffer.
Args:
key (str): The column name to set a value for.
item (TensorType): The data to insert.
"""
self.data[key] = item
@DeveloperAPI
def compress(self, bulk=False, columns=frozenset(["obs", "new_obs"])):
def compress(
self,
bulk: bool = False,
columns: Set[str] = frozenset(["obs", "new_obs"])) -> None:
"""Compresses the data buffers (by column) in place.
Args:
bulk (bool): Whether to compress across the batch dimension (0)
as well. If False will compress n separate list items, where n
is the batch size.
columns (Set[str]): The columns to compress. Default: Only
compress the obs and new_obs columns.
"""
for key in columns:
if key in self.data:
if bulk:
@@ -246,7 +324,19 @@ class SampleBatch:
[pack(o) for o in self.data[key]])
@DeveloperAPI
def decompress_if_needed(self, columns=frozenset(["obs", "new_obs"])):
def decompress_if_needed(
self,
columns: Set[str] = frozenset(
["obs", "new_obs"])) -> "SampleBatch":
"""Decompresses data buffers (per column if not compressed) in place.
Args:
columns (Set[str]): The columns to decompress. Default: Only
decompress the obs and new_obs columns.
Returns:
SampleBatch: This very SampleBatch.
"""
for key in columns:
if key in self.data:
arr = self.data[key]
@@ -272,10 +362,17 @@ class SampleBatch:
@PublicAPI
class MultiAgentBatch:
"""A batch of experiences from multiple agents in the environment."""
"""A batch of experiences from multiple agents in the environment.
Attributes:
policy_batches (Dict[PolicyID, SampleBatch]): Mapping from policy
ids to SampleBatches of experiences.
count (int): The number of env steps in this batch.
"""
@PublicAPI
def __init__(self, policy_batches: Dict[PolicyID, SampleBatch],
def __init__(self,
policy_batches: Dict[PolicyID, SampleBatch],
env_steps: int):
"""Initialize a MultiAgentBatch object.
@@ -285,12 +382,8 @@ class MultiAgentBatch:
env_steps (int): The number of timesteps in the environment this
batch contains. This will be less than the number of
transitions this batch contains across all policies in total.
Attributes:
policy_batches (Dict[PolicyID, SampleBatch]): Mapping from policy
ids to SampleBatches of experiences.
count (int): the number of env steps in this batch.
"""
for v in policy_batches.values():
assert isinstance(v, SampleBatch)
self.policy_batches = policy_batches
@@ -303,7 +396,7 @@ class MultiAgentBatch:
"""The number of env steps (there are >= 1 agent steps per env step).
Returns:
int: the number of environment steps contained in this batch.
int: The number of environment steps contained in this batch.
"""
return self.count
@@ -312,7 +405,7 @@ class MultiAgentBatch:
"""The number of agent steps (there are >= 1 agent steps per env step).
Returns:
int: the number of agent steps total in this batch.
int: The number of agent steps total in this batch.
"""
ct = 0
for batch in self.policy_batches.values():
@@ -379,8 +472,9 @@ class MultiAgentBatch:
@staticmethod
@PublicAPI
def wrap_as_needed(policy_batches: Dict[PolicyID, SampleBatch],
env_steps: int) -> Any:
def wrap_as_needed(
policy_batches: Dict[PolicyID, SampleBatch],
env_steps: int) -> Union[SampleBatch, "MultiAgentBatch"]:
"""Returns SampleBatch or MultiAgentBatch, depending on given policies.
Args:
@@ -437,11 +531,19 @@ class MultiAgentBatch:
@PublicAPI
def size_bytes(self) -> int:
"""
Returns:
int: The overall size in bytes of all policy batches (all columns).
"""
return sum(b.size_bytes() for b in self.policy_batches.values())
@DeveloperAPI
def compress(self, bulk=False, columns=frozenset(["obs", "new_obs"])):
"""Compresses each policy batch.
def compress(
self,
bulk: bool = False,
columns: Set[str] = frozenset(
["obs", "new_obs"])) -> None:
"""Compresses each policy batch (per column) in place.
Args:
bulk (bool): Whether to compress across the batch dimension (0)
@@ -453,11 +555,17 @@ class MultiAgentBatch:
batch.compress(bulk=bulk, columns=columns)
@DeveloperAPI
def decompress_if_needed(self, columns=frozenset(["obs", "new_obs"])):
"""Decompresses each policy batch, if already compressed.
def decompress_if_needed(
self,
columns: Set[str] = frozenset(
["obs", "new_obs"])) -> "MultiAgentBatch":
"""Decompresses each policy batch (per column), if already compressed.
Args:
columns (Set[str]): Set of column names to decompress.
Returns:
MultiAgentBatch: This very MultiAgentBatch.
"""
for batch in self.policy_batches.values():
batch.decompress_if_needed(columns)
+237 -131
View File
@@ -1,7 +1,9 @@
import errno
import gym
import logging
import numpy as np
import os
from typing import Dict, List, Optional, Tuple, Union
import ray
import ray.experimental.tf_utils
@@ -15,6 +17,7 @@ from ray.rllib.utils.debug import summarize
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule
from ray.rllib.utils.tf_run_builder import TFRunBuilder
from ray.rllib.utils.types import ModelGradients, TensorType, TrainerConfigDict
tf1, tf, tfv = try_import_tf()
logger = logging.getLogger(__name__)
@@ -24,6 +27,11 @@ logger = logging.getLogger(__name__)
class TFPolicy(Policy):
"""An agent policy and loss implemented in TensorFlow.
Do not sub-class this class directly (neither should you sub-class
DynamicTFPolicy), but rather use
rllib.policy.tf_policy_template.build_tf_policy
to generate your custom tf (graph-mode or eager) Policy classes.
Extending this class enables RLlib to perform TensorFlow specific
optimizations on the policy, e.g., parallelization across gpus or
fusing multiple graphs together in the multi-agent setting.
@@ -48,77 +56,83 @@ class TFPolicy(Policy):
@DeveloperAPI
def __init__(self,
observation_space,
action_space,
config,
sess,
obs_input,
sampled_action,
loss,
loss_inputs,
model=None,
sampled_action_logp=None,
action_input=None,
log_likelihood=None,
dist_inputs=None,
dist_class=None,
state_inputs=None,
state_outputs=None,
prev_action_input=None,
prev_reward_input=None,
seq_lens=None,
max_seq_len=20,
batch_divisibility_req=1,
update_ops=None,
explore=None,
timestep=None):
"""Initialize the policy.
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict,
sess: "tf1.Session",
obs_input: TensorType,
sampled_action: TensorType,
loss: TensorType,
loss_inputs: List[Tuple[str, TensorType]],
model: ModelV2 = None,
sampled_action_logp: Optional[TensorType] = None,
action_input: Optional[TensorType] = None,
log_likelihood: Optional[TensorType] = None,
dist_inputs: Optional[TensorType] = None,
dist_class: Optional[type] = None,
state_inputs: Optional[List[TensorType]] = None,
state_outputs: Optional[List[TensorType]] = None,
prev_action_input: Optional[TensorType] = None,
prev_reward_input: Optional[TensorType] = None,
seq_lens: Optional[TensorType] = None,
max_seq_len: int = 20,
batch_divisibility_req: int = 1,
update_ops: List[TensorType] = None,
explore: Optional[TensorType] = None,
timestep: Optional[TensorType] = None):
"""Initializes a Policy object.
Arguments:
observation_space (gym.Space): Observation space of the env.
action_space (gym.Space): Action space of the env.
config (dict): The Policy config dict.
sess (Session): The TensorFlow session to use.
obs_input (Tensor): Input placeholder for observations, of shape
[BATCH_SIZE, obs...].
sampled_action (Tensor): Tensor for sampling an action, of shape
[BATCH_SIZE, action...]
loss (Tensor): Scalar policy loss output tensor.
loss_inputs (list): A (name, placeholder) tuple for each loss
input argument. Each placeholder name must correspond to a
SampleBatch column key returned by postprocess_trajectory(),
and has shape [BATCH_SIZE, data...]. These keys will be read
from postprocessed sample batches and fed into the specified
placeholders during loss computation.
model (rllib.models.Model): used to integrate custom losses and
Args:
observation_space (gym.spaces.Space): Observation space of the env.
action_space (gym.spaces.Space): Action space of the env.
config (TrainerConfigDict): The Policy config dict.
sess (tf1.Session): The TensorFlow session to use.
obs_input (TensorType): Input placeholder for observations, of
shape [BATCH_SIZE, obs...].
sampled_action (TensorType): Tensor for sampling an action, of
shape [BATCH_SIZE, action...]
loss (TensorType): Scalar policy loss output tensor.
loss_inputs (List[Tuple[str, TensorType]]): A (name, placeholder)
tuple for each loss input argument. Each placeholder name must
correspond to a SampleBatch column key returned by
postprocess_trajectory(), and has shape [BATCH_SIZE, data...].
These keys will be read from postprocessed sample batches and
fed into the specified placeholders during loss computation.
model (ModelV2): used to integrate custom losses and
stats from user-defined RLlib models.
sampled_action_logp (Tensor): log probability of the sampled
action.
action_input (Optional[Tensor]): Input placeholder for actions for
logp/log-likelihood calculations.
log_likelihood (Optional[Tensor]): Tensor to calculate the
sampled_action_logp (Optional[TensorType]): log probability of the
sampled action.
action_input (Optional[TensorType]): Input placeholder for actions
for logp/log-likelihood calculations.
log_likelihood (Optional[TensorType]): Tensor to calculate the
log_likelihood (given action_input and obs_input).
dist_class (Optional[type): An optional ActionDistribution class
dist_class (Optional[type]): An optional ActionDistribution class
to use for generating a dist object from distribution inputs.
dist_inputs (Optional[Tensor]): Tensor to calculate the
dist_inputs (Optional[TensorType]): Tensor to calculate the
distribution inputs/parameters.
state_inputs (list): list of RNN state input Tensors.
state_outputs (list): list of RNN state output Tensors.
prev_action_input (Tensor): placeholder for previous actions
prev_reward_input (Tensor): placeholder for previous rewards
seq_lens (Tensor): Placeholder for RNN sequence lengths, of shape
[NUM_SEQUENCES]. Note that NUM_SEQUENCES << BATCH_SIZE. See
state_inputs (Optional[List[TensorType]]): List of RNN state input
Tensors.
state_outputs (Optional[List[TensorType]]): List of RNN state
output Tensors.
prev_action_input (Optional[TensorType]): placeholder for previous
actions.
prev_reward_input (Optional[TensorType]): placeholder for previous
rewards.
seq_lens (Optional[TensorType]): Placeholder for RNN sequence
lengths, of shape [NUM_SEQUENCES].
Note that NUM_SEQUENCES << BATCH_SIZE. See
policy/rnn_sequencing.py for more information.
max_seq_len (int): Max sequence length for LSTM training.
batch_divisibility_req (int): pad all agent experiences batches to
multiples of this value. This only has an effect if not using
a LSTM model.
update_ops (list): override the batchnorm update ops to run when
applying gradients. Otherwise we run all update ops found in
the current variable scope.
explore (Tensor): Placeholder for `explore` parameter into
call to Exploration.get_exploration_action.
timestep (Tensor): Placeholder for the global sampling timestep.
update_ops (List[TensorType]): override the batchnorm update ops to
run when applying gradients. Otherwise we run all update ops
found in the current variable scope.
explore (Optional[TensorType]): Placeholder for `explore` parameter
into call to Exploration.get_exploration_action.
timestep (Optional[TensorType]): Placeholder for the global
sampling timestep.
"""
self.framework = "tf"
super().__init__(observation_space, action_space, config)
@@ -192,33 +206,49 @@ class TFPolicy(Policy):
"""Return the list of all savable variables for this policy."""
return self.model.variables()
def get_placeholder(self, name):
def get_placeholder(self, name) -> "tf1.placeholder":
"""Returns the given action or loss input placeholder by name.
If the loss has not been initialized and a loss input placeholder is
requested, an error is raised.
Args:
name (str): The name of the placeholder to return. One of
SampleBatch.CUR_OBS|PREV_ACTION/REWARD or a valid key from
`self._loss_input_dict`.
Returns:
tf1.placeholder: The placeholder under the given str key.
"""
obs_inputs = {
SampleBatch.CUR_OBS: self._obs_input,
SampleBatch.PREV_ACTIONS: self._prev_action_input,
SampleBatch.PREV_REWARDS: self._prev_reward_input,
}
if name in obs_inputs:
return obs_inputs[name]
if name == SampleBatch.CUR_OBS:
return self._obs_input
elif name == SampleBatch.PREV_ACTIONS:
return self._prev_action_input
elif name == SampleBatch.PREV_REWARDS:
return self._prev_reward_input
assert self._loss_input_dict is not None, \
"Should have set this before get_placeholder can be called"
return self._loss_input_dict[name]
def get_session(self):
def get_session(self) -> "tf1.Session":
"""Returns a reference to the TF session for this policy."""
return self._sess
def loss_initialized(self):
def loss_initialized(self) -> bool:
"""Returns whether the loss function has been initialized."""
return self._loss is not None
def _initialize_loss(self, loss, loss_inputs):
def _initialize_loss(self,
loss: TensorType,
loss_inputs: List[Tuple[str, TensorType]]) -> None:
"""Initializes the loss op from given loss tensor and placeholders.
Args:
loss (TensorType): The loss op generated by some loss function.
loss_inputs (List[Tuple[str, TensorType]]): The list of Tuples:
(name, tf1.placeholders) needed for calculating the loss.
"""
self._loss_inputs = loss_inputs
self._loss_input_dict = dict(self._loss_inputs)
for i, ph in enumerate(self._state_inputs):
@@ -270,16 +300,18 @@ class TFPolicy(Policy):
self._optimizer.variables(), self._sess)
@override(Policy)
def compute_actions(self,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
info_batch=None,
episodes=None,
explore=None,
timestep=None,
**kwargs):
def compute_actions(
self,
obs_batch: Union[List[TensorType], TensorType],
state_batches: Optional[List[TensorType]] = None,
prev_action_batch: Union[List[TensorType], TensorType] = None,
prev_reward_batch: Union[List[TensorType], TensorType] = None,
info_batch: Optional[Dict[str, list]] = None,
episodes: Optional[List["MultiAgentEpisode"]] = None,
explore: Optional[bool] = None,
timestep: Optional[int] = None,
**kwargs):
explore = explore if explore is not None else self.config["explore"]
timestep = timestep if timestep is not None else self.global_timestep
@@ -299,12 +331,16 @@ class TFPolicy(Policy):
return fetched
@override(Policy)
def compute_log_likelihoods(self,
actions,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None):
def compute_log_likelihoods(
self,
actions: Union[List[TensorType], TensorType],
obs_batch: Union[List[TensorType], TensorType],
state_batches: Optional[List[TensorType]] = None,
prev_action_batch: Optional[
Union[List[TensorType], TensorType]] = None,
prev_reward_batch: Optional[
Union[List[TensorType], TensorType]] = None) -> TensorType:
if self._log_likelihood is None:
raise ValueError("Cannot compute log-prob/likelihood w/o a "
"self._log_likelihood op!")
@@ -341,40 +377,51 @@ class TFPolicy(Policy):
return builder.get(fetches)[0]
@override(Policy)
def compute_gradients(self, postprocessed_batch):
assert self.loss_initialized()
builder = TFRunBuilder(self._sess, "compute_gradients")
fetches = self._build_compute_gradients(builder, postprocessed_batch)
return builder.get(fetches)
@override(Policy)
def apply_gradients(self, gradients):
assert self.loss_initialized()
builder = TFRunBuilder(self._sess, "apply_gradients")
fetches = self._build_apply_gradients(builder, gradients)
builder.get(fetches)
@override(Policy)
def learn_on_batch(self, postprocessed_batch):
@DeveloperAPI
def learn_on_batch(self, postprocessed_batch: SampleBatch) -> Dict[
str, TensorType]:
assert self.loss_initialized()
builder = TFRunBuilder(self._sess, "learn_on_batch")
fetches = self._build_learn_on_batch(builder, postprocessed_batch)
return builder.get(fetches)
@override(Policy)
def get_exploration_info(self):
@DeveloperAPI
def compute_gradients(
self,
postprocessed_batch: SampleBatch) -> \
Tuple[ModelGradients, Dict[str, TensorType]]:
assert self.loss_initialized()
builder = TFRunBuilder(self._sess, "compute_gradients")
fetches = self._build_compute_gradients(builder, postprocessed_batch)
return builder.get(fetches)
@override(Policy)
@DeveloperAPI
def apply_gradients(self, gradients: ModelGradients) -> None:
assert self.loss_initialized()
builder = TFRunBuilder(self._sess, "apply_gradients")
fetches = self._build_apply_gradients(builder, gradients)
builder.get(fetches)
@override(Policy)
@DeveloperAPI
def get_exploration_info(self) -> Dict[str, TensorType]:
return self.exploration.get_info(sess=self.get_session())
@override(Policy)
def get_weights(self):
@DeveloperAPI
def get_weights(self) -> Union[Dict[str, TensorType], List[TensorType]]:
return self._variables.get_weights()
@override(Policy)
def set_weights(self, weights):
@DeveloperAPI
def set_weights(self, weights) -> None:
return self._variables.set_weights(weights)
@override(Policy)
def get_state(self):
@DeveloperAPI
def get_state(self) -> Union[Dict[str, TensorType], List[TensorType]]:
# For tf Policies, return Policy weights and optimizer var values.
state = super().get_state()
if self._optimizer_variables and \
@@ -384,7 +431,8 @@ class TFPolicy(Policy):
return state
@override(Policy)
def set_state(self, state):
@DeveloperAPI
def set_state(self, state) -> None:
state = state.copy() # shallow copy
# Set optimizer vars first.
optimizer_vars = state.pop("_optimizer_variables", None)
@@ -394,7 +442,8 @@ class TFPolicy(Policy):
super().set_state(state)
@override(Policy)
def export_model(self, export_dir):
@DeveloperAPI
def export_model(self, export_dir: str) -> None:
"""Export tensorflow graph to export_dir for serving."""
with self._sess.graph.as_default():
builder = tf1.saved_model.builder.SavedModelBuilder(export_dir)
@@ -407,7 +456,9 @@ class TFPolicy(Policy):
builder.save()
@override(Policy)
def export_checkpoint(self, export_dir, filename_prefix="model"):
@DeveloperAPI
def export_checkpoint(self, export_dir: str,
filename_prefix: str = "model") -> None:
"""Export tensorflow checkpoint to export_dir."""
try:
os.makedirs(export_dir)
@@ -421,7 +472,8 @@ class TFPolicy(Policy):
saver.save(self._sess, save_path)
@override(Policy)
def import_model_from_h5(self, import_file):
@DeveloperAPI
def import_model_from_h5(self, import_file: str) -> None:
"""Imports weights into tf model."""
# Make sure the session is the right one (see issue #7046).
with self._sess.graph.as_default():
@@ -429,31 +481,53 @@ class TFPolicy(Policy):
return self.model.import_from_h5(import_file)
@DeveloperAPI
def copy(self, existing_inputs):
def copy(self,
existing_inputs: List[Tuple[str, "tf1.placeholder"]]) -> \
"TFPolicy":
"""Creates a copy of self using existing input placeholders.
Optional, only required to work with the multi-GPU optimizer."""
Optional: Only required to work with the multi-GPU optimizer.
Args:
existing_inputs (List[Tuple[str, tf1.placeholder]]): Dict mapping
names (str) to tf1.placeholders to re-use (share) with the
returned copy of self.
Returns:
TFPolicy: A copy of self.
"""
raise NotImplementedError
@override(Policy)
def is_recurrent(self):
@DeveloperAPI
def is_recurrent(self) -> bool:
return len(self._state_inputs) > 0
@override(Policy)
def num_state_tensors(self):
@DeveloperAPI
def num_state_tensors(self) -> int:
return len(self._state_inputs)
@DeveloperAPI
def extra_compute_action_feed_dict(self):
"""Extra dict to pass to the compute actions session run."""
def extra_compute_action_feed_dict(self) -> Dict[TensorType, TensorType]:
"""Extra dict to pass to the compute actions session run.
Returns:
Dict[TensorType, TensorType]: A feed dict to be added to the
feed_dict passed to the compute_actions session.run() call.
"""
return {}
@DeveloperAPI
def extra_compute_action_fetches(self):
def extra_compute_action_fetches(self) -> Dict[str, TensorType]:
"""Extra values to fetch and return from compute_actions().
By default we return action probability/log-likelihood info
and action distribution inputs (if present).
Returns:
Dict[str, TensorType]: An extra fetch-dict to be passed to and
returned from the compute_actions() call.
"""
extra_fetches = {}
# Action-logp and action-prob.
@@ -466,38 +540,70 @@ class TFPolicy(Policy):
return extra_fetches
@DeveloperAPI
def extra_compute_grad_feed_dict(self):
"""Extra dict to pass to the compute gradients session run."""
def extra_compute_grad_feed_dict(self) -> Dict[TensorType, TensorType]:
"""Extra dict to pass to the compute gradients session run.
Returns:
Dict[TensorType, TensorType]: Extra feed_dict to be passed to the
compute_gradients Session.run() call.
"""
return {} # e.g, kl_coeff
@DeveloperAPI
def extra_compute_grad_fetches(self):
"""Extra values to fetch and return from compute_gradients()."""
def extra_compute_grad_fetches(self) -> Dict[str, any]:
"""Extra values to fetch and return from compute_gradients().
Returns:
Dict[str, any]: Extra fetch dict to be added to the fetch dict
of the compute_gradients Session.run() call.
"""
return {LEARNER_STATS_KEY: {}} # e.g, stats, td error, etc.
@DeveloperAPI
def optimizer(self):
"""TF optimizer to use for policy optimization."""
def optimizer(self) -> "tf.keras.optimizers.Optimizer":
"""TF optimizer to use for policy optimization.
Returns:
tf.keras.optimizers.Optimizer: The local optimizer to use for this
Policy's Model.
"""
if hasattr(self, "config"):
return tf1.train.AdamOptimizer(learning_rate=self.config["lr"])
else:
return tf1.train.AdamOptimizer()
@DeveloperAPI
def gradients(self, optimizer, loss):
"""Override for custom gradient computation."""
def gradients(self,
optimizer: "tf.keras.optimizers.Optimizer",
loss: TensorType) -> List[Tuple[TensorType, TensorType]]:
"""Override this for a custom gradient computation behavior.
Returns:
List[Tuple[TensorType, TensorType]]: List of tuples with grad
values and the grad-value's corresponding tf.variable in it.
"""
return optimizer.compute_gradients(loss)
@DeveloperAPI
def build_apply_op(self, optimizer, grads_and_vars):
"""Override for custom gradient apply computation."""
def build_apply_op(
self,
optimizer: "tf.keras.optimizers.Optimizer",
grads_and_vars: List[Tuple[TensorType, TensorType]]) -> \
"tf.Operation":
"""Override this for a custom gradient apply computation behavior.
# specify global_step for TD3 which needs to count the num updates
Args:
optimizer (tf.keras.optimizers.Optimizer): The local tf optimizer
to use for applying the grads and vars.
grads_and_vars (List[Tuple[TensorType, TensorType]]): List of
tuples with grad values and the grad-value's corresponding
tf.variable in it.
"""
# Specify global_step for TD3 which needs to count the num updates.
return optimizer.apply_gradients(
self._grads_and_vars,
global_step=tf1.train.get_or_create_global_step())
@DeveloperAPI
def _get_is_training_placeholder(self):
"""Get the placeholder for _is_training, i.e., for batch norm layers.
@@ -670,7 +776,7 @@ class TFPolicy(Policy):
def _get_loss_inputs_dict(self, batch, shuffle):
"""Return a feed dict from a batch.
Arguments:
Args:
batch (SampleBatch): batch of data to derive inputs from
shuffle (bool): whether to shuffle batch sequences. Shuffle may
be done in-place. This only makes sense if you're further
+4 -4
View File
@@ -131,10 +131,10 @@ def build_tf_policy(name,
DynamicTFPolicy.__init__(
self,
obs_space,
action_space,
config,
loss_fn,
obs_space=obs_space,
action_space=action_space,
config=config,
loss_fn=loss_fn,
stats_fn=stats_fn,
grad_stats_fn=grad_stats_fn,
before_loss_init=before_loss_init_wrapper,
+176 -99
View File
@@ -1,8 +1,11 @@
import functools
import gym
import numpy as np
import time
from typing import Dict, List, Optional
from typing import Callable, Dict, List, Optional, Tuple, Union
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
from ray.rllib.policy.sample_batch import SampleBatch
@@ -15,11 +18,13 @@ from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule
from ray.rllib.utils.torch_ops import convert_to_non_torch_type, \
convert_to_torch_tensor
from ray.rllib.utils.tracking_dict import UsageTrackingDict
from ray.rllib.utils.types import AgentID
from ray.rllib.utils.types import AgentID, ModelGradients, ModelWeights, \
TensorType, TrainerConfigDict
torch, _ = try_import_torch()
@DeveloperAPI
class TorchPolicy(Policy):
"""Template for a PyTorch policy and loss to use with RLlib.
@@ -33,49 +38,63 @@ class TorchPolicy(Policy):
dist_class (type): Torch action distribution class.
"""
@DeveloperAPI
def __init__(self,
observation_space,
action_space,
config,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict,
*,
model,
loss,
action_distribution_class,
action_sampler_fn=None,
action_distribution_fn=None,
max_seq_len=20,
get_batch_divisibility_req=None):
model: ModelV2,
loss: Callable[
[Policy, ModelV2, type, SampleBatch], TensorType],
action_distribution_class: TorchDistributionWrapper,
action_sampler_fn: Callable[
[TensorType, List[TensorType]], Tuple[
TensorType, TensorType]] = None,
action_distribution_fn: Optional[Callable[
[Policy, ModelV2, TensorType, TensorType, TensorType],
Tuple[TensorType, type, List[TensorType]]]] = None,
max_seq_len: int = 20,
get_batch_divisibility_req: Optional[int] = None):
"""Build a policy from policy and loss torch modules.
Note that model will be placed on GPU device if CUDA_VISIBLE_DEVICES
is set. Only single GPU is supported for now.
Arguments:
observation_space (gym.Space): observation space of the policy.
action_space (gym.Space): action space of the policy.
config (dict): The Policy config dict.
model (nn.Module): PyTorch policy module. Given observations as
Args:
observation_space (gym.spaces.Space): observation space of the
policy.
action_space (gym.spaces.Space): action space of the policy.
config (TrainerConfigDict): The Policy config dict.
model (ModelV2): PyTorch policy module. Given observations as
input, this module must return a list of outputs where the
first item is action logits, and the rest can be any value.
loss (func): Function that takes (policy, model, dist_class,
train_batch) and returns a single scalar loss.
action_distribution_class (ActionDistribution): Class for action
distribution.
action_sampler_fn (Optional[callable]): A callable returning a
sampled action and its log-likelihood given some (obs and
state) inputs.
action_distribution_fn (Optional[callable]): A callable returning
distribution inputs (parameters), a dist-class to generate an
action distribution object from, and internal-state outputs
(or an empty list if not applicable).
loss (Callable[[Policy, ModelV2, type, SampleBatch], TensorType]):
Function that takes (policy, model, dist_class, train_batch)
and returns a single scalar loss.
action_distribution_class (TorchDistributionWrapper): Class for
a torch action distribution.
action_sampler_fn (Callable[[TensorType, List[TensorType]],
Tuple[TensorType, TensorType]]): A callable returning a
sampled action and its log-likelihood given Policy, ModelV2,
input_dict, explore, timestep, and is_training.
action_distribution_fn (Optional[Callable[[Policy, ModelV2,
Dict[str, TensorType], TensorType, TensorType],
Tuple[TensorType, type, List[TensorType]]]]): A callable
returning distribution inputs (parameters), a dist-class to
generate an action distribution object from, and
internal-state outputs (or an empty list if not applicable).
Note: No Exploration hooks have to be called from within
`action_distribution_fn`. It's should only perform a simple
forward pass through some model.
If None, pass inputs through `self.model()` to get the
distribution inputs.
If None, pass inputs through `self.model()` to get distribution
inputs.
The callable takes as inputs: Policy, ModelV2, input_dict,
explore, timestep, is_training.
max_seq_len (int): Max sequence length for LSTM training.
get_batch_divisibility_req (Optional[callable]): Optional callable
that returns the divisibility requirement for sample batches.
get_batch_divisibility_req (Optional[Callable[[Policy], int]]]):
Optional callable that returns the divisibility requirement
for sample batches given the Policy.
"""
self.framework = "torch"
super().__init__(observation_space, action_space, config)
@@ -100,16 +119,19 @@ class TorchPolicy(Policy):
else 1
@override(Policy)
def compute_actions(self,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
info_batch=None,
episodes=None,
explore=None,
timestep=None,
**kwargs):
@DeveloperAPI
def compute_actions(
self,
obs_batch: Union[List[TensorType], TensorType],
state_batches: Optional[List[TensorType]] = None,
prev_action_batch: Union[List[TensorType], TensorType] = None,
prev_reward_batch: Union[List[TensorType], TensorType] = None,
info_batch: Optional[Dict[str, list]] = None,
episodes: Optional[List["MultiAgentEpisode"]] = None,
explore: Optional[bool] = None,
timestep: Optional[int] = None,
**kwargs) -> \
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
explore = explore if explore is not None else self.config["explore"]
timestep = timestep if timestep is not None else self.global_timestep
@@ -149,7 +171,8 @@ class TorchPolicy(Policy):
other_trajectories: Dict[AgentID, "Trajectory"],
explore: bool = None,
timestep: Optional[int] = None,
**kwargs):
**kwargs) -> \
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
explore = explore if explore is not None else self.config["explore"]
timestep = timestep if timestep is not None else self.global_timestep
@@ -237,12 +260,16 @@ class TorchPolicy(Policy):
return actions, state_out, extra_fetches, logp
@override(Policy)
def compute_log_likelihoods(self,
actions,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None):
@DeveloperAPI
def compute_log_likelihoods(
self,
actions: Union[List[TensorType], TensorType],
obs_batch: Union[List[TensorType], TensorType],
state_batches: Optional[List[TensorType]] = None,
prev_action_batch: Optional[
Union[List[TensorType], TensorType]] = None,
prev_reward_batch: Optional[
Union[List[TensorType], TensorType]] = None) -> TensorType:
if self.action_sampler_fn and self.action_distribution_fn is None:
raise ValueError("Cannot compute log-prob/likelihood w/o an "
@@ -282,7 +309,9 @@ class TorchPolicy(Policy):
return log_likelihoods
@override(Policy)
def learn_on_batch(self, postprocessed_batch):
@DeveloperAPI
def learn_on_batch(self, postprocessed_batch: SampleBatch) -> Dict[
str, TensorType]:
# Get batch ready for RNNs, if applicable.
pad_batch_to_sequences_of_same_size(
postprocessed_batch,
@@ -340,7 +369,9 @@ class TorchPolicy(Policy):
return {LEARNER_STATS_KEY: grad_info}
@override(Policy)
def compute_gradients(self, postprocessed_batch):
@DeveloperAPI
def compute_gradients(self,
postprocessed_batch: SampleBatch) -> ModelGradients:
train_batch = self._lazy_tensor_dict(postprocessed_batch)
loss_out = force_list(
self._loss(self, self.model, self.dist_class, train_batch))
@@ -367,7 +398,8 @@ class TorchPolicy(Policy):
return grads, {LEARNER_STATS_KEY: grad_info}
@override(Policy)
def apply_gradients(self, gradients):
@DeveloperAPI
def apply_gradients(self, gradients: ModelGradients) -> None:
# TODO(sven): Not supported for multiple optimizers yet.
assert len(self._optimizers) == 1
for g, p in zip(gradients, self.model.parameters()):
@@ -377,19 +409,39 @@ class TorchPolicy(Policy):
self._optimizers[0].step()
@override(Policy)
def get_weights(self):
@DeveloperAPI
def get_weights(self) -> ModelWeights:
return {
k: v.cpu().detach().numpy()
for k, v in self.model.state_dict().items()
}
@override(Policy)
def set_weights(self, weights):
@DeveloperAPI
def set_weights(self, weights: ModelWeights) -> None:
weights = convert_to_torch_tensor(weights, device=self.device)
self.model.load_state_dict(weights)
@override(Policy)
def get_state(self):
@DeveloperAPI
def is_recurrent(self) -> bool:
return len(self.model.get_initial_state()) > 0
@override(Policy)
@DeveloperAPI
def num_state_tensors(self) -> int:
return len(self.model.get_initial_state())
@override(Policy)
@DeveloperAPI
def get_initial_state(self) -> List[TensorType]:
return [
s.cpu().detach().numpy() for s in self.model.get_initial_state()
]
@override(Policy)
@DeveloperAPI
def get_state(self) -> Union[Dict[str, TensorType], List[TensorType]]:
state = super().get_state()
state["_optimizer_variables"] = []
for i, o in enumerate(self._optimizers):
@@ -397,7 +449,8 @@ class TorchPolicy(Policy):
return state
@override(Policy)
def set_state(self, state):
@DeveloperAPI
def set_state(self, state: object) -> None:
state = state.copy() # shallow copy
# Set optimizer vars first.
optimizer_vars = state.pop("_optimizer_variables", None)
@@ -408,21 +461,11 @@ class TorchPolicy(Policy):
# Then the Policy's (NN) weights.
super().set_state(state)
@override(Policy)
def is_recurrent(self):
return len(self.model.get_initial_state()) > 0
@override(Policy)
def num_state_tensors(self):
return len(self.model.get_initial_state())
@override(Policy)
def get_initial_state(self):
return [
s.cpu().detach().numpy() for s in self.model.get_initial_state()
]
def extra_grad_process(self, optimizer, loss):
@DeveloperAPI
def extra_grad_process(
self,
optimizer: "torch.optim.Optimizer",
loss: TensorType):
"""Called after each optimizer.zero_grad() + loss.backward() call.
Called for each self._optimizers/loss-value pair.
@@ -431,60 +474,94 @@ class TorchPolicy(Policy):
Args:
optimizer (torch.optim.Optimizer): A torch optimizer object.
loss (torch.Tensor): The loss tensor associated with the optimizer.
loss (TensorType): The loss tensor associated with the optimizer.
Returns:
dict: An info dict.
Dict[str, TensorType]: An dict with information on the gradient
processing step.
"""
return {}
def extra_action_out(self, input_dict, state_batches, model, action_dist):
@DeveloperAPI
def extra_action_out(
self,
input_dict: Dict[str, TensorType],
state_batches: List[TensorType],
model: TorchModelV2,
action_dist: TorchDistributionWrapper) -> Dict[str, TensorType]:
"""Returns dict of extra info to include in experience batch.
Args:
input_dict (dict): Dict of model input tensors.
state_batches (list): List of state tensors.
model (TorchModelV2): Reference to the model.
action_dist (TorchActionDistribution): Torch action dist object
input_dict (Dict[str, TensorType]): Dict of model input tensors.
state_batches (List[TensorType]): List of state tensors.
model (TorchModelV2): Reference to the model object.
action_dist (TorchDistributionWrapper): Torch action dist object
to get log-probs (e.g. for already sampled actions).
Returns:
Dict[str, TensorType]: Extra outputs to return in a
compute_actions() call (3rd return value).
"""
return {}
def extra_grad_info(self, train_batch):
"""Return dict of extra grad info."""
@DeveloperAPI
def extra_grad_info(self, train_batch: SampleBatch) -> Dict[
str, TensorType]:
"""Return dict of extra grad info.
Args:
train_batch (SampleBatch): The training batch for which to produce
extra grad info for.
Returns:
Dict[str, TensorType]: The info dict carrying grad info per str
key.
"""
return {}
def optimizer(self):
"""Custom PyTorch optimizer to use."""
@DeveloperAPI
def optimizer(self) -> Union[
List["torch.optim.Optimizer"], "torch.optim.Optimizer"]:
"""Custom the local PyTorch optimizer(s) to use.
Returns:
Union[List[torch.optim.Optimizer], torch.optim.Optimizer]:
The local PyTorch optimizer(s) to use for this Policy.
"""
if hasattr(self, "config"):
return torch.optim.Adam(
self.model.parameters(), lr=self.config["lr"])
else:
return torch.optim.Adam(self.model.parameters())
@override(Policy)
@DeveloperAPI
def export_model(self, export_dir: str) -> None:
"""TODO(sven): implement for torch.
"""
raise NotImplementedError
@override(Policy)
@DeveloperAPI
def export_checkpoint(self, export_dir: str) -> None:
"""TODO(sven): implement for torch.
"""
raise NotImplementedError
@override(Policy)
@DeveloperAPI
def import_model_from_h5(self, import_file: str) -> None:
"""Imports weights into torch model."""
return self.model.import_from_h5(import_file)
def _lazy_tensor_dict(self, postprocessed_batch):
train_batch = UsageTrackingDict(postprocessed_batch)
train_batch.set_get_interceptor(convert_to_torch_tensor)
return train_batch
@override(Policy)
def export_model(self, export_dir):
"""TODO(sven): implement for torch.
"""
raise NotImplementedError
@override(Policy)
def export_checkpoint(self, export_dir):
"""TODO(sven): implement for torch.
"""
raise NotImplementedError
@override(Policy)
def import_model_from_h5(self, import_file):
"""Imports weights into torch model."""
return self.model.import_from_h5(import_file)
# TODO: (sven) Unify hyperparam annealing procedures across RLlib (tf/torch)
# and for all possible hyperparams, not just lr.
@DeveloperAPI
class LearningRateSchedule:
"""Mixin for TFPolicy that adds a learning rate schedule."""
+2
View File
@@ -23,3 +23,5 @@ if __name__ == "__main__":
# Clean up.
del os.environ["RLLIB_TEST_NO_TORCH_IMPORT"]
print("ok")
+2 -2
View File
@@ -5,8 +5,8 @@ from ray.rllib.utils.schedules.schedule import Schedule
tf1, tf, tfv = try_import_tf()
def _linear_interpolation(l, r, alpha):
return l + alpha * (r - l)
def _linear_interpolation(left, right, alpha):
return left + alpha * (right - left)
class PiecewiseSchedule(Schedule):
+10 -5
View File
@@ -2,6 +2,9 @@ from typing import Any, Dict, List, Tuple, Union
import gym
# Represents a fully filled out config of a Trainer class.
# Note: Policy config dicts are usually the same as TrainerConfigDict, but
# parts of it may sometimes be altered in e.g. a multi-agent setup,
# where we have >1 Policies in the same Trainer.
TrainerConfigDict = dict
# A trainer config dict that only has overrides. It needs to be combined with
@@ -63,8 +66,13 @@ GradInfoDict = dict
# policy id.
LearnerStatsDict = dict
# Type of dict returned by compute_gradients() representing model gradients.
ModelGradients = dict
# Represents a generic tensor type.
# This could be an np.ndarray, tf.Tensor, or a torch.Tensor.
TensorType = Any
# List of grads+var tuples (tf) or list of gradient tensors (torch)
# representing model gradients and returned by compute_gradients().
ModelGradients = Union[List[Tuple[TensorType, TensorType]], List[TensorType]]
# Type of dict returned by get_weights() representing model weights.
ModelWeights = dict
@@ -72,9 +80,6 @@ ModelWeights = dict
# Some kind of sample batch.
SampleBatchType = Union["SampleBatch", "MultiAgentBatch"]
# Represents a generic tensor type.
TensorType = Any
# Either a plain tensor, or a dict or tuple of tensors (or StructTensors).
TensorStructType = Union[TensorType, dict, tuple]