[rllib] Support torch device and distributions. (#4553)

This commit is contained in:
cfan
2019-04-12 11:39:14 -07:00
committed by Eric Liang
parent 5cfbfe5df6
commit bb207a205b
6 changed files with 186 additions and 62 deletions
@@ -17,25 +17,28 @@ from ray.rllib.utils.annotations import override
class A3CLoss(nn.Module):
def __init__(self, policy_model, vf_loss_coeff=0.5, entropy_coeff=0.01):
def __init__(self, dist_class, vf_loss_coeff=0.5, entropy_coeff=0.01):
nn.Module.__init__(self)
self.policy_model = policy_model
self.dist_class = dist_class
self.vf_loss_coeff = vf_loss_coeff
self.entropy_coeff = entropy_coeff
def forward(self, observations, actions, advantages, value_targets):
logits, _, values, _ = self.policy_model({"obs": observations}, [])
log_probs = F.log_softmax(logits, dim=1)
probs = F.softmax(logits, dim=1)
action_log_probs = log_probs.gather(1, actions.view(-1, 1))
entropy = -(log_probs * probs).sum(-1).sum()
pi_err = -advantages.dot(action_log_probs.reshape(-1))
value_err = F.mse_loss(values.reshape(-1), value_targets)
def forward(self, policy_model, observations, actions, advantages,
value_targets):
logits, _, values, _ = policy_model({
SampleBatch.CUR_OBS: observations
}, [])
dist = self.dist_class(logits)
log_probs = dist.logp(actions)
self.entropy = dist.entropy().mean()
self.pi_err = -advantages.dot(log_probs.reshape(-1))
self.value_err = F.mse_loss(values.reshape(-1), value_targets)
overall_err = sum([
pi_err,
self.vf_loss_coeff * value_err,
-self.entropy_coeff * entropy,
self.pi_err,
self.vf_loss_coeff * self.value_err,
-self.entropy_coeff * self.entropy,
])
return overall_err
@@ -44,7 +47,7 @@ class A3CPostprocessing(object):
@override(TorchPolicyGraph)
def extra_action_out(self, model_out):
return {SampleBatch.VF_PREDS: model_out[2].numpy()}
return {SampleBatch.VF_PREDS: model_out[2].cpu().numpy()}
@override(PolicyGraph)
def postprocess_trajectory(self,
@@ -66,29 +69,47 @@ class A3CTorchPolicyGraph(A3CPostprocessing, TorchPolicyGraph):
def __init__(self, obs_space, action_space, config):
config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config)
self.config = config
_, self.logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"])
self.model = ModelCatalog.get_torch_model(obs_space, self.logit_dim,
self.config["model"])
loss = A3CLoss(self.model, self.config["vf_loss_coeff"],
dist_class, self.logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"], torch=True)
model = ModelCatalog.get_torch_model(obs_space, self.logit_dim,
self.config["model"])
loss = A3CLoss(dist_class, self.config["vf_loss_coeff"],
self.config["entropy_coeff"])
TorchPolicyGraph.__init__(
self,
obs_space,
action_space,
self.model,
model,
loss,
loss_inputs=[
SampleBatch.CUR_OBS, SampleBatch.ACTIONS,
Postprocessing.ADVANTAGES, Postprocessing.VALUE_TARGETS
])
],
action_distribution_cls=dist_class)
@override(TorchPolicyGraph)
def optimizer(self):
return torch.optim.Adam(self.model.parameters(), lr=self.config["lr"])
return torch.optim.Adam(self._model.parameters(), lr=self.config["lr"])
@override(TorchPolicyGraph)
def extra_grad_process(self):
info = {}
if self.config["grad_clip"]:
total_norm = nn.utils.clip_grad_norm_(self._model.parameters(),
self.config["grad_clip"])
info["grad_gnorm"] = total_norm
return info
@override(TorchPolicyGraph)
def extra_grad_info(self):
return {
"policy_entropy": self._loss.entropy.item(),
"policy_loss": self._loss.pi_err.item(),
"vf_loss": self._loss.value_err.item()
}
def _value(self, obs):
with self.lock:
obs = torch.from_numpy(obs).float().unsqueeze(0)
_, _, vf, _ = self.model({"obs": obs}, [])
return vf.detach().numpy().squeeze()
obs = torch.from_numpy(obs).float().unsqueeze(0).to(self.device)
_, _, vf, _ = self._model({"obs": obs}, [])
return vf.detach().cpu().numpy().squeeze()
@@ -3,7 +3,6 @@ from __future__ import division
from __future__ import print_function
import torch
import torch.nn.functional as F
from torch import nn
import ray
@@ -17,16 +16,18 @@ from ray.rllib.utils.annotations import override
class PGLoss(nn.Module):
def __init__(self, policy_model):
def __init__(self, dist_class):
nn.Module.__init__(self)
self.policy_model = policy_model
self.dist_class = dist_class
def forward(self, observations, actions, advantages):
logits, _, values, _ = self.policy_model({"obs": observations}, [])
log_probs = F.log_softmax(logits, dim=1)
action_log_probs = log_probs.gather(1, actions.view(-1, 1))
pi_err = -advantages.dot(action_log_probs.reshape(-1))
return pi_err
def forward(self, policy_model, observations, actions, advantages):
logits, _, values, _ = policy_model({
SampleBatch.CUR_OBS: observations
}, [])
dist = self.dist_class(logits)
log_probs = dist.logp(actions)
self.pi_err = -advantages.dot(log_probs.reshape(-1))
return self.pi_err
class PGPostprocessing(object):
@@ -34,7 +35,7 @@ class PGPostprocessing(object):
@override(TorchPolicyGraph)
def extra_action_out(self, model_out):
return {SampleBatch.VF_PREDS: model_out[2].numpy()}
return {SampleBatch.VF_PREDS: model_out[2].cpu().numpy()}
@override(PolicyGraph)
def postprocess_trajectory(self,
@@ -49,29 +50,34 @@ class PGTorchPolicyGraph(PGPostprocessing, TorchPolicyGraph):
def __init__(self, obs_space, action_space, config):
config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config)
self.config = config
_, self.logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"])
self.model = ModelCatalog.get_torch_model(obs_space, self.logit_dim,
self.config["model"])
loss = PGLoss(self.model)
dist_class, self.logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"], torch=True)
model = ModelCatalog.get_torch_model(obs_space, self.logit_dim,
self.config["model"])
loss = PGLoss(dist_class)
TorchPolicyGraph.__init__(
self,
obs_space,
action_space,
self.model,
model,
loss,
loss_inputs=[
SampleBatch.CUR_OBS, SampleBatch.ACTIONS,
Postprocessing.ADVANTAGES
])
],
action_distribution_cls=dist_class)
@override(TorchPolicyGraph)
def optimizer(self):
return torch.optim.Adam(self.model.parameters(), lr=self.config["lr"])
return torch.optim.Adam(self._model.parameters(), lr=self.config["lr"])
@override(TorchPolicyGraph)
def extra_grad_info(self):
return {"policy_loss": self._loss.pi_err.item()}
def _value(self, obs):
with self.lock:
obs = torch.from_numpy(obs).float().unsqueeze(0)
obs = torch.from_numpy(obs).float().unsqueeze(0).to(self.device)
_, _, vf, _ = self.model({"obs": obs}, [])
return vf.detach().numpy().squeeze()
return vf.detach().cpu().numpy().squeeze()
@@ -2,15 +2,17 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
from threading import Lock
try:
import torch
import torch.nn.functional as F
except ImportError:
pass # soft dep
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.utils.annotations import override
@@ -28,11 +30,11 @@ class TorchPolicyGraph(PolicyGraph):
"""
def __init__(self, observation_space, action_space, model, loss,
loss_inputs):
loss_inputs, action_distribution_cls):
"""Build a policy graph from policy and loss torch modules.
Note that module inputs will be CPU tensors. The model and loss modules
are responsible for moving inputs to the right device.
Note that model will be placed on GPU device if CUDA_VISIBLE_DEVICES
is set. Only single GPU is supported for now.
Arguments:
observation_space (gym.Space): observation space of the policy.
@@ -47,14 +49,20 @@ class TorchPolicyGraph(PolicyGraph):
loss_inputs (list): List of SampleBatch columns that will be
passed to the loss module's forward() function when computing
the loss. For example, ["obs", "action", "advantages"].
action_distribution_cls (ActionDistribution): Class for action
distribution.
"""
self.observation_space = observation_space
self.action_space = action_space
self.lock = Lock()
self._model = model
self.device = (torch.device("cuda")
if bool(os.environ.get("CUDA_VISIBLE_DEVICES", None))
else torch.device("cpu"))
self._model = model.to(self.device)
self._loss = loss
self._loss_inputs = loss_inputs
self._optimizer = self.optimizer()
self._action_dist_cls = action_distribution_cls
@override(PolicyGraph)
def compute_actions(self,
@@ -67,11 +75,14 @@ class TorchPolicyGraph(PolicyGraph):
**kwargs):
with self.lock:
with torch.no_grad():
ob = torch.from_numpy(np.array(obs_batch)).float()
ob = torch.from_numpy(np.array(obs_batch)) \
.float().to(self.device)
model_out = self._model({"obs": ob}, state_batches)
logits, _, vf, state = model_out
actions = F.softmax(logits, dim=1).multinomial(1).squeeze(0)
return (actions.numpy(), [h.numpy() for h in state],
action_dist = self._action_dist_cls(logits)
actions = action_dist.sample()
return (actions.cpu().numpy(),
[h.cpu().numpy() for h in state],
self.extra_action_out(model_out))
@override(PolicyGraph)
@@ -79,33 +90,40 @@ class TorchPolicyGraph(PolicyGraph):
with self.lock:
loss_in = []
for key in self._loss_inputs:
loss_in.append(torch.from_numpy(postprocessed_batch[key]))
loss_out = self._loss(*loss_in)
loss_in.append(
torch.from_numpy(postprocessed_batch[key]).to(self.device))
loss_out = self._loss(self._model, *loss_in)
self._optimizer.zero_grad()
loss_out.backward()
grad_process_info = self.extra_grad_process()
# Note that return values are just references;
# calling zero_grad will modify the values
grads = []
for p in self._model.parameters():
if p.grad is not None:
grads.append(p.grad.data.numpy())
grads.append(p.grad.data.cpu().numpy())
else:
grads.append(None)
return grads, {}
grad_info = self.extra_grad_info()
grad_info.update(grad_process_info)
return grads, {LEARNER_STATS_KEY: grad_info}
@override(PolicyGraph)
def apply_gradients(self, gradients):
with self.lock:
for g, p in zip(gradients, self._model.parameters()):
if g is not None:
p.grad = torch.from_numpy(g)
p.grad = torch.from_numpy(g).to(self.device)
self._optimizer.step()
return {}
@override(PolicyGraph)
def get_weights(self):
with self.lock:
return self._model.state_dict()
return {k: v.cpu() for k, v in self._model.state_dict().items()}
@override(PolicyGraph)
def set_weights(self, weights):
@@ -116,6 +134,11 @@ class TorchPolicyGraph(PolicyGraph):
def get_initial_state(self):
return [s.numpy() for s in self._model.state_init()]
def extra_grad_process(self):
"""Allow subclass to do extra processing on gradients and
return processing info."""
return {}
def extra_action_out(self, model_out):
"""Returns dict of extra info to include in experience batch.
@@ -123,6 +146,11 @@ class TorchPolicyGraph(PolicyGraph):
model_out (list): Outputs of the policy model module."""
return {}
def extra_grad_info(self):
"""Return dict of extra grad info."""
return {}
def optimizer(self):
"""Custom PyTorch optimizer to use."""
return torch.optim.Adam(self._model.parameters())
+13 -3
View File
@@ -15,6 +15,8 @@ from ray.rllib.models.extra_spaces import Simplex
from ray.rllib.models.action_dist import (Categorical, MultiCategorical,
Deterministic, DiagGaussian,
MultiActionDistribution, Dirichlet)
from ray.rllib.models.torch_action_dist import (TorchCategorical,
TorchDiagGaussian)
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.models.fcnet import FullyConnectedNetwork
from ray.rllib.models.visionnet import VisionNetwork
@@ -89,13 +91,14 @@ class ModelCatalog(object):
@staticmethod
@DeveloperAPI
def get_action_dist(action_space, config, dist_type=None):
def get_action_dist(action_space, config, dist_type=None, torch=False):
"""Returns action distribution class and size for the given action space.
Args:
action_space (Space): Action space of the target gym env.
config (dict): Optional model config.
dist_type (str): Optional identifier of the action distribution.
torch (bool): Optional whether to return PyTorch distribution.
Returns:
dist_class (ActionDistribution): Python class of the distribution.
@@ -111,7 +114,7 @@ class ModelCatalog(object):
"Consider reshaping this into a single dimension, "
"using a Tuple action space, or the multi-agent API.")
if dist_type is None:
dist = DiagGaussian
dist = TorchDiagGaussian if torch else DiagGaussian
if config.get("squash_to_range"):
raise ValueError(
"The squash_to_range option is deprecated. See the "
@@ -120,7 +123,8 @@ class ModelCatalog(object):
elif dist_type == "deterministic":
return Deterministic, action_space.shape[0]
elif isinstance(action_space, gym.spaces.Discrete):
return Categorical, action_space.n
dist = TorchCategorical if torch else Categorical
return dist, action_space.n
elif isinstance(action_space, gym.spaces.Tuple):
child_dist = []
input_lens = []
@@ -129,14 +133,20 @@ class ModelCatalog(object):
action, config)
child_dist.append(dist)
input_lens.append(action_size)
if torch:
raise NotImplementedError
return partial(
MultiActionDistribution,
child_distributions=child_dist,
action_space=action_space,
input_lens=input_lens), sum(input_lens)
elif isinstance(action_space, Simplex):
if torch:
raise NotImplementedError
return Dirichlet, action_space.shape[0]
elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete):
if torch:
raise NotImplementedError
return MultiCategorical, sum(action_space.nvec)
raise NotImplementedError("Unsupported args: {} {}".format(
@@ -0,0 +1,52 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
try:
import torch
except ImportError:
pass # soft dep
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.utils.annotations import override
class TorchDistributionWrapper(ActionDistribution):
"""Wrapper class for torch.distributions."""
@override(ActionDistribution)
def logp(self, actions):
return self.dist.log_prob(actions)
@override(ActionDistribution)
def entropy(self):
return self.dist.entropy()
@override(ActionDistribution)
def kl(self, other):
return torch.distributions.kl.kl_divergence(self.dist, other)
@override(ActionDistribution)
def sample(self):
return self.dist.sample()
class TorchCategorical(TorchDistributionWrapper):
"""Wrapper class for PyTorch Categorical distribution."""
@override(ActionDistribution)
def __init__(self, inputs):
self.dist = torch.distributions.categorical.Categorical(logits=inputs)
class TorchDiagGaussian(TorchDistributionWrapper):
"""Wrapper class for PyTorch Normal distribution."""
@override(ActionDistribution)
def __init__(self, inputs):
mean, log_std = torch.chunk(inputs, 2, dim=1)
self.dist = torch.distributions.normal.Normal(mean, torch.exp(log_std))
@override(TorchDistributionWrapper)
def logp(self, actions):
return TorchDistributionWrapper.logp(self, actions).sum(-1)