From 5534d4b0782a024e092c874aaee95d73b3fff2b3 Mon Sep 17 00:00:00 2001 From: wassname Date: Sat, 16 Jan 2021 18:14:57 +0800 Subject: [PATCH] prcoess_obs works, with training by both obs and critic --- Makefile | 2 +- load_demonstrations.py | 2 +- main.py | 18 ++++++-- grconvnet3.py => process_obs.py | 75 +++++++++++++++++++++++++++++---- sac.py | 28 ++++++------ 5 files changed, 99 insertions(+), 26 deletions(-) rename grconvnet3.py => process_obs.py (62%) diff --git a/Makefile b/Makefile index 27e5fbc..76c1ecf 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ python=/home/wassname/anaconda/envs/diygym3/bin/python date=2021-01-03_13-30-07 LOGURU_LEVEL=INFO run: - LOGURU_LEVEL=INFO ${python} main.py --demonstrations data/demonstrations --cuda --updates_per_step 4 --automatic_entropy_tuning + LOGURU_LEVEL=INFO ${python} main.py --demonstrations data/demonstrations --cuda --updates_per_step 4 --automatic_entropy_tuning true # LOGURU_LEVEL=INFO ${python} main.py --demonstrations data/demonstrations --cuda --updates_per_step 2 --load auto --alpha 0.1 --tau 1 --target_update_interval 1000 # LOGURU_LEVEL=INFO ${python} main.py --demonstrations data/demonstrations --cuda --updates_per_step 2 --load auto --tau 1 --target_update_interval 1000 --policy Deterministic diff --git a/load_demonstrations.py b/load_demonstrations.py index c6e21f6..174d001 100644 --- a/load_demonstrations.py +++ b/load_demonstrations.py @@ -13,7 +13,7 @@ def load_demonstrations(mem: ReplayMemory, recordings: Path): records = get_recordings(str(recordings)) logger.info('picks in recordings', sum(records['reward']>10)) ends=records["episodes_end_point"] - for i in tqdm(range(len(ends)-1), desc='loading demonstrations'): + for i in range(len(ends)-1): a = ends[i] b = ends[i+1] for s in range(a+1, b): diff --git a/main.py b/main.py index fb7ebb7..440c003 100644 --- a/main.py +++ b/main.py @@ -12,11 +12,13 @@ from replay_memory import ReplayMemory from load_demonstrations import load_demonstrations import apple_gym.env import pickle +from process_obs import ProcessObservation +# from torchinfo import summary from tqdm.auto import tqdm from loguru import logger -from rich import logger.info +from rich import print from rich.logging import RichHandler logging.basicConfig(level=logging.INFO, handlers=[RichHandler(rich_tracebacks=True, markup=True)]) logger.configure(handlers=[{"sink": RichHandler(markup=True), @@ -81,13 +83,22 @@ env.action_space.seed(args.seed) keys_to_monitor = ['env_reward/apple_pick/tree/min_fruit_dist_reward', 'env_reward/apple_pick/tree/gripping_fruit_reward', 'env_reward/apple_pick/tree/force_tree_reward', - 'env_reward/apple_pick/tree/force_fruit_reward', 'env_obs/apple_pick/tree/picks']: + 'env_reward/apple_pick/tree/force_fruit_reward', 'env_obs/apple_pick/tree/picks'] torch.manual_seed(args.seed) np.random.seed(args.seed) +# A visual network +observation_space=env.observation_space.shape[0] +process_obs=ProcessObservation() +observation_space=observation_space - process_obs.reduce_action_space +logger.info(f"process_obs reduces obs_space {env.observation_space.shape[0]}-{process_obs.reduce_action_space}={observation_space}") + # Agent -agent = SAC(env.observation_space.shape[0], env.action_space, args) +agent = SAC(observation_space, env.action_space, args, process_obs) + +# TODO +# summary(model, input_size=(batch_size, 1, 28, 28)) #Tensorboard log_name = '{}_SAC_{}_{}_{}'.format(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), args.env_name, @@ -132,6 +143,7 @@ updates = 0 with tqdm(unit='steps', mininterval=5) as prog: for i_episode in itertools.count(0): + print('1') episode_reward = 0 episode_steps = 0 done = False diff --git a/grconvnet3.py b/process_obs.py similarity index 62% rename from grconvnet3.py rename to process_obs.py index fd2fe34..60827e0 100644 --- a/grconvnet3.py +++ b/process_obs.py @@ -1,4 +1,6 @@ +import os +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -42,7 +44,7 @@ class GenerativeResnet3Headless(nn.Module): self.res2 = ResidualBlock(channel_size * 4, channel_size * 4) self.res3 = ResidualBlock(channel_size * 4, channel_size * 4) self.res4 = ResidualBlock(channel_size * 4, channel_size * 4) - self.res5 = ResidualBlock(channel_size * 4, channel_size * 4) + self.conv4 = nn.ConvTranspose2d(channel_size * 4, channel_size * 2, kernel_size=4, stride=2, padding=1, output_padding=1) @@ -65,21 +67,33 @@ class GenerativeResnet3Headless(nn.Module): self.dropout_sin = nn.Dropout(p=prob) self.dropout_wid = nn.Dropout(p=prob) + # freeze above params + for param in self.parameters(): + param.requires_grad = False + + self.res5 = ResidualBlock(channel_size * 4, channel_size * 4) + self.head = nn.Conv2d(64, 4, 1, bias=False) def forward(self, x_in): - x = F.relu(self.bn1(self.conv1(x_in))) - x = F.relu(self.bn2(self.conv2(x))) - x = F.relu(self.bn3(self.conv3(x))) - x = self.res1(x) - x = self.res2(x) - x = self.res3(x) - x = self.res4(x) + # Freeze these layers + with torch.no_grad(): + x = F.relu(self.bn1(self.conv1(x_in))) + x = F.relu(self.bn2(self.conv2(x))) + x = F.relu(self.bn3(self.conv3(x))) + x = self.res1(x) + x = self.res2(x) + x = self.res3(x) + x = self.res4(x) + x = self.res5(x) - # 1x1 conv to reduce feature state, with random weights + # 1x1 conv to reduce feature state, init with random weights x = self.head(x) + x = F.max_pool2d(x, kernel_size=3, stride=2) + x = F.max_pool2d(x, kernel_size=3, stride=2) + # ignore the old head which made it larger # x = F.relu(self.bn4(self.conv4(x))) # x = F.relu(self.bn5(self.conv5(x))) # x = self.conv6(x) @@ -96,3 +110,46 @@ class GenerativeResnet3Headless(nn.Module): # width_output = self.width_output(x) return x + + +class ProcessObservation(nn.Module): + def __init__(self, res=(224, 224)): + super().__init__() + self.res = res + + # Load visual model + grconvnet3_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + 'data/nets/cornell-randsplit-rgbd-grconvnet3-drop1-ch16/epoch_30_iou_0.97.pt' + ) + self.feature_extractor = GenerativeResnet3Headless().eval() + self.feature_extractor.load_state_dict(state_dict=torch.load(grconvnet3_path)) + + old_img_size = (res[0], res[1], 8) + new_img_size = (res[0]//16-1, res[1]//16-1, 8) + self.reduce_action_space = int(np.prod(old_img_size) - np.prod(new_img_size)) + + def __call__(self, obs): + """ + Takes in a torch array of observations, processes the images into features. + + This assumes the observations ends in 2 rgbd images with shape (224, 244, 4) + """ + # import pdb; pdb.set_trace() + h, w = self.res + px = h * w + base_rgbd = obs[:, -px * 4:].reshape((-1, h, w, 4)) + arm_rgbd = obs[:, -px * 8: - px * 4].reshape((-1, h, w, 4)) + others = obs[:,: - px * 8] + bs = obs.shape[0] + + # make a batch + x = torch.cat([base_rgbd, arm_rgbd], 0) + x = x.permute((0, 3, 1, 2)) # to ((-1, 4, x, y)) + h = self.feature_extractor(x) + + # undo fake batch + base_h, arm_h = h[:bs].reshape((bs, -1)), h[bs:].reshape((bs, -1)) + # add features together + y = torch.cat([others, base_h, arm_h], 1) + return y diff --git a/sac.py b/sac.py index bdf89cb..6d91743 100644 --- a/sac.py +++ b/sac.py @@ -13,18 +13,19 @@ class SAC(object): self.gamma = args.gamma self.tau = args.tau self.alpha = args.alpha - self.process_obs = process_obs + self.device = torch.device("cuda" if args.cuda else "cpu") self.policy_type = args.policy self.target_update_interval = args.target_update_interval self.automatic_entropy_tuning = args.automatic_entropy_tuning - self.device = torch.device("cuda" if args.cuda else "cpu") + self.process_obs = process_obs.to(self.device) + self.critic = QNetwork(num_inputs, action_space.shape[0], args.hidden_size).to(device=self.device) + self.critic_optim = Adam( + list(self.critic.parameters()) + list(process_obs.parameters()) + , lr=args.lr) - self.critic = QNetwork(num_inputs, action_space.shape[0], args.hidden_size, process_obs).to(device=self.device) - self.critic_optim = Adam(self.critic.parameters(), lr=args.lr) - - self.critic_target = QNetwork(num_inputs, action_space.shape[0], args.hidden_size, process_obs).to(self.device) + self.critic_target = QNetwork(num_inputs, action_space.shape[0], args.hidden_size).to(self.device) hard_update(self.critic_target, self.critic) if self.policy_type == "Gaussian": @@ -34,13 +35,15 @@ class SAC(object): self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device) self.alpha_optim = Adam([self.log_alpha], lr=args.lr) - self.policy = GaussianPolicy(num_inputs, action_space.shape[0], args.hidden_size, action_space, process_obs).to(self.device) - self.policy_optim = Adam(self.policy.parameters(), lr=args.lr) + self.policy = GaussianPolicy(num_inputs, action_space.shape[0], args.hidden_size, action_space).to(self.device) + self.policy_optim = Adam( + list(self.policy.parameters()) + list(process_obs.parameters()), + lr=args.lr) else: self.alpha = 0 self.automatic_entropy_tuning = False - self.policy = DeterministicPolicy(num_inputs, action_space.shape[0], args.hidden_size, action_space, process_obs).to(self.device) + self.policy = DeterministicPolicy(num_inputs, action_space.shape[0], args.hidden_size, action_space).to(self.device) self.policy_optim = Adam(self.policy.parameters(), lr=args.lr) def select_action(self, obs, evaluate=False): @@ -64,8 +67,8 @@ class SAC(object): state_batch = self.process_obs(obs_batch) - next_state_batch = self.process_obs(next_obs_batch) with torch.no_grad(): + next_state_batch = self.process_obs(next_obs_batch) next_state_action, next_state_log_pi, _ = self.policy.sample(next_state_batch) qf1_next_target, qf2_next_target = self.critic_target(next_state_batch, next_state_action) min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - self.alpha * next_state_log_pi @@ -79,6 +82,7 @@ class SAC(object): qf_loss.backward() self.critic_optim.step() + state_batch = self.process_obs(obs_batch) pi, log_pi, _ = self.policy.sample(state_batch) qf1_pi, qf2_pi = self.critic(state_batch, pi) @@ -111,13 +115,13 @@ class SAC(object): # Save model parameters def save_model(self, actor_path=None, critic_path=None): - logger.debug(f'saving models to {actor_path} and {critic_path}')) + logger.debug(f'saving models to {actor_path} and {critic_path}') torch.save(self.policy.state_dict(), actor_path) torch.save(self.critic.state_dict(), critic_path) # Load model parameters def load_model(self, actor_path, critic_path): - logger.info(f'Loading models from {actor_path} and {critic_path}')) + logger.info(f'Loading models from {actor_path} and {critic_path}') if actor_path is not None: self.policy.load_state_dict(torch.load(actor_path)) if critic_path is not None: