mirror of
https://github.com/wassname/pytorch-soft-actor-critic.git
synced 2026-06-27 16:46:28 +08:00
fix for pytorch-1.5
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -9,7 +9,7 @@ from tensorboardX import SummaryWriter
|
||||
from replay_memory import ReplayMemory
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch Soft Actor-Critic Args')
|
||||
parser.add_argument('--env-name', default="HalfCheetah-v2",
|
||||
parser.add_argument('--env-name', default="HalfCheetahBulletEnv-v0",
|
||||
help='Mujoco Gym environment (default: HalfCheetah-v2)')
|
||||
parser.add_argument('--policy', default="Gaussian",
|
||||
help='Policy Type: Gaussian | Deterministic (default: Gaussian)')
|
||||
|
||||
BIN
Binary file not shown.
@@ -67,6 +67,14 @@ class SAC(object):
|
||||
qf1, qf2 = self.critic(state_batch, action_batch) # Two Q-functions to mitigate positive bias in the policy improvement step
|
||||
qf1_loss = F.mse_loss(qf1, next_q_value) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
|
||||
qf2_loss = F.mse_loss(qf2, next_q_value) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
|
||||
qf_loss = qf1_loss + qf2_loss
|
||||
|
||||
self.critic_optim.zero_grad()
|
||||
qf_loss.backward()
|
||||
self.critic_optim.step()
|
||||
|
||||
for c_param in self.critic.parameters():
|
||||
c_param.requires_grad = False
|
||||
|
||||
pi, log_pi, _ = self.policy.sample(state_batch)
|
||||
|
||||
@@ -75,18 +83,13 @@ class SAC(object):
|
||||
|
||||
policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean() # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]
|
||||
|
||||
self.critic_optim.zero_grad()
|
||||
qf1_loss.backward()
|
||||
self.critic_optim.step()
|
||||
|
||||
self.critic_optim.zero_grad()
|
||||
qf2_loss.backward()
|
||||
self.critic_optim.step()
|
||||
|
||||
self.policy_optim.zero_grad()
|
||||
policy_loss.backward()
|
||||
self.policy_optim.step()
|
||||
|
||||
for c_param in self.critic.parameters():
|
||||
c_param.requires_grad = True
|
||||
|
||||
if self.automatic_entropy_tuning:
|
||||
alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user