From bb207a205bde6cdd7e378ed6643d8cc1e8479659 Mon Sep 17 00:00:00 2001 From: cfan Date: Fri, 12 Apr 2019 11:39:14 -0700 Subject: [PATCH] [rllib] Support torch device and distributions. (#4553) --- ci/jenkins_tests/run_rllib_tests.sh | 7 ++ .../agents/a3c/a3c_torch_policy_graph.py | 71 ++++++++++++------- .../rllib/agents/pg/torch_pg_policy_graph.py | 46 ++++++------ .../rllib/evaluation/torch_policy_graph.py | 56 +++++++++++---- python/ray/rllib/models/catalog.py | 16 ++++- python/ray/rllib/models/torch_action_dist.py | 52 ++++++++++++++ 6 files changed, 186 insertions(+), 62 deletions(-) create mode 100644 python/ray/rllib/models/torch_action_dist.py diff --git a/ci/jenkins_tests/run_rllib_tests.sh b/ci/jenkins_tests/run_rllib_tests.sh index 89939bb2a..8012ce652 100644 --- a/ci/jenkins_tests/run_rllib_tests.sh +++ b/ci/jenkins_tests/run_rllib_tests.sh @@ -412,6 +412,13 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ --stop '{"training_iteration": 1}' \ --config '{"num_workers": 2, "use_pytorch": true, "sample_async": false}' +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ + /ray/ci/suppress_output /ray/python/ray/rllib/train.py \ + --env Pendulum-v0 \ + --run A3C \ + --stop '{"training_iteration": 1}' \ + --config '{"num_workers": 2, "use_pytorch": true, "sample_async": false}' + docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ /ray/ci/suppress_output /ray/python/ray/rllib/train.py \ --env PongDeterministic-v4 \ diff --git a/python/ray/rllib/agents/a3c/a3c_torch_policy_graph.py b/python/ray/rllib/agents/a3c/a3c_torch_policy_graph.py index 18f8f0422..d35aabe0d 100644 --- a/python/ray/rllib/agents/a3c/a3c_torch_policy_graph.py +++ b/python/ray/rllib/agents/a3c/a3c_torch_policy_graph.py @@ -17,25 +17,28 @@ from ray.rllib.utils.annotations import override class A3CLoss(nn.Module): - def __init__(self, policy_model, vf_loss_coeff=0.5, entropy_coeff=0.01): + def __init__(self, dist_class, vf_loss_coeff=0.5, entropy_coeff=0.01): nn.Module.__init__(self) - self.policy_model = policy_model + self.dist_class = dist_class self.vf_loss_coeff = vf_loss_coeff self.entropy_coeff = entropy_coeff - def forward(self, observations, actions, advantages, value_targets): - logits, _, values, _ = self.policy_model({"obs": observations}, []) - log_probs = F.log_softmax(logits, dim=1) - probs = F.softmax(logits, dim=1) - action_log_probs = log_probs.gather(1, actions.view(-1, 1)) - entropy = -(log_probs * probs).sum(-1).sum() - pi_err = -advantages.dot(action_log_probs.reshape(-1)) - value_err = F.mse_loss(values.reshape(-1), value_targets) + def forward(self, policy_model, observations, actions, advantages, + value_targets): + logits, _, values, _ = policy_model({ + SampleBatch.CUR_OBS: observations + }, []) + dist = self.dist_class(logits) + log_probs = dist.logp(actions) + self.entropy = dist.entropy().mean() + self.pi_err = -advantages.dot(log_probs.reshape(-1)) + self.value_err = F.mse_loss(values.reshape(-1), value_targets) overall_err = sum([ - pi_err, - self.vf_loss_coeff * value_err, - -self.entropy_coeff * entropy, + self.pi_err, + self.vf_loss_coeff * self.value_err, + -self.entropy_coeff * self.entropy, ]) + return overall_err @@ -44,7 +47,7 @@ class A3CPostprocessing(object): @override(TorchPolicyGraph) def extra_action_out(self, model_out): - return {SampleBatch.VF_PREDS: model_out[2].numpy()} + return {SampleBatch.VF_PREDS: model_out[2].cpu().numpy()} @override(PolicyGraph) def postprocess_trajectory(self, @@ -66,29 +69,47 @@ class A3CTorchPolicyGraph(A3CPostprocessing, TorchPolicyGraph): def __init__(self, obs_space, action_space, config): config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config) self.config = config - _, self.logit_dim = ModelCatalog.get_action_dist( - action_space, self.config["model"]) - self.model = ModelCatalog.get_torch_model(obs_space, self.logit_dim, - self.config["model"]) - loss = A3CLoss(self.model, self.config["vf_loss_coeff"], + dist_class, self.logit_dim = ModelCatalog.get_action_dist( + action_space, self.config["model"], torch=True) + model = ModelCatalog.get_torch_model(obs_space, self.logit_dim, + self.config["model"]) + loss = A3CLoss(dist_class, self.config["vf_loss_coeff"], self.config["entropy_coeff"]) TorchPolicyGraph.__init__( self, obs_space, action_space, - self.model, + model, loss, loss_inputs=[ SampleBatch.CUR_OBS, SampleBatch.ACTIONS, Postprocessing.ADVANTAGES, Postprocessing.VALUE_TARGETS - ]) + ], + action_distribution_cls=dist_class) @override(TorchPolicyGraph) def optimizer(self): - return torch.optim.Adam(self.model.parameters(), lr=self.config["lr"]) + return torch.optim.Adam(self._model.parameters(), lr=self.config["lr"]) + + @override(TorchPolicyGraph) + def extra_grad_process(self): + info = {} + if self.config["grad_clip"]: + total_norm = nn.utils.clip_grad_norm_(self._model.parameters(), + self.config["grad_clip"]) + info["grad_gnorm"] = total_norm + return info + + @override(TorchPolicyGraph) + def extra_grad_info(self): + return { + "policy_entropy": self._loss.entropy.item(), + "policy_loss": self._loss.pi_err.item(), + "vf_loss": self._loss.value_err.item() + } def _value(self, obs): with self.lock: - obs = torch.from_numpy(obs).float().unsqueeze(0) - _, _, vf, _ = self.model({"obs": obs}, []) - return vf.detach().numpy().squeeze() + obs = torch.from_numpy(obs).float().unsqueeze(0).to(self.device) + _, _, vf, _ = self._model({"obs": obs}, []) + return vf.detach().cpu().numpy().squeeze() diff --git a/python/ray/rllib/agents/pg/torch_pg_policy_graph.py b/python/ray/rllib/agents/pg/torch_pg_policy_graph.py index dde22685d..746ef1bca 100644 --- a/python/ray/rllib/agents/pg/torch_pg_policy_graph.py +++ b/python/ray/rllib/agents/pg/torch_pg_policy_graph.py @@ -3,7 +3,6 @@ from __future__ import division from __future__ import print_function import torch -import torch.nn.functional as F from torch import nn import ray @@ -17,16 +16,18 @@ from ray.rllib.utils.annotations import override class PGLoss(nn.Module): - def __init__(self, policy_model): + def __init__(self, dist_class): nn.Module.__init__(self) - self.policy_model = policy_model + self.dist_class = dist_class - def forward(self, observations, actions, advantages): - logits, _, values, _ = self.policy_model({"obs": observations}, []) - log_probs = F.log_softmax(logits, dim=1) - action_log_probs = log_probs.gather(1, actions.view(-1, 1)) - pi_err = -advantages.dot(action_log_probs.reshape(-1)) - return pi_err + def forward(self, policy_model, observations, actions, advantages): + logits, _, values, _ = policy_model({ + SampleBatch.CUR_OBS: observations + }, []) + dist = self.dist_class(logits) + log_probs = dist.logp(actions) + self.pi_err = -advantages.dot(log_probs.reshape(-1)) + return self.pi_err class PGPostprocessing(object): @@ -34,7 +35,7 @@ class PGPostprocessing(object): @override(TorchPolicyGraph) def extra_action_out(self, model_out): - return {SampleBatch.VF_PREDS: model_out[2].numpy()} + return {SampleBatch.VF_PREDS: model_out[2].cpu().numpy()} @override(PolicyGraph) def postprocess_trajectory(self, @@ -49,29 +50,34 @@ class PGTorchPolicyGraph(PGPostprocessing, TorchPolicyGraph): def __init__(self, obs_space, action_space, config): config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config) self.config = config - _, self.logit_dim = ModelCatalog.get_action_dist( - action_space, self.config["model"]) - self.model = ModelCatalog.get_torch_model(obs_space, self.logit_dim, - self.config["model"]) - loss = PGLoss(self.model) + dist_class, self.logit_dim = ModelCatalog.get_action_dist( + action_space, self.config["model"], torch=True) + model = ModelCatalog.get_torch_model(obs_space, self.logit_dim, + self.config["model"]) + loss = PGLoss(dist_class) TorchPolicyGraph.__init__( self, obs_space, action_space, - self.model, + model, loss, loss_inputs=[ SampleBatch.CUR_OBS, SampleBatch.ACTIONS, Postprocessing.ADVANTAGES - ]) + ], + action_distribution_cls=dist_class) @override(TorchPolicyGraph) def optimizer(self): - return torch.optim.Adam(self.model.parameters(), lr=self.config["lr"]) + return torch.optim.Adam(self._model.parameters(), lr=self.config["lr"]) + + @override(TorchPolicyGraph) + def extra_grad_info(self): + return {"policy_loss": self._loss.pi_err.item()} def _value(self, obs): with self.lock: - obs = torch.from_numpy(obs).float().unsqueeze(0) + obs = torch.from_numpy(obs).float().unsqueeze(0).to(self.device) _, _, vf, _ = self.model({"obs": obs}, []) - return vf.detach().numpy().squeeze() + return vf.detach().cpu().numpy().squeeze() diff --git a/python/ray/rllib/evaluation/torch_policy_graph.py b/python/ray/rllib/evaluation/torch_policy_graph.py index be55dde86..3f91cd49a 100644 --- a/python/ray/rllib/evaluation/torch_policy_graph.py +++ b/python/ray/rllib/evaluation/torch_policy_graph.py @@ -2,15 +2,17 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os + import numpy as np from threading import Lock try: import torch - import torch.nn.functional as F except ImportError: pass # soft dep +from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.utils.annotations import override @@ -28,11 +30,11 @@ class TorchPolicyGraph(PolicyGraph): """ def __init__(self, observation_space, action_space, model, loss, - loss_inputs): + loss_inputs, action_distribution_cls): """Build a policy graph from policy and loss torch modules. - Note that module inputs will be CPU tensors. The model and loss modules - are responsible for moving inputs to the right device. + Note that model will be placed on GPU device if CUDA_VISIBLE_DEVICES + is set. Only single GPU is supported for now. Arguments: observation_space (gym.Space): observation space of the policy. @@ -47,14 +49,20 @@ class TorchPolicyGraph(PolicyGraph): loss_inputs (list): List of SampleBatch columns that will be passed to the loss module's forward() function when computing the loss. For example, ["obs", "action", "advantages"]. + action_distribution_cls (ActionDistribution): Class for action + distribution. """ self.observation_space = observation_space self.action_space = action_space self.lock = Lock() - self._model = model + self.device = (torch.device("cuda") + if bool(os.environ.get("CUDA_VISIBLE_DEVICES", None)) + else torch.device("cpu")) + self._model = model.to(self.device) self._loss = loss self._loss_inputs = loss_inputs self._optimizer = self.optimizer() + self._action_dist_cls = action_distribution_cls @override(PolicyGraph) def compute_actions(self, @@ -67,11 +75,14 @@ class TorchPolicyGraph(PolicyGraph): **kwargs): with self.lock: with torch.no_grad(): - ob = torch.from_numpy(np.array(obs_batch)).float() + ob = torch.from_numpy(np.array(obs_batch)) \ + .float().to(self.device) model_out = self._model({"obs": ob}, state_batches) logits, _, vf, state = model_out - actions = F.softmax(logits, dim=1).multinomial(1).squeeze(0) - return (actions.numpy(), [h.numpy() for h in state], + action_dist = self._action_dist_cls(logits) + actions = action_dist.sample() + return (actions.cpu().numpy(), + [h.cpu().numpy() for h in state], self.extra_action_out(model_out)) @override(PolicyGraph) @@ -79,33 +90,40 @@ class TorchPolicyGraph(PolicyGraph): with self.lock: loss_in = [] for key in self._loss_inputs: - loss_in.append(torch.from_numpy(postprocessed_batch[key])) - loss_out = self._loss(*loss_in) + loss_in.append( + torch.from_numpy(postprocessed_batch[key]).to(self.device)) + loss_out = self._loss(self._model, *loss_in) self._optimizer.zero_grad() loss_out.backward() + + grad_process_info = self.extra_grad_process() + # Note that return values are just references; # calling zero_grad will modify the values grads = [] for p in self._model.parameters(): if p.grad is not None: - grads.append(p.grad.data.numpy()) + grads.append(p.grad.data.cpu().numpy()) else: grads.append(None) - return grads, {} + + grad_info = self.extra_grad_info() + grad_info.update(grad_process_info) + return grads, {LEARNER_STATS_KEY: grad_info} @override(PolicyGraph) def apply_gradients(self, gradients): with self.lock: for g, p in zip(gradients, self._model.parameters()): if g is not None: - p.grad = torch.from_numpy(g) + p.grad = torch.from_numpy(g).to(self.device) self._optimizer.step() return {} @override(PolicyGraph) def get_weights(self): with self.lock: - return self._model.state_dict() + return {k: v.cpu() for k, v in self._model.state_dict().items()} @override(PolicyGraph) def set_weights(self, weights): @@ -116,6 +134,11 @@ class TorchPolicyGraph(PolicyGraph): def get_initial_state(self): return [s.numpy() for s in self._model.state_init()] + def extra_grad_process(self): + """Allow subclass to do extra processing on gradients and + return processing info.""" + return {} + def extra_action_out(self, model_out): """Returns dict of extra info to include in experience batch. @@ -123,6 +146,11 @@ class TorchPolicyGraph(PolicyGraph): model_out (list): Outputs of the policy model module.""" return {} + def extra_grad_info(self): + """Return dict of extra grad info.""" + + return {} + def optimizer(self): """Custom PyTorch optimizer to use.""" return torch.optim.Adam(self._model.parameters()) diff --git a/python/ray/rllib/models/catalog.py b/python/ray/rllib/models/catalog.py index bcd79cbfe..776773552 100644 --- a/python/ray/rllib/models/catalog.py +++ b/python/ray/rllib/models/catalog.py @@ -15,6 +15,8 @@ from ray.rllib.models.extra_spaces import Simplex from ray.rllib.models.action_dist import (Categorical, MultiCategorical, Deterministic, DiagGaussian, MultiActionDistribution, Dirichlet) +from ray.rllib.models.torch_action_dist import (TorchCategorical, + TorchDiagGaussian) from ray.rllib.models.preprocessors import get_preprocessor from ray.rllib.models.fcnet import FullyConnectedNetwork from ray.rllib.models.visionnet import VisionNetwork @@ -89,13 +91,14 @@ class ModelCatalog(object): @staticmethod @DeveloperAPI - def get_action_dist(action_space, config, dist_type=None): + def get_action_dist(action_space, config, dist_type=None, torch=False): """Returns action distribution class and size for the given action space. Args: action_space (Space): Action space of the target gym env. config (dict): Optional model config. dist_type (str): Optional identifier of the action distribution. + torch (bool): Optional whether to return PyTorch distribution. Returns: dist_class (ActionDistribution): Python class of the distribution. @@ -111,7 +114,7 @@ class ModelCatalog(object): "Consider reshaping this into a single dimension, " "using a Tuple action space, or the multi-agent API.") if dist_type is None: - dist = DiagGaussian + dist = TorchDiagGaussian if torch else DiagGaussian if config.get("squash_to_range"): raise ValueError( "The squash_to_range option is deprecated. See the " @@ -120,7 +123,8 @@ class ModelCatalog(object): elif dist_type == "deterministic": return Deterministic, action_space.shape[0] elif isinstance(action_space, gym.spaces.Discrete): - return Categorical, action_space.n + dist = TorchCategorical if torch else Categorical + return dist, action_space.n elif isinstance(action_space, gym.spaces.Tuple): child_dist = [] input_lens = [] @@ -129,14 +133,20 @@ class ModelCatalog(object): action, config) child_dist.append(dist) input_lens.append(action_size) + if torch: + raise NotImplementedError return partial( MultiActionDistribution, child_distributions=child_dist, action_space=action_space, input_lens=input_lens), sum(input_lens) elif isinstance(action_space, Simplex): + if torch: + raise NotImplementedError return Dirichlet, action_space.shape[0] elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete): + if torch: + raise NotImplementedError return MultiCategorical, sum(action_space.nvec) raise NotImplementedError("Unsupported args: {} {}".format( diff --git a/python/ray/rllib/models/torch_action_dist.py b/python/ray/rllib/models/torch_action_dist.py new file mode 100644 index 000000000..b8becc9a3 --- /dev/null +++ b/python/ray/rllib/models/torch_action_dist.py @@ -0,0 +1,52 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +try: + import torch +except ImportError: + pass # soft dep + +from ray.rllib.models.action_dist import ActionDistribution +from ray.rllib.utils.annotations import override + + +class TorchDistributionWrapper(ActionDistribution): + """Wrapper class for torch.distributions.""" + + @override(ActionDistribution) + def logp(self, actions): + return self.dist.log_prob(actions) + + @override(ActionDistribution) + def entropy(self): + return self.dist.entropy() + + @override(ActionDistribution) + def kl(self, other): + return torch.distributions.kl.kl_divergence(self.dist, other) + + @override(ActionDistribution) + def sample(self): + return self.dist.sample() + + +class TorchCategorical(TorchDistributionWrapper): + """Wrapper class for PyTorch Categorical distribution.""" + + @override(ActionDistribution) + def __init__(self, inputs): + self.dist = torch.distributions.categorical.Categorical(logits=inputs) + + +class TorchDiagGaussian(TorchDistributionWrapper): + """Wrapper class for PyTorch Normal distribution.""" + + @override(ActionDistribution) + def __init__(self, inputs): + mean, log_std = torch.chunk(inputs, 2, dim=1) + self.dist = torch.distributions.normal.Normal(mean, torch.exp(log_std)) + + @override(TorchDistributionWrapper) + def logp(self, actions): + return TorchDistributionWrapper.logp(self, actions).sum(-1)