mirror of
https://github.com/wassname/ray.git
synced 2026-07-06 04:12:08 +08:00
160 lines
5.5 KiB
Python
160 lines
5.5 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
|
|
class PolicyGraph(object):
|
|
"""An agent policy and loss, i.e., a TFPolicyGraph or other subclass.
|
|
|
|
This object defines how to act in the environment, and also losses used to
|
|
improve the policy based on its experiences. Note that both policy and
|
|
loss are defined together for convenience, though the policy itself is
|
|
logically separate.
|
|
|
|
All policies can directly extend PolicyGraph, however TensorFlow users may
|
|
find TFPolicyGraph simpler to implement. TFPolicyGraph also enables RLlib
|
|
to apply TensorFlow-specific optimizations such as fusing multiple policy
|
|
graphs and multi-GPU support.
|
|
|
|
Attributes:
|
|
observation_space (gym.Space): Observation space of the policy.
|
|
action_space (gym.Space): Action space of the policy.
|
|
"""
|
|
|
|
def __init__(self, observation_space, action_space, config):
|
|
"""Initialize the graph.
|
|
|
|
This is the standard constructor for policy graphs. The policy graph
|
|
class you pass into PolicyEvaluator will be constructed with
|
|
these arguments.
|
|
|
|
Args:
|
|
observation_space (gym.Space): Observation space of the policy.
|
|
action_space (gym.Space): Action space of the policy.
|
|
config (dict): Policy-specific configuration data.
|
|
"""
|
|
|
|
self.observation_space = observation_space
|
|
self.action_space = action_space
|
|
|
|
def compute_actions(self, obs_batch, state_batches, is_training=False):
|
|
"""Compute actions for the current policy.
|
|
|
|
Arguments:
|
|
obs_batch (np.ndarray): batch of observations
|
|
state_batches (list): list of RNN state input batches, if any
|
|
is_training (bool): whether we are training the policy
|
|
|
|
Returns:
|
|
actions (np.ndarray): batch of output actions, with shape like
|
|
[BATCH_SIZE, ACTION_SHAPE].
|
|
state_outs (list): list of RNN state output batches, if any, with
|
|
shape like [STATE_SIZE, BATCH_SIZE].
|
|
info (dict): dictionary of extra feature batches, if any, with
|
|
shape like {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def compute_single_action(self, obs, state, is_training=False):
|
|
"""Unbatched version of compute_actions.
|
|
|
|
Arguments:
|
|
obs (obj): single observation
|
|
state_batches (list): list of RNN state inputs, if any
|
|
is_training (bool): whether we are training the policy
|
|
|
|
Returns:
|
|
actions (obj): single action
|
|
state_outs (list): list of RNN state outputs, if any
|
|
info (dict): dictionary of extra features, if any
|
|
"""
|
|
|
|
[action], state_out, info = self.compute_actions(
|
|
[obs], [[s] for s in state], is_training)
|
|
return action, [s[0] for s in state_out], \
|
|
{k: v[0] for k, v in info.items()}
|
|
|
|
def postprocess_trajectory(self, sample_batch, other_agent_batches=None):
|
|
"""Implements algorithm-specific trajectory postprocessing.
|
|
|
|
Arguments:
|
|
sample_batch (SampleBatch): batch of experiences for the policy,
|
|
which will contain at most one episode trajectory.
|
|
other_agent_batches (dict): In a multi-agent env, this contains a
|
|
mapping of agent ids to (policy_graph, agent_batch) tuples
|
|
containing the policy graph and experiences of the other agent.
|
|
|
|
Returns:
|
|
SampleBatch: postprocessed sample batch.
|
|
"""
|
|
return sample_batch
|
|
|
|
def compute_gradients(self, postprocessed_batch):
|
|
"""Computes gradients against a batch of experiences.
|
|
|
|
Returns:
|
|
grads (list): List of gradient output values
|
|
info (dict): Extra policy-specific values
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def apply_gradients(self, gradients):
|
|
"""Applies previously computed gradients.
|
|
|
|
Returns:
|
|
info (dict): Extra policy-specific values
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def compute_apply(self, samples):
|
|
"""Fused compute gradients and apply gradients call.
|
|
|
|
Returns:
|
|
grad_info: dictionary of extra metadata from compute_gradients().
|
|
apply_info: dictionary of extra metadata from apply_gradients().
|
|
|
|
Examples:
|
|
>>> batch = ev.sample()
|
|
>>> ev.compute_apply(samples)
|
|
"""
|
|
|
|
grads, grad_info = self.compute_gradients(samples)
|
|
apply_info = self.apply_gradients(grads)
|
|
return grad_info, apply_info
|
|
|
|
def get_weights(self):
|
|
"""Returns model weights.
|
|
|
|
Returns:
|
|
weights (obj): Serializable copy or view of model weights
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def set_weights(self, weights):
|
|
"""Sets model weights.
|
|
|
|
Arguments:
|
|
weights (obj): Serializable copy or view of model weights
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def get_initial_state(self):
|
|
"""Returns initial RNN state for the current policy."""
|
|
return []
|
|
|
|
def get_state(self):
|
|
"""Saves all local state.
|
|
|
|
Returns:
|
|
state (obj): Serialized local state.
|
|
"""
|
|
return self.get_weights()
|
|
|
|
def set_state(self, state):
|
|
"""Restores all local state.
|
|
|
|
Arguments:
|
|
state (obj): Serialized local state.
|
|
"""
|
|
self.set_weights(state)
|