mirror of
https://github.com/wassname/pytorch-soft-actor-critic.git
synced 2026-06-27 18:06:10 +08:00
Add Reg. Loss
This commit is contained in:
@@ -91,7 +91,7 @@ class GaussianPolicy(nn.Module):
|
||||
# Enforcing Action Bound
|
||||
log_prob -= torch.log(1 - action.pow(2) + epsilon)
|
||||
log_prob = log_prob.sum(1, keepdim=True)
|
||||
return action, log_prob, torch.tanh(mean)
|
||||
return action, log_prob, mean, log_std
|
||||
|
||||
class DeterministicPolicy(nn.Module):
|
||||
def __init__(self, num_inputs, num_actions, hidden_dim):
|
||||
|
||||
@@ -34,9 +34,10 @@ class SAC(object):
|
||||
def select_action(self, state, eval=False):
|
||||
state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
|
||||
if eval == False:
|
||||
action, _, _ = self.policy.sample(state)
|
||||
action, _, _, _ = self.policy.sample(state)
|
||||
else:
|
||||
_, _, action = self.policy.sample(state)
|
||||
_, _, action, _ = self.policy.sample(state)
|
||||
action = torch.tanh(action)
|
||||
action = action.detach().cpu().numpy()
|
||||
return action[0]
|
||||
|
||||
@@ -68,12 +69,15 @@ class SAC(object):
|
||||
qf2_loss.backward()
|
||||
self.critic_optim.step()
|
||||
|
||||
pi, log_pi, _ = self.policy.sample(state_batch)
|
||||
pi, log_pi, mean, log_std = self.policy.sample(state_batch)
|
||||
|
||||
qf1_pi, qf2_pi = self.critic(state_batch, pi)
|
||||
min_qf_pi = torch.min(qf1_pi, qf2_pi)
|
||||
|
||||
policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean() # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]
|
||||
# Regularization Loss
|
||||
reg_loss = 0.001 * (mean.pow(2).mean() + log_std.pow(2).mean())
|
||||
policy_loss += reg_loss
|
||||
|
||||
vf = self.value(state_batch)
|
||||
with torch.no_grad():
|
||||
|
||||
Reference in New Issue
Block a user