Add Reg. Loss

This commit is contained in:
pranz24
2019-04-06 21:51:07 +05:30
parent 01f1793ca5
commit e0ee7fcb83
2 changed files with 8 additions and 4 deletions
+1 -1
View File
@@ -91,7 +91,7 @@ class GaussianPolicy(nn.Module):
# Enforcing Action Bound
log_prob -= torch.log(1 - action.pow(2) + epsilon)
log_prob = log_prob.sum(1, keepdim=True)
return action, log_prob, torch.tanh(mean)
return action, log_prob, mean, log_std
class DeterministicPolicy(nn.Module):
def __init__(self, num_inputs, num_actions, hidden_dim):
+7 -3
View File
@@ -34,9 +34,10 @@ class SAC(object):
def select_action(self, state, eval=False):
state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
if eval == False:
action, _, _ = self.policy.sample(state)
action, _, _, _ = self.policy.sample(state)
else:
_, _, action = self.policy.sample(state)
_, _, action, _ = self.policy.sample(state)
action = torch.tanh(action)
action = action.detach().cpu().numpy()
return action[0]
@@ -68,12 +69,15 @@ class SAC(object):
qf2_loss.backward()
self.critic_optim.step()
pi, log_pi, _ = self.policy.sample(state_batch)
pi, log_pi, mean, log_std = self.policy.sample(state_batch)
qf1_pi, qf2_pi = self.critic(state_batch, pi)
min_qf_pi = torch.min(qf1_pi, qf2_pi)
policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean() # Jπ = 𝔼stD,εtN[α * logπ(f(εt;st)|st) Q(st,f(εt;st))]
# Regularization Loss
reg_loss = 0.001 * (mean.pow(2).mean() + log_std.pow(2).mean())
policy_loss += reg_loss
vf = self.value(state_batch)
with torch.no_grad():