mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 23:46:50 +08:00
[rllib] Remove TorchPolicy locks (#5764)
* remove torch lock * remove lock
This commit is contained in:
@@ -71,10 +71,9 @@ def torch_optimizer(policy, config):
|
||||
|
||||
class ValueNetworkMixin(object):
|
||||
def _value(self, obs):
|
||||
with self.lock:
|
||||
obs = torch.from_numpy(obs).float().unsqueeze(0).to(self.device)
|
||||
_ = self.model({"obs": obs}, [], [1])
|
||||
return self.model.value_function().detach().cpu().numpy().squeeze()
|
||||
obs = torch.from_numpy(obs).float().unsqueeze(0).to(self.device)
|
||||
_ = self.model({"obs": obs}, [], [1])
|
||||
return self.model.value_function().detach().cpu().numpy().squeeze()
|
||||
|
||||
|
||||
A3CTorchPolicy = build_torch_policy(
|
||||
|
||||
@@ -5,8 +5,6 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
from threading import Lock
|
||||
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
@@ -25,8 +23,6 @@ class TorchPolicy(Policy):
|
||||
Attributes:
|
||||
observation_space (gym.Space): observation space of the policy.
|
||||
action_space (gym.Space): action space of the policy.
|
||||
lock (Lock): Lock that must be held around PyTorch ops on this graph.
|
||||
This is necessary when using the async sampler.
|
||||
config (dict): config of the policy
|
||||
model (TorchModel): Torch model instance
|
||||
dist_class (type): Torch action distribution class
|
||||
@@ -52,7 +48,6 @@ class TorchPolicy(Policy):
|
||||
"""
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
self.lock = Lock()
|
||||
self.device = (torch.device("cuda")
|
||||
if bool(os.environ.get("CUDA_VISIBLE_DEVICES", None))
|
||||
else torch.device("cpu"))
|
||||
@@ -70,83 +65,74 @@ class TorchPolicy(Policy):
|
||||
info_batch=None,
|
||||
episodes=None,
|
||||
**kwargs):
|
||||
with self.lock:
|
||||
with torch.no_grad():
|
||||
input_dict = self._lazy_tensor_dict({
|
||||
"obs": obs_batch,
|
||||
})
|
||||
if prev_action_batch:
|
||||
input_dict["prev_actions"] = prev_action_batch
|
||||
if prev_reward_batch:
|
||||
input_dict["prev_rewards"] = prev_reward_batch
|
||||
model_out = self.model(input_dict, state_batches, [1])
|
||||
logits, state = model_out
|
||||
action_dist = self.dist_class(logits, self.model)
|
||||
actions = action_dist.sample()
|
||||
return (actions.cpu().numpy(),
|
||||
[h.cpu().numpy() for h in state],
|
||||
self.extra_action_out(input_dict, state_batches,
|
||||
self.model))
|
||||
with torch.no_grad():
|
||||
input_dict = self._lazy_tensor_dict({
|
||||
"obs": obs_batch,
|
||||
})
|
||||
if prev_action_batch:
|
||||
input_dict["prev_actions"] = prev_action_batch
|
||||
if prev_reward_batch:
|
||||
input_dict["prev_rewards"] = prev_reward_batch
|
||||
model_out = self.model(input_dict, state_batches, [1])
|
||||
logits, state = model_out
|
||||
action_dist = self.dist_class(logits, self.model)
|
||||
actions = action_dist.sample()
|
||||
return (actions.cpu().numpy(), [h.cpu().numpy() for h in state],
|
||||
self.extra_action_out(input_dict, state_batches,
|
||||
self.model))
|
||||
|
||||
@override(Policy)
|
||||
def learn_on_batch(self, postprocessed_batch):
|
||||
train_batch = self._lazy_tensor_dict(postprocessed_batch)
|
||||
|
||||
with self.lock:
|
||||
loss_out = self._loss(self, self.model, self.dist_class,
|
||||
train_batch)
|
||||
self._optimizer.zero_grad()
|
||||
loss_out.backward()
|
||||
loss_out = self._loss(self, self.model, self.dist_class, train_batch)
|
||||
self._optimizer.zero_grad()
|
||||
loss_out.backward()
|
||||
|
||||
grad_process_info = self.extra_grad_process()
|
||||
self._optimizer.step()
|
||||
grad_process_info = self.extra_grad_process()
|
||||
self._optimizer.step()
|
||||
|
||||
grad_info = self.extra_grad_info(train_batch)
|
||||
grad_info.update(grad_process_info)
|
||||
return {LEARNER_STATS_KEY: grad_info}
|
||||
grad_info = self.extra_grad_info(train_batch)
|
||||
grad_info.update(grad_process_info)
|
||||
return {LEARNER_STATS_KEY: grad_info}
|
||||
|
||||
@override(Policy)
|
||||
def compute_gradients(self, postprocessed_batch):
|
||||
train_batch = self._lazy_tensor_dict(postprocessed_batch)
|
||||
|
||||
with self.lock:
|
||||
loss_out = self._loss(self, self.model, self.dist_class,
|
||||
train_batch)
|
||||
self._optimizer.zero_grad()
|
||||
loss_out.backward()
|
||||
loss_out = self._loss(self, self.model, self.dist_class, train_batch)
|
||||
self._optimizer.zero_grad()
|
||||
loss_out.backward()
|
||||
|
||||
grad_process_info = self.extra_grad_process()
|
||||
grad_process_info = self.extra_grad_process()
|
||||
|
||||
# Note that return values are just references;
|
||||
# calling zero_grad will modify the values
|
||||
grads = []
|
||||
for p in self.model.parameters():
|
||||
if p.grad is not None:
|
||||
grads.append(p.grad.data.cpu().numpy())
|
||||
else:
|
||||
grads.append(None)
|
||||
# Note that return values are just references;
|
||||
# calling zero_grad will modify the values
|
||||
grads = []
|
||||
for p in self.model.parameters():
|
||||
if p.grad is not None:
|
||||
grads.append(p.grad.data.cpu().numpy())
|
||||
else:
|
||||
grads.append(None)
|
||||
|
||||
grad_info = self.extra_grad_info(train_batch)
|
||||
grad_info.update(grad_process_info)
|
||||
return grads, {LEARNER_STATS_KEY: grad_info}
|
||||
grad_info = self.extra_grad_info(train_batch)
|
||||
grad_info.update(grad_process_info)
|
||||
return grads, {LEARNER_STATS_KEY: grad_info}
|
||||
|
||||
@override(Policy)
|
||||
def apply_gradients(self, gradients):
|
||||
with self.lock:
|
||||
for g, p in zip(gradients, self.model.parameters()):
|
||||
if g is not None:
|
||||
p.grad = torch.from_numpy(g).to(self.device)
|
||||
self._optimizer.step()
|
||||
for g, p in zip(gradients, self.model.parameters()):
|
||||
if g is not None:
|
||||
p.grad = torch.from_numpy(g).to(self.device)
|
||||
self._optimizer.step()
|
||||
|
||||
@override(Policy)
|
||||
def get_weights(self):
|
||||
with self.lock:
|
||||
return {k: v.cpu() for k, v in self.model.state_dict().items()}
|
||||
return {k: v.cpu() for k, v in self.model.state_dict().items()}
|
||||
|
||||
@override(Policy)
|
||||
def set_weights(self, weights):
|
||||
with self.lock:
|
||||
self.model.load_state_dict(weights)
|
||||
self.model.load_state_dict(weights)
|
||||
|
||||
@override(Policy)
|
||||
def get_initial_state(self):
|
||||
|
||||
Reference in New Issue
Block a user