fix for pytorch-1.5

This commit is contained in:
pranz24
2020-06-06 00:19:15 +05:30
parent 847edf58a5
commit e961172767
9 changed files with 12 additions and 9 deletions
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
+1 -1
View File
@@ -9,7 +9,7 @@ from tensorboardX import SummaryWriter
from replay_memory import ReplayMemory
parser = argparse.ArgumentParser(description='PyTorch Soft Actor-Critic Args')
parser.add_argument('--env-name', default="HalfCheetah-v2",
parser.add_argument('--env-name', default="HalfCheetahBulletEnv-v0",
help='Mujoco Gym environment (default: HalfCheetah-v2)')
parser.add_argument('--policy', default="Gaussian",
help='Policy Type: Gaussian | Deterministic (default: Gaussian)')
+11 -8
View File
@@ -67,6 +67,14 @@ class SAC(object):
qf1, qf2 = self.critic(state_batch, action_batch) # Two Q-functions to mitigate positive bias in the policy improvement step
qf1_loss = F.mse_loss(qf1, next_q_value) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
qf2_loss = F.mse_loss(qf2, next_q_value) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
qf_loss = qf1_loss + qf2_loss
self.critic_optim.zero_grad()
qf_loss.backward()
self.critic_optim.step()
for c_param in self.critic.parameters():
c_param.requires_grad = False
pi, log_pi, _ = self.policy.sample(state_batch)
@@ -75,18 +83,13 @@ class SAC(object):
policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean() # Jπ = 𝔼stD,εtN[α * logπ(f(εt;st)|st) Q(st,f(εt;st))]
self.critic_optim.zero_grad()
qf1_loss.backward()
self.critic_optim.step()
self.critic_optim.zero_grad()
qf2_loss.backward()
self.critic_optim.step()
self.policy_optim.zero_grad()
policy_loss.backward()
self.policy_optim.step()
for c_param in self.critic.parameters():
c_param.requires_grad = True
if self.automatic_entropy_tuning:
alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()