From f28ca774ed11d5df093c87203a47534a8c8b0427 Mon Sep 17 00:00:00 2001 From: wassname Date: Sun, 17 Jan 2021 18:27:36 +0800 Subject: [PATCH] smaller camera, detach critic inputs --- Makefile | 5 +++-- main.py | 28 +++++++++++++++------------- process_obs.py | 4 ++-- sac.py | 17 +++++++++++++---- 4 files changed, 33 insertions(+), 21 deletions(-) diff --git a/Makefile b/Makefile index 6c241be..3820eb7 100644 --- a/Makefile +++ b/Makefile @@ -9,9 +9,10 @@ run: -m pdb -c continue \ main.py \ --cuda \ + --opt_level O1 \ --automatic_entropy_tuning true \ - --replay_size 10000 \ - --demonstrations data/demonstrations \ + --replay_size 100000 \ + # --demonstrations data/demonstrations \ # --load auto \ # ${python} -m pdb main.py --cuda --automatic_entropy_tuning true --replay_size 10000 --load auto --start_steps 200 # LOGURU_LEVEL=INFO ${python} main.py --demonstrations data/demonstrations --cuda --automatic_entropy_tuning true --replay_size 20000 --load auto diff --git a/main.py b/main.py index db185c1..93a0b59 100644 --- a/main.py +++ b/main.py @@ -66,6 +66,8 @@ def get_args(): help='Load models') parser.add_argument('-r', '--render', action="store_true", help='show') + parser.add_argument('-O', '--opt_level', default='O0', + help='Apex Amp Optimisation level') args = parser.parse_args() return args @@ -95,16 +97,16 @@ observation_dim=observation_dim - process_obs.reduce_obs_space logger.info(f"process_obs reduces obs_space {env.observation_space.shape[0]}-{process_obs.reduce_obs_space}={observation_dim}") # Agent -agent = SAC(observation_dim, env.action_space, args, process_obs) +agent = SAC(observation_dim, env.action_space, args, process_obs, args.opt_level) -# from torchinfo import summary -# print('process_obs') -# summary(process_obs, input_size=(2, *env.observation_space.shape), depth=2) -# print('critic') -# summary(agent.critic, input_size=((2, observation_dim), (2, action_dim))) -# print('policy') -# summary(agent.policy, input_size=(2, observation_dim)) -# # print(process_obs, agent.critic, agent.policy) +from torchinfo import summary +print('process_obs') +summary(process_obs, input_size=(2, *env.observation_space.shape), depth=2) +print('critic') +summary(agent.critic, input_size=((2, observation_dim), (2, action_dim))) +print('policy') +summary(agent.policy, input_size=(2, observation_dim)) +# print(process_obs, agent.critic, agent.policy) #Tensorboard log_name = '{}_SAC_{}_{}_{}'.format(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), args.env_name, @@ -172,7 +174,7 @@ with RichTQDM() as prog: else: action = agent.select_action(state) # Sample action from policy - if len(memory) > args.batch_size and (total_numsteps%20==0): + if len(memory) > args.batch_size and (total_numsteps%1==0): # Number of updates per step in environment for i in range(args.updates_per_step): # Update parameters of all the networks @@ -211,7 +213,7 @@ with RichTQDM() as prog: logger.info("\nEpisode: {}, total numsteps: {}, episode steps: {}, reward: {}, updates: {}".format(i_episode, total_numsteps, episode_steps, round(episode_reward, 2), updates)) prog.desc = "e: {}, r: {}, u: {}, m: {}".format(i_episode, round(episode_reward, 2), updates, len(memory)) - if (i_episode % 100 == 0) and (args.eval is True) and i_episode>0: + if (i_episode % 10 == 0) and (args.eval is True) and i_episode>0: avg_reward = 0. episodes = 10 for _ in range(episodes): @@ -244,7 +246,7 @@ with RichTQDM() as prog: if total_numsteps >= args.num_steps: break - if args.train: - save(save_dir) + if args.train: + save(save_dir) env.close() diff --git a/process_obs.py b/process_obs.py index 6ef5825..b8db477 100644 --- a/process_obs.py +++ b/process_obs.py @@ -113,7 +113,7 @@ class GenerativeResnet3Headless(nn.Module): class ProcessObservation(nn.Module): - def __init__(self, res=(224, 224)): + def __init__(self, res=(112, 112)): super().__init__() self.res = res @@ -122,7 +122,7 @@ class ProcessObservation(nn.Module): os.path.dirname(os.path.abspath(__file__)), 'data/nets/cornell-randsplit-rgbd-grconvnet3-drop1-ch16/epoch_30_iou_0.97.pt' ) - self.feature_extractor = GenerativeResnet3Headless().train().half() + self.feature_extractor = GenerativeResnet3Headless().train() self.feature_extractor.load_state_dict(state_dict=torch.load(grconvnet3_path), strict=False) old_img_size = (res[0], res[1], 8) diff --git a/sac.py b/sac.py index 8197745..aa11b8f 100644 --- a/sac.py +++ b/sac.py @@ -5,10 +5,11 @@ from torch.optim import Adam from utils import soft_update, hard_update from model import GaussianPolicy, QNetwork, DeterministicPolicy from loguru import logger +from apex import amp class SAC(object): - def __init__(self, num_inputs, action_space, args, process_obs=None): + def __init__(self, num_inputs, action_space, args, process_obs=None, opt_level='O1'): self.gamma = args.gamma self.tau = args.tau @@ -49,6 +50,12 @@ class SAC(object): list(self.policy.parameters()) + list(process_obs.parameters()), lr=args.lr) + if opt_level is not None: + model, optimizer = amp.initialize( + [self.policy, self.process_obs, self.critic, self.critic_target], + [self.policy_optim, self.critic_optim], + opt_level=opt_level) + def select_action(self, obs, evaluate=False): with torch.no_grad(): obs = torch.FloatTensor(obs).to(self.device).unsqueeze(0).to(self.dtype) @@ -85,20 +92,22 @@ class SAC(object): self.critic_optim.zero_grad() assert torch.isfinite(qf_loss).all() - qf_loss.backward() + with amp.scale_loss(qf_loss, self.critic_optim) as qf_loss: + qf_loss.backward() self.critic_optim.step() state_batch = self.process_obs(obs_batch) pi, log_pi, _ = self.policy.sample(state_batch) - qf1_pi, qf2_pi = self.critic(state_batch, pi) + qf1_pi, qf2_pi = self.critic(state_batch.detach(), pi) min_qf_pi = torch.min(qf1_pi, qf2_pi) policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean() # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))] self.policy_optim.zero_grad() assert torch.isfinite(policy_loss).all() - policy_loss.backward() + with amp.scale_loss(policy_loss, self.policy_optim) as policy_loss: + policy_loss.backward() self.policy_optim.step() if self.automatic_entropy_tuning: