mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 21:12:15 +08:00
[rllib] Refactor pytorch custom model support (#3634)
This commit is contained in:
@@ -7,7 +7,6 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
import ray
|
||||
from ray.rllib.models.pytorch.misc import var_to_np
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
@@ -23,7 +22,7 @@ class A3CLoss(nn.Module):
|
||||
self.entropy_coeff = entropy_coeff
|
||||
|
||||
def forward(self, observations, actions, advantages, value_targets):
|
||||
logits, values = self.policy_model(observations)
|
||||
logits, _, values, _ = self.policy_model({"obs": observations}, [])
|
||||
log_probs = F.log_softmax(logits, dim=1)
|
||||
probs = F.softmax(logits, dim=1)
|
||||
action_log_probs = log_probs.gather(1, actions.view(-1, 1))
|
||||
@@ -46,8 +45,8 @@ class A3CTorchPolicyGraph(TorchPolicyGraph):
|
||||
self.config = config
|
||||
_, self.logit_dim = ModelCatalog.get_action_dist(
|
||||
action_space, self.config["model"])
|
||||
self.model = ModelCatalog.get_torch_model(
|
||||
obs_space.shape, self.logit_dim, self.config["model"])
|
||||
self.model = ModelCatalog.get_torch_model(obs_space, self.logit_dim,
|
||||
self.config["model"])
|
||||
loss = A3CLoss(self.model, self.config["vf_loss_coeff"],
|
||||
self.config["entropy_coeff"])
|
||||
TorchPolicyGraph.__init__(
|
||||
@@ -60,7 +59,7 @@ class A3CTorchPolicyGraph(TorchPolicyGraph):
|
||||
|
||||
@override(TorchPolicyGraph)
|
||||
def extra_action_out(self, model_out):
|
||||
return {"vf_preds": var_to_np(model_out[1])}
|
||||
return {"vf_preds": model_out[2].numpy()}
|
||||
|
||||
@override(TorchPolicyGraph)
|
||||
def optimizer(self):
|
||||
@@ -82,7 +81,5 @@ class A3CTorchPolicyGraph(TorchPolicyGraph):
|
||||
def _value(self, obs):
|
||||
with self.lock:
|
||||
obs = torch.from_numpy(obs).float().unsqueeze(0)
|
||||
res = self.model.hidden_layers(obs)
|
||||
res = self.model.value_branch(res)
|
||||
res = res.squeeze()
|
||||
return var_to_np(res)
|
||||
_, _, vf, _ = self.model({"obs": obs}, [])
|
||||
return vf.numpy().squeeze()
|
||||
|
||||
@@ -5,24 +5,35 @@ from __future__ import print_function
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ray.rllib.models.preprocessors import get_preprocessor
|
||||
from ray.rllib.models.pytorch.model import TorchModel
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
# TODO(ekl) we should have common models for pytorch like we do for TF
|
||||
class RNNModel(nn.Module):
|
||||
def __init__(self, obs_size, rnn_hidden_dim, n_actions):
|
||||
nn.Module.__init__(self)
|
||||
self.rnn_hidden_dim = rnn_hidden_dim
|
||||
self.n_actions = n_actions
|
||||
self.fc1 = nn.Linear(obs_size, rnn_hidden_dim)
|
||||
self.rnn = nn.GRUCell(rnn_hidden_dim, rnn_hidden_dim)
|
||||
self.fc2 = nn.Linear(rnn_hidden_dim, n_actions)
|
||||
|
||||
def init_hidden(self):
|
||||
class RNNModel(TorchModel):
|
||||
"""The default RNN model for QMIX."""
|
||||
|
||||
def __init__(self, obs_space, num_outputs, options):
|
||||
TorchModel.__init__(self, obs_space, num_outputs, options)
|
||||
self.obs_size = _get_size(obs_space)
|
||||
self.rnn_hidden_dim = options["lstm_cell_size"]
|
||||
self.fc1 = nn.Linear(self.obs_size, self.rnn_hidden_dim)
|
||||
self.rnn = nn.GRUCell(self.rnn_hidden_dim, self.rnn_hidden_dim)
|
||||
self.fc2 = nn.Linear(self.rnn_hidden_dim, num_outputs)
|
||||
|
||||
@override(TorchModel)
|
||||
def state_init(self):
|
||||
# make hidden states on same device as model
|
||||
return self.fc1.weight.new(1, self.rnn_hidden_dim).zero_()
|
||||
return [self.fc1.weight.new(1, self.rnn_hidden_dim).zero_().squeeze(0)]
|
||||
|
||||
def forward(self, inputs, hidden_state):
|
||||
x = F.relu(self.fc1(inputs.float()))
|
||||
h_in = hidden_state.reshape(-1, self.rnn_hidden_dim)
|
||||
@override(TorchModel)
|
||||
def _forward(self, input_dict, hidden_state):
|
||||
x = F.relu(self.fc1(input_dict["obs"]))
|
||||
h_in = hidden_state[0].reshape(-1, self.rnn_hidden_dim)
|
||||
h = self.rnn(x, h_in)
|
||||
q = self.fc2(h)
|
||||
return q, h
|
||||
return q, h, None, [h]
|
||||
|
||||
|
||||
def _get_size(obs_space):
|
||||
return get_preprocessor(obs_space)(obs_space).size
|
||||
|
||||
@@ -12,13 +12,12 @@ from torch.distributions import Categorical
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.qmix.mixers import VDNMixer, QMixer
|
||||
from ray.rllib.agents.qmix.model import RNNModel
|
||||
from ray.rllib.agents.qmix.model import RNNModel, _get_size
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.models.action_dist import TupleActions
|
||||
from ray.rllib.models.pytorch.misc import var_to_np
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.lstm import chop_into_sequences
|
||||
from ray.rllib.models.model import _unpack_obs
|
||||
from ray.rllib.models.preprocessors import get_preprocessor
|
||||
from ray.rllib.env.constants import GROUP_REWARDS
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
@@ -61,7 +60,7 @@ class QMixLoss(nn.Module):
|
||||
|
||||
# Calculate estimated Q-Values
|
||||
mac_out = []
|
||||
h = self.model.init_hidden().expand([B, self.n_agents, -1])
|
||||
h = [s.expand([B, self.n_agents, -1]) for s in self.model.state_init()]
|
||||
for t in range(T):
|
||||
q, h = _mac(self.model, obs[:, t], h)
|
||||
mac_out.append(q)
|
||||
@@ -73,8 +72,10 @@ class QMixLoss(nn.Module):
|
||||
|
||||
# Calculate the Q-Values necessary for the target
|
||||
target_mac_out = []
|
||||
target_h = self.target_model.init_hidden().expand(
|
||||
[B, self.n_agents, -1])
|
||||
target_h = [
|
||||
s.expand([B, self.n_agents, -1])
|
||||
for s in self.target_model.state_init()
|
||||
]
|
||||
for t in range(T):
|
||||
target_q, target_h = _mac(self.target_model, obs[:, t], target_h)
|
||||
target_mac_out.append(target_q)
|
||||
@@ -154,13 +155,22 @@ class QMixPolicyGraph(PolicyGraph):
|
||||
(self.n_actions, ), mask_shape))
|
||||
self.has_action_mask = True
|
||||
self.obs_size = _get_size(agent_obs_space.spaces["obs"])
|
||||
# The real agent obs space is nested inside the dict
|
||||
agent_obs_space = agent_obs_space.spaces["obs"]
|
||||
else:
|
||||
self.has_action_mask = False
|
||||
self.obs_size = _get_size(agent_obs_space)
|
||||
|
||||
self.model = RNNModel(self.obs_size, self.h_size, self.n_actions)
|
||||
self.target_model = RNNModel(self.obs_size, self.h_size,
|
||||
self.n_actions)
|
||||
self.model = ModelCatalog.get_torch_model(
|
||||
agent_obs_space,
|
||||
self.n_actions,
|
||||
config["model"],
|
||||
default_model_cls=RNNModel)
|
||||
self.target_model = ModelCatalog.get_torch_model(
|
||||
agent_obs_space,
|
||||
self.n_actions,
|
||||
config["model"],
|
||||
default_model_cls=RNNModel)
|
||||
|
||||
# Setup the mixer network.
|
||||
# The global state is just the stacked agent observations for now.
|
||||
@@ -203,13 +213,12 @@ class QMixPolicyGraph(PolicyGraph):
|
||||
episodes=None,
|
||||
**kwargs):
|
||||
obs_batch, action_mask = self._unpack_observation(obs_batch)
|
||||
assert len(state_batches) == self.n_agents, state_batches
|
||||
state_batches = np.stack(state_batches, axis=1)
|
||||
|
||||
# Compute actions
|
||||
with th.no_grad():
|
||||
q_values, hiddens = _mac(self.model, th.from_numpy(obs_batch),
|
||||
th.from_numpy(state_batches))
|
||||
q_values, hiddens = _mac(
|
||||
self.model, th.from_numpy(obs_batch),
|
||||
[th.from_numpy(np.array(s)) for s in state_batches])
|
||||
avail = th.from_numpy(action_mask).float()
|
||||
masked_q_values = q_values.clone()
|
||||
masked_q_values[avail == 0.0] = -float("inf")
|
||||
@@ -219,11 +228,10 @@ class QMixPolicyGraph(PolicyGraph):
|
||||
random_actions = Categorical(avail).sample().long()
|
||||
actions = (pick_random * random_actions +
|
||||
(1 - pick_random) * masked_q_values.max(dim=2)[1])
|
||||
actions = var_to_np(actions)
|
||||
hiddens = var_to_np(hiddens)
|
||||
actions = actions.numpy()
|
||||
hiddens = [s.numpy() for s in hiddens]
|
||||
|
||||
return (TupleActions(list(actions.transpose([1, 0]))),
|
||||
hiddens.transpose([1, 0, 2]), {})
|
||||
return TupleActions(list(actions.transpose([1, 0]))), hiddens, {}
|
||||
|
||||
@override(PolicyGraph)
|
||||
def compute_apply(self, samples):
|
||||
@@ -239,7 +247,7 @@ class QMixPolicyGraph(PolicyGraph):
|
||||
samples["dones"], obs_batch
|
||||
],
|
||||
[samples["state_in_{}".format(k)]
|
||||
for k in range(self.n_agents)],
|
||||
for k in range(len(self.get_initial_state()))],
|
||||
max_seq_len=self.config["model"]["max_seq_len"],
|
||||
dynamic_max=True,
|
||||
_extra_padding=1)
|
||||
@@ -292,8 +300,8 @@ class QMixPolicyGraph(PolicyGraph):
|
||||
@override(PolicyGraph)
|
||||
def get_initial_state(self):
|
||||
return [
|
||||
self.model.init_hidden().numpy().squeeze()
|
||||
for _ in range(self.n_agents)
|
||||
s.expand([self.n_agents, -1]).numpy()
|
||||
for s in self.model.state_init()
|
||||
]
|
||||
|
||||
@override(PolicyGraph)
|
||||
@@ -342,6 +350,12 @@ class QMixPolicyGraph(PolicyGraph):
|
||||
return group_rewards
|
||||
|
||||
def _unpack_observation(self, obs_batch):
|
||||
"""Unpacks the action mask / tuple obs from agent grouping.
|
||||
|
||||
Returns:
|
||||
obs (Tensor): flattened obs tensor of shape [B, n_agents, obs_size]
|
||||
mask (Tensor): action mask, if any
|
||||
"""
|
||||
unpacked = _unpack_obs(
|
||||
np.array(obs_batch),
|
||||
self.observation_space.original_space,
|
||||
@@ -388,17 +402,13 @@ def _validate(obs_space, action_space):
|
||||
"must be homogeneous, got {}".format(action_space.spaces))
|
||||
|
||||
|
||||
def _get_size(obs_space):
|
||||
return get_preprocessor(obs_space)(obs_space).size
|
||||
|
||||
|
||||
def _mac(model, obs, h):
|
||||
"""Forward pass of the multi-agent controller.
|
||||
|
||||
Arguments:
|
||||
model: Model that produces q-values for a 1d agent batch
|
||||
model: TorchModel class
|
||||
obs: Tensor of shape [B, n_agents, obs_size]
|
||||
h: Tensor of shape [B, n_agents, h_size]
|
||||
h: List of tensors of shape [B, n_agents, h_size]
|
||||
|
||||
Returns:
|
||||
q_vals: Tensor of shape [B, n_agents, n_actions]
|
||||
@@ -406,6 +416,7 @@ def _mac(model, obs, h):
|
||||
"""
|
||||
B, n_agents = obs.size(0), obs.size(1)
|
||||
obs_flat = obs.reshape([B * n_agents, -1])
|
||||
h_flat = h.reshape([B * n_agents, -1])
|
||||
q_flat, h_flat = model.forward(obs_flat, h_flat)
|
||||
return q_flat.reshape([B, n_agents, -1]), h_flat.reshape([B, n_agents, -1])
|
||||
h_flat = [s.reshape([B * n_agents, -1]) for s in h]
|
||||
q_flat, _, _, h_flat = model.forward({"obs": obs_flat}, h_flat)
|
||||
return q_flat.reshape(
|
||||
[B, n_agents, -1]), [s.reshape([B, n_agents, -1]) for s in h_flat]
|
||||
|
||||
@@ -8,7 +8,6 @@ from threading import Lock
|
||||
try:
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from ray.rllib.models.pytorch.misc import var_to_np
|
||||
except ImportError:
|
||||
pass # soft dep
|
||||
|
||||
@@ -66,15 +65,14 @@ class TorchPolicyGraph(PolicyGraph):
|
||||
info_batch=None,
|
||||
episodes=None,
|
||||
**kwargs):
|
||||
if state_batches:
|
||||
raise NotImplementedError("Torch RNN support")
|
||||
with self.lock:
|
||||
with torch.no_grad():
|
||||
ob = torch.from_numpy(np.array(obs_batch)).float()
|
||||
model_out = self._model(ob)
|
||||
logits = model_out[0] # assume the first output is the logits
|
||||
model_out = self._model({"obs": ob}, state_batches)
|
||||
logits, _, vf, state = model_out
|
||||
actions = F.softmax(logits, dim=1).multinomial(1).squeeze(0)
|
||||
return var_to_np(actions), [], self.extra_action_out(model_out)
|
||||
return (actions.numpy(), [h.numpy() for h in state],
|
||||
self.extra_action_out(model_out))
|
||||
|
||||
@override(PolicyGraph)
|
||||
def compute_gradients(self, postprocessed_batch):
|
||||
@@ -87,7 +85,7 @@ class TorchPolicyGraph(PolicyGraph):
|
||||
loss_out.backward()
|
||||
# Note that return values are just references;
|
||||
# calling zero_grad will modify the values
|
||||
grads = [var_to_np(p.grad.data) for p in self._model.parameters()]
|
||||
grads = [p.grad.data.numpy() for p in self._model.parameters()]
|
||||
return grads, {}
|
||||
|
||||
@override(PolicyGraph)
|
||||
@@ -108,6 +106,10 @@ class TorchPolicyGraph(PolicyGraph):
|
||||
with self.lock:
|
||||
self._model.load_state_dict(weights)
|
||||
|
||||
@override(PolicyGraph)
|
||||
def get_initial_state(self):
|
||||
return [s.numpy() for s in self._model.state_init()]
|
||||
|
||||
def extra_action_out(self, model_out):
|
||||
"""Returns dict of extra info to include in experience batch.
|
||||
|
||||
|
||||
@@ -52,8 +52,6 @@ MODEL_DEFAULTS = {
|
||||
"framestack": True,
|
||||
# Final resized frame dimension
|
||||
"dim": 84,
|
||||
# Pytorch conv requires images to be channel-major
|
||||
"channel_major": False,
|
||||
# (deprecated) Converts ATARI frame to 1 Channel Grayscale image
|
||||
"grayscale": False,
|
||||
# (deprecated) Changes frame to range from [-1, 1] if true
|
||||
@@ -232,14 +230,17 @@ class ModelCatalog(object):
|
||||
options)
|
||||
|
||||
@staticmethod
|
||||
def get_torch_model(input_shape, num_outputs, options=None):
|
||||
"""Returns a PyTorch suitable model. This is currently only supported
|
||||
in A3C.
|
||||
def get_torch_model(obs_space,
|
||||
num_outputs,
|
||||
options=None,
|
||||
default_model_cls=None):
|
||||
"""Returns a custom model for PyTorch algorithms.
|
||||
|
||||
Args:
|
||||
input_shape (tuple): The input shape to the model.
|
||||
obs_space (Space): The input observation space.
|
||||
num_outputs (int): The size of the output vector of the model.
|
||||
options (dict): Optional args to pass to the model constructor.
|
||||
default_model_cls (cls): Optional class to use if no custom model.
|
||||
|
||||
Returns:
|
||||
model (models.Model): Neural network model.
|
||||
@@ -250,21 +251,29 @@ class ModelCatalog(object):
|
||||
PyTorchVisionNet)
|
||||
|
||||
options = options or MODEL_DEFAULTS
|
||||
|
||||
if options.get("custom_model"):
|
||||
model = options["custom_model"]
|
||||
logger.info("Using custom torch model {}".format(model))
|
||||
return _global_registry.get(RLLIB_MODEL, model)(
|
||||
input_shape, num_outputs, options)
|
||||
logger.debug("Using custom torch model {}".format(model))
|
||||
return _global_registry.get(RLLIB_MODEL,
|
||||
model)(obs_space, num_outputs, options)
|
||||
|
||||
# TODO(alok): fix to handle Discrete(n) state spaces
|
||||
obs_rank = len(input_shape) - 1
|
||||
if options.get("use_lstm"):
|
||||
raise NotImplementedError(
|
||||
"LSTM auto-wrapping not implemented for torch")
|
||||
|
||||
if default_model_cls:
|
||||
return default_model_cls(obs_space, num_outputs, options)
|
||||
|
||||
if isinstance(obs_space, gym.spaces.Discrete):
|
||||
obs_rank = 1
|
||||
else:
|
||||
obs_rank = len(obs_space.shape)
|
||||
|
||||
if obs_rank > 1:
|
||||
return PyTorchVisionNet(input_shape, num_outputs, options)
|
||||
return PyTorchVisionNet(obs_space, num_outputs, options)
|
||||
|
||||
# TODO(alok): overhaul PyTorchFCNet so it can just
|
||||
# take input shape directly
|
||||
return PyTorchFCNet(input_shape[0], num_outputs, options)
|
||||
return PyTorchFCNet(obs_space, num_outputs, options)
|
||||
|
||||
@staticmethod
|
||||
def get_preprocessor(env, options=None):
|
||||
|
||||
@@ -160,11 +160,14 @@ class Model(object):
|
||||
self._num_outputs, shape))
|
||||
|
||||
|
||||
def _restore_original_dimensions(input_dict, obs_space):
|
||||
def _restore_original_dimensions(input_dict, obs_space, tensorlib=tf):
|
||||
if hasattr(obs_space, "original_space"):
|
||||
return dict(
|
||||
input_dict,
|
||||
obs=_unpack_obs(input_dict["obs"], obs_space.original_space))
|
||||
obs=_unpack_obs(
|
||||
input_dict["obs"],
|
||||
obs_space.original_space,
|
||||
tensorlib=tensorlib))
|
||||
return input_dict
|
||||
|
||||
|
||||
|
||||
@@ -64,15 +64,11 @@ class GenericPixelPreprocessor(Preprocessor):
|
||||
self._grayscale = options.get("grayscale")
|
||||
self._zero_mean = options.get("zero_mean")
|
||||
self._dim = options.get("dim")
|
||||
self._channel_major = options.get("channel_major")
|
||||
if self._grayscale:
|
||||
shape = (self._dim, self._dim, 1)
|
||||
else:
|
||||
shape = (self._dim, self._dim, 3)
|
||||
|
||||
# channel_major requires (# in-channels, row dim, col dim)
|
||||
if self._channel_major:
|
||||
shape = shape[-1:] + shape[:-1]
|
||||
return shape
|
||||
|
||||
@override(Preprocessor)
|
||||
@@ -94,8 +90,6 @@ class GenericPixelPreprocessor(Preprocessor):
|
||||
scaled = (scaled - 128) / 128
|
||||
else:
|
||||
scaled *= 1.0 / 255.0
|
||||
if self._channel_major:
|
||||
scaled = np.reshape(scaled, self.shape)
|
||||
return scaled
|
||||
|
||||
|
||||
|
||||
@@ -3,30 +3,27 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
|
||||
from ray.rllib.models.pytorch.model import Model, SlimFC
|
||||
from ray.rllib.models.pytorch.misc import normc_initializer
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
|
||||
from ray.rllib.models.pytorch.model import TorchModel
|
||||
from ray.rllib.models.pytorch.misc import normc_initializer, SlimFC, \
|
||||
_get_activation_fn
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FullyConnectedNetwork(Model):
|
||||
"""TODO(rliaw): Logits, Value should both be contained here"""
|
||||
|
||||
def _build_layers(self, inputs, num_outputs, options):
|
||||
assert type(inputs) is int
|
||||
hiddens = options.get("fcnet_hiddens", [256, 256])
|
||||
fcnet_activation = options.get("fcnet_activation", "tanh")
|
||||
activation = None
|
||||
if fcnet_activation == "tanh":
|
||||
activation = nn.Tanh
|
||||
elif fcnet_activation == "relu":
|
||||
activation = nn.ReLU
|
||||
logger.info("Constructing fcnet {} {}".format(hiddens, activation))
|
||||
class FullyConnectedNetwork(TorchModel):
|
||||
"""Generic fully connected network."""
|
||||
|
||||
def __init__(self, obs_space, num_outputs, options):
|
||||
TorchModel.__init__(self, obs_space, num_outputs, options)
|
||||
hiddens = options.get("fcnet_hiddens")
|
||||
activation = _get_activation_fn(options.get("fcnet_activation"))
|
||||
logger.debug("Constructing fcnet {} {}".format(hiddens, activation))
|
||||
layers = []
|
||||
last_layer_size = inputs
|
||||
last_layer_size = np.product(obs_space.shape)
|
||||
for size in hiddens:
|
||||
layers.append(
|
||||
SlimFC(
|
||||
@@ -36,29 +33,25 @@ class FullyConnectedNetwork(Model):
|
||||
activation_fn=activation))
|
||||
last_layer_size = size
|
||||
|
||||
self.hidden_layers = nn.Sequential(*layers)
|
||||
self._hidden_layers = nn.Sequential(*layers)
|
||||
|
||||
self.logits = SlimFC(
|
||||
self._logits = SlimFC(
|
||||
in_size=last_layer_size,
|
||||
out_size=num_outputs,
|
||||
initializer=normc_initializer(0.01),
|
||||
activation_fn=None)
|
||||
self.value_branch = SlimFC(
|
||||
self._value_branch = SlimFC(
|
||||
in_size=last_layer_size,
|
||||
out_size=1,
|
||||
initializer=normc_initializer(1.0),
|
||||
activation_fn=None)
|
||||
|
||||
def forward(self, obs):
|
||||
""" Internal method - pass in torch tensors, not numpy arrays
|
||||
|
||||
Args:
|
||||
obs: observations and features
|
||||
|
||||
Return:
|
||||
logits: logits to be sampled from for each state
|
||||
value: value function for each state"""
|
||||
res = self.hidden_layers(obs)
|
||||
logits = self.logits(res)
|
||||
value = self.value_branch(res).squeeze(1)
|
||||
return logits, value
|
||||
@override(nn.Module)
|
||||
def forward(self, input_dict, hidden_state):
|
||||
# Note that we override forward() and not _forward() to get the
|
||||
# flattened obs here.
|
||||
obs = input_dict["obs"]
|
||||
features = self._hidden_layers(obs.reshape(obs.shape[0], -1))
|
||||
logits = self._logits(features)
|
||||
value = self._value_branch(features).squeeze(1)
|
||||
return logits, features, value, hidden_state
|
||||
|
||||
@@ -5,10 +5,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def var_to_np(var):
|
||||
return var.cpu().detach().numpy()
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def normc_initializer(std=1.0):
|
||||
@@ -51,3 +48,68 @@ def valid_padding(in_size, filter_size, stride_size):
|
||||
padding = (pad_left, pad_right, pad_top, pad_bottom)
|
||||
output = (out_height, out_width)
|
||||
return padding, output
|
||||
|
||||
|
||||
def _get_activation_fn(name):
|
||||
activation = None
|
||||
if name == "tanh":
|
||||
activation = nn.Tanh
|
||||
elif name == "relu":
|
||||
activation = nn.ReLU
|
||||
else:
|
||||
raise ValueError("Unknown activation: {}".format(name))
|
||||
return activation
|
||||
|
||||
|
||||
class SlimConv2d(nn.Module):
|
||||
"""Simple mock of tf.slim Conv2d"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel,
|
||||
stride,
|
||||
padding,
|
||||
initializer=nn.init.xavier_uniform_,
|
||||
activation_fn=nn.ReLU,
|
||||
bias_init=0):
|
||||
super(SlimConv2d, self).__init__()
|
||||
layers = []
|
||||
if padding:
|
||||
layers.append(nn.ZeroPad2d(padding))
|
||||
conv = nn.Conv2d(in_channels, out_channels, kernel, stride)
|
||||
if initializer:
|
||||
initializer(conv.weight)
|
||||
nn.init.constant_(conv.bias, bias_init)
|
||||
|
||||
layers.append(conv)
|
||||
if activation_fn:
|
||||
layers.append(activation_fn())
|
||||
self._model = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self._model(x)
|
||||
|
||||
|
||||
class SlimFC(nn.Module):
|
||||
"""Simple PyTorch version of `linear` function"""
|
||||
|
||||
def __init__(self,
|
||||
in_size,
|
||||
out_size,
|
||||
initializer=None,
|
||||
activation_fn=None,
|
||||
bias_init=0):
|
||||
super(SlimFC, self).__init__()
|
||||
layers = []
|
||||
linear = nn.Linear(in_size, out_size)
|
||||
if initializer:
|
||||
initializer(linear.weight)
|
||||
nn.init.constant_(linear.bias, bias_init)
|
||||
layers.append(linear)
|
||||
if activation_fn:
|
||||
layers.append(activation_fn())
|
||||
self._model = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self._model(x)
|
||||
|
||||
@@ -2,79 +2,57 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ray.rllib.models.model import _restore_original_dimensions
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, obs_space, ac_space, options):
|
||||
super(Model, self).__init__()
|
||||
self._build_layers(obs_space, ac_space, options)
|
||||
|
||||
def _build_layers(self, inputs, num_outputs, options):
|
||||
raise NotImplementedError
|
||||
class TorchModel(nn.Module):
|
||||
"""Defines an abstract network model for use with RLlib / PyTorch."""
|
||||
|
||||
def forward(self, obs):
|
||||
"""Forward pass for the model. Internal method - should only
|
||||
be passed PyTorch Tensors.
|
||||
def __init__(self, obs_space, num_outputs, options):
|
||||
"""All custom RLlib torch models must support this constructor.
|
||||
|
||||
PyTorch automatically overloads the given model
|
||||
with this function. Recommended that model(obs)
|
||||
is used instead of model.forward(obs). See
|
||||
https://discuss.pytorch.org/t/any-different-between-model
|
||||
-input-and-model-forward-input/3690
|
||||
Arguments:
|
||||
obs_space (gym.Space): Input observation space.
|
||||
num_outputs (int): Output tensor must be of size
|
||||
[BATCH_SIZE, num_outputs].
|
||||
options (dict): Dictionary of model options.
|
||||
"""
|
||||
nn.Module.__init__(self)
|
||||
self.obs_space = obs_space
|
||||
self.num_outputs = num_outputs
|
||||
self.options = options
|
||||
|
||||
def forward(self, input_dict, hidden_state):
|
||||
"""Wraps _forward() to unpack flattened Dict and Tuple observations."""
|
||||
input_dict["obs"] = input_dict["obs"].float() # TODO(ekl): avoid cast
|
||||
input_dict = _restore_original_dimensions(
|
||||
input_dict, self.obs_space, tensorlib=torch)
|
||||
outputs, features, vf, h = self._forward(input_dict, hidden_state)
|
||||
return outputs, features, vf, h
|
||||
|
||||
def state_init(self):
|
||||
"""Returns a list of initial hidden state tensors, if any."""
|
||||
return []
|
||||
|
||||
def _forward(self, input_dict, hidden_state):
|
||||
"""Forward pass for the model.
|
||||
|
||||
Prefer implementing this instead of forward() directly for proper
|
||||
handling of Dict and Tuple observations.
|
||||
|
||||
Arguments:
|
||||
input_dict (dict): Dictionary of tensor inputs, commonly
|
||||
including "obs", "prev_action", "prev_reward", each of shape
|
||||
[BATCH_SIZE, ...].
|
||||
hidden_state (list): List of hidden state tensors, each of shape
|
||||
[BATCH_SIZE, h_size].
|
||||
|
||||
Returns:
|
||||
(outputs, feature_layer, values, state): Tensors of size
|
||||
[BATCH_SIZE, num_outputs], [BATCH_SIZE, desired_feature_size],
|
||||
[BATCH_SIZE], and [len(hidden_state), BATCH_SIZE, h_size].
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SlimConv2d(nn.Module):
|
||||
"""Simple mock of tf.slim Conv2d"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel,
|
||||
stride,
|
||||
padding,
|
||||
initializer=nn.init.xavier_uniform_,
|
||||
activation_fn=nn.ReLU,
|
||||
bias_init=0):
|
||||
super(SlimConv2d, self).__init__()
|
||||
layers = []
|
||||
if padding:
|
||||
layers.append(nn.ZeroPad2d(padding))
|
||||
conv = nn.Conv2d(in_channels, out_channels, kernel, stride)
|
||||
if initializer:
|
||||
initializer(conv.weight)
|
||||
nn.init.constant_(conv.bias, bias_init)
|
||||
|
||||
layers.append(conv)
|
||||
if activation_fn:
|
||||
layers.append(activation_fn())
|
||||
self._model = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self._model(x)
|
||||
|
||||
|
||||
class SlimFC(nn.Module):
|
||||
"""Simple PyTorch of `linear` function"""
|
||||
|
||||
def __init__(self,
|
||||
in_size,
|
||||
out_size,
|
||||
initializer=None,
|
||||
activation_fn=None,
|
||||
bias_init=0):
|
||||
super(SlimFC, self).__init__()
|
||||
layers = []
|
||||
linear = nn.Linear(in_size, out_size)
|
||||
if initializer:
|
||||
initializer(linear.weight)
|
||||
nn.init.constant_(linear.bias, bias_init)
|
||||
layers.append(linear)
|
||||
if activation_fn:
|
||||
layers.append(activation_fn())
|
||||
self._model = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self._model(x)
|
||||
|
||||
@@ -4,28 +4,25 @@ from __future__ import print_function
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from ray.rllib.models.pytorch.model import Model, SlimConv2d, SlimFC
|
||||
from ray.rllib.models.pytorch.misc import normc_initializer, valid_padding
|
||||
from ray.rllib.models.pytorch.model import TorchModel
|
||||
from ray.rllib.models.pytorch.misc import normc_initializer, valid_padding, \
|
||||
SlimConv2d, SlimFC
|
||||
from ray.rllib.models.visionnet import _get_filter_config
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
|
||||
class VisionNetwork(Model):
|
||||
"""Generic vision network"""
|
||||
class VisionNetwork(TorchModel):
|
||||
"""Generic vision network."""
|
||||
|
||||
def _build_layers(self, inputs, num_outputs, options):
|
||||
"""TF visionnet in PyTorch.
|
||||
|
||||
Params:
|
||||
inputs (tuple): (channels, rows/height, cols/width)
|
||||
num_outputs (int): logits size
|
||||
"""
|
||||
filters = options.get("conv_filters") or [
|
||||
[16, [8, 8], 4],
|
||||
[32, [4, 4], 2],
|
||||
[512, [11, 11], 1],
|
||||
]
|
||||
def __init__(self, obs_space, num_outputs, options):
|
||||
TorchModel.__init__(self, obs_space, num_outputs, options)
|
||||
filters = options.get("conv_filters")
|
||||
if not filters:
|
||||
filters = _get_filter_config(obs_space.shape)
|
||||
layers = []
|
||||
in_channels, in_size = inputs[0], inputs[1:]
|
||||
|
||||
(w, h, in_channels) = obs_space.shape
|
||||
in_size = [w, h]
|
||||
for out_channels, kernel, stride in filters[:-1]:
|
||||
padding, out_size = valid_padding(in_size, kernel,
|
||||
[stride, stride])
|
||||
@@ -39,31 +36,20 @@ class VisionNetwork(Model):
|
||||
SlimConv2d(in_channels, out_channels, kernel, stride, None))
|
||||
self._convs = nn.Sequential(*layers)
|
||||
|
||||
self.logits = SlimFC(
|
||||
self._logits = SlimFC(
|
||||
out_channels, num_outputs, initializer=nn.init.xavier_uniform_)
|
||||
self.value_branch = SlimFC(
|
||||
self._value_branch = SlimFC(
|
||||
out_channels, 1, initializer=normc_initializer())
|
||||
|
||||
def hidden_layers(self, obs):
|
||||
""" Internal method - pass in torch tensors, not numpy arrays
|
||||
@override(TorchModel)
|
||||
def _forward(self, input_dict, hidden_state):
|
||||
features = self._hidden_layers(input_dict["obs"])
|
||||
logits = self._logits(features)
|
||||
value = self._value_branch(features).squeeze(1)
|
||||
return logits, features, value, hidden_state
|
||||
|
||||
args:
|
||||
obs: observations and features"""
|
||||
res = self._convs(obs)
|
||||
def _hidden_layers(self, obs):
|
||||
res = self._convs(obs.permute(0, 3, 1, 2)) # switch to channel-major
|
||||
res = res.squeeze(3)
|
||||
res = res.squeeze(2)
|
||||
return res
|
||||
|
||||
def forward(self, obs):
|
||||
"""Internal method. Implements the
|
||||
|
||||
Args:
|
||||
obs (PyTorch): observations and features
|
||||
|
||||
Return:
|
||||
logits (PyTorch): logits to be sampled from for each state
|
||||
value (PyTorch): value function for each state"""
|
||||
res = self.hidden_layers(obs)
|
||||
logits = self.logits(res)
|
||||
value = self.value_branch(res).squeeze(1)
|
||||
return logits, value
|
||||
|
||||
@@ -18,7 +18,7 @@ class VisionNetwork(Model):
|
||||
inputs = input_dict["obs"]
|
||||
filters = options.get("conv_filters")
|
||||
if not filters:
|
||||
filters = _get_filter_config(inputs)
|
||||
filters = _get_filter_config(inputs.shape.as_list()[1:])
|
||||
|
||||
activation = get_activation_fn(options.get("conv_activation"))
|
||||
|
||||
@@ -49,7 +49,8 @@ class VisionNetwork(Model):
|
||||
return flatten(fc2), flatten(fc1)
|
||||
|
||||
|
||||
def _get_filter_config(inputs):
|
||||
def _get_filter_config(shape):
|
||||
shape = list(shape)
|
||||
filters_84x84 = [
|
||||
[16, [8, 8], 4],
|
||||
[32, [4, 4], 2],
|
||||
@@ -60,15 +61,14 @@ def _get_filter_config(inputs):
|
||||
[32, [4, 4], 2],
|
||||
[256, [11, 11], 1],
|
||||
]
|
||||
shape = inputs.shape.as_list()[1:]
|
||||
if len(shape) == 3 and shape[:2] == [84, 84]:
|
||||
return filters_84x84
|
||||
elif len(shape) == 3 and shape[:2] == [42, 42]:
|
||||
return filters_42x42
|
||||
else:
|
||||
raise ValueError(
|
||||
"No default configuration for obs input {}".format(inputs) +
|
||||
"No default configuration for obs shape {}".format(shape) +
|
||||
", you must specify `conv_filters` manually as a model option. "
|
||||
"Default configurations are only available for inputs of size "
|
||||
"[?, 42, 42, K] and [?, 84, 84, K]. You may alternatively want "
|
||||
"Default configurations are only available for inputs of shape "
|
||||
"[42, 42, K] and [84, 84, K]. You may alternatively want "
|
||||
"to use a custom model or preprocessor.")
|
||||
|
||||
@@ -12,6 +12,7 @@ import tensorflow as tf
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.a3c import A2CAgent
|
||||
from ray.rllib.agents.pg import PGAgent
|
||||
from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph
|
||||
from ray.rllib.env import MultiAgentEnv
|
||||
@@ -19,6 +20,8 @@ from ray.rllib.env.async_vector_env import AsyncVectorEnv
|
||||
from ray.rllib.env.vector_env import VectorEnv
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.models.model import Model
|
||||
from ray.rllib.models.pytorch.fcnet import FullyConnectedNetwork
|
||||
from ray.rllib.models.pytorch.model import TorchModel
|
||||
from ray.rllib.rollout import rollout
|
||||
from ray.rllib.test.test_external_env import SimpleServing
|
||||
from ray.tune.registry import register_env
|
||||
@@ -129,6 +132,29 @@ class InvalidModel2(Model):
|
||||
return tf.constant(0), tf.constant(0)
|
||||
|
||||
|
||||
class TorchSpyModel(TorchModel):
|
||||
capture_index = 0
|
||||
|
||||
def __init__(self, obs_space, num_outputs, options):
|
||||
TorchModel.__init__(self, obs_space, num_outputs, options)
|
||||
self.fc = FullyConnectedNetwork(
|
||||
obs_space.original_space.spaces["sensors"].spaces["position"],
|
||||
num_outputs, options)
|
||||
|
||||
def _forward(self, input_dict, hidden_state):
|
||||
pos = input_dict["obs"]["sensors"]["position"].numpy()
|
||||
front_cam = input_dict["obs"]["sensors"]["front_cam"][0].numpy()
|
||||
task = input_dict["obs"]["inner_state"]["job_status"]["task"].numpy()
|
||||
ray.experimental.internal_kv._internal_kv_put(
|
||||
"torch_spy_in_{}".format(TorchSpyModel.capture_index),
|
||||
pickle.dumps((pos, front_cam, task)),
|
||||
overwrite=True)
|
||||
TorchSpyModel.capture_index += 1
|
||||
return self.fc({
|
||||
"obs": input_dict["obs"]["sensors"]["position"]
|
||||
}, hidden_state)
|
||||
|
||||
|
||||
class DictSpyModel(Model):
|
||||
capture_index = 0
|
||||
|
||||
@@ -359,6 +385,36 @@ class NestedSpacesTest(unittest.TestCase):
|
||||
# Test rollout works on restore
|
||||
rollout(agent2, "nested", 100)
|
||||
|
||||
def testPyTorchModel(self):
|
||||
ModelCatalog.register_custom_model("composite", TorchSpyModel)
|
||||
register_env("nested", lambda _: NestedDictEnv())
|
||||
a2c = A2CAgent(
|
||||
env="nested",
|
||||
config={
|
||||
"num_workers": 0,
|
||||
"use_pytorch": True,
|
||||
"sample_batch_size": 5,
|
||||
"train_batch_size": 5,
|
||||
"model": {
|
||||
"custom_model": "composite",
|
||||
},
|
||||
})
|
||||
|
||||
a2c.train()
|
||||
|
||||
# Check that the model sees the correct reconstructed observations
|
||||
for i in range(4):
|
||||
seen = pickle.loads(
|
||||
ray.experimental.internal_kv._internal_kv_get(
|
||||
"torch_spy_in_{}".format(i)))
|
||||
pos_i = DICT_SAMPLES[i]["sensors"]["position"].tolist()
|
||||
cam_i = DICT_SAMPLES[i]["sensors"]["front_cam"][0].tolist()
|
||||
task_i = one_hot(
|
||||
DICT_SAMPLES[i]["inner_state"]["job_status"]["task"], 5)
|
||||
self.assertEqual(seen[0][0].tolist(), pos_i)
|
||||
self.assertEqual(seen[1][0].tolist(), cam_i)
|
||||
self.assertEqual(seen[2][0].tolist(), task_i)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init(num_cpus=5)
|
||||
|
||||
@@ -14,7 +14,6 @@ pong-a3c-pytorch-cnn:
|
||||
observation_filter: NoFilter
|
||||
model:
|
||||
use_lstm: false
|
||||
channel_major: true
|
||||
dim: 84
|
||||
grayscale: true
|
||||
zero_mean: false
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
cartpole-a2c-torch:
|
||||
env: CartPole-v0
|
||||
run: A2C
|
||||
stop:
|
||||
episode_reward_mean: 100
|
||||
timesteps_total: 100000
|
||||
config:
|
||||
num_workers: 0
|
||||
use_pytorch: true
|
||||
Reference in New Issue
Block a user