mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 14:05:08 +08:00
[rllib] Support torch device and distributions. (#4553)
This commit is contained in:
@@ -17,25 +17,28 @@ from ray.rllib.utils.annotations import override
|
||||
|
||||
|
||||
class A3CLoss(nn.Module):
|
||||
def __init__(self, policy_model, vf_loss_coeff=0.5, entropy_coeff=0.01):
|
||||
def __init__(self, dist_class, vf_loss_coeff=0.5, entropy_coeff=0.01):
|
||||
nn.Module.__init__(self)
|
||||
self.policy_model = policy_model
|
||||
self.dist_class = dist_class
|
||||
self.vf_loss_coeff = vf_loss_coeff
|
||||
self.entropy_coeff = entropy_coeff
|
||||
|
||||
def forward(self, observations, actions, advantages, value_targets):
|
||||
logits, _, values, _ = self.policy_model({"obs": observations}, [])
|
||||
log_probs = F.log_softmax(logits, dim=1)
|
||||
probs = F.softmax(logits, dim=1)
|
||||
action_log_probs = log_probs.gather(1, actions.view(-1, 1))
|
||||
entropy = -(log_probs * probs).sum(-1).sum()
|
||||
pi_err = -advantages.dot(action_log_probs.reshape(-1))
|
||||
value_err = F.mse_loss(values.reshape(-1), value_targets)
|
||||
def forward(self, policy_model, observations, actions, advantages,
|
||||
value_targets):
|
||||
logits, _, values, _ = policy_model({
|
||||
SampleBatch.CUR_OBS: observations
|
||||
}, [])
|
||||
dist = self.dist_class(logits)
|
||||
log_probs = dist.logp(actions)
|
||||
self.entropy = dist.entropy().mean()
|
||||
self.pi_err = -advantages.dot(log_probs.reshape(-1))
|
||||
self.value_err = F.mse_loss(values.reshape(-1), value_targets)
|
||||
overall_err = sum([
|
||||
pi_err,
|
||||
self.vf_loss_coeff * value_err,
|
||||
-self.entropy_coeff * entropy,
|
||||
self.pi_err,
|
||||
self.vf_loss_coeff * self.value_err,
|
||||
-self.entropy_coeff * self.entropy,
|
||||
])
|
||||
|
||||
return overall_err
|
||||
|
||||
|
||||
@@ -44,7 +47,7 @@ class A3CPostprocessing(object):
|
||||
|
||||
@override(TorchPolicyGraph)
|
||||
def extra_action_out(self, model_out):
|
||||
return {SampleBatch.VF_PREDS: model_out[2].numpy()}
|
||||
return {SampleBatch.VF_PREDS: model_out[2].cpu().numpy()}
|
||||
|
||||
@override(PolicyGraph)
|
||||
def postprocess_trajectory(self,
|
||||
@@ -66,29 +69,47 @@ class A3CTorchPolicyGraph(A3CPostprocessing, TorchPolicyGraph):
|
||||
def __init__(self, obs_space, action_space, config):
|
||||
config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config)
|
||||
self.config = config
|
||||
_, self.logit_dim = ModelCatalog.get_action_dist(
|
||||
action_space, self.config["model"])
|
||||
self.model = ModelCatalog.get_torch_model(obs_space, self.logit_dim,
|
||||
self.config["model"])
|
||||
loss = A3CLoss(self.model, self.config["vf_loss_coeff"],
|
||||
dist_class, self.logit_dim = ModelCatalog.get_action_dist(
|
||||
action_space, self.config["model"], torch=True)
|
||||
model = ModelCatalog.get_torch_model(obs_space, self.logit_dim,
|
||||
self.config["model"])
|
||||
loss = A3CLoss(dist_class, self.config["vf_loss_coeff"],
|
||||
self.config["entropy_coeff"])
|
||||
TorchPolicyGraph.__init__(
|
||||
self,
|
||||
obs_space,
|
||||
action_space,
|
||||
self.model,
|
||||
model,
|
||||
loss,
|
||||
loss_inputs=[
|
||||
SampleBatch.CUR_OBS, SampleBatch.ACTIONS,
|
||||
Postprocessing.ADVANTAGES, Postprocessing.VALUE_TARGETS
|
||||
])
|
||||
],
|
||||
action_distribution_cls=dist_class)
|
||||
|
||||
@override(TorchPolicyGraph)
|
||||
def optimizer(self):
|
||||
return torch.optim.Adam(self.model.parameters(), lr=self.config["lr"])
|
||||
return torch.optim.Adam(self._model.parameters(), lr=self.config["lr"])
|
||||
|
||||
@override(TorchPolicyGraph)
|
||||
def extra_grad_process(self):
|
||||
info = {}
|
||||
if self.config["grad_clip"]:
|
||||
total_norm = nn.utils.clip_grad_norm_(self._model.parameters(),
|
||||
self.config["grad_clip"])
|
||||
info["grad_gnorm"] = total_norm
|
||||
return info
|
||||
|
||||
@override(TorchPolicyGraph)
|
||||
def extra_grad_info(self):
|
||||
return {
|
||||
"policy_entropy": self._loss.entropy.item(),
|
||||
"policy_loss": self._loss.pi_err.item(),
|
||||
"vf_loss": self._loss.value_err.item()
|
||||
}
|
||||
|
||||
def _value(self, obs):
|
||||
with self.lock:
|
||||
obs = torch.from_numpy(obs).float().unsqueeze(0)
|
||||
_, _, vf, _ = self.model({"obs": obs}, [])
|
||||
return vf.detach().numpy().squeeze()
|
||||
obs = torch.from_numpy(obs).float().unsqueeze(0).to(self.device)
|
||||
_, _, vf, _ = self._model({"obs": obs}, [])
|
||||
return vf.detach().cpu().numpy().squeeze()
|
||||
|
||||
@@ -3,7 +3,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
import ray
|
||||
@@ -17,16 +16,18 @@ from ray.rllib.utils.annotations import override
|
||||
|
||||
|
||||
class PGLoss(nn.Module):
|
||||
def __init__(self, policy_model):
|
||||
def __init__(self, dist_class):
|
||||
nn.Module.__init__(self)
|
||||
self.policy_model = policy_model
|
||||
self.dist_class = dist_class
|
||||
|
||||
def forward(self, observations, actions, advantages):
|
||||
logits, _, values, _ = self.policy_model({"obs": observations}, [])
|
||||
log_probs = F.log_softmax(logits, dim=1)
|
||||
action_log_probs = log_probs.gather(1, actions.view(-1, 1))
|
||||
pi_err = -advantages.dot(action_log_probs.reshape(-1))
|
||||
return pi_err
|
||||
def forward(self, policy_model, observations, actions, advantages):
|
||||
logits, _, values, _ = policy_model({
|
||||
SampleBatch.CUR_OBS: observations
|
||||
}, [])
|
||||
dist = self.dist_class(logits)
|
||||
log_probs = dist.logp(actions)
|
||||
self.pi_err = -advantages.dot(log_probs.reshape(-1))
|
||||
return self.pi_err
|
||||
|
||||
|
||||
class PGPostprocessing(object):
|
||||
@@ -34,7 +35,7 @@ class PGPostprocessing(object):
|
||||
|
||||
@override(TorchPolicyGraph)
|
||||
def extra_action_out(self, model_out):
|
||||
return {SampleBatch.VF_PREDS: model_out[2].numpy()}
|
||||
return {SampleBatch.VF_PREDS: model_out[2].cpu().numpy()}
|
||||
|
||||
@override(PolicyGraph)
|
||||
def postprocess_trajectory(self,
|
||||
@@ -49,29 +50,34 @@ class PGTorchPolicyGraph(PGPostprocessing, TorchPolicyGraph):
|
||||
def __init__(self, obs_space, action_space, config):
|
||||
config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config)
|
||||
self.config = config
|
||||
_, self.logit_dim = ModelCatalog.get_action_dist(
|
||||
action_space, self.config["model"])
|
||||
self.model = ModelCatalog.get_torch_model(obs_space, self.logit_dim,
|
||||
self.config["model"])
|
||||
loss = PGLoss(self.model)
|
||||
dist_class, self.logit_dim = ModelCatalog.get_action_dist(
|
||||
action_space, self.config["model"], torch=True)
|
||||
model = ModelCatalog.get_torch_model(obs_space, self.logit_dim,
|
||||
self.config["model"])
|
||||
loss = PGLoss(dist_class)
|
||||
|
||||
TorchPolicyGraph.__init__(
|
||||
self,
|
||||
obs_space,
|
||||
action_space,
|
||||
self.model,
|
||||
model,
|
||||
loss,
|
||||
loss_inputs=[
|
||||
SampleBatch.CUR_OBS, SampleBatch.ACTIONS,
|
||||
Postprocessing.ADVANTAGES
|
||||
])
|
||||
],
|
||||
action_distribution_cls=dist_class)
|
||||
|
||||
@override(TorchPolicyGraph)
|
||||
def optimizer(self):
|
||||
return torch.optim.Adam(self.model.parameters(), lr=self.config["lr"])
|
||||
return torch.optim.Adam(self._model.parameters(), lr=self.config["lr"])
|
||||
|
||||
@override(TorchPolicyGraph)
|
||||
def extra_grad_info(self):
|
||||
return {"policy_loss": self._loss.pi_err.item()}
|
||||
|
||||
def _value(self, obs):
|
||||
with self.lock:
|
||||
obs = torch.from_numpy(obs).float().unsqueeze(0)
|
||||
obs = torch.from_numpy(obs).float().unsqueeze(0).to(self.device)
|
||||
_, _, vf, _ = self.model({"obs": obs}, [])
|
||||
return vf.detach().numpy().squeeze()
|
||||
return vf.detach().cpu().numpy().squeeze()
|
||||
|
||||
@@ -2,15 +2,17 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from threading import Lock
|
||||
|
||||
try:
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
except ImportError:
|
||||
pass # soft dep
|
||||
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
@@ -28,11 +30,11 @@ class TorchPolicyGraph(PolicyGraph):
|
||||
"""
|
||||
|
||||
def __init__(self, observation_space, action_space, model, loss,
|
||||
loss_inputs):
|
||||
loss_inputs, action_distribution_cls):
|
||||
"""Build a policy graph from policy and loss torch modules.
|
||||
|
||||
Note that module inputs will be CPU tensors. The model and loss modules
|
||||
are responsible for moving inputs to the right device.
|
||||
Note that model will be placed on GPU device if CUDA_VISIBLE_DEVICES
|
||||
is set. Only single GPU is supported for now.
|
||||
|
||||
Arguments:
|
||||
observation_space (gym.Space): observation space of the policy.
|
||||
@@ -47,14 +49,20 @@ class TorchPolicyGraph(PolicyGraph):
|
||||
loss_inputs (list): List of SampleBatch columns that will be
|
||||
passed to the loss module's forward() function when computing
|
||||
the loss. For example, ["obs", "action", "advantages"].
|
||||
action_distribution_cls (ActionDistribution): Class for action
|
||||
distribution.
|
||||
"""
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
self.lock = Lock()
|
||||
self._model = model
|
||||
self.device = (torch.device("cuda")
|
||||
if bool(os.environ.get("CUDA_VISIBLE_DEVICES", None))
|
||||
else torch.device("cpu"))
|
||||
self._model = model.to(self.device)
|
||||
self._loss = loss
|
||||
self._loss_inputs = loss_inputs
|
||||
self._optimizer = self.optimizer()
|
||||
self._action_dist_cls = action_distribution_cls
|
||||
|
||||
@override(PolicyGraph)
|
||||
def compute_actions(self,
|
||||
@@ -67,11 +75,14 @@ class TorchPolicyGraph(PolicyGraph):
|
||||
**kwargs):
|
||||
with self.lock:
|
||||
with torch.no_grad():
|
||||
ob = torch.from_numpy(np.array(obs_batch)).float()
|
||||
ob = torch.from_numpy(np.array(obs_batch)) \
|
||||
.float().to(self.device)
|
||||
model_out = self._model({"obs": ob}, state_batches)
|
||||
logits, _, vf, state = model_out
|
||||
actions = F.softmax(logits, dim=1).multinomial(1).squeeze(0)
|
||||
return (actions.numpy(), [h.numpy() for h in state],
|
||||
action_dist = self._action_dist_cls(logits)
|
||||
actions = action_dist.sample()
|
||||
return (actions.cpu().numpy(),
|
||||
[h.cpu().numpy() for h in state],
|
||||
self.extra_action_out(model_out))
|
||||
|
||||
@override(PolicyGraph)
|
||||
@@ -79,33 +90,40 @@ class TorchPolicyGraph(PolicyGraph):
|
||||
with self.lock:
|
||||
loss_in = []
|
||||
for key in self._loss_inputs:
|
||||
loss_in.append(torch.from_numpy(postprocessed_batch[key]))
|
||||
loss_out = self._loss(*loss_in)
|
||||
loss_in.append(
|
||||
torch.from_numpy(postprocessed_batch[key]).to(self.device))
|
||||
loss_out = self._loss(self._model, *loss_in)
|
||||
self._optimizer.zero_grad()
|
||||
loss_out.backward()
|
||||
|
||||
grad_process_info = self.extra_grad_process()
|
||||
|
||||
# Note that return values are just references;
|
||||
# calling zero_grad will modify the values
|
||||
grads = []
|
||||
for p in self._model.parameters():
|
||||
if p.grad is not None:
|
||||
grads.append(p.grad.data.numpy())
|
||||
grads.append(p.grad.data.cpu().numpy())
|
||||
else:
|
||||
grads.append(None)
|
||||
return grads, {}
|
||||
|
||||
grad_info = self.extra_grad_info()
|
||||
grad_info.update(grad_process_info)
|
||||
return grads, {LEARNER_STATS_KEY: grad_info}
|
||||
|
||||
@override(PolicyGraph)
|
||||
def apply_gradients(self, gradients):
|
||||
with self.lock:
|
||||
for g, p in zip(gradients, self._model.parameters()):
|
||||
if g is not None:
|
||||
p.grad = torch.from_numpy(g)
|
||||
p.grad = torch.from_numpy(g).to(self.device)
|
||||
self._optimizer.step()
|
||||
return {}
|
||||
|
||||
@override(PolicyGraph)
|
||||
def get_weights(self):
|
||||
with self.lock:
|
||||
return self._model.state_dict()
|
||||
return {k: v.cpu() for k, v in self._model.state_dict().items()}
|
||||
|
||||
@override(PolicyGraph)
|
||||
def set_weights(self, weights):
|
||||
@@ -116,6 +134,11 @@ class TorchPolicyGraph(PolicyGraph):
|
||||
def get_initial_state(self):
|
||||
return [s.numpy() for s in self._model.state_init()]
|
||||
|
||||
def extra_grad_process(self):
|
||||
"""Allow subclass to do extra processing on gradients and
|
||||
return processing info."""
|
||||
return {}
|
||||
|
||||
def extra_action_out(self, model_out):
|
||||
"""Returns dict of extra info to include in experience batch.
|
||||
|
||||
@@ -123,6 +146,11 @@ class TorchPolicyGraph(PolicyGraph):
|
||||
model_out (list): Outputs of the policy model module."""
|
||||
return {}
|
||||
|
||||
def extra_grad_info(self):
|
||||
"""Return dict of extra grad info."""
|
||||
|
||||
return {}
|
||||
|
||||
def optimizer(self):
|
||||
"""Custom PyTorch optimizer to use."""
|
||||
return torch.optim.Adam(self._model.parameters())
|
||||
|
||||
@@ -15,6 +15,8 @@ from ray.rllib.models.extra_spaces import Simplex
|
||||
from ray.rllib.models.action_dist import (Categorical, MultiCategorical,
|
||||
Deterministic, DiagGaussian,
|
||||
MultiActionDistribution, Dirichlet)
|
||||
from ray.rllib.models.torch_action_dist import (TorchCategorical,
|
||||
TorchDiagGaussian)
|
||||
from ray.rllib.models.preprocessors import get_preprocessor
|
||||
from ray.rllib.models.fcnet import FullyConnectedNetwork
|
||||
from ray.rllib.models.visionnet import VisionNetwork
|
||||
@@ -89,13 +91,14 @@ class ModelCatalog(object):
|
||||
|
||||
@staticmethod
|
||||
@DeveloperAPI
|
||||
def get_action_dist(action_space, config, dist_type=None):
|
||||
def get_action_dist(action_space, config, dist_type=None, torch=False):
|
||||
"""Returns action distribution class and size for the given action space.
|
||||
|
||||
Args:
|
||||
action_space (Space): Action space of the target gym env.
|
||||
config (dict): Optional model config.
|
||||
dist_type (str): Optional identifier of the action distribution.
|
||||
torch (bool): Optional whether to return PyTorch distribution.
|
||||
|
||||
Returns:
|
||||
dist_class (ActionDistribution): Python class of the distribution.
|
||||
@@ -111,7 +114,7 @@ class ModelCatalog(object):
|
||||
"Consider reshaping this into a single dimension, "
|
||||
"using a Tuple action space, or the multi-agent API.")
|
||||
if dist_type is None:
|
||||
dist = DiagGaussian
|
||||
dist = TorchDiagGaussian if torch else DiagGaussian
|
||||
if config.get("squash_to_range"):
|
||||
raise ValueError(
|
||||
"The squash_to_range option is deprecated. See the "
|
||||
@@ -120,7 +123,8 @@ class ModelCatalog(object):
|
||||
elif dist_type == "deterministic":
|
||||
return Deterministic, action_space.shape[0]
|
||||
elif isinstance(action_space, gym.spaces.Discrete):
|
||||
return Categorical, action_space.n
|
||||
dist = TorchCategorical if torch else Categorical
|
||||
return dist, action_space.n
|
||||
elif isinstance(action_space, gym.spaces.Tuple):
|
||||
child_dist = []
|
||||
input_lens = []
|
||||
@@ -129,14 +133,20 @@ class ModelCatalog(object):
|
||||
action, config)
|
||||
child_dist.append(dist)
|
||||
input_lens.append(action_size)
|
||||
if torch:
|
||||
raise NotImplementedError
|
||||
return partial(
|
||||
MultiActionDistribution,
|
||||
child_distributions=child_dist,
|
||||
action_space=action_space,
|
||||
input_lens=input_lens), sum(input_lens)
|
||||
elif isinstance(action_space, Simplex):
|
||||
if torch:
|
||||
raise NotImplementedError
|
||||
return Dirichlet, action_space.shape[0]
|
||||
elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete):
|
||||
if torch:
|
||||
raise NotImplementedError
|
||||
return MultiCategorical, sum(action_space.nvec)
|
||||
|
||||
raise NotImplementedError("Unsupported args: {} {}".format(
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
pass # soft dep
|
||||
|
||||
from ray.rllib.models.action_dist import ActionDistribution
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
|
||||
class TorchDistributionWrapper(ActionDistribution):
|
||||
"""Wrapper class for torch.distributions."""
|
||||
|
||||
@override(ActionDistribution)
|
||||
def logp(self, actions):
|
||||
return self.dist.log_prob(actions)
|
||||
|
||||
@override(ActionDistribution)
|
||||
def entropy(self):
|
||||
return self.dist.entropy()
|
||||
|
||||
@override(ActionDistribution)
|
||||
def kl(self, other):
|
||||
return torch.distributions.kl.kl_divergence(self.dist, other)
|
||||
|
||||
@override(ActionDistribution)
|
||||
def sample(self):
|
||||
return self.dist.sample()
|
||||
|
||||
|
||||
class TorchCategorical(TorchDistributionWrapper):
|
||||
"""Wrapper class for PyTorch Categorical distribution."""
|
||||
|
||||
@override(ActionDistribution)
|
||||
def __init__(self, inputs):
|
||||
self.dist = torch.distributions.categorical.Categorical(logits=inputs)
|
||||
|
||||
|
||||
class TorchDiagGaussian(TorchDistributionWrapper):
|
||||
"""Wrapper class for PyTorch Normal distribution."""
|
||||
|
||||
@override(ActionDistribution)
|
||||
def __init__(self, inputs):
|
||||
mean, log_std = torch.chunk(inputs, 2, dim=1)
|
||||
self.dist = torch.distributions.normal.Normal(mean, torch.exp(log_std))
|
||||
|
||||
@override(TorchDistributionWrapper)
|
||||
def logp(self, actions):
|
||||
return TorchDistributionWrapper.logp(self, actions).sum(-1)
|
||||
Reference in New Issue
Block a user