[rllib] Remove TorchPolicy locks (#5764)

* remove torch lock

* remove lock
This commit is contained in:
Eric Liang
2019-09-24 17:52:16 -07:00
committed by GitHub
parent 10f21fa313
commit c6919d315d
2 changed files with 47 additions and 62 deletions
+3 -4
View File
@@ -71,10 +71,9 @@ def torch_optimizer(policy, config):
class ValueNetworkMixin(object):
def _value(self, obs):
with self.lock:
obs = torch.from_numpy(obs).float().unsqueeze(0).to(self.device)
_ = self.model({"obs": obs}, [], [1])
return self.model.value_function().detach().cpu().numpy().squeeze()
obs = torch.from_numpy(obs).float().unsqueeze(0).to(self.device)
_ = self.model({"obs": obs}, [], [1])
return self.model.value_function().detach().cpu().numpy().squeeze()
A3CTorchPolicy = build_torch_policy(
+44 -58
View File
@@ -5,8 +5,6 @@ from __future__ import print_function
import numpy as np
import os
from threading import Lock
try:
import torch
except ImportError:
@@ -25,8 +23,6 @@ class TorchPolicy(Policy):
Attributes:
observation_space (gym.Space): observation space of the policy.
action_space (gym.Space): action space of the policy.
lock (Lock): Lock that must be held around PyTorch ops on this graph.
This is necessary when using the async sampler.
config (dict): config of the policy
model (TorchModel): Torch model instance
dist_class (type): Torch action distribution class
@@ -52,7 +48,6 @@ class TorchPolicy(Policy):
"""
self.observation_space = observation_space
self.action_space = action_space
self.lock = Lock()
self.device = (torch.device("cuda")
if bool(os.environ.get("CUDA_VISIBLE_DEVICES", None))
else torch.device("cpu"))
@@ -70,83 +65,74 @@ class TorchPolicy(Policy):
info_batch=None,
episodes=None,
**kwargs):
with self.lock:
with torch.no_grad():
input_dict = self._lazy_tensor_dict({
"obs": obs_batch,
})
if prev_action_batch:
input_dict["prev_actions"] = prev_action_batch
if prev_reward_batch:
input_dict["prev_rewards"] = prev_reward_batch
model_out = self.model(input_dict, state_batches, [1])
logits, state = model_out
action_dist = self.dist_class(logits, self.model)
actions = action_dist.sample()
return (actions.cpu().numpy(),
[h.cpu().numpy() for h in state],
self.extra_action_out(input_dict, state_batches,
self.model))
with torch.no_grad():
input_dict = self._lazy_tensor_dict({
"obs": obs_batch,
})
if prev_action_batch:
input_dict["prev_actions"] = prev_action_batch
if prev_reward_batch:
input_dict["prev_rewards"] = prev_reward_batch
model_out = self.model(input_dict, state_batches, [1])
logits, state = model_out
action_dist = self.dist_class(logits, self.model)
actions = action_dist.sample()
return (actions.cpu().numpy(), [h.cpu().numpy() for h in state],
self.extra_action_out(input_dict, state_batches,
self.model))
@override(Policy)
def learn_on_batch(self, postprocessed_batch):
train_batch = self._lazy_tensor_dict(postprocessed_batch)
with self.lock:
loss_out = self._loss(self, self.model, self.dist_class,
train_batch)
self._optimizer.zero_grad()
loss_out.backward()
loss_out = self._loss(self, self.model, self.dist_class, train_batch)
self._optimizer.zero_grad()
loss_out.backward()
grad_process_info = self.extra_grad_process()
self._optimizer.step()
grad_process_info = self.extra_grad_process()
self._optimizer.step()
grad_info = self.extra_grad_info(train_batch)
grad_info.update(grad_process_info)
return {LEARNER_STATS_KEY: grad_info}
grad_info = self.extra_grad_info(train_batch)
grad_info.update(grad_process_info)
return {LEARNER_STATS_KEY: grad_info}
@override(Policy)
def compute_gradients(self, postprocessed_batch):
train_batch = self._lazy_tensor_dict(postprocessed_batch)
with self.lock:
loss_out = self._loss(self, self.model, self.dist_class,
train_batch)
self._optimizer.zero_grad()
loss_out.backward()
loss_out = self._loss(self, self.model, self.dist_class, train_batch)
self._optimizer.zero_grad()
loss_out.backward()
grad_process_info = self.extra_grad_process()
grad_process_info = self.extra_grad_process()
# Note that return values are just references;
# calling zero_grad will modify the values
grads = []
for p in self.model.parameters():
if p.grad is not None:
grads.append(p.grad.data.cpu().numpy())
else:
grads.append(None)
# Note that return values are just references;
# calling zero_grad will modify the values
grads = []
for p in self.model.parameters():
if p.grad is not None:
grads.append(p.grad.data.cpu().numpy())
else:
grads.append(None)
grad_info = self.extra_grad_info(train_batch)
grad_info.update(grad_process_info)
return grads, {LEARNER_STATS_KEY: grad_info}
grad_info = self.extra_grad_info(train_batch)
grad_info.update(grad_process_info)
return grads, {LEARNER_STATS_KEY: grad_info}
@override(Policy)
def apply_gradients(self, gradients):
with self.lock:
for g, p in zip(gradients, self.model.parameters()):
if g is not None:
p.grad = torch.from_numpy(g).to(self.device)
self._optimizer.step()
for g, p in zip(gradients, self.model.parameters()):
if g is not None:
p.grad = torch.from_numpy(g).to(self.device)
self._optimizer.step()
@override(Policy)
def get_weights(self):
with self.lock:
return {k: v.cpu() for k, v in self.model.state_dict().items()}
return {k: v.cpu() for k, v in self.model.state_dict().items()}
@override(Policy)
def set_weights(self, weights):
with self.lock:
self.model.load_state_dict(weights)
self.model.load_state_dict(weights)
@override(Policy)
def get_initial_state(self):