[rllib] Refactor pytorch custom model support (#3634)

This commit is contained in:
Eric Liang
2019-01-03 13:48:33 +08:00
committed by GitHub
parent b6bcd18d65
commit 47d36d7bd6
19 changed files with 402 additions and 240 deletions
@@ -7,7 +7,6 @@ import torch.nn.functional as F
from torch import nn
import ray
from ray.rllib.models.pytorch.misc import var_to_np
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.evaluation.postprocessing import compute_advantages
from ray.rllib.evaluation.policy_graph import PolicyGraph
@@ -23,7 +22,7 @@ class A3CLoss(nn.Module):
self.entropy_coeff = entropy_coeff
def forward(self, observations, actions, advantages, value_targets):
logits, values = self.policy_model(observations)
logits, _, values, _ = self.policy_model({"obs": observations}, [])
log_probs = F.log_softmax(logits, dim=1)
probs = F.softmax(logits, dim=1)
action_log_probs = log_probs.gather(1, actions.view(-1, 1))
@@ -46,8 +45,8 @@ class A3CTorchPolicyGraph(TorchPolicyGraph):
self.config = config
_, self.logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"])
self.model = ModelCatalog.get_torch_model(
obs_space.shape, self.logit_dim, self.config["model"])
self.model = ModelCatalog.get_torch_model(obs_space, self.logit_dim,
self.config["model"])
loss = A3CLoss(self.model, self.config["vf_loss_coeff"],
self.config["entropy_coeff"])
TorchPolicyGraph.__init__(
@@ -60,7 +59,7 @@ class A3CTorchPolicyGraph(TorchPolicyGraph):
@override(TorchPolicyGraph)
def extra_action_out(self, model_out):
return {"vf_preds": var_to_np(model_out[1])}
return {"vf_preds": model_out[2].numpy()}
@override(TorchPolicyGraph)
def optimizer(self):
@@ -82,7 +81,5 @@ class A3CTorchPolicyGraph(TorchPolicyGraph):
def _value(self, obs):
with self.lock:
obs = torch.from_numpy(obs).float().unsqueeze(0)
res = self.model.hidden_layers(obs)
res = self.model.value_branch(res)
res = res.squeeze()
return var_to_np(res)
_, _, vf, _ = self.model({"obs": obs}, [])
return vf.numpy().squeeze()
+26 -15
View File
@@ -5,24 +5,35 @@ from __future__ import print_function
from torch import nn
import torch.nn.functional as F
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.models.pytorch.model import TorchModel
from ray.rllib.utils.annotations import override
# TODO(ekl) we should have common models for pytorch like we do for TF
class RNNModel(nn.Module):
def __init__(self, obs_size, rnn_hidden_dim, n_actions):
nn.Module.__init__(self)
self.rnn_hidden_dim = rnn_hidden_dim
self.n_actions = n_actions
self.fc1 = nn.Linear(obs_size, rnn_hidden_dim)
self.rnn = nn.GRUCell(rnn_hidden_dim, rnn_hidden_dim)
self.fc2 = nn.Linear(rnn_hidden_dim, n_actions)
def init_hidden(self):
class RNNModel(TorchModel):
"""The default RNN model for QMIX."""
def __init__(self, obs_space, num_outputs, options):
TorchModel.__init__(self, obs_space, num_outputs, options)
self.obs_size = _get_size(obs_space)
self.rnn_hidden_dim = options["lstm_cell_size"]
self.fc1 = nn.Linear(self.obs_size, self.rnn_hidden_dim)
self.rnn = nn.GRUCell(self.rnn_hidden_dim, self.rnn_hidden_dim)
self.fc2 = nn.Linear(self.rnn_hidden_dim, num_outputs)
@override(TorchModel)
def state_init(self):
# make hidden states on same device as model
return self.fc1.weight.new(1, self.rnn_hidden_dim).zero_()
return [self.fc1.weight.new(1, self.rnn_hidden_dim).zero_().squeeze(0)]
def forward(self, inputs, hidden_state):
x = F.relu(self.fc1(inputs.float()))
h_in = hidden_state.reshape(-1, self.rnn_hidden_dim)
@override(TorchModel)
def _forward(self, input_dict, hidden_state):
x = F.relu(self.fc1(input_dict["obs"]))
h_in = hidden_state[0].reshape(-1, self.rnn_hidden_dim)
h = self.rnn(x, h_in)
q = self.fc2(h)
return q, h
return q, h, None, [h]
def _get_size(obs_space):
return get_preprocessor(obs_space)(obs_space).size
@@ -12,13 +12,12 @@ from torch.distributions import Categorical
import ray
from ray.rllib.agents.qmix.mixers import VDNMixer, QMixer
from ray.rllib.agents.qmix.model import RNNModel
from ray.rllib.agents.qmix.model import RNNModel, _get_size
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.models.action_dist import TupleActions
from ray.rllib.models.pytorch.misc import var_to_np
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.lstm import chop_into_sequences
from ray.rllib.models.model import _unpack_obs
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.env.constants import GROUP_REWARDS
from ray.rllib.utils.annotations import override
@@ -61,7 +60,7 @@ class QMixLoss(nn.Module):
# Calculate estimated Q-Values
mac_out = []
h = self.model.init_hidden().expand([B, self.n_agents, -1])
h = [s.expand([B, self.n_agents, -1]) for s in self.model.state_init()]
for t in range(T):
q, h = _mac(self.model, obs[:, t], h)
mac_out.append(q)
@@ -73,8 +72,10 @@ class QMixLoss(nn.Module):
# Calculate the Q-Values necessary for the target
target_mac_out = []
target_h = self.target_model.init_hidden().expand(
[B, self.n_agents, -1])
target_h = [
s.expand([B, self.n_agents, -1])
for s in self.target_model.state_init()
]
for t in range(T):
target_q, target_h = _mac(self.target_model, obs[:, t], target_h)
target_mac_out.append(target_q)
@@ -154,13 +155,22 @@ class QMixPolicyGraph(PolicyGraph):
(self.n_actions, ), mask_shape))
self.has_action_mask = True
self.obs_size = _get_size(agent_obs_space.spaces["obs"])
# The real agent obs space is nested inside the dict
agent_obs_space = agent_obs_space.spaces["obs"]
else:
self.has_action_mask = False
self.obs_size = _get_size(agent_obs_space)
self.model = RNNModel(self.obs_size, self.h_size, self.n_actions)
self.target_model = RNNModel(self.obs_size, self.h_size,
self.n_actions)
self.model = ModelCatalog.get_torch_model(
agent_obs_space,
self.n_actions,
config["model"],
default_model_cls=RNNModel)
self.target_model = ModelCatalog.get_torch_model(
agent_obs_space,
self.n_actions,
config["model"],
default_model_cls=RNNModel)
# Setup the mixer network.
# The global state is just the stacked agent observations for now.
@@ -203,13 +213,12 @@ class QMixPolicyGraph(PolicyGraph):
episodes=None,
**kwargs):
obs_batch, action_mask = self._unpack_observation(obs_batch)
assert len(state_batches) == self.n_agents, state_batches
state_batches = np.stack(state_batches, axis=1)
# Compute actions
with th.no_grad():
q_values, hiddens = _mac(self.model, th.from_numpy(obs_batch),
th.from_numpy(state_batches))
q_values, hiddens = _mac(
self.model, th.from_numpy(obs_batch),
[th.from_numpy(np.array(s)) for s in state_batches])
avail = th.from_numpy(action_mask).float()
masked_q_values = q_values.clone()
masked_q_values[avail == 0.0] = -float("inf")
@@ -219,11 +228,10 @@ class QMixPolicyGraph(PolicyGraph):
random_actions = Categorical(avail).sample().long()
actions = (pick_random * random_actions +
(1 - pick_random) * masked_q_values.max(dim=2)[1])
actions = var_to_np(actions)
hiddens = var_to_np(hiddens)
actions = actions.numpy()
hiddens = [s.numpy() for s in hiddens]
return (TupleActions(list(actions.transpose([1, 0]))),
hiddens.transpose([1, 0, 2]), {})
return TupleActions(list(actions.transpose([1, 0]))), hiddens, {}
@override(PolicyGraph)
def compute_apply(self, samples):
@@ -239,7 +247,7 @@ class QMixPolicyGraph(PolicyGraph):
samples["dones"], obs_batch
],
[samples["state_in_{}".format(k)]
for k in range(self.n_agents)],
for k in range(len(self.get_initial_state()))],
max_seq_len=self.config["model"]["max_seq_len"],
dynamic_max=True,
_extra_padding=1)
@@ -292,8 +300,8 @@ class QMixPolicyGraph(PolicyGraph):
@override(PolicyGraph)
def get_initial_state(self):
return [
self.model.init_hidden().numpy().squeeze()
for _ in range(self.n_agents)
s.expand([self.n_agents, -1]).numpy()
for s in self.model.state_init()
]
@override(PolicyGraph)
@@ -342,6 +350,12 @@ class QMixPolicyGraph(PolicyGraph):
return group_rewards
def _unpack_observation(self, obs_batch):
"""Unpacks the action mask / tuple obs from agent grouping.
Returns:
obs (Tensor): flattened obs tensor of shape [B, n_agents, obs_size]
mask (Tensor): action mask, if any
"""
unpacked = _unpack_obs(
np.array(obs_batch),
self.observation_space.original_space,
@@ -388,17 +402,13 @@ def _validate(obs_space, action_space):
"must be homogeneous, got {}".format(action_space.spaces))
def _get_size(obs_space):
return get_preprocessor(obs_space)(obs_space).size
def _mac(model, obs, h):
"""Forward pass of the multi-agent controller.
Arguments:
model: Model that produces q-values for a 1d agent batch
model: TorchModel class
obs: Tensor of shape [B, n_agents, obs_size]
h: Tensor of shape [B, n_agents, h_size]
h: List of tensors of shape [B, n_agents, h_size]
Returns:
q_vals: Tensor of shape [B, n_agents, n_actions]
@@ -406,6 +416,7 @@ def _mac(model, obs, h):
"""
B, n_agents = obs.size(0), obs.size(1)
obs_flat = obs.reshape([B * n_agents, -1])
h_flat = h.reshape([B * n_agents, -1])
q_flat, h_flat = model.forward(obs_flat, h_flat)
return q_flat.reshape([B, n_agents, -1]), h_flat.reshape([B, n_agents, -1])
h_flat = [s.reshape([B * n_agents, -1]) for s in h]
q_flat, _, _, h_flat = model.forward({"obs": obs_flat}, h_flat)
return q_flat.reshape(
[B, n_agents, -1]), [s.reshape([B, n_agents, -1]) for s in h_flat]
@@ -8,7 +8,6 @@ from threading import Lock
try:
import torch
import torch.nn.functional as F
from ray.rllib.models.pytorch.misc import var_to_np
except ImportError:
pass # soft dep
@@ -66,15 +65,14 @@ class TorchPolicyGraph(PolicyGraph):
info_batch=None,
episodes=None,
**kwargs):
if state_batches:
raise NotImplementedError("Torch RNN support")
with self.lock:
with torch.no_grad():
ob = torch.from_numpy(np.array(obs_batch)).float()
model_out = self._model(ob)
logits = model_out[0] # assume the first output is the logits
model_out = self._model({"obs": ob}, state_batches)
logits, _, vf, state = model_out
actions = F.softmax(logits, dim=1).multinomial(1).squeeze(0)
return var_to_np(actions), [], self.extra_action_out(model_out)
return (actions.numpy(), [h.numpy() for h in state],
self.extra_action_out(model_out))
@override(PolicyGraph)
def compute_gradients(self, postprocessed_batch):
@@ -87,7 +85,7 @@ class TorchPolicyGraph(PolicyGraph):
loss_out.backward()
# Note that return values are just references;
# calling zero_grad will modify the values
grads = [var_to_np(p.grad.data) for p in self._model.parameters()]
grads = [p.grad.data.numpy() for p in self._model.parameters()]
return grads, {}
@override(PolicyGraph)
@@ -108,6 +106,10 @@ class TorchPolicyGraph(PolicyGraph):
with self.lock:
self._model.load_state_dict(weights)
@override(PolicyGraph)
def get_initial_state(self):
return [s.numpy() for s in self._model.state_init()]
def extra_action_out(self, model_out):
"""Returns dict of extra info to include in experience batch.
+24 -15
View File
@@ -52,8 +52,6 @@ MODEL_DEFAULTS = {
"framestack": True,
# Final resized frame dimension
"dim": 84,
# Pytorch conv requires images to be channel-major
"channel_major": False,
# (deprecated) Converts ATARI frame to 1 Channel Grayscale image
"grayscale": False,
# (deprecated) Changes frame to range from [-1, 1] if true
@@ -232,14 +230,17 @@ class ModelCatalog(object):
options)
@staticmethod
def get_torch_model(input_shape, num_outputs, options=None):
"""Returns a PyTorch suitable model. This is currently only supported
in A3C.
def get_torch_model(obs_space,
num_outputs,
options=None,
default_model_cls=None):
"""Returns a custom model for PyTorch algorithms.
Args:
input_shape (tuple): The input shape to the model.
obs_space (Space): The input observation space.
num_outputs (int): The size of the output vector of the model.
options (dict): Optional args to pass to the model constructor.
default_model_cls (cls): Optional class to use if no custom model.
Returns:
model (models.Model): Neural network model.
@@ -250,21 +251,29 @@ class ModelCatalog(object):
PyTorchVisionNet)
options = options or MODEL_DEFAULTS
if options.get("custom_model"):
model = options["custom_model"]
logger.info("Using custom torch model {}".format(model))
return _global_registry.get(RLLIB_MODEL, model)(
input_shape, num_outputs, options)
logger.debug("Using custom torch model {}".format(model))
return _global_registry.get(RLLIB_MODEL,
model)(obs_space, num_outputs, options)
# TODO(alok): fix to handle Discrete(n) state spaces
obs_rank = len(input_shape) - 1
if options.get("use_lstm"):
raise NotImplementedError(
"LSTM auto-wrapping not implemented for torch")
if default_model_cls:
return default_model_cls(obs_space, num_outputs, options)
if isinstance(obs_space, gym.spaces.Discrete):
obs_rank = 1
else:
obs_rank = len(obs_space.shape)
if obs_rank > 1:
return PyTorchVisionNet(input_shape, num_outputs, options)
return PyTorchVisionNet(obs_space, num_outputs, options)
# TODO(alok): overhaul PyTorchFCNet so it can just
# take input shape directly
return PyTorchFCNet(input_shape[0], num_outputs, options)
return PyTorchFCNet(obs_space, num_outputs, options)
@staticmethod
def get_preprocessor(env, options=None):
+5 -2
View File
@@ -160,11 +160,14 @@ class Model(object):
self._num_outputs, shape))
def _restore_original_dimensions(input_dict, obs_space):
def _restore_original_dimensions(input_dict, obs_space, tensorlib=tf):
if hasattr(obs_space, "original_space"):
return dict(
input_dict,
obs=_unpack_obs(input_dict["obs"], obs_space.original_space))
obs=_unpack_obs(
input_dict["obs"],
obs_space.original_space,
tensorlib=tensorlib))
return input_dict
-6
View File
@@ -64,15 +64,11 @@ class GenericPixelPreprocessor(Preprocessor):
self._grayscale = options.get("grayscale")
self._zero_mean = options.get("zero_mean")
self._dim = options.get("dim")
self._channel_major = options.get("channel_major")
if self._grayscale:
shape = (self._dim, self._dim, 1)
else:
shape = (self._dim, self._dim, 3)
# channel_major requires (# in-channels, row dim, col dim)
if self._channel_major:
shape = shape[-1:] + shape[:-1]
return shape
@override(Preprocessor)
@@ -94,8 +90,6 @@ class GenericPixelPreprocessor(Preprocessor):
scaled = (scaled - 128) / 128
else:
scaled *= 1.0 / 255.0
if self._channel_major:
scaled = np.reshape(scaled, self.shape)
return scaled
+26 -33
View File
@@ -3,30 +3,27 @@ from __future__ import division
from __future__ import print_function
import logging
from ray.rllib.models.pytorch.model import Model, SlimFC
from ray.rllib.models.pytorch.misc import normc_initializer
import numpy as np
import torch.nn as nn
from ray.rllib.models.pytorch.model import TorchModel
from ray.rllib.models.pytorch.misc import normc_initializer, SlimFC, \
_get_activation_fn
from ray.rllib.utils.annotations import override
logger = logging.getLogger(__name__)
class FullyConnectedNetwork(Model):
"""TODO(rliaw): Logits, Value should both be contained here"""
def _build_layers(self, inputs, num_outputs, options):
assert type(inputs) is int
hiddens = options.get("fcnet_hiddens", [256, 256])
fcnet_activation = options.get("fcnet_activation", "tanh")
activation = None
if fcnet_activation == "tanh":
activation = nn.Tanh
elif fcnet_activation == "relu":
activation = nn.ReLU
logger.info("Constructing fcnet {} {}".format(hiddens, activation))
class FullyConnectedNetwork(TorchModel):
"""Generic fully connected network."""
def __init__(self, obs_space, num_outputs, options):
TorchModel.__init__(self, obs_space, num_outputs, options)
hiddens = options.get("fcnet_hiddens")
activation = _get_activation_fn(options.get("fcnet_activation"))
logger.debug("Constructing fcnet {} {}".format(hiddens, activation))
layers = []
last_layer_size = inputs
last_layer_size = np.product(obs_space.shape)
for size in hiddens:
layers.append(
SlimFC(
@@ -36,29 +33,25 @@ class FullyConnectedNetwork(Model):
activation_fn=activation))
last_layer_size = size
self.hidden_layers = nn.Sequential(*layers)
self._hidden_layers = nn.Sequential(*layers)
self.logits = SlimFC(
self._logits = SlimFC(
in_size=last_layer_size,
out_size=num_outputs,
initializer=normc_initializer(0.01),
activation_fn=None)
self.value_branch = SlimFC(
self._value_branch = SlimFC(
in_size=last_layer_size,
out_size=1,
initializer=normc_initializer(1.0),
activation_fn=None)
def forward(self, obs):
""" Internal method - pass in torch tensors, not numpy arrays
Args:
obs: observations and features
Return:
logits: logits to be sampled from for each state
value: value function for each state"""
res = self.hidden_layers(obs)
logits = self.logits(res)
value = self.value_branch(res).squeeze(1)
return logits, value
@override(nn.Module)
def forward(self, input_dict, hidden_state):
# Note that we override forward() and not _forward() to get the
# flattened obs here.
obs = input_dict["obs"]
features = self._hidden_layers(obs.reshape(obs.shape[0], -1))
logits = self._logits(features)
value = self._value_branch(features).squeeze(1)
return logits, features, value, hidden_state
+66 -4
View File
@@ -5,10 +5,7 @@ from __future__ import print_function
import numpy as np
import torch
def var_to_np(var):
return var.cpu().detach().numpy()
import torch.nn as nn
def normc_initializer(std=1.0):
@@ -51,3 +48,68 @@ def valid_padding(in_size, filter_size, stride_size):
padding = (pad_left, pad_right, pad_top, pad_bottom)
output = (out_height, out_width)
return padding, output
def _get_activation_fn(name):
activation = None
if name == "tanh":
activation = nn.Tanh
elif name == "relu":
activation = nn.ReLU
else:
raise ValueError("Unknown activation: {}".format(name))
return activation
class SlimConv2d(nn.Module):
"""Simple mock of tf.slim Conv2d"""
def __init__(self,
in_channels,
out_channels,
kernel,
stride,
padding,
initializer=nn.init.xavier_uniform_,
activation_fn=nn.ReLU,
bias_init=0):
super(SlimConv2d, self).__init__()
layers = []
if padding:
layers.append(nn.ZeroPad2d(padding))
conv = nn.Conv2d(in_channels, out_channels, kernel, stride)
if initializer:
initializer(conv.weight)
nn.init.constant_(conv.bias, bias_init)
layers.append(conv)
if activation_fn:
layers.append(activation_fn())
self._model = nn.Sequential(*layers)
def forward(self, x):
return self._model(x)
class SlimFC(nn.Module):
"""Simple PyTorch version of `linear` function"""
def __init__(self,
in_size,
out_size,
initializer=None,
activation_fn=None,
bias_init=0):
super(SlimFC, self).__init__()
layers = []
linear = nn.Linear(in_size, out_size)
if initializer:
initializer(linear.weight)
nn.init.constant_(linear.bias, bias_init)
layers.append(linear)
if activation_fn:
layers.append(activation_fn())
self._model = nn.Sequential(*layers)
def forward(self, x):
return self._model(x)
+46 -68
View File
@@ -2,79 +2,57 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
import torch.nn as nn
from ray.rllib.models.model import _restore_original_dimensions
class Model(nn.Module):
def __init__(self, obs_space, ac_space, options):
super(Model, self).__init__()
self._build_layers(obs_space, ac_space, options)
def _build_layers(self, inputs, num_outputs, options):
raise NotImplementedError
class TorchModel(nn.Module):
"""Defines an abstract network model for use with RLlib / PyTorch."""
def forward(self, obs):
"""Forward pass for the model. Internal method - should only
be passed PyTorch Tensors.
def __init__(self, obs_space, num_outputs, options):
"""All custom RLlib torch models must support this constructor.
PyTorch automatically overloads the given model
with this function. Recommended that model(obs)
is used instead of model.forward(obs). See
https://discuss.pytorch.org/t/any-different-between-model
-input-and-model-forward-input/3690
Arguments:
obs_space (gym.Space): Input observation space.
num_outputs (int): Output tensor must be of size
[BATCH_SIZE, num_outputs].
options (dict): Dictionary of model options.
"""
nn.Module.__init__(self)
self.obs_space = obs_space
self.num_outputs = num_outputs
self.options = options
def forward(self, input_dict, hidden_state):
"""Wraps _forward() to unpack flattened Dict and Tuple observations."""
input_dict["obs"] = input_dict["obs"].float() # TODO(ekl): avoid cast
input_dict = _restore_original_dimensions(
input_dict, self.obs_space, tensorlib=torch)
outputs, features, vf, h = self._forward(input_dict, hidden_state)
return outputs, features, vf, h
def state_init(self):
"""Returns a list of initial hidden state tensors, if any."""
return []
def _forward(self, input_dict, hidden_state):
"""Forward pass for the model.
Prefer implementing this instead of forward() directly for proper
handling of Dict and Tuple observations.
Arguments:
input_dict (dict): Dictionary of tensor inputs, commonly
including "obs", "prev_action", "prev_reward", each of shape
[BATCH_SIZE, ...].
hidden_state (list): List of hidden state tensors, each of shape
[BATCH_SIZE, h_size].
Returns:
(outputs, feature_layer, values, state): Tensors of size
[BATCH_SIZE, num_outputs], [BATCH_SIZE, desired_feature_size],
[BATCH_SIZE], and [len(hidden_state), BATCH_SIZE, h_size].
"""
raise NotImplementedError
class SlimConv2d(nn.Module):
"""Simple mock of tf.slim Conv2d"""
def __init__(self,
in_channels,
out_channels,
kernel,
stride,
padding,
initializer=nn.init.xavier_uniform_,
activation_fn=nn.ReLU,
bias_init=0):
super(SlimConv2d, self).__init__()
layers = []
if padding:
layers.append(nn.ZeroPad2d(padding))
conv = nn.Conv2d(in_channels, out_channels, kernel, stride)
if initializer:
initializer(conv.weight)
nn.init.constant_(conv.bias, bias_init)
layers.append(conv)
if activation_fn:
layers.append(activation_fn())
self._model = nn.Sequential(*layers)
def forward(self, x):
return self._model(x)
class SlimFC(nn.Module):
"""Simple PyTorch of `linear` function"""
def __init__(self,
in_size,
out_size,
initializer=None,
activation_fn=None,
bias_init=0):
super(SlimFC, self).__init__()
layers = []
linear = nn.Linear(in_size, out_size)
if initializer:
initializer(linear.weight)
nn.init.constant_(linear.bias, bias_init)
layers.append(linear)
if activation_fn:
layers.append(activation_fn())
self._model = nn.Sequential(*layers)
def forward(self, x):
return self._model(x)
+24 -38
View File
@@ -4,28 +4,25 @@ from __future__ import print_function
import torch.nn as nn
from ray.rllib.models.pytorch.model import Model, SlimConv2d, SlimFC
from ray.rllib.models.pytorch.misc import normc_initializer, valid_padding
from ray.rllib.models.pytorch.model import TorchModel
from ray.rllib.models.pytorch.misc import normc_initializer, valid_padding, \
SlimConv2d, SlimFC
from ray.rllib.models.visionnet import _get_filter_config
from ray.rllib.utils.annotations import override
class VisionNetwork(Model):
"""Generic vision network"""
class VisionNetwork(TorchModel):
"""Generic vision network."""
def _build_layers(self, inputs, num_outputs, options):
"""TF visionnet in PyTorch.
Params:
inputs (tuple): (channels, rows/height, cols/width)
num_outputs (int): logits size
"""
filters = options.get("conv_filters") or [
[16, [8, 8], 4],
[32, [4, 4], 2],
[512, [11, 11], 1],
]
def __init__(self, obs_space, num_outputs, options):
TorchModel.__init__(self, obs_space, num_outputs, options)
filters = options.get("conv_filters")
if not filters:
filters = _get_filter_config(obs_space.shape)
layers = []
in_channels, in_size = inputs[0], inputs[1:]
(w, h, in_channels) = obs_space.shape
in_size = [w, h]
for out_channels, kernel, stride in filters[:-1]:
padding, out_size = valid_padding(in_size, kernel,
[stride, stride])
@@ -39,31 +36,20 @@ class VisionNetwork(Model):
SlimConv2d(in_channels, out_channels, kernel, stride, None))
self._convs = nn.Sequential(*layers)
self.logits = SlimFC(
self._logits = SlimFC(
out_channels, num_outputs, initializer=nn.init.xavier_uniform_)
self.value_branch = SlimFC(
self._value_branch = SlimFC(
out_channels, 1, initializer=normc_initializer())
def hidden_layers(self, obs):
""" Internal method - pass in torch tensors, not numpy arrays
@override(TorchModel)
def _forward(self, input_dict, hidden_state):
features = self._hidden_layers(input_dict["obs"])
logits = self._logits(features)
value = self._value_branch(features).squeeze(1)
return logits, features, value, hidden_state
args:
obs: observations and features"""
res = self._convs(obs)
def _hidden_layers(self, obs):
res = self._convs(obs.permute(0, 3, 1, 2)) # switch to channel-major
res = res.squeeze(3)
res = res.squeeze(2)
return res
def forward(self, obs):
"""Internal method. Implements the
Args:
obs (PyTorch): observations and features
Return:
logits (PyTorch): logits to be sampled from for each state
value (PyTorch): value function for each state"""
res = self.hidden_layers(obs)
logits = self.logits(res)
value = self.value_branch(res).squeeze(1)
return logits, value
+6 -6
View File
@@ -18,7 +18,7 @@ class VisionNetwork(Model):
inputs = input_dict["obs"]
filters = options.get("conv_filters")
if not filters:
filters = _get_filter_config(inputs)
filters = _get_filter_config(inputs.shape.as_list()[1:])
activation = get_activation_fn(options.get("conv_activation"))
@@ -49,7 +49,8 @@ class VisionNetwork(Model):
return flatten(fc2), flatten(fc1)
def _get_filter_config(inputs):
def _get_filter_config(shape):
shape = list(shape)
filters_84x84 = [
[16, [8, 8], 4],
[32, [4, 4], 2],
@@ -60,15 +61,14 @@ def _get_filter_config(inputs):
[32, [4, 4], 2],
[256, [11, 11], 1],
]
shape = inputs.shape.as_list()[1:]
if len(shape) == 3 and shape[:2] == [84, 84]:
return filters_84x84
elif len(shape) == 3 and shape[:2] == [42, 42]:
return filters_42x42
else:
raise ValueError(
"No default configuration for obs input {}".format(inputs) +
"No default configuration for obs shape {}".format(shape) +
", you must specify `conv_filters` manually as a model option. "
"Default configurations are only available for inputs of size "
"[?, 42, 42, K] and [?, 84, 84, K]. You may alternatively want "
"Default configurations are only available for inputs of shape "
"[42, 42, K] and [84, 84, K]. You may alternatively want "
"to use a custom model or preprocessor.")
@@ -12,6 +12,7 @@ import tensorflow as tf
import unittest
import ray
from ray.rllib.agents.a3c import A2CAgent
from ray.rllib.agents.pg import PGAgent
from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph
from ray.rllib.env import MultiAgentEnv
@@ -19,6 +20,8 @@ from ray.rllib.env.async_vector_env import AsyncVectorEnv
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.models import ModelCatalog
from ray.rllib.models.model import Model
from ray.rllib.models.pytorch.fcnet import FullyConnectedNetwork
from ray.rllib.models.pytorch.model import TorchModel
from ray.rllib.rollout import rollout
from ray.rllib.test.test_external_env import SimpleServing
from ray.tune.registry import register_env
@@ -129,6 +132,29 @@ class InvalidModel2(Model):
return tf.constant(0), tf.constant(0)
class TorchSpyModel(TorchModel):
capture_index = 0
def __init__(self, obs_space, num_outputs, options):
TorchModel.__init__(self, obs_space, num_outputs, options)
self.fc = FullyConnectedNetwork(
obs_space.original_space.spaces["sensors"].spaces["position"],
num_outputs, options)
def _forward(self, input_dict, hidden_state):
pos = input_dict["obs"]["sensors"]["position"].numpy()
front_cam = input_dict["obs"]["sensors"]["front_cam"][0].numpy()
task = input_dict["obs"]["inner_state"]["job_status"]["task"].numpy()
ray.experimental.internal_kv._internal_kv_put(
"torch_spy_in_{}".format(TorchSpyModel.capture_index),
pickle.dumps((pos, front_cam, task)),
overwrite=True)
TorchSpyModel.capture_index += 1
return self.fc({
"obs": input_dict["obs"]["sensors"]["position"]
}, hidden_state)
class DictSpyModel(Model):
capture_index = 0
@@ -359,6 +385,36 @@ class NestedSpacesTest(unittest.TestCase):
# Test rollout works on restore
rollout(agent2, "nested", 100)
def testPyTorchModel(self):
ModelCatalog.register_custom_model("composite", TorchSpyModel)
register_env("nested", lambda _: NestedDictEnv())
a2c = A2CAgent(
env="nested",
config={
"num_workers": 0,
"use_pytorch": True,
"sample_batch_size": 5,
"train_batch_size": 5,
"model": {
"custom_model": "composite",
},
})
a2c.train()
# Check that the model sees the correct reconstructed observations
for i in range(4):
seen = pickle.loads(
ray.experimental.internal_kv._internal_kv_get(
"torch_spy_in_{}".format(i)))
pos_i = DICT_SAMPLES[i]["sensors"]["position"].tolist()
cam_i = DICT_SAMPLES[i]["sensors"]["front_cam"][0].tolist()
task_i = one_hot(
DICT_SAMPLES[i]["inner_state"]["job_status"]["task"], 5)
self.assertEqual(seen[0][0].tolist(), pos_i)
self.assertEqual(seen[1][0].tolist(), cam_i)
self.assertEqual(seen[2][0].tolist(), task_i)
if __name__ == "__main__":
ray.init(num_cpus=5)
@@ -14,7 +14,6 @@ pong-a3c-pytorch-cnn:
observation_filter: NoFilter
model:
use_lstm: false
channel_major: true
dim: 84
grayscale: true
zero_mean: false
@@ -0,0 +1,9 @@
cartpole-a2c-torch:
env: CartPole-v0
run: A2C
stop:
episode_reward_mean: 100
timesteps_total: 100000
config:
num_workers: 0
use_pytorch: true