mirror of
https://github.com/wassname/pytorch-soft-actor-critic.git
synced 2026-06-27 17:01:47 +08:00
smaller camera, detach critic inputs
This commit is contained in:
@@ -9,9 +9,10 @@ run:
|
||||
-m pdb -c continue \
|
||||
main.py \
|
||||
--cuda \
|
||||
--opt_level O1 \
|
||||
--automatic_entropy_tuning true \
|
||||
--replay_size 10000 \
|
||||
--demonstrations data/demonstrations \
|
||||
--replay_size 100000 \
|
||||
# --demonstrations data/demonstrations \
|
||||
# --load auto \
|
||||
# ${python} -m pdb main.py --cuda --automatic_entropy_tuning true --replay_size 10000 --load auto --start_steps 200
|
||||
# LOGURU_LEVEL=INFO ${python} main.py --demonstrations data/demonstrations --cuda --automatic_entropy_tuning true --replay_size 20000 --load auto
|
||||
|
||||
@@ -66,6 +66,8 @@ def get_args():
|
||||
help='Load models')
|
||||
parser.add_argument('-r', '--render', action="store_true",
|
||||
help='show')
|
||||
parser.add_argument('-O', '--opt_level', default='O0',
|
||||
help='Apex Amp Optimisation level')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
@@ -95,16 +97,16 @@ observation_dim=observation_dim - process_obs.reduce_obs_space
|
||||
logger.info(f"process_obs reduces obs_space {env.observation_space.shape[0]}-{process_obs.reduce_obs_space}={observation_dim}")
|
||||
|
||||
# Agent
|
||||
agent = SAC(observation_dim, env.action_space, args, process_obs)
|
||||
agent = SAC(observation_dim, env.action_space, args, process_obs, args.opt_level)
|
||||
|
||||
# from torchinfo import summary
|
||||
# print('process_obs')
|
||||
# summary(process_obs, input_size=(2, *env.observation_space.shape), depth=2)
|
||||
# print('critic')
|
||||
# summary(agent.critic, input_size=((2, observation_dim), (2, action_dim)))
|
||||
# print('policy')
|
||||
# summary(agent.policy, input_size=(2, observation_dim))
|
||||
# # print(process_obs, agent.critic, agent.policy)
|
||||
from torchinfo import summary
|
||||
print('process_obs')
|
||||
summary(process_obs, input_size=(2, *env.observation_space.shape), depth=2)
|
||||
print('critic')
|
||||
summary(agent.critic, input_size=((2, observation_dim), (2, action_dim)))
|
||||
print('policy')
|
||||
summary(agent.policy, input_size=(2, observation_dim))
|
||||
# print(process_obs, agent.critic, agent.policy)
|
||||
|
||||
#Tensorboard
|
||||
log_name = '{}_SAC_{}_{}_{}'.format(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), args.env_name,
|
||||
@@ -172,7 +174,7 @@ with RichTQDM() as prog:
|
||||
else:
|
||||
action = agent.select_action(state) # Sample action from policy
|
||||
|
||||
if len(memory) > args.batch_size and (total_numsteps%20==0):
|
||||
if len(memory) > args.batch_size and (total_numsteps%1==0):
|
||||
# Number of updates per step in environment
|
||||
for i in range(args.updates_per_step):
|
||||
# Update parameters of all the networks
|
||||
@@ -211,7 +213,7 @@ with RichTQDM() as prog:
|
||||
logger.info("\nEpisode: {}, total numsteps: {}, episode steps: {}, reward: {}, updates: {}".format(i_episode, total_numsteps, episode_steps, round(episode_reward, 2), updates))
|
||||
prog.desc = "e: {}, r: {}, u: {}, m: {}".format(i_episode, round(episode_reward, 2), updates, len(memory))
|
||||
|
||||
if (i_episode % 100 == 0) and (args.eval is True) and i_episode>0:
|
||||
if (i_episode % 10 == 0) and (args.eval is True) and i_episode>0:
|
||||
avg_reward = 0.
|
||||
episodes = 10
|
||||
for _ in range(episodes):
|
||||
@@ -244,7 +246,7 @@ with RichTQDM() as prog:
|
||||
if total_numsteps >= args.num_steps:
|
||||
break
|
||||
|
||||
if args.train:
|
||||
save(save_dir)
|
||||
if args.train:
|
||||
save(save_dir)
|
||||
|
||||
env.close()
|
||||
|
||||
+2
-2
@@ -113,7 +113,7 @@ class GenerativeResnet3Headless(nn.Module):
|
||||
|
||||
|
||||
class ProcessObservation(nn.Module):
|
||||
def __init__(self, res=(224, 224)):
|
||||
def __init__(self, res=(112, 112)):
|
||||
super().__init__()
|
||||
self.res = res
|
||||
|
||||
@@ -122,7 +122,7 @@ class ProcessObservation(nn.Module):
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
'data/nets/cornell-randsplit-rgbd-grconvnet3-drop1-ch16/epoch_30_iou_0.97.pt'
|
||||
)
|
||||
self.feature_extractor = GenerativeResnet3Headless().train().half()
|
||||
self.feature_extractor = GenerativeResnet3Headless().train()
|
||||
self.feature_extractor.load_state_dict(state_dict=torch.load(grconvnet3_path), strict=False)
|
||||
|
||||
old_img_size = (res[0], res[1], 8)
|
||||
|
||||
@@ -5,10 +5,11 @@ from torch.optim import Adam
|
||||
from utils import soft_update, hard_update
|
||||
from model import GaussianPolicy, QNetwork, DeterministicPolicy
|
||||
from loguru import logger
|
||||
from apex import amp
|
||||
|
||||
|
||||
class SAC(object):
|
||||
def __init__(self, num_inputs, action_space, args, process_obs=None):
|
||||
def __init__(self, num_inputs, action_space, args, process_obs=None, opt_level='O1'):
|
||||
|
||||
self.gamma = args.gamma
|
||||
self.tau = args.tau
|
||||
@@ -49,6 +50,12 @@ class SAC(object):
|
||||
list(self.policy.parameters()) + list(process_obs.parameters()),
|
||||
lr=args.lr)
|
||||
|
||||
if opt_level is not None:
|
||||
model, optimizer = amp.initialize(
|
||||
[self.policy, self.process_obs, self.critic, self.critic_target],
|
||||
[self.policy_optim, self.critic_optim],
|
||||
opt_level=opt_level)
|
||||
|
||||
def select_action(self, obs, evaluate=False):
|
||||
with torch.no_grad():
|
||||
obs = torch.FloatTensor(obs).to(self.device).unsqueeze(0).to(self.dtype)
|
||||
@@ -85,20 +92,22 @@ class SAC(object):
|
||||
|
||||
self.critic_optim.zero_grad()
|
||||
assert torch.isfinite(qf_loss).all()
|
||||
qf_loss.backward()
|
||||
with amp.scale_loss(qf_loss, self.critic_optim) as qf_loss:
|
||||
qf_loss.backward()
|
||||
self.critic_optim.step()
|
||||
|
||||
state_batch = self.process_obs(obs_batch)
|
||||
pi, log_pi, _ = self.policy.sample(state_batch)
|
||||
|
||||
qf1_pi, qf2_pi = self.critic(state_batch, pi)
|
||||
qf1_pi, qf2_pi = self.critic(state_batch.detach(), pi)
|
||||
min_qf_pi = torch.min(qf1_pi, qf2_pi)
|
||||
|
||||
policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean() # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]
|
||||
|
||||
self.policy_optim.zero_grad()
|
||||
assert torch.isfinite(policy_loss).all()
|
||||
policy_loss.backward()
|
||||
with amp.scale_loss(policy_loss, self.policy_optim) as policy_loss:
|
||||
policy_loss.backward()
|
||||
self.policy_optim.step()
|
||||
|
||||
if self.automatic_entropy_tuning:
|
||||
|
||||
Reference in New Issue
Block a user