Files
2021-01-17 18:27:36 +08:00

149 lines
7.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import torch
import torch.nn.functional as F
from torch.optim import Adam
from utils import soft_update, hard_update
from model import GaussianPolicy, QNetwork, DeterministicPolicy
from loguru import logger
from apex import amp
class SAC(object):
def __init__(self, num_inputs, action_space, args, process_obs=None, opt_level='O1'):
self.gamma = args.gamma
self.tau = args.tau
self.alpha = args.alpha
self.device = torch.device("cuda" if args.cuda else "cpu")
self.dtype = torch.float
self.policy_type = args.policy
self.target_update_interval = args.target_update_interval
self.automatic_entropy_tuning = args.automatic_entropy_tuning
self.process_obs = process_obs.to(self.device).to(self.dtype)
self.critic = QNetwork(num_inputs, action_space.shape[0], args.hidden_size).to(device=self.device).to(self.dtype)
self.critic_optim = Adam(
list(self.critic.parameters()) + list(process_obs.parameters())
, lr=args.lr)
self.critic_target = QNetwork(num_inputs, action_space.shape[0], args.hidden_size).to(self.device).to(self.dtype)
hard_update(self.critic_target, self.critic)
if self.policy_type == "Gaussian":
# Target Entropy = dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper
if self.automatic_entropy_tuning is True:
self.target_entropy = -torch.prod(torch.Tensor(action_space.shape).to(self.device)).item()
self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device, dtype=self.dtype)
self.alpha_optim = Adam([self.log_alpha], lr=args.lr)
self.policy = GaussianPolicy(num_inputs, action_space.shape[0], args.hidden_size, action_space).to(self.device).to(self.dtype)
self.policy_optim = Adam(
list(self.policy.parameters()) + list(process_obs.parameters()),
lr=args.lr)
else:
self.alpha = 0
self.automatic_entropy_tuning = False
self.policy = DeterministicPolicy(num_inputs, action_space.shape[0], args.hidden_size, action_space).to(self.device).to(self.dtype)
self.policy_optim = Adam(
list(self.policy.parameters()) + list(process_obs.parameters()),
lr=args.lr)
if opt_level is not None:
model, optimizer = amp.initialize(
[self.policy, self.process_obs, self.critic, self.critic_target],
[self.policy_optim, self.critic_optim],
opt_level=opt_level)
def select_action(self, obs, evaluate=False):
with torch.no_grad():
obs = torch.FloatTensor(obs).to(self.device).unsqueeze(0).to(self.dtype)
state = self.process_obs(obs)
if evaluate is False:
action, _, _ = self.policy.sample(state)
else:
_, _, action = self.policy.sample(state)
action = action.detach().cpu().numpy()[0]
return action
def update_parameters(self, memory, batch_size, updates):
# Sample a batch from memory
obs_batch, action_batch, reward_batch, next_obs_batch, mask_batch = memory.sample(batch_size=batch_size)
obs_batch = torch.FloatTensor(obs_batch).to(self.device).to(self.dtype)
next_obs_batch= torch.FloatTensor(next_obs_batch).to(self.device).to(self.dtype)
action_batch = torch.FloatTensor(action_batch).to(self.device).to(self.dtype)
reward_batch = torch.FloatTensor(reward_batch).to(self.device).unsqueeze(1).to(self.dtype)
mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1).to(self.dtype)
state_batch = self.process_obs(obs_batch)
with torch.no_grad():
next_state_batch = self.process_obs(next_obs_batch)
next_state_action, next_state_log_pi, _ = self.policy.sample(next_state_batch)
qf1_next_target, qf2_next_target = self.critic_target(next_state_batch, next_state_action)
min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - self.alpha * next_state_log_pi
next_q_value = reward_batch + mask_batch * self.gamma * (min_qf_next_target)
qf1, qf2 = self.critic(state_batch, action_batch) # Two Q-functions to mitigate positive bias in the policy improvement step
qf1_loss = F.mse_loss(qf1, next_q_value) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
qf2_loss = F.mse_loss(qf2, next_q_value) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
qf_loss = qf1_loss + qf2_loss
self.critic_optim.zero_grad()
assert torch.isfinite(qf_loss).all()
with amp.scale_loss(qf_loss, self.critic_optim) as qf_loss:
qf_loss.backward()
self.critic_optim.step()
state_batch = self.process_obs(obs_batch)
pi, log_pi, _ = self.policy.sample(state_batch)
qf1_pi, qf2_pi = self.critic(state_batch.detach(), pi)
min_qf_pi = torch.min(qf1_pi, qf2_pi)
policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean() # Jπ = 𝔼stD,εtN[α * logπ(f(εt;st)|st) Q(st,f(εt;st))]
self.policy_optim.zero_grad()
assert torch.isfinite(policy_loss).all()
with amp.scale_loss(policy_loss, self.policy_optim) as policy_loss:
policy_loss.backward()
self.policy_optim.step()
if self.automatic_entropy_tuning:
alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
self.alpha_optim.zero_grad()
alpha_loss.backward()
self.alpha_optim.step()
self.alpha = self.log_alpha.exp()
alpha_tlogs = self.alpha.clone() # For TensorboardX logs
else:
alpha_loss = torch.tensor(0.).to(self.device).to(self.dtype)
alpha_tlogs = torch.tensor(self.alpha) # For TensorboardX logs
if updates % self.target_update_interval == 0:
soft_update(self.critic_target, self.critic, self.tau)
return qf1_loss.item(), qf2_loss.item(), policy_loss.item(), alpha_loss.item(), alpha_tlogs.item()
# Save model parameters
def save_model(self, actor_path=None, critic_path=None, process_obs_path=None):
logger.debug(f'saving models to {actor_path} and {critic_path} and {process_obs_path}')
torch.save(self.policy.state_dict(), actor_path)
torch.save(self.critic.state_dict(), critic_path)
torch.save(self.process_obs.state_dict(), process_obs_path)
# Load model parameters
def load_model(self, actor_path=None, critic_path=None, process_obs_path=None):
logger.info(f'Loading models from {actor_path} and {critic_path} and {process_obs_path}')
if actor_path is not None:
self.policy.load_state_dict(torch.load(actor_path))
if critic_path is not None:
self.critic.load_state_dict(torch.load(critic_path))
if process_obs_path is not None:
self.process_obs.load_state_dict(torch.load(process_obs_path))